决策树算法Java实现示例(一)

2014-11-24 08:24:16 · 作者: · 浏览: 0
package xx;

import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;

public class DicisionTree {

	public static void main(String[] args) throws Exception {
		String[] attrNames = new String[] { "AGE", "INCOME", "STUDENT",
				"CREDIT_RATING" };

		// 读取样本集
		Map> samples = readSamples(attrNames);

		// 生成决策树
		Object decisionTree = generateDecisionTree(samples, attrNames);

		// 输出决策树
		outputDecisionTree(decisionTree, 0, null);
	}

	/**
	 * 读取已分类的样本集,返回Map:分类 -> 属于该分类的样本的列表
	 */
	static Map> readSamples(String[] attrNames) {

		// 样本属性及其所属分类(数组中的最后一个元素为样本所属分类)
		Object[][] rawData = new Object[][] {
				{ "<30  ", "High  ", "No ", "Fair     ", "0" },
				{ "<30  ", "High  ", "No ", "Excellent", "0" },
				{ "30-40", "High  ", "No ", "Fair     ", "1" },
				{ ">40  ", "Medium", "No ", "Fair     ", "1" },
				{ ">40  ", "Low   ", "Yes", "Fair     ", "1" },
				{ ">40  ", "Low   ", "Yes", "Excellent", "0" },
				{ "30-40", "Low   ", "Yes", "Excellent", "1" },
				{ "<30  ", "Medium", "No ", "Fair     ", "0" },
				{ "<30  ", "Low   ", "Yes", "Fair     ", "1" },
				{ ">40  ", "Medium", "Yes", "Fair     ", "1" },
				{ "<30  ", "Medium", "Yes", "Excellent", "1" },
				{ "30-40", "Medium", "No ", "Excellent", "1" },
				{ "30-40", "High  ", "Yes", "Fair     ", "1" },
				{ ">40  ", "Medium", "No ", "Excellent", "0" } };

		// 读取样本属性及其所属分类,构造表示样本的Sample对象,并按分类划分样本集
		Map> ret = new HashMap>();
		for (Object[] row : rawData) {
			Sample sample = new Sample();
			int i = 0;
			for (int n = row.length - 1; i < n; i++)
				sample.setAttribute(attrNames[i], row[i]);
			sample.setCategory(row[i]);
			List samples = ret.get(row[i]);
			if (samples == null) {
				samples = new LinkedList();
				ret.put(row[i], samples);
			}
			samples.add(sample);
		}

		return ret;
	}

	/**
	 * 构造决策树
	 */
	static Object generateDecisionTree(
			Map> categoryToSamples, String[] attrNames) {

		// 如果只有一个样本,将该样本所属分类作为新样本的分类
		if (categoryToSamples.size() == 1)
			return categoryToSamples.keySet().iterator().next();

		// 如果没有供决策的属性,则将样本集中具有最多样本的分类作为新样本的分类,即投票选举出分类
		if (attrNames.length == 0) {
			int max = 0;
			Object maxCategory = null;
			for (Entry
> entry : categoryToSamples .entrySet()) { int cur = entry.getValue().size(); if (cur > max) { max = cur; maxCategory = entry.getKey(); } } return maxCategory; } // 选取测试属性 Object[] rst = chooseBestTestAttribute(categoryToSamples, attrNames); // 决策树根结点,分支属性为选取的测试属性 Tree tree = new Tree(attrNames[(Integer) rst[0]]); // 已用过的测试属性不应再次被选为测试属性 String[] subA = new String[attrNames.length - 1]; for (int i = 0, j = 0; i < attrNames.length; i++) if (i != (Integer) rst[0]) subA[j++] = attrNames[i]; // 根据分支属性生成分支 @SuppressWarnings("unchecked") Map>> splits = /* NEW LINE */(Map>>) rst[2]; for (Entry>> entry : splits.entrySet()) { Object attrValue = entry.getKey(); Map> split = entry.getValue(); Object child = generateDecisionTree(split, subA); tree.setChild(attrValue, child); } return tree; } /** * 选取最优测试属性。最优是指如果根据选取的测试属性分支,则从各分支确定新样本 * 的分类需要的信息量之和最小,这等价于确定新样本的测试属性获得的信息增益最大 * 返回数组:选取的属性下标、信息量之和、Map(属性值->(分类->样本列表)) */ static Object[] chooseBestTestAttribute( Map> categoryToSamples, String[] attrNames) { int minIndex = -1; // 最优属性下标 double minValue = Double.MAX_VALUE; // 最小信息量 Map>> minSplits = null; // 最优分支方案 // 对每一个属性,计算将其作为测试属性的情况下在各分支确定新样本的分类需要的信息量之和,选取最小为最优 for (int attrIndex = 0; attrIndex < attrNames.length; attrIndex++) { int allCount = 0; // 统计样本总数的计数器 // 按当前属性构建Map:属性值->(分类->样本列表) Map>> curSplits = /* NEW LINE */new HashMap>>(); for (Entry> entry : categoryToSamples .entrySet()) { Object category = entry.getKey(); List samples = entry.getValue(); for (Sample sample : samples) { Object attrValue = sample