如何使用TensorFlow创建CNN

如何使用TensorFlow创建CNN

这篇文章主要介绍“如何使用TensorFlow创建CNN”,在日常操作中,相信很多人在如何使用TensorFlow创建CNN问题上存在疑惑,小编查阅了各式资料,整理出简单好用的操作方法,希望对大家解答”如何使用TensorFlow创建CNN”的疑惑有所帮助!接下来,请跟着小编一起来学习吧!

使用TensorFlow创建CNN

#-*-coding:utf-8-*-importtensorflowastfimportnumpyasnp#下载mnist数据集fromtensorflow.examples.tutorials.mnistimportinput_datamnist=input_data.read_data_sets('./mnist_data/',one_hot=True)#fromtensorflow.contrib.learn.python.learn.datasets.mnistimportread_data_sets##mnist=read_data_sets('./mnist_data/',one_hot=True)n_output_layer=10#定义待训练的神经网络defconvolutional_neural_network(data):weights={'w_conv1':tf.Variable(tf.random_normal([5,5,1,32])),'w_conv2':tf.Variable(tf.random_normal([5,5,32,64])),'w_fc':tf.Variable(tf.random_normal([7*7*64,1024])),'out':tf.Variable(tf.random_normal([1024,n_output_layer]))}biases={'b_conv1':tf.Variable(tf.random_normal([32])),'b_conv2':tf.Variable(tf.random_normal([64])),'b_fc':tf.Variable(tf.random_normal([1024])),'out':tf.Variable(tf.random_normal([n_output_layer]))}data=tf.reshape(data,[-1,28,28,1])conv1=tf.nn.relu(tf.add(tf.nn.conv2d(data,weights['w_conv1'],strides=[1,1,1,1],padding='SAME'),biases['b_conv1']))conv1=tf.nn.max_pool(conv1,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')conv2=tf.nn.relu(tf.add(tf.nn.conv2d(conv1,weights['w_conv2'],strides=[1,1,1,1],padding='SAME'),biases['b_conv2']))conv2=tf.nn.max_pool(conv2,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')fc=tf.reshape(conv2,[-1,7*7*64])fc=tf.nn.relu(tf.add(tf.matmul(fc,weights['w_fc']),biases['b_fc']))#dropout剔除一些"神经元"#fc=tf.nn.dropout(fc,0.8)output=tf.add(tf.matmul(fc,weights['out']),biases['out'])returnoutput#每次使用100条数据进行训练batch_size=100X=tf.placeholder('float',[None,28*28])Y=tf.placeholder('float')#使用数据训练神经网络deftrain_neural_network(X,Y):predict=convolutional_neural_network(X)#cost_func=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=predict,labels=Y))cost_func=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=predict,labels=Y))optimizer=tf.train.AdamOptimizer().minimize(cost_func)#learningrate默认0.001epochs=1withtf.Session()assession:#session.run(tf.initialize_all_variables())session.run(tf.global_variables_initializer())epoch_loss=0forepochinrange(epochs):foriinrange(int(mnist.train.num_examples/batch_size)):x,y=mnist.train.next_batch(batch_size)_,c=session.run([optimizer,cost_func],feed_dict={X:x,Y:y})epoch_loss+=cprint(epoch,':',epoch_loss)correct=tf.equal(tf.argmax(predict,1),tf.argmax(Y,1))accuracy=tf.reduce_mean(tf.cast(correct,'float'))print('准确率:',accuracy.eval({X:mnist.test.images,Y:mnist.test.labels}))train_neural_network(X,Y)

执行结果:

准确率:0.9789

tflearn

下面使用tflearn重写上面代码,tflearn是TensorFlow的高级封装,类似Keras。

tflearn提供了更简单、直观的接口。和scikit-learn差不多,代码如下:

#-*-coding:utf-8-*-importtflearnfromtflearn.layers.convimportconv_2d,max_pool_2dfromtflearn.layers.coreimportinput_data,dropout,fully_connectedfromtflearn.layers.estimatorimportregressiontrain_x,train_y,test_x,test_y=tflearn.datasets.mnist.load_data(data_dir="./mnist_data/",one_hot=True)train_x=train_x.reshape(-1,28,28,1)test_x=test_x.reshape(-1,28,28,1)#定义神经网络模型conv_net=input_data(shape=[None,28,28,1],name='input')conv_net=conv_2d(conv_net,32,2,activation='relu')conv_net=max_pool_2d(conv_net,2)conv_net=conv_2d(conv_net,64,2,activation='relu')conv_net=max_pool_2d(conv_net,2)conv_net=fully_connected(conv_net,1024,activation='relu')conv_net=dropout(conv_net,0.8)conv_net=fully_connected(conv_net,10,activation='softmax')conv_net=regression(conv_net,optimizer='adam',loss='categorical_crossentropy',name='output')model=tflearn.DNN(conv_net)#训练model.fit({'input':train_x},{'output':train_y},n_epoch=13,validation_set=({'input':test_x},{'output':test_y}),snapshot_step=300,show_metric=True,run_id='mnist')model.save('./mnist.model')#保存模型"""model.load('mnist.model')#加载模型model.predict([test_x[1]])#预测"""

到此,关于“如何使用TensorFlow创建CNN”的学习就结束了,希望能够解决大家的疑惑。理论与实践的搭配能更好的帮助大家学习,快去试试吧!若想继续学习更多相关知识,请继续关注恰卡编程网网站,小编会继续努力为大家带来更多实用的文章!

发布于 2022-01-14 22:30:29
收藏
分享
海报
0 条评论
40
上一篇:mnist数据集问题怎么解决 下一篇:CNN的相关知识点有哪些
目录

    0 条评论

    本站已关闭游客评论,请登录或者注册后再评论吧~

    忘记密码?

    图形验证码