torch.where() usage


foreword

This article mainly describes torch.where()the two usages, the first one is the most conventional, and it is also indicated in the official document; the second one is the boolcalculation of cooperative tensors


1, torch.where () conventional usage

Let's first look at the explanation of the official documentation :

torch.where(condition, x, y)
According to the condition, that is condiction, return a tensor of elements selected from xor (here a new tensor is created, the elements of the new tensor are selected from x or y, and the shape must meet the broadcast condition of sum). The explanation is as follows: 1. : When it is true, the returned value, otherwise the returned value 2. : The selected value at that time 2. : The selected value at that timeyxy
Parameters
condition (bool型张量) conditionxy
x (张量或标量)condition=Truex
y (张量或标量)condition=Falsey

I read a lot of blog posts, and they all said xthat ythe shape of the sum must be the same, which is completely nonsense. The official document is clearly written: The tensors condition, x, y must be broadcastable. That is to say, condition、x、yit is enough to be able to broadcast, and the shape is not required Same. Look at the usage below:

1.1 Same shape

First demonstrate the case of the same shape:

import torch

x = torch.tensor([[1, 2, 3], [3, 4, 5], [5, 6, 7]])
y = torch.tensor([[5, 6, 7], [7, 8, 9], [9, 10, 11]])
z = torch.where(x > 5, x, y)

print(f'x = {
      
      x}')
print(f'=========================')
print(f'y = {
      
      y}')
print(f'=========================')
print(f'x > 5 = {
      
      x > 5}')
print(f'=========================')
print(f'z = {
      
      z}')

>print result:
x = tensor([[1, 2, 3],
        [3, 4, 5],
        [5, 6, 7]])
=========================
y = tensor([[ 5,  6,  7],
        [ 7,  8,  9],
        [ 9, 10, 11]])
=========================
x > 5 = tensor([[False, False, False],
        [False, False, False],
        [False,  True,  True]])
=========================
z = tensor([[5, 6, 7],
        [7, 8, 9],
        [9, 6, 7]])

The above defines xand y, the shape of the two shape=(3, 3)is the same, and then condition = x > 5the xvalue of each element in it must be greater than 5, here you can see xthat both row 0 and row 1 are both False, and only columns 1 and 2 of row 2 Yes True, and as I said before True, xthe value in is used for the last time, Falseand ythe value in is used for the last time, then the first two newly created zrows and row 2 and column 0 use the yvalue in, and the remaining two use the value in xThe values ​​in , zas shapewell (3, 3).

1.2 Scalar case

x = 3
y = torch.tensor([[1, 5, 7]])
z = torch.where(y > 2, y, x)

print(f'y > 2 = {
      
      y > 2}')
print(f'=========================')
print(f'z = {
      
      z}')

print(f'y > 2 = {
      
      y > 2}')
print(f'=========================')
print(f'z = {
      
      z}')

>print result:
y > 2 = tensor([[False,  True,  True]])
=========================
z = tensor([[3, 5, 7]])

Here, xis a scalar, condition = y > 2, if you ask me why I don't conditionset it as condition = x > 2, it's very simple, x > 2no bool Tensor. Here scalars and tensors can be broadcasted! !
example:

a = torch.tensor([1, 5, 7])
b = 3
c = a + b
d = torch.tensor([3, 3, 3])
e = a + d

print(f'c = {
      
      c}')
print(f'e = {
      
      e}')

>print result:
c = tensor([ 4,  8, 10])
d = tensor([ 4,  8, 10])

In fact, even if it is b = 3pulled [3, 3, 3], it is just dlike that.

1.3 Different shapes

In fact, the shape of the scalar is also different. Let me repeat it here, see an example:

x = torch.tensor([[1, 3, 5]])
y = torch.tensor([[2], [4], [6]])
z = torch.where(x > 2, x, y)

print(f'x = {
      
      x}')
print(f'=========================')
print(f'y = {
      
      y}')
print(f'=========================')
print(f'x > 2 = {
      
      x > 2}')
print(f'=========================')
print(f'z = {
      
      z}')

>print result:
x = tensor([[1, 3, 5]])
=========================
y = tensor([[2],
        [4],
        [6]])
=========================
x > 2 = tensor([[False,  True,  True]])
=========================
z = tensor([[2, 3, 5],
        [4, 3, 5],
        [6, 3, 5]])

The above x.shape=(1, 3) y.shape=(3, 1), and then condition = x > 2, shape=(1, 3)are broadcastable , so the operation can also be successful. During the calculation torch.where(x > 2, x, y), they are respectively x、y、conditionbroadcasted, x.shape=(3, 3), y.shape=(3, 3), condition.shape=(3, 3)
insert image description here
so ythe values ​​in column 0 are replaced by the values ​​in columns 1 and 2 x.
Readers and friends are welcome to try more broadcasting forms by themselves


2. Special usage of torch.where()

torch.where(a & b)
aAnd bboth bool Tensor, return a tuple , the first item of the tuple is the row in the middle ,a、b and Truethe second item is the columnindexTensora、bTrueindexTensor

Please see the example:

a = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 1]], dtype=torch.bool)
b = torch.ones((3, 3), dtype=torch.bool)
c = torch.where(a & b)

print(f'a = {
      
      a}')
print(f'=========================')
print(f'b = {
      
      b}')
print(f'=========================')
print(f'c = {
      
      c}')

>print result:
a = tensor([[False,  True,  True],
        [ True, False, False],
        [False, False,  True]])
=========================
b = tensor([[True, True, True],
        [True, True, True],
        [True, True, True]])
=========================
c = (tensor([0, 0, 1, 2]), tensor([1, 2, 0, 2]))

cIt is a tuple , the 0th item is the row label of a、bboth , and the 1st item is the column label of bothTruea、bTrue


Summarize

The above are the two usages of torch.where(). It seems to be more troublesome. If you practice more, it will be the same. The special point is a special usage of a broadcast mechanism. Comments and corrections are welcome!
Please respect originality and refuse to reprint! ! !


reference link

https://pytorch.org/docs/stable/generated/torch.where.html#torch.where
https://pytorch.org/docs/stable/notes/broadcasting.html#broadcasting-semantics
https://numpy.org/doc/stable/user/basics.broadcasting.html

Guess you like

Origin blog.csdn.net/euqlll/article/details/127791397