はしくれエンジニアもどきのメモ

情報系技術・哲学・デザインなどの勉強メモ・備忘録です。

正規分布間のKL情報量を計算する

正規分布間のKL情報量を計算する

gist: 正規分布間のKL情報量 · GitHub

KL情報量が1や2といったときに,どのくらいの大きさかよくわからなかったので, 標準正規分布を基準にしたときどれくらいズレた正規分布だとこの大きさになるのか調べてみた.

つまり,下の図を作りたかった.

平均をズラしていった時のKL情報量

実行環境

  • Windows10

KL情報量

確率変数が連続の場合のKL情報量の式は以下になる.

$$ D_{KL}(p || q) = \int_{ -\infty }^{ \infty } p(x) \ln{\frac{ p(x) }{ q(x) }} dx $$

2つの正規分布確率密度関数

 X_1 \sim N( \mu_1, \sigma_1^2 ) とすると,


p(x) =  \frac{ 1 }{ \sqrt{ 2 \pi } \sigma_1 } \exp{ \left\{ -\frac{(x - \mu_1)^2}{2 \sigma_1^2} \right\} }

 X_2 \sim N( \mu_2, \sigma_2^2 )とすると,


q(x) =  \frac{ 1 }{ \sqrt{ 2 \pi } \sigma_2 } \exp{ \left\{ -\frac{(x - \mu_2)^2}{2 \sigma_2^2} \right\} }

片方の分布を標準正規分布にして,もう一方の正規分布の平均をズラしていくと以下のGIFになる. この正規分布間のKL情報量を求めていくことを考える.

正規分布をズラしていくイメージ

正規分布間のKL情報量を導出

自己情報量の差$- \log{q} - (-\log{p})$を計算する.

 { 
\begin{eqnarray}
- \log{q( x )} - ( - \log{p( x )} ) 
&=& \log{p( x )} - \log{q( x )} \\\\
&=& \log{ \frac{ p(x) }{ q(x) } }\\\\
&=& \log{ \left\{ \frac{ \frac{1}{\sqrt{2 \pi} \sigma\_1} \exp{ \left\{ - \frac{(x - \mu\_1)^2}{2 \sigma\_1^2} \right\} } }{ \frac{1}{\sqrt{2 \pi} \sigma\_2} \exp{ \left\{ - \frac{(x - \mu\_2)^2}{2 \sigma\_2^2} \right\} } } \right\} } \\\\
&=& \log{ \left[ \frac{\frac{1}{ \sigma\_1}}{\frac{1}{ \sigma\_2}} \exp{ \left\{ -\frac{(x - \mu\_1)^2}{2 \sigma\_1^2} + \frac{(x - \mu\_2)^2}{2 \sigma\_2^2} \right\} } \right] } \\\\
&=& \log{ \left( \frac{\sigma\_2}{\sigma\_1} \right) } + \frac{1}{2} \left( \frac{(x - \mu\_2)^2}{\sigma\_2^2} - \frac{(x - \mu\_1)^2}{\sigma\_1^2}  \right) \\\\
\end{eqnarray}
}

第3項の積分を展開する.


\begin{eqnarray}
\int_{-\infty}^{\infty} (x - \mu\_2)^2 p(x) dx
&=& \int_{-\infty}^{\infty} x^2 p(x) dx - 2\mu\_2 \int_{-\infty}^{\infty} x p(x) dx + \mu\_2 \int_{-\infty}^{\infty} p(x) dx \\\\
&=& \int_{-\infty}^{\infty} x^2 p(x) dx - 2\mu\_2 \mu\_1+ \mu\_2 \cdot 1
\end{eqnarray}


V[X = E[X2] - (E[X])2 \\ E[X2] = V[X] + (E[X])2 \\ = \sigma_12 + \mu_12 ]

よって,


\begin{eqnarray}
\int_{ - \infty }^{ \infty } ( x - \mu_2 )^2 p(x) dx
&=& (\sigma_1^2 + \mu_1^2) - 2\mu_2 \mu_1 + \mu_2 \\\\
&=& \sigma_1^2 + (\mu_1 - \mu_2)^2  \\\\
\end{eqnarray}

最終的に$D_{KL}$は,


\begin{eqnarray}
D_{KL}(p || q)
= \log{ \left( \frac{\sigma_2}{\sigma_1} \right)} - \frac{1}{2} + \frac{1}{2 \sigma_2^2 } (\sigma_1^2 + (\mu_1 - \mu_2)^2 ) \\\\
\end{eqnarray}

基準の正規分布$p(x)$を標準正規分布にしたときのKL情報量


\begin{eqnarray}
D_{KL}(N(0, 1^2) || q)
&=& \log{ \left( \frac{\sigma_2}{\sigma_1} \right)}
+ \frac{1}{2\sigma_2^2} (\sigma_1^2 + (\mu_1 - \mu_2)^2 )
- \frac{1}{2} \\\\
&=& \log{ \left( \frac{\sigma_2}{1} \right)}
+ \frac{1}{2 \sigma_2^2} (1^2 + (0 - \mu_2)^2 )
- \frac{1}{2} \\\\
&=& \log{ \left( \sigma_2\right)}
+ \frac{1}{2 \sigma_2^2} (1 + \mu_2^2 )
- \frac{1}{2} \\\\
\end{eqnarray}

平均を変数,分散1に固定した正規分布と標準正規分布でのKL情報量

.
\begin{eqnarray}
D_{KL}(N(0, 1^2) || N(\mu_2, 1^2))
&=& \log{ \left( \sigma_2\right)}
+ \frac{1}{2 \sigma_2^2} (1 + \mu_2^2 )
- \frac{1}{2} \\\\
&=& \log{ \left( 1 \right)} + \frac{1}{2 \cdot 1^2} (1 + \mu_2^2 ) - \frac{1}{2} \\\\
&=& 0 + \frac{1}{2} (1 + \mu_2^2 ) - \frac{1}{2} \\\\
&=&  \frac{\mu_2^2 }{2} + \frac{1}{2}  - \frac{1}{2} \\\\
&=&  \frac{\mu_2^2 }{2}
\end{eqnarray}

平均を変えたときのKL情報量

アニメーショングラフの作成

正規分布の平均パラメータをズラしていったときのKL情報量を描画してみる.

import scipy as sp
import scipy.stats as stats
import matplotlib.pyplot as plt
import matplotlib.animation as animation
plt.rcParams["animation.convert_path"] = "C:\Program Files\ImageMagick-7.0.8-Q16\magick.exe"

%matplotlib inline

fig = plt.figure(figsize=(16, 9))
ax1 = fig.add_subplot(211)
ax2 = fig.add_subplot(212)

ims = []

# ax_1 N
x = sp.linspace(-5, 5, 100)
pdf = stats.norm(loc=0, scale=1).pdf(x)
# line, = ax1.plot(x, pdf, lw=2, label="$N(\mu_2, \sigma_2^2)$")
ax1.plot(x, pdf, linestyle="--", label="$N(0, 1^2)$")

# ax2 D_KL
mu_2 = sp.linspace(-5, 5, 100)
D_KL = mu_2**2 / 2
ax2.plot(mu_2, D_KL)
ax2.set_xlabel("$\mu_2$")
ax2.set_ylabel("$D_{KL}(N(0, 1^2) || N(\mu_2, 1^2))$")
ax2.grid()

# ani
im = ax1.plot(x, stats.norm(loc=0, scale=1).pdf(x), linestyle="-", color="r", label="$N(\mu_2, 1^2)$")
ax1.grid(True)
ax1.legend(loc='upper left')
ims.append(im)
for mu in sp.linspace(0, 4):
    pdf_N_mu = stats.norm(loc=mu, scale=1).pdf(x)
    im, = ax1.plot(x, pdf_N_mu, linestyle="-", color="r", label="$N(\mu_2, 1^2)$")
    im_v = ax1.vlines([mu], 0, 0.4, color="k", linestyles="--")
    D = mu**2/2
    im_v_ax2 = ax2.vlines([mu], 0, D, color="k", linestyles="--")
    im_v_ax2_t = ax2.text(0, 6, "$D_{KL}=" + str(D.round(3)) + "$", ha='left', va='bottom', fontsize=40)
    ims.append([im, im_v, im_v_ax2, im_v_ax2_t])

plt.show()

ani = animation.ArtistAnimation(fig, ims)
ani.save('move_mean_N_and_D.gif', writer="imagemagick")

平均をズラしていった時のKL情報量(アニメーション)

KL情報量から平均パラメータ$\mu_2$を求める式にすると,


\mu_2^2 = 2 D_{KL}( N( 0, 1^2) \|\| N( \mu_2, 1^2 ) ) \\\\
\mu_2 = \pm \sqrt{ 2 D_{KL}( N(0, 1^2 ) \|\| N( \mu_2, 1^2 ) ) }
%matplotlib inline

D_KL = sp.linspace(0, 5, 100)
mu_2 = sp.sqrt( 2 * D_KL)

plt.plot(D_KL, mu_2)
plt.plot(D_KL, -mu_2)
plt.xlabel("$D_{KL}(N(0, 1^2) || N(\mu_2, 1^2))$")
plt.ylabel("$\mu_2$")
plt.grid()
plt.savefig("x-DKL_y-mu.png")
plt.show()

横軸をKL情報量,縦軸を平均パラメータにしたとき

平均パラメータとKL情報量を代表的な値に変えて値をみる.

pair = set()
for mu in sp.arange(0, 10+1):
    D = mu**2 / 2
    pair.add((mu, D))

for D in sp.arange(0, 10+1):
    mu = sp.sqrt(2 * D)
    pair.add((mu, D))

pair_l = list(pair)
pair_l = sorted(pair_l, key=lambda x: x[0])
pair_l
[(0, 0.0),
 (1, 0.5),
 (1.4142135623730951, 1),
 (2, 2.0),
 (2.449489742783178, 3),
 (2.8284271247461903, 4),
 (3, 4.5),
 (3.1622776601683795, 5),
 (3.4641016151377544, 6),
 (3.7416573867739413, 7),
 (4, 8.0),
 (4.242640687119285, 9),
 (4.47213595499958, 10),
 (5, 12.5),
 (6, 18.0),
 (7, 24.5),
 (8, 32.0),
 (9, 40.5),
 (10, 50.0)]

分散を変数,平均を0に固定した正規分布と標準正規分布でのKL情報量

[tex: N( 0, \sigma_22 \cdot 12 ) ] とおく. 平均をズラした場合と違って,スケール倍を変えていった$D_{KL}$の変化を見ることになる.


\begin{eqnarray}
D_{KL}(N(0, 1^2) || N(0, \sigma_2^2))
&=& \log{ \left( \sigma_2 \right)} + \frac{1}{2 \sigma_2^2} (1 + \mu_2^2 ) - \frac{1}{2} \\\\
&=& \log{ \left( \sigma_2 \right)} + \frac{1}{2 \sigma_2^2} (1 + 0^2 )
- \frac{1}{2} \\\\
&=& \log{ \left( \sigma_2 \right)} + \frac{1}{2 \sigma_2^2} - \frac{1}{2} \\\\
\end{eqnarray}
import scipy as sp
import matplotlib.pyplot as plt

%matplotlib inline

sigma_2 = sp.linspace(0.5, 5, 100)
D_KL = sp.log(sigma_2) + 1 / ( 2 * sigma_2**2 ) - 0.5

plt.plot(sigma_2, D_KL)
plt.xlabel("$0.5 \leq \sigma_2 \leq 5$")
plt.ylabel("$D_{KL}(N(0, 1^2) || N(0, \sigma_2^2))$")
plt.grid()
plt.show()

分散を変えたときのKL情報量

このグラフの特徴として,

  •  \sigma = 1 を境にKL情報量の増加・減少傾向(傾き)が違う
  •  \sigma \lt 1:KL情報量の増加・減少傾向(傾き)が( \sigma \gt 1 と比較して)大きい
  •  \sigma \gt 1:KL情報量の増加・減少傾向(傾き)が( \sigma \lt 1と比較して)小さい

数式を見ると


\begin{eqnarray}
D_{KL}(N(0, 1^2) \|\| N(0, \sigma_2^2))
&=& \log{ \left( \sigma_2 \right) }+ \frac{1}{2 \sigma_2^2} - \frac{1}{2} \\\\
\end{eqnarray}
  •  \sigma_2 \ll 1 のとき[tex: \frac{1}{2 \sigma_22} > \log{ \left( \sigma_2 \right)}] 二乗で増加する.

  •  \sigma_2 \gg 1のとき[tex: \log{ ( \sigma_2 ) } > \frac{1}{ 2 \sigma_22} \fallingdotseq 0 $] logスケールで増加する.

この特徴は面白いので横軸を対数にして,範囲  0.1 \leq \sigma_2 \leq 10 でグラフで見ると KL情報量の差が大きいことがわかる.

これを解釈すると, 最適化によりKL情報量の最小化をすると, KL情報量が小さくなりやすい \sigma \gt 1側の分散パラメータが採択されやすいといえる. つまり,特別精度のいい(分散パラメータが小さい)正規分布でなく, 元の正規分布より精度の悪い(分散の大きい)パラメータが推定されやすい.

これは悪いほうで見積もるという人間の直感に近い考え方に合っている.

# log scale
import scipy as sp
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties

%matplotlib inline

fig = plt.figure(figsize=(16/1.3, 9/1.3))
sigma_2 = sp.linspace(0.5, 2, 100)
D_KL = sp.log(sigma_2) + 1 / ( 2 * sigma_2**2 ) - 0.5

plt.plot(sigma_2, D_KL)
plt.vlines([1], 0, .8, linestyles="--")
fp = FontProperties(fname=r'C:\WINDOWS\Fonts\YuGothM.ttc', size=14)
plt.text(1.02, .75, "正規分布間のKL情報量最小化で最適化すると\n元の分布より大きい分散パラメータが採用されやすい\n(より小さい分散パラメータは採用されにくい)", linespacing=2.5, va="top", fontsize=36, fontproperties=fp)
plt.xscale("log", basex=2)
plt.xlabel("$0.5 \leq \sigma_2 \leq 2$ (scale log)")
plt.ylabel("$D_{KL}(N(0, 1^2) || N(0, \sigma_2^2))$")

plt.grid()
plt.show()

横軸をlogスケールに

アニメーション

import scipy as sp
import scipy.stats as stats
import matplotlib.pyplot as plt
import matplotlib.animation as animation
plt.rcParams["animation.convert_path"] = "C:\Program Files\ImageMagick-7.0.8-Q16\magick.exe"

%matplotlib inline

fig = plt.figure(figsize=(16, 9))
ax1 = fig.add_subplot(211)
ax2 = fig.add_subplot(212)

ims = []

# ax_1 N
x = sp.linspace(-4, 4, 200)
pdf = stats.norm(loc=0, scale=1).pdf(x)
ax1.plot(x, pdf, linestyle="--", label="$N(0, 1^2)$")

# ax2 D_KL
sigma_2 = sp.linspace(0.1, 4, 100)
D_KL = sp.log(sigma_2) + 1 / ( 2 * sigma_2**2 ) - 0.5
ax2.plot(sigma_2, D_KL)
ax2.set_xlabel("$\sigma_2$")
ax2.set_ylabel("$D_{KL}(N(0, 1^2) || N(0, \sigma_2^2))$")
ax2.grid()
ax2.set_ylim(0, 1)

# ani
im = ax1.plot(x, stats.norm(loc=0, scale=1).pdf(x), linestyle="-", color="r", label="$N(0, \sigma_2^2)$")
ax1.grid(True)
ax1.legend(loc='upper left')
ims.append(im)
for sigma in sp.linspace(0.5, 4):
    pdf_N_mu = stats.norm(loc=0, scale=sigma).pdf(x)
    im, = ax1.plot(x, pdf_N_mu, linestyle="-", color="r", label="$N(0, \sigma_2^2)$")
    im_t = ax1.text(-3, 0.6, "$\sigma_2 = " + str(sigma.round(3)) + "$", ha='left', va='bottom', fontsize=40)
    D = sp.log(sigma) + 1 / ( 2 * sigma**2 ) - 0.5
    im_v_ax2 = ax2.vlines([sigma], 0, D, linestyles="--")
    im_v_ax2_t = ax2.text(1, 0.5, "$D_{KL}=" + str(D.round(3)) + "$", ha='left', va='bottom', fontsize=40)
    ims.append([im, im_t, im_v_ax2, im_v_ax2_t])

plt.show()

ani = animation.ArtistAnimation(fig, ims)
ani.save('move_std_N_and_D.gif', writer="imagemagick")

分散をズラしていった時のKL情報量(アニメーション)

分散を代表的な値に変えてKL情報量をみてみると,

pair = set()
for sigma in 1 / sp.arange(1, 10+1):
    D = sp.log(sigma) + 1 / ( 2 * sigma**2 ) - 0.5
    pair.add((sigma.round(3), D.round(3)))

for sigma in sp.arange(1, 10+1):
    D = sp.log(sigma) + 1 / ( 2 * sigma**2 ) - 0.5
    pair.add((sigma, D.round(3)))

pair_l = list(pair)
pair_l = sorted(pair_l, key=lambda x: x[0])
pair_l
[(0.1, 47.197),
 (0.111, 37.803),
 (0.125, 29.421),
 (0.143, 22.054),
 (0.167, 15.708),
 (0.2, 10.391),
 (0.25, 6.114),
 (0.333, 2.901),
 (0.5, 0.807),
 (1.0, 0.0),
 (2, 0.318),
 (3, 0.654),
 (4, 0.918),
 (5, 1.129),
 (6, 1.306),
 (7, 1.456),
 (8, 1.587),
 (9, 1.703),
 (10, 1.808)]