最近邻算法,最直接的理解就是,输入数据的特征与已有数据的特征一一进行比对,最靠近哪一个就将输入数据划分为那一个所属的类,当然,以此来统计k个最靠近特征中所属类别最多的类,那就变成了k近邻算法。本博客同样对sklearn的乳腺癌数据进行最近邻算法分类,基本的内容同上一篇博客内容一样,就是最近邻计算的是距离,优化的是最小距离问题,这里采用L1距离(曼哈顿距离)或者L2距离(欧氏距离),计算特征之间的绝对距离:

# 计算L1距离(曼哈顿)
distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.negative(xte))), reduction_indices=1)
# L2距离(欧式距离)
distance = tf.sqrt(tf.reduce_sum(tf.square(tf.add(xtr, tf.negative(xte))), reduction_indices=1))

优化问题就是获得最小距离的标签:

pred = tf.arg_min(distance, 0)

最后衡量最近邻算法的性能的时候就通过统计正确分类和错误分类的个数来计算准确率,完整的代码如下:

from __future__ import print_function
import tensorflow as tf
import sklearn.datasets
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets as skd
from sklearn.model_selection import train_test_split # 加载乳腺癌数据集,该数据及596个样本,每个样本有30维,共有两类
cancer = skd.load_breast_cancer() # 将数据集的数据和标签分离
X_data = cancer.data
Y_data = cancer.target
print("X_data.shape = ", X_data.shape)
print("Y_data.shape = ", Y_data.shape) # 将数据和标签分成训练集和测试集
x_train,x_test,y_train,y_test = train_test_split(X_data,Y_data,test_size=0.2,random_state=1)
print("y_test=", y_test)
print("x_train.shape = ", x_train.shape)
print("x_test.shape = ", x_test.shape)
print("y_train.shape = ", y_train.shape)
print("y_test.shape = ", y_test.shape) # tf的图模型输入
xtr = tf.placeholder("float", [None, 30])
xte = tf.placeholder("float", [30]) # 计算L1距离(曼哈顿)
# distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.negative(xte))), reduction_indices=1)
# L2距离(欧式距离)
distance = tf.sqrt(tf.reduce_sum(tf.square(tf.add(xtr, tf.negative(xte))), reduction_indices=1))
# Prediction: Get min distance index (Nearest neighbor)
pred = tf.arg_min(distance, 0) accuracy = 0.
error_count = 0 init = tf.global_variables_initializer() with tf.Session() as sess:
    sess.run(init)     for i in range(x_test.shape[0]):
        # 获取最近邻类
        nn_index = sess.run(pred, feed_dict={xtr: x_train, xte: x_test[i, :]})
        print("Test", i, "Prediction:", y_train[nn_index], "True Class:", y_test[i])
        if y_train[nn_index] == y_test[i]:
            accuracy += 1./len(x_test)
        else:
            error_count = error_count + 1
    print("完成!")
    print("准确分类:", x_test.shape[0] - error_count)
    print("错误分类:", error_count)
    print("准确率:", accuracy)

最近邻算法的表现如下:

这里有几点影响:

1、数据集,一般,训练集越大,相对来说准确率相对就高一些;

2、使用欧氏距离度量的时候会比用曼哈顿距离要好一些。

朱雀桥边野草花,乌衣巷口夕阳斜。

旧时王谢堂前燕,飞入寻常百姓家。

  -- 刘禹锡 《乌衣巷》

最新文章

  1. 把图标改成web字体
  2. VoxelGrid体素滤波器对点云进行下采样
  3. [转]N种内核注入DLL的思路及实现
  4. javascript基础知识-命名提前,作用域
  5. signalR制作微信墙 开源
  6. 信息安全系统设计基础实验一:Linux开发环境的配置和使用
  7. array_count_values函数
  8. checkbox 选择一个checkbox,其他checkbox也会选择
  9. cf C. Vasya and Robot
  10. 关于C语言中的强符号、弱符号、强引用和弱引用的一些陋见,欢迎指正
  11. 爬虫入门系列(二):优雅的HTTP库requests
  12. selenium+python对页面元素进行高亮显示
  13. css模板
  14. 关闭默认共享,禁止ipc$空连接
  15. linux --- Ansible-playbook篇
  16. wx事件处理二
  17. [LOJ#2878]. 「JOISC 2014 Day2」邮戳拉力赛[括号序列dp]
  18. Javascript:10天设计一门语言
  19. Java容器涉及的类(代码)
  20. NET 知识体系结构

热门文章

  1. 004.Delphi插件之QPlugins,参数传递
  2. Linux间传输文件 scp
  3. 如何拯救被Due逼疯的留学生们?
  4. python中numpy矩阵运算操作大全(非常全)!
  5. 吴裕雄--天生自然java开发常用类库学习笔记:观察者设计模式
  6. Problem B: Bulbs
  7. python基础数据类型--列表(list)
  8. 响应式布局之 px、em、 rem
  9. python中添加requests资源包
  10. Listener(Web监听器、活化、钝化)