k近邻法的C++实现:kd树(一)

2015-01-22 20:58:31 · 作者: · 浏览: 19
1.k近邻算法的思想
?
给定一个训练集,对于新的输入实例,在训练集中找到与该实例最近的k个实例,这k个实例中的多数属于某个类,就把该输入实例分为这个类。
?
因为要找到最近的k个实例,所以计算输入实例与训练集中实例之间的距离是关键!
?
k近邻算法最简单的方法是线性扫描,这时要计算输入实例与每一个训练实例的距离,当训练集很大时,非常耗时,这种方法不可行,为了提高k近邻的搜索效率,常常考虑使用特殊的存储结构存储训练数据,以减少计算距离的次数,具体方法很多,这里介绍实现经典的kd树方法。
?
2.构造kd树
?
kd树是一种对k维空间中的实例点进行存储以便对其进行快速检索的树形数据结构,kd树是二叉树。
?
下面举例说明:
?
给定一个二维空间的数据集: T = {(2,3), (5,4), (9,6), (4,7), (8,1), (7,2)},构造一个平衡kd树。
?
根结点对应包含数据集T的矩形选择x(1) 轴,6个数据点的x(1) 坐标的中位数是7,以超平面x(1) = 7将空间分为左右两个子矩形(子结点)
左矩形以x(2) = 4为中位数分为两个子矩形
右矩形以x(2) = 6 分为两个子矩形
如此递归,直到两个子区域没有实例存在时停止
?
?
?
?
?
3.利用kd树搜索最近邻
?
输入:已构造的kd树;目标点x;
?
输出:x的最近邻
?
在kd树中找出包含目标点x的叶结点:从根结点出发,递归的向下访问kd树,若目标点x的当前维的坐标小于切分点的坐标,则移动到左子结点,否则移动到右子结点,直到子结点为叶结点为止。
以此叶结点为“当前最近点”
递归地向上回退,在每个结点进行以下操作:(a)如果该结点保存的实例点比当前最近点距离目标点更近,则以该实例点为“当前最近点”;
(b)当前最近点一定存在于某结点一个子结点对应的区域,检查该子结点的父结点的另
一子结点对应区域是否有更近的点(即检查另一子结点对应的区域是否与以目标点为球
心、以目标点与“当前最近点”间的距离为半径的球体相交);如果相交,可能在另一
个子结点对应的区域内存在距目标点更近的点,移动到另一个子结点,接着递归进行最
近邻搜索;如果不相交,向上回退
?
当回退到根结点时,搜索结束,最后的“当前最近点”即为x的最近邻点。
4.C++实现
?
??
? 1 #include
? 2 #include
? 3 #include
? 4 #include
? 5 #include
? 6 using namespace std;
? 7?
? 8?
? 9?
?10?
?11 struct KdTree{
?12 ? ? vector root;
?13 ? ? KdTree* parent;
?14 ? ? KdTree* leftChild;
?15 ? ? KdTree* rightChild;
?16 ? ? //默认构造函数
?17 ? ? KdTree(){parent = leftChild = rightChild = NULL;}
?18 ? ? //判断kd树是否为空
?19 ? ? bool isEmpty()
?20 ? ? {
?21 ? ? ? ? return root.empty();
?22 ? ? }
?23 ? ? //判断kd树是否只是一个叶子结点
?24 ? ? bool isLeaf()
?25 ? ? {
?26 ? ? ? ? return (!root.empty()) &&?
?27 ? ? ? ? ? ? rightChild == NULL && leftChild == NULL;
?28 ? ? }
?29 ? ? //判断是否是树的根结点
?30 ? ? bool isRoot()
?31 ? ? {
?32 ? ? ? ? return (!isEmpty()) && parent == NULL;
?33 ? ? }
?34 ? ? //判断该子kd树的根结点是否是其父kd树的左结点
?35 ? ? bool isLeft()
?36 ? ? {
?37 ? ? ? ? return parent->leftChild->root == root;
?38 ? ? }
?39 ? ? //判断该子kd树的根结点是否是其父kd树的右结点
?40 ? ? bool isRight()
?41 ? ? {
?42 ? ? ? ? return parent->rightChild->root == root;
?43 ? ? }
?44 };
?45?
?46 int data[6][2] = {{2,3},{5,4},{9,6},{4,7},{8,1},{7,2}};
?47?
?48 template
?49 vector > Transpose(vector > Matrix)
?50 {
?51 ? ? unsigned row = Matrix.size();
?52 ? ? unsigned col = Matrix[0].size();
?53 ? ? vector > Trans(col,vector(row,0));
?54 ? ? for (unsigned i = 0; i < col; ++i)
?55 ? ? {
?56 ? ? ? ? for (unsigned j = 0; j < row; ++j)
?57 ? ? ? ? {
?58 ? ? ? ? ? ? Trans[i][j] = Matrix[j][i];
?59 ? ? ? ? }
?60 ? ? }
?61 ? ? return Trans;
?62 }
?63?
?64 template
?65 T findMiddleva lue(vector vec)
?66 {
?67 ? ? sort(vec.begin(),vec.end());
?68 ? ? auto pos = vec.size() / 2;
?69 ? ? return vec[pos];
?70 }
?71?
?72?
?73 //构建kd树
?74 void buildKdTree(KdTree* tree, vector > data, unsigned depth)
?75 {
?76?
?77 ? ? //样本的数量
?78 ? ? unsigned samplesNum = data.size();
?79 ? ? //终止条件
?80 ? ? if (samplesNum == 0)
?81 ? ? {
?82 ? ? ? ? return;
?83 ? ? }
?84 ? ? if (samplesNum == 1)
?85 ? ? {
?86 ? ? ? ? tree->root = data[0];
?87 ? ? ? ? return;
?88 ? ? }
?89 ? ? //样本的维度
?90 ? ? unsigned k = data[0].size();
?91 ? ? vector > transData = Transpose(data);
?92 ? ? //选择切分属性
?93 ? ? unsigned splitAttribute = depth % k;
?94 ? ? vector splitAttributeva lues = transData[splitAttribute];
?95 ? ? //选择切分值