RECURRENT BATCH NORMALIZATION

RECURRENT BATCH NORMALIZATION
(링크)

1. 서론

Batch Norm 은 신경망 훈련시 각각의 Batch 의 데이터의 분포가 매우 상이하기 때문에 발생하는 문제를 해결하기 위해 고안된 방법으로, 현재 대부분의 신경망 설계시 빠르고 효율적인 신경망의 훈련을 위해서 적용되고 있다. 본 논문에서는 RNN 구조에서 Batch Norm 을 적용하는 방법에 대해서 이야기 하고 있다.

2. 선행 연구

(1) LSTM

LSTM 은 Long Short Term Memory Network 의 약자로 RNN 의 Sequence 가 길어지면 길어질 수록 신호가 유실되는 Vanishing Problem 을 해결하기 위해 고안된 것으로, 설명에 앞서 기본적인 RNN 은 아래와 같은 구조로 동작한다.

Hidden State 는 현재 시점 이전의 정보를 저장하기 위한 용도로 사용되며, 위 그림의 ht 를 구하는 공식과 같이 통상적으로는 Activation Function 을 Tanh 으로하여 t-1 시점의 Hidden State 와 t 시점의 x 정보를 Input 으로 ht 시점의 새로운 Hidden State 를 도출하고 t+1 Step 에서 다시 사용하게 된다. y 는 t 시점의 x 에 대한 연산 결과로 각 Step 의 예측 결과로 볼 수 있다.  이전 시점의 정보를 현재 시점에 무언가를 예측시 활용 할 수 있도록 한다는 Concept 에 있어서 RNN 은 매우 널리 활용되고 있으나, 기본적인 RNN 은 그 길이가 길어지면 오래전의 정보는 누락 될 수 있다는 Vanishing Problem 문제를 가지고 있는데, 이러한 문제를 해결하기 위하여 LSTM 과 같은 Conveyor belt 와 Forget Gate 와 Input Gate를 가지고 이러한 문제를 해결하고자 하는 접근 방법이 고안되었다.

t-1, t, t+1 시점을 연결하여 그림을 그려보면 위와 같이 표현될 수 있는데,  위에는 일반적인 RNN 구조이고 아래는 Vanishing Problem 을 개선하기 위한 LSTM 구조이다.

수식으로 보면 위와 같이 표현될 수 있다. f,i,o 의 Activation Function 은 Sigmoid 로 0~1 사이의 출력 값을 갖도록 한다. f 는 Forget Gate 로 선택적으로 이전 정보를 잊거나 기억하도록 하는 역할을 합니다. (0이면 정보 삭제, 1이면 온전한 전달)이 되겠습니다. i*g는 input Gate 로 현재의 정보를 기억하기 위한 Gate 로 g 는 Tanh Activation Function 의 활용으로 -1 ~ 1까지 음/양 방향성 정보를 표현하며, i는 sigmoid 로 Activation Function 으로 0~1 기억의 정보를 표현합니다.  여기서 c 는 Conveyor Belt 로 History 정보를 효과적으로 전달하기 위한 목적으로 설계되어 있으며, 위에서 설명한 Forget Gate 와 Input Gate 의 조합으로 이루어 집니다.

주절주절 설명한 내용으로 그래프로 표현하면 위와 같이 명확하게 설명이 됩니다.

class LSTMCell(RNNCell):

    def __call__(self, x, state, scope=None):
        with tf.variable_scope(scope or type(self).__name__):
            c, h = state

            # Keep W_xh and W_hh separate here as well to reuse initialization methods
            x_size = x.get_shape().as_list()[1]
            W_xh = tf.get_variable('W_xh',
                [x_size, 4 * self.num_units],
                initializer=orthogonal_initializer())
            W_hh = tf.get_variable('W_hh',
                [self.num_units, 4 * self.num_units],
                initializer=bn_lstm_identity_initializer(0.95))
            bias = tf.get_variable('bias', [4 * self.num_units])

            # hidden = tf.matmul(x, W_xh) + tf.matmul(h, W_hh) + bias
            # improve speed by concat.
            concat = tf.concat([x, h], 1)
            W_both = tf.concat([W_xh, W_hh], 0)
            hidden = tf.matmul(concat, W_both) + bias

            i, j, f, o = tf.split(hidden, 4, axis=1)

            new_c = c * tf.sigmoid(f) + tf.sigmoid(i) * tf.tanh(j)
            new_h = tf.tanh(new_c) * tf.sigmoid(o)

            return new_h, (new_c, new_h)

코드로도 위와 같이 명확하게 설명이 될 수 있겠습니다.

(2) Batch Normalization (링크)

Batch Norm 은 기본적으로 SGD 개념에 따르면, 전체의 훈련 Set 을 Random 하게 Mini-Batch 로 나누어 훈련하게 되는데, 이때, 각각의 Mini-Batch 의 분산이 상이하여 발생하는 문제를 해결하고자 Normalized 작업을 수행하여 준다로 이해하면 되겠다.

각각의 Mini Batch 의 분산은 1, 평균 0 이 되도록 회귀식을 만들어서 각각의 데이터에 적용하게 된다. 그래서 위에서 보는 y 가 나오게 되는 것이다.

size = x.get_shape().as_list()[1]

        scale = tf.get_variable('scale', [size],
            initializer=tf.constant_initializer(0.1))
        offset = tf.get_variable('offset', [size])

        pop_mean = tf.get_variable('pop_mean', [size],
            initializer=tf.zeros_initializer(),
            trainable=False)
        pop_var = tf.get_variable('pop_var', [size],
            initializer=tf.ones_initializer(),
            trainable=False)
        batch_mean, batch_var = tf.nn.moments(x, [0])

        train_mean_op = tf.assign(
            pop_mean,
            pop_mean * decay + batch_mean * (1 - decay))
        train_var_op = tf.assign(
            pop_var,
            pop_var * decay + batch_var * (1 - decay))

        def batch_statistics():
            with tf.control_dependencies([train_mean_op, train_var_op]):
                return tf.nn.batch_normalization(x, batch_mean, batch_var, offset, scale, epsilon)

        def population_statistics():
            return tf.nn.batch_normalization(x, pop_mean, pop_var, offset, scale, epsilon)

코드로 보면 위와 같다. 위의 코드는 Mini Batch 에 대한 Normalized 와 Inference 시에는 전체 데이터에 대해서 Normalized 두 가지를 구현하고 있다. 코드를 잘 보면, Train 시에는 Mini Batch 에 대해서 평균과 분산을 바로 구해서 사용하고, Inference 시에는 전체 데이터에 대해서 평균과 분산을 구해서 사용하고 있다.

2. 핵심 IDEA

(1) Batch Normalized LSTM

DNN 구조에서 Batch Norm 은 각각의 Hidden Layer 에 적용하는 형태로 Simple 하게 적용하여 많이 사용하고 있는데, RNN 구조에서는 어떻게 하는 것이 효과적일까? 바로 이 질문에 대한 답변이 이 논문에 있다고 보면 되겠다.

위의 LSTM 그래프에 Batch Norm 이 적용되는 구간을 표시하였다. LSTM Cell 과 동일한 구조로 동작하되,  붉은 색으로 표시한 부분에서 Batch Norm 작업을 실행한다고 보면 되겠다.

class BNLSTMCell(RNNCell):
    '''Batch normalized LSTM as described in arxiv.org/abs/1603.09025'''
    def __init__(self, num_units, training):
        self.num_units = num_units
        self.training = training

    @property
    def state_size(self):
        return (self.num_units, self.num_units)

    @property
    def output_size(self):
        return self.num_units

    def __call__(self, x, state, scope=None):
        with tf.variable_scope(scope or type(self).__name__):
            c, h = state

            x_size = x.get_shape().as_list()[1]
            W_xh = tf.get_variable('W_xh',
                [x_size, 4 * self.num_units],
                initializer=orthogonal_initializer())
            W_hh = tf.get_variable('W_hh',
                [self.num_units, 4 * self.num_units],
                initializer=bn_lstm_identity_initializer(0.95))
            bias = tf.get_variable('bias', [4 * self.num_units])

            xh = tf.matmul(x, W_xh)
            hh = tf.matmul(h, W_hh)

            bn_xh = batch_norm(xh, 'xh', self.training)
            bn_hh = batch_norm(hh, 'hh', self.training)

            hidden = bn_xh + bn_hh + bias

            i, j, f, o = tf.split(1, 4, hidden)

            new_c = c * tf.sigmoid(f) + tf.sigmoid(i) * tf.tanh(j)
            bn_new_c = batch_norm(new_c, 'c', self.training)

            new_h = tf.tanh(bn_new_c) * tf.sigmoid(o)

            return new_h, (new_c, new_h)

코드로 보면 위와 같다. 그래프로 설명한 것과 같이 붉은 색으로 표시한 부분에서 Batch Normalization 을 적용하는 것을 볼 수 있다.

(2) 결론

당연한 결론이겠지만,  아래 그림에서 보면, 파란선은 일반 LSTM , 붉은 선은 BN NORM 으로 훨씬 더 빠르게 훈련이 되고 Test Set 기준으로 평가시 일반화 성능도 더 좋은 것으로 연구 결과가 나왔다고 한다.

 

Leave a Reply

Your email address will not be published. Required fields are marked *