플로라도의 data workout

PyTorch cuda에서 CPU추론으로 변경할때 확인해야 할 사항 본문

기초 노트/PyTorch

PyTorch cuda에서 CPU추론으로 변경할때 확인해야 할 사항

플로라도 2024. 7. 29. 21:05

1.모델

1-1. 가중치 로드 

model.load(weight_path, map_location=device) #  코드 확인

 

 

1-2. 모델 객체 이동 확인

 

모델객체.to(device) # 코드 확인

 

 

1-3. 모델 평가모드 설정

with torch.no_grad() 루프와  model.eval()모드 확인

 

2.  tensor이동

'data(tensor).to(device)' 코드 확인

 

3. nn.DataParallel 삭제

멀티-GPU 환경에서 학습한 스크립트 일부를 이용할시, 추론단계에서는 'nn.DataParallel' 관련 코드 삭제

 

<참고>

<1>
model.to('cuda')
model = nn.DataParallel(model)

<2>
model.to('cuda')

 

<1>과 <2>의 model의 state_dict의 key값은 상이하다.

이후 모델 저장 및 로드시 문제가될 수 있음

 

4. cudnn 관련코드 삭제

'torch.backends.cudnn' 관련 코드 삭제

 

 

 

 

<참고> 관련

 

model.load_state_dict(state_dict, strict=False)

 

멀티-GPU로 학습된 모델을 배포하기 위해서 CPU 환경에서 inference세팅을 맞추는데

흔히 가중치가 꺠진 문제, torch.size관련 runtimError등이 아니라 정상적으로 모델 코드가 작동하는데도 불구하고, output이 가중치가 깨진 마냥 매우 튀는 경우가 발생하였다.

 

현상의 원인은 strict=False에 있었는데, 멀티-GPU로 학습시 nn.DataParallel(model)에 의해 학습이 완료된 model의 가중치 state_dict의 레이어의 key값들은 'module.'이라는 prefix가 붙기 때문에

inference.py에서 단일 CPU환경으로 로드한 model의 state_dict의 key값과 일치하지 않기 때문이었다.

 

모델의 가중치를 strict=True옵션으로 로드한 뒤 inference파일을 다시 실행하면,

원래 저장되어있던 pretrained weight의 state_dict key값인 "module.FeatureExtraction.ConvNet.layer4.2.bn1.num_batches_tracked" 등의 다수 "module.' prefix를 가진 missing key에러를 확인할 수 있었으며,

기존 학습된 가중치의 key값들이 nn.DataParallel버전임을 확인하고 copyStateDict라는 아래의 함수를 통해 nn.DataParallel로 생성된 state_dict의 이름들을 맞춰줘서 해결하였다.

 

 

 

def copyStateDict(state_dict):
	if list(state_dict.keys())[0].startswith("module"):
			start_idx = 1
	else:
			start_idx = 0
			new_state_dict = OrderedDict()
	for k, v in state_dict.items():
			name = ".".join(k.split(".")[start_idx:])
			new_state_dict[name] = v
	return new_state_dict

 

위의 copyStateDict함수는 아래의 github에서 참고하였으며, PyTorch 포럼 다수에서도 state_dict의 key값 일부를 차용해서 조정하는 아래와 같은 코드가 많이 공유되어 있는 것을 확인할 수 있었다.

 

 

def copyStateDict(state_dict):
    if not next(iter(state_dict)).startswith('module.'):
        # Not a DataParallel model
        return state_dict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:]  # remove 'module.' prefix
        new_state_dict[name] = v
    return new_state_dict

 


Reference)

 

https://github.com/clovaai/CRAFT-pytorch