Bayesian regression

前面介绍的线性模型都是从最小二乘,均方误差的角度去建立的,从最简单的最小二乘到带正则项的 lasso,ridge 等。而 Bayesian regression 是从 Bayesian 概率模型的角度出发的,虽然最后也会转换成一个能量函数的形式。

从前面的线性模型中,我们都假设如下的关系:

y=wx" role="presentation">y=wxy=wx

上面这个关系式其实是直接从值的角度来考虑,其实我们也可以假设如下的关系:

y=wx+ϵ" role="presentation">y=wx+ϵy=wx+ϵ

这个 ϵ" role="presentation" style="position: relative;">ϵϵ 表示一种误差,或者噪声,如果估计的值非常准确,那么 ϵ=0" role="presentation" style="position: relative;">ϵ=0ϵ=0, 否则,这将是一个随机数。

如果我们有一组训练样本,那么每个观察值 y" role="presentation" style="position: relative;">yy 都会有个对应的 ϵ" role="presentation" style="position: relative;">ϵϵ, 而且我们假设 ϵ" role="presentation" style="position: relative;">ϵϵ 是满足独立同分布的。那么我们可以用概率的形式表示为:

p(y|w,x,α)=N(y|wx,α)" role="presentation">p(y|w,x,α)=N(y|wx,α)p(y|w,x,α)=N(y|wx,α)

对于一组训练集,我们可以表示为:

p(y|X,w)=∏i=1NN(yi|wxi,α)" role="presentation">p(y|X,w)=∏i=1NN(yi|wxi,α)p(y|X,w)=∏i=1NN(yi|wxi,α)

最后,利用最大似然估计,可以将上面的表达式转化为一个能量最小的形式。上面是从最大似然估计的角度去求系数。

下面我们考虑从最大后验概率的角度,

p(w|y)=p(y|w)p(w|α)p(α)" role="presentation">p(w|y)=p(y|w)p(w|α)p(α)p(w|y)=p(y|w)p(w|α)p(α)
p(w|α)=N(w|0,α−1I)" role="presentation">p(w|α)=N(w|0,α−1I)p(w|α)=N(w|0,α−1I)

p(α)" role="presentation" style="position: relative;">p(α)p(α) 本身是服从 gamma 分布的。

sklearn 上也给出了一个例子:

import numpy as np
import matplotlib.pyplot as plt
from scipy import stats from sklearn.linear_model import BayesianRidge, LinearRegression # #############################################################################
# Generating simulated data with Gaussian weights
np.random.seed(0)
n_samples, n_features = 100, 100
X = np.random.randn(n_samples, n_features) # Create Gaussian data
# Create weights with a precision lambda_ of 4.
lambda_ = 4.
w = np.zeros(n_features)
# Only keep 10 weights of interest
relevant_features = np.random.randint(0, n_features, 10)
for i in relevant_features:
w[i] = stats.norm.rvs(loc=0, scale=1. / np.sqrt(lambda_))
# Create noise with a precision alpha of 50.
alpha_ = 50.
noise = stats.norm.rvs(loc=0, scale=1. / np.sqrt(alpha_), size=n_samples)
# Create the target
y = np.dot(X, w) + noise # #############################################################################
# Fit the Bayesian Ridge Regression and an OLS for comparison
clf = BayesianRidge(compute_score=True)
clf.fit(X, y) ols = LinearRegression()
ols.fit(X, y) # #############################################################################
# Plot true weights, estimated weights, histogram of the weights, and
# predictions with standard deviations
lw = 2
plt.figure(figsize=(6, 5))
plt.title("Weights of the model")
plt.plot(clf.coef_, color='lightgreen', linewidth=lw,
label="Bayesian Ridge estimate")
plt.plot(w, color='gold', linewidth=lw, label="Ground truth")
plt.plot(ols.coef_, color='navy', linestyle='--', label="OLS estimate")
plt.xlabel("Features")
plt.ylabel("Values of the weights")
plt.legend(loc="best", prop=dict(size=12)) plt.figure(figsize=(6, 5))
plt.title("Histogram of the weights")
plt.hist(clf.coef_, bins=n_features, color='gold', log=True,
edgecolor='black')
plt.scatter(clf.coef_[relevant_features], 5 * np.ones(len(relevant_features)),
color='navy', label="Relevant features")
plt.ylabel("Features")
plt.xlabel("Values of the weights")
plt.legend(loc="upper left") plt.figure(figsize=(6, 5))
plt.title("Marginal log-likelihood")
plt.plot(clf.scores_, color='navy', linewidth=lw)
plt.ylabel("Score")
plt.xlabel("Iterations") # Plotting some predictions for polynomial regression
def f(x, noise_amount):
y = np.sqrt(x) * np.sin(x)
noise = np.random.normal(0, 1, len(x))
return y + noise_amount * noise degree = 10
X = np.linspace(0, 10, 100)
y = f(X, noise_amount=0.1)
clf_poly = BayesianRidge()
clf_poly.fit(np.vander(X, degree), y) X_plot = np.linspace(0, 11, 25)
y_plot = f(X_plot, noise_amount=0)
y_mean, y_std = clf_poly.predict(np.vander(X_plot, degree), return_std=True)
plt.figure(figsize=(6, 5))
plt.errorbar(X_plot, y_mean, y_std, color='navy',
label="Polynomial Bayesian Ridge Regression", linewidth=lw)
plt.plot(X_plot, y_plot, color='gold', linewidth=lw,
label="Ground Truth")
plt.ylabel("Output y")
plt.xlabel("Feature X")
plt.legend(loc="lower left")
plt.show()

最新文章

  1. mac好用的markdown编辑器
  2. List接口、Set接口、Map接口的方法
  3. java基础(环境设置,基础语法,函数数组)
  4. C#.NET 大型通用信息化系统集成快速开发平台 4.0 版本 - 多系统开发接口 - 苹果客户端开发接口
  5. Win10 兼容性 Visual studio web应用程序 ASP.NET 4.0 尚未在 Web 服务器上注册
  6. 安卓序列化漏洞 —— CVE-2015-3525
  7. 【jmeter】HTTP属性管理器HTTP Cookie Manager、HTTP Request Defaults
  8. Android WIFI 启动流程
  9. 链表回文串判断&&链式A+B
  10. css z-index属性
  11. Android获取SharedPreferences失败,且App无法启动
  12. 在Ceph创建虚拟机的过程改进分析
  13. 读《Ext.JS.4.First.Look》随笔
  14. 强制删除sql用户链接
  15. Golang源码探索(三) GC的实现原理
  16. truffle 开发入门教程
  17. Docker Swarm 高可用详解
  18. asp.net mvc项目使用spring.net发布到IIS后,在访问提示错误 Could not load type from string value 'DALMsSql.DBSessionFactory,DALMsSql'.
  19. MongoDB aggregate 运用篇(转)
  20. dev -c++ 快捷键

热门文章

  1. Python第二弹--------类和对象
  2. mydumper原理介绍
  3. 4.4 Routing -- Specifying A Route's Model
  4. 非线性方程(组):一维非线性方程(一)二分法、不动点迭代、牛顿法 [MATLAB]
  5. VS2010/MFC编程入门之十七(对话框:文件对话框)
  6. poj1228 Grandpa's Estate
  7. 日志处理(二) 日志组件logback的介绍及配置使用方法(转)
  8. DB 异常
  9. 20145339顿珠达杰 《网络对抗技术》 逆向与Bof基础
  10. calcite介绍