Pytorch関数の徹底解説 - torch.sort

カテゴリ: 「Pytorch 関数を簡単に理解する」総合カタログ


入力テンソルの要素を、指定された次元に沿った値で並べ替えます。指定しない場合はdim、入力の最後の次元が選択されます。descending指定した場合True、要素は値によって降順に並べ替えられ、それ以外の場合は昇順に並べ替えられます。stableそうであればTrue、ソート ルーチンは安定し、同等の要素の順序が維持されます。

文法

torch.sort(input, dim=-1, descending=False, stable=False, *, out=None)

パラメータ

  • input:[ Tensor] 入力テンソル
  • dim: [オプション、int] 並べ替えの基準となるディメンション
  • 降順 : [オプション、bool] 並べ替え順序、False昇順、True降順、デフォルトでは昇順
  • 安定 : [オプション、bool] ソートが安定しているかどうかを示します。安定している場合はTrue、同等の要素の順序が保持されます。

戻り値

タプル(values, indices)( はvaluesソートされた値) は、indices元の入力テンソルの要素のインデックスです。

>>> x = torch.randn(3, 4)
>>> sorted, indices = torch.sort(x)
>>> sorted
tensor([[-0.2162,  0.0608,  0.6719,  2.3332],
        [-0.5793,  0.0061,  0.6058,  0.9497],
        [-0.5071,  0.3343,  0.9553,  1.0960]])
>>> indices
tensor([[ 1,  0,  2,  3],
        [ 3,  1,  0,  2],
        [ 0,  3,  1,  2]])

>>> sorted, indices = torch.sort(x, 0)
>>> sorted
tensor([[-0.5071, -0.2162,  0.6719, -0.5793],
        [ 0.0608,  0.0061,  0.9497,  0.3343],
        [ 0.6058,  0.9553,  1.0960,  2.3332]])
>>> indices
tensor([[ 2,  0,  0,  1],
        [ 0,  1,  1,  2],
        [ 1,  2,  2,  0]])
        
>>> x = torch.tensor([0, 1] * 9)
>>> x.sort()
torch.return_types.sort(
    values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
    indices=tensor([ 2, 16,  4,  6, 14,  8,  0, 10, 12,  9, 17, 15, 13, 11,  7,  5,  3,  1]))
>>> x.sort(stable=True)
torch.return_types.sort(
    values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
    indices=tensor([ 0,  2,  4,  6,  8, 10, 12, 14, 16,  1,  3,  5,  7,  9, 11, 13, 15, 17]))

おすすめ

転載: blog.csdn.net/hy592070616/article/details/131884831