FixMatch의 트레이닝 디테일을 살펴보자

2024. 8. 20. 21:18머신러닝&딥러닝/비지도학습, 준지도학습

728x90

FixMatch (Sohn et al, 2020)는 적은 수의 labeled data와 많은 수의 unlabeled data가 섞인 데이터셋에서 학습을 할 때 pseudo-labeling 기반으로 unlabeled data를 utilize하는 반지도학습(semi-supervised learning) 기법이다.

 

Fixmatch는 Mixup이라는 다소 복잡(?)한 unlabeled data utilization 기법을 가진 MixMatch(2019)에 비해 단순하면서도 보편적인 아이디어를 제공하여 많은 반지도학습 상황에 응용되어 왔다.

 

아이디어를 간단하게 소개하면, labeled data는 일반적인 supervised learning처럼 이용하고, unlabeled data는 모델의 confidence가 특정 임계값(threshold)을 넘으면 pseudo-label을 붙여서 학습에 이용하는 구조이다. 이 threshold는 상수값으로 고정된다. (이후 발전된 알고리즘들에서는 threshold를 adaptive하게 조정한다.)

 

반지도학습에서 unlabeled loss는 일반적으로 consistency regularization을 사용한다. 즉 레이블이 없더라도 같은 데이터라면 weakly augmented와 strongly augmented data에 대해 모델은 같은 값으로 예측해야 한다.

즉 unlabeled loss는 위와 같은 식으로 구성된다.

하이퍼파라미터

배치사이즈 B에 대해, labeled data는 아래 loss로 supervised model과 동일하게 계산된다.

B개의 labeled data가 있을 때, unlabeled data는 $\mu$개가 인퍼런스된다. 즉 $\mu = 7$로 두면, B개의 labeled data에 대해 7B개의 unlabeled data가 인퍼런스된다. 여기서 모델이 출력한 confidence 값이 threshold 이상이면, unlabeled data의 pseudo-labed과 모델의 예측값 사이의 cross entropy loss 가 loss term에 추가된다.

두 가지 loss term은 $\lambda$라는 하이퍼파라미터로 합쳐진다. 이 값이 0이면 supervised learning과 동일해지고, 커질수록 unlabeled data의 영향력이 커진다.

 

Threshold에 대한 하이퍼파라미터 테스트를 해본 결과 0.8 아래로는 높은 성능이 나오지 않는 것 같다. Temperature같은 경우 Softmax 시 각 logit값을 sharpening하는 상수를 말한다. threshold를 설정하는 경우 temperature를 1.0 근처로 두는 것이 좋은 것 같다.

Augmentation

Consistency Regularization을 위해서는 Weak augmentation과 Strong augmentation을 적용해야 하는데, 본 논문에서는 Weak로 정석적인 flip and shift augmentation을 주었다.

shift augmentation의 예시(출처 : Dataaspirant)

 

구체적으로, SVHN(숫자 이미지)를 제외한 데이터에 대해 50% 확률로 horizontalflip을, 12.5% 확률로 translation을 주었다.

 

Strong augmentation을 위해서는 RandAugment를 사용하였다. 이는 여러 가지 다양한 Augmentation 중 랜덤하게 하나를 골라 적용하는 것이다. 아래는 반지도학습 프레임워크 LAMDA-SSL 에서 구현한 Strong Augmentation이다.

def AutoContrast(X, **kwarg):
    return PIL.ImageOps.autocontrast(X)


def Brightness(X, min_v, max_v,magnitude,num_bins=10):
    v = min_v+float(max_v -min_v) * magnitude/ num_bins
    return PIL.ImageEnhance.Brightness(X).enhance(v)


def Color(X, min_v, max_v,magnitude,num_bins=10):
    v = min_v+float(max_v -min_v) * magnitude/ num_bins
    return PIL.ImageEnhance.Color(X).enhance(v)


def Contrast(X, min_v, max_v,magnitude,num_bins=10):
    v = min_v+float(max_v -min_v) * magnitude/ num_bins
    return PIL.ImageEnhance.Contrast(X).enhance(v)

def Equalize(X, **kwarg):
    return PIL.ImageOps.equalize(X)


def Identity(X, **kwarg):
    return X


def Invert(X, **kwarg):
    return PIL.ImageOps.invert(X)


def Posterize(X, min_v, max_v,magnitude,num_bins=10):
    v = int(min_v+(max_v -min_v) * magnitude/ num_bins)
    return PIL.ImageOps.posterize(X, v)


def Rotate(X, min_v, max_v,magnitude,num_bins=10):
    v = int(min_v+(max_v -min_v) * magnitude/ num_bins)
    if random.random() < 0.5:
        v = -v
    return X.rotate(v)


def Sharpness(X, min_v, max_v,magnitude,num_bins=10):
    v = min_v+float(max_v -min_v) * magnitude/ num_bins
    return PIL.ImageEnhance.Sharpness(X).enhance(v)


def ShearX(X, min_v, max_v,magnitude,num_bins=10):
    v = min_v+float(max_v -min_v) * magnitude/ num_bins
    if random.random() < 0.5:
        v = -v
    return X.transform(X.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))


def ShearY(X, min_v, max_v,magnitude,num_bins=10):
    v = min_v+float(max_v -min_v) * magnitude/ num_bins
    if random.random() < 0.5:
        v = -v
    return X.transform(X.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))


def Solarize(X, min_v, max_v,magnitude,num_bins=10):
    v = int(min_v+(max_v -min_v) * magnitude/ num_bins)
    return PIL.ImageOps.solarize(X, 256 - v)



def TranslateX(X, min_v, max_v,magnitude,num_bins=10):
    v = min_v+float(max_v -min_v) * magnitude/ num_bins
    if random.random() < 0.5:
        v = -v
    v = int(v * X.size[0])
    return X.transform(X.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))


def TranslateY(X, min_v, max_v,magnitude,num_bins=10):
    v = min_v+float(max_v -min_v) * magnitude/ num_bins
    if random.random() < 0.5:
        v = -v
    v = int(v * X.size[1])
    return X.transform(X.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))

 

Optimizer

일반적인 지도학습에서는 최신 옵티마이저인 Adam, AdamW를 주로 사용한다. 하지만, 본 논문에 따르면 Adam보다 SGD + Nesterov accelerating gradient를 추가한 옵티마이저가 성능이 더 좋다고 한다.

 

또한, learning rate scheduling에는 https://arxiv.org/abs/1608.03983논문에 나온 cosine learning rate decay를 사용했는데, 이를 learning rate를 $ \eta \cos\left(\frac{7\pi k}{16K}\right) $ 로 설정한다. $k/K$는 항상 0과 1 사이이므로 코사인값은 1에서 0으로 감소하는 식으로 스케줄링된다.

 

SGDR: Stochastic Gradient Descent with Warm Restarts

Restart techniques are common in gradient-free optimization to deal with multimodal functions. Partial warm restarts are also gaining popularity in gradient-based optimization to improve the rate of convergence in accelerated gradient schemes to deal with

arxiv.org

 

LAMDA-SSL 기준 Cosine Warmup은 아래와 같이 정의된다.

class CosineWarmup(LambdaLR):
    def __init__(self,
                 num_training_steps,
                 num_warmup_steps=0,
                 num_cycles=7./16,
                 last_epoch=-1,
                 verbose=False):
        # >> Parameter:
        # >> - num_training_steps: The total number of iterations for training.
        # >> - num_warmup_steps: The number of iterations to warm up.
        # >> - num_cycles: The upperbound of the multiplicative factor is num_cycles PI.
        # >> - last_epoch: The index of the last epoch.
        # >> - verbose: Whether to output redundant information.
        self.num_warmup_steps=num_warmup_steps
        self.num_cycles=num_cycles
        self.num_training_steps=num_training_steps
        self.verbose=verbose
        super().__init__(lr_lambda=self._lr_lambda,last_epoch=last_epoch,verbose=self.verbose)

    def _lr_lambda(self,current_step):
        if current_step < self.num_warmup_steps:
            return float(current_step) / float(max(1, self.num_warmup_steps))
        no_progress = float(current_step - self.num_warmup_steps) / \
            float(max(1, self.num_training_steps - self.num_warmup_steps))
        return max(0., math.cos(math.pi * self.num_cycles * no_progress))

 

Momentum은 0.9일 때까지 성능이 계속 좋아지다 그 이후 급격히 나빠지는 것을 확인할 수 있다. (Momentum 1.0은 gradient 반영 없이 관성만으로 파라미터가 움직인다) 또한 learning rate는 Adam 등에서와 달리 $10^{-2} \sim 10^{-1}$ 스케일에서 최적을 보였다.

이때, $\eta$를 labeled와 unlabeled의 비율인 $\mu$가 작을 때 batch size 와 선형적으로 스케일링해야 효과적이라고 한다.

 

 

Custom Dataset 적용

아래는 custom dataset(medical image) binary classification task를 위해 FixMatch를 직접 적용해 본 사례이며, 시험삼아 1 epoch 만 돌려 보았다. 하이퍼파라미터는 아래 이미지와 같이 적용하였다.

 

Tensorboard를 이용해 iteration마다 loss값을 찍어 봤는데, 일부 Loss가 크게 튀는 지점들이 있었다. Unlabeled data의 효과일 것으로 추측했는데, 모델이 안정적으로 학습이 완료되었다고 평가하려면 이러한 fluctuation이 줄어들때까지 학습시켜야 할 것 같다. (물론 overfitting을 막기 위해 Validation을 넣어야 한다)

 

아래는 1024 iteration 간격으로 validation accuracy를 찍어본 것인데, 역시 더 학습시켜도 될 것 같다.

 

반응형