はしくれエンジニアもどきのメモ

情報系技術・哲学・デザインなどの勉強メモ・備忘録です。

正規逆ガンマ分布の確率密度関数の導出・可視化・サンプリング

正規逆ガンマ分布の確率密度関数の導出・可視化・サンプリング

Udemyの「ベイズ推定とグラフィカルモデル:コンピュータビジョン基礎1」の授業で, ベイズ統計で正規分布のパラメータの分布に使われる正規逆ガンマ分布の紹介があったので,導出と可視化のメモ.

資料として以下を使用しているので記号はそれに合わせる.

正規逆ガンマ分布について

Normal-inverse-gamma distribution - Wikipedia

ベイズ統計で,正規分布のパラメータ(平均$\mu$, 分散 \sigma^{2})の同時分布 Pr(\mu, \sigma^{2}) として使われる. 何故使われるかというと,正規分布の尤度に対して共役事前分布の関係になっているためで, 事後分布も正規逆ガンマ分布になり,分布の更新計算が容易なため.

事後分布 尤度 共役事前分布
正規逆ガンマ分布 正規分布 正規逆ガンマ分布

導出

正規逆ガンマ分布は,正規分布のパラメータを以下のように考えた同時分布になっている.

  • 分散パラメータが逆ガンマ分布に従う,つまり, \sigma^{2} \sim \text{InvGamma}(\alpha, \beta) = Pr(\sigma^{2}|\alpha, \beta)

  • 平均パラメータが,上の分散パラメータを使った正規分布に従う,厳密には, \mu \sim \text{Norm}(\delta, \frac{\sigma^{2}}{\gamma}) = Pr(\mu | \delta, \frac{\sigma^{2}}{\gamma})

以上の同時分布により,逆ガンマ分布と正規分布の積による密度関数を導出できる.

改めて,逆ガンマ分布と正規分布を以下のように定義する.


\text{InvGamm}( \sigma^{2} | \alpha, \beta )
= \frac{ \beta^{\alpha} }{ \Gamma( \alpha ) }\sigma^{ 2 ( - \alpha - 1 ) }
\exp{ ( \frac{-\beta}{\sigma^{2}} )}

\text{Norm}( \mu | \delta, \frac{\sigma^{2}}{\gamma} )
= \frac{ 1 }{ \sqrt{2 \pi \frac{\sigma^{2}}{\gamma} }} \exp{[ - 0.5 \frac{ ( \mu - \delta )^{2}}{ \frac{\sigma^{2}}{\gamma} } ] }

正規逆ガンマ分布は上記の分布の積により,以下の式で表される.


\text{NormInvGamm}(\mu, \sigma^{2} | \alpha, \beta, \delta, \gamma ) = \text{Norm}(\mu|\delta, \frac{\sigma^{2}}{\gamma}) \cdot \text{InvGamm}(\sigma^{2}|\alpha, \beta) \\

密度関数を導出すると,


\begin{eqnarray}
\text{NormInvGamm}(\mu, \sigma^{2} | \alpha, \beta, \delta, \gamma )
&=& \text{Norm}(\mu|\delta, \frac{\sigma^{2}}{\gamma}) \cdot \text{InvGamm}(\sigma^{2}|\alpha, \beta) \\\\
&=& \frac{1}{\sqrt{2 \pi \frac{\sigma^{2}}{\gamma} }} \exp{ [ -0.5\frac{(\mu - \delta)^{2}}{ \frac{\sigma^{2}}{\gamma} } ] } \cdot \frac{\beta^{\alpha}}{\Gamma(\alpha)}\sigma^{2(-\alpha-1)}
\exp{ ( \frac{ - \beta}{\sigma^{2}} ) } \\\\
&=& \frac{\sqrt{\gamma}}{\sqrt{2 \pi \sigma^{2}}} \frac{ \beta^{\alpha}}{\Gamma(\alpha)} \left( \frac{1}{\sigma^{2}} \right)^{\alpha + 1} \exp{ ( - \frac{\gamma( \mu - \delta)^{2}}{ 2\sigma^{2} } ) } \exp{ ( \frac{-\beta}{\sigma^{2}} ) } \\\\
&=& \frac{\sqrt{\gamma}}{\sqrt{2 \pi \sigma^{2}}} \frac{\beta^{\alpha}}{\Gamma( \alpha ) } \left(\frac{1}{\sigma^{2}} \right)^{\alpha + 1} \exp{ ( -\frac{2\beta + \gamma(\mu - \delta )^{2}}{ 2\sigma^{2} } ) }
\end{eqnarray}

密度計算のコード

この式より,手打ちした以下のコードで確率密度を計算できる.

import scipy as sp
from scipy import special

def norm_inv_gamma_pdf_mamually(mu, sigma2, alpha, beta, gamma, delta):
    '''
    only one side vec:
    mu: vec
    sigma2: vec
    '''
    c = sp.sqrt(gamma / (2 * sp.pi * sigma2))
    g = beta**alpha / special.gamma(alpha)
    acc = 1 / (sigma2**(alpha + 1))
    e = sp.exp(- (2 * beta + gamma * (delta - mu)**2) / (2 * sigma2))
    return c * g * acc * e

上記コードでもいいが,正規逆ガンマ分布はもともと,逆ガンマ分布と正規分布との積の同時分布なので, 密度も積で計算できる.

scipy.statsには逆ガンマ分布と正規分布の密度関数pdfを計算できるので, スカラ量を与えて密度を計算すると以下のコードになる.

para_alpha = 1.
para_beta = 1.
para_delta = 0
para_gamma = 1.

var_sigma2 = 1.
var_mu = 0

scipy.stats.invgamma(a=para_alpha, scale=para_beta).pdf(var_sigma2) * scipy.stats.norm(loc=para_delta, scale=sp.sqrt(var_sigma2/para_gamma)).pdf(var_mu)

これを踏まえて,変数$\mu$と$\sigma^{2}$をスカラでなく配列で与えたときに, 密度行列を返すコードを作る.

import scipy as sp
from scipy import stats

def norm_inv_gamma_pdf(mu, sigma2, alpha, beta, gamma, delta):
    '''
    N(mu |δ,σ2/λ)・IG(σ2|α,β)
    return pdf_Matrix
    '''
    col = len(mu)
    row = len(sigma2)
    sigma2 = sp.sort(sigma2)[::-1]# desc
    IG = stats.invgamma(a=alpha, scale=beta)
    Mpdf_sigma2 = IG.pdf(x=sigma2).reshape(-1, 1)
    Mpdf_sigma2 = sp.repeat(Mpdf_sigma2, repeats=col, axis=1)

    Mpdf_N = sp.zeros_like(Mpdf_sigma2)
    for r in sp.arange(row):
        N = stats.norm(loc=delta, scale=sp.sqrt(sigma2[r]/gamma))
        Mpdf_N[r] = N.pdf(x=mu)
    return Mpdf_N * Mpdf_sigma2

正規逆ガンマ分布をヒートマップで可視化

正規逆ガンマ分布の密度関数を可視化する. 平均パラメータ$\mu$と分散パラメータ$\sigma^{2}$との2次元になるため, ヒートマップ(2次元ヒストグラム)で可視化する.

matplotlibの場合,matplotlib.pyplot.imshowに密度行列を渡せば描画できる. 上記コードで密度行列を計算できたのでこれを渡せばいい.

plt.imshow(cmap="hot")を指定すれば,密度の高いところが白く,低いところが黒く,中間が赤になる.

$(\alpha, \beta, \gamma, \delta) = (1., 1., 1., 0)$の正規逆ガンマ分布で $\sigma^{2} = (0.01, 7), \mu = (-5, 5)$ の範囲の確率密度を描画する.

import scipy as sp
from scipy import stats
import matplotlib.pyplot as plt

%matplotlib inline
fig = plt.figure(figsize=(5, 5))

def norm_inv_gamma_pdf(mu, sigma2, alpha, beta, gamma, delta):
    '''
    N(mu |δ,σ2/λ)・IG(σ2|α,β)
    return pdf_Matrix
    '''
    col = len(mu)
    row = len(sigma2)
    sigma2 = sp.sort(sigma2)[::-1]# desc
    IG = stats.invgamma(a=alpha, scale=beta)
    Mpdf_sigma2 = IG.pdf(x=sigma2).reshape(-1, 1)
    Mpdf_sigma2 = sp.repeat(Mpdf_sigma2, repeats=col, axis=1)

    Mpdf_N = sp.zeros_like(Mpdf_sigma2)
    for r in sp.arange(row):
        N = stats.norm(loc=delta, scale=sp.sqrt(sigma2[r]/gamma))
        Mpdf_N[r] = N.pdf(x=mu)
    return Mpdf_N * Mpdf_sigma2

alpha, beta, gamma, delta = 1., 1., 1., 0
sigma2 = sp.linspace(7, 0.01, 500)
mu = sp.linspace(-5, 5, 500)
M=norm_inv_gamma_pdf(mu, sigma2, alpha, beta, gamma, delta)

plt.imshow(M, cmap="hot", extent=[mu[0], mu[-1], sigma2[-1], sigma2[0]], aspect=10/7)
plt.colorbar(shrink=0.75)
plt.xlabel("$\mu$")
plt.ylabel("$\sigma^{2}$")
plt.xlim(-5, 5)
plt.ylim(0, 7)
plt.tight_layout()
# plt.savefig("n-inv-gamma.png")
plt.show()

(\alpha, \beta, \gamma, \delta) = (1., 1., 1., 0)の正規逆ガンマ分布のヒートマップ

パラメータを変えて, "Computer vision: models, learning and inference" by Simon Prince p.41 Fig 3.6と同じ図を描画してみる.

matplotlibの場合,matplotlib.gridspecを使うとgridlayoutのように3x2で図を配置などして描画できる.

左から - (alpha, beta, gamma, delta)=(1., 1., 1., 0) - (0.5., 1., 1., 0) - (2., 1., 1., 0) - (1., 0.5., 1., 0) - (1., 2., 1., 0) - (1., 1., 0.4., 0) - (1., 1., 4., 0) - (1., 1., 1., -2.) - (1., 1., 1., 2.)

import scipy as sp
from scipy import stats
from scipy import special
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

%matplotlib inline

def norm_inv_gamma_pdf_mamually(mu, sigma2, alpha, beta, gamma, delta):
    '''
    only one side vec:
    mu: vec
    sigma2: vec
    '''
    c = sp.sqrt(gamma / (2 * sp.pi * sigma2))
    g = beta**alpha / special.gamma(alpha)
    acc = 1 / (sigma2**(alpha + 1))
    e = sp.exp(- (2 * beta + gamma * (delta - mu)**2) / (2 * sigma2))
    return c * g * acc * e

def norm_inv_gamma_pdf(mu, sigma2, alpha, beta, gamma, delta):
    '''
    N(mu |δ,σ2/λ)・IG(σ2|α,β)
    return pdf_Matrix
    '''
    col = len(mu)
    row = len(sigma2)
    sigma2 = sp.sort(sigma2)[::-1]# desc
    IG = stats.invgamma(a=alpha, scale=beta)
    Mpdf_sigma2 = IG.pdf(x=sigma2).reshape(-1, 1)
    Mpdf_sigma2 = sp.repeat(Mpdf_sigma2, repeats=col, axis=1)

    Mpdf_N = sp.zeros_like(Mpdf_sigma2)
    for r in sp.arange(row):
        N = stats.norm(loc=delta, scale=sp.sqrt(sigma2[r]/gamma))
        Mpdf_N[r] = N.pdf(x=mu)
    return Mpdf_N * Mpdf_sigma2

def plot_heat(fig, ax, M, vmin, vmax, extent, cmap="hot", shrink=0.7):
    im = None
    if vmin and vmax:
        im = ax.imshow(M, cmap=cmap, vmin=vmin, vmax=vmax, extent=extent)
    else:
        im = ax.imshow(M, cmap=cmap, extent=extent)
    fig.colorbar(im, ax=ax, shrink=shrink)
    ax.set_aspect(1./ax.get_data_ratio())
    return im

def plot_heat_log(fig, ax, M, vmin, vmax, extent, cmap="hot", shrink=0.7):
    return plot_heat(fig, ax, sp.log(M+1e-9), vmin, vmax, extent, cmap, shrink)

def plot_heat_norm_inv_gamma(alpha, beta, gamma, delta, fig, ax,  mu=sp.linspace(-5, 5, 500), sigma2 = sp.linspace(7, 0.01, 500), cmap="hot",
                             vmin=None, vmax=None,
                             extent=[-5, 5, 0.01, 7], shrink=0.7, titleFontSize=17):
    M = norm_inv_gamma_pdf(mu, sigma2, alpha, beta, gamma, delta)
    plot_heat(fig, ax, M, cmap=cmap, vmin=vmin, vmax=vmax, extent=extent, shrink=shrink)
    ax.set_xlabel("$\mu$")
    ax.set_ylabel("$\sigma^{2}$")
    ax.set_title(f'({alpha:.1f}, {beta:.1f}, {gamma:.1f}, {delta:.1f})'.format(
        alpha, beta, gamma, delta), fontsize=titleFontSize)

def plot_heat_log_norm_inv_gamma(alpha, beta, gamma, delta, fig, ax,  mu=sp.linspace(-5, 5, 500), sigma2 = sp.linspace(7, 0.01, 500), cmap="hot",
                             vmin=-20, vmax=-3,
                             extent=[-5, 5, 0.01, 7], shrink=0.7, titleFontSize=17):
    M = norm_inv_gamma_pdf(mu, sigma2, alpha, beta, gamma, delta)
    plot_heat_log(fig, ax, M, cmap=cmap, vmin=vmin, vmax=vmax, extent=extent, shrink=shrink)
    ax.set_xlabel("$\mu$")
    ax.set_ylabel("$\sigma^{2}$")
    ax.set_title(f'({alpha:.1f}, {beta:.1f}, {gamma:.1f}, {delta:.1f}) log density'.format(
        alpha, beta, gamma, delta), fontsize=titleFontSize)

fig = plt.figure(figsize=(30, 10))
gs = gridspec.GridSpec(2, 6)

ax = fig.add_subplot(gs[0:2, 0:2])
plot_heat_norm_inv_gamma(1, 1, 1, 0, fig, ax)

ax = fig.add_subplot(gs[0, 2])
plot_heat_norm_inv_gamma(0.5, 1, 1, 0, fig, ax)
ax = fig.add_subplot(gs[1, 2])
plot_heat_norm_inv_gamma(2, 1, 1, 0, fig, ax)

ax = fig.add_subplot(gs[0, 3])
plot_heat_norm_inv_gamma(1, 0.5, 1, 0, fig, ax)
ax = fig.add_subplot(gs[1, 3])
plot_heat_norm_inv_gamma(1, 2, 1, 0, fig, ax)

ax = fig.add_subplot(gs[0, 4])
plot_heat_norm_inv_gamma(1, 1, 0.4, 0, fig, ax)
ax = fig.add_subplot(gs[1, 4])
plot_heat_norm_inv_gamma(1, 1, 4, 0, fig, ax)

ax = fig.add_subplot(gs[0, 5])
plot_heat_norm_inv_gamma(1, 1, 1, -2, fig, ax)
ax = fig.add_subplot(gs[1, 5])
plot_heat_norm_inv_gamma(1, 1, 1, 2, fig, ax)

plt.tight_layout()
plt.savefig("normal-inv-gamma_grid.png")
plt.show()

いろいろパラメータを変えた正規逆ガンマ分布のヒートマップ

確率密度0から1の中で差がわかるように対数密度にする. "Computer vision: models, learning and inference" by Simon Prince p.41 Fig 3.6もlogの文字はないが対数密度になっている.

%matplotlib inline

fig = plt.figure(figsize=(30, 10))
gs = gridspec.GridSpec(2, 6)

ax = fig.add_subplot(gs[0:2, 0:2])
plot_heat_log_norm_inv_gamma(1, 1, 1, 0, fig, ax)

ax = fig.add_subplot(gs[0, 2])
plot_heat_log_norm_inv_gamma(0.5, 1, 1, 0, fig, ax)
ax = fig.add_subplot(gs[1, 2])
plot_heat_log_norm_inv_gamma(2, 1, 1, 0, fig, ax)

ax = fig.add_subplot(gs[0, 3])
plot_heat_log_norm_inv_gamma(1, 0.5, 1, 0, fig, ax)
ax = fig.add_subplot(gs[1, 3])
plot_heat_log_norm_inv_gamma(1, 2, 1, 0, fig, ax)

ax = fig.add_subplot(gs[0, 4])
plot_heat_log_norm_inv_gamma(1, 1, 0.4, 0, fig, ax)
ax = fig.add_subplot(gs[1, 4])
plot_heat_log_norm_inv_gamma(1, 1, 4, 0, fig, ax)

ax = fig.add_subplot(gs[0, 5])
plot_heat_log_norm_inv_gamma(1, 1, 1, -2, fig, ax)
ax = fig.add_subplot(gs[1, 5])
plot_heat_log_norm_inv_gamma(1, 1, 1, 2, fig, ax)

plt.tight_layout()
plt.savefig("normal-inv-gamma_grid_log.png")
plt.show()

いろいろパラメータを変えた正規逆ガンマ分布のヒートマップ(対数密度)

正規逆ガンマ分布からのサンプリング

scipy.statsには正規逆ガンマ分布の関数はないが, 逆ガンマ分布から分散パラメータをサンプリング, そのサンプリングした分散パラメータをもった正規分布から平均パラメータをサンプリングとすれば, 正規逆ガンマ分布からのサンプリングが可能となる.

from scipy import stats

def NormInvGamma_sampling(alpha, beta, delta, gamma, N=1):
    IG = stats.invgamma(scale=beta, a=alpha)
    sigma2 = IG.rvs(size=N)

    mu = sp.array([])
    for s2 in sigma2:
        normal = stats.norm(loc=delta, scale=sp.sqrt(s2/gamma))
        mu = sp.append(mu, normal.rvs(size=1))
    return sp.vstack((mu, sigma2))

# NormInvGamma_sampling(1, 1, 0, 1, N=10)

%matplotlib inline
import matplotlib.pyplot as plt

samples = NormInvGamma_sampling(1, 1, 0, 1, N=10000)
plt.scatter(samples[0], samples[1])

plt.xlabel("$\mu$")
plt.ylabel("$\sigma^2$")
plt.xlim(-12, 12)
plt.ylim(0, 34)
plt.savefig("sampling.png")
plt.show()

10,000個サンプリングした散布図

特徴

分散パラメータで積分するとt分布になる

これは,予測分布$P(x^{*} | \vec{x}=\vec{d})$がt分布になることを証明するのに使える.

cartman0.hatenablog.com