前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >kNN 解决鸢尾花和手写数字识别分类问题

kNN 解决鸢尾花和手写数字识别分类问题

作者头像
石晓文
发布2019-06-24 21:53:45
1K0
发布2019-06-24 21:53:45
举报
文章被收录于专栏:小小挖掘机小小挖掘机

摘要:运用 kNN 解决鸢尾花和手写数字识别分类问题,熟悉 Sklearn 的一般套路。

今天我们以两个常见的数据集鸢尾花手写数字识别为例,练习 Sklearn 使用 kNN 算法解决机器学习分类问题,作为对之前四篇文章的小结。

练习完这两个案例,相信会对 kNN 算法有一个比较全面的理解,同时能学会 Sklearn 处理机器学习的一些固定套路。

预测鸢尾花数据集分类

鸢(yuan)尾花是 20 世纪 30 年代的一个经典数据集。该数据集包括三种花共 150 个样本,每个样本有 4 个数值型特征,分别是花萼长度(cm)、花萼宽度(cm)、花瓣长度(cm)、花瓣宽度(cm)。

取数据集的前 5 行预览一下:

现在的任务是:随机给一些样本,要判定它们分别属于哪一种花。和葡萄酒数据集的问题很相似。

所以我们同样可以用 kNN 算法来找到答案。思路很简单,加载数据集并划分训练集和测试集,在训练集上训练模型,然后把测试集应用到模型中,预测样本分别属于哪类花,最后计算分类的准确率。

mark

最后分类准确度达到了 97.8%,45 个测试花的样本仅预测错了 1 样本,准确率相当高。这还是在我们没有对模型做任何调优的情况下得到的。可见 kNN 算法的确是种效果很好的算法

接着,再来尝试一个相对大型点的分类数据集手写数字识别数据集,看看 kNN 算法性能如何。

手写数字识别数据集预测

mark

这个数据集来源 1998年 的一个手写数字实验。包含 1797 个样本,每个样本有 64 个特征(由 8 * 8 构成的 64 个数字像素点)每个样本的标签分别是 0-9 自然数中的一个。

现在的任务是:随机给一些数字样本,判定它们是哪个数字。这个任务 kNN 模型也能很好地完成,过程和刚才的鸢尾花一样,我们就直接贴代码:

plt.imshow 是一个图像处理函数,详细使用可以参考:plt.imshow 教程 %%time 是 jupyter book 的一个魔法命令,可以计算单元格执行的运算时间

箭头处对比了我们手写的模型和 Sklearn 中的模型的运行时间:8.74s 和 105ms,我们的算法慢了 80 倍差距非常大,主要是我们使用的是最简单粗暴的算法,而 Sklearn 用了更优化的方法。

这里也反映了 kNN 算法的运行时间会随着数据集维度增大而增大,所以得优化 kNN 算法才可以,比如使用 kd 树、球树等,我们放到后面再讲。

到这儿,我们练习了三个实例,相信对 kNN 算法有了比较深的认识了,不过现在又有一个新的问题。

我们在建立模型时,一直默认 k 参数(选择近邻样本点个数)为 3,这个参数对模型分类结果影响很大,它还可以是很多其他值,那是不是选 3 得到的模型就一定是最好的呢?

另外,在第一篇文章中,我们计算模型距离时默认使用的是欧拉距离,而欧拉距离是不是对任何数据集都是最好的计算方法呢?

上面的答案显然是不那么绝对的。

这里就涉及到了 kNN 算法的超参数问题,上面的 k 和距离都是超参数。不同的超参数会得到不同的 kNN 模型,为了得到更好的 kNN 模型,我们就需要好好了解一下超参数,下一篇文章就来介绍它。

本文参与?腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-06-22,如有侵权请联系?cloudcommunity@tencent.com 删除

本文分享自 小小挖掘机 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与?腾讯云自媒体分享计划? ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 预测鸢尾花数据集分类
  • 手写数字识别数据集预测
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
http://www.vxiaotou.com