实现两个矩阵的无循环计算欧氏距离 Euclidean distance

navigation:

*[1.问题描述](#1.problems sources)

*[2.解决方法](#2.no loop cal the distances)

1.问题来源

kNN算法中会计算两个矩阵的距离

可以使用循环的方法来实现,效率较低

def compute_distances_one_loop(self, X):
"""
train:5000x3072
test: 500x3072
- X: A numpy array of shape (num_test, D) containing test data
Returns:
- dists: A numpy array of shape (num_test, num_train) where dists[i, j]
is the Euclidean distance between the ith test point and the jth training
point.
"""
num_test = X.shape[0]
num_train = self.X_train.shape[0]
dists = np.zeros((num_test, num_train))
for i in range(num_test):
#######################################################################
# TODO: #
# Compute the l2 distance between the ith test point and all training #
# points, and store the result in dists[i, :]. #
#######################################################################
distance=np.sqrt(np.sum(np.square(self.X_train - X[i,:]),axis=1))
dists[i,:]=distance
return dists

2.无循环计算L2 distances

一眼看到这个代码,真的是被深深折服!厉害,值得细细学习搞懂。

def compute_distances_no_loops(self, X):
"""
Compute the distance between each test point in X and each training point
in self.X_train using no explicit loops.
Input / Output: Same as compute_distances_two_loops
"""
num_test = X.shape[0]
num_train = self.X_train.shape[0]
dists = np.zeros((num_test, num_train)) #########################################################################
# TODO: #
# Compute the l2 distance between all test points and all training #
# points without using any explicit loops, and store the result in #
# dists. #
# #
# You should implement this function using only basic array operations; #
# in particular you should not use functions from scipy. #
# #
# HINT: Try to formulate the l2 distance using matrix multiplication #
# and two broadcast sums. #
######################################################################### M = np.dot(X, self.X_train.T)
nrow=M.shape[0]
ncol=M.shape[1]
te = np.diag(np.dot(X,X.T))
tr = np.diag(np.dot(self.X_train,self.X_train.T))
te= np.reshape(np.repeat(te,ncol),M.shape)
tr = np.reshape(np.repeat(tr, nrow), M.T.shape)
sq=-2 * M +te+tr.T
dists = np.sqrt(sq) return dists

可能一下子有点懵,不着急 我们举个例子一步一步理解

要先知道计算L2的距离公式:

\[L2(x_{i},x_{j})=(\sum_{i=1}^{n} \mid x_{i}^{(l)} - x_{j}^{(l)} \mid ^{2})^{\frac{1}{2}}
\]

计算L2距离需要得到 两点距离差的平方和的开方

再熟悉一个基本公式

\[(a-b)^{2}= a^{2}- 2ab+b^{2}
\]

# 假设 x:4x3  ,y: 2x3
# 最后输出一个 2x4矩阵
import numpy as np
>>> x=np.array([[1,2,3],[3,4,5],[5,6,7],[7,8,9]])
>>> x
array([[1, 2, 3],
[3, 4, 5],
[5, 6, 7],
[7, 8, 9]])
>>> y=np.array([[2,3,4],[1,2,3]])
>>> y
array([[2, 3, 4],
[1, 2, 3]])
# 计算两个矩阵的乘积
>>> M=np.dot(y,x.T)
>>> M
array([[20, 38, 56, 74],
[14, 26, 38, 50]])
# 保存乘积矩阵的行列
>>> nrow=M.shape[0]
>>> ncol=M.shape[1]
>>> nrow
2
>>> ncol
4

先计算,提取出对角元素

>>> te=np.diag(np.dot(y,y.T))
>>> tr=np.diag(np.dot(x,x.T))
>>> te
array([29, 14])
>>> tr
array([ 14, 50, 110, 194])

按对角元素来进行扩充,满足矩阵计算要求

得到\(a^{2}\),\(b^{2}\)

# 继续整理
>>> te=np.reshape(np.repeat(te,ncol),M.shape) # ncol:4 ,M: 2x4
>>> tr=np.reshape(np.repeat(tr,nrow),M.T.shape) #nrow:2 ,M.T:4x2
>>> te
array([[29, 29, 29, 29],
[14, 14, 14, 14]])
>>> tr
array([[ 14, 14],
[ 50, 50],
[110, 110],
[194, 194]])

\(-2ab\)就是-2*M

计算距离的开方

>>> sq=-2*M+te+tr.T
>>> dists=np.sqrt(sq)
>>> sq
array([[ 3, 3, 27, 75],
[ 0, 12, 48, 108]])
>>> dists
array([[ 1.73205081, 1.73205081, 5.19615242, 8.66025404],
[ 0. , 3.46410162, 6.92820323, 10.39230485]])

最新文章

  1. MySQL:动态开启慢查询日志(Slow Query Log)
  2. BADI_MATERIAL_CHECK(物料主数据表的增强检查)
  3. java一行一行写入或读取数据
  4. learning to rank
  5. .net验证码生成及使用
  6. Spring+SpringMVC+Mybatis 利用AOP自定义注解实现可配置日志快照记录
  7. linux list all users.
  8. 在Linux上怎么安装和配置DenyHosts工具
  9. bootstrap在 刷新页面,tab选择页面不会改变。
  10. --@angularJS--指令与指令之间的交互demo
  11. java TreeSet 应用
  12. VS2013创建Windows服务
  13. 5分钟了解MySQL5.7的undo log在线收缩新特性
  14. 测试驱动开发实践3————从testList开始
  15. [Reinforcement Learning] Model-Free Prediction
  16. 【Vuex】mapGetters 辅助函数
  17. STM32 --- 断言(assert_param)的开启和使用
  18. 爬虫的基本操作 requests / BeautifulSoup 的使用
  19. centos6.6安装hadoop-2.5.0(五、部署过程中的问题解决)
  20. Latex数学公式编写

热门文章

  1. 第二章 jQuery框架使用准备
  2. Spring源码分析之环境搭建
  3. netty源码解解析(4.0)-16 ChannelHandler概览
  4. SpringBoot底层原理及分析
  5. pycharm与monkeyrunner测试
  6. Linux-Windows 端口转发
  7. 按需制作最小的本地yum源
  8. Spring aop 拦截自定义注解+分组验证参数
  9. dns自动配置shell脚本
  10. RE最全面的正则表达式----终结篇 特殊处理