문제 설명:
학습에 WongKinYiu/PyTorch_YOLOv4를 사용할 때 오류가 보고됩니다.
Traceback (most recent call last):
File "train.py", line 15, in <module>
from torch.cuda import amp
ImportError: cannot import name 'amp'
원인 분석:
1. PyTorch1.6 이상에서만 torch.cuda에서 amp를 가져올 수 있습니다.
2. 그렇지 않으면 apex를 직접 설치해야 하며 소스 코드가 다음으로 변경됩니다.
from apex import amp
해결책:
1. PyTorch 및 CUDA 버전이 일치하는지 확인합니다.
https://pytorch.org/get-started/previous-versions/
2. PyTorch 버전을 1.6 이상으로 업데이트합니다(모델에 호환되지 않는 다른 버전이 있을 수 있으므로 권장됨). 암호)
또는
에이펙스 설치
git clone https://github.com/NVIDIA/apex.git
cd apex
python setup.py install --cpp_ext --cuda_ext
(有时会安装失败,检查PyTorch和CUDA版本 或者 去掉--cuda_ext 便可顺利安装)
from torch.cuda import amp
로 변경 됩니다 from apex import amp
.