RECURRENT BATCH NORMALIZATION
(링크)
1. 서론
Batch Norm 은 신경망 훈련시 각각의 Batch 의 데이터의 분포가 매우 상이하기 때문에 발생하는 문제를 해결하기 위해 고안된 방법으로, 현재 대부분의 신경망 설계시 빠르고 효율적인 신경망의 훈련을 위해서 적용되고 있다. 본 논문에서는 RNN 구조에서 Batch Norm 을 적용하는 방법에 대해서 이야기 하고 있다.
2. 선행 연구
(1) LSTM
LSTM 은 Long Short Term Memory Network 의 약자로 RNN 의 Sequence 가 길어지면 길어질 수록 신호가 유실되는 Vanishing Problem 을 해결하기 위해 고안된 것으로, 설명에 앞서 기본적인 RNN 은 아래와 같은 구조로 동작한다.
Hidden State 는 현재 시점 이전의 정보를 저장하기 위한 용도로 사용되며, 위 그림의 ht 를 구하는 공식과 같이 통상적으로는 Activation Function 을 Tanh 으로하여 t-1 시점의 Hidden State 와 t 시점의 x 정보를 Input 으로 ht 시점의 새로운 Hidden State 를 도출하고 t+1 Step 에서 다시 사용하게 된다. y 는 t 시점의 x 에 대한 연산 결과로 각 Step 의 예측 결과로 볼 수 있다. 이전 시점의 정보를 현재 시점에 무언가를 예측시 활용 할 수 있도록 한다는 Concept 에 있어서 RNN 은 매우 널리 활용되고 있으나, 기본적인 RNN 은 그 길이가 길어지면 오래전의 정보는 누락 될 수 있다는 Vanishing Problem 문제를 가지고 있는데, 이러한 문제를 해결하기 위하여 LSTM 과 같은 Conveyor belt 와 Forget Gate 와 Input Gate를 가지고 이러한 문제를 해결하고자 하는 접근 방법이 고안되었다.
t-1, t, t+1 시점을 연결하여 그림을 그려보면 위와 같이 표현될 수 있는데, 위에는 일반적인 RNN 구조이고 아래는 Vanishing Problem 을 개선하기 위한 LSTM 구조이다.
수식으로 보면 위와 같이 표현될 수 있다. f,i,o 의 Activation Function 은 Sigmoid 로 0~1 사이의 출력 값을 갖도록 한다. f 는 Forget Gate 로 선택적으로 이전 정보를 잊거나 기억하도록 하는 역할을 합니다. (0이면 정보 삭제, 1이면 온전한 전달)이 되겠습니다. i*g는 input Gate 로 현재의 정보를 기억하기 위한 Gate 로 g 는 Tanh Activation Function 의 활용으로 -1 ~ 1까지 음/양 방향성 정보를 표현하며, i는 sigmoid 로 Activation Function 으로 0~1 기억의 정보를 표현합니다. 여기서 c 는 Conveyor Belt 로 History 정보를 효과적으로 전달하기 위한 목적으로 설계되어 있으며, 위에서 설명한 Forget Gate 와 Input Gate 의 조합으로 이루어 집니다.
주절주절 설명한 내용으로 그래프로 표현하면 위와 같이 명확하게 설명이 됩니다.
class LSTMCell(RNNCell): def __call__(self, x, state, scope=None): with tf.variable_scope(scope or type(self).__name__): c, h = state # Keep W_xh and W_hh separate here as well to reuse initialization methods x_size = x.get_shape().as_list()[1] W_xh = tf.get_variable('W_xh', [x_size, 4 * self.num_units], initializer=orthogonal_initializer()) W_hh = tf.get_variable('W_hh', [self.num_units, 4 * self.num_units], initializer=bn_lstm_identity_initializer(0.95)) bias = tf.get_variable('bias', [4 * self.num_units]) # hidden = tf.matmul(x, W_xh) + tf.matmul(h, W_hh) + bias # improve speed by concat. concat = tf.concat([x, h], 1) W_both = tf.concat([W_xh, W_hh], 0) hidden = tf.matmul(concat, W_both) + bias i, j, f, o = tf.split(hidden, 4, axis=1) new_c = c * tf.sigmoid(f) + tf.sigmoid(i) * tf.tanh(j) new_h = tf.tanh(new_c) * tf.sigmoid(o) return new_h, (new_c, new_h)
코드로도 위와 같이 명확하게 설명이 될 수 있겠습니다.
(2) Batch Normalization (링크)
Batch Norm 은 기본적으로 SGD 개념에 따르면, 전체의 훈련 Set 을 Random 하게 Mini-Batch 로 나누어 훈련하게 되는데, 이때, 각각의 Mini-Batch 의 분산이 상이하여 발생하는 문제를 해결하고자 Normalized 작업을 수행하여 준다로 이해하면 되겠다.
각각의 Mini Batch 의 분산은 1, 평균 0 이 되도록 회귀식을 만들어서 각각의 데이터에 적용하게 된다. 그래서 위에서 보는 y 가 나오게 되는 것이다.
size = x.get_shape().as_list()[1] scale = tf.get_variable('scale', [size], initializer=tf.constant_initializer(0.1)) offset = tf.get_variable('offset', [size]) pop_mean = tf.get_variable('pop_mean', [size], initializer=tf.zeros_initializer(), trainable=False) pop_var = tf.get_variable('pop_var', [size], initializer=tf.ones_initializer(), trainable=False) batch_mean, batch_var = tf.nn.moments(x, [0]) train_mean_op = tf.assign( pop_mean, pop_mean * decay + batch_mean * (1 - decay)) train_var_op = tf.assign( pop_var, pop_var * decay + batch_var * (1 - decay)) def batch_statistics(): with tf.control_dependencies([train_mean_op, train_var_op]): return tf.nn.batch_normalization(x, batch_mean, batch_var, offset, scale, epsilon) def population_statistics(): return tf.nn.batch_normalization(x, pop_mean, pop_var, offset, scale, epsilon)
코드로 보면 위와 같다. 위의 코드는 Mini Batch 에 대한 Normalized 와 Inference 시에는 전체 데이터에 대해서 Normalized 두 가지를 구현하고 있다. 코드를 잘 보면, Train 시에는 Mini Batch 에 대해서 평균과 분산을 바로 구해서 사용하고, Inference 시에는 전체 데이터에 대해서 평균과 분산을 구해서 사용하고 있다.
2. 핵심 IDEA
(1) Batch Normalized LSTM
DNN 구조에서 Batch Norm 은 각각의 Hidden Layer 에 적용하는 형태로 Simple 하게 적용하여 많이 사용하고 있는데, RNN 구조에서는 어떻게 하는 것이 효과적일까? 바로 이 질문에 대한 답변이 이 논문에 있다고 보면 되겠다.
위의 LSTM 그래프에 Batch Norm 이 적용되는 구간을 표시하였다. LSTM Cell 과 동일한 구조로 동작하되, 붉은 색으로 표시한 부분에서 Batch Norm 작업을 실행한다고 보면 되겠다.
class BNLSTMCell(RNNCell): '''Batch normalized LSTM as described in arxiv.org/abs/1603.09025''' def __init__(self, num_units, training): self.num_units = num_units self.training = training @property def state_size(self): return (self.num_units, self.num_units) @property def output_size(self): return self.num_units def __call__(self, x, state, scope=None): with tf.variable_scope(scope or type(self).__name__): c, h = state x_size = x.get_shape().as_list()[1] W_xh = tf.get_variable('W_xh', [x_size, 4 * self.num_units], initializer=orthogonal_initializer()) W_hh = tf.get_variable('W_hh', [self.num_units, 4 * self.num_units], initializer=bn_lstm_identity_initializer(0.95)) bias = tf.get_variable('bias', [4 * self.num_units]) xh = tf.matmul(x, W_xh) hh = tf.matmul(h, W_hh) bn_xh = batch_norm(xh, 'xh', self.training) bn_hh = batch_norm(hh, 'hh', self.training) hidden = bn_xh + bn_hh + bias i, j, f, o = tf.split(1, 4, hidden) new_c = c * tf.sigmoid(f) + tf.sigmoid(i) * tf.tanh(j) bn_new_c = batch_norm(new_c, 'c', self.training) new_h = tf.tanh(bn_new_c) * tf.sigmoid(o) return new_h, (new_c, new_h)
코드로 보면 위와 같다. 그래프로 설명한 것과 같이 붉은 색으로 표시한 부분에서 Batch Normalization 을 적용하는 것을 볼 수 있다.
(2) 결론
당연한 결론이겠지만, 아래 그림에서 보면, 파란선은 일반 LSTM , 붉은 선은 BN NORM 으로 훨씬 더 빠르게 훈련이 되고 Test Set 기준으로 평가시 일반화 성능도 더 좋은 것으로 연구 결과가 나왔다고 한다.