[DEBUG 일기] 'amp' 이름을 가져올 수 없습니다.

문제 설명:

학습에 WongKinYiu/PyTorch_YOLOv4를 사용할 때 오류가 보고됩니다.

Traceback (most recent call last):
  File "train.py", line 15, in <module>
    from torch.cuda import amp
ImportError: cannot import name 'amp'

원인 분석:

1. PyTorch1.6 이상에서만 torch.cuda에서 amp를 가져올 수 있습니다.
2. 그렇지 않으면 apex를 직접 설치해야 하며 소스 코드가 다음으로 변경됩니다.

from apex import amp

해결책:

1. PyTorch 및 CUDA 버전이 일치하는지 확인합니다.
https://pytorch.org/get-started/previous-versions/
2. PyTorch 버전을 1.6 이상으로 업데이트합니다(모델에 호환되지 않는 다른 버전이 있을 수 있으므로 권장됨). 암호)

또는
에이펙스 설치

git clone https://github.com/NVIDIA/apex.git
cd apex
python setup.py install --cpp_ext --cuda_ext
(有时会安装失败,检查PyTorch和CUDA版本 或者 去掉--cuda_ext 便可顺利安装)

from torch.cuda import amp로 변경 됩니다 from apex import amp.

추천

출처blog.csdn.net/lucifer479/article/details/111322564