数据存储文件:buycomputer.properties

#数据个数
datanum=14
#属性及属性值
nodeAndAttribute=年龄:青/中/老,收入:高/中/低,学生:是/否,信誉:良/优,归类:买/不买
#数据
D1=青,高,否,良,不买
D2=青,高,否,优,不买
D3=中,高,否,良,买
D4=老,中,否,良,买
D5=老,低,是,良,买
D6=老,低,是,优,不买
D7=中,低,是,优,买
D8=青,中,否,良,不买
D9=青,低,是,良,买
D10=老,中,是,良,买
D11=青,中,是,优,买
D12=中,中,否,优,买
D13=中,高,是,良,买
D14=老,中,否,优,不买
D15=老,中,否,优,买

实体类:TreeNode.java

package com.id3.node;

import java.util.HashMap;
import java.util.Map; public class TreeNode { private String nodeName;
private Map<String,Attributes> attributes;
private double gain; public double getGain() {
return gain;
}
public void setGain(double gain) {
this.gain = gain;
}
public String getNodeName() { return nodeName;
}
public void setNodeName(String nodeName) {
this.nodeName = nodeName;
}
public Map<String, Attributes> getAttributes() {
return attributes;
}
public void setAttributes(Map<String, Attributes> attributes) { this.attributes = attributes;
} @Override
public String toString() {
return "TreeNode [nodeName=" + nodeName + ", attributes=" + attributes
+ ", gain=" + gain + "]";
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result
+ ((attributes == null) ? 0 : attributes.hashCode());
long temp;
temp = Double.doubleToLongBits(gain);
result = prime * result + (int) (temp ^ (temp >>> 32));
result = prime * result
+ ((nodeName == null) ? 0 : nodeName.hashCode());
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
TreeNode other = (TreeNode) obj;
if (attributes == null) {
if (other.attributes != null)
return false;
} else if (!attributes.equals(other.attributes))
return false;
if (Double.doubleToLongBits(gain) != Double
.doubleToLongBits(other.gain))
return false;
if (nodeName == null) {
if (other.nodeName != null)
return false;
} else if (!nodeName.equals(other.nodeName))
return false;
return true;
} } class Attributes{ private String attrName;
private TreeNode nextNode;
private String leafName;
private int attrNum;
private double h;
Map<String, Integer> resultNum = new HashMap<String, Integer>(); public String getLeafName() {
return leafName;
}
public void setLeafName(String leafName) {
this.leafName = leafName;
}
public Map<String, Integer> getResultNum() {
return resultNum;
}
public void setResultNum(Map<String, Integer> resultNum) {
this.resultNum = resultNum;
}
public double getH() {
return h;
}
public void setH(double h) {
this.h = h;
}
public String getAttrName() {
return attrName;
}
public void setAttrName(String attrName) {
this.attrName = attrName;
}
public TreeNode getNextNode() {
return nextNode;
}
public void setNextNode(TreeNode nextNode) {
this.nextNode = nextNode;
}
public int getAttrNum() {
return attrNum;
}
public void setAttrNum(int attrNum) {
this.attrNum = attrNum;
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result
+ ((attrName == null) ? 0 : attrName.hashCode());
result = prime * result + attrNum;
long temp;
temp = Double.doubleToLongBits(h);
result = prime * result + (int) (temp ^ (temp >>> 32));
result = prime * result
+ ((leafName == null) ? 0 : leafName.hashCode());
result = prime * result
+ ((nextNode == null) ? 0 : nextNode.hashCode());
result = prime * result
+ ((resultNum == null) ? 0 : resultNum.hashCode());
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
Attributes other = (Attributes) obj;
if (attrName == null) {
if (other.attrName != null)
return false;
} else if (!attrName.equals(other.attrName))
return false;
if (attrNum != other.attrNum)
return false;
if (Double.doubleToLongBits(h) != Double.doubleToLongBits(other.h))
return false;
if (leafName == null) {
if (other.leafName != null)
return false;
} else if (!leafName.equals(other.leafName))
return false;
if (nextNode == null) {
if (other.nextNode != null)
return false;
} else if (!nextNode.equals(other.nextNode))
return false;
if (resultNum == null) {
if (other.resultNum != null)
return false;
} else if (!resultNum.equals(other.resultNum))
return false;
return true;
}
@Override
public String toString() {
return "Attributes [attrName=" + attrName + ", nextNode=" + nextNode
+ ", leafName=" + leafName + ", attrNum=" + attrNum + ", h="
+ h + ", resultNum=" + resultNum + "]";
} }

ID3算法:ID3Alogo.java

package com.id3.node;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties; /**
* ID3算法
* @author JoMint
*
*/
public class ID3Alogo { //存每个节点及其属性等相关变量
private List<TreeNode> treeList;
//存数据集
private List<Map<String, String>> dataList;
//遍历决策树时的开始节点
private Attributes startNode;
//决策结果变量的值
private List<String> resultList;
//结果属性节点
private TreeNode resultNode;
//决策树
private String str; //构建决策树的开始调用方法
public void ID3(String id3Name,String readPath,String printPath){ //初始化成员变量
initElement(id3Name);
//读数据
readData(readPath);
//构建决策树
cusTree(dataList, treeList, startNode);
//System.out.println(startNode.getNextNode().getAttributes().get("Overcast").getLeafName());
//遍历决策树,并把结果存入str中 printTree(startNode,"");
//打印决策树
System.out.println(str);
//输出决策树到文件
printTreetoTxt(printPath); } /**
* 初始化成员变量
*/
private void initElement(String id3Name) { //存每个节点及其属性等相关变量
treeList = new ArrayList<TreeNode>();
//存数据集
dataList = new ArrayList<Map<String,String>>();
//遍历决策树时的开始节点
startNode = new Attributes();
//决策结果变量的值
resultList = new ArrayList<String>();
//结果属性节点
TreeNode resultNode = null;
//决策树
str = id3Name+"决策树:\r\n"; } /**
* 读数据
*/
private void readData(String path) { Map<String, String> dataMap;
Map<String,Attributes> attrMap;
TreeNode treeNode;
int num; //创建读取properties文件的对象
Properties pro = new Properties(); try {
//为了读取中文字符,将读取文件的类型改为字符流读取
InputStream inputStream = new FileInputStream(path);
BufferedReader bf = new BufferedReader(new InputStreamReader(inputStream));
//加载数据文件
pro.load(bf);
//读取数据总个数
num = Integer.parseInt(pro.getProperty("datanum"));
//读取属性及属性值
String attribute = pro.getProperty("nodeAndAttribute");
//将每个属性分开,用数组存,遍历每个属性,再把每个属性的属性值分开,存到treeList中
String[] attArray = attribute.split(",");
for (int i = 0; i < attArray.length; i++) { treeNode = new TreeNode();
String[] temp = attArray[i].split(":");
String nodeName = temp[0];
String[] attr = temp[1].split("/");
treeNode.setNodeName(nodeName);
attrMap = new HashMap<String, Attributes>();
Attributes attributes;
for (int j = 0; j < attr.length; j++) {
//Map<String, Integer> map = new HashMap<String, Integer>();
attributes = new Attributes();
//map.put(attr[j], 0);
attributes.setAttrName(attr[j]);
attrMap.put(attr[j], attributes); //存入结果变量的值,为最后的判断做铺垫
if(i == attArray.length-1){ resultList.add(attr[j]); } }
treeNode.setAttributes(attrMap);
treeList.add(treeNode);
} //遍历数据集,将数据按行存入dataList中
for (int i = 1; i <= num; i++) { dataMap = new HashMap<String, String>();
String key = "D"+i;
String[] colline = pro.getProperty(key).split(",");
//System.out.println(key+"=="+colline.length);
for (int j = 0; j < treeList.size(); j++) {
//System.out.println(treeList.size());
dataMap.put(treeList.get(j).getNodeName(), colline[j]);
}
dataList.add(dataMap);
} //得到结果属性的名字
resultNode = treeList.get(treeList.size()-1); // System.out.println("************************resultNode==" + resultNode + "***********************");
} catch (FileNotFoundException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} } /**
* 数据处理
* @param cdataList
* @param ctreeList
*/
private List<List> dealData(List<Map<String, String>> dataList, List<TreeNode> treeList){ List<List> returnList= new ArrayList<List>();
int num = dataList.size(); /*
* 统计数据集中每个属性的属性值个数
*/
Map<String, Attributes> attrMap = new HashMap<String, Attributes>();
Map<String, Integer> resultMap;
for (int i = 0; i < treeList.size(); i++) {
for (int j = 0; j < dataList.size(); j++) {
//获得当前数据集中当前列当前行的属性值
String key = dataList.get(j).get(treeList.get(i).getNodeName());
attrMap = treeList.get(i).getAttributes();
//System.out.println(attrMap.get(key)+"=="+key);
//计算样本中对应的属性变量的个数
attrMap.get(key).setAttrNum(attrMap.get(key).getAttrNum()+1); //System.out.println("->"+attrMap.get(key)); //获得结果变量值
String result = dataList.get(j).get(treeList.get(treeList.size()-1).getNodeName());
resultMap = attrMap.get(key).getResultNum();
//如果包含这个结果变量,则数量上加1; 如果不包含,赋初值为1
if (resultMap.containsKey(result)) {
resultMap.put(result, resultMap.get(result)+1);
}else{
resultMap.put(result, 1);
}
}
}
/*
* 计算熵
*/
DecimalFormat df = new DecimalFormat("#.###");
for (int i = 0; i < treeList.size(); i++) {
//遍历 Attributes
//计算属性熵: gain
double gain = 0.0;
for (Map.Entry<String, Attributes> element : treeList.get(i).getAttributes().entrySet()) {
Attributes attr = treeList.get(i).getAttributes().get(element.getKey());
Map<String, Integer> result = attr.getResultNum();
//遍历每个 Attributes 的 resultNum
//计算属性值的熵 :h
double h = 0.0;
for (Map.Entry<String, Integer> element2 : result.entrySet()) {
double resultNum = (double)result.get(element2.getKey());
double attrNum = (double)attr.getAttrNum();
resultNum = resultNum/attrNum;
h -= (resultNum*(Math.log(resultNum)/Math.log((double)2)));
h = Double.parseDouble(df.format(h));
attr.setH(h);
//System.out.println("resultNum=========="+resultNum);
}
//System.out.println(" attr==>"+attr);
gain += ((double)attr.getAttrNum()/num)*attr.getH();
gain = Double.parseDouble(df.format(gain)); //System.out.println("gain=="+gain);
} treeList.get(i).setGain(gain);
//System.out.println(" gain-->"+treeList.get(i)); } //将处理好的dataList和treeList放在returnList中返回
returnList.add(dataList);
returnList.add(treeList); return returnList; // System.out.println("***************************************************+++++++↓");
// for (int i = 0; i < treeList.size(); i++) {
// System.out.println(treeList.get(i));
// }
// System.out.println();
// for (int i = 0; i < dataList.size(); i++) {
// System.out.println(dataList.get(i));
// }
//
// System.out.println("================================================="+num+"条数据=="+treeList.size()+"个属性");
// System.out.println("***************************************************+++++++↑"); } /**
* 构建决策树
* @param dataList
* @param treeList
*/
@SuppressWarnings("unchecked")
private void cusTree(List<Map<String, String>> dataList, List<TreeNode> treeList, Attributes cAttr){ List<List> curryList= new ArrayList<List>(); //处理数据 curryList = dealData(dataList, treeList); //从 curryList 中得到 dataList 和 treeList
dataList = (List<Map<String, String>>)curryList.get(0);
treeList = (List<TreeNode>)curryList.get(1); //判断当前处理的数据集中的决策结果,若决策结果相同的个数等于总的当前处理的数据集的条数,则遍历结束
//将当前的决策结果放入当前判断的属性值的后边
//返回到调用这个函数的父函数
for (TreeNode treeNode : treeList) {
if (treeNode.getNodeName().equals(resultNode.getNodeName())) {
for (String attr : resultList) {
if (treeNode.getAttributes().get(attr).getAttrNum() == dataList.size()) {
cAttr.setLeafName(attr);
return;
}
}
}
} // System.out.println("=_=_=_=_=_=_=datalist==="+dataList);
// System.out.println("=_=_=_=_=_=_=treelist==="+treeList); //寻找最优解 //得到根节点
TreeNode rootNode = treeList.get(0); for (TreeNode treeNode : treeList) { if(!treeNode.getNodeName().equals(treeList.get(treeList.size()-1).getNodeName())){
if(treeNode.getGain() < rootNode.getGain()){
rootNode = treeNode;
}
} }
// System.out.println("*********↓↓↓↓↓↓↓↓***********当前根节点为:"+rootNode.getNodeName()+"***********↓↓↓↓↓↓↓↓*********"); cAttr.setNextNode(rootNode); //对当前根节点的属性进行遍历,寻找下一个节点 //节点名
String nodeName = rootNode.getNodeName();
//属性名
String attrName = "";
//属性节点
Attributes attr = new Attributes();
//当前节点的属性值集合
Map<String, Attributes> attrMap = rootNode.getAttributes(); //遍历节点的每个属性值
for (Map.Entry<String, Attributes> entry : attrMap.entrySet()) { attr = attrMap.get(entry.getKey());
attrName = attr.getAttrName(); // System.out.println("*****************attrName========"+attrName+"******************"); //得到新的data集合对象 List<Map<String, String>> newDataList = new ArrayList<Map<String,String>>();
Map<String, String> newMap = new HashMap<String, String>();
//String attrName = rootNode.getAttributes().get("Sunny").getAttrName();
newMap.clear(); //删除dataList中已处理过的节点数据
//遍历dataList
for (Map<String, String> map : dataList) { if(map.containsKey(nodeName)){ if(map.get(nodeName).equals(attrName)){
newMap = new HashMap<String, String>();
for (Map.Entry<String, String> m : map.entrySet()) { //如果该节点不是已处理过的节点
if(!m.getKey().equals(nodeName)){
//得到新的节点
newMap.put(m.getKey(), map.get(m.getKey()));
} } //将新的节点存入newDataList中
newDataList.add(newMap);
} } }
// System.out.println("↓↓↓↓↓↓*******************新的data集合:*******************↓↓↓↓↓↓");
// for (Map<String, String> map : newDataList) {
// System.out.println(map);
// } //获得新的tree集合对象,而且值为初值 List<TreeNode> newTreeList = new ArrayList<TreeNode>(); //将treeList中的数据清空
clearTree(treeList); //删除treeList中已处理过的节点
for (TreeNode treeNode : treeList) {
if(!treeNode.getNodeName().equals(nodeName)){
newTreeList.add(treeNode);
}
}
// System.out.println("↓↓↓↓↓↓*******************新的tree集合:*******************↓↓↓↓↓↓");
// for (TreeNode treeNode : newTreeList) {
// System.out.println(treeNode);
// } //递归调用当前函数,继续找节点
cusTree(newDataList, newTreeList,attr);
}
} /**
* 输出决策树
* @param attr
*/
private void printTree(Attributes attr, String ceil) { String nodeName = attr.getNextNode().getNodeName();
Map<String, Attributes> attrMap = attr.getNextNode().getAttributes(); str += ceil+"----"+nodeName+"\r\n";
for (Map.Entry<String, Attributes> nextAttr : attrMap.entrySet()) { //如果当前属性值没有下一个节点,则将当前属性值的名称及决策结果输出
if(attrMap.get(nextAttr.getKey()).getNextNode() == null){ str += ceil+"-------"+attrMap.get(nextAttr.getKey()).getAttrName()+"\r\n";
str += ceil+"----------"+attrMap.get(nextAttr.getKey()).getLeafName()+"\r\n"; }else{ str += ceil+"-------"+attrMap.get(nextAttr.getKey()).getAttrName()+"\r\n";
printTree(attrMap.get(nextAttr.getKey()),"------");
}
} } /**
* 打印决策树到txt文本
* @param path
*/
private void printTreetoTxt(String path){ if(path == null || path.equals("")) return;
File file = new File(path);
File folder = file.getParentFile();
FileWriter fw;
try { if(!folder.exists()){
folder.mkdirs();
file.createNewFile();
} fw = new FileWriter(file);
fw.write(str); fw.flush();
fw.close();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
} /**
* 还原初始数据
* @param treeList
*/
private void clearTree(List<TreeNode> treeList){ for (TreeNode treeNode : treeList) {
Map<String, Attributes> map = treeNode.getAttributes(); for (Map.Entry<String, Attributes> entry : map.entrySet()) {
Attributes attr = map.get(entry.getKey());
attr.setAttrNum(0);
attr.setH(0);
Map<String, Integer> map2 = attr.getResultNum();
map2.clear();
}
treeNode.setGain(0);
}
} }

主函数:ID3Main.java

package com.id3.node;

public class ID3Main {

    public static void main(String[] args) {

        ID3Alogo id3Alogo = new ID3Alogo();
id3Alogo.ID3("决策树名","数据文件地址", "输出文件地址"); } }

最新文章

  1. SVN部署和使用
  2. 修改TNSLSNR的端口
  3. erlang mac os 10.9 卸载脚本
  4. js注入,黑客之路必备!
  5. apache 403错
  6. ubuntu下安装git,sublime,nodejs
  7. hiho 1182 : 欧拉路&#183;三
  8. (转)UIColor 的使用
  9. bzoj 3053 HDU 4347 : The Closest M Points kd树
  10. HDU 1853Cyclic Tour(网络流之最小费用流)
  11. js 日期天数相加减,格式化yyyy-MM-dd
  12. 第003篇 深入体验C#项目开发(二)
  13. scrot-0.8
  14. 使用nodejs爬取和讯网高管增减持数据
  15. sublime编辑器代码背景刺眼怎么修改?
  16. json 函数
  17. Hive随机取某几行数据
  18. Char类型与Sting类型的数字字符转换时的不同点
  19. P4145 上帝造题的七分钟2 / 花神游历各国
  20. Git提取两次提交的差异文件

热门文章

  1. 4-20ma电流信号转0-5v()
  2. 自己用wireshark 抓了个包,分析了一下
  3. Dice Possibility
  4. 使用VNC远程管理VPS(Centos系统)
  5. css3的::selection属性
  6. .net task
  7. JAVA基础--方法的重写overwrite 和 重载overload
  8. Naive Bayes在mapreduce上的实现
  9. HDU 2121 Ice_cream’s world II 最小树形图 模板
  10. iOS多页面传值方式之单例传值singleton