小编给大家分享一下Matplotlib如何绘制混淆矩阵,希望大家阅读完这篇文章之后都有所收获,下面让我们一起去探讨吧!
代码如下:
importitertools
importmatplotlib.pyplotasplt
importnumpyasnp
#绘制混淆矩阵
defplot_confusion_matrix(cm,classes,normalize=False,title='Confusionmatrix',cmap=plt.cm.Blues):
"""
-cm:计算出的混淆矩阵的值
-classes:混淆矩阵中每一行每一列对应的列
-normalize:True:显示百分比,False:显示个数
"""
ifnormalize:
cm=cm.astype('float')/cm.sum(axis=1)[:,np.newaxis]
print("显示百分比:")
np.set_printoptions(formatter={'float':'{:0.2f}'.format})
print(cm)
else:
print('显示具体数字:')
print(cm)
plt.imshow(cm,interpolation='nearest',cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks=np.arange(len(classes))
plt.xticks(tick_marks,classes,rotation=45)
plt.yticks(tick_marks,classes)
#matplotlib版本问题,如果不加下面这行代码,则绘制的混淆矩阵上下只能显示一半,有的版本的matplotlib不需要下面的代码,分别试一下即可
plt.ylim(len(classes)-0.5,-0.5)
fmt='.2f'ifnormalizeelse'd'
thresh=cm.max()/2.
fori,jinitertools.product(range(cm.shape[0]),range(cm.shape[1])):
plt.text(j,i,format(cm[i,j],fmt),
horizontalalignment="center",
color="white"ifcm[i,j]>threshelse"black")
plt.tight_layout()
plt.ylabel('Truelabel')
plt.xlabel('Predictedlabel')
plt.show()
测试数据:
cnf_matrix=np.array([[8707,64,731,164,45],
[1821,5530,79,0,28],
[266,167,1982,4,2],
[691,0,107,1930,26],
[30,0,111,17,42]])
attack_types=['Normal','DoS','Probe','R2L','U2R']
第一种情况:显示百分比
plot_confusion_matrix(cnf_matrix,classes=attack_types,normalize=True,title='Normalizedconfusionmatrix')
效果:
第二种情况:显示数字
plot_confusion_matrix(cnf_matrix,classes=attack_types,normalize=False,title='Normalizedconfusionmatrix')
效果:
看完了这篇文章,相信你对“Matplotlib如何绘制混淆矩阵”有了一定的了解,如果想了解更多相关知识,欢迎关注恰卡编程网行业资讯频道,感谢各位的阅读!