正規分布間のKL情報量を計算する
正規分布間のKL情報量を計算する
gist: 正規分布間のKL情報量 · GitHub
KL情報量が1や2といったときに,どのくらいの大きさかよくわからなかったので, 標準正規分布を基準にしたときどれくらいズレた正規分布だとこの大きさになるのか調べてみた.
つまり,下の図を作りたかった.
実行環境
- Windows10
- conda 4.6.12
- python 3.7.0
- conda 4.6.12
KL情報量
確率変数が連続の場合のKL情報量の式は以下になる.
$$ D_{KL}(p || q) = \int_{ -\infty }^{ \infty } p(x) \ln{\frac{ p(x) }{ q(x) }} dx $$
2つの正規分布の確率密度関数
とすると,
とすると,
片方の分布を標準正規分布にして,もう一方の正規分布の平均をズラしていくと以下のGIFになる. この正規分布間のKL情報量を求めていくことを考える.
正規分布間のKL情報量を導出
自己情報量の差$- \log{q} - (-\log{p})$を計算する.
第3項の積分を展開する.
= E[X2] - (E[X])2 \\ E[X2] = V[X] + (E[X])2 \\ = \sigma_12 + \mu_12 ]
よって,
最終的に$D_{KL}$は,
基準の正規分布$p(x)$を標準正規分布にしたときのKL情報量
平均を変数,分散1に固定した正規分布と標準正規分布でのKL情報量
アニメーショングラフの作成
正規分布の平均パラメータをズラしていったときのKL情報量を描画してみる.
import scipy as sp import scipy.stats as stats import matplotlib.pyplot as plt import matplotlib.animation as animation plt.rcParams["animation.convert_path"] = "C:\Program Files\ImageMagick-7.0.8-Q16\magick.exe" %matplotlib inline fig = plt.figure(figsize=(16, 9)) ax1 = fig.add_subplot(211) ax2 = fig.add_subplot(212) ims = [] # ax_1 N x = sp.linspace(-5, 5, 100) pdf = stats.norm(loc=0, scale=1).pdf(x) # line, = ax1.plot(x, pdf, lw=2, label="$N(\mu_2, \sigma_2^2)$") ax1.plot(x, pdf, linestyle="--", label="$N(0, 1^2)$") # ax2 D_KL mu_2 = sp.linspace(-5, 5, 100) D_KL = mu_2**2 / 2 ax2.plot(mu_2, D_KL) ax2.set_xlabel("$\mu_2$") ax2.set_ylabel("$D_{KL}(N(0, 1^2) || N(\mu_2, 1^2))$") ax2.grid() # ani im = ax1.plot(x, stats.norm(loc=0, scale=1).pdf(x), linestyle="-", color="r", label="$N(\mu_2, 1^2)$") ax1.grid(True) ax1.legend(loc='upper left') ims.append(im) for mu in sp.linspace(0, 4): pdf_N_mu = stats.norm(loc=mu, scale=1).pdf(x) im, = ax1.plot(x, pdf_N_mu, linestyle="-", color="r", label="$N(\mu_2, 1^2)$") im_v = ax1.vlines([mu], 0, 0.4, color="k", linestyles="--") D = mu**2/2 im_v_ax2 = ax2.vlines([mu], 0, D, color="k", linestyles="--") im_v_ax2_t = ax2.text(0, 6, "$D_{KL}=" + str(D.round(3)) + "$", ha='left', va='bottom', fontsize=40) ims.append([im, im_v, im_v_ax2, im_v_ax2_t]) plt.show() ani = animation.ArtistAnimation(fig, ims) ani.save('move_mean_N_and_D.gif', writer="imagemagick")
KL情報量から平均パラメータ$\mu_2$を求める式にすると,
%matplotlib inline D_KL = sp.linspace(0, 5, 100) mu_2 = sp.sqrt( 2 * D_KL) plt.plot(D_KL, mu_2) plt.plot(D_KL, -mu_2) plt.xlabel("$D_{KL}(N(0, 1^2) || N(\mu_2, 1^2))$") plt.ylabel("$\mu_2$") plt.grid() plt.savefig("x-DKL_y-mu.png") plt.show()
平均パラメータとKL情報量を代表的な値に変えて値をみる.
pair = set() for mu in sp.arange(0, 10+1): D = mu**2 / 2 pair.add((mu, D)) for D in sp.arange(0, 10+1): mu = sp.sqrt(2 * D) pair.add((mu, D)) pair_l = list(pair) pair_l = sorted(pair_l, key=lambda x: x[0]) pair_l
[(0, 0.0), (1, 0.5), (1.4142135623730951, 1), (2, 2.0), (2.449489742783178, 3), (2.8284271247461903, 4), (3, 4.5), (3.1622776601683795, 5), (3.4641016151377544, 6), (3.7416573867739413, 7), (4, 8.0), (4.242640687119285, 9), (4.47213595499958, 10), (5, 12.5), (6, 18.0), (7, 24.5), (8, 32.0), (9, 40.5), (10, 50.0)]
分散を変数,平均を0に固定した正規分布と標準正規分布でのKL情報量
[tex: N( 0, \sigma_22 \cdot 12 ) ] とおく. 平均をズラした場合と違って,スケール倍を変えていった$D_{KL}$の変化を見ることになる.
import scipy as sp import matplotlib.pyplot as plt %matplotlib inline sigma_2 = sp.linspace(0.5, 5, 100) D_KL = sp.log(sigma_2) + 1 / ( 2 * sigma_2**2 ) - 0.5 plt.plot(sigma_2, D_KL) plt.xlabel("$0.5 \leq \sigma_2 \leq 5$") plt.ylabel("$D_{KL}(N(0, 1^2) || N(0, \sigma_2^2))$") plt.grid() plt.show()
このグラフの特徴として,
- を境にKL情報量の増加・減少傾向(傾き)が違う
- :KL情報量の増加・減少傾向(傾き)が( と比較して)大きい
- :KL情報量の増加・減少傾向(傾き)が(と比較して)小さい
数式を見ると
のとき[tex: \frac{1}{2 \sigma_22} > \log{ \left( \sigma_2 \right)}] 二乗で増加する.
のとき[tex: \log{ ( \sigma_2 ) } > \frac{1}{ 2 \sigma_22} \fallingdotseq 0 $] logスケールで増加する.
この特徴は面白いので横軸を対数にして,範囲 でグラフで見ると KL情報量の差が大きいことがわかる.
これを解釈すると, 最適化によりKL情報量の最小化をすると, KL情報量が小さくなりやすい側の分散パラメータが採択されやすいといえる. つまり,特別精度のいい(分散パラメータが小さい)正規分布でなく, 元の正規分布より精度の悪い(分散の大きい)パラメータが推定されやすい.
これは悪いほうで見積もるという人間の直感に近い考え方に合っている.
# log scale import scipy as sp import matplotlib.pyplot as plt from matplotlib.font_manager import FontProperties %matplotlib inline fig = plt.figure(figsize=(16/1.3, 9/1.3)) sigma_2 = sp.linspace(0.5, 2, 100) D_KL = sp.log(sigma_2) + 1 / ( 2 * sigma_2**2 ) - 0.5 plt.plot(sigma_2, D_KL) plt.vlines([1], 0, .8, linestyles="--") fp = FontProperties(fname=r'C:\WINDOWS\Fonts\YuGothM.ttc', size=14) plt.text(1.02, .75, "正規分布間のKL情報量最小化で最適化すると\n元の分布より大きい分散パラメータが採用されやすい\n(より小さい分散パラメータは採用されにくい)", linespacing=2.5, va="top", fontsize=36, fontproperties=fp) plt.xscale("log", basex=2) plt.xlabel("$0.5 \leq \sigma_2 \leq 2$ (scale log)") plt.ylabel("$D_{KL}(N(0, 1^2) || N(0, \sigma_2^2))$") plt.grid() plt.show()
アニメーション
import scipy as sp import scipy.stats as stats import matplotlib.pyplot as plt import matplotlib.animation as animation plt.rcParams["animation.convert_path"] = "C:\Program Files\ImageMagick-7.0.8-Q16\magick.exe" %matplotlib inline fig = plt.figure(figsize=(16, 9)) ax1 = fig.add_subplot(211) ax2 = fig.add_subplot(212) ims = [] # ax_1 N x = sp.linspace(-4, 4, 200) pdf = stats.norm(loc=0, scale=1).pdf(x) ax1.plot(x, pdf, linestyle="--", label="$N(0, 1^2)$") # ax2 D_KL sigma_2 = sp.linspace(0.1, 4, 100) D_KL = sp.log(sigma_2) + 1 / ( 2 * sigma_2**2 ) - 0.5 ax2.plot(sigma_2, D_KL) ax2.set_xlabel("$\sigma_2$") ax2.set_ylabel("$D_{KL}(N(0, 1^2) || N(0, \sigma_2^2))$") ax2.grid() ax2.set_ylim(0, 1) # ani im = ax1.plot(x, stats.norm(loc=0, scale=1).pdf(x), linestyle="-", color="r", label="$N(0, \sigma_2^2)$") ax1.grid(True) ax1.legend(loc='upper left') ims.append(im) for sigma in sp.linspace(0.5, 4): pdf_N_mu = stats.norm(loc=0, scale=sigma).pdf(x) im, = ax1.plot(x, pdf_N_mu, linestyle="-", color="r", label="$N(0, \sigma_2^2)$") im_t = ax1.text(-3, 0.6, "$\sigma_2 = " + str(sigma.round(3)) + "$", ha='left', va='bottom', fontsize=40) D = sp.log(sigma) + 1 / ( 2 * sigma**2 ) - 0.5 im_v_ax2 = ax2.vlines([sigma], 0, D, linestyles="--") im_v_ax2_t = ax2.text(1, 0.5, "$D_{KL}=" + str(D.round(3)) + "$", ha='left', va='bottom', fontsize=40) ims.append([im, im_t, im_v_ax2, im_v_ax2_t]) plt.show() ani = animation.ArtistAnimation(fig, ims) ani.save('move_std_N_and_D.gif', writer="imagemagick")
分散を代表的な値に変えてKL情報量をみてみると,
pair = set() for sigma in 1 / sp.arange(1, 10+1): D = sp.log(sigma) + 1 / ( 2 * sigma**2 ) - 0.5 pair.add((sigma.round(3), D.round(3))) for sigma in sp.arange(1, 10+1): D = sp.log(sigma) + 1 / ( 2 * sigma**2 ) - 0.5 pair.add((sigma, D.round(3))) pair_l = list(pair) pair_l = sorted(pair_l, key=lambda x: x[0]) pair_l
[(0.1, 47.197), (0.111, 37.803), (0.125, 29.421), (0.143, 22.054), (0.167, 15.708), (0.2, 10.391), (0.25, 6.114), (0.333, 2.901), (0.5, 0.807), (1.0, 0.0), (2, 0.318), (3, 0.654), (4, 0.918), (5, 1.129), (6, 1.306), (7, 1.456), (8, 1.587), (9, 1.703), (10, 1.808)]