본문 바로가기
  • 기술을 이야기하지만 사람을 생각합니다.
20. 인공지능과 딥러닝

[PyTorch로 시작하는 딥러닝 기초] 09-4 Batch-Normalization

by WE DONE IT. 2020. 2. 9.

  

edwith의 [부스트코스] 파이토치로 시작하는 딥러닝 기초<Lab-09-4 Batch Normalization> 강의를 정리한 내용입니다.

 

[LECTURE] Lab-09-4 Batch Normalization : edwith

학습목표 Batch Normalization 에 대해 알아본다. 핵심키워드 Batch Normalization 경사 소실(Gradient Vanishing) / 폭발(Explodi... - tkddyd

www.edwith.org


Batch Normalization

  • Gradient Vanishing / Exploding

  • Internal Covariate Shift

  • Batch Normalization

  • Code: mnist_batchnorm

Gradient Vanishing / Exploding

 

Gradient Vanishing

Gradient가 앞단으로 전파되면서 점점 옅어지게 되어 너무 작아져서 소멸하게 되는 문제이다.

Gradient Exploding

반면, exploidng은 gradient가 너무 크게 계산되어 너무 큰 값 또는 nand 값이 나오는 문제이다.   

 

Solution

  • Change activation function
    : 활성화 함수 중 sigmoid에서 이 문제가 발생하기 때문에 ReLU를 사용하기도 한다.

  • Careful initialization
    : weight 초기화를 잘 해보자는 의미로써, He initialization, Xavier initialization 등을 사용한다.

  • Small learning rate
    : Gradient Exploding 문제를 해결하기 위해 learning rate 값을 작게 할 수 있다. 

  • Batch Normalization
    : 학습 과정을 안전하게 할 수 있으며, 학습 속도의 가속 등 다양한 이점이 있음

  

Internal Covariate Shift

Covariate Shift의 개념 : 학습셋과 검증셋 분포의 차이가 문제를 발생시킴

위 그래프처럼 빨간색 그래프가 train set 파란색 점선이 test set이라고 할 때, 학습셋과 검증셋은 분포(distribution)의 차이가 어떤 문제점을 발생시킨 게 Covariate Shift의 개념이다. 입력과 출력의 분포가 다르다는 것도 이와 유사한 개념이다.

 

Internal Covariate Shift 개념: 레이어를 거치면서 이미지 분포 형태가 변화하게 된다.

여러 장의 고양이 이미지들은 어떤 분포 형태를 띄고 있으며, forward 되면서 첫 번째 레이어를 통과할 때 covariate shift 문제가 발생한다고 가정해 보자. 1, 2, 3, 4 레이어를 지나면서 분포가 약간씩 변하게 된다. 이러한 현상을 Internal Covariate Shift 라고 한다. 

 

이처럼 입력 데이터와 출력 데이터 간의 차이가 발생하는 문제를 대응하기 위해, 입력 데이터를 정규화(normalize) 하여 neural network를 사용했다.

 

Batch Normalization에서 주장하는 Internal Covariate Shift 문제는 입력과 출력 데이터에서 문제를 바라보지 않는다.
한 레이어 마다 입력과 출력을 가지고 있는데 이 레이어들끼리 covariate shift가 발생하며, 레이어가 깊어질 수록 distribution이 더 크게 발생한다는 점이다. 그래서 이 문제를 해결하고자 Batch Normalization을 이용한다.

 

Batch Normalization

각 레이어마다 Normalization을 하는 레이어를 두어, 변형된 분포가 나오지 않도록 하는 것이다. 미니배치 마다 normalization을 한다는 뜻에서 Batch Normalization이라고 한다. 

 Batch Normalization 논문에서 표현한 알고리즘  (출처: https://arxiv.org/pdf/1502.03167.pdf)

  • 위 식에서 입실론은 계산할 때 0으로 나눠져서 nand 문제가 발생하는 것을 막기 위한 아주 작은 숫자이다.
  • Normalize된 결과에 감마와 베타라는 scale과 shift transform을 적용한다. 
    • 데이터를 계속 Normalize 하게 되면 activation function의 non-linearity 같은 성질을 잃게 되는데,
      이 문제를 완화하기 위함이다. 
    • Normalize가 끝난 뒤, 결과값에 감마를 곱하고 shift를 더해주는 연산을 하면 Batch Normalization 계산이 끝난다. 

Train & Eval mode 

Train mode는 실제 학습할 때 사용하는 모드이며 dropout을 이용하여 노드를 껐다 켰다 하게 된다. Evaluation mode는 검증할 때 선언 후, dropout을 사용하지 않고 전체 뉴런을 상용한다.

Batch Normalization에서도 이 모드를 사용하는데, train과 실제 inference, 테스트를 할 때 차이점이 존재하기 때문에 사용한다.

 

왜 Batch Normalization에서도 사용해야 하는가?

batch size 값을 정한 뒤, 뉴럴 네트워크에 넣어서 학습하게 된다. 시그마와 뮤 값을 평균과 분산을 계산한다. 다음으로 normalize를 진행하여 X hat을 계산한다. 그 다음, 학습했던 감마를 X hat에 곱하고 베타를 더하는 방식으로 transform을 한다.

 최종적으로 Batch Normalization 하는 과정에서 X가 변경되면 뮤와 variance가 전혀 다른 값이 나올 수 있다. 즉, batch 값이 바뀌면다른 결과가 나올 수 있다. 따라서 이 과정에서도 trian과 evaluation 모드를 따로 둔다.

 

Leaning mean과 Learning variance

  1. 학습셋에서 sample mean과 sample variance 평균을 따로 저장한다.
  2. Inference (테스트)할 때 입력 데이터의 뮤와 variance가 아닌, leanring mean과 learning variance를 계산한다.
  3. Batch Normalization 학습이 끝난 뒤 입력 batch 데이터와 상관 없이 변하지 않는 고정값이 된다.
  4. 이 값을 inference 할 때에는 이 값을 이용하여 mean과 variance로 normalizae를 시키는 방식을 취한다.

따라서, batch에 있는 데이터가 변화하더라도 normalize하는 mean과 varianece 값이 바뀌지 않게 된다.

 

Batch Normalization에서 dropout을 할 때 문제점: 데이터가 바뀜에 따라서 mean과 variance 값이 바뀜
이 문제를 해결하기 위해 sample mean과 sample variance를 저장한 뒤, 다시 불러오는 방식으로 해결함.

Code: mnist_batchnorm

Batch Normalization 사용 여부에 따른 성능을 비교한 코드이다.


 # nn layers
linear1 = torch.nn.Linear(784, 32, bias=True) 
linear2 = torch.nn.Linear(32, 32, bias=True) 
linear3 = torch.nn.Linear(32, 10, bias=True) 

relu = torch.nn.ReLU()
bn1 = torch.nn.BatchNorm1d(32)
bn2 = torch.nn.BatchNorm1d(32)

nn_linear1 = torch.nn.Linear(784, 32, bias=True) 
nn_linear2 = torch.nn.Linear(32, 32, bias=True) nn_linear3 = torch.nn.Linear(32, 10, bias=True)


# model
bn_model = torch.nn.Sequential(linear1, bn1, relu,
                              linear2, bn2, relu,
                              linear3).to(device)
nn_model = torch.nn.Sequential(nn_linear1, relu,
                               nn_linear2, relu,
                               nn_linear3).to(device)

for epoch in range(training_epochs): 
	bn_model.train() # set the model to train mode (반드시 선언해야됨)
    
for X, Y in train_loader:
	# reshape input image into [batch_size by 784] 
    # label is not one-hot encoded
	X = X.view(-1, 28 * 28).to(device)
	Y = Y.to(device)
    
       bn_optimizer.zero_grad()
       bn_prediction = bn_model(X)
       bn_loss = criterion(bn_prediction, Y)
       bn_loss.backward()
       bn_optimizer.step()
       nn_optimizer.zero_grad()
       nn_prediction = nn_model(X)
       nn_loss = criterion(nn_prediction, Y)
       nn_loss.backward()
       nn_optimizer.step()

 

Batch Normalization 사용 여부에 따른 성능 비교

  1. 학습을 진행하면서 BatchNorm을 썼을 때와 안 썼을 때 loss 비교
  2. 학습셋의 정확도 비교
  3. 검증할 때 BatchNorm 사용 여부에 따른 loss 비교
  4. 검증셋의 정확도 비교

MNIST가 아닌 다른 데이터를 사용했을 때 BatchNorm의 성능은 더욱 빛을 바랄 수 있다.

댓글