在深度学习加载模型时会用到torch.load函数,可能会因为函数的安全性导致报错。
如下列是一段针对模型评估和优化的代码
from sklearn.metrics import confusion_matrix, classification_report
# 加载模型
model = SimpleCNN()
model.load_state_dict(torch.load('pneumonia_model.pth'))
model.eval()
# 在测试集上评估模型
y_true = []
y_pred = []
with torch.no_grad():
for images, labels in test_loader:
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
y_true.extend(labels.cpu().numpy())
y_pred.extend(predicted.cpu().numpy())
# 计算混淆矩阵
cm = confusion_matrix(y_true, y_pred)
# 打印分类报告
print(classification_report(y_true, y_pred, target_names=['NORMAL', 'PNEUMONIA']))
而在实际运行时可能会报错:
FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. model.load_state_dict(torch.load('pneumonia_model.pth'))
这个警告是关于PyTorch的torch.load
函数的安全性。我们可以从基础开始,详细解释这个问题并提供解决方案。
一、问题解释
PyTorch的torch.load
函数默认使用pickle
模块来加载模型权重,而pickle
模块存在潜在的安全风险,因为它可以加载并执行任意Python代码。这意味着,如果你从不受信任的来源加载模型权重,可能会执行恶意代码。
二、解决方案
为了确保安全,PyTorch计划在未来的版本中更改weights_only
的默认值为True
,以限制加载过程中可能执行的函数。以下是解决这个警告的方案
1. 使用weights_only=True
参数
model.load_state_dict(torch.load('pneumonia_model.pth', weights_only=True))
通过设置weights_only=True
,可以确保只加载模型的权重,而不会执行其他可能的代码。
另外要注意:检查文件的来源
确保你从信任的来源加载模型权重。不要从不受信任的来源加载模型文件。
修改后的代码
from sklearn.metrics import confusion_matrix, classification_report
# 加载模型
model = SimpleCNN()
model.load_state_dict(torch.load('pneumonia_model.pth', weights_only=True))
model.eval()
# 在测试集上评估模型
y_true = []
y_pred = []
with torch.no_grad():
for images, labels in test_loader:
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
y_true.extend(labels.cpu().numpy())
y_pred.extend(predicted.cpu().numpy())
# 计算混淆矩阵
cm = confusion_matrix(y_true, y_pred)
# 打印分类报告
print(classification_report(y_true, y_pred, target_names=['NORMAL', 'PNEUMONIA']))