源码有一个写法:
def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: # noqa: F811
pass
forward
,它的第一个参数 input
是一个 Tensor
类型的变量,第二个参数 hx
是一个可选的 Tensor
类型变量,这里使用了 Python 3.7 引入的类型注解语法。
函数返回值类型是一个由两个 Tensor
类型变量组成的元组(Tuple)【 -> 的意思是返回值】
除此之外,这段代码还包含了一个 pass
语句,它的作用是占位符,表示这个函数目前并没有实现任何功能。如果你需要在这个函数中添加具体的代码实现,可以将这个 pass
语句替换为你的代码。
2、@classmethod
@classmethod
是 Python 中的一个装饰器(Decorator),用于定义类方法。
类方法是与类相关联的方法,而不是与实例相关联的方法。可以直接“类.方法”直接调用类方法而不需要实例化
使用 @classmethod
装饰器来定义一个类方法,可以在方法中使用 cls
参数来引用类本身,而不是实例本身【在类方法中,第一个参数通常被约定为 cls
,表示类本身。然而,你可以使用任何名称。】。例如:
class MyClass:
var = 123
@classmethod
def class_method(cls):
print(cls.var)
# 使用“@classmethod”的时候,可以直接调用类方法,不需要实例化
MyClass.class_method()
# 不使用“@classmethod”的时候,需要实例化才能调用类方法
class1 = MyClass()
class1.class_method()
在上面的示例代码中,我们使用 @classmethod
装饰器定义了一个名为 class_method
的类方法。在该方法中,我们使用 cls
参数来引用类本身,并打印了类变量 var
的值。最后,我们在不创建实例的情况下调用了该方法,输出了类变量的值。
3、eval()
eval() 可以将字符串形式的 Python 表达式作为参数进行求值
例如上面的代码,就可以将一个字符串转为参数传递给方法了,因为有时候你需要动态变化你的参数,所以就有了上面的写法
4、nn.parameter()
nn.Parameter()
是 PyTorch 中用于将 tensor 转换为 nn.Parameter
类型的函数。nn.Parameter
实际上是一个特殊的张量类型,它会被自动注册为模型的可学习参数,即在训练过程中需要更新的参数。与普通的 Tensor 不同,nn.Parameter
的属性包括要求梯度(requires_grad)和所处设备(device)等。
在 PyTorch 中,使用 nn.Parameter
将 tensor 转换为模型参数有两个好处:
-
自动追踪计算图:将 tensor 包装为
nn.Parameter
之后,PyTorch 会自动将其加入计算图中,并记录相应的梯度信息,这样就可以通过自动微分实现反向传播,即计算模型参数的梯度。 -
方便管理模型参数:使用
nn.Parameter
可以使得模型参数更方便地集中管理,例如可以使用model.parameters()
方法自动获取模型中的所有参数,或者使用model.named_parameters()
方法获取模型中每个参数的名称和值等信息。
使用 nn.Parameter
的一般流程是:首先定义一个 tensor,然后将其转换成 nn.Parameter
类型并赋予初始值,最后将其添加到模型中作为可训练参数。下面是一段示例代码:
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.weight = nn.Parameter(torch.randn(3, 5)) # 定义可训练参数为 3x5 的张量
self.bias = nn.Parameter(torch.zeros(3)) # 定义可训练参数为长度为 3 的张量
def forward(self, x):
return torch.matmul(x, self.weight.t()) + self.bias # 使用可训练参数计算输出
在这个例子中,我们定义了一个 MyModel 类,并将 weight 和 bias 转换为 nn.Parameter
类型。在模型的 forward 函数中,我们使用 weight 和 bias 计算输出,这样就可以利用反向传播算法,根据损失函数对 weight 和 bias 进行梯度更新。