플로라도의 data workout

파이토치 Global Average pooling 구현 본문

기초 노트/PyTorch

파이토치 Global Average pooling 구현

플로라도 2024. 4. 30. 00:01

torchivsion의 pre-trained된 Resnet18, Resnet50과 같은 모델의 가장 마지막단에는 

Global Average Pooling을 사용하여 Linear Classifier의 역할을 대체하고 있다.

 

Global Average Pooling의 아이디어는 Resnet뿐만 아니라 많은 딥러닝 아키텍쳐에서 쓰이는데 

파이토치에서 제공하는 pooling layer에는 별도의 global average pooling layer가 없다.

 

그러나 기존의 average pooling 2d와 adaptive_avg_pool2d로 gap를 간단히 구현할 수 있었다.

 

import torch
import torch.nn.functional as F

# 임의의 텐서를 생성;: 배치 크기 1, 10개의 채널, 28x28 spatial size의 feature map
input_tensor = torch.randn(1, 10, 28, 28)

### Global Average Pooling 적용 (1) ###
gap_output = F.adaptive_avg_pool2d(input_tensor, (1, 1))

### Global Average Pooling 적용 (2) ###
gap_output = F.avg_pool2d(input_tensor, input_tensor.size()[2:])  # Input, Kernel_size

# 두 케이스 모두 결과적으로 각 채널에 대해 하나의 평균값을 가진 1x10 크기의 벡터가 출력됩니다.

# 결과 텐서를 (배치 크기, 채널 수) 형태로 변환
gap_output = gap_output.view(input_tensor.size(0), -1)

print(gap_output.shape)  # torch.Size([1, 10])을 출력합니다.