KNN算法
2024-10-18 08:20:38
1.算法讲解
KNN算法是一个最基本、最简单的有监督算法,基本思路就是给定一个样本,先通过距离计算,得到这个样本最近的topK个样本,然后根据这topK个样本的标签,投票决定给定样本的标签;
训练过程:只需要加载训练数据;
测试过程:通过之前加载的训练数据,计算测试数据集中各个样本的标签,从而完成测试数据集的标注;
2.代码
具体代码如下:
#!/usr/bin/env/ python
# -*- coding: utf-8 -*-
import csv
import random
from matplotlib import pyplot as plt
import numpy as np
from sklearn.decomposition import PCA
class KNN(object):
def __init__(self):
self._trainData = None
self._trainDataLabel = None
# 计算距离
def _computerDist(self,testData):
m = testData.shape[0]
n = self._trainData.shape[0]
dist = np.zeros((m,n))
for i in range(m):
for j in range(n):
dist[i][j] = np.sum( (testData[i,:] - self._trainData[j,:])**2 )
return dist
# 模型训练,knn只需要加载训练数据集
def train(self,dataset):
self._trainData = dataset[:,0:-1]
self._trainDataLabel = np.array(dataset[:,-1],dtype = np.int)
# 预测测试数据集
def predict(self,testData,topK = 3):
dist = self._computerDist(testData)
num_test = testData.shape[0]
predLable = np.zeros(num_test)
for i in range(num_test):
labelList = []
# 得到前topK样本的索引
idxList = np.argsort(dist[i,:])[:topK].tolist()
# 根据这些索引,得到对应的标签
labelList = self._trainDataLabel[idxList]
# 统计各个标签数目
counts = np.bincount(labelList)
# 将标签数目最大的标签值作为样本的标签
predLable[i] = np.argmax(counts)
return predLable
# 测试准确率
def test(self,testData,testLabel,topK = 3):
predLabel = self.predict(testData,topK)
predLabel = np.array(predLabel,dtype = int)
num_correct = np.sum(predLabel == testLabel)
num_test = testLabel.shape[0]
accuracy = float(num_correct) / num_test
print "testLabel:" + str(testLabel)
print "predLabel:" + str(predLabel)
print "get: %d / % d => accuracy: %f" %(num_correct,num_test,accuracy)
return predLabel
# 画出结果图
def plotResult(self,testData,predLable):
X = self._trainData
y = self._trainDataLabel
pca = PCA(n_components=2)
X_r = pca.fit(X).transform(X)
test_r = pca.fit(testData).transform(testData)
plt.figure()
for c, i in zip("rgb", [0, 1, 2]):
plt.scatter(X_r[y == i, 0], X_r[y == i, 1], c=c)
plt.scatter(test_r[predLable == i,0],test_r[predLable == i,1],s= 30,c = c,marker = 'D')
plt.legend()
plt.title('KNN of IRIS dataset')
plt.show()
# 加载数据集
def loadDataSet(self,fileName,splitRatio = 0.9):
lines = csv.reader(open(fileName,"rb") )
dataset = list(lines)
for i in range(len(dataset)):
dataset[i] = [float(x) for x in dataset[i]]
trainSize = int(len(dataset) * splitRatio)
random.shuffle(dataset)
trainData = np.array(dataset[:trainSize])
testData = np.array(dataset[trainSize:])
return trainData,testData
if __name__ == "__main__":
fileName = 'iris.csv'
KNNobj = KNN()
trainData,testData = KNNobj.loadDataSet(fileName,0.8)
# 抽取出测试数据
testdata = testData[:,0:-1]
# 抽取出测试标签数据
testdataLabel = np.array(testData[:,-1],dtype = int)
# 训练模型
KNNobj.train(trainData)
# 测试模型
predLabel = KNNobj.test(testdata,testdataLabel,3)
# 画出结果分布
KNNobj.plotResult(testdata,predLabel)
3.结果分析
本实例中,训练数据样本量为120个,测试数据样本量为30个,topK=3;
运行结果如下:
get: 29 / 30 => accuracy: 0.966667
结果分布图如下所示:
其中圆心点为训练数据,菱形点为测试数据;不同颜色代表不同的类;
4.参考链接
Comparison of LDA and PCA 2D projection of Iris dataset
最新文章
- esxi 6 虚拟机安装复制
- PC远程调试移动设备
- 在可以调用 OLE 之前,必须将当前线程设置为单线程单元(STA)模式
- 使用Microsoft.Office.Interop.Excel.Application xlApp 生成Excel
- python asyncio笔记
- [DataTable]控件排序事件中用DataView及DataTable排序
- jQuery判断页面滚动条滚动方向
- Asp.net文件缓存依赖
- 'datetime.datetime' has no attribute 'datetime'问题
- 容斥原理及SQL in关键字在EF中的应用
- Mac下Kali虚拟机与宿主机共享文件夹
- java----JDOM解析XML
- PHP超全局变量$_SERVER分析
- php 异步执行脚本
- codefroces 450B矩阵快速幂
- Sublime Text webstorm等编译器快速编写HTML/CSS代码的技巧
- J06-Java IO流总结六 《 BufferedReader和BufferedWriter 》
- 自己定义带三角形箭头的TextView
- 函数响应式编程(FRP)框架--ReactiveCocoa
- ECUST 12级 Practise