2024. 8. 13. 14:10ㆍ머신러닝&딥러닝/생성모델
본 포스팅에서는 간단한 코드로 구현한 디퓨전 모델(DDPM)의 pytorch 구현 코드를 분석하려고 한다. 코드는 https://colab.research.google.com/drive/1sjy9odlSSy0RBVgMTgP7s99NXsqglsUL?usp=sharing#scrollTo=LQnlc27k7Aiw 를 참고하였다. 작성자는 DeepFindr이다.
DDPM 논문에 대해 어느 정도 이해하고 있다는 전제 하에 작성하였기 때문에, 논문 내용을 모르는 사람은 먼저 읽고 오길 바란다. https://arxiv.org/abs/2006.11239 논문에는 diffusion loss에 대한 증명과 왜 사용할 수 있는지에 대한 내용이 주로 담겨있기 때문에 어디부터 코드로 구현해야 할지 감이 잘 잡히지 않을 수도 있는데, 본 포스팅에서 간단히 설명하고 넘어가고자 한다.
DDPM의 간단한 설명
DDPM은 간단히 말해서, U-Net으로 noise를 prediction하는 아키텍처이다. U-Net은 segmentation하는 데 많이 쓰이는 네트워크인데, segmentation은 픽셀 단위의 classification이므로 본질적으로 noise prediction과 같다.

즉, DDPM을 코드로 구현하기 위해서는 위 figure에서 두 가지가 필요하다. 첫째는 원본 이미지를 noise로 corruption시키는 noise scheduler. 둘째는 noise에서 원본을 복원하는 denoising U-Net. 이 두 가지를 구현하는 것이 DDPM 코드의 핵심이다.
Noise Scheduler의 구현
DDPM은 Hidden Markov Chain으로, x_t를 만들기 위해서는 x_0부터 시작해서 t회 노이즈를 주입해야 한다. 하지만, 논문 본문에서는 이를 q(x_t | x_0)으로 한 번에 처리하는 방법을 설명한다.

즉, 어떤 t를 주어진 timestep (0, 1, 2, ..., T) 에서 랜덤하게 샘플링하고, 해당하는 t만큼 노이즈를 한 번에 주입해 이미지를 corruption시키는 것이 noise scheduler이다. 이때, 주입하는 noise의 평균과 분산을 계산하기 위해서는 alpha, beta 값들이 필요한데, 코드에서는 자주 사용하는 상수값들을 미리 계산하여 벡터로 저장해놓고 여기에서 꺼내 쓴다.
import torch.nn.functional as F
def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
return torch.linspace(start, end, timesteps)
def get_index_from_list(vals, t, x_shape):
"""
Returns a specific index t of a passed list of values vals
while considering the batch dimension.
"""
batch_size = t.shape[0]
out = vals.gather(-1, t.cpu())
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
def forward_diffusion_sample(x_0, t, device="cpu"):
"""
Takes an image and a timestep as input and
returns the noisy version of it
"""
noise = torch.randn_like(x_0)
sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x_0.shape)
sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
sqrt_one_minus_alphas_cumprod, t, x_0.shape
)
# mean + variance
return sqrt_alphas_cumprod_t.to(device) * x_0.to(device) \
+ sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device)
# Define beta schedule
T = 300
betas = linear_beta_schedule(timesteps=T)
# Pre-calculate different terms for closed form
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
미리 정의해 둔 알파 베타 관련 상수들은 아래와 같다.
- 알파, 베타 (더하면 1 되는 관계)
- 알파 바, 베타 바 (torch.cumprod를 사용해서 쭉 곱하면 됨)
- alphas_comprod_prev : F.pad을 이용해 패당을 추가하고, 인덱스를 하나씩 밀어서 t-1번째 인덱스에 접근할 수 있도록 함
- 루트 알파의 역수
- 루트 알파 바
- 루트 1-알파 바
- posterior variance : 아래 값을 저장해 두었다.

정의해둔 상수를 이용하여 forward_diffusion_sample이 리턴하는 값은, 원본 이미지 x_0에 t만큼의 노이즈가 주입된 결과이다. 이 결과의 평균은
sqrt_alphas_cumprod_t.to(device) * x_0.to(device) + sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device)
와 같으며, 분산은
noise.to(device)
이다. Noise Scheduler는 학습되는 파라미터 없이, 미리 정해진 베타와 알파 값들로 동작된다.

timestep t만큼 corruption된 이미지는 위 식에서 epsilon_theta 안에 들어가있는 부분과 같다. 이제, 실제 noise (epsilon)과 예측된 noise (epslion_theta) 간의 loss를 계산하기 위해 U-Net을 구현하자. (위 식 앞에 붙은 괴상한 상수는 무시해도 된다.)
Denoising U-Net의 구현
먼저, U-Net의 double convolution block을 구현하자. 일반적인 U-Net과 다른 점은, Denoising U-Net은 time embedding이 포함된다. time_mlp을 통해 time embedding을 out_ch와 같은 차원으로 보내고, 첫 번째 convolution을 거친 이미지와 더해진다. (concat이 아니다!!)
class Block(nn.Module):
def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
super().__init__()
self.time_mlp = nn.Linear(time_emb_dim, out_ch)
if up:
self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1)
self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
else:
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
self.bnorm1 = nn.BatchNorm2d(out_ch)
self.bnorm2 = nn.BatchNorm2d(out_ch)
self.relu = nn.ReLU()
def forward(self, x, t, ):
# First Conv
h = self.bnorm1(self.relu(self.conv1(x)))
# Time embedding
time_emb = self.relu(self.time_mlp(t))
# Extend last 2 dimensions
time_emb = time_emb[(..., ) + (None, ) * 2]
# Add time channel
h = h + time_emb
# Second Conv
h = self.bnorm2(self.relu(self.conv2(h)))
# Down or Upsample
return self.transform(h)
즉, Denoising U-Net에는 모든 block에 time embedding이 더해진다.

time embedding은 어떻게 만들까? 이는 transformer 논문에 등장한 sinusoidal embedding 방식을 차용한다.

즉, 0~T 의 timestep을 각각 0~T의 위치를 갖는 position처럼 생각하여 sinusoidal embedding 하는 것이다.
class SinusoidalPositionEmbeddings(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, time):
device = time.device
half_dim = self.dim // 2
embeddings = math.log(10000) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
embeddings = time[:, None] * embeddings[None, :]
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
# TODO: Double check the ordering here
return embeddings
이를 합쳐서 U-Net을 구현하자. 3채널짜리 이미지를 initial convolution으로 64채널로 바꾸고, 4번 downsampling, 4번 upsampling하고 output convolution을 통과시켜 3채널짜리 noise prediction을 수행한다.
class SimpleUnet(nn.Module):
"""
A simplified variant of the Unet architecture.
"""
def __init__(self):
super().__init__()
image_channels = 3
down_channels = (64, 128, 256, 512, 1024)
up_channels = (1024, 512, 256, 128, 64)
out_dim = 3
time_emb_dim = 32
# Time embedding
self.time_mlp = nn.Sequential(
SinusoidalPositionEmbeddings(time_emb_dim),
nn.Linear(time_emb_dim, time_emb_dim),
nn.ReLU()
)
# Initial projection
self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)
# Downsample
self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i+1], \
time_emb_dim) \
for i in range(len(down_channels)-1)])
# Upsample
self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i+1], \
time_emb_dim, up=True) \
for i in range(len(up_channels)-1)])
# Edit: Corrected a bug found by Jakub C (see YouTube comment)
self.output = nn.Conv2d(up_channels[-1], out_dim, 1)
def forward(self, x, timestep):
# Embedd time
t = self.time_mlp(timestep)
# Initial conv
x = self.conv0(x)
# Unet
residual_inputs = []
for down in self.downs:
x = down(x, t)
residual_inputs.append(x)
for up in self.ups:
residual_x = residual_inputs.pop()
# Add residual x as additional channels
x = torch.cat((x, residual_x), dim=1)
x = up(x, t)
return self.output(x)
U-Net의 forward 함수에는 x, t가 들어간다. Downsampling/Upsampling block의 double convolution 사이에 time embedding이 더해져야 하기 때문이다.
Denoising U-Net은 학습이 되는 DDPM의 요소로, 모델의 크기를 결정한다. 아래 코드로 U-Net의 파라미터 개수를 확인해 보면, 62,433,123개가 나온다.
model = SimpleUnet()
print("Num params: ", sum(p.numel() for p in model.parameters()))
model
Loss Function
Noise Scheduler에서 노이즈를 주입할 때, time t에 대해 얼마만큼의 noise가 주입되었는지를 알고 있다. 또한, U-Net이 prediction한 noise는 U-Net의 output으로 나온다. 즉, 이 두 노이즈 간의 loss를 계산하면 된다. 여기에서는 L1 loss를 사용하였다. 왜 L2가 아닌 L1을 쓰는지에 대한 나의 추측을 적어보자면... 노이즈는 outlier로 작용할 수 있는 값이 많기 때문에 L2 loss를 사용하면 일부 노이즈에 대해 너무 큰 페널티가 부여되기 때문이 아닐까 한다.
def get_loss(model, x_0, t):
x_noisy, noise = forward_diffusion_sample(x_0, t, device)
noise_pred = model(x_noisy, t)
return F.l1_loss(noise, noise_pred)
위 세 가지 요소를 통해 Denoising U-Net을 학습시킴으로써 DDPM training을 모두 구현할 수 있다.
Sampling
시간 t만큼의 노이즈를 한 번에 주입하는 학습 과정과 달리, 샘플링 과정은 랜덤 노이즈로부터 시작하기 때문에 timestep T (fully corrupted image)에서 한 단계씩 denoising한다. 이것이 DDPM의 고질적인 문제인 느린 샘플링 시간의 주요한 원인이다.

한 단계 denoising을 할 때는, U-Net을 통해 timestep t에서의 noise prediction을 수행하고, 미리 정의해둔 상수들을 위 pseudocode에 나온 것과 같이 조합하여 x_t에서 노이즈를 뺀다. 그리고 mean, variance로 보정하여 x_t-1을 도출한다. 이 과정을 거쳐 x_0을 샘플링한다.
@torch.no_grad()
def sample_timestep(x, t):
"""
Calls the model to predict the noise in the image and returns
the denoised image.
Applies noise to this image, if we are not in the last step yet.
"""
betas_t = get_index_from_list(betas, t, x.shape)
sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
sqrt_one_minus_alphas_cumprod, t, x.shape
)
sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)
# Call model (current image - noise prediction)
model_mean = sqrt_recip_alphas_t * (
x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
)
posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)
if t == 0:
# As pointed out by Luis Pereira (see YouTube comment)
# The t's are offset from the t's in the paper
return model_mean
else:
noise = torch.randn_like(x)
return model_mean + torch.sqrt(posterior_variance_t) * noise