Accurate , Large minibatch SGD Trainging ImageNet in 1 Hour

대용량 혹은 분산 처리 환경에서 BatchSize와 Learning rates를 설정하는데 있어 자주 실수 하는 부분을 정리해주는 논문

관련된 논문

  • “Don’t Decay the learning Rate, Increase the Batch Size
    Distributed Training of neural Networks
    Fast and easy distributed deep learning n TensorFlow

Motivate

  • 데이터가 많아 짐에 따라 Training 시간이 점점더 오래 걸리고 있음( 몇시간에서 몇일, 몇주일까지 )
  • 회사 입장에서는 프로덕트를 늦게 되는 문제이고, 연구자 입장에서는 많은 실험을 하지 못하는 문제를 야기함
  • GPU가 많이 있더라도 GPU자원을 Scaling하고 효율적으로 사용하는 것은 어려운 문제
  • Facebook이나 Google에서 데이터도 많고 서버도 많은 큰 회사에서 이런 연구들을 진행

사전지식

  • SGD는 전체 X에 속해 있는 x에 대해서 Minimize를 진행하게 됨
  • 미니배치는 하나의 배치안에서 평균을 내고 그 다음에 Learning rate (에타로 표시함) 를 곱해서 weight를 업데이트 함
  • 물론 weight update할때모멤텀 같은것도 씀

Linear Scaling Rule

  • Batchsize를 증가시키면 Learningrate도 같은 배율로 증가 해야 함
  • Batchsize가 n인 상황과 kn인 상황을 나눔
  • n은 한 머신에서 돌릴수 있는 크기라고 생각하면 좋고, kn은 k개의 머신이라고 생각하면 됨
  • Batchsize : n인 경우
    • k번 iterations을 하고 learning rate는 에타
    • j가 0부터 k-1개까지 진행
    • j가 0부터 시작하면, 0번째 배치의 loss의 gradient를 구하고, 1번째 배치의 loss의 gradient를 구하고
    • 결국 j가 k-1이 될때까지 반복함
  • Batchsize가 kn인 경우
    • Single step kn개의 배치를 k개의 n개 크기의 배치라고 생각함
    • kn으로 나우어서 nomalize한다는 점
    • loss를 계산하는 지점이 n개의 배치는 Wt+1, Wt+2로 계속 바껴가면서 계산하나 Wt로 고정되어서 한점에서만 계산함
  • 우리의 목적은 Batchsize가 n개던 kn개던 Weight가 같은 지점에 도착하기를 원함
  • k보다 작은 모든 인덱스 j에 대해서 loss의 gredient가 같다고 가정하면(수학적으로 엄밀히 증명하지 않고 informal intuition이라고함)
  • 시그마 텀도 같아 지고
  • 에타와 에타^이 같아 질려면 에타^이 k에타가 되면 같아짐
  • Batchsize는 n은 gradient 방향으로 조금씩 k방향으로 나간다고 생각하고,
  • Batchsize kn은 한번에 많이 가고 대신 learning rate에 k를 곱한다고 생각하면 편함

Warmup

  • Batchsize n과 kn의 로스가 비슷할 것이다라는 가정이 성립하지 않는 상황을 설명하는데, 네트워크 훈련의 처음은 잘 성립이 안됨. 왜냐하면 처음에는 웨이트가 조금씩 변경되어도 그래디언트가 많이 변하기 때문임
  • 그래서 Warmup 단계를 추가로 설명함
  • 두가지 방법이 있음 상수 Warmup, 점진적Warmup
  • 정해진 상수로 미리 Warmup 함
  • 점진적 Warmup은 learning rates를 바꿔 가면서 warmup을 시킴
  • object detection 이나 segmentation에서는 상수 warmup도 효과를 보이나 classification에서는 점진적 방법이 성능이 좋음(경험적으로 )
  • Batch Normailization은 일반적으로 훌륭하지만 loss를 분석할때는 힘듬
  • loss는 input data가 independence 하다고 가정하고 시작하지만 BN을 쓰면 이미 input data를 가지고 통계적으로 수치를 구하기 때문에 independence가 깨짐
  • 분산 GPU환경에서는 모든 Woker의 분산과 평균을 구해서 더해주는 worker간의 communication cost가 어마어마 하게 필요하게 됨
  • 이와 같이 independence가 깨진 loss는 lb라고 하고, x가 특정 배치에 dependence한 loss 임
  • 여기서 식을 약간 변형하여
  • 하나의 배치를 X의 n승 즉, n개의 training set의 Cartesian Product라 가정하면 Batch B가 그냥 하나의 Sample이 될수 있고, 이렇게 바꾸면 indepence가 아직은 존재 한다고 함
  • 그러면 n이 분산컴퓨팅에 대한 하이퍼 파라미터가 아니고, 그냥 BN에 대한 Hyper parameter라고 가정할수 있음
  • 보통 분산 환경에서 모든 worker에서 norm을 구하는 것보다는 한 worker안에서 BN을 진행해야 하는게 좋고 이러면 Communication Cost도 줄이고, Loss를 위해서도 좋음

Subtleties and Pitfalls of Distributed SGD

  • 우리가 일반적으로 계산하는 크로스 엔트로피를 스케일링 하는것과 learning rate를 스케일링 하는것은 다름
  • weight decay를 적용하는경우 learning rate가 weight decay term에도 붙어 있기 때문에 leraning rate 의 스케일 조절과 크로스엔트로피의 스케일 조절은 다름
  • 모멘텀 구현은 크게 2가지가 있음
  • 첫째는 loss의 gradient 를 구하고 그 다음에 learning rate를 구하는 방법
  • 둘째는 모멘텀을 구할때 이미 learning rate 에타가 이미 곱해지는 상태임
  • 그래서 learning rate가 고정일때는 두개가 동일하나, learning rate가 변경될때에는 momenturm correction을 해야 한다.
  • 분산 컴퓨팅 환경에서는 gradient를 aggregation 할때 먼저 로컬머신에서 평균을 구하고, k개의 머신끼리 더해서 평균을 구해야 함
  • 그렇지만 보통 분산컴퓨팅 환경에서는 각각 더해주는 작업을 이미 하고 있기 때문에 로컬 머신에서 먼저 k를 나누고 더하는게 좋음
  • random sample보다 random shuffling이 더 좋다고 여러 논문이 이야기함
  • 분산환경에서는 매 epoch마다 전체 Data를 k개로 파티셔닝하고, k개의 worker가 각각 파티셔닝된 데이터를 처리 하는 형태로 구현해야함
  • 모든 worker가 큰 데이터 셋에서 샘플링 하던지, 매 epoch마다 새로운 파티셔닝을 해야 하는데 같은 파티셔닝만 쓴다던지 하면 문제가 생김

Experiments

  • learning rate scale하는 것 하나만으로 8k까지는 균일하게 냄
  • 8k보다 커지면 문제가 생김
  • baseline은 256배치로 single machine에서 실행한 결과
  • constant가 에러가 25.88, gradual이 23.74로 좀더 적음
  • 분산을 보면 gradual이 0.09로 보다 안정적으로 훈련이 된다는것을 볼수 있음. 그래프도 마찬가지
  • 8k까지는 잘되나 마지막 3개는 배치사이즈가 커짐으로서 훈련이 잘 안됨
  • batchsize와 learning rate의 상관관계를 알아보기 위해 batchsize고정하고 learning rate만 변경해서 해보니 그래프가 다르게 그려짐
  • 같은 결과를 내기 위해서 배치사이즈를 늘릴때 러닝레이트도 늘려야함
  • 훈련이 잘 안되는 초기 상태에서는 warmup을 해야 함(constant, gradual)
  • 구현시 자주 하는 실수
    • 크로스 엔트로피 스케일과 러닝레이트 스케일은 다르다
    • 모멘텀을 적용할때 러닝레이트 변경을 하면 모멤텀 보정을 해야한다.
    • 분산컴퓨터 환경에서는 서메이션을 하기 때문에 미리 k로 나누자.
    • 분산컴퓨터 환경에서 셔플링으로 k개의 파티션으로 나누어서 워커로 분배

참고자료

  • “https://arxiv.org/pdf/1706.02677.pdf”
  • “https://www.youtube.com/watch?v=g3McZgloCJo&index=32&list=PLWKf9beHi3TgstcIn8K6dI_85_ppAxzB8”

Leave a Reply

Your email address will not be published. Required fields are marked *