原文地址:https://www.jianshu.com/p/1db700f866ee

问题描述





程序实现

# kNN_RBFN.py
# coding:utf-8 import numpy as np
import matplotlib.pyplot as plt def ReadData(dataFile): with open(dataFile, 'r') as f:
lines = f.readlines()
data_list = []
for line in lines:
line = line.strip().split()
data_list.append([float(l) for l in line])
dataArray = np.array(data_list)
return dataArray def sign(n): if(n>=0):
return 1
else:
return -1 def kNN(k,trainArray,dataX):
num_data=dataX.shape[0]
predY=np.zeros((num_data,))
for n in range(num_data):
distArray=np.sum((trainArray[:,:-1]-dataX[n,:])**2,axis=1)
id_list=np.argsort(distArray,axis=0).tolist()[:k]
for i in id_list:
predY[n]+=trainArray[i,-1]
predY[n]=sign(predY[n])
return predY def GetZeroOneError(predY,dataY):
return (predY!=dataY).sum()/dataY.shape[0] def plot_bar_chart(X,Y,nameX,nameY,saveName):
plt.figure(figsize=(10,6))
plt.bar(left=X,height=Y,width=0.8,align="center",yerr=0.000001)
for (c,w) in zip(X,Y):
plt.text(c,w*1.03,str(round(w,4)))
plt.xlabel(nameX)
plt.ylabel(nameY)
plt.xlim(X[0]-1,X[-1]+1)
plt.xticks(X)
plt.ylim(0,1)
plt.title(nameY+" versus "+nameX)
plt.savefig(saveName)
return def RBFNetwork(k,gamma,trainArray,dataX):
num_data=dataX.shape[0]
predY=np.zeros((num_data,))
for n in range(num_data):
gaussianDistArray=np.exp(-gamma*np.sum((trainArray[:,:-1]-dataX[n,:])**2,axis=1))
id_list=np.argsort(gaussianDistArray,axis=0).tolist()[:k]
for i in id_list:
predY[n]+=trainArray[i,-1]
predY[n]=sign(predY[n])
return predY if __name__=="__main__": dataArray=ReadData("hw8_train.dat")
testArray=ReadData("hw8_test.dat")
k_list=[1,3,5,7,9]
ein_list=[]
eout_list=[]
for k in k_list:
predY=kNN(k,dataArray,dataArray[:,:-1])
ein_list.append(GetZeroOneError(predY,dataArray[:,-1]))
predY=kNN(k,dataArray,testArray[:,:-1])
eout_list.append(GetZeroOneError(predY,testArray[:,-1])) # 12
plot_bar_chart(k_list,ein_list,nameX="k",nameY="Ein(gk-nbor)",saveName="12.png") # 14
plot_bar_chart(k_list,eout_list,nameX='k',nameY="Eout(gk-bor)",saveName="14.png") gamma_list=[-3,-1,0,1,2]
ein_list=[]
eout_list=[]
for gamma in gamma_list:
predY=RBFNetwork(dataArray.shape[0],10**gamma,dataArray,dataArray[:,:-1])
ein_list.append(GetZeroOneError(predY,dataArray[:,-1]))
predY=RBFNetwork(dataArray.shape[0],10**gamma,dataArray,testArray[:,:-1])
eout_list.append(GetZeroOneError(predY,testArray[:,-1])) # 16
plot_bar_chart(X=gamma_list,Y=ein_list,nameX="log10(gamma)",nameY="Ein(guniform)",saveName="16.png") # 18
plot_bar_chart(X=gamma_list,Y=eout_list,nameX="log10(gamma)",nameY="Eout(guniform)",saveName="18.png")
# kMeans.py
# coding:utf-8 from numpy import random
from kNN_RBFN import * def kMeans(t,k,dataArray):
num_data=dataArray.shape[0]
random.seed(t)
centreIDList=random.randint(0,num_data,k).tolist()
nowCentreArray=dataArray[centreIDList,:]
tmpCentreArray=np.array(nowCentreArray)
ein=1000000
nowEin=ein-1
dict={}
while(nowEin<ein):
ein=nowEin
dict = {}
for n in range(num_data):
distArray=np.sum((nowCentreArray-dataArray[n,:])**2,axis=1)
minID=np.argmin(distArray)
tmpCentreArray[minID]=(tmpCentreArray[minID]+dataArray[n,:])/2
try:
dict[minID].append(dataArray[n,:])
except:
dict[minID]=[]
dict[minID].append(dataArray[n,:])
nowCentreArray=np.array(tmpCentreArray)
nowEin=GetEin(nowCentreArray,dict)
return nowCentreArray,dict def GetEin(nowCentreArray,dict):
k=nowCentreArray.shape[0]
ein=0
for i in range(k):
if i not in dict.keys():
continue
data=np.array(dict[i])
ein+=np.average(np.sum((data-nowCentreArray[i])**2,axis=1))
return ein def plot_bar_chart(X,Y,nameX,nameY,saveName):
plt.figure(figsize=(10,6))
plt.bar(left=X,height=Y,width=0.8,align="center",yerr=0.000001)
for (c,w) in zip(X,Y):
plt.text(c,w*1.03,str(round(w,4)))
plt.xlabel(nameX)
plt.ylabel(nameY)
plt.xlim(X[0]-1,X[-1]+1)
plt.xticks(X)
plt.title(nameY+" versus "+nameX)
plt.savefig(saveName)
return if __name__=="__main__": dataArray=ReadData("hw8_nolabel_train.dat")
k_list=[2,4,6,8,10]
ein_list=[]
for k in k_list:
ein=0
for t in range(500):
nowCentreArray,dict=kMeans(t,k,dataArray)
ein+=GetEin(nowCentreArray,dict)
ein_list.append(ein/500) plot_bar_chart(k_list,ein_list,nameX="k",nameY="the average Ein over 500 experiments",saveName="20.png")

运行结果









最新文章

  1. 为什么Pojo类没有注解也没有spring中配置&lt;bean&gt;也能够被加载到容器中。
  2. 72. Generate Parentheses &amp;&amp; Valid Parentheses
  3. js类式继承模式学习心得
  4. 有的机器不能通过session登录
  5. 树状数组+STL FZU 2029 买票问题
  6. mergeSort
  7. STL--set
  8. svn sc create 命令行创建服务自启动
  9. ASP.NET验证控件应用实例与详解。
  10. 试用MarkDown
  11. strace详解及实战
  12. 一文搞懂RAM、ROM、SDRAM、DRAM、DDR、flash等存储介质
  13. cocos CCLayer glDrawArrays(GL_TRIANGLE_STRIP, 0, 4);ios11闪退 spine动画
  14. jQuery横向上下排列鱼骨图形式信息展示代码时光轴样式(转自CSDN,原文链接附于文中)
  15. Bullet3的一些理解
  16. 2017面向对象程序设计(JAVA)第3周学习指导及要求(2017.9.6-2017.9.12)
  17. 树状数组训练题1:弱弱的战壕(vijos1066)
  18. linux-ubuntu14.04以下使用gdb出现的问题
  19. 1005 Spell It Right (20 分)
  20. Jmeter获取不到cookie(备注:前面和后面的几个步骤都可以获取到cookie)

热门文章

  1. Rust &lt;3&gt;:控制流
  2. jmeter压测、操作数据库、分布式、 linux下运行的简单介绍
  3. layer通过父页面调用子页面的方法及属性
  4. UVA 12821 Double Shortest Paths
  5. 利用单选框的单选特性作tab切换
  6. ReentrantReadWriteLock的相关使用
  7. 第三章 k8s的node节点配置
  8. JDBC getConnection细节
  9. Gym 101981K bfs
  10. 前端学习(八)sass和bootstrap(笔记)