[Paper Review] Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
Abstract
Deep Neural Network를 학습시킬 때, 각 layer에서 이전 layer의 parameter가 변하므로 input distribution이 매번 달라지게 된다. 이 현상은 training을 learning rate를 줄이고 parameter initialization을 신중하게 함으로써 해결될 수 있으나 training 자체의 어려움은 여전히 존재하게 된다. 이 현상을 internal covariate shift현상이라고 한다.
본 논문에서는 layer의 training batch 마다 input을 normalize해서 해결하고자 한다.
Introduction
SGD algorithm
Deep learning에서 많이 쓰이는 optimization method중 하나인 Stochastic Gradient Descent(SGD)는 아래와 같이 전체 training dataset에 대해서 각 weight의 loss 값을 최소화 하는 방향으로 업데이트 하고자 한다.
이 때, SGD는 Mini batch를 사용하여 parameter를 아래와 같이 업데이트 한다.
위와 같이 Mini batch를 사용하면 장점이 두가지 존재한다.
1. loss에 대한 gradient 값을 계산하는 것이 전체 training dataset에 대해서 고려하는 것과 비슷한 효과를 준다.
2. 각 training sample에 대해서 gradient를 계산하는 것보다 mini batch에 대해서 계산하는 것이 parallel computation을 이용할 수 있으므로 더 효율적이다.
SGD는 간단하고 효율적인 알고리즘임에도 불구하고, Model hyper parameter를 신중히 tuning해야한다는 단점이 존재한다.(learning rate, weight initialization 등)
Internal Covariate Shift
u를 input 이라고 하고, W와 b는 learnable parameter라고 생각하자. g는 sigmoid 함수이다.
sigmoid의 특성상, sigmoid의 input값이 증가하게되면 non-linearity saturation regime에 갇히게 된다. 즉, gradient 값이 점점 0에 가깝게 되서 weight parameter가 업데이트 되지 않게 된다.
input u는 다른 sub network로 부터 나온 output이라고 했을때, 그 sub network로 들어오는 input의 distribution이 계속 달라짐에 따라 parameter가 계속해서 변하게 되면 non-linearity saturation regime에 더 갇히게 될 가능성이 크다.
이 때, Input u의 distribution을 고정시키면 학습진행이 안정화 될 것이다.
Deep network의 internal node(weight parameter)가 계속해서 바뀌는 현상을 Internal Covariate Shift라고 한다. 이를 제거하기 위해서 본 논문에서는 Batch Normalization을 제안한다. Batch normalization을 통해 아래와 같은 효과를 볼 수 있다.
- Accelerate training(allows to use higher learning rate)
- Reduce dependence of gradients on the scale of parameters or of weight initial values
- Sigmoid와 같은 Saturating nonlinearity가 saturated modes에 빠지지 않도록 하여 gradient값이 0이 되지 않도록 한다.
Normalization via Mini-Batch Statistics
본 논문에서는 다음과 같은 두가지 명료화를 진행했다.
1. 각 scalar feature를 normalize하되, 평균값은 0를, 분산값은 1을 갖도록 한다. 즉, 어떤 layer에 d차원 input x = (x1, x2, x3, .... xd)가 들어간다고 했을 때, 아래와 같이 normalize된다.
이 때, 평균값과 분산값은 전체 training dataset에 대해서 계산된다.
위와 같이 normalize하게 되면 그 layer가 represent하고 있는 부분이 변할 수 있다. 예를 들어, Sigmoid에 들어가기 전에 normalize해주는 경우 linear regime of nonlinearity로 강제로 들어가게 한다. 이를 해결하기 위해서, 매 activation x(k)마다 아래와 같이 scale & shift parameter를 도입하여 normalize이전의 input distribution을 복원할 수 있도록 하였다.
r값과 B값은 학습 가능한 parameter이다.
2. Mini batch 별로 mean과 variance 를 계산하여 normalize한다.
자세한 Batch Normalization Transform의 알고리즘은 아래와 같다.
Training and Inference with Batch Normalized Networks
Batch normalization을 적용함으로써, Activation function을 거치기 전에 Batch Normalization을 진행한다. 이 때, Stochastic Gradient Descent를 이용하여 optimze한다. Training할 때는 alg.1 처럼 mini -batch를 normalize하는 것이 학습 속도를 빠르게 하지만 Inference시에는 mini batch에 대해서 normalize할 필요가 없고 오로지 input에 대해서만 고려해야한다. 따라서 학습시에 사용되었던 mini batch들의 평균값과 분산값을 이용하여 normalize해준다. 즉, inference시에는 평균값과 분산값이 학습이 안되고 고정된 값으로 존재한다.
Batch Normalization을 적용한 Network의 Training과 Inference Algorithm은 아래와 같다
Batch-Normalized Convolutional Networks
W와 b는 fully connected 혹은 Convolutional layer의 weight와 bias 값이라고 하자. g는 sigmoid 혹은 ReLU와 같은 activation function이라고 하자. Non-linearity function에 들어가기전에 BN을 적용한다.
구체적으로, 학습시에는 아래와 같은 그림 처럼 Channel wise하게 normalize를 진행한다. m개의 batch에 대해서 각 feature map의 크기가 pxq라고 한다면 mxpxq 개의 scalar에 대해서 평균과 분산값을 구한다음, scale and shift parameter를 가지고 normalize하게 된다.
Inference 시에는, 고정된 평균과 분산값을 가지고 normalize해주게 된다.
Batch Normalization enables higher learning rates
Batch Norm이 없는 deep network에서는 너무 큰 learning rate를 적용하면 gradient exploding혹은 vanishing이 일어나서 local minima에 빠지기 쉬웠다. BN을 적용하게 되면, 이러한 문제를 어느정도 해결해줄 수 있다. 앞서 설명한 것처럼 BN은 saturated regime of Non-linearity에 빠지지 않도록 해주기 때문이다.
또한, BN은 weight parameter scale에 대해서 robust하다. BN이 없는 network에서 Learning rate를 크게 주면, weight parameters의 scale 값 또한 커지게 되어 gradient 값 또한 커진다. 그러나 BN을 포함시면 학습이 오히려 안정화 될 수 있다.
Input u에 대해서 Weight W를 곱한 결과 값을 BN에 넣었을 때와 aW를 곱한 결과 값을 BN에 넣었을 때 결과값은 동일하다. 따라서 아래와 같은 등식들이 성립하고, a값이 커지면 커질 수록 weight gradient 값이 작아지게 되므로 학습을 안정화 시킨다.
Experiments
Activations over time
맨 왼쪽에서 보면, BN을 적용했을 때 확실히 더 빠르게 수렴하고 더 높은 test accuracy를 보이는 것을 알 수 있고 (b), (c)에서 보면 가운데 값이 평균, 위 아래가 분산 값이라고 생각하면 되는데 BN이 없으면 input distribution이 계속해서 변하는 것을 볼 수 있고 이는 Internal covariate shift 를 일으킨다. 반면 BN는 input distribution이 안정적이다.
Imagenet에 대해서 실험해봤을때, Inception 보다 훨씬 높은 test accuracy를 달성했다. BN-x5는 Baseline 모델(Inception)의 초기 learning rate(0.0015) * 5 한 값으로 준 모델을 의미한다.
앙상블을 사용했을 때, State of the art를 달성할 수 있었다고 한다.
댓글
댓글 쓰기