spark SVD++源码解析

  上一篇对[SVD++算法][4]的原理进行了总结,本文是对Spark SVDPlusPlus源码的分析总结。源码位置在Spark源码包的org.apache.spark.graphx.lib.SVDPlusPlus,需要引入spark-graphx相关包。

迭代公式推导

  相比于SVD和implicit ALS算法,SVD++的python单机版算法效率明显低很多,因为多了很重的一个子项$|N(u)|^{-\frac{1}{2}}\sum_{j\in N(u)}y_j$,需要汇总用户接触物品集所表达的隐式反馈$y_j$。模型的训练目标如下:

下面是Spark SVDPlusPlus的配置类

1
2
3
4
5
6
7
8
9
10
class Conf(
var rank: Int,//因子向量维度
var maxIters: Int,//最大迭代次数
var minVal: Double,//评分下限
var maxVal: Double,//评分上限
var gamma1: Double,//b*梯度下降学习速率
var gamma2: Double,//q,p,y梯度下降学习速率
var gamma6: Double,//b*正则系数
var gamma7: Double)//q,p,y正则系数
extends Serializable

基于梯度下降(为了并行,注意这里是批量梯度下降,不是随机梯度下降)方法,对目标学习公式求偏导,可得到如下迭代公式(为了方便源码表达,这里系数名称和源码保持一致):

Spark Graphx

  因为Spark SVD++是基于Spark Graphx实现的,所以先对Graphx做简明总结。Graphx是用于图并行计算的spark组件,它基于RDD引入了图抽象:每个顶点和边都绑定属性的多重图(multigraph,两个顶点间有多条边)。Graphx提供了图操作和图算法集合。图的基本单元是点(Vertix)和边(Edge)组成的Triplets,下图中的蓝橙颜色块表示属性。

这里写图片描述
  对于用户的评分数据,A就表示用户,B表示物品,边上的橙色块表示评分,蓝色快储存了我们需要迭代了p、q、y、b*,会在之后详述。

Spark SVD++用到的图操作如下:

操作 解释
aggregateMessages Graphx的核心聚合操作,主要有两步操作
1.sendMsg:将edge中的属性发送到顶点
2.mergeMsg:在顶点处进行merge操作
这里我们可以联想到将物品表达的隐式反馈汇总到用户顶点
outerJoinVertices 是aggregateMessages的小伙伴
聚合操作会返回新的Vertices集合(这也是scala函数式编程的特点)
基于VertixId与原来的Triplets集合进行join,更新对应顶点里的属性

核心源码分析

下面分析关键代码片段,主要关注变量迭代,入口代码如下,Edge保存了Double类型的评分数据,srcVertex是用户,dstVertex是物品

1
def run(edges: RDD[Edge[Double]], conf: Conf)

顶点属性初始化和全局均值$u$计算

1
2
3
4
5
6
7
8
9
10
11
12
// 生成默认的顶点属性 包含四个属性(v1,v2,0,0),v1、v2表示向量
def defaultF(rank: Int): (Array[Double], Array[Double], Double, Double) = {
// TODO: use a fixed random seed
val v1 = Array.fill(rank)(Random.nextDouble())
val v2 = Array.fill(rank)(Random.nextDouble())
(v1, v2, 0.0, 0.0)
}
//计算全局均值u
val (rs, rc) = edges.map(e => (e.attr, 1L)).reduce((a, b) => (a._1 + b._1, a._2 + b._2))
val u = rs / rc
//初始化图g
var g = Graph.fromEdges(edges, defaultF(conf.rank)).cache()

  这里为什么顶点属性是(v1, v2, 0.0, 0.0)=(Array,Array,Double,Double)这个模样,源代码没有注释,阅读了后面的代码细节才能推断出来,觉得这是源码做的不好的地方,虽然整体处理流程设计的很精巧,可读性却一般般,不过跟猜谜语一样,也挺有趣。
我们的目标是迭代$p_u、q_i、y_i、b_i、bu$,其中$p_u,b_i$属于用户属性,$q_i、y_i、b_i$属于物品属性。看完整个代码,位置分布谜底如下,空缺的位置是存放一些中间值。

  • 用户属性
  • 物品属性

下面以user_property表示用户顶点属性,item_property表示物品顶点属性,下标从0开始。随着源码不断更新空缺位置,此时

bias和norm初始化($b_u,b_i$和$|N(u)|^{-\frac{1}{2}}$)

计算每个用户和物品的ratingcount($rc$)、ratingsum $rs$ (*号表示不区分用户和物品)

1
2
3
val t0 = g.aggregateMessages[(Long, Double)](
ctx => { ctx.sendToSrc((1L, ctx.attr)); ctx.sendToDst((1L, ctx.attr)) },
(g1, g2) => (g1._1 + g2._1, g1._2 + g2._2))

初始化

1
2
3
4
5
6
7
8
9
val gJoinT0 = g.outerJoinVertices(t0) {
(vid: VertexId, vd: (Array[Double], Array[Double], Double, Double),
msg: Option[(Long, Double)]) =>
(vd._1, vd._2, msg.get._2 / msg.get._1 - u, 1.0 / scala.math.sqrt(msg.get._1))
}.cache()
//触发spark action操作,便于缓存
materialize(gJoinT0)
g.unpersist()
g = gJoinT0

此时,

这里,物品属性中的$|N(i)|^{-\frac{1}{2}}$并没有什么作用,作者应该是为了代码简洁,对用户属性和物品属性使用了同样的处理作用,出现了这个副产物。至此,准备工作全部ready。

迭代

阶段1 计算

这是为下一步计算预测评分pred,进而计算误差和因子更新迭代做准备。根据公式,需要把用户$u$看过的物品所表达的隐式反馈聚合到用户端(这一点太符合spark图计算了)。

1.1 聚合计算$\sum_{j\in N(u)}y_j$
1
2
3
4
5
6
7
8
9
10
val t1 = g.aggregateMessages[Array[Double]](
//注意聚合用到了物品属性第二个位置的值,可以推断出该位置是隐式反馈yi
ctx => ctx.sendToSrc(ctx.dstAttr._2),
(g1, g2) => {
val out = g1.clone()
//向量相加操作 g1 + g2,使用了blas daxpy api,out=out+g2
blas.daxpy(out.length, 1.0, g2, 1, out, 1)
//和用户u关联的隐式反馈之和
out
})
1.2 更新 $user_property[1]= |N(u)|^{-\frac{1}{2}}\sum_{j\in N(u)}y_j$
1
2
3
4
5
6
7
8
9
10
11
12
val gJoinT1 = g.outerJoinVertices(t1) {
(vid: VertexId, vd: (Array[Double], Array[Double], Double, Double),
msg: Option[Array[Double]]) =>
//物品顶点接收不到,因为聚合信息只发送给了用户顶点
if (msg.isDefined) {
val out = vd._1.clone()
blas.daxpy(out.length, vd._4, msg.get, 1, out, 1)
(vd._1, out, vd._3, vd._4)
} else {
vd
}
}.cache()

此时

阶段2 更新$p_u$,$q_i$,$y_i$

2.1 梯度求解

ctx.sendToSrc ($\Delta p_u,\Delta y_i$, $\Delta b_u$)
ctx.sendToDst ($\Delta q_i,\Delta y_i$, $\Delta b_i$)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def sendMsgTrainF(conf: Conf, u: Double)
(ctx: EdgeContext[
(Array[Double], Array[Double], Double, Double),
Double,
(Array[Double], Array[Double], Double)]) {
val (usr, itm) = (ctx.srcAttr, ctx.dstAttr)
//这里说明了pu,qi的位置
val (p, q) = (usr._1, itm._1)
val rank = p.length
//计算误差
var pred = u + usr._3 + itm._3 + blas.ddot(rank, q, 1, usr._2, 1)
pred = math.max(pred, conf.minVal)
pred = math.min(pred, conf.maxVal)
val err = ctx.attr - pred
// updateP = (err * q - conf.gamma7 * p) * conf.gamma2
val updateP = q.clone()
blas.dscal(rank, err * conf.gamma2, updateP, 1)
blas.daxpy(rank, -conf.gamma7 * conf.gamma2, p, 1, updateP, 1)
// updateQ = (err * usr._2 - conf.gamma7 * q) * conf.gamma2
val updateQ = usr._2.clone()
blas.dscal(rank, err * conf.gamma2, updateQ, 1)
blas.daxpy(rank, -conf.gamma7 * conf.gamma2, q, 1, updateQ, 1)
// updateY = (err * usr._4 * q - conf.gamma7 * itm._2) * conf.gamma2
val updateY = q.clone()
blas.dscal(rank, err * usr._4 * conf.gamma2, updateY, 1)
blas.daxpy(rank, -conf.gamma7 * conf.gamma2, itm._2, 1, updateY, 1)
ctx.sendToSrc((updateP, updateY, (err - conf.gamma6 * usr._3) * conf.gamma1))
ctx.sendToDst((updateQ, updateY, (err - conf.gamma6 * itm._3) * conf.gamma1))
}
2.2 合并为
1
2
3
4
5
6
7
8
9
10
val t2 = g.aggregateMessages(
sendMsgTrainF(conf, u),
(g1: (Array[Double], Array[Double], Double), g2: (Array[Double], Array[Double], Double)) =>
{
val out1 = g1._1.clone()
blas.daxpy(out1.length, 1.0, g2._1, 1, out1, 1)
val out2 = g2._2.clone()
blas.daxpy(out2.length, 1.0, g2._2, 1, out2, 1)
(out1, out2, g1._3 + g2._3)
})
2.3 更新
1
2
3
4
5
6
7
8
9
10
11
val gJoinT2 = g.outerJoinVertices(t2) {
(vid: VertexId,
vd: (Array[Double], Array[Double], Double, Double),
msg: Option[(Array[Double], Array[Double], Double)]) => {
val out1 = vd._1.clone()
blas.daxpy(out1.length, 1.0, msg.get._1, 1, out1, 1)
val out2 = vd._2.clone()
blas.daxpy(out2.length, 1.0, msg.get._2, 1, out2, 1)
(out1, out2, vd._3 + msg.get._3, vd._4)
}
}.cache()

总结

  整个源码阅读下来,还是挺有收获的,理解了svd++的并行原理。spark svd++处理逻辑设计的很精巧,代码简洁高效,全篇代码也就200行左右,当然这也和spark、scala语言简洁有关。spark图模块的聚合操作非常契合svd++的迭代计算。唯一觉得有点不好的地方是代码可读性稍微不足,不过也是因为自己水平不足,读起来有点费劲。越精巧的代码就应该多点注释增强可读性,方便维护和迭代。
  如果文中有哪里理解不对的地方,希望大家帮忙指正。

参考

spark svd++源码
spark svd++源码分析