In this part are: How selective purposes (WHERE) based on the coordinates, how to update object (scatter_nd) based on the coordinates, how to generate a coordinate system ()
1.where
where for the tensor is a tensor bool type, i.e., True or False by element composition, where (tensor) Returns True for the location of the element
# Randomly generated normally distributed [3,3] of Tensor A = tf.random.normal ([3,3 ]) Print (A) # assign bool matrix corresponding to the mask mask = A> 0 Print ( mask) # accessible to true by a corresponding mask element value Print (tf.boolean_mask (a, mask)) # acquired by the true position where indices = tf.where (mask) Print (indices) # by the indices from a elements having Print (tf.gather_nd (A, indices))
Print (mask) # define all the elements A Tensor. 1 A = tf.ones ([3,3 ]) # define B tensor elements are all 0 B = tf.zeros ([3,3 ]) # take the A sampling B on the false to true Print (tf.where (mask, a, B))
2.scatter_nd
# Specify the update value index indices tf.constant = ([[. 4], [. 3], [. 1], [. 7 ]]) # specify the update element Updates tf.constant = ([9,10,11,12 ]) # specify plate Shape Shape tf.constant = ([. 8 ]) Print (tf.scatter_nd (indices, Updates, Shape))
# Specified index update element indices tf.constant = ([[0], [2 ]]) # specify the update value of the element Updates = tf.constant ([ [[ 5,5,5,5], [6,6 , 6,6], [7,7,7,7], [8,8,8,8 ]], [[ 5,5,5,5], [6,6,6,6], [7 , 7,7,7], [8,8,8,8 ]] ]) Print (updates.shape) # specify plate Shape Shape tf.constant = ([4,4,4 ]) Print (tf.scatter_nd ( indices, updates, shape))
3.meshgrid
# 生成y轴,范围-2,2,元素个数5个 y = tf.linspace(-2,2,5) print(y) # 按照相同方式生成x轴 x = tf.linspace(-2,2,5) # 生成坐标系 points_x,points_y = tf.meshgrid(x,y) print(points_x.shape)
然后通过tf.stack方法,即可实现x和y的合并,从而生成点的坐标