tensorflow与pytorch张量互转

1、tensorflow张量转pytorch张量

  • tensorflow(Tensor)–>numpy.ndarray–>pytorch(Tensor)
import torch
import tensorflow as tf

tf_tensor = tf.constant([1,2,3])
with tf.compat.v1.Session().as_default():
	np_array = tf_tensor.eval()
torch_tensor = torch.from_numpy(np_array )

2、pytorch张量转tensorflow张量

  • pytorch(Tensor)–>numpy.ndarray–>tensorflow(Tensor)
import torch
import tensorflow as tf

torch_tensor = torch.ones(100)
np_tensor = torch_tensor.numpy()
tf_tensor = tf.convert_to_tensor(np_tensor)

猜你喜欢

转载自blog.csdn.net/wjinjie/article/details/127970641