ディリクレ(Dirichlet)分布を3Dで可視化する
(3次元の)ディリクレ分布をpythonのmatplotlib.plot_surface
で可視化するメモ.
環境
- Windows10
- Docker Desktop
ディリクレ分布
ベルヌーイ分布をn変量に拡張したもの. 多項分布(categorical分布)のパラメータの分布でよく出てくる. 例えば,サイコロの各出目がその確率になる確率を計算できる. (6次元の確率から確率密度関数pdfの値が求まる) なので確率変数をすべて足すと1になる.(自由度として1つ減る.) $x_{1}+x_{2}+x_{3}+x_{4}+x_{5}+x_{6}=1$
確率密度関数:ガンマ関数を使って表される. ディリクレ分布のパラメータを$\vec{\alpha}$とする. このパラメータはベルヌーイ分布のパラメータをサイコロのような多面体に拡張したものになる. パラメータの意味としてベルヌーイ分布のパラメータ同様に事前に振っておいてどの目がどのくらい出現していたかに相当する. 全体的にパラメータの値が大きくなると出現回数が確認されてるので分散は小さくなっていく.
$$ \mathrm{pdf}(\vec{x} | \vec{\alpha}) = \frac{1}{C} \prod_{k=1}^{K} x^{\alpha_{ k - 1 }} \\ C = \frac{\prod_{k=1}^{K} \Gamma(\alpha_{k}) }{\Gamma(\sum_{k=1}^{K} \alpha_{k})} $$
K=6とするとサイコロの各出目の確率の確率分布を表せる.
.rvs
などでランダムサンプリングすると,
K=6であれば多項分布のパラメータになる$x_1+\cdots+x_{6}=1$となる確率をサンプリングできる.
from scipy import stats alpha = np.array([1, 1, 1, 1, 1, 1]) # specify concentration parameters stats.dirichlet.rvs(alpha, 1) array([[0.1657884 , 0.17995035, 0.0089385 , 0.05918143, 0.09352992, 0.49261141]])
matplotlibで可視化する.
可視化できるように3次元(K=3)のディリクレ分布で試す. 3次元なので3面のサイコロの各確率が出るに相当する. どの点も$x_{1}+x_{2}+x_{3}=1$を満たす. さらに各頂点は,$x_{1},x_{2},x_{3}$のどれかが1で残りは0になっている.
グラフの各軸は
- x軸:確率変数$x_{1}$
- y軸:確率変数$x_{2}$
- z軸: 確率密度関数$pdf(x_{1},x_{2},1-x_{1}-x_{2})$
とする.
ディリクレ分布の密度関数の計算コード:
# ディレクレ import numpy as np from scipy import special class Dirichlet(): def __init__(self,para:list)->None: self.para = np.array(para) def pdf(self, x:list)->np.float: # 正規化定数 Z x_ar = np.array(x) cons = np.prod(special.gamma(self.para))/(special.gamma(np.sum(self.para))) p = (1./cons) * np.prod(x_ar**(self.para-1)) return p def plt_3d(self, zlim=None)->None: xdata = np.linspace(0, 1, 200) ydata = np.linspace(0, 1, 200) X,Y = np.meshgrid(xdata, ydata) z = [] X[X+Y>1] = 0 Y[X+Y>1] = 0 for _x, _y, _z in zip(X.flatten(), Y.flatten(), (1-X-Y).flatten()): z.append(self.pdf([_x, _y, _z])) Z = np.array(z).reshape(X.shape) ax3d = plt.axes(projection='3d') ax3d.plot_surface(X, Y, Z,cmap='plasma') ax3d.set_zlim(zlim) ax3d.set_xlabel("$x_1$") ax3d.set_ylabel("$x_2$") ax3d.set_zlabel("pdf$(x_1,x_2,1-x_1-x_2)$") ax3d.set_title("Dir($\\vec{\\alpha} = $" + "%s)" % self.para) plt.show() diri.pdf([1.,0.,0.])
3Dの可視化では,[0,1]の範囲をmeshgridを使ってx,yの各2次元配列を作成. $x+y+x_3=1$の制約があるので$x+y>1$になるメッシュは$x=0,y=0$として上書きしている.
パラメータ$\vec{\alpha} = [1,1,1]$のときを描画してみる.jupyer labでは%matplotlib widget
でインタラクティブなグラフが描ける.
%matplotlib widget diri = Dirichlet([1.,1.,1.]) diri.plt_3d(zlim=(0, 2.1))
ちなみに $pdf(1,0,0|alpha=[1,1,1])=2$ になる
パラメータ$\vec{\alpha} = [5,5,5]$: $\vec{\alpha} = [1,1,1]$のときより分散が小さいので真ん中らへんが一番高くなる.
%matplotlib widget diri = Dirichlet([5.,5.,5.]) diri.plt_3d()
パラメータ$\vec{\alpha} = [5,1,1]$: $x_1$が一番起きやすいので$x_1$の値が1付近で密度関数が高くなる.
%matplotlib widget diri = Dirichlet([5.,1.,1.]) diri.plt_3d()