怎么在C#中利用TensorFlow.NET训练数据集
怎么在C#中利用TensorFlow.NET训练数据集?很多新手对此不是很清楚,为了帮助大家解决这个难题,下面小编将为大家详细讲解,有这方面需求的人可以来学习下,希望你能有所收获。
什么是TensorFlow.NET?
TensorFlow.NET 是 SciSharp STACK
开源社区团队的贡献,其使命是打造一个完全属于.NET开发者自己的机器学习平台,特别对于C#开发人员来说,是一个“0”学习成本的机器学习平台,该平台集成了大量API和底层封装,力图使TensorFlow的Python代码风格和编程习惯可以无缝移植到.NET平台,下图是同样TF任务的Python实现和C#实现的语法相似度对比,从中读者基本可以略窥一二。
由于TensorFlow.NET在.NET平台的优秀性能,同时搭配SciSharp的NumSharp、SharpCV、Pandas.NET、Keras.NET、Matplotlib.Net等模块,可以完全脱离Python环境使用,目前已经被微软ML.NET官方的底层算法集成,并被谷歌写入TensorFlow官网教程推荐给全球开发者。
SciSharp 产品结构
微软 ML.NET底层集成算法
谷歌官方推荐.NET开发者使用
URL: https://www.tensorflow.org/versions/r2.0/api_docs
项目说明
本文利用TensorFlow.NET构建简单的图像分类模型,针对工业现场的印刷字符进行单字符OCR识别,从工业相机获取原始大尺寸的图像,前期使用OpenCV进行图像预处理和字符分割,提取出单个字符的小图,送入TF进行推理,推理的结果按照顺序组合成完整的字符串,返回至主程序逻辑进行后续的生产线工序。
实际使用中,如果你们需要训练自己的图像,只需要把训练的文件夹按照规定的顺序替换成你们自己的图片即可。支持GPU或CPU方式,该项目的完整代码在GitHub如下:
https://github.com/SciSharp/SciSharp-Stack-Examples/blob/master/src/TensorFlowNET.Examples/ImageProcessing/CnnInYourOwnData.cs
模型介绍
本项目的CNN模型主要由 2个卷积层&池化层 和 1个全连接层 组成,激活函数使用常见的Relu,是一个比较浅的卷积神经网络模型。其中超参数之一"学习率",采用了自定义的动态下降的学习率,后面会有详细说明。具体每一层的Shape参考下图:
数据集说明
为了模型测试的训练速度考虑,图像数据集主要节选了一小部分的OCR字符(X、Y、Z),数据集的特征如下:
分类数量:3 classes 【X/Y/Z】
图像尺寸:Width 64 × Height 64
图像通道:1 channel(灰度图)
数据集数量:
train:X - 384pcs ; Y - 384pcs ; Z - 384pcs
validation:X - 96pcs ; Y - 96pcs ; Z - 96pcs
test:X - 96pcs ; Y - 96pcs ; Z - 96pcs
其它说明:数据集已经经过 随机 翻转/平移/缩放/镜像 等预处理进行增强
整体数据集情况如下图所示:
代码说明
环境设置
.NET 框架:使用.NET Framework 4.7.2及以上,或者使用.NET CORE 2.2及以上
CPU 配置: Any CPU 或 X64 皆可
GPU 配置:需要自行配置好CUDA和环境变量,建议 CUDA v10.1,Cudnn v7.5
类库和命名空间引用
从NuGet安装必要的依赖项,主要是SciSharp相关的类库,如下图所示:
注意事项:尽量安装最新版本的类库,CV须使用 SciSharp 的 SharpCV 方便内部变量传递
<PackageReferenceInclude="Colorful.Console"Version="1.2.9"/> <PackageReferenceInclude="Newtonsoft.Json"Version="12.0.3"/> <PackageReferenceInclude="SciSharp.TensorFlow.Redist"Version="1.15.0"/> <PackageReferenceInclude="SciSharp.TensorFlowHub"Version="0.0.5"/> <PackageReferenceInclude="SharpCV"Version="0.2.0"/> <PackageReferenceInclude="SharpZipLib"Version="1.2.0"/> <PackageReferenceInclude="System.Drawing.Common"Version="4.7.0"/> <PackageReferenceInclude="TensorFlow.NET"Version="0.14.0"/>
引用命名空间,包括 NumSharp、Tensorflow 和 SharpCV ;
usingNumSharp; usingNumSharp.Backends; usingNumSharp.Backends.Unmanaged; usingSharpCV; usingSystem; usingSystem.Collections; usingSystem.Collections.Generic; usingSystem.Diagnostics; usingSystem.IO; usingSystem.Linq; usingSystem.Runtime.CompilerServices; usingTensorflow; usingstaticTensorflow.Binding; usingstaticSharpCV.Binding; usingSystem.Collections.Concurrent; usingSystem.Threading.Tasks;
主逻辑结构
主逻辑:
准备数据
创建计算图
训练
预测
publicboolRun() { PrepareData(); BuildGraph(); using(varsess=tf.Session()) { Train(sess); Test(sess); } TestDataOutput(); returnaccuracy_test>0.98; }
数据集载入
数据集下载和解压
数据集地址:https://github.com/SciSharp/SciSharp-Stack-Examples/blob/master/data/data_CnnInYourOwnData.zip
数据集下载和解压代码 ( 部分封装的方法请参考 GitHub完整代码 ):
stringurl="https://github.com/SciSharp/SciSharp-Stack-Examples/blob/master/data/data_CnnInYourOwnData.zip"; Directory.CreateDirectory(Name); Utility.Web.Download(url,Name,"data_CnnInYourOwnData.zip"); Utility.Compress.UnZip(Name+"\\data_CnnInYourOwnData.zip",Name);
字典创建
读取目录下的子文件夹名称,作为分类的字典,方便后面One-hot使用
privatevoidFillDictionaryLabel(stringDirPath) { string[]str_dir=Directory.GetDirectories(DirPath,"*",SearchOption.TopDirectoryOnly); intstr_dir_num=str_dir.Length; if(str_dir_num>0) { Dict_Label=newDictionary<Int64,string>(); for(inti=0;i<str_dir_num;i++) { stringlabel=(str_dir[i].Replace(DirPath+"\\","")).Split('\\').First(); Dict_Label.Add(i,label); print(i.ToString()+":"+label); } n_classes=Dict_Label.Count; } }
文件List读取和打乱
从文件夹中读取train、validation、test的list,并随机打乱顺序。
读取目录
ArrayFileName_Train=Directory.GetFiles(Name+"\\train","*.*",SearchOption.AllDirectories); ArrayLabel_Train=GetLabelArray(ArrayFileName_Train); ArrayFileName_Validation=Directory.GetFiles(Name+"\\validation","*.*",SearchOption.AllDirectories); ArrayLabel_Validation=GetLabelArray(ArrayFileName_Validation); ArrayFileName_Test=Directory.GetFiles(Name+"\\test","*.*",SearchOption.AllDirectories); ArrayLabel_Test=GetLabelArray(ArrayFileName_Test);
获得标签
privateInt64[]GetLabelArray(string[]FilesArray) { Int64[]ArrayLabel=newInt64[FilesArray.Length]; for(inti=0;i<ArrayLabel.Length;i++) { string[]labels=FilesArray[i].Split('\\'); stringlabel=labels[labels.Length-2]; ArrayLabel[i]=Dict_Label.Single(k=>k.Value==label).Key; } returnArrayLabel; }
随机乱序
public(string[],Int64[])ShuffleArray(intcount,string[]images,Int64[]labels) { ArrayListmylist=newArrayList(); string[]new_images=newstring[count]; Int64[]new_labels=newInt64[count]; Randomr=newRandom(); for(inti=0;i<count;i++) { mylist.Add(i); } for(inti=0;i<count;i++) { intrand=r.Next(mylist.Count); new_images[i]=images[(int)(mylist[rand])]; new_labels[i]=labels[(int)(mylist[rand])]; mylist.RemoveAt(rand); } print("shufflearraylist:"+count.ToString()); return(new_images,new_labels); }
部分数据集预先载入
Validation/Test数据集和标签一次性预先载入成NDArray格式。
privatevoidLoadImagesToNDArray() { //Loadlabels y_valid=np.eye(Dict_Label.Count)[newNDArray(ArrayLabel_Validation)]; y_test=np.eye(Dict_Label.Count)[newNDArray(ArrayLabel_Test)]; print("LoadLabelsToNDArray:OK!"); //LoadImages x_valid=np.zeros(ArrayFileName_Validation.Length,img_h,img_w,n_channels); x_test=np.zeros(ArrayFileName_Test.Length,img_h,img_w,n_channels); LoadImage(ArrayFileName_Validation,x_valid,"validation"); LoadImage(ArrayFileName_Test,x_test,"test"); print("LoadImagesToNDArray:OK!"); } privatevoidLoadImage(string[]a,NDArrayb,stringc) { for(inti=0;i<a.Length;i++) { b[i]=ReadTensorFromImageFile(a[i]); Console.Write("."); } Console.WriteLine(); Console.WriteLine("LoadImagesToNDArray:"+c); } privateNDArrayReadTensorFromImageFile(stringfile_name) { using(vargraph=tf.Graph().as_default()) { varfile_reader=tf.read_file(file_name,"file_reader"); vardecodeJpeg=tf.image.decode_jpeg(file_reader,channels:n_channels,name:"DecodeJpeg"); varcast=tf.cast(decodeJpeg,tf.float32); vardims_expander=tf.expand_dims(cast,0); varresize=tf.constant(newint[]{img_h,img_w}); varbilinear=tf.image.resize_bilinear(dims_expander,resize); varsub=tf.subtract(bilinear,newfloat[]{img_mean}); varnormalized=tf.divide(sub,newfloat[]{img_std}); using(varsess=tf.Session(graph)) { returnsess.run(normalized); } } }
计算图构建
构建CNN静态计算图,其中学习率每n轮Epoch进行1次递减。
#regionBuildGraph publicGraphBuildGraph() { vargraph=newGraph().as_default(); tf_with(tf.name_scope("Input"),delegate { x=tf.placeholder(tf.float32,shape:(-1,img_h,img_w,n_channels),name:"X"); y=tf.placeholder(tf.float32,shape:(-1,n_classes),name:"Y"); }); varconv1=conv_layer(x,filter_size1,num_filters1,stride1,name:"conv1"); varpool1=max_pool(conv1,ksize:2,stride:2,name:"pool1"); varconv2=conv_layer(pool1,filter_size2,num_filters2,stride2,name:"conv2"); varpool2=max_pool(conv2,ksize:2,stride:2,name:"pool2"); varlayer_flat=flatten_layer(pool2); varfc1=fc_layer(layer_flat,h2,"FC1",use_relu:true); varoutput_logits=fc_layer(fc1,n_classes,"OUT",use_relu:false); //Someimportantparametersavedwithgraph,easytoloadlater varimg_h_t=tf.constant(img_h,name:"img_h"); varimg_w_t=tf.constant(img_w,name:"img_w"); varimg_mean_t=tf.constant(img_mean,name:"img_mean"); varimg_std_t=tf.constant(img_std,name:"img_std"); varchannels_t=tf.constant(n_channels,name:"img_channels"); //learningratedecay gloabl_steps=tf.Variable(0,trainable:false); learning_rate=tf.Variable(learning_rate_base); //createtrainimagesgraph tf_with(tf.variable_scope("LoadImage"),delegate { decodeJpeg=tf.placeholder(tf.@byte,name:"DecodeJpeg"); varcast=tf.cast(decodeJpeg,tf.float32); vardims_expander=tf.expand_dims(cast,0); varresize=tf.constant(newint[]{img_h,img_w}); varbilinear=tf.image.resize_bilinear(dims_expander,resize); varsub=tf.subtract(bilinear,newfloat[]{img_mean}); normalized=tf.divide(sub,newfloat[]{img_std},name:"normalized"); }); tf_with(tf.variable_scope("Train"),delegate { tf_with(tf.variable_scope("Loss"),delegate { loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels:y,logits:output_logits),name:"loss"); }); tf_with(tf.variable_scope("Optimizer"),delegate { optimizer=tf.train.AdamOptimizer(learning_rate:learning_rate,name:"Adam-op").minimize(loss,global_step:gloabl_steps); }); tf_with(tf.variable_scope("Accuracy"),delegate { varcorrect_prediction=tf.equal(tf.argmax(output_logits,1),tf.argmax(y,1),name:"correct_pred"); accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32),name:"accuracy"); }); tf_with(tf.variable_scope("Prediction"),delegate { cls_prediction=tf.argmax(output_logits,axis:1,name:"predictions"); prob=tf.nn.softmax(output_logits,axis:1,name:"prob"); }); }); returngraph; } ///<summary> ///Createa2Dconvolutionlayer ///</summary> ///<paramname="x">inputfrompreviouslayer</param> ///<paramname="filter_size">sizeofeachfilter</param> ///<paramname="num_filters">numberoffilters(oroutputfeaturemaps)</param> ///<paramname="stride">filterstride</param> ///<paramname="name">layername</param> ///<returns>Theoutputarray</returns> privateTensorconv_layer(Tensorx,intfilter_size,intnum_filters,intstride,stringname) { returntf_with(tf.variable_scope(name),delegate { varnum_in_channel=x.shape[x.NDims-1]; varshape=new[]{filter_size,filter_size,num_in_channel,num_filters}; varW=weight_variable("W",shape); //vartf.summary.histogram("weight",W); varb=bias_variable("b",new[]{num_filters}); //tf.summary.histogram("bias",b); varlayer=tf.nn.conv2d(x,W, strides:new[]{1,stride,stride,1}, padding:"SAME"); layer+=b; returntf.nn.relu(layer); }); } ///<summary> ///Createamaxpoolinglayer ///</summary> ///<paramname="x">inputtomax-poolinglayer</param> ///<paramname="ksize">sizeofthemax-poolingfilter</param> ///<paramname="stride">strideofthemax-poolingfilter</param> ///<paramname="name">layername</param> ///<returns>Theoutputarray</returns> privateTensormax_pool(Tensorx,intksize,intstride,stringname) { returntf.nn.max_pool(x, ksize:new[]{1,ksize,ksize,1}, strides:new[]{1,stride,stride,1}, padding:"SAME", name:name); } ///<summary> ///Flattenstheoutputoftheconvolutionallayertobefedintofully-connectedlayer ///</summary> ///<paramname="layer">inputarray</param> ///<returns>flattenedarray</returns> privateTensorflatten_layer(Tensorlayer) { returntf_with(tf.variable_scope("Flatten_layer"),delegate { varlayer_shape=layer.TensorShape; varnum_features=layer_shape[newSlice(1,4)].size; varlayer_flat=tf.reshape(layer,new[]{-1,num_features}); returnlayer_flat; }); } ///<summary> ///Createaweightvariablewithappropriateinitialization ///</summary> ///<paramname="name"></param> ///<paramname="shape"></param> ///<returns></returns> privateRefVariableweight_variable(stringname,int[]shape) { variniter=tf.truncated_normal_initializer(stddev:0.01f); returntf.get_variable(name, dtype:tf.float32, shape:shape, initializer:initer); } ///<summary> ///Createabiasvariablewithappropriateinitialization ///</summary> ///<paramname="name"></param> ///<paramname="shape"></param> ///<returns></returns> privateRefVariablebias_variable(stringname,int[]shape) { varinitial=tf.constant(0f,shape:shape,dtype:tf.float32); returntf.get_variable(name, dtype:tf.float32, initializer:initial); } ///<summary> ///Createafully-connectedlayer ///</summary> ///<paramname="x">inputfrompreviouslayer</param> ///<paramname="num_units">numberofhiddenunitsinthefully-connectedlayer</param> ///<paramname="name">layername</param> ///<paramname="use_relu">booleantoaddReLUnon-linearity(ornot)</param> ///<returns>Theoutputarray</returns> privateTensorfc_layer(Tensorx,intnum_units,stringname,booluse_relu=true) { returntf_with(tf.variable_scope(name),delegate { varin_dim=x.shape[1]; varW=weight_variable("W_"+name,shape:new[]{in_dim,num_units}); varb=bias_variable("b_"+name,new[]{num_units}); varlayer=tf.matmul(x,W)+b; if(use_relu) layer=tf.nn.relu(layer); returnlayer; }); } #endregion
模型训练和模型保存
Batch数据集的读取,采用了 SharpCV 的cv2.imread,可以直接读取本地图像文件至NDArray,实现CV和Numpy的无缝对接;
使用.NET的异步线程安全队列BlockingCollection<T>,实现TensorFlow原生的队列管理器FIFOQueue;
在训练模型的时候,我们需要将样本从硬盘读取到内存之后,才能进行训练。我们在会话中运行多个线程,并加入队列管理器进行线程间的文件入队出队操作,并限制队列容量,主线程可以利用队列中的数据进行训练,另一个线程进行本地文件的IO读取,这样可以实现数据的读取和模型的训练是异步的,降低训练时间。
模型的保存,可以选择每轮训练都保存,或最佳训练模型保存
#regionTrain publicvoidTrain(Sessionsess) { //Numberoftrainingiterationsineachepoch varnum_tr_iter=(ArrayLabel_Train.Length)/batch_size; varinit=tf.global_variables_initializer(); sess.run(init); varsaver=tf.train.Saver(tf.global_variables(),max_to_keep:10); path_model=Name+"\\MODEL"; Directory.CreateDirectory(path_model); floatloss_val=100.0f; floataccuracy_val=0f; varsw=newStopwatch(); sw.Start(); foreach(varepochinrange(epochs)) { print($"Trainingepoch:{epoch+1}"); //Randomlyshufflethetrainingdataatthebeginningofeachepoch (ArrayFileName_Train,ArrayLabel_Train)=ShuffleArray(ArrayLabel_Train.Length,ArrayFileName_Train,ArrayLabel_Train); y_train=np.eye(Dict_Label.Count)[newNDArray(ArrayLabel_Train)]; //decaylearningrate if(learning_rate_step!=0) { if((epoch!=0)&&(epoch%learning_rate_step==0)) { learning_rate_base=learning_rate_base*learning_rate_decay; if(learning_rate_base<=learning_rate_min){learning_rate_base=learning_rate_min;} sess.run(tf.assign(learning_rate,learning_rate_base)); } } //Loadlocalimagesasynchronously,usequeue,improvetrainefficiency BlockingCollection<(NDArrayc_x,NDArrayc_y,intiter)>BlockC=newBlockingCollection<(NDArrayC1,NDArrayC2,intiter)>(TrainQueueCapa); Task.Run(()=> { foreach(variterationinrange(num_tr_iter)) { varstart=iteration*batch_size; varend=(iteration+1)*batch_size; (NDArrayx_batch,NDArrayy_batch)=GetNextBatch(sess,ArrayFileName_Train,y_train,start,end); BlockC.Add((x_batch,y_batch,iteration)); } BlockC.CompleteAdding(); }); foreach(variteminBlockC.GetConsumingEnumerable()) { sess.run(optimizer,(x,item.c_x),(y,item.c_y)); if(item.iter%display_freq==0) { //Calculateanddisplaythebatchlossandaccuracy varresult=sess.run(new[]{loss,accuracy},newFeedItem(x,item.c_x),newFeedItem(y,item.c_y)); loss_val=result[0]; accuracy_val=result[1]; print("CNN:"+($"iter{item.iter.ToString("000")}:Loss={loss_val.ToString("0.0000")},TrainingAccuracy={accuracy_val.ToString("P")}{sw.ElapsedMilliseconds}ms")); sw.Restart(); } } //Runvalidationaftereveryepoch (loss_val,accuracy_val)=sess.run((loss,accuracy),(x,x_valid),(y,y_valid)); print("CNN:"+"---------------------------------------------------------"); print("CNN:"+$"gloablsteps:{sess.run(gloabl_steps)},learningrate:{sess.run(learning_rate)},validationloss:{loss_val.ToString("0.0000")},validationaccuracy:{accuracy_val.ToString("P")}"); print("CNN:"+"---------------------------------------------------------"); if(SaverBest) { if(accuracy_val>max_accuracy) { max_accuracy=accuracy_val; saver.save(sess,path_model+"\\CNN_Best"); print("CKPTModelissave."); } } else { saver.save(sess,path_model+string.Format("\\CNN_Epoch_{0}_Loss_{1}_Acc_{2}",epoch,loss_val,accuracy_val)); print("CKPTModelissave."); } } Write_Dictionary(path_model+"\\dic.txt",Dict_Label); } privatevoidWrite_Dictionary(stringpath,Dictionary<Int64,string>mydic) { FileStreamfs=newFileStream(path,FileMode.Create); StreamWritersw=newStreamWriter(fs); foreach(vardinmydic){sw.Write(d.Key+","+d.Value+"\r\n");} sw.Flush(); sw.Close(); fs.Close(); print("Write_Dictionary"); } private(NDArray,NDArray)Randomize(NDArrayx,NDArrayy) { varperm=np.random.permutation(y.shape[0]); np.random.shuffle(perm); return(x[perm],y[perm]); } private(NDArray,NDArray)GetNextBatch(NDArrayx,NDArrayy,intstart,intend) { varslice=newSlice(start,end); varx_batch=x[slice]; vary_batch=y[slice]; return(x_batch,y_batch); } privateunsafe(NDArray,NDArray)GetNextBatch(Sessionsess,string[]x,NDArrayy,intstart,intend) { NDArrayx_batch=np.zeros(end-start,img_h,img_w,n_channels); intn=0; for(inti=start;i<end;i++) { NDArrayimg4=cv2.imread(x[i],IMREAD_COLOR.IMREAD_GRAYSCALE); x_batch[n]=sess.run(normalized,(decodeJpeg,img4)); n++; } varslice=newSlice(start,end); vary_batch=y[slice]; return(x_batch,y_batch); } #endregion
测试集预测
训练完成的模型对test数据集进行预测,并统计准确率
计算图中增加了一个提取预测结果Top-1的概率的节点,最后测试集预测的时候可以把详细的预测数据进行输出,方便实际工程中进行调试和优化。
publicvoidTest(Sessionsess) { (loss_test,accuracy_test)=sess.run((loss,accuracy),(x,x_test),(y,y_test)); print("CNN:"+"---------------------------------------------------------"); print("CNN:"+$"Testloss:{loss_test.ToString("0.0000")},testaccuracy:{accuracy_test.ToString("P")}"); print("CNN:"+"---------------------------------------------------------"); (Test_Cls,Test_Data)=sess.run((cls_prediction,prob),(x,x_test)); } privatevoidTestDataOutput() { for(inti=0;i<ArrayLabel_Test.Length;i++) { Int64real=ArrayLabel_Test[i]; intpredict=(int)(Test_Cls[i]); varprobability=Test_Data[i,predict]; stringresult=(real==predict)?"OK":"NG"; stringfileName=ArrayFileName_Test[i]; stringreal_str=Dict_Label[real]; stringpredict_str=Dict_Label[predict]; print((i+1).ToString()+"|"+"result:"+result+"|"+"real_str:"+real_str+"|" +"predict_str:"+predict_str+"|"+"probability:"+probability.GetSingle().ToString()+"|" +"fileName:"+fileName); } }
看完上述内容是否对您有帮助呢?如果还想对相关知识有进一步的了解或阅读更多相关文章,请关注恰卡编程网行业资讯频道,感谢您对恰卡编程网的支持。