EM算法的简明实现

当然是教学用的简明实现了,这份实现是针对双硬币模型的。

双硬币模型

假设有两枚硬币A、B,以相同的概率随机选择一个硬币,进行如下的抛硬币实验:共做5次实验,每次实验独立的抛十次,结果如图中a所示,例如某次实验产生了H、T、T、T、H、H、T、H、T、H,H代表正面朝上。

假设试验数据记录员可能是实习生,业务不一定熟悉,造成a和b两种情况

a表示实习生记录了详细的试验数据,我们可以观测到试验数据中每次选择的是A还是B

b表示实习生忘了记录每次试验选择的是A还是B,我们无法观测实验数据中选择的硬币是哪个

问在两种情况下分别如何估计两个硬币正面出现的概率?

a情况相信大家都很熟悉,既然能观测到试验数据是哪枚硬币产生的,就可以统计正反面的出现次数,直接利用最大似然估计即可。

b情况就无法直接进行最大似然估计了,只能用EM算法,接下来引用nipunbatra博主的简明EM算法Python实现。

 # -*- coding: utf-8 -*-
"""
Created on Tue Jul 4 18:23:28 2017 @author: Administrator
""" import numpy as np
from scipy import stats priors = [0.6, 0.5]
observations = np.array([[1,0,0,0,1,1,0,1,0,1],
[1,1,1,1,0,1,1,1,1,1],
[1,0,1,1,1,1,1,0,1,1],
[1,0,1,0,0,0,1,1,0,0],
[0,1,1,1,0,1,1,1,0,1]]) def em_single(priors, observations):
"""
input:
priors:[theta_A, theta_B]
obvervations:m*n matrix output: """
theta_A = priors[0]
theta_B = priors[1]
counts = {'A':{'H':0,'T':0}, 'B':{'H':0,'T':0}} # e-step
for observation in observations:
len_observation = len(observation)
num_heads = observation.sum() # 正面个数
num_tails = len_observation - num_heads # 反面个数
# 两个二项分布
contribution_A = stats.binom.pmf(num_heads, len_observation, theta_A)
contribution_B = stats.binom.pmf(num_heads, len_observation, theta_B)
# 采用各自硬币的权重
weight_A = contribution_A/(contribution_A+contribution_B)
weight_B = contribution_B/(contribution_A+contribution_B) # 更新在当前参数下,硬币A和B产生正反面的次数
counts['A']['H'] += weight_A * num_heads
counts['A']['T'] += weight_A * num_tails
counts['B']['H'] += weight_B * num_heads
counts['B']['T'] += weight_B * num_tails # M-step
new_theta_A = counts['A']['H']/(counts['A']['H'] + counts['A']['T'])
new_theta_B = counts['B']['H']/(counts['B']['H'] + counts['B']['T']) return [new_theta_A, new_theta_B] def em(observations, prior, tol=1e-6, iterations=10000):
"""
EM算法
param observations: 观察数据
param prior: 模型初值
param tol: 迭代结束阈值
param iteration: 最大迭代数
return: 局部最优的模型参数
"""
import math
iter = 0
while iter < iterations:
new_prior = em_single(prior, observations)
delta_change = np.abs(new_prior[0]-prior[0])
if delta_change < tol:
break
else:
prior = new_prior
iter += 1
print (iter) return [new_prior, iter] y = em(observations, priors)

参考自:http://www.hankcs.com/ml/em-algorithm-and-its-generalization.html

最新文章

  1. 冷门JS技巧
  2. lsof -ntP -i:端口取出 动行程序的PID 然后xargs kill -9 这个进程
  3. IO/ACM中来自浮点数的陷阱(收集向)
  4. ZOJ-3870 Team Formation
  5. CSS3实战之新增的选择器
  6. PHP中使用多线程
  7. Webstorm 不识别es6 import React from ‘react’——webstorm不支持jsx语法怎么办
  8. include a image in devexpress datagrid
  9. 21 Merge Two Sorted Lists(两链表归并排序Easy)
  10. css基本属性
  11. voc-fcn-alexnet网络结构理解
  12. 每天学点Linux-切割命令split
  13. CSS3 3D酷炫立方体变换动画
  14. IDA远程调试 在内存中dump Dex文件
  15. webservice 测试页面
  16. 在一个由 &#39;L&#39; , &#39;R&#39; 和 &#39;X&#39; 三个字符组成的字符串(例如&quot;RXXLRXRXL&quot;)中进行移动操作。一次移动操作指用一个&quot;LX&quot;替换一个&quot;XL&quot;,或者用一个&quot;XR&quot;替换一个&quot;RX&quot;。现给定起始字符串start和结束字符串end,请编写代码,当且仅当存在一系列移动操作使得start可以转换成end时, 返回True。
  17. Windows xcopy
  18. js中文乱码问题,编码设为utf-8,但还是乱码问题。
  19. iOS去掉icon的(自带磨光效果)gloss effects
  20. PTA基础编程题目集6-3简单求和 (函数题)

热门文章

  1. JAVA MyBatis使用技巧收集
  2. lua -- mysql导出json
  3. [转]《RabbitMQ官方指南》安装指南
  4. MySQL的binlog日志&lt;转&gt;
  5. Java编程的逻辑 (84) - 反射
  6. c++类成员函数后边加const是为什么?
  7. Nginx系列二:(Nginx Rewrite 规则、Nginx 防盗链、Nginx 动静分离、Nginx+keepalived 实现高可用)
  8. java-信息安全(九)-基于DH,非对称加密,对称加密等理解HTTPS
  9. Hibernate HQL的使用
  10. 多密钥ssh-key生成与管理