機械と学習する

統計解析、機械学習について学習したことをまとめていきます

Pyroを使った確率モデルの近似推論

【概要】

  • PPLの一つであるPyroを触ってみました
  • Pyroでの変分推論とMCMC(NUTS)でのパラメータ推論についてのメモです

【目次】


はじめに

当ブログの他の記事を見ていただくと分かる通り、確率モデルの近似推論を行う際、普段私はPyMC3を使っています。ですが、以下の背景から別のPPL(確率プログラミング言語; Probabilistic Programming Language)も扱えるようにならんとなと思ってます。

  • PyMC3はTheanoベース。Theanoは有志によるメンテは続いているとのことだが、、、
  • PyMC3の後継であるPyMC4はまだまだ開発中で現時点ではちょっと使いにくい、、、

そこで今回、PyTorchベースのPyroというPPLで確率モデルの近似推論(変分推論、MCMC)を実施するための書き方についてまとめてみました。Pyroはまだまだ触り始めた段階なので、あくまでもメモ程度の内容ですのでそこは注意してください。 なお、今回紹介する全コード(jupyter notebook)を最後に添付していますので、そちらだけ読んでいただいても問題ないと思います。

間違いや改善案など指摘していただけるとめちゃくちゃ助かります。

【トップに戻る】

ベルヌーイ分布のパラメータ推論

まずは最も基本的な確率モデルとして、どんな教科書にも出てくるベルヌーイ分布のパラメータ推論問題を題材にします。

ベルヌーイ分布は、0か1の2値を生成する確率分布で、2クラス分類問題の確率モデルなどとして頻繁に登場する分布です。統計関係の大体の教科書に掲載されているのでご存知の方は多いと思います。

ベルヌーイ分布含め有名な確率分布について私も以下のブログ記事を書いているので、もしよければ参考にしてもらえたらと。

learning-with-machine.hatenablog.com

モデル

ベルヌーイ分布のパラメータ推論は、具体例として、コインの表が出る確率やチョコボールのエンゼルが当たる確率といった問題が該当します。

モデルは以下のようになります。

f:id:hippy-hikky:20200909211827p:plain:h20

それぞれの確率分布は以下の通りとしました。

f:id:hippy-hikky:20200909211945p:plain:h50

p(x|\theta)は問題設定からベルヌーイ分布としますが、p(\theta)はベータ分布でなくても0~1の範囲に収まるようならOKです。

このモデルにおいて、観測としてX=[x_1, x_2, \cdots, x_N]が与えられるとして、p(\theta | X)が知りたい分布です。

このモデルは多くの書籍にある通り、解析的にp(\theta | X)を導出でき、導出結果はベータ分布となります。詳しくは、参考文献などを見てください(Googleで検索すればいくらでも導出が出てくると思います)。

解析解は導出できるのですが、Pyroを使って近似推論してみます。

Pyroによるモデルの定義

上記のモデルをPyroで実装します。以下のように、各確率変数に対して確率分布を定義していくだけです。この辺りはPyMC3も同じで、PPLって感じですね。一点注意として、入力するデータはPytorchのtensorの必要があります。

def model_coin(param=None):
    alpha0 = torch.tensor(1.0)
    beta0 = torch.tensor(1.0)
    theta = pyro.sample('theta', dist.Beta(alpha0, beta0))
    x_smp = pyro.sample('x', dist.Bernoulli(theta))
    return x_smp

このようにモデルを定義すれば、あとは変分推論でもMCMCでの推論でも扱えます。また、条件として固定したい変数を以下のようにpyro.conditionで定義することで、条件付き分布からサンプルを取得することができます。

y_sim = [pyro.sample('y', 
                     pyro.condition(model_coin, data={'theta':0.7})).item() 
         for _ in range(20)]

これは確率\thetaを固定した場合のサンプルです(つまりp(x|\theta))。pyro.conditionで条件を指定しなくてもpyro.sampleでサンプルを取得できますが、その場合は事前分布(初期のモデル)からのサンプルになります。

確率モデルでは、このようにデータを生成することができるので、シミュレーションを行えることが面白いところの一つだと思います。

変分推論

事後分布の近似推論を行うには大きく分けて二つの手法があります。代表的なものに変分推論とMCMCによる推論があります。

まずは変分推論を実装してみます。なお変分推論については、参考文献を参照してください。一応、当ブログでも扱ってるのでその記事を貼っておきます。

learning-with-machine.hatenablog.com

近似分布の定義

変分推論では、変分パラメータ\xiを使って分布p(Z;\xi)を事後分布p(Z|X)に近づけていきます。そのため、この近似分布p(Z;\xi)を定義してあげることが必要になります。
(というのを「ガイド関数」というので良いのかな?)

def posterior():
    alpha_q = pyro.param("alpha_q", torch.tensor(1.0),
                         constraint=constraints.positive)
    beta_q = pyro.param("beta_q", torch.tensor(1.0),
                        constraint=constraints.positive)
    return dist.Beta(alpha_q, beta_q)

def guide(data):
    pyro.sample('theta', posterior())

posteriorで近似分布を定義しています。近似分布はベータ分布とし、alpha_qbeta_qが変分パラメータでこれが推論対象となります。ベータ分布のパラメータは非負なので、constraints.positiveで非負条件を設定しています。 なおベータ分布を指定したのは、解析解がベータ分布になることを知っているということもありますが、\thetaは確率なので、確率を表現する分布ということでベータ分布を採用しています。

ちなみに、ガイド関数はpyroチュートリアルなどでは一つの関数として書いてあることが多いのですが、「まるまるにっき」さんのこちらの記事の書き方がとてもわかりやすかったので、採用させていただきました。

推論

あとはソルバーを指定して推論します。

cond_model = pyro.condition(
    model_coin, 
    data={'x':data})
optimizer = pyro.optim.SGD({"lr": 0.001, "momentum":0.1})
svi = pyro.infer.SVI(model=cond_model, 
                     guide=guide, 
                     optim=optimizer, 
                     loss=pyro.infer.Trace_ELBO())
n_steps = 10000
for step in range(n_steps):
    svi.step(data)

観測データXを条件としてモデルに入れるので、pyro.conditionでデータを指定しています。

推論結果

推論結果はこのようになりました。設定値の付近にピークがあるような分布になりましたね。

f:id:hippy-hikky:20200909222504p:plain
近似事後分布からのサンプルのヒストグラム.太い破線は設定値.

MCMC

次は、同じモデルをMCMCアルゴリズムを用いて推論します。

MCMCアルゴリズムの詳細については、例えば当ブログの以下の記事などを参考にしていただけたらと思います。

learning-with-machine.hatenablog.com

確率遷移核の指定

MCMCは同時確率(モデル)を定義すれば計算ができます。

なので、あとはカーネルを指定します。PyroではNUTS(No-U-Turn Sampler)とHMC(Hamiltonian Monte Carlo)が実装されているそうです。ここではNUTSを使っています。

nuts_kernel = infer.NUTS(cond_model, 
                         adapt_step_size=True, 
                         jit_compile=True, 
                         ignore_jit_warnings=True)

推論

あとは、MCMCサンプルの取得を行います。

mcmc = infer.MCMC(nuts_kernel, 
                  num_samples=1500,
                  warmup_steps=500,
                  num_chains=1)
mcmc.summary()

これだけです。簡単ですね。

推論結果

MCMCによる\thetaの事後分布の推論結果は以下の通りとなりました。

f:id:hippy-hikky:20200909223755p:plain
MCMCサンプルのヒストグラム.太い破線は設定値.

【トップに戻る】

ガウス分布正規分布)のパラメータ推論

次はガウス分布のパラメータ推論を行います。

ガウス分布正規分布)はとても有名な分布でいろいろなところで現れてきますね。

ガウス分布は、平均と分散(標準偏差)の二つのパラメータで定義される確率分布です。先のベルヌーイ分布の例は、推論対象の確率変数は一つだけでしたが、今度は複数の確率変数を推論する場合の例になっています。

モデル

まずは同時確率(モデル)を考えます。

f:id:hippy-hikky:20200910003735p:plain:h25

ここで、それぞれの確率分布は以下の通りとします。

f:id:hippy-hikky:20200910003747p:plain:h90

ベルヌーイ分布の例と同様、xについての確率分布はガウス分布ですが、それぞれのパラメータの分布についてはその値の範囲からある程度自由に選択することができます。今回は、共役関係にある分布を採用しています。

こちらも、このモデルに従えば解析的に事後分布p(\mu, \sigma | X)を導出することがきます。ですが、ここはPyroの練習なので、あえて変分推論とMCMCで推論します。

Pyroによるモデル定義

Pyroの記法に従ってモデルを実装します。

def model_gaussian(params=None):
    mu0 = torch.tensor(0.0)
    sig0 = torch.tensor(10.0)
    a0 = torch.tensor(1.0)
    b0 = torch.tensor(2.0)
    mu = pyro.sample('mu', dist.Normal(mu0, sig0))
    sig = pyro.sample('sig', dist.InverseGamma(a0, b0))
    x = pyro.sample('x', dist.Normal(mu, sig))
    return x

ベルヌーイ分布の場合と同様に、上記のようにモデルを実装すれば、パラメータを条件として与えた場合のシミュレーションを行うことができます。コードは添付のnotebookを参照ください。

変分推論

まずは変分推論を実装します。

近似分布の定義

まずは近似分布p(Z;\xi)を定義します。ここで、ガウス分布は推論対象のパラメータが2つあります。これは正確には依存関係があり、ガウスガンマ分布などを使うことで二つ同時にサンプルすることができます。しかしここでは、平均と標準偏差は独立と仮定し、それぞれ独立に近似分布を設計します(平均場近似)。

def posterior_mu():
    mu_q = pyro.param("mu_q", torch.tensor(0.))
    sig_q = pyro.param("sig_q", torch.tensor(1.), constraint=constraints.positive)
    return dist.Normal(mu_q, sig_q)

def posterior_sig():
    a_q = pyro.param('a_q', torch.tensor(1.0), constraint=constraints.positive)
    b_q = pyro.param('b_q', torch.tensor(2.0), constraint=constraints.positive)
    return dist.InverseGamma(a_q, b_q)

def guide(data):
    pyro.sample('mu', posterior_mu())
    pyro.sample('sig', posterior_sig())

推論

ここまで定義できればあとはベルヌーイ分布の例と同じです。

data = {'x':torch.tensor(xs)}
cond_gauss = pyro.condition(model_gaussian, 
                            data=data)

pyro.clear_param_store()
optimizer = pyro.optim.SGD({"lr": 0.0001, "momentum":0.1})
svi = pyro.infer.SVI(model=cond_gauss, 
                     guide=guide, 
                     optim=optimizer, 
                     loss=pyro.infer.Trace_ELBO())
n_steps = 1000
for step in range(n_steps):
    svi.step(torch.tensor(xs))

推論結果

推論結果はこのようになりました。

f:id:hippy-hikky:20200910005247p:plain
近似事後分布からのサンプルのヒストグラム.上は平均、下は標準偏差.太い破線は設定値.

ということで、概ね設定値周りで事後分布を予測できていますね。

MCMC

次はMCMCで推論してみます。

確率遷移核の指定と推論

こちらでもNUTSを使ってMCMCサンプルを得ることにします。MCMCサンプルの取得については確率モデルにほぼ依存せずに書けるようです。

nuts_kernel = infer.NUTS(cond_gauss, 
                         adapt_step_size=True, 
                         jit_compile=True, 
                         ignore_jit_warnings=True)
mcmc = infer.MCMC(nuts_kernel, 
                  num_samples=1500,
                  warmup_steps=500,
                  num_chains=1)
mcmc.run()

推論結果

MCMCによるそれぞれのパラメータの事後分布は以下の通りです。

f:id:hippy-hikky:20200910005707p:plain
MCMCサンプルのヒストグラム.上は平均,下は標準偏差の事後分布.太い破線は設定値.

変分推論と比較するとだいぶ尖った(確信の度合いが高い)分布になっていますね。

【トップに戻る】

(参考)全コード

【トップに戻る】

おわりに

ということで、Pyroを使ってベルヌーイ分布とガウス分布のパラメータ推論をやってみました。

確率モデルで重要なのは、(事前でも事後でも)確率モデルを使ってシミュレーションできることだと思います。その点では、モデルを定義して、pyro.conditionで条件(データ)を指定できるのはすごくわかりやすいなと思いました。同時確率(model)から条件付き確率(pyro.condition)を作るという流れが式を立てて考えていくそのままの流れになるので。

今回は、ベルヌーイ分布とガウス分布のパラメータ推論なので解析解も容易に導出できる問題でした。次は、もうちょっと複雑なモデルでも同様に書けるのか試してみようと思います。

【トップに戻る】

参考文献

  1. 須山敦志, 機械学習プロフェッショナルシリーズ ベイズ深層学習, 講談社, 2019

ベイズ深層学習 (機械学習プロフェッショナルシリーズ)

ベイズ深層学習 (機械学習プロフェッショナルシリーズ)

  • 作者:須山 敦志
  • 発売日: 2019/08/08
  • メディア: 単行本(ソフトカバー)

  1. C.M.ビショップ(著), パターン認識機械学習 上 (Pattern Recognition and Machine Learning), Springer, 2007

パターン認識と機械学習 上

パターン認識と機械学習 上

  1. C.M.ビショップ(著), パターン認識機械学習 下 (Pattern Recognition and Machine Learning), Springer, 2007

【トップに戻る】