생성 모델로 만든 이미지의 평가 방법 (S.Azizi 2023 논문 리뷰)

2024. 4. 15. 20:16머신러닝&딥러닝/생성모델

728x90

생성 모델로 만든 이미지가 원본 이미지랑 얼마나 비슷한지 어떻게 비교할까? 이는 해당 생성 모델의 성능과도 직결되는 문제이기 때문에, 적절한 평가 지표가 필요하다. 현재 융합의학기술원에서 디지털 병리 연구의 일환으로 DDPM으로 사구체 이미지를 생성한 synthetic data를 만들고 있는데, 아래와 같이 육안으로 보았을 때도 생성된 이미지의 Good case와 Bad case를 나눌 수 있다. 하지만, 육안으로 생성된 이미지의 퀄리티를 평가하는 것 말고 조금 더 정량적인 방법은 없을까?

Synthetic Data를 만드는 근본적인 이유는, Classification 혹은 Segmentation 등 다른 task를 수행하는 모델의 데이터셋을 보강하여 accuracy를 높이기 위함이다. 따라서, 가장 실용적이고 중요한 방법은 이들 모델의 성능이 synthetic data로 얼마나 올라가는지 확인하는 것이다. 본 포스팅에서는 논문 리뷰(1)를 통해 이의 구체적인 방법을 알아보도록 하자.
 
(1) S.Azizi et al., Synthetic Data from Diffusion Models Improves ImageNet Classification, CVPR, 2023
 

Introduction

최신 생성모델(Deep Generation Model)

  • DDPMs
  • GANs

최근 연구 결과에 따르면, DDPM이 각종 GAN보다 training stability와 quality 측면에서 더 우수한 성과를 보이고 있다. 성과를 평가한 주요한 task는 class가 조건으로 분류된 generative model, 그리고 open vocabulary text-to-image generation 이다.
 
이러한 합성 이미지 데이터로, 어려운 classification 혹은 discrimination 과제의 정확도를 더 높일 수는 없을까? 합성 데이터로 기존 모델의 성능을 올리는 것을 Generative Data Augmentation이라고 하는데, 본 논문에서는 특히 ImageNet 데이터에 대한 data augmentation을 다룬다. 또한, 이러한 데이터로 학습시킨 class-conditional model이 기존 sota 분류 모델보다 높은 accuracy를 보인다는 점을 시사한다.
 

Related Work

Synthetic Data

 
Synthetic Data는 아래와 같은 작업을 수행할 때 쓰인다.

  • Semantic image segmentation: 이미지의 각 부분을 픽셀 단위로 어떤 카테고리(class/object)에 속하는지 판별하는 것
  • optical flow estimation: 이미지 혹은 비디오에서 움직이는 사물을 트래킹하는 것
  • human motion understanding: Human motion을 학습하여 robotics 등에 적용 가능
  • dense prediction tasts: Medical Data도 dense prediction의 예시가 될 수 있을듯

최근의 diffusion model은 데이터를 학습하여 새로운 데이터를 생성하는 것뿐 아니라 text-to-image로도 우수한 성과를 보이고 있다.

Distilation and Transfer

 
본 논문에서는 큰 multimodal dataset으로 pretrain되고, ImageNet으로 fine-tuning된 diffusion model로 생성한 데이터로 classification model을 훈련해 보았다. 기존 모델과는 classification model 자체를 큰 multimodal dataset으로 pretrain하고 ImageNet으로 fine tuning 했다는 점이 다르다.
 
이러한 방식은 knowledge distillation과 관련이 있다. Knowledge Distillation이란, teacher network로부터 학습한 지식을 student network로 transfer하는 과정을 말한다. 즉 여기서는 diffusion model로 학습한 knowledge를 classifier로 전이하는 것이다.
 

Diffustion Model Applications

 
Diffusiuon Model은 이미지, 음성, 비디오 등 다양한 분야에서 적용된다. 특히, large-scale text-to-image generation 분야에서는 DALL-E 2, Imagen, eDiff, GLIDE 등의 모델이 사용되고 있다. 최근에는 이들을 training data를 augmentation하는 데 사용 가능하다는 연구결과가 많이 나오고 있는데, He et al. 은 GLIDE로 zero-shot learning을 개선할 수 있다는 결과를 발표했으며 Trabucco et al.은 pretrained diffusion model로 few-shot learning을 개선할 수 있다는 결과를 발표했다. 한편, Bansal et al. 과 Sariyildiz et al.에서는 pretrained diffution model만으로는 accuracy가 개선되지 않았다는 결과를 발표했다.
 

Background

Diffusion이란?

 
Diffusion model에 대한 전반적인 설명을 한다. Noising / Denoising 과정을 통해 새로운 이미지를 학습하는 전반적인 과정은 지난 번 DDPM 포스팅에서 다루었으니 여기서는 패스하겠다. https://cascade.tistory.com/59

DDPM을 통한 이미지 생성 및 보간 (J. Ho 2020 논문 리뷰)

DDPM이란?DDPM(Denosing Diffusion Probabilistic Model)은 발전된 형태의 diffusion 생성모델로, 이미지에 gaussian noise 를 조금씩 첨가하여 완전한 noise image로 만들어지는 과정 (q) 을 학습하여, 완전한 noise 이미지

cascade.tistory.com

 

Classification Accuracy Score

 
일반적으로, generative model의 visual quality를 평가할 때는 FID와 IS를 사용한다. 

  • FID(Frechet Inception Distance) score: 영상 집합 간의 거리를 잰다.
  • IS(Inception Score): 영상 집합 자체의 우수도를 평가하는 지표로, 품질과 다양성을 평가한다.

한편 이 지표들은 세 가지의 단점이 있다. 첫째, non-GAN 모델을 더욱 까다롭게 평가하는 경향이 있다. 둘째, sampling 할 때 여러 variation을 준 경우, 이 지표들을 의도적으로 높일 수 있다. 셋째, downstream task의 성능과 완전한 상관관계를 보이는 것은 아니다.
 
따라서, 본 논문에서는 CAS(Classification accuracy score)를 사용하는데, 이는 ResNet-50으로 평가하는 ImageNet validation set을 이용한다. 이제까지의 연구 결과에서는, 100% synthetic data만 사용하거나, 아주 적은 양의 synthetic data가 원본 데이터에 합쳐진 경우에는 CAS 점수가 낮게 나왔다. 반면, Cascade diffusion model은 BigGAN-deep이나 VQ-VAE보다 높은 CAS 점수가 나왔다.
 

Generative Model Training and Sampling

ImageNet ILSVRC 2012 데이터셋(ImageNet-1K)은 1000개 카테고리 및 128만 개의 labeled triaining data, 5만 개의 validation data를 포함한다. 이 데이터셋을 generated data의 평가기준으로 삼았다. 또한 이미지 generation은 large-scale text-to-image diffusion(Imagen) 모델을 사용했고, 프롬프트 임력은 1~2개의 주어진 클래스 이름을 사용했다.
 

Imagen Fine-tuning

 
본 논문에서는 Imagen을 fine tuning해서 사용했는데, 사전학습된 모델은 text와 embedding을 매핑하는 pretrained text encoder를 포함한다. Image Generator는 이 embedding을 바탕으로 64x64 이미지를 생성하며, super resolution model에 의해 256x256으로, 다시 1024x1024로 upsampling한다. 즉 diffusion cascade가 세 단계로 진행되는 것이다. 또한 각 단계를 text cross-attention layer를 포함한다.

  1. 64 x 64 이미지 최초 생성: 파라미터 20억개
  2. 256 x 256 upsampling: 파라미터 6억개
  3. 1024x1024 upsampling: 파라미터 4억개
Imagen 아키텍처 (출처: AssemblyAI)

 
이때, fine tuning은 첫 번째와 두 번째 단계에서만 하고, 셋째 단계인 1024x1024로 upsampling하는 단계는 그대로 두었다. 옵티마이저로는 1단계에서는 Adafactor를, 2단계에서는 Adam을 썼다.
 

Sampling parameters

 
좋은 hyperparameter를 선별하기 위해, FID score를 이용해 최적의 값을 찾았다. 선별의 대상이 되는 hyperparameter는 아래 세 가지이다.

  1. guidance weight: 모델이 이미지의 세부 사항을 얼마나 강조할지를 조절(값이 클수록 선명한 이미지 생성)
  2. log-variance: 샘플링 과정에서 노이즈의 분산 조절, log-variance가 낮으면 더 일관되고 안정된 이미지 생성
  3. sampling step의 횟수: noise 를 제거하는 step 수를 조절

실험해본 결과, logvar coeff = 0.0, guidance weight 1.25, sampling step 수 1000에서 최적의 FID score를 냈다. 이제, 이 hyperparameter를 이용하여, 120만 개의 이미지를 생성 후 FID, IS, CAS 값을 재었다. (CAS를 잴 때는 guidance weight에도 variation을 준 것 같다.) 마찬가지로, guidance weight = 1.25일 때 CAS가 최적이었다.

이제, super resolution을 위한 sampling parameter를 결정해야 한다. 여기서도 마찬가지로 각종 hyperparameter를 조절해 가면서 실험해본 결과, FID와 CAS의 아래와 같은 관계를 찾았다. 전반적으로  FID와 CAS는 correlation이 강한 양상이다.

이 최적의 hyperparameter 를 이용하여, ImageNet의 class마다 같은 수의 데이터로 만들어 데이터셋을 augmentation했다. 그 결과, training dataset을 약 10배 가까이 증폭했다.
 

Results

Sample Quality

 
이 Fine tuning한 Imagen 모델은 FID train, IS 의 메트릭에서 기존의 모든 모델보다 높은 점수를 받았다.

 

Classification Accuracy Score

 
CAS는 class의 종류에 따라 synthetic data의 포함 여부가 다른 영향을 주었다. 아래와 같이 확인했을 때, 기존의 CDM 모델보다 Fine tuned Imagen이 전체적으로 CAS가 높고, 256x256보다는 1024x1024가 더 나은 것을 확인할 수 있다.

순서대로 256x256 CDM, 256x256 Imagen(FineTuned), 1024x1024 Imagen(FineTuned)
원본 데이터와 Synthetic data 합치기

 
원본 데이터와 synthetic data를 섞으면 training 결과가 어떻게 될까? 본 연구에서는 Synthetic Data가 많아질수록 Top-1 accuracy가 높아지는 경향을 보였다.

이것은 기존의 연구와 조금 다른 점이 있는데, Ravuri and Vinyals는 대부분의 모델에서 원본과 synthetic을 섞으면 Top 5 accuracy가 낮아짐을 밝혔고, Big- GAN-deep에서는 적은 synthetic data에서는 조금 상승하지만 많이 섰으면 결국 떨어진다고 밝혔다.

그러나, 이미지 해상도가 높은 경우에는 synthetic dataset이 크더라도 정확도가 더 이상 높아지지 않았다.

반응형