tf.slice(input_, begin, size, name = None)
- 这个函数的作用是从输入数据input中提取出一块切片
- 切片的尺寸是size,切片的开始位置是begin。
- 切片的尺寸size表示输出tensor的数据维度,其中size[i]表示在第i维度上面的元素个数。
- 开始位置begin表示切片相对于输入数据input_的每一个偏移量
import tensorflow as tf
sess=tf.Session()
input=tf.constant([[[1, 1, 1], [2, 2, 2]],
[[3, 3, 3], [4, 4, 4]],
[[5, 5, 5], [6, 6, 6]]])
data=tf.slice(input,[1,0,0],[1,1,3])
print(sess.run(data))
#输出的结果是:[[[3 3 3]]]
data1=tf.slice(input,[1,0,0],[1,2,3])
print(sess.run(data1))
#输出的结果是:[[[3 3 3]
# [4 4 4]]]
解析:tf.slice(input,[1,0,0],[1,1,3])
[1,0,0]:第一位为1,表示从axis=0(行)的第2个位置开始,即从[[3,3,3],[4,4,4]开始
第二位为0,表示从axis=1(列)的第一个位置开始,即从[3,3,3]开始
第三位为0,表示从axis=2(维度)的第一个位置开始
[1,1,3]:第一位为1,表示后面两行只取[[3,3,3],[4,4,4]这行
第二位为1,表示两列中只取[3,3,3]这列
第三位为3,表示最后取出来的数是三个维度的。(如果为2,则只取前两个数,最后结果为[[[3,3]]])
综合begin和size得出的数据为:[[[3,3,3]]],下一个结果同理可得