torch.where usage
foreword
This article mainly describes torch.where()
the two usages, the first one is the most conventional, and it is also indicated in the official document; the second one is the bool
calculation of cooperative tensors
1, torch.where () conventional usage
Let's first look at the explanation of the official documentation :
torch.where(condition, x, y)
According to the condition, that iscondiction
, return a tensor of elements selected fromx
or (here a new tensor is created, the elements of the new tensor are selected from x or y, and the shape must meet the broadcast condition of sum). The explanation is as follows: 1. : When it is true, the returned value, otherwise the returned value 2. : The selected value at that time 2. : The selected value at that timey
x
y
Parameters
condition (bool型张量)
condition
x
y
x (张量或标量)
condition=True
x
y (张量或标量)
condition=False
y
I read a lot of blog posts, and they all said x
that y
the shape of the sum must be the same, which is completely nonsense. The official document is clearly written: The tensors condition, x, y must be broadcastable. That is to say, condition、x、y
it is enough to be able to broadcast, and the shape is not required Same. Look at the usage below:
1.1 Same shape
First demonstrate the case of the same shape:
import torch
x = torch.tensor([[1, 2, 3], [3, 4, 5], [5, 6, 7]])
y = torch.tensor([[5, 6, 7], [7, 8, 9], [9, 10, 11]])
z = torch.where(x > 5, x, y)
print(f'x = {
x}')
print(f'=========================')
print(f'y = {
y}')
print(f'=========================')
print(f'x > 5 = {
x > 5}')
print(f'=========================')
print(f'z = {
z}')
>print result:
x = tensor([[1, 2, 3],
[3, 4, 5],
[5, 6, 7]])
=========================
y = tensor([[ 5, 6, 7],
[ 7, 8, 9],
[ 9, 10, 11]])
=========================
x > 5 = tensor([[False, False, False],
[False, False, False],
[False, True, True]])
=========================
z = tensor([[5, 6, 7],
[7, 8, 9],
[9, 6, 7]])
The above defines x
and y
, the shape of the two shape=(3, 3)
is the same, and then condition = x > 5
the x
value of each element in it must be greater than 5, here you can see x
that both row 0 and row 1 are both False
, and only columns 1 and 2 of row 2 Yes True
, and as I said before True
, x
the value in is used for the last time, False
and y
the value in is used for the last time, then the first two newly created z
rows and row 2 and column 0 use the y
value in, and the remaining two use the value in x
The values in , z
as shape
well (3, 3)
.
1.2 Scalar case
x = 3
y = torch.tensor([[1, 5, 7]])
z = torch.where(y > 2, y, x)
print(f'y > 2 = {
y > 2}')
print(f'=========================')
print(f'z = {
z}')
print(f'y > 2 = {
y > 2}')
print(f'=========================')
print(f'z = {
z}')
>print result:
y > 2 = tensor([[False, True, True]])
=========================
z = tensor([[3, 5, 7]])
Here, x
is a scalar, condition = y > 2
, if you ask me why I don't condition
set it as condition = x > 2
, it's very simple, x > 2
no bool Tensor
. Here scalars and tensors can be broadcasted! !
example:
a = torch.tensor([1, 5, 7])
b = 3
c = a + b
d = torch.tensor([3, 3, 3])
e = a + d
print(f'c = {
c}')
print(f'e = {
e}')
>print result:
c = tensor([ 4, 8, 10])
d = tensor([ 4, 8, 10])
In fact, even if it is b = 3
pulled [3, 3, 3]
, it is just d
like that.
1.3 Different shapes
In fact, the shape of the scalar is also different. Let me repeat it here, see an example:
x = torch.tensor([[1, 3, 5]])
y = torch.tensor([[2], [4], [6]])
z = torch.where(x > 2, x, y)
print(f'x = {
x}')
print(f'=========================')
print(f'y = {
y}')
print(f'=========================')
print(f'x > 2 = {
x > 2}')
print(f'=========================')
print(f'z = {
z}')
>print result:
x = tensor([[1, 3, 5]])
=========================
y = tensor([[2],
[4],
[6]])
=========================
x > 2 = tensor([[False, True, True]])
=========================
z = tensor([[2, 3, 5],
[4, 3, 5],
[6, 3, 5]])
The above x.shape=(1, 3) y.shape=(3, 1)
, and then condition = x > 2
, shape=(1, 3)
are broadcastable , so the operation can also be successful. During the calculation torch.where(x > 2, x, y)
, they are respectively x、y、condition
broadcasted, x.shape=(3, 3)
, y.shape=(3, 3)
, condition.shape=(3, 3)
so y
the values in column 0 are replaced by the values in columns 1 and 2 x
.
Readers and friends are welcome to try more broadcasting forms by themselves
2. Special usage of torch.where()
torch.where(a & b)
a
Andb
bothbool Tensor
, return a tuple , the first item of the tuple is the row in the middle ,a、b
andTrue
the second item is the columnindex
Tensor
a、b
True
index
Tensor
Please see the example:
a = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 1]], dtype=torch.bool)
b = torch.ones((3, 3), dtype=torch.bool)
c = torch.where(a & b)
print(f'a = {
a}')
print(f'=========================')
print(f'b = {
b}')
print(f'=========================')
print(f'c = {
c}')
>print result:
a = tensor([[False, True, True],
[ True, False, False],
[False, False, True]])
=========================
b = tensor([[True, True, True],
[True, True, True],
[True, True, True]])
=========================
c = (tensor([0, 0, 1, 2]), tensor([1, 2, 0, 2]))
c
It is a tuple , the 0th item is the row label of a、b
both , and the 1st item is the column label of bothTrue
a、b
True
Summarize
The above are the two usages of torch.where(). It seems to be more troublesome. If you practice more, it will be the same. The special point is a special usage of a broadcast mechanism. Comments and corrections are welcome!
Please respect originality and refuse to reprint! ! !
reference link
https://pytorch.org/docs/stable/generated/torch.where.html#torch.where
https://pytorch.org/docs/stable/notes/broadcasting.html#broadcasting-semantics
https://numpy.org/doc/stable/user/basics.broadcasting.html