閱讀654 返回首頁    go 阿裏雲


K-均值聚類__示例程序_圖模型_大數據計算服務-阿裏雲

k-均值聚類(Kmeans) 算法是非常基礎並大量使用的聚類算法。算法基本思想是:以空間中 k 個點為中心進行聚類,對最靠近他們的點進行歸類。通過迭代的方法,逐次更新各聚類中心的值,直至得到最好的聚類結果。

假設要把樣本集分為 k 個類別,算法描述如下:

  • 適當選擇 k 個類的初始中心
  • 在第 i 次迭代中,對任意一個樣本,求其到 k 個中心的距離,將該樣本歸到距離最短的中心所在的類
  • 利用均值等方法更新該類的中心值
  • 對於所有的 k 個聚類中心,如果利用上兩步的迭代法更新後,值保持不變或者小於某個閾值,則迭代結束,否則繼續迭代
  1. import java.io.DataInput;
  2. import java.io.DataOutput;
  3. import java.io.IOException;
  4. import org.apache.log4j.Logger;
  5. import com.aliyun.odps.io.WritableRecord;
  6. import com.aliyun.odps.graph.Aggregator;
  7. import com.aliyun.odps.graph.ComputeContext;
  8. import com.aliyun.odps.graph.GraphJob;
  9. import com.aliyun.odps.graph.GraphLoader;
  10. import com.aliyun.odps.graph.MutationContext;
  11. import com.aliyun.odps.graph.Vertex;
  12. import com.aliyun.odps.graph.WorkerContext;
  13. import com.aliyun.odps.io.DoubleWritable;
  14. import com.aliyun.odps.io.LongWritable;
  15. import com.aliyun.odps.io.NullWritable;
  16. import com.aliyun.odps.data.TableInfo;
  17. import com.aliyun.odps.io.Text;
  18. import com.aliyun.odps.io.Tuple;
  19. import com.aliyun.odps.io.Writable;
  20. public class Kmeans {
  21. private final static Logger LOG = Logger.getLogger(Kmeans.class);
  22. public static class KmeansVertex extends
  23. Vertex<Text, Tuple, NullWritable, NullWritable> {
  24. @Override
  25. public void compute(
  26. ComputeContext<Text, Tuple, NullWritable, NullWritable> context,
  27. Iterable<NullWritable> messages) throws IOException {
  28. context.aggregate(getValue());
  29. }
  30. }
  31. public static class KmeansVertexReader extends
  32. GraphLoader<Text, Tuple, NullWritable, NullWritable> {
  33. @Override
  34. public void load(LongWritable recordNum, WritableRecord record,
  35. MutationContext<Text, Tuple, NullWritable, NullWritable> context)
  36. throws IOException {
  37. KmeansVertex vertex = new KmeansVertex();
  38. vertex.setId(new Text(String.valueOf(recordNum.get())));
  39. vertex.setValue(new Tuple(record.getAll()));
  40. context.addVertexRequest(vertex);
  41. }
  42. }
  43. public static class KmeansAggrValue implements Writable {
  44. Tuple centers = new Tuple();
  45. Tuple sums = new Tuple();
  46. Tuple counts = new Tuple();
  47. @Override
  48. public void write(DataOutput out) throws IOException {
  49. centers.write(out);
  50. sums.write(out);
  51. counts.write(out);
  52. }
  53. @Override
  54. public void readFields(DataInput in) throws IOException {
  55. centers = new Tuple();
  56. centers.readFields(in);
  57. sums = new Tuple();
  58. sums.readFields(in);
  59. counts = new Tuple();
  60. counts.readFields(in);
  61. }
  62. @Override
  63. public String toString() {
  64. return "centers " + centers.toString() + ", sums " + sums.toString()
  65. + ", counts " + counts.toString();
  66. }
  67. }
  68. public static class KmeansAggregator extends Aggregator<KmeansAggrValue> {
  69. @SuppressWarnings("rawtypes")
  70. @Override
  71. public KmeansAggrValue createInitialValue(WorkerContext context)
  72. throws IOException {
  73. KmeansAggrValue aggrVal = null;
  74. if (context.getSuperstep() == 0) {
  75. aggrVal = new KmeansAggrValue();
  76. aggrVal.centers = new Tuple();
  77. aggrVal.sums = new Tuple();
  78. aggrVal.counts = new Tuple();
  79. byte[] centers = context.readCacheFile("centers");
  80. String lines[] = new String(centers).split("n");
  81. for (int i = 0; i < lines.length; i++) {
  82. String[] ss = lines[i].split(",");
  83. Tuple center = new Tuple();
  84. Tuple sum = new Tuple();
  85. for (int j = 0; j < ss.length; ++j) {
  86. center.append(new DoubleWritable(Double.valueOf(ss[j].trim())));
  87. sum.append(new DoubleWritable(0.0));
  88. }
  89. LongWritable count = new LongWritable(0);
  90. aggrVal.sums.append(sum);
  91. aggrVal.counts.append(count);
  92. aggrVal.centers.append(center);
  93. }
  94. } else {
  95. aggrVal = (KmeansAggrValue) context.getLastAggregatedValue(0);
  96. }
  97. return aggrVal;
  98. }
  99. @Override
  100. public void aggregate(KmeansAggrValue value, Object item) {
  101. int min = 0;
  102. double mindist = Double.MAX_VALUE;
  103. Tuple point = (Tuple) item;
  104. for (int i = 0; i < value.centers.size(); i++) {
  105. Tuple center = (Tuple) value.centers.get(i);
  106. // use Euclidean Distance, no need to calculate sqrt
  107. double dist = 0.0d;
  108. for (int j = 0; j < center.size(); j++) {
  109. double v = ((DoubleWritable) point.get(j)).get()
  110. - ((DoubleWritable) center.get(j)).get();
  111. dist += v * v;
  112. }
  113. if (dist < mindist) {
  114. mindist = dist;
  115. min = i;
  116. }
  117. }
  118. // update sum and count
  119. Tuple sum = (Tuple) value.sums.get(min);
  120. for (int i = 0; i < point.size(); i++) {
  121. DoubleWritable s = (DoubleWritable) sum.get(i);
  122. s.set(s.get() + ((DoubleWritable) point.get(i)).get());
  123. }
  124. LongWritable count = (LongWritable) value.counts.get(min);
  125. count.set(count.get() + 1);
  126. }
  127. @Override
  128. public void merge(KmeansAggrValue value, KmeansAggrValue partial) {
  129. for (int i = 0; i < value.sums.size(); i++) {
  130. Tuple sum = (Tuple) value.sums.get(i);
  131. Tuple that = (Tuple) partial.sums.get(i);
  132. for (int j = 0; j < sum.size(); j++) {
  133. DoubleWritable s = (DoubleWritable) sum.get(j);
  134. s.set(s.get() + ((DoubleWritable) that.get(j)).get());
  135. }
  136. }
  137. for (int i = 0; i < value.counts.size(); i++) {
  138. LongWritable count = (LongWritable) value.counts.get(i);
  139. count.set(count.get() + ((LongWritable) partial.counts.get(i)).get());
  140. }
  141. }
  142. @SuppressWarnings("rawtypes")
  143. @Override
  144. public boolean terminate(WorkerContext context, KmeansAggrValue value)
  145. throws IOException {
  146. // compute new centers
  147. Tuple newCenters = new Tuple(value.sums.size());
  148. for (int i = 0; i < value.sums.size(); i++) {
  149. Tuple sum = (Tuple) value.sums.get(i);
  150. Tuple newCenter = new Tuple(sum.size());
  151. LongWritable c = (LongWritable) value.counts.get(i);
  152. for (int j = 0; j < sum.size(); j++) {
  153. DoubleWritable s = (DoubleWritable) sum.get(j);
  154. double val = s.get() / c.get();
  155. newCenter.set(j, new DoubleWritable(val));
  156. // reset sum for next iteration
  157. s.set(0.0d);
  158. }
  159. // reset count for next iteration
  160. c.set(0);
  161. newCenters.set(i, newCenter);
  162. }
  163. // update centers
  164. Tuple oldCenters = value.centers;
  165. value.centers = newCenters;
  166. LOG.info("old centers: " + oldCenters + ", new centers: " + newCenters);
  167. // compare new/old centers
  168. boolean converged = true;
  169. for (int i = 0; i < value.centers.size() && converged; i++) {
  170. Tuple oldCenter = (Tuple) oldCenters.get(i);
  171. Tuple newCenter = (Tuple) newCenters.get(i);
  172. double sum = 0.0d;
  173. for (int j = 0; j < newCenter.size(); j++) {
  174. double v = ((DoubleWritable) newCenter.get(j)).get()
  175. - ((DoubleWritable) oldCenter.get(j)).get();
  176. sum += v * v;
  177. }
  178. double dist = Math.sqrt(sum);
  179. LOG.info("old center: " + oldCenter + ", new center: " + newCenter
  180. + ", dist: " + dist);
  181. // converge threshold for each center: 0.05
  182. converged = dist < 0.05d;
  183. }
  184. if (converged || context.getSuperstep() == context.getMaxIteration() - 1) {
  185. // converged or reach max iteration, output centers
  186. for (int i = 0; i < value.centers.size(); i++) {
  187. context.write(((Tuple) value.centers.get(i)).toArray());
  188. }
  189. // true means to terminate iteration
  190. return true;
  191. }
  192. // false means to continue iteration
  193. return false;
  194. }
  195. }
  196. private static void printUsage() {
  197. System.out.println("Usage: <in> <out> [Max iterations (default 30)]");
  198. System.exit(-1);
  199. }
  200. public static void main(String[] args) throws IOException {
  201. if (args.length < 2)
  202. printUsage();
  203. GraphJob job = new GraphJob();
  204. job.setGraphLoaderClass(KmeansVertexReader.class);
  205. job.setRuntimePartitioning(false);
  206. job.setVertexClass(KmeansVertex.class);
  207. job.setAggregatorClass(KmeansAggregator.class);
  208. job.addInput(TableInfo.builder().tableName(args[0]).build());
  209. job.addOutput(TableInfo.builder().tableName(args[1]).build());
  210. // default max iteration is 30
  211. job.setMaxIteration(30);
  212. if (args.length >= 3)
  213. job.setMaxIteration(Integer.parseInt(args[2]));
  214. long start = System.currentTimeMillis();
  215. job.run();
  216. System.out.println("Job Finished in "
  217. + (System.currentTimeMillis() - start) / 1000.0 + " seconds");
  218. }
  219. }

代碼說明

Kmeans 源代碼包括以下幾部分:

  • 38行:定義 KmeansVertexReader 類,加載圖,將表中每一條記錄解析為一個點,點標識無關緊要,這裏取傳入的 recordNum 序號作為標識,點值為記錄的所有列組成的 Tuple
  • 30行:定義 KmeansVertex,compute() 方法非常簡單,隻是調用上下文對象的 aggregate 方法,傳入當前點的取值(Tuple 類型,向量表示)
  • 83行:定義 KmeansAggregator,這個類封裝了 Kmeans 算法的主要邏輯,其中:
    • createInitialValue 為每一輪迭代創建初始值(k 類中心點),若是第一輪迭代(superstep=0),該取值為初始中心點,否則取值為上一輪結束時的新中心點;
    • aggregate 方法為每個點計算其到各個類中心的距離,並歸為距離最短的類,並更新該類的 sum 和 count;
    • merge 方法合並來自各個 worker 收集的 sum 和 count;
    • terminate 方法根據各個類的 sum 和 count 計算新的中心點,若新中心點與之前的中心點距離小於某個閾值或者迭代次數到達最大迭代次數設置,則終止迭代(返回 false),寫最終的中心點到結果表
  • 236行:主程序(main函數),定義 GraphJob,指定 Vertex/GraphLoader/Aggregator 等的實現,以及最大迭代次數(默認 30),並指定輸入輸出表。
  • 243行:job.setRuntimePartitioning(false),對於 Kmeans 算法,加載圖是不需要進行點的分發,設置 RuntimePartitioning 為false,提升加載圖時的性能。

最後更新:2016-09-21 11:03:08

  上一篇:go PageRank__示例程序_圖模型_大數據計算服務-阿裏雲
  下一篇:go BiPartiteMatchiing__示例程序_圖模型_大數據計算服務-阿裏雲