人工智能A7论坛 >> Tensorflow和深度学习笔记_论坛版 >> 8.RNN在文本和图像中的应用

8.2 RNN做图像分类

通常,在大家的印象中,CNN解决图像问题,RNN解决文本、语音等序列问题。这固然不错,但是并不绝对。我们前面已经介绍过,CNN也可以用于文本分类问题,而这里将会让大家看到,RNN也是可以用于对图像(如MNIST数据集)进行分类的!其原理是,把图片中的所有“行”组成一个序列,也就是序列中每一个组成单元就是一行。

 

这里简要列举基本过程。

RNN的构造如下

def RNN(x, weights, biases):

    # Prepare data shape to match `rnn` function requirements

    # Current data input shape: (batch_size, n_steps, n_input)

    # Required shape: 'n_steps' tensors list of shape (batch_size, n_input)

 

    # Permuting batch_size and n_steps

    x = tf.transpose(x, [1, 0, 2])

    # Reshaping to (n_steps*batch_size, n_input)

    x = tf.reshape(x, [-1, n_input])

    # Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)

    x = tf.split(x, n_steps, 0)

 

    # Define a lstm cell with tensorflow

    lstm_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)

 

    # Get lstm cell output

    outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32)

 

    # Linear activation, using rnn inner loop last output

    return tf.matmul(outputs[-1], weights['out']) + biases['out']

可以看到,这是一个用LSTM组成的RNN序列。

 

喂数据和训练的情况

batch_x, batch_y = mnist.train.next_batch(batch_size)

        # Reshape data to get 28 seq of 28 elements

        batch_x = batch_x.reshape((batch_size, n_steps, n_input))

        # Run optimization op (backprop)

        sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})

 

可以看到,图像是28行,每一行有28个元素,这28行构成了一个序列。

具体代码:

https://github.com/aymericdamien/TensorFlow-Examples

TensorFlow-Examples-master/examples/3_NeuralNetworks/recurrent_network.py