본문 바로가기

Codes

Java - K-means Algorithm

K-means Algoritm (K평균 군집화 알고리즘)
K-means (MacQueen, 1967) 은 유명한 군집화 (Clustering) 문제를 해결하는 가장 간단한 자율학습 (Unsupervised Learning) 알고리즘중 하나이다. 사전에 정해진 어떤수의 클러스터를 통해서 주어진 데이터 집합을 분류하는 간단하고 쉬운 방법.
k-means 는 partitional clustering 에 속한다.

data 이외에 cluster 의 수  k를 input 으로 하며 이때  k를 seed point 라고 한다. seed point 는 임의로 선택되며 바람직한 cluster 구조에 관한 어떤 지식들이 seed point를 선택하는데에 사용될 수 있다. Forgy' algorithm 과 다른점은 하나의 sample 이 하나의 cluster 에 합류하자마자 곧 cluster 의 centroid 가 다시 계산된다는 것이다. 또한 Forgy' algorithm 이 반복적(iterative) 한 반면에 k-means algorithm 은 data set에서 단지 두 번만의 pass 가 이루어진다. 그 과정은 다음과 같다.

1. 처음에 k cluster 로서 시작한다. 남아있는 n-k sample들에 대해서는 가장 가까이 있는 centroid를 찾는다. 이것에 가장 가까이 있는 centroid를 가지는 것이 확인된 cluster 에 sample을 포함시킨다. 각각의 sample 들이 할당된 후에 할당된 cluster 의 centroid 가 다시 계산된다.

2. 그 data를 두 번 처리한다. 각 sample에 대하여 가장 가까이 있는 centroid를 찾는다. 가장 가까이 있는 centroid를 가진 것으로 확인된 cluster 에 sample을 위치시킨다. (이 step 에서는 어떤 centroid 도 다시 계산하지 않는다.

(reference : AIstudy - http://www.aistudy.com)

 


위 설명을 바탕으로 한번 구현해 보았다.
실제 데이터들을 바탕으로 써 먹을수 있게끔 구현하였고,
visualization은 알아서 하면 될 듯.

< 500개의 데이터를 k-means 알고리즘으로 군집화(30개의 클래스) >

weight.java

public class weight {

           public double [] value;

           public int num;

 

           public weight(int length, boolean rnd){

                     value = new double[length];

                     if(rnd)

                                for(int i = 0; i < length; i++)

                                          value[i] = Math.random();

                     num = -1; // non-clustering

           }

          

           public void setNumber(int num){

                     this.num = num;

           }

          

           public int getNumber(){

                     return num;

           }

          

           public double getLength(){

                     return value.length;

           }

          

           public void set(int index, double val){

                     value[index] = val;

           }

          

           public double get(int index){

                     return value[index];

           }

          

           public double distance(weight w){

                     return Math.sqrt(distanceSq(w));

           }

 

           public double distanceSq(weight w){

                     if(w.getLength() != value.length)

                                return -1; // error

                     else{

                                double distSq = 0;

                                double d;

 

                                for(int i = 0; i < value.length; i++){

                                          d = value[i] - w.get(i);

                                          distSq += d * d;

                                }

 

                                return distSq;

                     }

           }

}

 


KCluster.java

public class KCluster extends weight{

           public int num;

          

           public KCluster(int length, boolean rnd, int num){

                     super(length, rnd);

                     this.num = num;

           }

 

           public int getNumber(){

                     return num;

           }

          

           public void setWeight(weight w){

                     value = w.value;

           }

}


ClusteringEngine.java

import java.util.ArrayList;

 

 

public class ClusteringEngine implements Runnable{

           volatile Thread timer;

           public ArrayList<weight> dataSet;

           public ArrayList<KCluster> kSet;

           public int length;

           public double threshold = 0.005;

 

           public double err=0;

           @Override

           public void run() {

                     // TODO Auto-generated method stub

                     while(timer == Thread.currentThread()){

                                try{

                                          Thread.sleep(1000);

                                }catch(InterruptedException e){ }

                                // running Method

                                clustering();

                                if(!rePosition())

                                          stop();

                     }

           }

 

           public void start(){

                     timer = new Thread(this);

                     timer.start();

           }

 

           public void stop(){

                     timer = null;

           }

 

           public ClusteringEngine(int length,int dataSize, int kSize){

                     dataSet = new ArrayList<weight>();

                     kSet = new ArrayList<KCluster>();

                     this.length = length;

                    

                     for(int i = 0; i < dataSize; i++) // initializing dataSet

                                dataSet.add(new weight(length,true));

                    

                     for(int i = 0; i < kSize; i++) // initializing KClusterSet

                                kSet.add(new KCluster(length, true, i));

           }

 

           public void clustering(){

                     for(weight w : dataSet)

                                w.setNumber(getBestClass(w));

           }

 

           public int getBestClass(weight w){ // 웨이트와 가장 가까운 k 찾아 거기에 해당하는 넘버를 리턴

                     KCluster min = kSet.get((int)(Math.random() * kSet.size()));

                     for(KCluster k : kSet)

                                if(!min.equals(k) && min.distance(w) > k.distance(w))

                                          min = k;

                     return min.getNumber();

           }

 

           public boolean rePosition(){
                     double avgDist = 0;

                     for(int i = 0; i < kSet.size(); i++){

                                weight avgWeight = averaging(i);

                                KCluster k = kSet.get(i);

                                avgDist += avgWeight.distance(k);

                                k.setWeight(avgWeight);

                     }

                     avgDist /= kSet.size();

                     err = avgDist;

                     if(avgDist > threshold)

                                return true;  // 재배치를 하였으면 true

                     else

                                return false; // 아니면 false

           }

          

           public weight averaging(int num){ // 클래스 넘버에 해당하는 데이터만 찾아서 평균위치를 찾음.

                     weight avg = new weight(length, false);

                     int count = 0;

                     for(weight w : dataSet)

                                if(num == w.getNumber()){

                                          for(int i = 0; i < length; i++)

                                                     avg.set(i, avg.get(i) + w.get(i));

                                          count++;

                                }

                     for(int i = 0; i < length; i++)

                                avg.set(i, avg.get(i) / count);

                     return avg;

           }

          

           public ArrayList<weight> getDataArray(){

                     return dataSet;

           }

          

           public ArrayList<KCluster> getKArray(){

                     return kSet;

           }

}

'Codes' 카테고리의 다른 글

Flocking in Java - testFrame  (0) 2011.11.01
Java - 최적화된 회전 방향 결정  (0) 2011.09.01