カテゴリ: 「Pytorch 関数を簡単に理解する」総合カタログ
入力テンソルの要素を、指定された次元に沿った値で並べ替えます。指定しない場合はdim
、入力の最後の次元が選択されます。descending
指定した場合True
、要素は値によって降順に並べ替えられ、それ以外の場合は昇順に並べ替えられます。stable
そうであればTrue
、ソート ルーチンは安定し、同等の要素の順序が維持されます。
文法
torch.sort(input, dim=-1, descending=False, stable=False, *, out=None)
パラメータ
input
:[Tensor
] 入力テンソルdim
: [オプション、int
] 並べ替えの基準となるディメンション- 降順 : [オプション、
bool
] 並べ替え順序、False
昇順、True
降順、デフォルトでは昇順 - 安定 : [オプション、
bool
] ソートが安定しているかどうかを示します。安定している場合はTrue
、同等の要素の順序が保持されます。
戻り値
タプル(values, indices)
( はvalues
ソートされた値) は、indices
元の入力テンソルの要素のインデックスです。
例
>>> x = torch.randn(3, 4)
>>> sorted, indices = torch.sort(x)
>>> sorted
tensor([[-0.2162, 0.0608, 0.6719, 2.3332],
[-0.5793, 0.0061, 0.6058, 0.9497],
[-0.5071, 0.3343, 0.9553, 1.0960]])
>>> indices
tensor([[ 1, 0, 2, 3],
[ 3, 1, 0, 2],
[ 0, 3, 1, 2]])
>>> sorted, indices = torch.sort(x, 0)
>>> sorted
tensor([[-0.5071, -0.2162, 0.6719, -0.5793],
[ 0.0608, 0.0061, 0.9497, 0.3343],
[ 0.6058, 0.9553, 1.0960, 2.3332]])
>>> indices
tensor([[ 2, 0, 0, 1],
[ 0, 1, 1, 2],
[ 1, 2, 2, 0]])
>>> x = torch.tensor([0, 1] * 9)
>>> x.sort()
torch.return_types.sort(
values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
indices=tensor([ 2, 16, 4, 6, 14, 8, 0, 10, 12, 9, 17, 15, 13, 11, 7, 5, 3, 1]))
>>> x.sort(stable=True)
torch.return_types.sort(
values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
indices=tensor([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 1, 3, 5, 7, 9, 11, 13, 15, 17]))