315
阿裏雲
線性回歸__示例程序_圖模型_大數據計算服務-阿裏雲
在統計學中,線性回歸是用來確定兩種或兩種以上變量間的相互依賴關係的統計分析方法,與分類算法處理離散預測不同 ,回歸算法可對連續值類型進行預測。線性回歸算法定義損失函數為樣本集的最小平方誤差之和,通過最小化損失函數求 解權重矢量。常用的解法是梯度下降法:
- 初始化權重矢量,給定下降速率以及迭代次數(或者迭代收斂條件);
- 對每個樣本,計算最小平方誤差
- 對最小平方誤差求和,根據下降速率更新權重
- 重複迭代直到收斂
源代碼
import java.io.DataInput;import java.io.DataOutput;import java.io.IOException;import com.aliyun.odps.data.TableInfo;import com.aliyun.odps.graph.Aggregator;import com.aliyun.odps.graph.ComputeContext;import com.aliyun.odps.graph.GraphJob;import com.aliyun.odps.graph.MutationContext;import com.aliyun.odps.graph.WorkerContext;import com.aliyun.odps.graph.Vertex;import com.aliyun.odps.graph.GraphLoader;import com.aliyun.odps.io.DoubleWritable;import com.aliyun.odps.io.LongWritable;import com.aliyun.odps.io.NullWritable;import com.aliyun.odps.io.Tuple;import com.aliyun.odps.io.Writable;import com.aliyun.odps.io.WritableRecord;/*** LineRegression input: y,x1,x2,x3,......**/public class LinearRegression {public static class GradientWritable implements Writable {Tuple lastTheta;Tuple currentTheta;Tuple tmpGradient;LongWritable count;DoubleWritable lost;@Overridepublic void readFields(DataInput in) throws IOException {lastTheta = new Tuple();lastTheta.readFields(in);currentTheta = new Tuple();currentTheta.readFields(in);tmpGradient = new Tuple();tmpGradient.readFields(in);count = new LongWritable();count.readFields(in);/* update 1: add a variable to store lost at every iteration */lost = new DoubleWritable();lost.readFields(in);}@Overridepublic void write(DataOutput out) throws IOException {lastTheta.write(out);currentTheta.write(out);tmpGradient.write(out);count.write(out);lost.write(out);}}public static class LinearRegressionVertex extendsVertex<LongWritable, Tuple, NullWritable, NullWritable> {@Overridepublic void compute(ComputeContext<LongWritable, Tuple, NullWritable, NullWritable> context,Iterable<NullWritable> messages) throws IOException {context.aggregate(getValue());}}public static class LinearRegressionVertexReader extendsGraphLoader<LongWritable, Tuple, NullWritable, NullWritable> {@Overridepublic void load(LongWritable recordNum, WritableRecord record,MutationContext<LongWritable, Tuple, NullWritable, NullWritable> context)throws IOException {LinearRegressionVertex vertex = new LinearRegressionVertex();vertex.setId(recordNum);vertex.setValue(new Tuple(record.getAll()));context.addVertexRequest(vertex);}}public static class LinearRegressionAggregator extendsAggregator<GradientWritable> {@SuppressWarnings("rawtypes")@Overridepublic GradientWritable createInitialValue(WorkerContext context)throws IOException {if (context.getSuperstep() == 0) {/* set initial value, all 0 */GradientWritable grad = new GradientWritable();grad.lastTheta = new Tuple();grad.currentTheta = new Tuple();grad.tmpGradient = new Tuple();grad.count = new LongWritable(1);grad.lost = new DoubleWritable(0.0);int n = (int) Long.parseLong(context.getConfiguration().get("Dimension"));for (int i = 0; i < n; i++) {grad.lastTheta.append(new DoubleWritable(0));grad.currentTheta.append(new DoubleWritable(0));grad.tmpGradient.append(new DoubleWritable(0));}return grad;} elsereturn (GradientWritable) context.getLastAggregatedValue(0);}public static double vecMul(Tuple value, Tuple theta) {/* perform this partial computing: y(i)−hθ(x(i)) for each sample *//* value denote a piece of sample and value(0) is y */double sum = 0.0;for (int j = 1; j < value.size(); j++)sum += Double.parseDouble(value.get(j).toString())* Double.parseDouble(theta.get(j).toString());Double tmp = Double.parseDouble(theta.get(0).toString()) + sum- Double.parseDouble(value.get(0).toString());return tmp;}@Overridepublic void aggregate(GradientWritable gradient, Object value)throws IOException {/** perform on each vertex--each sample i:set theta(j) for each sample i* for each dimension*/double tmpVar = vecMul((Tuple) value, gradient.currentTheta);/** update 2:local worker aggregate(), perform like merge() below. This* means the variable gradient denotes the previous aggregated value*/gradient.tmpGradient.set(0, new DoubleWritable(((DoubleWritable) gradient.tmpGradient.get(0)).get() + tmpVar));gradient.lost.set(Math.pow(tmpVar, 2));/** calculate (y(i)−hθ(x(i))) x(i)(j) for each sample i for each* dimension j*/for (int j = 1; j < gradient.tmpGradient.size(); j++)gradient.tmpGradient.set(j, new DoubleWritable(((DoubleWritable) gradient.tmpGradient.get(j)).get() + tmpVar* Double.parseDouble(((Tuple) value).get(j).toString())));}@Overridepublic void merge(GradientWritable gradient, GradientWritable partial)throws IOException {/* perform SumAll on each dimension for all samples. */Tuple master = (Tuple) gradient.tmpGradient;Tuple part = (Tuple) partial.tmpGradient;for (int j = 0; j < gradient.tmpGradient.size(); j++) {DoubleWritable s = (DoubleWritable) master.get(j);s.set(s.get() + ((DoubleWritable) part.get(j)).get());}gradient.lost.set(gradient.lost.get() + partial.lost.get());}@SuppressWarnings("rawtypes")@Overridepublic boolean terminate(WorkerContext context, GradientWritable gradient)throws IOException {/** 1. calculate new theta 2. judge the diff between last step and this* step, if smaller than the threshold, stop iteration*/gradient.lost = new DoubleWritable(gradient.lost.get()/ (2 * context.getTotalNumVertices()));/** we can calculate lost in order to make sure the algorithm is running on* the right direction (for debug)*/System.out.println(gradient.count + " lost:" + gradient.lost);Tuple tmpGradient = gradient.tmpGradient;System.out.println("tmpGra" + tmpGradient);Tuple lastTheta = gradient.lastTheta;Tuple tmpCurrentTheta = new Tuple(gradient.currentTheta.size());System.out.println(gradient.count + " terminate_start_last:" + lastTheta);double alpha = 0.07; // learning rate// alpha =// Double.parseDouble(context.getConfiguration().get("Alpha"));/* perform theta(j) = theta(j)-alpha*tmpGradient */long M = context.getTotalNumVertices();/** update 3: add (/M) on the code. The original code forget this step*/for (int j = 0; j < lastTheta.size(); j++) {tmpCurrentTheta.set(j,new DoubleWritable(Double.parseDouble(lastTheta.get(j).toString())- alpha/ M* Double.parseDouble(tmpGradient.get(j).toString())));}System.out.println(gradient.count + " terminate_start_current:"+ tmpCurrentTheta);// judge if convergence is happening.double diff = 0.00d;for (int j = 0; j < gradient.currentTheta.size(); j++)diff += Math.pow(((DoubleWritable) tmpCurrentTheta.get(j)).get()- ((DoubleWritable) lastTheta.get(j)).get(), 2);if (/** Math.sqrt(diff) < 0.00000000005d ||*/Long.parseLong(context.getConfiguration().get("Max_Iter_Num")) == gradient.count.get()) {context.write(gradient.currentTheta.toArray());return true;}gradient.lastTheta = tmpCurrentTheta;gradient.currentTheta = tmpCurrentTheta;gradient.count.set(gradient.count.get() + 1);int n = (int) Long.parseLong(context.getConfiguration().get("Dimension"));/** update 4: Important!!! Remember this step. Graph won't reset the* initial value for global variables at the beginning of each iteration*/for (int i = 0; i < n; i++) {gradient.tmpGradient.set(i, new DoubleWritable(0));}return false;}}public static void main(String[] args) throws IOException {GraphJob job = new GraphJob();job.setGraphLoaderClass(LinearRegressionVertexReader.class);job.setRuntimePartitioning(false);job.setNumWorkers(3);job.setVertexClass(LinearRegressionVertex.class);job.setAggregatorClass(LinearRegressionAggregator.class);job.addInput(TableInfo.builder().tableName(args[0]).build());job.addOutput(TableInfo.builder().tableName(args[1]).build());job.setMaxIteration(Integer.parseInt(args[2])); // Numbers of Iterationjob.setInt("Max_Iter_Num", Integer.parseInt(args[2]));job.setInt("Dimension", Integer.parseInt(args[3])); // Dimensionjob.setFloat("Alpha", Float.parseFloat(args[4])); // Learning ratelong start = System.currentTimeMillis();job.run();System.out.println("Job Finished in "+ (System.currentTimeMillis() - start) / 1000.0 + " seconds");}}
最後更新:2016-06-22 12:03:05
上一篇:
拓撲排序__示例程序_圖模型_大數據計算服務-阿裏雲
下一篇:
三角形計數__示例程序_圖模型_大數據計算服務-阿裏雲
Topic管理__產品管理_控製台使用手冊_阿裏雲物聯網套件-阿裏雲
創建環境變量__API分組相關接口_API_API 網關-阿裏雲
重置密碼__帳號管理_用戶指南_雲數據庫 RDS 版-阿裏雲
接入詳細說明__設備端接入手冊_阿裏雲物聯網套件-阿裏雲
設置日誌__管理存儲空間_控製台用戶指南_對象存儲 OSS-阿裏雲
.gov.cn 域名持有者變更(域名過戶)操作說明__域名持有者信息修改_管理操作_域名-阿裏雲
創建RAM用戶__快速入門_訪問控製-阿裏雲
查詢文檔__數據管理_DMS for MongoDB_用戶指南(NoSQL)_數據管理-阿裏雲
上雲須知__金融雲介紹_金融雲-阿裏雲
權限相關常見問題__常見問題_大數據開發套件-阿裏雲
相關內容
常見錯誤說明__附錄_大數據計算服務-阿裏雲
發送短信接口__API使用手冊_短信服務-阿裏雲
接口文檔__Android_安全組件教程_移動安全-阿裏雲
運營商錯誤碼(聯通)__常見問題_短信服務-阿裏雲
設置短信模板__使用手冊_短信服務-阿裏雲
OSS 權限問題及排查__常見錯誤及排除_最佳實踐_對象存儲 OSS-阿裏雲
消息通知__操作指南_批量計算-阿裏雲
設備端快速接入(MQTT)__快速開始_阿裏雲物聯網套件-阿裏雲
查詢API調用流量數據__API管理相關接口_API_API 網關-阿裏雲
使用STS訪問__JavaScript-SDK_SDK 參考_對象存儲 OSS-阿裏雲