上一篇对[SVD++算法][4]的原理进行了总结,本文是对Spark SVDPlusPlus源码的分析总结。源码位置在Spark源码包的org.apache.spark.graphx.lib.SVDPlusPlus,需要引入spark-graphx相关包。
迭代公式推导
相比于SVD和implicit ALS算法,SVD++的python单机版算法效率明显低很多,因为多了很重的一个子项$|N(u)|^{-\frac{1}{2}}\sum_{j\in N(u)}y_j$,需要汇总用户接触物品集所表达的隐式反馈$y_j$。模型的训练目标如下:
下面是Spark SVDPlusPlus的配置类
|
|
基于梯度下降(为了并行,注意这里是批量梯度下降,不是随机梯度下降)方法,对目标学习公式求偏导,可得到如下迭代公式(为了方便源码表达,这里系数名称和源码保持一致):
Spark Graphx
因为Spark SVD++是基于Spark Graphx实现的,所以先对Graphx做简明总结。Graphx是用于图并行计算的spark组件,它基于RDD引入了图抽象:每个顶点和边都绑定属性的多重图(multigraph,两个顶点间有多条边)。Graphx提供了图操作和图算法集合。图的基本单元是点(Vertix)和边(Edge)组成的Triplets,下图中的蓝橙颜色块表示属性。
对于用户的评分数据,A就表示用户,B表示物品,边上的橙色块表示评分,蓝色快储存了我们需要迭代了p、q、y、b*,会在之后详述。
Spark SVD++用到的图操作如下:
操作 | 解释 |
---|---|
aggregateMessages | Graphx的核心聚合操作,主要有两步操作 1.sendMsg:将edge中的属性发送到顶点 2.mergeMsg:在顶点处进行merge操作 这里我们可以联想到将物品表达的隐式反馈汇总到用户顶点 |
outerJoinVertices | 是aggregateMessages的小伙伴 聚合操作会返回新的Vertices集合(这也是scala函数式编程的特点) 基于VertixId与原来的Triplets集合进行join,更新对应顶点里的属性 |
核心源码分析
下面分析关键代码片段,主要关注变量迭代,入口代码如下,Edge保存了Double类型的评分数据,srcVertex是用户,dstVertex是物品
|
|
顶点属性初始化和全局均值$u$计算
|
|
这里为什么顶点属性是(v1, v2, 0.0, 0.0)=(Array,Array,Double,Double)这个模样,源代码没有注释,阅读了后面的代码细节才能推断出来,觉得这是源码做的不好的地方,虽然整体处理流程设计的很精巧,可读性却一般般,不过跟猜谜语一样,也挺有趣。
我们的目标是迭代$p_u、q_i、y_i、b_i、bu$,其中$p_u,b_i$属于用户属性,$q_i、y_i、b_i$属于物品属性。看完整个代码,位置分布谜底如下,空缺的位置是存放一些中间值。
- 用户属性
- 物品属性
下面以user_property表示用户顶点属性,item_property表示物品顶点属性,下标从0开始。随着源码不断更新空缺位置,此时
bias和norm初始化($b_u,b_i$和$|N(u)|^{-\frac{1}{2}}$)
计算每个用户和物品的ratingcount($rc$)、ratingsum $rs$ (*号表示不区分用户和物品)
|
|
初始化
|
|
此时,
这里,物品属性中的$|N(i)|^{-\frac{1}{2}}$并没有什么作用,作者应该是为了代码简洁,对用户属性和物品属性使用了同样的处理作用,出现了这个副产物。至此,准备工作全部ready。
迭代
阶段1 计算
这是为下一步计算预测评分pred,进而计算误差和因子更新迭代做准备。根据公式,需要把用户$u$看过的物品所表达的隐式反馈聚合到用户端(这一点太符合spark图计算了)。
1.1 聚合计算$\sum_{j\in N(u)}y_j$
|
|
1.2 更新 $user_property[1]= |N(u)|^{-\frac{1}{2}}\sum_{j\in N(u)}y_j$
|
|
此时
阶段2 更新$p_u$,$q_i$,$y_i$
2.1 梯度求解
ctx.sendToSrc ($\Delta p_u,\Delta y_i$, $\Delta b_u$)
ctx.sendToDst ($\Delta q_i,\Delta y_i$, $\Delta b_i$)
|
|
2.2 合并为
|
|
2.3 更新
|
|
总结
整个源码阅读下来,还是挺有收获的,理解了svd++的并行原理。spark svd++处理逻辑设计的很精巧,代码简洁高效,全篇代码也就200行左右,当然这也和spark、scala语言简洁有关。spark图模块的聚合操作非常契合svd++的迭代计算。唯一觉得有点不好的地方是代码可读性稍微不足,不过也是因为自己水平不足,读起来有点费劲。越精巧的代码就应该多点注释增强可读性,方便维护和迭代。
如果文中有哪里理解不对的地方,希望大家帮忙指正。