PyTorch 1.0 中文官方教程:使用字符级别特征的 RNN 网络进行姓氏分类

译者:hhxx2015

作者: Sean Robertson

我们将构建和训练字符级RNN来对单词进行分类。 字符级RNN将单词作为一系列字符读取,在每一步输出预测和“隐藏状态”,将其先前的隐藏状态输入至下一时刻。 我们将最终时刻输出作为预测结果,即表示该词属于哪个类。

具体来说,我们将在18种语言构成的几千个姓氏的数据集上训练模型,根据一个单词的拼写预测它是哪种语言的姓氏:

$ python predict.py Hinton
(-0.47) Scottish
(-1.52) English
(-3.57) Irish

$ python predict.py Schmidhuber
(-0.19) German
(-2.48) Czech
(-2.68) Dutch

阅读建议:

我默认你已经安装好了PyTorch,熟悉Python语言,理解“张量”的概念:

事先学习并了解RNN的工作原理对理解这个例子十分有帮助:

阅读全文/改进本文

猜你喜欢

转载自www.cnblogs.com/wizardforcel/p/10321624.html