플로라도의 data workout
PyTorch cuda에서 CPU추론으로 변경할때 확인해야 할 사항 본문
1.모델
1-1. 가중치 로드
model.load(weight_path, map_location=device) # 코드 확인
1-2. 모델 객체 이동 확인
모델객체.to(device) # 코드 확인
1-3. 모델 평가모드 설정
with torch.no_grad() 루프와 model.eval()모드 확인
2. tensor이동
'data(tensor).to(device)' 코드 확인
3. nn.DataParallel 삭제
멀티-GPU 환경에서 학습한 스크립트 일부를 이용할시, 추론단계에서는 'nn.DataParallel' 관련 코드 삭제
<참고>
<1>
model.to('cuda')
model = nn.DataParallel(model)
<2>
model.to('cuda')
<1>과 <2>의 model의 state_dict의 key값은 상이하다.
이후 모델 저장 및 로드시 문제가될 수 있음
4. cudnn 관련코드 삭제
'torch.backends.cudnn' 관련 코드 삭제
<참고> 관련
model.load_state_dict(state_dict, strict=False)
멀티-GPU로 학습된 모델을 배포하기 위해서 CPU 환경에서 inference세팅을 맞추는데
흔히 가중치가 꺠진 문제, torch.size관련 runtimError등이 아니라 정상적으로 모델 코드가 작동하는데도 불구하고, output이 가중치가 깨진 마냥 매우 튀는 경우가 발생하였다.
현상의 원인은 strict=False에 있었는데, 멀티-GPU로 학습시 nn.DataParallel(model)에 의해 학습이 완료된 model의 가중치 state_dict의 레이어의 key값들은 'module.'이라는 prefix가 붙기 때문에
inference.py에서 단일 CPU환경으로 로드한 model의 state_dict의 key값과 일치하지 않기 때문이었다.
모델의 가중치를 strict=True옵션으로 로드한 뒤 inference파일을 다시 실행하면,
원래 저장되어있던 pretrained weight의 state_dict key값인 "module.FeatureExtraction.ConvNet.layer4.2.bn1.num_batches_tracked" 등의 다수 "module.' prefix를 가진 missing key에러를 확인할 수 있었으며,
기존 학습된 가중치의 key값들이 nn.DataParallel버전임을 확인하고 copyStateDict라는 아래의 함수를 통해 nn.DataParallel로 생성된 state_dict의 이름들을 맞춰줘서 해결하였다.
def copyStateDict(state_dict):
if list(state_dict.keys())[0].startswith("module"):
start_idx = 1
else:
start_idx = 0
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = ".".join(k.split(".")[start_idx:])
new_state_dict[name] = v
return new_state_dict
위의 copyStateDict함수는 아래의 github에서 참고하였으며, PyTorch 포럼 다수에서도 state_dict의 key값 일부를 차용해서 조정하는 아래와 같은 코드가 많이 공유되어 있는 것을 확인할 수 있었다.
def copyStateDict(state_dict):
if not next(iter(state_dict)).startswith('module.'):
# Not a DataParallel model
return state_dict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove 'module.' prefix
new_state_dict[name] = v
return new_state_dict
Reference)
'기초 노트 > PyTorch' 카테고리의 다른 글
파이토치(PyTorch) 데이터로더(DataLoader)의 모든 것 (0) | 2024.06.15 |
---|---|
파이토치 torch.nn.functional.avg_pool2d vs torch.nn.AvgPool2d의 차이점 (0) | 2024.05.02 |
파이토치 Global Average pooling 구현 (0) | 2024.04.30 |
파이토치 에러 핸들링 - TypeError: cross_entropy_loss(): argument 'input' (position 1) must be Tensor, not NoneType (0) | 2024.04.14 |
파이토치 텐서의 기본 구성요소 (0) | 2024.04.14 |