2024. 6. 25. 01:29ㆍ머신러닝&딥러닝/CNN
U-Net은 2015년 공개된 모델로 segmentation task에서 가장 영향력이 컸던 모델 중 하나이다. 본 포스팅에서는 U-Net을 간단하게 요약하면서, pytorch로 구현한 아키텍처를 살펴보고자 한다.
Introduction
기존의 CNN 모델은 주로 분류(classification)에 초점을 맞추고 있었다. 그러나, 많은 경우, 특히 의생명 분야의 이미지에서는 단순히 어떤 카테고리에 속하는지보다 그 카테고리에 속하는 물체가 이미지 어디에 있는지 구분해야 하는 경우가 있다. 즉, 이는 픽셀 단위로 분류 작업을 하는 것으로 볼 수 있으며, 각 픽셀은 특정 레이블로 매핑되어야 한다. 이러한 segmentation 작업을 위한 기존 연구로는 Ciresan et al.이 있다. 이 논문에서는 Sliding Window 방식을 사용하여 단일 픽셀 수준에서 객체를 분류했으며, 이는 매우 높은 계산 비용을 초래한다.

이 모델은 localization, 즉 근접한 위치 정보를 학습시킬 수 있다는 장점이 있으나, 패치 중첩 등으로 매우 느리다는 단점이 있었고 localization과 context 간의 trade-off가 발생한다는 단점이 있었다.
Architecture of U-Net

U-Net의 아키텍처는 세 가지 특징을 가진다.
- Contraction Path: 1채널(grayscale)짜리 572 x 572 이미지를 받는다. (실제 이미지는 이보다 작으며, Mirror padding이라는 기법이 적용되었다) 이를 double-conv, maxpool을 통해 bottleneck으로 임베딩한다.
- Expansion Path: Bottleneck을 up-sampling하여
- Skip Connection: Contraction path에서 double-conv를 거친 텐서들이 그대로 복사되어 Expansion path로 들어간다.
코드 구현 with pytorch
코드 구현은 Aladdin Persson (www.youtube.com/@AladdinPersson) 이 구현한 U-Net에서 조금 변형을 가했다. 전체 버전은 https://github.com/owenchokor/paper_review/tree/main의 unet2.py에서 확인할 수 있다.
GitHub - owenchokor/paper_review: AI & Medical paper reviews
AI & Medical paper reviews. Contribute to owenchokor/paper_review development by creating an account on GitHub.
github.com
input으로 들어온 텐서 x는 아래와 같은 순서를 거친다.
(Contraction) DoubleConv - maxpool - DoubleConv - maxpool - Doubleconv - maxpool - DoubleConv - maxpool
(BottleNeck) DoubleConv
(Expansion) UpConv - (+skip connection) DoubleConv - UpConv - (+skip connection) DoubleConv - UpConv - (+skip connection) DoubleConv - UpConv - (+skip connection) DoubleConv - final conv
DoubleConv 구현
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace = True),
nn.Conv2d(out_channels, out_channels, 3, 1, 1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.conv(x)
UNET 정의
class UNET(nn.Module):
def __init__(
self, in_channels=3, out_channels=1, features = [64, 128, 256, 512]
):
super(UNET, self).__init__()
self.upconvs = nn.ModuleList()
self.updoubles = nn.ModuleList()
self.downs = nn.ModuleList()
self.pool = nn.MaxPool2d(2, 2)
for feature in features:
self.downs.append(DoubleConv(in_channels, feature))
in_channels = feature
for feature in reversed(features):
self.upconvs.append(
nn.ConvTranspose2d(feature*2, feature, 2, 2),
)
for feature in reversed(features):
self.updoubles.append(DoubleConv(feature*2, feature))
self.bottleneck = DoubleConv(features[-1], features[-1]*2)
self.final_conv = DoubleConv(features[0], out_channels)
def forward(self, x):
skip_connections = []
for md in self.downs:
x = md(x)
skip_connections.append(x)
x = self.pool(x)
x = self.bottleneck(x)
skip_connections.reverse()
for i in range(len(self.upconvs)):
x = self.upconvs[i](x)
try:
x = torch.cat((skip_connections[i], x), dim = 1)
except RuntimeError:
x = TF.resize(x, size = skip_connections[i].shape[2:])
x = self.updoubles[i](x)
return self.final_conv(x)
'머신러닝&딥러닝 > CNN' 카테고리의 다른 글
ImageNet 데이터셋의 AlexNet을 이용한 분류 (A.Krizhevsky 2012 논문 리뷰) (0) | 2024.04.22 |
---|---|
CNN에서 backpropagation이 이루어지는 원리 (0) | 2024.03.27 |
pytorch를 이용한 LeNet-5(1998) 구현 (0) | 2024.03.15 |
[CNN] LeNet-5를 활용한 손글씨 인식 (Yann LeCun 1998 논문 리뷰) (2) | 2024.02.26 |