tensorflow high-level operating

In this part are: How selective purposes (WHERE) based on the coordinates, how to update object (scatter_nd) based on the coordinates, how to generate a coordinate system ()

1.where

where for the tensor is a tensor bool type, i.e., True or False by element composition, where (tensor) Returns True for the location of the element

# Randomly generated normally distributed [3,3] of Tensor 
A = tf.random.normal ([3,3 ])
 Print (A)
 # assign bool matrix corresponding to the mask 
mask = A> 0
 Print ( mask)
 # accessible to true by a corresponding mask element value 
Print (tf.boolean_mask (a, mask))
 # acquired by the true position where 
indices = tf.where (mask)
 Print (indices)
 # by the indices from a elements having 
Print (tf.gather_nd (A, indices))

Print (mask)
 # define all the elements A Tensor. 1 
A = tf.ones ([3,3 ])
 # define B tensor elements are all 0 
B = tf.zeros ([3,3 ])
 # take the A sampling B on the false to true 
Print (tf.where (mask, a, B))

 

2.scatter_nd

# Specify the update value index 
indices tf.constant = ([[. 4], [. 3], [. 1], [. 7 ]])
 # specify the update element 
Updates tf.constant = ([9,10,11,12 ])
 # specify plate Shape 
Shape tf.constant = ([. 8 ])
 Print (tf.scatter_nd (indices, Updates, Shape))

 

 

# Specified index update element 
indices tf.constant = ([[0], [2 ]])
 # specify the update value of the element 
Updates = tf.constant ([ 
    [[ 5,5,5,5], [6,6 , 6,6], [7,7,7,7], [8,8,8,8 ]], 
    [[ 5,5,5,5], [6,6,6,6], [7 , 7,7,7], [8,8,8,8 ]] 
]) 
Print (updates.shape)
 # specify plate Shape 
Shape tf.constant = ([4,4,4 ])
 Print (tf.scatter_nd ( indices, updates, shape))

 

3.meshgrid

# 生成y轴,范围-2,2,元素个数5个
y = tf.linspace(-2,2,5)
print(y)
# 按照相同方式生成x轴
x = tf.linspace(-2,2,5)
# 生成坐标系
points_x,points_y = tf.meshgrid(x,y)
print(points_x.shape)

然后通过tf.stack方法,即可实现x和y的合并,从而生成点的坐标

 

Guess you like

Origin www.cnblogs.com/zdm-code/p/12233301.html