KNN 算法解析和java 代码及python代码实现(二)
大,则不考虑
double distance=calDistance(testdate,t); // 训练元祖和测试元祖的距离
KNNNode node=pq.peek(); // 查询返回队列头的元素,此时为记录最大的值
if(distance
{
pq.remove();
pq.add(new KNNNode(i,distance,t.get(t.size()-1).toString()));
// 当此次训练元祖记录和测试元祖距离小于 队列中最大距离是,被选中
}
}
return getMostClass(pq);
}
/*
* PriorityQueue 的操作是值针对对头进行的
*/
private String getMostClass(PriorityQueue pq)
{
Map count=new HashMap(); // 利用hashMap计算哪个类别是最多的
for(int i=0;i
{
KNNNode node=pq.remove();
String c=node.getC();
if(count.containsKey(c))
{
count.put(c,count.get(c)+1);// HashMap的key值不能重复
}
else
{
count.put(c,1); //加入新的key-value
}
}
int maxIndex=-1;
int maxCount=0;
// HashMap 的操作没有针对序号的 get(i),只有针对key值的操作,所以先进行keys的获得
//keySet() 获得hashmap的key值序列
Object []classes=count.keySet().toArray();
for(int i=0;i
} catch (Exception e) {
{
if(count.get(classes[i])>maxCount)
{
maxIndex=i;
maxCount=count.get(classes[i]);
}
}
return classes[maxIndex].toString();
}
}
package cluster;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.List;
public class TestKNN {
public void read(List
- > datas, String path){
try {
BufferedReader br = new BufferedReader(new FileReader(new File(path)));
String data = br.readLine();
List l = null;
while (data != null) {
String t[] = data.split(" ");
l = new ArrayList();
for (int i = 0; i < t.length; i++) {
l.add(Double.parseDouble(t[i]));
}
datas.add(l);
data = br.readLine();
}
e.printStackTrace();
}
}
public void printtestdate(List testdate)
{
for(int i=0;i
System.out.print(testdate.get(i)+" ");
}
public static void main(String args[])
{
TestKNN tknn=new TestKNN();
String path1="TestSet/KNN/dates.txt";
String path2="TestSet/KNN/testdate.txt";
List
- > dates=new ArrayList
- >();
List
- > testdate=new ArrayList
- >();
try{
tknn.read(dates, path1);
tknn.read(testdate, path2);
KNN knn=new KNN();
for(int i=0;i
{
List tdata=dates.get(i);
System.out.print("训练元组:");
tknn.printtestdate(tdata);
System.out.println("");
//System.out.println(knn.knn(dates,tdata, 3));
}
for(int i=0;i
{
List tdata=testdate.get(i);
System.out.print("测试元组:");
tknn.printtestdate(tdata);
System.out.println("所属类别为:");
System.out.println(Math.round(Float.parseFloat(knn.knn(dates,tdata, 3))));
}
}catch(Exception e)
{
e.printStackTrace();
}
}
}
python版本
# -*- coding: gb2312 -*-
import math
import string
#计算v1与v2之间的欧拉距离
def euclidean(v1,v2):
d=0.0
for i in range(len(v1)):
d+=(v1[i]-v2[i])**2
return math.sqrt(d)
#计算vec1与所有数据data的距离,并且排序
def getdistances(data,vec1):
#data=getVlist(data)
#vec1=getVlist(vec1)
distancelist=[]
for i in range(len(data)):
vec2=data[i]
distancelist.append((euclidean(vec1,vec2),data[i][8]))
distancelist.sort()
return distancelist
# 训练元组表示
vlist1=["1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 1",
"1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 1",
"1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 1",
"1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 0",
"1.0 0