654
阿裏雲
K-均值聚類__示例程序_圖模型_大數據計算服務-阿裏雲
k-均值聚類(Kmeans) 算法是非常基礎並大量使用的聚類算法。算法基本思想是:以空間中 k 個點為中心進行聚類,對最靠近他們的點進行歸類。通過迭代的方法,逐次更新各聚類中心的值,直至得到最好的聚類結果。
假設要把樣本集分為 k 個類別,算法描述如下:
- 適當選擇 k 個類的初始中心
- 在第 i 次迭代中,對任意一個樣本,求其到 k 個中心的距離,將該樣本歸到距離最短的中心所在的類
- 利用均值等方法更新該類的中心值
- 對於所有的 k 個聚類中心,如果利用上兩步的迭代法更新後,值保持不變或者小於某個閾值,則迭代結束,否則繼續迭代
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import org.apache.log4j.Logger;
import com.aliyun.odps.io.WritableRecord;
import com.aliyun.odps.graph.Aggregator;
import com.aliyun.odps.graph.ComputeContext;
import com.aliyun.odps.graph.GraphJob;
import com.aliyun.odps.graph.GraphLoader;
import com.aliyun.odps.graph.MutationContext;
import com.aliyun.odps.graph.Vertex;
import com.aliyun.odps.graph.WorkerContext;
import com.aliyun.odps.io.DoubleWritable;
import com.aliyun.odps.io.LongWritable;
import com.aliyun.odps.io.NullWritable;
import com.aliyun.odps.data.TableInfo;
import com.aliyun.odps.io.Text;
import com.aliyun.odps.io.Tuple;
import com.aliyun.odps.io.Writable;
public class Kmeans {
private final static Logger LOG = Logger.getLogger(Kmeans.class);
public static class KmeansVertex extends
Vertex<Text, Tuple, NullWritable, NullWritable> {
@Override
public void compute(
ComputeContext<Text, Tuple, NullWritable, NullWritable> context,
Iterable<NullWritable> messages) throws IOException {
context.aggregate(getValue());
}
}
public static class KmeansVertexReader extends
GraphLoader<Text, Tuple, NullWritable, NullWritable> {
@Override
public void load(LongWritable recordNum, WritableRecord record,
MutationContext<Text, Tuple, NullWritable, NullWritable> context)
throws IOException {
KmeansVertex vertex = new KmeansVertex();
vertex.setId(new Text(String.valueOf(recordNum.get())));
vertex.setValue(new Tuple(record.getAll()));
context.addVertexRequest(vertex);
}
}
public static class KmeansAggrValue implements Writable {
Tuple centers = new Tuple();
Tuple sums = new Tuple();
Tuple counts = new Tuple();
@Override
public void write(DataOutput out) throws IOException {
centers.write(out);
sums.write(out);
counts.write(out);
}
@Override
public void readFields(DataInput in) throws IOException {
centers = new Tuple();
centers.readFields(in);
sums = new Tuple();
sums.readFields(in);
counts = new Tuple();
counts.readFields(in);
}
@Override
public String toString() {
return "centers " + centers.toString() + ", sums " + sums.toString()
+ ", counts " + counts.toString();
}
}
public static class KmeansAggregator extends Aggregator<KmeansAggrValue> {
@SuppressWarnings("rawtypes")
@Override
public KmeansAggrValue createInitialValue(WorkerContext context)
throws IOException {
KmeansAggrValue aggrVal = null;
if (context.getSuperstep() == 0) {
aggrVal = new KmeansAggrValue();
aggrVal.centers = new Tuple();
aggrVal.sums = new Tuple();
aggrVal.counts = new Tuple();
byte[] centers = context.readCacheFile("centers");
String lines[] = new String(centers).split("n");
for (int i = 0; i < lines.length; i++) {
String[] ss = lines[i].split(",");
Tuple center = new Tuple();
Tuple sum = new Tuple();
for (int j = 0; j < ss.length; ++j) {
center.append(new DoubleWritable(Double.valueOf(ss[j].trim())));
sum.append(new DoubleWritable(0.0));
}
LongWritable count = new LongWritable(0);
aggrVal.sums.append(sum);
aggrVal.counts.append(count);
aggrVal.centers.append(center);
}
} else {
aggrVal = (KmeansAggrValue) context.getLastAggregatedValue(0);
}
return aggrVal;
}
@Override
public void aggregate(KmeansAggrValue value, Object item) {
int min = 0;
double mindist = Double.MAX_VALUE;
Tuple point = (Tuple) item;
for (int i = 0; i < value.centers.size(); i++) {
Tuple center = (Tuple) value.centers.get(i);
// use Euclidean Distance, no need to calculate sqrt
double dist = 0.0d;
for (int j = 0; j < center.size(); j++) {
double v = ((DoubleWritable) point.get(j)).get()
- ((DoubleWritable) center.get(j)).get();
dist += v * v;
}
if (dist < mindist) {
mindist = dist;
min = i;
}
}
// update sum and count
Tuple sum = (Tuple) value.sums.get(min);
for (int i = 0; i < point.size(); i++) {
DoubleWritable s = (DoubleWritable) sum.get(i);
s.set(s.get() + ((DoubleWritable) point.get(i)).get());
}
LongWritable count = (LongWritable) value.counts.get(min);
count.set(count.get() + 1);
}
@Override
public void merge(KmeansAggrValue value, KmeansAggrValue partial) {
for (int i = 0; i < value.sums.size(); i++) {
Tuple sum = (Tuple) value.sums.get(i);
Tuple that = (Tuple) partial.sums.get(i);
for (int j = 0; j < sum.size(); j++) {
DoubleWritable s = (DoubleWritable) sum.get(j);
s.set(s.get() + ((DoubleWritable) that.get(j)).get());
}
}
for (int i = 0; i < value.counts.size(); i++) {
LongWritable count = (LongWritable) value.counts.get(i);
count.set(count.get() + ((LongWritable) partial.counts.get(i)).get());
}
}
@SuppressWarnings("rawtypes")
@Override
public boolean terminate(WorkerContext context, KmeansAggrValue value)
throws IOException {
// compute new centers
Tuple newCenters = new Tuple(value.sums.size());
for (int i = 0; i < value.sums.size(); i++) {
Tuple sum = (Tuple) value.sums.get(i);
Tuple newCenter = new Tuple(sum.size());
LongWritable c = (LongWritable) value.counts.get(i);
for (int j = 0; j < sum.size(); j++) {
DoubleWritable s = (DoubleWritable) sum.get(j);
double val = s.get() / c.get();
newCenter.set(j, new DoubleWritable(val));
// reset sum for next iteration
s.set(0.0d);
}
// reset count for next iteration
c.set(0);
newCenters.set(i, newCenter);
}
// update centers
Tuple oldCenters = value.centers;
value.centers = newCenters;
LOG.info("old centers: " + oldCenters + ", new centers: " + newCenters);
// compare new/old centers
boolean converged = true;
for (int i = 0; i < value.centers.size() && converged; i++) {
Tuple oldCenter = (Tuple) oldCenters.get(i);
Tuple newCenter = (Tuple) newCenters.get(i);
double sum = 0.0d;
for (int j = 0; j < newCenter.size(); j++) {
double v = ((DoubleWritable) newCenter.get(j)).get()
- ((DoubleWritable) oldCenter.get(j)).get();
sum += v * v;
}
double dist = Math.sqrt(sum);
LOG.info("old center: " + oldCenter + ", new center: " + newCenter
+ ", dist: " + dist);
// converge threshold for each center: 0.05
converged = dist < 0.05d;
}
if (converged || context.getSuperstep() == context.getMaxIteration() - 1) {
// converged or reach max iteration, output centers
for (int i = 0; i < value.centers.size(); i++) {
context.write(((Tuple) value.centers.get(i)).toArray());
}
// true means to terminate iteration
return true;
}
// false means to continue iteration
return false;
}
}
private static void printUsage() {
System.out.println("Usage: <in> <out> [Max iterations (default 30)]");
System.exit(-1);
}
public static void main(String[] args) throws IOException {
if (args.length < 2)
printUsage();
GraphJob job = new GraphJob();
job.setGraphLoaderClass(KmeansVertexReader.class);
job.setRuntimePartitioning(false);
job.setVertexClass(KmeansVertex.class);
job.setAggregatorClass(KmeansAggregator.class);
job.addInput(TableInfo.builder().tableName(args[0]).build());
job.addOutput(TableInfo.builder().tableName(args[1]).build());
// default max iteration is 30
job.setMaxIteration(30);
if (args.length >= 3)
job.setMaxIteration(Integer.parseInt(args[2]));
long start = System.currentTimeMillis();
job.run();
System.out.println("Job Finished in "
+ (System.currentTimeMillis() - start) / 1000.0 + " seconds");
}
}
代碼說明
Kmeans 源代碼包括以下幾部分:
- 38行:定義 KmeansVertexReader 類,加載圖,將表中每一條記錄解析為一個點,點標識無關緊要,這裏取傳入的 recordNum 序號作為標識,點值為記錄的所有列組成的 Tuple
- 30行:定義 KmeansVertex,compute() 方法非常簡單,隻是調用上下文對象的 aggregate 方法,傳入當前點的取值(Tuple 類型,向量表示)
- 83行:定義 KmeansAggregator,這個類封裝了 Kmeans 算法的主要邏輯,其中:
- createInitialValue 為每一輪迭代創建初始值(k 類中心點),若是第一輪迭代(superstep=0),該取值為初始中心點,否則取值為上一輪結束時的新中心點;
- aggregate 方法為每個點計算其到各個類中心的距離,並歸為距離最短的類,並更新該類的 sum 和 count;
- merge 方法合並來自各個 worker 收集的 sum 和 count;
- terminate 方法根據各個類的 sum 和 count 計算新的中心點,若新中心點與之前的中心點距離小於某個閾值或者迭代次數到達最大迭代次數設置,則終止迭代(返回 false),寫最終的中心點到結果表
- 236行:主程序(main函數),定義 GraphJob,指定 Vertex/GraphLoader/Aggregator 等的實現,以及最大迭代次數(默認 30),並指定輸入輸出表。
- 243行:job.setRuntimePartitioning(false),對於 Kmeans 算法,加載圖是不需要進行點的分發,設置 RuntimePartitioning 為false,提升加載圖時的性能。
最後更新:2016-09-21 11:03:08
上一篇:
PageRank__示例程序_圖模型_大數據計算服務-阿裏雲
下一篇:
BiPartiteMatchiing__示例程序_圖模型_大數據計算服務-阿裏雲
常見錯誤說明__附錄_大數據計算服務-阿裏雲
發送短信接口__API使用手冊_短信服務-阿裏雲
接口文檔__Android_安全組件教程_移動安全-阿裏雲
運營商錯誤碼(聯通)__常見問題_短信服務-阿裏雲
設置短信模板__使用手冊_短信服務-阿裏雲
OSS 權限問題及排查__常見錯誤及排除_最佳實踐_對象存儲 OSS-阿裏雲
消息通知__操作指南_批量計算-阿裏雲
設備端快速接入(MQTT)__快速開始_阿裏雲物聯網套件-阿裏雲
查詢API調用流量數據__API管理相關接口_API_API 網關-阿裏雲
使用STS訪問__JavaScript-SDK_SDK 參考_對象存儲 OSS-阿裏雲