tensorflow2.x学习笔记二十二:tf.where和tf.boolean_mask的使用

一、tf.where

tf.where(condition, x=None, y=None,name=None)
  • condition:一个Tensor,数据类型为bool类型

  • 如果x、y均为空,那么返回condition中值为True的位置

  • 如果x、y不为空,那么x、y必须有相同的形状。如果x、y是标量,那么condition也必须是标量。如果x、y是向量,那么condition必须和x的某一维有相同的形状或者和x形状一致。返回值和x、y有相同的形状,如果condition对应位置值为True那么返回的Tensor对应位置为x的值,否则为y的值

下面举几个例子:

  • x,y均为None
import tensorflow as tf

a = tf.constant([[-1,1,-1],[2,2,-2],[3,-3,3]],
                dtype=tf.float32)
print(c)
'''
tf.Tensor(
[[-1.  1. -1.]
 [ 2.  2. -2.]
 [ 3. -3.  3.]], shape=(3, 3), dtype=float32)
'''
b= tf.where(a<0) 
print(b)

'''
tf.Tensor(
[[0 0]
 [0 2]
 [1 2]
 [2 1]], shape=(4, 2), dtype=int64)
'''
  • x,y不为None,condition的形状和x的形状不同,分别是只有第一维相同和只有第二维相同
mask=np.array([True,False,True])
c= tf.where(mask,tf.fill(a.shape,np.nan),a) 
print(mask.shape)
print(c)

'''
(3,)
tf.Tensor(
[[nan  1. nan]
 [nan  2. nan]
 [nan -3. nan]], shape=(3, 3), dtype=float32)
'''
mask=np.array([[True],[False],[True]])
d= tf.where(mask,tf.fill(a.shape,np.nan),a) 
print(mask.shape)
print(d)
'''
(3,1)
tf.Tensor(
[[nan nan nan]
 [ 2.  2. -2.]
 [nan nan nan]], shape=(3, 3), dtype=float32)
'''
  • x,y不为None,condition的形状和x的形状相同
e = tf.where(a<0,tf.fill(a.shape,np.nan),a) 
print(e)

'''
tf.Tensor(
[[nan  1. nan]
 [ 2.  2. nan]
 [ 3. nan  3.]], shape=(3, 3), dtype=float32)
'''

二、tf.boolean_mask

tf.boolean_mask(tensor,mask,name='boolean_mask',axis=None)
  • tensor:是N维度的

  • mask:是K维度的;注意K小于等于N

  • name:可选项也就是这个操作的名字,

  • axis:是一个0维度的inttensor,表示的是从参数tensor的哪个axis开始mask,默认的情况下axis=0,表示从第一维度进行mask,因此K+axis小于等于N

  • 返回的是N-K+1维度的tensor,也就是maskTrue的地方保存下来。
    下面三个例子中的mask的维度分别是三维、二维和一维

tensor = np.array([[[1, 2], [3, 4], [5, 6]],
                   [[1, 2], [3, 4], [5, 6]]])

mask=np.array([[[True,True],[False,True],[False,False]],
               [[True,True],[False,True],[False,False]]])

z=tf.boolean_mask(tensor, mask)
print(tensor)
print(tensor.shape)
print(mask.shape)
print(z)

'''
[[[1 2]
  [3 4]
  [5 6]]

 [[1 2]
  [3 4]
  [5 6]]]
  
(2, 3, 2)
(2, 3, 2)
tf.Tensor([1 2 4 1 2 4], shape=(6,), dtype=int32)
'''
mask=np.array([[True,False,False],[True,False,False]])

z=tf.boolean_mask(tensor, mask)
print(tensor.shape)
print(mask.shape)
print(z)

'''
(2, 3, 2)
(2, 3)

tf.Tensor(
[[1 2]
 [1 2]], shape=(2, 2), dtype=int32)
'''
mask=np.array([True,False])

z=tf.boolean_mask(tensor, mask)
print(tensor.shape)
print(mask.shape)
print(z)

(2, 3, 2)
(2,)

tf.Tensor(
[[[1 2]
  [3 4]
  [5 6]]], shape=(1, 3, 2), dtype=int32)

猜你喜欢

转载自blog.csdn.net/qq_39507748/article/details/105556270