일반적인 분류(classification)문제를 다루다보면 항상 마지막에 linear layer가 들어가게 된다.

이는 우리가 생각할수 있는 가장 간단한 방식이고, 이는 말이 된다.

 

그런데 만약 입력 텐서의 크기가 고정되지 않다면 출력 텐서의 크기도 마찬가지로 다르게 나온다, 여기서 문제가 보통 사용하는 linear layer는 입력 크기가 고정되기에 입력 사이즈가 변화되면 에러가 발생한다.

 

그래서 이를 위해서 입력에 관계없이 출력을 고정하도록 설계된 Adaptive Pooling을 사용할 수 있다.[1, 2]

 

단순히 출력을 지정한 튜플값으로 맞춰주는 pooling이라고 생각하면 된다.

 

예로 다음과 같은 텐서(b, c, h, w)를 만들어서 Max pooling을 했다고 가정한다.

>>> a = torch.randn(1, 2, 5, 5)
>>> b = F.max_pool2d(a, stride=1, kernel_size=2)
>>> b.size()
torch.Size([1, 2, 4, 4])
>>> b
tensor([[[[ 1.4379,  1.4379,  1.0259,  1.0259],
          [ 1.8803,  0.7414,  0.0418,  0.8152],
          [ 1.8803,  0.7414,  0.0418, -0.0604],
          [ 1.3686,  0.2596,  1.2160,  1.2160]],

         [[ 1.5899,  0.6822,  1.4687,  1.4687],
          [ 1.5899,  0.6822,  1.4687,  1.4687],
          [-0.1199,  0.7847,  0.7847,  0.7320],
          [-0.2405,  0.7847,  0.7847,  0.6991]]]])

그리고 위와 같은 텐서는 Adaptive Max Pooling을 사용하면 다음과 같이 표현할수 있다.

>>> c = F.adaptive_max_pool2d(a, (4, 4))
>>> c.size()
torch.Size(1, 2, 4, 4)
>>> c
tensor([[[[ 1.4379,  1.4379,  1.0259,  1.0259],
          [ 1.8803,  0.7414,  0.0418,  0.8152],
          [ 1.8803,  0.7414,  0.0418, -0.0604],
          [ 1.3686,  0.2596,  1.2160,  1.2160]],

         [[ 1.5899,  0.6822,  1.4687,  1.4687],
          [ 1.5899,  0.6822,  1.4687,  1.4687],
          [-0.1199,  0.7847,  0.7847,  0.7320],
          [-0.2405,  0.7847,  0.7847,  0.6991]]]])

Adaptive Max Pooling의 매개변수로 위 코드처럼 튜플값이 들어가는데 이는 2d Max pooling을 적용하고 나올 결과의 (h, w)를 의미한다. 위 코드 경우에는 (b, c, 4, 4)로 나오기를 원해서 (4, 4)라는 튜플을 넘겨준 것 이다.

 

*참고로 torchvision의 resnet은 Adaptive Avg Pool2d가 적용되었다. pre-trained weight은 이미지넷 224사이즈로 학습이 된걸로 보인다.[3]

>>> import torchvision.models as models
>>> models.resnet18()
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=1000, bias=True)
)

references

1. "Contiguous vs non-contiguous tensor", Nov 2018, discuss.pytorch.org/t/contigious-vs-non-contigious-tensor/30107

2. "Adaptive_avg_pool2d vs avg_pool2d", Oct 2018, discuss.pytorch.org/t/adaptive-avg-pool2d-vs-avg-pool2d/27011

3. "TORCHVISION.MODELS", pytorch.org/docs/stable/torchvision/models.html

+ Recent posts