逆誤差伝播法(バックプロパゲーション)で関数の微分値を求める
逆誤差伝播法(バックプロパゲーション)で関数の微分値を求める
TheanoやPytorchのforward関数やbackward関数がどういう計算しているのか知りたくなったので, そのメモ.
今回はニューラルネットワークの線形和部分は無視して, 関数の微分値を求めることを考える.
逆誤差伝播法の考え方
逆誤差伝播法は微分の連鎖律(chain rule)を利用する(正直これが全てといっても過言ではない),
微分の連鎖律は,高校数学の微積では"合成関数の微分"や"媒介変数の微分"として紹介される.
合成関数の微分がどんなものだったかというと, 関数$f(x)$が別の関数$A(x)$で表されるとき, $x$での微分は,以下の式で表せる.
$$ \frac{df}{dx} = \frac{df}{dA}\frac{dA}{dx} $$
つまり,複雑な関数$f(x)$を直接微分できなくても,(小さい)関数$A(x)$に分けてその微分ができれば求めることができるというものである.
さらに,関数$A(x)$がまだ複雑な形で, より簡易な関数$B(x)$で表すことができる場合(つまり$A(B)=B(x)$,同様にして求められる.
$$ \frac{df}{dx} = \frac{df}{dA}\frac{dA}{dB}\frac{dB}{dx} $$
また,入れ子になる関数を増やしても同様の積で求められる.
微分の連鎖律
微分の連鎖律は,合成関数の微分を一般化しており, 変数の数を任意の数$n$,入れ子になる関数の数$k$を任意の数にして拡張しても成立することを表している.
複数変数は和で繋ぐ,入れ子になる関数は積で繋ぐ.
関数$f( x_{1}, \cdots, x_{n} )$が関数$g_{11}( g_{12}(x_{1}, \cdots x_{n} ) ), \cdots, g_{n1}(g_{n2}( x_1, \cdots x_{n} ))$で表せ,これらが関数$g_{12}( g_{13}(x_{1}, \cdots x_{ n } )), \cdots, g_{n2}(g_{n3}( x_1, \cdots x_{n} ))$で表せ,最終的に$g_{ 1k }( x_{1}, \cdots x_{ n } ), \cdots, g_{nk}(x_1, \cdots x_{n} )$が$x_{1}, \cdots x_{n}$で表せるなら,$x_{1}$での微分は
$$ \frac{df}{dx_1} = \frac{df}{dg_{11}}\frac{dg_{11}}{dg_{12}} \cdots \frac{dg_{1k}}{dx_1} \\ + \frac{df}{dg_{21}}\frac{dg_{21}}{dg_{22}} \cdots \frac{dg_{2k}}{dx_1} \\ + \cdots \\ + \frac{df}{dg_{n1}}\frac{dg_{n1}}{dg_{n2}} \cdots \frac{dg_{nk}}{dx_1} $$
($x_2$から$x_n$での微分も同様にして求まる.)
これはヤコビ行列でも求まる.
参考:
逆誤差伝播では
逆誤差伝播法では,微分の連鎖律をニューラルネットワークで使われるネットワーク図(グラフ図)に置き換えて説明される.
(置き換えてるだけで計算は同じ.)
連鎖律で現れる各関数をノード(ユニット)として捉える.
関数f(x)
が次のよう関数で構成されると考える.
x -> A(x) -> B(A) -> C(B) -> f(C)
forward処理(順伝播)とbackward処理(逆伝播)の2つを行う.
- forward処理(順伝播):入力側から,入力値を通常通り関数で計算して微分値に代入するための値を求めておく.
-
backward処理(逆伝播):出力側から考える.出力$f$を入力$C$で偏微分した関数をあらかじめ求めておき,
$\frac{\partial f}{\partial C}$は$C$の関数になっているのでforwardで求めておいた$C$の値を代入して微分値を求める.
- 次のユニット$C$でも同様に考え,出力$C$を入力$B$で偏微分した関数 $\frac{\partial C}{\partial B}$に代入して微分値を求める.ただ,このままだと関数$f$との微分ではないので出力側から微分値をもらって積を計算する. $\frac{\partial f}{\partial C}\frac{\partial C}{\partial B} = \frac{\partial f}{\partial B}$,これにより関数$f$の$B$での微分値が求まる.
- 次のユニット$C$でも同様に考え,出力$C$を入力$B$で偏微分した関数 $\frac{\partial C}{\partial B}$に代入して微分値を求める.ただ,このままだと関数$f$との微分ではないので出力側から微分値をもらって積を計算する. $\frac{\partial f}{\partial C}\frac{\partial C}{\partial B} = \frac{\partial f}{\partial B}$,これにより関数$f$の$B$での微分値が求まる.
- 同様にして,入力側のユニット$A$まで計算すると,$\frac{\partial f}{\partial A}\frac{\partial A}{\partial x} = \frac{\partial f}{\partial x}$より,微分値が求まる.
逆伝播の場合,これまで計算した微分値を出力側からもらって計算するので,この部分が逆伝播になっている. 出力側からもらう値は連鎖律の前半の項の値になっている.
下記で例として,シグモイド関数を使って説明する.
シグモイド関数を(数式で)微分
$$ f( x ) = \frac{1}{1 + \exp{( -x ) } } $$
シグモイド関数の微分値を逆誤差伝播法で求める
シグモイド関数を以下の図のような関数に分けて考える.
出力側から考える.
- C -> D:
Cの値はforwardにより事前に求まっている.
$$ \frac{\partial D}{\partial C} = (\frac{1}{C})' \\ = - \frac{1}{C^2}\\ = - \frac{1}{1.37^2}\\ = -0.533 $$
出力側には,さらにDを入力してDが出力されるユニットがあると拡張して考えると, その連鎖律は,
$$ \frac{\partial D}{\partial D} \frac{\partial D}{\partial C} = 1 \cdot \frac{\partial D}{\partial C} = -0.533 $$
- B -> C:
$$ \frac{\partial C}{\partial B} = (B+1)' = 1 $$
連鎖律より,(出力側の微分値が入力されたと考えることもでき,)
$$ \frac{\partial D}{\partial B} = \frac{\partial D}{\partial C}\frac{\partial C}{\partial B} = (-0.533) \cdot 1 = -0.533 $$
- A -> B:
偏微分する.
$$ \frac{\partial B}{\partial A} = (\exp{A})' = \exp{A} = \exp{-1} = 0.368 $$
同様に,連鎖律より,
$$ \frac{\partial D}{\partial A} = \frac{\partial D}{\partial B}\frac{\partial B}{\partial A} = -0.533 \cdot 0.368 = -0.196 $$
- x -> A:
偏微分する.
$$ \frac{\partial A}{\partial x} = (-x)' = -1 $$
同様に,連鎖律より,
$$ \frac{\partial D}{\partial x} = \frac{\partial D}{\partial A}\frac{\partial A}{\partial x} = (-0.196) \cdot (-1) = 0.196 $$
出力側の微分値と入力された値との偏微分の2つがわかれば,実際の出力との微分が求められる.
実装(python)
何がわかると実装できるかをまとめる.
- forward:入力値を受け取り,関数の値を導出,backwardで使うので保存
- backward: 出力側の微分値を受け取り,そのユニットの出力を入力で偏微分した値との積で,出力関数とユニットの入力での微分値が求まるのでそれを返す.
シグモイド関数を表現するのにどんなユニットがあればいいか.
- 定数倍する関数:$ax$,微分は$a$,シグモイド関数では$-1 \cdot x$に利用
- 指数関数:$\exp{x}$,微分は同じ$exp{x}$,シグモイド関数では$ \exp{-x}$に利用
- 定数を加算する関数:$a + x$,微分は$1$,シグモイド関数では$1 + \exp{-x}$に利用
- べき関数:$xn$,微分は$nx^{n-1}$,シグモイド関数では$(1+ \exp{-x})^{-1}$に利用
classを作ると以下のようになる.
import scipy as sp class FB: def __init__(self): pass class Sigmoid: def __init__(self): self._output = None def forward(self, input): self._output = 1 / (1 + sp.exp(- input)) return self._output def backward(self, diff_input=1): sig = self._output return sig * (1 - sig) * diff_input class X_nthPower: def __init__(self, nthPower = 1): self._output = None self._nthPower = nthPower def forward(self, input): self._output = input**self._nthPower return self._output def backward(self, diff_input=1): ''' x^n' = n x^{n-1} x^1 = (x^n)^(1/n) ''' x = self._output ** (1/self._nthPower) return self._nthPower * self._output / x * diff_input class X_addCons: def __init__(self, constant = 0): self._output = None self._constant = constant def forward(self, input): ''' X+c ''' self._output = input + self._constant return self._output def backward(self, diff_input=1): return 1 * diff_input class X_scaleCons: def __init__(self, constant = 1): self._output = None self._constant = constant def forward(self, input): ''' c x ''' self._output = self._constant * input return self._output def backward(self, diff_input=1): return self._constant * diff_input class Exp: def __init__(self): self._output = None def forward(self, input): ''' exp(x) ''' self._output = sp.exp(input) return self._output def backward(self, diff_input=1): return self._output * diff_input class _Network: def __init__(self, units): ''' units: list ''' self._units = units def forward(self,input): x = input for u in self._units: fw = u.forward(x) x = fw return fw def backward(self): diff_x = 1 for u in self._units[::-1]: bw = u.backward(diff_x) diff_x = bw return bw def createNetwork(self, units): return self._Network(units)
実際に実行すると
$ fb = FB() # forward x = 1 fb_xScaleMinus1 = fb.X_scaleCons(-1) fw1 = fb_xScaleMinus1.forward(1) $ fw1 -1 fb_exp = fb.Exp() fw2 = fb_exp.forward(fw1) $ fw2 0.36788 fb_xadd1 = fb.X_addCons(constant=1) fw3 = fb_xadd1.forward(fw2) $ fw3 1.36788 fb_xpowerMinus1 = fb.X_nthPower(nthPower=-1) $ fb_xpowerMinus1.forward(fw3) 0.7310585786300049
# backward bw1 = fb_xpowerMinus1.backward(1) $ bw1 -0.534446645388523 bw2 = fb_xadd1.backward(bw1) $ bw2 -0.534446645388523 bw3 = fb_exp.backward(bw2) $ bw3 -0.19661193324148188 bw4 = fb_xScaleMinus1.backward(bw3) $ bw4 0.19661193324148188