Usage tf.split function (tensorflow1.13.0)

tf.split(input, num_split, dimension):

dimension input means tensor which one dimension, 0 if it is 0 it means that the first cutting dimensions; num_split number is cut, if it is 2 representing the input tensor is cut into 2 parts, each one is a list.

 E.g:

Copy the code
import tensorflow as tf;
import numpy as np;

A = [[1,2,3],[4,5,6]]
x = tf.split(A, 3, 1)

with tf.Session() as sess:
        c = sess.run(x)
        for ele in c:
                print( ele )
Copy the code

 Output:

[[1]
 [4]]
[[2]
 [5]]
[[3]
 [6]]

Guess you like

Origin www.cnblogs.com/HYWZ36/p/11408451.html