PyTorch 1.5がリリースされました。このバージョンには、いくつかの新しい主要なAPIといくつかの改善、C ++フロントエンドの主要な更新、コンピュータビジョンモデルの「チャネルラスト」ストレージフォーマット、モデルの並列トレーニングが含まれています分散RPCフレームワークの安定版。
同時に、このバージョンは、ヘッセ行列とヤコビ行列のautogradの新しいAPIと、カスタムC ++クラスの作成を可能にするpybindに触発されたAPIも提供します。
C ++ FRONTEND API(安定版)
以前実験的としてマークされていたC ++フロントエンドAPIは現在Pythonと同等であり、全体的な機能は「安定した」状態に移行しています。主なハイライトは次のとおりです。
- カバレッジは約100%に達し、C ++ torch :: nnモジュール/関数に関するドキュメントを提供します。開発者はモデルをPython APIからC ++ APIに簡単に変換できます。
- C ++のオプティマイザーはPythonのオプティマイザーとは異なります。C++オプティマイザーは入力としてパラメーターグループを受け取ることができませんが、Pythonでは可能です。さらに、step関数の実装は完全に同じではありません。バージョン1.5では、C ++オプティマイザーは常に同等のPythonと同じように動作します。
- C ++にテンソル多次元インデックスAPIがないことはよく知られた問題であり、PyTorch GitHubの課題追跡とフォーラムで多くの議論を引き起こしています。以前の解決策は、arrow / select / index_select / masked_selectを組み合わせて使用することでした。これは、Python APIのエレガントなテンソル[:、0、...、mask]構文に比べて扱いにくく、エラーが発生しやすくなります。バージョン1.5では、開発者はtensor.index({Slice()、0、 "..."、mask})を使用して同じ目的を達成できます。
コンピュータビジョンモデルの「最後のチャネル」メモリ形式(実験的)
コンピュータビジョンモデルの「チャネルが最後」のストレージフォーマットは、現在実験段階にあります。このフォーマットのメモリレイアウトは、畳み込みアルゴリズムとハードウェア(NVIDIAのTensor Core、FBGEMM、QNNPACK)のパフォーマンスと効率を最大限に発揮できます。さらに、演算子を介して自動的に伝播するように設計されているため、メモリレイアウトを簡単に切り替えることができます。
カスタムC ++クラス(試験的)
このバージョンでは、カスタムC ++クラスをTorchScriptとPythonに同時にバインドするための新しいAPI torch.CutomClassHolderが追加されています。このAPIの構文はpybind11とほぼ同じです。開発者は、C ++クラスとメソッドをTorchScript型システムとランタイムシステムに公開して、TorchScript / PythonでC ++オブジェクトをインスタンス化して操作できるようにします。C ++バインディングの例:
template <class T>
struct MyStackClass : torch::CustomClassHolder {
std::vector<T> stack_;
MyStackClass(std::vector<T> init) : stack_(std::move(init)) {}
void push(T x) {
stack_.push_back(x);
}
T pop() {
auto val = stack_.back();
stack_.pop_back();
return val;
}
};
static auto testStack =
torch::class_<MyStackClass<std::string>>("myclasses", "MyStackClass")
.def(torch::init<std::vector<std::string>>())
.def("push", &MyStackClass<std::string>::push)
.def("pop", &MyStackClass<std::string>::pop)
.def("size", [](const c10::intrusive_ptr<MyStackClass>& self) {
return self->stack_.size();
});
次のように、PythonおよびTorchScriptで使用できるクラスを公開します。
@torch.jit.script
def do_stacks(s : torch.classes.myclasses.MyStackClass):
s2 = torch.classes.myclasses.MyStackClass(["hi", "mom"])
print(s2.pop()) # "mom"
s2.push("foobar")
return s2 # ["hi", "foobar"]
分散型RPCフレームワークAPI(安定版)
分散RPCフレームワークはバージョン1.4で実験的な形で登場し、現在は安定した状態です。このプロセスには、分散RPCフレームワークを全体的に信頼性と堅牢性を高めるための多くの機能強化とバグ修正が含まれます。プロファイリングサポート、RPCでのTorchScript関数の使用、使いやすい拡張機能など、2つの新機能も追加されました。
さらに、1.5以降、PyTorchはPython 2をサポートしなくなりました。将来、PythonのサポートはPython 3、特にPython 3.5、3.6、3.7、および3.8に限定される予定です。
より具体的な詳細については、発表を参照してください。
https://pytorch.org/blog/pytorch-1-dot-5-released-with-new-and-updated-apis