U-Net(O. Ronneberger 2015) 논문리뷰 및 pytorch 구현

2024. 6. 25. 01:29머신러닝&딥러닝/CNN

728x90

U-Net은 2015년 공개된 모델로 segmentation task에서 가장 영향력이 컸던 모델 중 하나이다. 본 포스팅에서는 U-Net을 간단하게 요약하면서, pytorch로 구현한 아키텍처를 살펴보고자 한다.


Introduction

기존의 CNN 모델은 주로 분류(classification)에 초점을 맞추고 있었다. 그러나, 많은 경우, 특히 의생명 분야의 이미지에서는 단순히 어떤 카테고리에 속하는지보다 그 카테고리에 속하는 물체가 이미지 어디에 있는지 구분해야 하는 경우가 있다. 즉, 이는 픽셀 단위로 분류 작업을 하는 것으로 볼 수 있으며, 각 픽셀은 특정 레이블로 매핑되어야 한다. 이러한 segmentation 작업을 위한 기존 연구로는 Ciresan et al.이 있다. 이 논문에서는 Sliding Window 방식을 사용하여 단일 픽셀 수준에서 객체를 분류했으며, 이는 매우 높은 계산 비용을 초래한다.

이 모델은 localization, 즉 근접한 위치 정보를 학습시킬 수 있다는 장점이 있으나, 패치 중첩 등으로 매우 느리다는 단점이 있었고 localization과 context 간의 trade-off가 발생한다는 단점이 있었다.

 

Architecture of U-Net

U-Net의 아키텍처는 세 가지 특징을 가진다.

  1. Contraction Path: 1채널(grayscale)짜리 572 x 572 이미지를 받는다. (실제 이미지는 이보다 작으며, Mirror padding이라는 기법이 적용되었다) 이를 double-conv, maxpool을 통해 bottleneck으로 임베딩한다.
  2. Expansion Path: Bottleneck을 up-sampling하여 
  3. Skip Connection: Contraction path에서 double-conv를 거친 텐서들이 그대로 복사되어 Expansion path로 들어간다.

코드 구현 with pytorch

코드 구현은 Aladdin Persson (www.youtube.com/@AladdinPersson) 이 구현한 U-Net에서 조금 변형을 가했다. 전체 버전은 https://github.com/owenchokor/paper_review/tree/main의 unet2.py에서 확인할 수 있다.

 

GitHub - owenchokor/paper_review: AI & Medical paper reviews

AI & Medical paper reviews. Contribute to owenchokor/paper_review development by creating an account on GitHub.

github.com

input으로 들어온 텐서 x는 아래와 같은 순서를 거친다.

(Contraction) DoubleConv - maxpool - DoubleConv - maxpool - Doubleconv - maxpool - DoubleConv - maxpool
(BottleNeck) DoubleConv
(Expansion) UpConv - (+skip connection) DoubleConv - UpConv - (+skip connection) DoubleConv - UpConv - (+skip connection) DoubleConv - UpConv - (+skip connection) DoubleConv - final conv
DoubleConv 구현
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace = True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)

 

UNET 정의
class UNET(nn.Module):
    def __init__(
            self, in_channels=3, out_channels=1, features = [64, 128, 256, 512]
    ):
        super(UNET, self).__init__()
        
        self.upconvs = nn.ModuleList()
        self.updoubles = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(2, 2)
        
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        for feature in reversed(features):
            self.upconvs.append(
                nn.ConvTranspose2d(feature*2, feature, 2, 2),
            )
        for feature in reversed(features):
            self.updoubles.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = DoubleConv(features[0], out_channels)


    def forward(self, x):
        skip_connections = []
        for md in self.downs:
            x = md(x)
            skip_connections.append(x)
            x = self.pool(x)
        x = self.bottleneck(x)
        skip_connections.reverse()

        for i in range(len(self.upconvs)):
            x = self.upconvs[i](x)
            try:
                x = torch.cat((skip_connections[i], x), dim = 1)
            except RuntimeError:
                x = TF.resize(x, size = skip_connections[i].shape[2:])
            x = self.updoubles[i](x)
        return self.final_conv(x)
반응형