正規逆ガンマ分布の確率密度関数の導出・可視化・サンプリング
正規逆ガンマ分布の確率密度関数の導出・可視化・サンプリング
Udemyの「ベイズ推定とグラフィカルモデル:コンピュータビジョン基礎1」の授業で, ベイズ統計で正規分布のパラメータの分布に使われる正規逆ガンマ分布の紹介があったので,導出と可視化のメモ.
資料として以下を使用しているので記号はそれに合わせる.
"Computer vision: models, learning and inference" by Simon Prince
https://www.udemy.com/computervision/learn/lecture/617408#questions/7270346
正規逆ガンマ分布について
Normal-inverse-gamma distribution - Wikipedia
ベイズ統計で,正規分布のパラメータ(平均$\mu$, 分散)の同時分布 として使われる. 何故使われるかというと,正規分布の尤度に対して共役事前分布の関係になっているためで, 事後分布も正規逆ガンマ分布になり,分布の更新計算が容易なため.
事後分布 | 尤度 | 共役事前分布 |
---|---|---|
正規逆ガンマ分布 | 正規分布 | 正規逆ガンマ分布 |
導出
正規逆ガンマ分布は,正規分布のパラメータを以下のように考えた同時分布になっている.
分散パラメータが逆ガンマ分布に従う,つまり,
平均パラメータが,上の分散パラメータを使った正規分布に従う,厳密には,
以上の同時分布により,逆ガンマ分布と正規分布の積による密度関数を導出できる.
改めて,逆ガンマ分布と正規分布を以下のように定義する.
正規逆ガンマ分布は上記の分布の積により,以下の式で表される.
密度関数を導出すると,
密度計算のコード
この式より,手打ちした以下のコードで確率密度を計算できる.
import scipy as sp from scipy import special def norm_inv_gamma_pdf_mamually(mu, sigma2, alpha, beta, gamma, delta): ''' only one side vec: mu: vec sigma2: vec ''' c = sp.sqrt(gamma / (2 * sp.pi * sigma2)) g = beta**alpha / special.gamma(alpha) acc = 1 / (sigma2**(alpha + 1)) e = sp.exp(- (2 * beta + gamma * (delta - mu)**2) / (2 * sigma2)) return c * g * acc * e
上記コードでもいいが,正規逆ガンマ分布はもともと,逆ガンマ分布と正規分布との積の同時分布なので, 密度も積で計算できる.
scipy.stats
には逆ガンマ分布と正規分布の密度関数pdf
を計算できるので,
スカラ量を与えて密度を計算すると以下のコードになる.
para_alpha = 1. para_beta = 1. para_delta = 0 para_gamma = 1. var_sigma2 = 1. var_mu = 0 scipy.stats.invgamma(a=para_alpha, scale=para_beta).pdf(var_sigma2) * scipy.stats.norm(loc=para_delta, scale=sp.sqrt(var_sigma2/para_gamma)).pdf(var_mu)
これを踏まえて,変数$\mu$と$\sigma^{2}$をスカラでなく配列で与えたときに, 密度行列を返すコードを作る.
import scipy as sp from scipy import stats def norm_inv_gamma_pdf(mu, sigma2, alpha, beta, gamma, delta): ''' N(mu |δ,σ2/λ)・IG(σ2|α,β) return pdf_Matrix ''' col = len(mu) row = len(sigma2) sigma2 = sp.sort(sigma2)[::-1]# desc IG = stats.invgamma(a=alpha, scale=beta) Mpdf_sigma2 = IG.pdf(x=sigma2).reshape(-1, 1) Mpdf_sigma2 = sp.repeat(Mpdf_sigma2, repeats=col, axis=1) Mpdf_N = sp.zeros_like(Mpdf_sigma2) for r in sp.arange(row): N = stats.norm(loc=delta, scale=sp.sqrt(sigma2[r]/gamma)) Mpdf_N[r] = N.pdf(x=mu) return Mpdf_N * Mpdf_sigma2
正規逆ガンマ分布をヒートマップで可視化
正規逆ガンマ分布の密度関数を可視化する. 平均パラメータ$\mu$と分散パラメータ$\sigma^{2}$との2次元になるため, ヒートマップ(2次元ヒストグラム)で可視化する.
matplotlibの場合,matplotlib.pyplot.imshow
に密度行列を渡せば描画できる.
上記コードで密度行列を計算できたのでこれを渡せばいい.
plt.imshow(cmap="hot")
を指定すれば,密度の高いところが白く,低いところが黒く,中間が赤になる.
$(\alpha, \beta, \gamma, \delta) = (1., 1., 1., 0)$の正規逆ガンマ分布で $\sigma^{2} = (0.01, 7), \mu = (-5, 5)$ の範囲の確率密度を描画する.
import scipy as sp from scipy import stats import matplotlib.pyplot as plt %matplotlib inline fig = plt.figure(figsize=(5, 5)) def norm_inv_gamma_pdf(mu, sigma2, alpha, beta, gamma, delta): ''' N(mu |δ,σ2/λ)・IG(σ2|α,β) return pdf_Matrix ''' col = len(mu) row = len(sigma2) sigma2 = sp.sort(sigma2)[::-1]# desc IG = stats.invgamma(a=alpha, scale=beta) Mpdf_sigma2 = IG.pdf(x=sigma2).reshape(-1, 1) Mpdf_sigma2 = sp.repeat(Mpdf_sigma2, repeats=col, axis=1) Mpdf_N = sp.zeros_like(Mpdf_sigma2) for r in sp.arange(row): N = stats.norm(loc=delta, scale=sp.sqrt(sigma2[r]/gamma)) Mpdf_N[r] = N.pdf(x=mu) return Mpdf_N * Mpdf_sigma2 alpha, beta, gamma, delta = 1., 1., 1., 0 sigma2 = sp.linspace(7, 0.01, 500) mu = sp.linspace(-5, 5, 500) M=norm_inv_gamma_pdf(mu, sigma2, alpha, beta, gamma, delta) plt.imshow(M, cmap="hot", extent=[mu[0], mu[-1], sigma2[-1], sigma2[0]], aspect=10/7) plt.colorbar(shrink=0.75) plt.xlabel("$\mu$") plt.ylabel("$\sigma^{2}$") plt.xlim(-5, 5) plt.ylim(0, 7) plt.tight_layout() # plt.savefig("n-inv-gamma.png") plt.show()
パラメータを変えて, "Computer vision: models, learning and inference" by Simon Prince p.41 Fig 3.6と同じ図を描画してみる.
matplotlibの場合,matplotlib.gridspec
を使うとgridlayoutのように3x2で図を配置などして描画できる.
左から - (alpha, beta, gamma, delta)=(1., 1., 1., 0) - (0.5., 1., 1., 0) - (2., 1., 1., 0) - (1., 0.5., 1., 0) - (1., 2., 1., 0) - (1., 1., 0.4., 0) - (1., 1., 4., 0) - (1., 1., 1., -2.) - (1., 1., 1., 2.)
import scipy as sp from scipy import stats from scipy import special import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec %matplotlib inline def norm_inv_gamma_pdf_mamually(mu, sigma2, alpha, beta, gamma, delta): ''' only one side vec: mu: vec sigma2: vec ''' c = sp.sqrt(gamma / (2 * sp.pi * sigma2)) g = beta**alpha / special.gamma(alpha) acc = 1 / (sigma2**(alpha + 1)) e = sp.exp(- (2 * beta + gamma * (delta - mu)**2) / (2 * sigma2)) return c * g * acc * e def norm_inv_gamma_pdf(mu, sigma2, alpha, beta, gamma, delta): ''' N(mu |δ,σ2/λ)・IG(σ2|α,β) return pdf_Matrix ''' col = len(mu) row = len(sigma2) sigma2 = sp.sort(sigma2)[::-1]# desc IG = stats.invgamma(a=alpha, scale=beta) Mpdf_sigma2 = IG.pdf(x=sigma2).reshape(-1, 1) Mpdf_sigma2 = sp.repeat(Mpdf_sigma2, repeats=col, axis=1) Mpdf_N = sp.zeros_like(Mpdf_sigma2) for r in sp.arange(row): N = stats.norm(loc=delta, scale=sp.sqrt(sigma2[r]/gamma)) Mpdf_N[r] = N.pdf(x=mu) return Mpdf_N * Mpdf_sigma2 def plot_heat(fig, ax, M, vmin, vmax, extent, cmap="hot", shrink=0.7): im = None if vmin and vmax: im = ax.imshow(M, cmap=cmap, vmin=vmin, vmax=vmax, extent=extent) else: im = ax.imshow(M, cmap=cmap, extent=extent) fig.colorbar(im, ax=ax, shrink=shrink) ax.set_aspect(1./ax.get_data_ratio()) return im def plot_heat_log(fig, ax, M, vmin, vmax, extent, cmap="hot", shrink=0.7): return plot_heat(fig, ax, sp.log(M+1e-9), vmin, vmax, extent, cmap, shrink) def plot_heat_norm_inv_gamma(alpha, beta, gamma, delta, fig, ax, mu=sp.linspace(-5, 5, 500), sigma2 = sp.linspace(7, 0.01, 500), cmap="hot", vmin=None, vmax=None, extent=[-5, 5, 0.01, 7], shrink=0.7, titleFontSize=17): M = norm_inv_gamma_pdf(mu, sigma2, alpha, beta, gamma, delta) plot_heat(fig, ax, M, cmap=cmap, vmin=vmin, vmax=vmax, extent=extent, shrink=shrink) ax.set_xlabel("$\mu$") ax.set_ylabel("$\sigma^{2}$") ax.set_title(f'({alpha:.1f}, {beta:.1f}, {gamma:.1f}, {delta:.1f})'.format( alpha, beta, gamma, delta), fontsize=titleFontSize) def plot_heat_log_norm_inv_gamma(alpha, beta, gamma, delta, fig, ax, mu=sp.linspace(-5, 5, 500), sigma2 = sp.linspace(7, 0.01, 500), cmap="hot", vmin=-20, vmax=-3, extent=[-5, 5, 0.01, 7], shrink=0.7, titleFontSize=17): M = norm_inv_gamma_pdf(mu, sigma2, alpha, beta, gamma, delta) plot_heat_log(fig, ax, M, cmap=cmap, vmin=vmin, vmax=vmax, extent=extent, shrink=shrink) ax.set_xlabel("$\mu$") ax.set_ylabel("$\sigma^{2}$") ax.set_title(f'({alpha:.1f}, {beta:.1f}, {gamma:.1f}, {delta:.1f}) log density'.format( alpha, beta, gamma, delta), fontsize=titleFontSize) fig = plt.figure(figsize=(30, 10)) gs = gridspec.GridSpec(2, 6) ax = fig.add_subplot(gs[0:2, 0:2]) plot_heat_norm_inv_gamma(1, 1, 1, 0, fig, ax) ax = fig.add_subplot(gs[0, 2]) plot_heat_norm_inv_gamma(0.5, 1, 1, 0, fig, ax) ax = fig.add_subplot(gs[1, 2]) plot_heat_norm_inv_gamma(2, 1, 1, 0, fig, ax) ax = fig.add_subplot(gs[0, 3]) plot_heat_norm_inv_gamma(1, 0.5, 1, 0, fig, ax) ax = fig.add_subplot(gs[1, 3]) plot_heat_norm_inv_gamma(1, 2, 1, 0, fig, ax) ax = fig.add_subplot(gs[0, 4]) plot_heat_norm_inv_gamma(1, 1, 0.4, 0, fig, ax) ax = fig.add_subplot(gs[1, 4]) plot_heat_norm_inv_gamma(1, 1, 4, 0, fig, ax) ax = fig.add_subplot(gs[0, 5]) plot_heat_norm_inv_gamma(1, 1, 1, -2, fig, ax) ax = fig.add_subplot(gs[1, 5]) plot_heat_norm_inv_gamma(1, 1, 1, 2, fig, ax) plt.tight_layout() plt.savefig("normal-inv-gamma_grid.png") plt.show()
確率密度0から1の中で差がわかるように対数密度にする. "Computer vision: models, learning and inference" by Simon Prince p.41 Fig 3.6もlogの文字はないが対数密度になっている.
%matplotlib inline fig = plt.figure(figsize=(30, 10)) gs = gridspec.GridSpec(2, 6) ax = fig.add_subplot(gs[0:2, 0:2]) plot_heat_log_norm_inv_gamma(1, 1, 1, 0, fig, ax) ax = fig.add_subplot(gs[0, 2]) plot_heat_log_norm_inv_gamma(0.5, 1, 1, 0, fig, ax) ax = fig.add_subplot(gs[1, 2]) plot_heat_log_norm_inv_gamma(2, 1, 1, 0, fig, ax) ax = fig.add_subplot(gs[0, 3]) plot_heat_log_norm_inv_gamma(1, 0.5, 1, 0, fig, ax) ax = fig.add_subplot(gs[1, 3]) plot_heat_log_norm_inv_gamma(1, 2, 1, 0, fig, ax) ax = fig.add_subplot(gs[0, 4]) plot_heat_log_norm_inv_gamma(1, 1, 0.4, 0, fig, ax) ax = fig.add_subplot(gs[1, 4]) plot_heat_log_norm_inv_gamma(1, 1, 4, 0, fig, ax) ax = fig.add_subplot(gs[0, 5]) plot_heat_log_norm_inv_gamma(1, 1, 1, -2, fig, ax) ax = fig.add_subplot(gs[1, 5]) plot_heat_log_norm_inv_gamma(1, 1, 1, 2, fig, ax) plt.tight_layout() plt.savefig("normal-inv-gamma_grid_log.png") plt.show()
正規逆ガンマ分布からのサンプリング
scipy.stats
には正規逆ガンマ分布の関数はないが,
逆ガンマ分布から分散パラメータをサンプリング,
そのサンプリングした分散パラメータをもった正規分布から平均パラメータをサンプリングとすれば,
正規逆ガンマ分布からのサンプリングが可能となる.
from scipy import stats def NormInvGamma_sampling(alpha, beta, delta, gamma, N=1): IG = stats.invgamma(scale=beta, a=alpha) sigma2 = IG.rvs(size=N) mu = sp.array([]) for s2 in sigma2: normal = stats.norm(loc=delta, scale=sp.sqrt(s2/gamma)) mu = sp.append(mu, normal.rvs(size=1)) return sp.vstack((mu, sigma2)) # NormInvGamma_sampling(1, 1, 0, 1, N=10) %matplotlib inline import matplotlib.pyplot as plt samples = NormInvGamma_sampling(1, 1, 0, 1, N=10000) plt.scatter(samples[0], samples[1]) plt.xlabel("$\mu$") plt.ylabel("$\sigma^2$") plt.xlim(-12, 12) plt.ylim(0, 34) plt.savefig("sampling.png") plt.show()
特徴
分散パラメータで積分するとt分布になる
これは,予測分布$P(x^{*} | \vec{x}=\vec{d})$がt分布になることを証明するのに使える.