CNN에서 backpropagation이 이루어지는 원리

2024. 3. 27. 11:04머신러닝&딥러닝/CNN

728x90

MLP(Multilayered Perceptron)에서 backpropagation은

1) Loss function의 input에 대한 편미분값의 역전파
2) Loss function의 parameter에 대한 편미분값의 계산

 

의 두 가지로 이루어진다. 이때 1) 에서 input은 CNN에서 커널 weight이 된다. CNN에서는 input을 1차원 벡터로 받는 MLP와 달리, 텐서를 input으로 받을 수 있으며 아래 세 가지 가정을 따른다.

1) Local Receptive Field: 커널을 사용하여 위치가 근접한 셀들의 정보를 읽는다
2) Shared Weights: 하나의 kernel을 이용한 컨볼루션 연산은 모두 같은 가중치와 bias를 공유한다
3) sub-sampling: 불필요한 정보를 없애기 위해 공간적 해상도를 낮추는 pooling을 사용한다

 

아래 글에서 MLP에서 backpropagation 연산이 어떤 방식으로 이루어지는지 chain rule을 통해 계산했다. 그렇다면 CNN에서는 위 3가지 가정을 만족하면서 어떻게 backpropagation을 계산할까?

https://cascade.tistory.com/36

 

Gradient descent & Backpropagation (Rumelhart 1986 논문 리뷰)

딥러닝 분야에는 다양한 아키텍처가 있다. CNN, RNN, Transformer 등으로 계보를 잇는 아키텍처들은 새로운 optimization 알고리즘이나 원리를 추가하면서 발전하지만, 모두 기본적인 학습 방식은 비슷하

cascade.tistory.com

 

CNN에서의 backpropagation 계산

커널 weight에 대한 편미분값 계산

X: input, Z: output, K: kernel weight, B: bias

 

CNN도 MLP와 마찬가지로 forward propagation 과정과 backpropagation 과정을 따른다. 앞서 말한 대로, output에 대한 Loss function의 편미분값은 input에 대한 Loss function의 편미분값으로 계산되며, 이 과정에서 Kernel과 bias에 대한 Loss function의 편미분값이 계산된다.

위는 3x3 input에 2x2 kernel이 convolution되는 간단한 과정을 나타낸 그림이다. Kernel weight에 대한 loss function의 편미분값은 chain rule에 의해 아래와 같이 계산된다.

 

이때 dZ/dK는 쉽게 알 수 있는데, 위의 convolution 전개식에서 Z는 X*K 꼴들의 선형결합으로 되어 있기 때문에 dZ/dK를 계산하면 K 앞에 계수처럼 붙어 있는 X만 남는다. 즉, 아래와 같이 식을 바꿀 수 있다.

 

그러면 dL/dK를 아래와 같이 행렬 형태로 고쳐 쓴다면, 아래와 같이 표현할 수 있다.

따라서, loss function의 K에 대한 미분은 conv(X, dL/dZ)로 계산할 수 있다.


Input에 대한 편미분값 계산

 

Input에 대한 Loss의 편미분값은 마찬가지로 chain rule에 의해 아래와 같이 적을 수 있다.

이번에는 dZ/dX를 소거해야 하기 때문에, K와 관련된 항이 남을 것으로 예상해볼 수 있다.

X_11으로 미분할 때는 K_11이, X_12로 미분할 때는 K_12와 K_11이 남음을 알 수 있다. 이걸 쭉 전개해 보자.

 

쭉 전개하면... 뭔가 잘 보이진 않는데 결론을 보면 "그렇구나" 하게 된다. 이는 dL/dZ 행렬에 zero padding을 씌운 뒤, 커널 weight를 180도 회전한 행렬을 convolution한 결과가 된다.

 

따라서, 아래와 같이 정리할수 있다.


결론

CNN에서의 backpropagation 계산도 MLP와 마찬가지로 아래의 두 과정을 따른다.

1) Loss function의 input에 대한 편미분값의 역전파
2) Loss function의 parameter에 대한 편미분값의 계산

 

그리고 각각의 과정은 아래 수식으로 계산이 된다.

반응형