閱讀315 返回首頁    go iPhone_iPad_Mac_手機_平板_蘋果apple


線性回歸__示例程序_圖模型_大數據計算服務-阿裏雲

在統計學中,線性回歸是用來確定兩種或兩種以上變量間的相互依賴關係的統計分析方法,與分類算法處理離散預測不同 ,回歸算法可對連續值類型進行預測。線性回歸算法定義損失函數為樣本集的最小平方誤差之和,通過最小化損失函數求 解權重矢量。常用的解法是梯度下降法:

  • 初始化權重矢量,給定下降速率以及迭代次數(或者迭代收斂條件);
  • 對每個樣本,計算最小平方誤差
  • 對最小平方誤差求和,根據下降速率更新權重
  • 重複迭代直到收斂

源代碼

  1. import java.io.DataInput;
  2. import java.io.DataOutput;
  3. import java.io.IOException;
  4. import com.aliyun.odps.data.TableInfo;
  5. import com.aliyun.odps.graph.Aggregator;
  6. import com.aliyun.odps.graph.ComputeContext;
  7. import com.aliyun.odps.graph.GraphJob;
  8. import com.aliyun.odps.graph.MutationContext;
  9. import com.aliyun.odps.graph.WorkerContext;
  10. import com.aliyun.odps.graph.Vertex;
  11. import com.aliyun.odps.graph.GraphLoader;
  12. import com.aliyun.odps.io.DoubleWritable;
  13. import com.aliyun.odps.io.LongWritable;
  14. import com.aliyun.odps.io.NullWritable;
  15. import com.aliyun.odps.io.Tuple;
  16. import com.aliyun.odps.io.Writable;
  17. import com.aliyun.odps.io.WritableRecord;
  18. /**
  19. * LineRegression input: y,x1,x2,x3,......
  20. **/
  21. public class LinearRegression {
  22. public static class GradientWritable implements Writable {
  23. Tuple lastTheta;
  24. Tuple currentTheta;
  25. Tuple tmpGradient;
  26. LongWritable count;
  27. DoubleWritable lost;
  28. @Override
  29. public void readFields(DataInput in) throws IOException {
  30. lastTheta = new Tuple();
  31. lastTheta.readFields(in);
  32. currentTheta = new Tuple();
  33. currentTheta.readFields(in);
  34. tmpGradient = new Tuple();
  35. tmpGradient.readFields(in);
  36. count = new LongWritable();
  37. count.readFields(in);
  38. /* update 1: add a variable to store lost at every iteration */
  39. lost = new DoubleWritable();
  40. lost.readFields(in);
  41. }
  42. @Override
  43. public void write(DataOutput out) throws IOException {
  44. lastTheta.write(out);
  45. currentTheta.write(out);
  46. tmpGradient.write(out);
  47. count.write(out);
  48. lost.write(out);
  49. }
  50. }
  51. public static class LinearRegressionVertex extends
  52. Vertex<LongWritable, Tuple, NullWritable, NullWritable> {
  53. @Override
  54. public void compute(
  55. ComputeContext<LongWritable, Tuple, NullWritable, NullWritable> context,
  56. Iterable<NullWritable> messages) throws IOException {
  57. context.aggregate(getValue());
  58. }
  59. }
  60. public static class LinearRegressionVertexReader extends
  61. GraphLoader<LongWritable, Tuple, NullWritable, NullWritable> {
  62. @Override
  63. public void load(LongWritable recordNum, WritableRecord record,
  64. MutationContext<LongWritable, Tuple, NullWritable, NullWritable> context)
  65. throws IOException {
  66. LinearRegressionVertex vertex = new LinearRegressionVertex();
  67. vertex.setId(recordNum);
  68. vertex.setValue(new Tuple(record.getAll()));
  69. context.addVertexRequest(vertex);
  70. }
  71. }
  72. public static class LinearRegressionAggregator extends
  73. Aggregator<GradientWritable> {
  74. @SuppressWarnings("rawtypes")
  75. @Override
  76. public GradientWritable createInitialValue(WorkerContext context)
  77. throws IOException {
  78. if (context.getSuperstep() == 0) {
  79. /* set initial value, all 0 */
  80. GradientWritable grad = new GradientWritable();
  81. grad.lastTheta = new Tuple();
  82. grad.currentTheta = new Tuple();
  83. grad.tmpGradient = new Tuple();
  84. grad.count = new LongWritable(1);
  85. grad.lost = new DoubleWritable(0.0);
  86. int n = (int) Long.parseLong(context.getConfiguration()
  87. .get("Dimension"));
  88. for (int i = 0; i < n; i++) {
  89. grad.lastTheta.append(new DoubleWritable(0));
  90. grad.currentTheta.append(new DoubleWritable(0));
  91. grad.tmpGradient.append(new DoubleWritable(0));
  92. }
  93. return grad;
  94. } else
  95. return (GradientWritable) context.getLastAggregatedValue(0);
  96. }
  97. public static double vecMul(Tuple value, Tuple theta) {
  98. /* perform this partial computing: y(i)−hθ(x(i)) for each sample */
  99. /* value denote a piece of sample and value(0) is y */
  100. double sum = 0.0;
  101. for (int j = 1; j < value.size(); j++)
  102. sum += Double.parseDouble(value.get(j).toString())
  103. * Double.parseDouble(theta.get(j).toString());
  104. Double tmp = Double.parseDouble(theta.get(0).toString()) + sum
  105. - Double.parseDouble(value.get(0).toString());
  106. return tmp;
  107. }
  108. @Override
  109. public void aggregate(GradientWritable gradient, Object value)
  110. throws IOException {
  111. /*
  112. * perform on each vertex--each sample i:set theta(j) for each sample i
  113. * for each dimension
  114. */
  115. double tmpVar = vecMul((Tuple) value, gradient.currentTheta);
  116. /*
  117. * update 2:local worker aggregate(), perform like merge() below. This
  118. * means the variable gradient denotes the previous aggregated value
  119. */
  120. gradient.tmpGradient.set(0, new DoubleWritable(
  121. ((DoubleWritable) gradient.tmpGradient.get(0)).get() + tmpVar));
  122. gradient.lost.set(Math.pow(tmpVar, 2));
  123. /*
  124. * calculate (y(i)−hθ(x(i))) x(i)(j) for each sample i for each
  125. * dimension j
  126. */
  127. for (int j = 1; j < gradient.tmpGradient.size(); j++)
  128. gradient.tmpGradient.set(j, new DoubleWritable(
  129. ((DoubleWritable) gradient.tmpGradient.get(j)).get() + tmpVar
  130. * Double.parseDouble(((Tuple) value).get(j).toString())));
  131. }
  132. @Override
  133. public void merge(GradientWritable gradient, GradientWritable partial)
  134. throws IOException {
  135. /* perform SumAll on each dimension for all samples. */
  136. Tuple master = (Tuple) gradient.tmpGradient;
  137. Tuple part = (Tuple) partial.tmpGradient;
  138. for (int j = 0; j < gradient.tmpGradient.size(); j++) {
  139. DoubleWritable s = (DoubleWritable) master.get(j);
  140. s.set(s.get() + ((DoubleWritable) part.get(j)).get());
  141. }
  142. gradient.lost.set(gradient.lost.get() + partial.lost.get());
  143. }
  144. @SuppressWarnings("rawtypes")
  145. @Override
  146. public boolean terminate(WorkerContext context, GradientWritable gradient)
  147. throws IOException {
  148. /*
  149. * 1. calculate new theta 2. judge the diff between last step and this
  150. * step, if smaller than the threshold, stop iteration
  151. */
  152. gradient.lost = new DoubleWritable(gradient.lost.get()
  153. / (2 * context.getTotalNumVertices()));
  154. /*
  155. * we can calculate lost in order to make sure the algorithm is running on
  156. * the right direction (for debug)
  157. */
  158. System.out.println(gradient.count + " lost:" + gradient.lost);
  159. Tuple tmpGradient = gradient.tmpGradient;
  160. System.out.println("tmpGra" + tmpGradient);
  161. Tuple lastTheta = gradient.lastTheta;
  162. Tuple tmpCurrentTheta = new Tuple(gradient.currentTheta.size());
  163. System.out.println(gradient.count + " terminate_start_last:" + lastTheta);
  164. double alpha = 0.07; // learning rate
  165. // alpha =
  166. // Double.parseDouble(context.getConfiguration().get("Alpha"));
  167. /* perform theta(j) = theta(j)-alpha*tmpGradient */
  168. long M = context.getTotalNumVertices();
  169. /*
  170. * update 3: add (/M) on the code. The original code forget this step
  171. */
  172. for (int j = 0; j < lastTheta.size(); j++) {
  173. tmpCurrentTheta
  174. .set(
  175. j,
  176. new DoubleWritable(Double.parseDouble(lastTheta.get(j)
  177. .toString())
  178. - alpha
  179. / M
  180. * Double.parseDouble(tmpGradient.get(j).toString())));
  181. }
  182. System.out.println(gradient.count + " terminate_start_current:"
  183. + tmpCurrentTheta);
  184. // judge if convergence is happening.
  185. double diff = 0.00d;
  186. for (int j = 0; j < gradient.currentTheta.size(); j++)
  187. diff += Math.pow(((DoubleWritable) tmpCurrentTheta.get(j)).get()
  188. - ((DoubleWritable) lastTheta.get(j)).get(), 2);
  189. if (/*
  190. * Math.sqrt(diff) < 0.00000000005d ||
  191. */Long.parseLong(context.getConfiguration().get("Max_Iter_Num")) == gradient.count
  192. .get()) {
  193. context.write(gradient.currentTheta.toArray());
  194. return true;
  195. }
  196. gradient.lastTheta = tmpCurrentTheta;
  197. gradient.currentTheta = tmpCurrentTheta;
  198. gradient.count.set(gradient.count.get() + 1);
  199. int n = (int) Long.parseLong(context.getConfiguration().get("Dimension"));
  200. /*
  201. * update 4: Important!!! Remember this step. Graph won't reset the
  202. * initial value for global variables at the beginning of each iteration
  203. */
  204. for (int i = 0; i < n; i++) {
  205. gradient.tmpGradient.set(i, new DoubleWritable(0));
  206. }
  207. return false;
  208. }
  209. }
  210. public static void main(String[] args) throws IOException {
  211. GraphJob job = new GraphJob();
  212. job.setGraphLoaderClass(LinearRegressionVertexReader.class);
  213. job.setRuntimePartitioning(false);
  214. job.setNumWorkers(3);
  215. job.setVertexClass(LinearRegressionVertex.class);
  216. job.setAggregatorClass(LinearRegressionAggregator.class);
  217. job.addInput(TableInfo.builder().tableName(args[0]).build());
  218. job.addOutput(TableInfo.builder().tableName(args[1]).build());
  219. job.setMaxIteration(Integer.parseInt(args[2])); // Numbers of Iteration
  220. job.setInt("Max_Iter_Num", Integer.parseInt(args[2]));
  221. job.setInt("Dimension", Integer.parseInt(args[3])); // Dimension
  222. job.setFloat("Alpha", Float.parseFloat(args[4])); // Learning rate
  223. long start = System.currentTimeMillis();
  224. job.run();
  225. System.out.println("Job Finished in "
  226. + (System.currentTimeMillis() - start) / 1000.0 + " seconds");
  227. }
  228. }

最後更新:2016-06-22 12:03:05

  上一篇:go 拓撲排序__示例程序_圖模型_大數據計算服務-阿裏雲
  下一篇:go 三角形計數__示例程序_圖模型_大數據計算服務-阿裏雲