while (1): study();

[논문 리뷰] Train longer, generalize better: closing the generalization gap in large batch training of neural networks 본문

논문 리뷰

[논문 리뷰] Train longer, generalize better: closing the generalization gap in large batch training of neural networks

전국민실업화 2021. 6. 27. 23:44
728x90

 이스라엘 연구팀에서 2017년에 발표한 연구 논문입니다. Tabnet에서 사용된 Ghost BN에 대해서 찾아보다가 논문을 읽게 되어 리뷰해보려 합니다. Tabnet의 논문 리뷰는 이후에 진행하도록 하겠습니다.

 

1. Introduction

 

 아직까지 인공지능 모델의 훈련에 있어 SGD는 중요한 역할을 하고 있습니다. 여담으로 Adam과 같이 Adaptive learning rate를 지원하는 방식은 훈련의 초기에 (특히 Transformer 아키텍처에서) 불안정한 양상을 보인더랬죠. 그럼으로 인해 SOTA급 성능을 뽑아내기 위해서 SGD 튜닝은 거의 불가피해 보입니다.

 

 그러나 SGD의 특성상 국소 최적해에 빠지는 등 일반화의 문제에 대해서 많은 관심이 끌렸고, 그 중 하나가 배치 사이즈가 너무 크면 모델이 일반화가 잘 되지 않는다는 것입니다. 저자들은 이 문제에 대해서 'Generalization gap'이라고 불렀습니다.

 

 사실 작은 배치 사이즈를 사용하기 시작한건 초기 컴퓨터의 메모리 자원 문제 때문이었던 걸로 알고 있습니다. 대신 많이 돌리면 언젠가는 근사하겠지~ 라는 심정으로 훈련을 시켰던 게 SGD였을텐데, 하드웨어 가격이 많이 떨어진 지금에 와서는 오히려 너무 크게 가져가면 일반화가 안된다니 아이러니한 것 같습니다. 

 

 그러나 큰 배치사이즈를 가져간다는 것은 훈련 속도가 향상됨을 의미하고, 결국 저자들은 이 Generalization gap을 해결하기 위한 방법론을 제안하려 합니다.

 

5. Matching weight increment statictics for different mini-batch sizes

 

 Parameters update step에 대한 공분산이 델타w라고 했을 때 이는 다음과 같이 계산됩니다.

이 말은 즉, 공분산을 일정하게 유지시키려면 학습율을 미니배치 사이즈의 제곱근에 비례하는 값으로 업데이트하면 됩니다.

또한 Ghost Batch Normalization을 사용하면 Generalization gap을 줄일 수 있습니다. 이는 간단하게 말하면 'Large batch를 사용하면 generalization gap이 커지니까, Large batch를 그 내부에서 다시 small batch로 나누어서 계산하자!'라는 겁니다. 전개는 다음과 같습니다. 각각의 virtual batches에 대해서 평균과 표준편차를 구하고, 평균내는 방식입니다. 나중에 추론에서 사용할 파라미터도 매 배치에 대해서 정보를 취합하며 업데이트하는 것이 눈에 띕니다.

 

Ghost Batch Normalization

 

5. Adapting number of weight updates eliminates generalization gap

 

 과적합에 대한 두려움 때문에 사람들은 validation error가 안정기에 접어들기 전 warm-up stage를 중단해버립니다. 그러나 저자들은 validation error가 떨어지지 않더라도 계속 같은 learning rate로 warm-up할 것을 권장합니다. 그렇게 함으로써 상당한 성능 향상을 얻을 수 있다고 합니다.

 

 저자들은 이러한 관찰 결과를 바탕으로 사실 generalization gap이라는 현상은 큰 배치 사이즈에서 오는 것이 아니라, 상대적으로 적은 업데이트 횟수에서 기인하는 것이라고 주장합니다. 배치 사이즈가 커지면 그만큼 업데이트 횟수가 적어지고, 여기에서 일반화 문제가 발생한다는 것이죠. 결국 Ghost Batch Normalization 또한 업데이트 횟수가 많아지는 효과를 얻음으로써 상대적으로 일반화가 잘 되는 거였다고 볼 수도 있겠네요.

 

7. Discussion

 

 저자들이 generalization gap을 해결하기 위해 제시한 바는 4가지입니다. 

 

1. SGD를 모멘텀, 그래디언트 클리핑, 그리고 learning rate schduler와 함께 사용할 것.

2. Adaptive learning rate를 미니배치의 제곱근에 비례하는 수치로 업데이트할 것

3. Ghost Batch Normalization을 사용할 것

4. 충분한 warm-up(high learning rate training) iteration을 취할 것

 

결론 

 

사실 learning rate schdule 튜닝 자체가 정말 힘든 일이기도 하고, Transformer 기반 아키텍처도 Radam이나 pre-LN정도만 사용하면 잘 돌아가는 것을 확인했는데 굳이 SGD를 사용해야 할까.. 싶기도 합니다. 또 Ghost Batch Normalization도 Tabnet을 써보고싶다는 욕심만 아니었으면 볼 일도 없었을 것 같은데 말이죠.

 

 다만, 일반화의 문제가 배치 사이즈에서 기인하는 것이 아닌, 업데이트 빈도에서 기인하는 것이라는 주장은 상당히 인상깊었습니다. 엄청 흥미롭게 읽은 것은 아니지만 그래도 하나는 얻어가는 것 같아서 기분이 좋습니다:)

 

728x90
Comments