기초 노트/PyTorch

파이토치 에러 핸들링 - TypeError: cross_entropy_loss(): argument 'input' (position 1) must be Tensor, not NoneType

플로라도 2024. 4. 14. 22:59
loss = nn.CrossEntropyLoss()(preds, label.to(device))

 

에러 메세지 : TypeError: cross_entropy_loss(): argument 'input' (position 1) must be Tensor, not NoneType

 

해당 에러 메세지는 학습 루프안의 loss를 계산하는 과정에서 발생하였고,

사유는 첫번째 포지션 인자인 preds가 NoneType이라는 에러였지만, 

원인은 모델의 forward 메서드에서 리턴값이 누락된 경우였다.