XUPWUP.NL
Login
Or create an account

KMeans clustering implementation

0 Posted 14 Nov 2014
User avatar
Rick Hendricksen
Administrator
Posts: 2251
  1. import java.util.ArrayList;
  2. import java.util.Arrays;
  3. import java.util.Collections;
  4. import java.util.List;
  5. import java.util.Random;
  6. /**
  7. *
  8. * @author Rick Hendricksen
  9. * @param <T>
  10. */
  11. public class KMeans<T> {
  12. public abstract static class CoordGetter<T>{
  13. public abstract double get(int dimension, T element);
  14. public abstract int getDimensions();
  15. }
  16. private static class Centroid{
  17. public Centroid(double[] coords) {
  18. this.coords = coords;
  19. this.aCoords = new double[coords.length];
  20. // the other members are not initialized, because java initializes them to 0 automatically.
  21. }
  22. public void reset(){
  23. for(int i = 0; i < coords.length; i++){
  24. coords[i] = aCoords[i];
  25. aCoords[i] = 0;
  26. }
  27. sumCount = 0;
  28. }
  29. double[] coords;
  30. double[] aCoords;
  31. // the number of nodes the average was computed with
  32. int sumCount;
  33. }
  34. List<T> data;
  35. int dimensions;
  36. CoordGetter<T> cg;
  37. Centroid[] centroids;
  38. ArrayList<Integer> indexList;
  39. Random random;
  40. /**
  41. *
  42. * @param data
  43. * @param coordGetter A function which returns xyz coordinates for each data element
  44. */
  45. public KMeans(List<T> data, CoordGetter<T> coordGetter) {
  46. this.data = data;
  47. this.cg = coordGetter;
  48. dimensions = cg.getDimensions();
  49. indexList = new ArrayList<>(data.size());
  50. for(int i = 0; i < data.size(); i++){
  51. indexList.add(i);
  52. }
  53. }
  54. /**
  55. *
  56. * @param k The number of clusters (must be less or equal to data.length)
  57. * @param random
  58. */
  59. public void init(int k, Random random){
  60. Collections.shuffle(indexList, random);
  61. centroids = new Centroid[k];
  62. for(int i = 0; i < k; i++){
  63. double[] coords = new double[dimensions];
  64. T element = data.get(indexList.get(i));
  65. for(int j = 0; j < coords.length; j++){
  66. coords[j] = cg.get(j, element);
  67. }
  68. centroids[i] = new Centroid(coords);
  69. }
  70. this.random = random;
  71. }
  72. /**
  73. * @complexity O(k)
  74. * @param d
  75. * @return
  76. */
  77. private int getIndexOfNearestCentroid(T d){
  78. int nearest = -1;
  79. double dist = Double.MAX_VALUE;
  80. for(int i =0; i < centroids.length; i++){
  81. Centroid c = centroids[i];
  82. double distsq = 0;
  83. for(int j = 0; j < dimensions; j++){
  84. double v = c.coords[j] - cg.get(j, d);
  85. distsq += v * v;
  86. }
  87. // if dist squared is closest then dist is closest
  88. if(distsq < dist){
  89. dist = distsq;
  90. nearest = i;
  91. }
  92. }
  93. return nearest;
  94. }
  95. /**
  96. * Run one kmeans iteration
  97. * @complexity O(n * k)
  98. */
  99. public void iterate(){
  100. for(T d : data){
  101. Centroid c = centroids[getIndexOfNearestCentroid(d)];
  102. for(int i = 0; i < dimensions; i++){
  103. c.aCoords[i] = c.aCoords[i] * (c.sumCount / (c.sumCount + 1d)) + cg.get(i, d) / (c.sumCount + 1);
  104. }
  105. c.sumCount++;
  106. }
  107. for(Centroid c : centroids){
  108. c.reset();
  109. }
  110. }
  111. /**
  112. *
  113. * @return An array of length k, each containing a list which contains all nodes in that cluster.
  114. */
  115. public ArrayList<T>[] getResults(){
  116. ArrayList<T>[] results = new ArrayList[centroids.length];
  117. for(int i = 0; i < results.length; i++){
  118. results[i] = new ArrayList<>();
  119. }
  120. for(T d : data){
  121. results[getIndexOfNearestCentroid(d)].add(d);
  122. }
  123. return results;
  124. }
  125. /**
  126. *
  127. * @return new double[centroids.length][3]
  128. */
  129. public double[][] getCentroidCoordinates(){
  130. double[][] co = new double[centroids.length][dimensions];
  131. for(int i = 0; i < centroids.length; i++){
  132. co[i] = Arrays.copyOf(centroids[i].coords, dimensions);
  133. }
  134. return co;
  135. }
  136. }

Usage:
  1. ArrayList<Node> children = ...;
  2. KMeans<Node> km = new KMeans<>(children, new KMeans.CoordGetter<Node>() {
  3. @Override
  4. public double get(int dimension, Node element) {
  5. switch(dimension){
  6. case 0: return element.x;
  7. case 1: return element.y;
  8. case 2: return element.z;
  9. }
  10. throw new IndexOutOfBoundsException();
  11. }
  12. @Override
  13. public int getDimensions() {
  14. return 3;
  15. }
  16. });
  17. km.init(6, new Random()); // 6 clusters
  18. int numberOfIterations = 20;
  19. for(int k = 0; k < numberOfIterations; k++){
  20. km.iterate();
  21. }
  22. ArrayList<Node>[] clustering = km.getResults();
  23. for(ArrayList<Node> c : clustering){
  24. System.out.println("Cluster size = " + c.size());
  25. }
1 Posted 03 Dec 2014
User avatar
Kebabbi
Moderator
Posts: 549
Je moet
  1. public void init(int k)
nog toevoegen
2 Posted 03 Dec 2014
User avatar
Rick Hendricksen
Administrator
Posts: 2251
Quoting: Post 1

Je moet
  1. public void init(int k)
nog toevoegen

Ik heb gewoon argument "new random()" toegevoegd
© Rick Hendricksen