機械と学習する

統計解析、機械学習について学習したことをまとめていきます

「ベイズ統計で実践モデリング」をPythonで実装する #2

【概要】

  • ベイズ統計で実践モデリング」という書籍に掲載の事例をPythonで実装してみます
  • 生で書くのではなく、PyMC、Pyroという二つのPPLを使ってを書いていきます
  • 今回は第4章で、ガウス分布のパラメータ推論について3タイプのモデルを扱います

【目次】


はじめに

本シリーズでは、「ベイズ統計で実践モデリング」に記載の事例をPythonを使って実装します。利用するフレームワークはPyMC3とPyroです。なお、Pyroを使いますが、推論はMCMCを使っていく予定です。

第2回目は書籍の4章を扱います。 なお書籍ではRとmatlabをベースに、WinBUGSをMCMCサンプラーとして、その実装も公開されています。サポートページはこちら

筆者は特にPyroは初心者なので、間違いやより良い実装など指摘していただけるとめちゃくちゃ助かります。

なお、Pyroの記法については、以下のページでも扱っています。

learning-with-machine.hatenablog.com

【トップに戻る】

第4章 ガウス分布を使った推論

4章では、広く使われているガウス分布正規分布)のパラメータ推論を扱っています。

細かい実装については実装をGithubで公開していますので、よかったらこちらも確認してもらえたらと思います。

github.com

平均と標準偏差を推論

まずは、単純なガウス分布のパラメータ推論です。N個のデータが共通のガウス分布から生成されていると仮定し、ガウス分布のパラメータ(平均と標準偏差)を推論するという問題です。

notebookはこちら

グラフィカルモデルを描いてみると次のようになります。

f:id:hippy-hikky:20201001111410p:plain

平均\mu標準偏差\sigmaはそれぞれ独立と仮定します。
平均\muの事前分布はガウス分布とし、標準偏差\sigmaの事前分布は十分と予想される範囲を考慮した一様分布とします。標準偏差の共役事前分布は逆ガンマ分布が知られていますが、(詳しくは読んでないのですが)Gelman(2006)の主張によるとガンマ分布の利用には異論があるようです。
この辺りは書籍の仮定をそのまま引継ぎました。

このモデルを推論するために、PyMC3では次のようにモデルを定義しました。

with pm.Model() as model:
    mu = pm.Normal('mu', mu=0.0, tau=0.0001)
    sigma = pm.Uniform('sigma', 0, 10)
    
    x_obs = pm.Normal('x_obs', mu=mu, sigma=sigma, observed=x)

同様にPyroの場合はこうです。

def model_gauss(params=None):
    mu0 = torch.tensor(0.0)
    sig0 = torch.tensor(np.sqrt(1/0.0001))
    mu = pyro.sample('mu', dist.Normal(mu0, sig0))
    
    low = torch.tensor(0.0)
    high = torch.tensor(10.0)
    sigma = pyro.sample('sigma', dist.Uniform(low, high))
    
    x = pyro.sample('x', dist.Normal(mu, sigma))
    return x

cond_model = pyro.condition(
    model_gauss, 
    data={'x':torch.tensor(x)}
)

データは適当なガウス分布(\mathcal{N}(x | 2.0, 1.5^2))を仮定して100点サンプルしました。書籍の例では4点のデータを入れるようになっていましたが、真の分布がわかった方が面白いと思ったので。

推論結果は次の通りです。設定値をだいたい予測できていることがわかります。

f:id:hippy-hikky:20201001112429p:plain
推論結果.左が平均パラメータ,右が標準偏差の事後分布.

七人の科学者:平均は共通で個人毎の分散パラメータを推論

次の例の問題設定は次のとおりです。

  • 7人の科学者がそれぞれ同じ対象からデータを採取(1回づつ)
  • 採取したデータから対象の真の値を推論したい
  • 7人は同じ対象を計測しているので平均は同一だが、個々人の計測能力が標準偏差に現れるとする

ということで、7人それぞれの「計測能力」をガウス分布で表現してやろうというもんだいです。そのため、平均が共通な7つのガウス分布を推論します。

notebookはこちら

まずグラフィカルモデルを考えます。

f:id:hippy-hikky:20201001113412p:plain

二つの確率変数x\sigmaが箱で囲われていますが、これはM人(7人)の科学者それぞれで標準偏差が異なるガウス分布からxが生成されることを示しています。\muは全体で一つしかないので、つまり共通であるということです。

PyMC3でのモデル定義は次のようにしました。

with pm.Model() as model:
    mu = pm.Normal('mu', mu=0.0, tau=0.001)
    
    _lam = pm.Gamma('_lam', 1., 1., shape=len(xs))
    sigma = pm.Deterministic('sigma', 1.0/np.sqrt(_lam))
    
    for i, x in enumerate(xs):
        x_obs = pm.Normal(f'x_obs_{i}', mu=mu, sigma=sigma[i], observed=x)

今回は標準偏差を推論するために、分散の逆数である精度(\lambda)をガンマ分布を事前分布として推論し、精度パラメータを使って標準偏差を導出しています。

同じモデルをPyroでは以下のように書きました。

def model_gauss(x):
    mu0 = torch.tensor(0.0)
    sig0 = torch.tensor(np.sqrt(1/0.0001))
    mu = pyro.sample('mu', dist.Normal(mu0, sig0))
    
    lamdas = []
    for i in range(len(x)):
        lamdas.append(pyro.sample(f'lam_{i}', dist.Gamma(torch.tensor(1.), torch.tensor(1.))))
    
    x_obs = []
    for i, x_data in enumerate(x):
        x_obs.append(pyro.sample(f'x_{i}', dist.Normal(mu, torch.sqrt(1./lamdas[i])), obs=x_data))
    return x

これまでは、データを条件として与える際にpyro.conditionを使っていたのですが、書き方がよくわからず、今回はモデル内部で直接観測値をobsというパラメータで指定しています。

推論の結果は以下の通りです。

f:id:hippy-hikky:20201001115159p:plain
平均の事後分布

ここで、データは書籍にある通りとしました(以下)。

xs = np.array([-27.020, 3.570, 8.191, 9.898, 9.603, 9.945, 10.056])

このデータの単純な平均は3.46です。しかし、一人目と二人目が大きくずれており、算術平均はこの結果にひきづられていることがわかりますね。推論結果を確認すると、事後分布の平均は9.5であり、直感的に正しそうな値になっていることがわかります(真の値は不明なので)。

次に標準偏差の予測結果です。

f:id:hippy-hikky:20201001115646p:plain
標準偏差の推論結果.

一人目がすごく標準偏差が広がっており、二人目も大きめです。その他は1.2程度の値であり、二人目までが誤差が大きいという推論結果であることがわかります。

IQの繰り返し測定:分散を共通として個人毎の平均パラメータを推論

本章最後の例の問題設定は以下の通りです。

  • 被験者N人がIQテストを繰り返しM回受けたとする
  • M回のテストの得点のばらつきはガウス分布に従うと仮定
  • ばらつきの標準偏差はテストツールの正確さに起因するとして、N人の被験者で共通と仮定
  • ばらつきの平均が被験者毎の(潜在的な)真のIQに相当する

今度は、先の例と違って平均はそれぞれ異なるが、標準偏差が共通であるとするモデルです。

notebookはこちら

グラフィカルモデルは次の通りです。

f:id:hippy-hikky:20201001120150p:plain

こちらの問題設定では、一人の被験者がM回の試験を受けるので、xはMの箱で囲っています。平均\muが異なる分布からxが生成されるので、これらがNの箱で囲われています。標準偏差\sigmaは共通なので一つだけです。

このモデルをPyMC3では次のように書きました。

with pm.Model() as model:
    mu_min = 0
    mu_max = 300
    mus = pm.Uniform('mus', mu_min, mu_max, shape=N)
    
    _lam = pm.Gamma('_lam', 1., 1.)
    sigma = pm.Deterministic('sigma', 1.0/_lam)
    
    for i, x_i in enumerate(x):
        x_obs = pm.Normal(f'x_obs{i}', mu=mus[i], sigma=sigma, observed=x_i)

Pyroではこう書きました。今度も、pyro.conditionは使っていないです。

def model_gauss(x):
    mu_min = torch.tensor(0.0)
    mu_max = torch.tensor(300.0)
    mus = []
    for i in range(len(x)):
        mus.append(pyro.sample(f'mu_{i}', dist.Uniform(mu_min, mu_max)))
    
    lam = pyro.sample('lam', dist.Gamma(torch.tensor(1.), torch.tensor(1.)))
    
    x_obs = []
    for i, x_data in enumerate(x):
        x_obs.append(pyro.sample(f'x_{i}', dist.Normal(mus[i], 1./lam), obs=x_data))
    return x_obs

データは、書籍の通り次の値を入力しました。

x = np.array([[90, 95, 100], 
              [105, 110, 115], 
              [150, 155, 160], 
             ])

N=3, M=3です。

共通の標準偏差の推論結果は次の通りです。

f:id:hippy-hikky:20201001120809p:plain
標準偏差の推論結果.

被験者毎のIQ値の推論結果である、平均パラメータの推論結果は次の通りとなりました。

f:id:hippy-hikky:20201001120924p:plain

こちらは、だいたい想像したような結果でしたね。

【トップに戻る】

おわりに

ということで、「ベイズ統計で実践モデリング」の4章をPyMC3とPyroを使って書いてみました。

4章では広く知られているガウス分布正規分布)のパラメータ推論について扱いました。ガウス分布のパラメータ推論は算術平均が最尤解になるので、あまり「パラメータ推論」という感覚を持っていない方もいらっしゃったかと思いますが、今後階層モデルなどを扱っていくにあたって、今回扱った考え方が活きてくると思います。

また、分散が異なるとした二つ目の例のように、標本平均を単純に確認しただけではまずい例もありました。外れ値が混在するでーたの場合、標本平均(最尤解)は大きくずれてしまう場合があるので、この辺りがベイズ推論の強みが出ているのかなと思いました。

次回は5章「データ解析の例」ということで、この辺りからだいぶ実践的な話題に入っていきそうです。時系列データの変化点検知などもあるので楽しみです。

【トップに戻る】

参考資料