赞
踩
参考文献:https://arxiv.org/pdf/1602.05629v4.pdf
对于一个机器学习应用来说,我们需要找到一个目标函数,使其最小化
f
(
w
)
=
1
n
∑
i
=
1
n
f
i
(
w
)
f(w) = \frac{1}{n}\sum_{i=1}^{n}f_i(w)
f(w)=n1i=1∑nfi(w)
上面等式中, f i ( w ) = l ( x i , y i , w ) f_i(w) = l(x_i, y_i, w) fi(w)=l(xi,yi,w),表示参数为 w w w的模型在样本 ( x i , y i ) (x_i, y_i) (xi,yi)上预测的损失。
假设现在有多个设备并行计算模型在某个数据集上的总体预测损失,总样本数为
n
n
n,设备
k
k
k上有
n
k
n_k
nk个样本,那么设备
k
k
k上的损失为:
F
k
(
w
)
=
1
n
k
∑
i
=
1
n
k
f
k
(
w
)
F_{k}(w) = \frac{1}{n_k}\sum_{i=1}^{n_k}f_k(w)
Fk(w)=nk1i=1∑nkfk(w)
那么模型在整个数据集上的预测损失为:
f
(
w
)
=
∑
k
=
1
K
n
k
n
F
k
(
w
)
=
1
n
∑
k
=
1
K
∑
i
=
1
n
k
f
i
(
w
)
f(w) = \sum_{k=1}^K \frac{n_k}{n} F_k(w) = \frac{1}{n}\sum_{k=1}^K\sum_{i=1}^{n_k}f_i(w)
f(w)=k=1∑KnnkFk(w)=n1k=1∑Ki=1∑nkfi(w)
相当于使用同一个模型在整个数据集上跑了一遍,得到了总体的平均损失
为了减少client和server间的通信次数,可以让更多的计算在client上完成。
假设第 t t t轮通信中设备k进行一次梯度下降得到的梯度为: g k = ∇ F k ( w t ) gk=\nabla F_k(w_t) gk=∇Fk(wt), g k g_k gk最终会发送到服务器。 w t w_t wt是第 t t t轮通信中模型的参数。
根据求导法则可知:
∇
f
(
w
t
)
=
∑
k
=
1
K
n
k
n
∇
F
k
(
w
t
)
\nabla f(w_t) = \sum_{k=1}^K \frac{n_k}{n} \nabla F_k(w_t)
∇f(wt)=∑k=1Knnk∇Fk(wt),所以服务器拿到所以client的参数之后,更新下一轮模型的参数为:
w
t
+
1
=
w
t
−
α
∇
f
(
w
t
)
=
w
t
−
α
∑
k
=
1
K
n
k
n
g
k
w_{t+1} = w_t - \alpha \nabla f(w_t) = w_t - \alpha \sum_{k=1}^{K} \frac{n_k}{n} g_k
wt+1=wt−α∇f(wt)=wt−αk=1∑Knnkgk
又因为设备
k
k
k可以用局部数据更新参数:
w
t
+
1
k
=
w
t
−
α
g
k
α
g
k
=
w
t
−
w
t
+
1
k
w^k_{t+1} = w_t - \alpha g_k \\ \alpha g_k = w_t - w^k_{t+1}
wt+1k=wt−αgkαgk=wt−wt+1k
代入上面公式:
w
t
+
1
=
w
t
−
∑
k
=
1
K
n
k
n
(
w
t
−
w
t
+
1
k
)
=
w
t
−
∑
k
=
1
K
n
k
n
w
t
+
∑
k
=
1
K
n
k
n
w
t
+
1
k
w
t
+
1
=
∑
k
=
1
K
n
k
n
w
t
+
1
k
w_{t+1} = w_t - \sum_{k=1}^K \frac{n_k}{n} (w_t - w^k_{t+1}) = w_t - \frac{\sum_{k=1}^K n_k}{n}w_t + \sum_{k=1}^{K}\frac{n_k}{n}w^k_{t+1} \\ w_{t+1} = \sum_{k=1}^{K}\frac{n_k}{n}w^k_{t+1}
wt+1=wt−k=1∑Knnk(wt−wt+1k)=wt−n∑k=1Knkwt+k=1∑Knnkwt+1kwt+1=k=1∑Knnkwt+1k
上面是进行一次梯度下降,如果进行多次梯度下降,设备
k
k
k的更新参数公式为:
w
t
+
1
k
=
w
t
−
α
g
k
1
−
α
g
k
2
−
.
.
.
−
α
g
k
e
p
o
c
h
其中
g
k
i
表示第
i
次梯度下降
w
t
+
1
k
=
w
t
−
α
(
g
k
1
+
g
k
2
+
.
.
.
+
g
k
e
p
o
c
h
)
令
g
k
=
g
k
1
+
g
k
2
+
.
.
.
+
g
k
e
p
o
c
h
,
则有
w
t
+
1
k
=
w
t
−
α
g
k
w_{t+1}^k = w_t - \alpha g_k^1 - \alpha g_k^2 \ \ - ...-\ \ \alpha g_k^{epoch} \ \ 其中 g_k^i表示第i次梯度下降 \\ w_{t+1}^k = w_t - \alpha (g_k^1 + g_k^2 \ \ + ...+g_k^{epoch}) \\ 令 \ \ \ g_k = g_k^1 + g_k^2 \ \ + ...+g_k^{epoch}, \ \ 则有 \ \ \ w_{t+1}^k = w_t-\alpha g_k
wt+1k=wt−αgk1−αgk2 −...− αgkepoch 其中gki表示第i次梯度下降wt+1k=wt−α(gk1+gk2 +...+gkepoch)令 gk=gk1+gk2 +...+gkepoch, 则有 wt+1k=wt−αgk
所以可以在本地进行多次梯度下降并更新本地模型参数,然后将本地模型参数发送给服务器,服务器对这些参数进行加权平均得到全局模型参数,最终发送给各个设备,这样就能减少客户端和服务器间的通信次数。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。