线性回归__示例程序_图模型_大数据计算服务-阿里云
在统计学中,线性回归是用来确定两种或两种以上变量间的相互依赖关系的统计分析方法,与分类算法处理离散预测不同 ,回归算法可对连续值类型进行预测。线性回归算法定义损失函数为样本集的最小平方误差之和,通过最小化损失函数求 解权重矢量。常用的解法是梯度下降法:
- 初始化权重矢量,给定下降速率以及迭代次数(或者迭代收敛条件);
- 对每个样本,计算最小平方误差
- 对最小平方误差求和,根据下降速率更新权重
- 重复迭代直到收敛
源代码
import java.io.DataInput;import java.io.DataOutput;import java.io.IOException;import com.aliyun.odps.data.TableInfo;import com.aliyun.odps.graph.Aggregator;import com.aliyun.odps.graph.ComputeContext;import com.aliyun.odps.graph.GraphJob;import com.aliyun.odps.graph.MutationContext;import com.aliyun.odps.graph.WorkerContext;import com.aliyun.odps.graph.Vertex;import com.aliyun.odps.graph.GraphLoader;import com.aliyun.odps.io.DoubleWritable;import com.aliyun.odps.io.LongWritable;import com.aliyun.odps.io.NullWritable;import com.aliyun.odps.io.Tuple;import com.aliyun.odps.io.Writable;import com.aliyun.odps.io.WritableRecord;/*** LineRegression input: y,x1,x2,x3,......**/public class LinearRegression {public static class GradientWritable implements Writable {Tuple lastTheta;Tuple currentTheta;Tuple tmpGradient;LongWritable count;DoubleWritable lost;@Overridepublic void readFields(DataInput in) throws IOException {lastTheta = new Tuple();lastTheta.readFields(in);currentTheta = new Tuple();currentTheta.readFields(in);tmpGradient = new Tuple();tmpGradient.readFields(in);count = new LongWritable();count.readFields(in);/* update 1: add a variable to store lost at every iteration */lost = new DoubleWritable();lost.readFields(in);}@Overridepublic void write(DataOutput out) throws IOException {lastTheta.write(out);currentTheta.write(out);tmpGradient.write(out);count.write(out);lost.write(out);}}public static class LinearRegressionVertex extendsVertex<LongWritable, Tuple, NullWritable, NullWritable> {@Overridepublic void compute(ComputeContext<LongWritable, Tuple, NullWritable, NullWritable> context,Iterable<NullWritable> messages) throws IOException {context.aggregate(getValue());}}public static class LinearRegressionVertexReader extendsGraphLoader<LongWritable, Tuple, NullWritable, NullWritable> {@Overridepublic void load(LongWritable recordNum, WritableRecord record,MutationContext<LongWritable, Tuple, NullWritable, NullWritable> context)throws IOException {LinearRegressionVertex vertex = new LinearRegressionVertex();vertex.setId(recordNum);vertex.setValue(new Tuple(record.getAll()));context.addVertexRequest(vertex);}}public static class LinearRegressionAggregator extendsAggregator<GradientWritable> {@SuppressWarnings("rawtypes")@Overridepublic GradientWritable createInitialValue(WorkerContext context)throws IOException {if (context.getSuperstep() == 0) {/* set initial value, all 0 */GradientWritable grad = new GradientWritable();grad.lastTheta = new Tuple();grad.currentTheta = new Tuple();grad.tmpGradient = new Tuple();grad.count = new LongWritable(1);grad.lost = new DoubleWritable(0.0);int n = (int) Long.parseLong(context.getConfiguration().get("Dimension"));for (int i = 0; i < n; i++) {grad.lastTheta.append(new DoubleWritable(0));grad.currentTheta.append(new DoubleWritable(0));grad.tmpGradient.append(new DoubleWritable(0));}return grad;} elsereturn (GradientWritable) context.getLastAggregatedValue(0);}public static double vecMul(Tuple value, Tuple theta) {/* perform this partial computing: y(i)−hθ(x(i)) for each sample *//* value denote a piece of sample and value(0) is y */double sum = 0.0;for (int j = 1; j < value.size(); j++)sum += Double.parseDouble(value.get(j).toString())* Double.parseDouble(theta.get(j).toString());Double tmp = Double.parseDouble(theta.get(0).toString()) + sum- Double.parseDouble(value.get(0).toString());return tmp;}@Overridepublic void aggregate(GradientWritable gradient, Object value)throws IOException {/** perform on each vertex--each sample i:set theta(j) for each sample i* for each dimension*/double tmpVar = vecMul((Tuple) value, gradient.currentTheta);/** update 2:local worker aggregate(), perform like merge() below. This* means the variable gradient denotes the previous aggregated value*/gradient.tmpGradient.set(0, new DoubleWritable(((DoubleWritable) gradient.tmpGradient.get(0)).get() + tmpVar));gradient.lost.set(Math.pow(tmpVar, 2));/** calculate (y(i)−hθ(x(i))) x(i)(j) for each sample i for each* dimension j*/for (int j = 1; j < gradient.tmpGradient.size(); j++)gradient.tmpGradient.set(j, new DoubleWritable(((DoubleWritable) gradient.tmpGradient.get(j)).get() + tmpVar* Double.parseDouble(((Tuple) value).get(j).toString())));}@Overridepublic void merge(GradientWritable gradient, GradientWritable partial)throws IOException {/* perform SumAll on each dimension for all samples. */Tuple master = (Tuple) gradient.tmpGradient;Tuple part = (Tuple) partial.tmpGradient;for (int j = 0; j < gradient.tmpGradient.size(); j++) {DoubleWritable s = (DoubleWritable) master.get(j);s.set(s.get() + ((DoubleWritable) part.get(j)).get());}gradient.lost.set(gradient.lost.get() + partial.lost.get());}@SuppressWarnings("rawtypes")@Overridepublic boolean terminate(WorkerContext context, GradientWritable gradient)throws IOException {/** 1. calculate new theta 2. judge the diff between last step and this* step, if smaller than the threshold, stop iteration*/gradient.lost = new DoubleWritable(gradient.lost.get()/ (2 * context.getTotalNumVertices()));/** we can calculate lost in order to make sure the algorithm is running on* the right direction (for debug)*/System.out.println(gradient.count + " lost:" + gradient.lost);Tuple tmpGradient = gradient.tmpGradient;System.out.println("tmpGra" + tmpGradient);Tuple lastTheta = gradient.lastTheta;Tuple tmpCurrentTheta = new Tuple(gradient.currentTheta.size());System.out.println(gradient.count + " terminate_start_last:" + lastTheta);double alpha = 0.07; // learning rate// alpha =// Double.parseDouble(context.getConfiguration().get("Alpha"));/* perform theta(j) = theta(j)-alpha*tmpGradient */long M = context.getTotalNumVertices();/** update 3: add (/M) on the code. The original code forget this step*/for (int j = 0; j < lastTheta.size(); j++) {tmpCurrentTheta.set(j,new DoubleWritable(Double.parseDouble(lastTheta.get(j).toString())- alpha/ M* Double.parseDouble(tmpGradient.get(j).toString())));}System.out.println(gradient.count + " terminate_start_current:"+ tmpCurrentTheta);// judge if convergence is happening.double diff = 0.00d;for (int j = 0; j < gradient.currentTheta.size(); j++)diff += Math.pow(((DoubleWritable) tmpCurrentTheta.get(j)).get()- ((DoubleWritable) lastTheta.get(j)).get(), 2);if (/** Math.sqrt(diff) < 0.00000000005d ||*/Long.parseLong(context.getConfiguration().get("Max_Iter_Num")) == gradient.count.get()) {context.write(gradient.currentTheta.toArray());return true;}gradient.lastTheta = tmpCurrentTheta;gradient.currentTheta = tmpCurrentTheta;gradient.count.set(gradient.count.get() + 1);int n = (int) Long.parseLong(context.getConfiguration().get("Dimension"));/** update 4: Important!!! Remember this step. Graph won't reset the* initial value for global variables at the beginning of each iteration*/for (int i = 0; i < n; i++) {gradient.tmpGradient.set(i, new DoubleWritable(0));}return false;}}public static void main(String[] args) throws IOException {GraphJob job = new GraphJob();job.setGraphLoaderClass(LinearRegressionVertexReader.class);job.setRuntimePartitioning(false);job.setNumWorkers(3);job.setVertexClass(LinearRegressionVertex.class);job.setAggregatorClass(LinearRegressionAggregator.class);job.addInput(TableInfo.builder().tableName(args[0]).build());job.addOutput(TableInfo.builder().tableName(args[1]).build());job.setMaxIteration(Integer.parseInt(args[2])); // Numbers of Iterationjob.setInt("Max_Iter_Num", Integer.parseInt(args[2]));job.setInt("Dimension", Integer.parseInt(args[3])); // Dimensionjob.setFloat("Alpha", Float.parseFloat(args[4])); // Learning ratelong start = System.currentTimeMillis();job.run();System.out.println("Job Finished in "+ (System.currentTimeMillis() - start) / 1000.0 + " seconds");}}
最后更新:2016-06-22 12:03:05
上一篇:
拓扑排序__示例程序_图模型_大数据计算服务-阿里云
下一篇:
三角形计数__示例程序_图模型_大数据计算服务-阿里云
Topic管理__产品管理_控制台使用手册_阿里云物联网套件-阿里云
创建环境变量__API分组相关接口_API_API 网关-阿里云
重置密码__帐号管理_用户指南_云数据库 RDS 版-阿里云
接入详细说明__设备端接入手册_阿里云物联网套件-阿里云
设置日志__管理存储空间_控制台用户指南_对象存储 OSS-阿里云
.gov.cn 域名持有者变更(域名过户)操作说明__域名持有者信息修改_管理操作_域名-阿里云
创建RAM用户__快速入门_访问控制-阿里云
查询文档__数据管理_DMS for MongoDB_用户指南(NoSQL)_数据管理-阿里云
上云须知__金融云介绍_金融云-阿里云
权限相关常见问题__常见问题_大数据开发套件-阿里云
相关内容
常见错误说明__附录_大数据计算服务-阿里云
发送短信接口__API使用手册_短信服务-阿里云
接口文档__Android_安全组件教程_移动安全-阿里云
运营商错误码(联通)__常见问题_短信服务-阿里云
设置短信模板__使用手册_短信服务-阿里云
OSS 权限问题及排查__常见错误及排除_最佳实践_对象存储 OSS-阿里云
消息通知__操作指南_批量计算-阿里云
设备端快速接入(MQTT)__快速开始_阿里云物联网套件-阿里云
查询API调用流量数据__API管理相关接口_API_API 网关-阿里云
使用STS访问__JavaScript-SDK_SDK 参考_对象存储 OSS-阿里云