怎么在C#中利用TensorFlow.NET训练数据集
怎么在C#中利用TensorFlow.NET训练数据集?很多新手对此不是很清楚,为了帮助大家解决这个难题,下面小编将为大家详细讲解,有这方面需求的人可以来学习下,希望你能有所收获。
什么是TensorFlow.NET?
TensorFlow.NET 是 SciSharp STACK
开源社区团队的贡献,其使命是打造一个完全属于.NET开发者自己的机器学习平台,特别对于C#开发人员来说,是一个“0”学习成本的机器学习平台,该平台集成了大量API和底层封装,力图使TensorFlow的Python代码风格和编程习惯可以无缝移植到.NET平台,下图是同样TF任务的Python实现和C#实现的语法相似度对比,从中读者基本可以略窥一二。
由于TensorFlow.NET在.NET平台的优秀性能,同时搭配SciSharp的NumSharp、SharpCV、Pandas.NET、Keras.NET、Matplotlib.Net等模块,可以完全脱离Python环境使用,目前已经被微软ML.NET官方的底层算法集成,并被谷歌写入TensorFlow官网教程推荐给全球开发者。
SciSharp 产品结构
微软 ML.NET底层集成算法
谷歌官方推荐.NET开发者使用
URL: https://www.tensorflow.org/versions/r2.0/api_docs
项目说明
本文利用TensorFlow.NET构建简单的图像分类模型,针对工业现场的印刷字符进行单字符OCR识别,从工业相机获取原始大尺寸的图像,前期使用OpenCV进行图像预处理和字符分割,提取出单个字符的小图,送入TF进行推理,推理的结果按照顺序组合成完整的字符串,返回至主程序逻辑进行后续的生产线工序。
实际使用中,如果你们需要训练自己的图像,只需要把训练的文件夹按照规定的顺序替换成你们自己的图片即可。支持GPU或CPU方式,该项目的完整代码在GitHub如下:
https://github.com/SciSharp/SciSharp-Stack-Examples/blob/master/src/TensorFlowNET.Examples/ImageProcessing/CnnInYourOwnData.cs
模型介绍
本项目的CNN模型主要由 2个卷积层&池化层 和 1个全连接层 组成,激活函数使用常见的Relu,是一个比较浅的卷积神经网络模型。其中超参数之一"学习率",采用了自定义的动态下降的学习率,后面会有详细说明。具体每一层的Shape参考下图:
数据集说明
为了模型测试的训练速度考虑,图像数据集主要节选了一小部分的OCR字符(X、Y、Z),数据集的特征如下:
分类数量:3 classes 【X/Y/Z】
图像尺寸:Width 64 × Height 64
图像通道:1 channel(灰度图)
数据集数量:
train:X - 384pcs ; Y - 384pcs ; Z - 384pcs
validation:X - 96pcs ; Y - 96pcs ; Z - 96pcs
test:X - 96pcs ; Y - 96pcs ; Z - 96pcs
其它说明:数据集已经经过 随机 翻转/平移/缩放/镜像 等预处理进行增强
整体数据集情况如下图所示:
代码说明
环境设置
.NET 框架:使用.NET Framework 4.7.2及以上,或者使用.NET CORE 2.2及以上
CPU 配置: Any CPU 或 X64 皆可
GPU 配置:需要自行配置好CUDA和环境变量,建议 CUDA v10.1,Cudnn v7.5
类库和命名空间引用
从NuGet安装必要的依赖项,主要是SciSharp相关的类库,如下图所示:
注意事项:尽量安装最新版本的类库,CV须使用 SciSharp 的 SharpCV 方便内部变量传递
引用命名空间,包括 NumSharp、Tensorflow 和 SharpCV ;
usingNumSharp; usingNumSharp.Backends; usingNumSharp.Backends.Unmanaged; usingSharpCV; usingSystem; usingSystem.Collections; usingSystem.Collections.Generic; usingSystem.Diagnostics; usingSystem.IO; usingSystem.Linq; usingSystem.Runtime.CompilerServices; usingTensorflow; usingstaticTensorflow.Binding; usingstaticSharpCV.Binding; usingSystem.Collections.Concurrent; usingSystem.Threading.Tasks;
主逻辑结构
主逻辑:
准备数据
创建计算图
训练
预测
publicboolRun()
{
PrepareData();
BuildGraph();
using(varsess=tf.Session())
{
Train(sess);
Test(sess);
}
TestDataOutput();
returnaccuracy_test>0.98;
}数据集载入
数据集下载和解压
数据集地址:https://github.com/SciSharp/SciSharp-Stack-Examples/blob/master/data/data_CnnInYourOwnData.zip
数据集下载和解压代码 ( 部分封装的方法请参考 GitHub完整代码 ):
stringurl="https://github.com/SciSharp/SciSharp-Stack-Examples/blob/master/data/data_CnnInYourOwnData.zip"; Directory.CreateDirectory(Name); Utility.Web.Download(url,Name,"data_CnnInYourOwnData.zip"); Utility.Compress.UnZip(Name+"\\data_CnnInYourOwnData.zip",Name);
字典创建
读取目录下的子文件夹名称,作为分类的字典,方便后面One-hot使用
privatevoidFillDictionaryLabel(stringDirPath)
{
string[]str_dir=Directory.GetDirectories(DirPath,"*",SearchOption.TopDirectoryOnly);
intstr_dir_num=str_dir.Length;
if(str_dir_num>0)
{
Dict_Label=newDictionary();
for(inti=0;i文件List读取和打乱
从文件夹中读取train、validation、test的list,并随机打乱顺序。
读取目录
ArrayFileName_Train=Directory.GetFiles(Name+"\\train","*.*",SearchOption.AllDirectories);
ArrayLabel_Train=GetLabelArray(ArrayFileName_Train);
ArrayFileName_Validation=Directory.GetFiles(Name+"\\validation","*.*",SearchOption.AllDirectories);
ArrayLabel_Validation=GetLabelArray(ArrayFileName_Validation);
ArrayFileName_Test=Directory.GetFiles(Name+"\\test","*.*",SearchOption.AllDirectories);
ArrayLabel_Test=GetLabelArray(ArrayFileName_Test);
获得标签
privateInt64[]GetLabelArray(string[]FilesArray)
{
Int64[]ArrayLabel=newInt64[FilesArray.Length];
for(inti=0;ik.Value==label).Key;
}
returnArrayLabel;
} 随机乱序
public(string[],Int64[])ShuffleArray(intcount,string[]images,Int64[]labels)
{
ArrayListmylist=newArrayList();
string[]new_images=newstring[count];
Int64[]new_labels=newInt64[count];
Randomr=newRandom();
for(inti=0;i部分数据集预先载入
Validation/Test数据集和标签一次性预先载入成NDArray格式。
privatevoidLoadImagesToNDArray()
{
//Loadlabels
y_valid=np.eye(Dict_Label.Count)[newNDArray(ArrayLabel_Validation)];
y_test=np.eye(Dict_Label.Count)[newNDArray(ArrayLabel_Test)];
print("LoadLabelsToNDArray:OK!");
//LoadImages
x_valid=np.zeros(ArrayFileName_Validation.Length,img_h,img_w,n_channels);
x_test=np.zeros(ArrayFileName_Test.Length,img_h,img_w,n_channels);
LoadImage(ArrayFileName_Validation,x_valid,"validation");
LoadImage(ArrayFileName_Test,x_test,"test");
print("LoadImagesToNDArray:OK!");
}
privatevoidLoadImage(string[]a,NDArrayb,stringc)
{
for(inti=0;i计算图构建
构建CNN静态计算图,其中学习率每n轮Epoch进行1次递减。
#regionBuildGraph
publicGraphBuildGraph()
{
vargraph=newGraph().as_default();
tf_with(tf.name_scope("Input"),delegate
{
x=tf.placeholder(tf.float32,shape:(-1,img_h,img_w,n_channels),name:"X");
y=tf.placeholder(tf.float32,shape:(-1,n_classes),name:"Y");
});
varconv1=conv_layer(x,filter_size1,num_filters1,stride1,name:"conv1");
varpool1=max_pool(conv1,ksize:2,stride:2,name:"pool1");
varconv2=conv_layer(pool1,filter_size2,num_filters2,stride2,name:"conv2");
varpool2=max_pool(conv2,ksize:2,stride:2,name:"pool2");
varlayer_flat=flatten_layer(pool2);
varfc1=fc_layer(layer_flat,h2,"FC1",use_relu:true);
varoutput_logits=fc_layer(fc1,n_classes,"OUT",use_relu:false);
//Someimportantparametersavedwithgraph,easytoloadlater
varimg_h_t=tf.constant(img_h,name:"img_h");
varimg_w_t=tf.constant(img_w,name:"img_w");
varimg_mean_t=tf.constant(img_mean,name:"img_mean");
varimg_std_t=tf.constant(img_std,name:"img_std");
varchannels_t=tf.constant(n_channels,name:"img_channels");
//learningratedecay
gloabl_steps=tf.Variable(0,trainable:false);
learning_rate=tf.Variable(learning_rate_base);
//createtrainimagesgraph
tf_with(tf.variable_scope("LoadImage"),delegate
{
decodeJpeg=tf.placeholder(tf.@byte,name:"DecodeJpeg");
varcast=tf.cast(decodeJpeg,tf.float32);
vardims_expander=tf.expand_dims(cast,0);
varresize=tf.constant(newint[]{img_h,img_w});
varbilinear=tf.image.resize_bilinear(dims_expander,resize);
varsub=tf.subtract(bilinear,newfloat[]{img_mean});
normalized=tf.divide(sub,newfloat[]{img_std},name:"normalized");
});
tf_with(tf.variable_scope("Train"),delegate
{
tf_with(tf.variable_scope("Loss"),delegate
{
loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels:y,logits:output_logits),name:"loss");
});
tf_with(tf.variable_scope("Optimizer"),delegate
{
optimizer=tf.train.AdamOptimizer(learning_rate:learning_rate,name:"Adam-op").minimize(loss,global_step:gloabl_steps);
});
tf_with(tf.variable_scope("Accuracy"),delegate
{
varcorrect_prediction=tf.equal(tf.argmax(output_logits,1),tf.argmax(y,1),name:"correct_pred");
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32),name:"accuracy");
});
tf_with(tf.variable_scope("Prediction"),delegate
{
cls_prediction=tf.argmax(output_logits,axis:1,name:"predictions");
prob=tf.nn.softmax(output_logits,axis:1,name:"prob");
});
});
returngraph;
}
///
///Createa2Dconvolutionlayer
///
///inputfrompreviouslayer
///sizeofeachfilter
///numberoffilters(oroutputfeaturemaps)
///filterstride
///layername
///Theoutputarray
privateTensorconv_layer(Tensorx,intfilter_size,intnum_filters,intstride,stringname)
{
returntf_with(tf.variable_scope(name),delegate
{
varnum_in_channel=x.shape[x.NDims-1];
varshape=new[]{filter_size,filter_size,num_in_channel,num_filters};
varW=weight_variable("W",shape);
//vartf.summary.histogram("weight",W);
varb=bias_variable("b",new[]{num_filters});
//tf.summary.histogram("bias",b);
varlayer=tf.nn.conv2d(x,W,
strides:new[]{1,stride,stride,1},
padding:"SAME");
layer+=b;
returntf.nn.relu(layer);
});
}
///
///Createamaxpoolinglayer
///
///inputtomax-poolinglayer
///sizeofthemax-poolingfilter
///strideofthemax-poolingfilter
///layername
///Theoutputarray
privateTensormax_pool(Tensorx,intksize,intstride,stringname)
{
returntf.nn.max_pool(x,
ksize:new[]{1,ksize,ksize,1},
strides:new[]{1,stride,stride,1},
padding:"SAME",
name:name);
}
///
///Flattenstheoutputoftheconvolutionallayertobefedintofully-connectedlayer
///
///inputarray
///flattenedarray
privateTensorflatten_layer(Tensorlayer)
{
returntf_with(tf.variable_scope("Flatten_layer"),delegate
{
varlayer_shape=layer.TensorShape;
varnum_features=layer_shape[newSlice(1,4)].size;
varlayer_flat=tf.reshape(layer,new[]{-1,num_features});
returnlayer_flat;
});
}
///
///Createaweightvariablewithappropriateinitialization
///
///
///
///
privateRefVariableweight_variable(stringname,int[]shape)
{
variniter=tf.truncated_normal_initializer(stddev:0.01f);
returntf.get_variable(name,
dtype:tf.float32,
shape:shape,
initializer:initer);
}
///
///Createabiasvariablewithappropriateinitialization
///
///
///
///
privateRefVariablebias_variable(stringname,int[]shape)
{
varinitial=tf.constant(0f,shape:shape,dtype:tf.float32);
returntf.get_variable(name,
dtype:tf.float32,
initializer:initial);
}
///
///Createafully-connectedlayer
///
///inputfrompreviouslayer
///numberofhiddenunitsinthefully-connectedlayer
///layername
///booleantoaddReLUnon-linearity(ornot)
///Theoutputarray
privateTensorfc_layer(Tensorx,intnum_units,stringname,booluse_relu=true)
{
returntf_with(tf.variable_scope(name),delegate
{
varin_dim=x.shape[1];
varW=weight_variable("W_"+name,shape:new[]{in_dim,num_units});
varb=bias_variable("b_"+name,new[]{num_units});
varlayer=tf.matmul(x,W)+b;
if(use_relu)
layer=tf.nn.relu(layer);
returnlayer;
});
}
#endregion 模型训练和模型保存
Batch数据集的读取,采用了 SharpCV 的cv2.imread,可以直接读取本地图像文件至NDArray,实现CV和Numpy的无缝对接;
使用.NET的异步线程安全队列BlockingCollection,实现TensorFlow原生的队列管理器FIFOQueue;
在训练模型的时候,我们需要将样本从硬盘读取到内存之后,才能进行训练。我们在会话中运行多个线程,并加入队列管理器进行线程间的文件入队出队操作,并限制队列容量,主线程可以利用队列中的数据进行训练,另一个线程进行本地文件的IO读取,这样可以实现数据的读取和模型的训练是异步的,降低训练时间。
模型的保存,可以选择每轮训练都保存,或最佳训练模型保存
#regionTrain
publicvoidTrain(Sessionsess)
{
//Numberoftrainingiterationsineachepoch
varnum_tr_iter=(ArrayLabel_Train.Length)/batch_size;
varinit=tf.global_variables_initializer();
sess.run(init);
varsaver=tf.train.Saver(tf.global_variables(),max_to_keep:10);
path_model=Name+"\\MODEL";
Directory.CreateDirectory(path_model);
floatloss_val=100.0f;
floataccuracy_val=0f;
varsw=newStopwatch();
sw.Start();
foreach(varepochinrange(epochs))
{
print($"Trainingepoch:{epoch+1}");
//Randomlyshufflethetrainingdataatthebeginningofeachepoch
(ArrayFileName_Train,ArrayLabel_Train)=ShuffleArray(ArrayLabel_Train.Length,ArrayFileName_Train,ArrayLabel_Train);
y_train=np.eye(Dict_Label.Count)[newNDArray(ArrayLabel_Train)];
//decaylearningrate
if(learning_rate_step!=0)
{
if((epoch!=0)&&(epoch%learning_rate_step==0))
{
learning_rate_base=learning_rate_base*learning_rate_decay;
if(learning_rate_base<=learning_rate_min){learning_rate_base=learning_rate_min;}
sess.run(tf.assign(learning_rate,learning_rate_base));
}
}
//Loadlocalimagesasynchronously,usequeue,improvetrainefficiency
BlockingCollection<(NDArrayc_x,NDArrayc_y,intiter)>BlockC=newBlockingCollection<(NDArrayC1,NDArrayC2,intiter)>(TrainQueueCapa);
Task.Run(()=>
{
foreach(variterationinrange(num_tr_iter))
{
varstart=iteration*batch_size;
varend=(iteration+1)*batch_size;
(NDArrayx_batch,NDArrayy_batch)=GetNextBatch(sess,ArrayFileName_Train,y_train,start,end);
BlockC.Add((x_batch,y_batch,iteration));
}
BlockC.CompleteAdding();
});
foreach(variteminBlockC.GetConsumingEnumerable())
{
sess.run(optimizer,(x,item.c_x),(y,item.c_y));
if(item.iter%display_freq==0)
{
//Calculateanddisplaythebatchlossandaccuracy
varresult=sess.run(new[]{loss,accuracy},newFeedItem(x,item.c_x),newFeedItem(y,item.c_y));
loss_val=result[0];
accuracy_val=result[1];
print("CNN:"+($"iter{item.iter.ToString("000")}:Loss={loss_val.ToString("0.0000")},TrainingAccuracy={accuracy_val.ToString("P")}{sw.ElapsedMilliseconds}ms"));
sw.Restart();
}
}
//Runvalidationaftereveryepoch
(loss_val,accuracy_val)=sess.run((loss,accuracy),(x,x_valid),(y,y_valid));
print("CNN:"+"---------------------------------------------------------");
print("CNN:"+$"gloablsteps:{sess.run(gloabl_steps)},learningrate:{sess.run(learning_rate)},validationloss:{loss_val.ToString("0.0000")},validationaccuracy:{accuracy_val.ToString("P")}");
print("CNN:"+"---------------------------------------------------------");
if(SaverBest)
{
if(accuracy_val>max_accuracy)
{
max_accuracy=accuracy_val;
saver.save(sess,path_model+"\\CNN_Best");
print("CKPTModelissave.");
}
}
else
{
saver.save(sess,path_model+string.Format("\\CNN_Epoch_{0}_Loss_{1}_Acc_{2}",epoch,loss_val,accuracy_val));
print("CKPTModelissave.");
}
}
Write_Dictionary(path_model+"\\dic.txt",Dict_Label);
}
privatevoidWrite_Dictionary(stringpath,Dictionarymydic)
{
FileStreamfs=newFileStream(path,FileMode.Create);
StreamWritersw=newStreamWriter(fs);
foreach(vardinmydic){sw.Write(d.Key+","+d.Value+"\r\n");}
sw.Flush();
sw.Close();
fs.Close();
print("Write_Dictionary");
}
private(NDArray,NDArray)Randomize(NDArrayx,NDArrayy)
{
varperm=np.random.permutation(y.shape[0]);
np.random.shuffle(perm);
return(x[perm],y[perm]);
}
private(NDArray,NDArray)GetNextBatch(NDArrayx,NDArrayy,intstart,intend)
{
varslice=newSlice(start,end);
varx_batch=x[slice];
vary_batch=y[slice];
return(x_batch,y_batch);
}
privateunsafe(NDArray,NDArray)GetNextBatch(Sessionsess,string[]x,NDArrayy,intstart,intend)
{
NDArrayx_batch=np.zeros(end-start,img_h,img_w,n_channels);
intn=0;
for(inti=start;i测试集预测
训练完成的模型对test数据集进行预测,并统计准确率
计算图中增加了一个提取预测结果Top-1的概率的节点,最后测试集预测的时候可以把详细的预测数据进行输出,方便实际工程中进行调试和优化。
publicvoidTest(Sessionsess)
{
(loss_test,accuracy_test)=sess.run((loss,accuracy),(x,x_test),(y,y_test));
print("CNN:"+"---------------------------------------------------------");
print("CNN:"+$"Testloss:{loss_test.ToString("0.0000")},testaccuracy:{accuracy_test.ToString("P")}");
print("CNN:"+"---------------------------------------------------------");
(Test_Cls,Test_Data)=sess.run((cls_prediction,prob),(x,x_test));
}
privatevoidTestDataOutput()
{
for(inti=0;i看完上述内容是否对您有帮助呢?如果还想对相关知识有进一步的了解或阅读更多相关文章,请关注恰卡编程网行业资讯频道,感谢您对恰卡编程网的支持。
推荐阅读
-
4个理由告诉你Java为何排行第一
本文由码农网 –单劼原创翻译,转载请看清文末的转载要求,欢迎参与我们的付费投稿计划!Java已经有20年的历史了,甚...
-
写给精明Java开发者的测试技巧
我们都会为我们的代码编写测试,不是吗?毫无疑问,我知道这个问题的答案可能会从“当然,但你知道怎样才能避免写测试吗?”到“必须...
-
Java 微服务框架 Redkale 入门介绍
Redkale功能Redkale虽然只有1.xM大小,但是麻雀虽小五脏俱全。既可作为服务器使用,也可当工具包使用。作为独立的工...
-
Java内存管理原理及内存区域详解
一、概述Java虚拟机在执行Java程序的过程中会把它所管理的内存划分为若干不同的数据区域,这些区域都有各自的用途以及创建和销毁...
-
2015年Java开发岗位面试题归类
下面是我自己收集整理的Java岗位今天面经遇到的面试题,可以用它来好好准备面试。一、Java基础1.String...
-
Java 虚拟机类加载机制和字节码执行引擎
引言我们知道java代码编译后生成的是字节码,那虚拟机是如何加载这些class字节码文件的呢?加载之后又是如何进行方法调用的呢?...
-
Java常量池理解与总结
一.相关概念什么是常量用final修饰的成员变量表示常量,值一旦给定就无法改变!final修饰的变量有三种:静态...
-
Java 实现线程死锁
概述春节的时候去面试了一家公司,笔试题里面有一道是使用简单的代码实现线程的‘死锁’,当时没有想到这道题考的是Sync...
-
Java:过去、未来的互联网编程之王
Java对你而言是什么?一门你大学里学过的语言?一个IT行业的通用语言?你相信Java已经为下一次互联网爆炸做好了准备么?Java...
-
20个高级Java面试题汇总
本文由码农网 –小峰原创翻译,转载请看清文末的转载要求,欢迎参与我们的付费投稿计划!这是一个高级Java面试系列题中...
