自動微分変分推論でガウス混合モデルのパラメータ推論 〜PyMC3を使った実装〜
【概要】
- 変分推論の自動化アルゴリズムの一つ、ADVI(自動微分変分推論)を使ってガウス混合分布のパラメータ推論をやってみました
- 基本的には、Gaussian Mixture Model with ADVI(PyMC3 Examples)の写経です
【目次】
はじめに
確率モデルの近似解を得る手法として、MCMCアルゴリズムや変分推論という手法がよく用いられます。
当ブログで以前公開した記事では、MCMCアルゴリズムを利用して混合モデルのパラメータ推論をやってみました。MCMCアルゴリズムは無限にサンプルを繰り返せば解析解に収束する性質を持っているのですが、推論に長時間かかることがあります。 実際、実務で用いる際にちょっと複雑なモデルをMCMCアルゴリズムで推論すると、数時間かかったりということがよくあります。
そこで本記事では、変分推論の自動化として提案されている自動微分変分推論(Automatic Differentiation Variational Inference ; ADVI)というアルゴリズムを利用してみます。ADVIを用いて混合モデルのパラメータを推論し、MCMCの結果と比較してみようと思います。
本記事で利用した全コードは実装コード全体に添付しています。基本的にこちらに全て書いているので、こちらを参照してもらえたら良いと思います。
注意
本記事は、Gaussian Mixture Model with ADVI(PyMC3 Examples)をベースにしています。変分推論やADVIの理論的な解説は本記事では行いません。本記事は、とにかく動かしてみたという程度の内容です。
軽い中身の記事なので、詳しく知りたい方は下記参考文献に掲載の資料を参照してください。
問題設定
本記事では、2次元ガウス混合モデルのパラメータ推論を扱います。 MCMCアルゴリズムを利用したガウス混合モデルのパラメータ推論については、当ブログの以下の記事でも扱っています。
learning-with-machine.hatenablog.com
ガウス混合モデル(GMM;Gaussian Mixture Model)
ガウス混合分布とは、ガウス分布が複数組み合わさってできたモデルで、どのガウス分布から生成されたデータであるかを示す潜在変数があり、このによってモデルがスイッチする仕組みと考えることができます。
一般に「混合モデル」と呼ぶ場合、組み合わせる分布はガウス分布に限りません。線形回帰モデルを組み合わせて(混合線形回帰モデル)複数のトレンドが混在した複雑なデータの回帰問題を扱うこともできます(ベイズ推論により混合分布のパラメータ推論をやってみる 〜線形回帰モデルの混合〜)*1。
確率変数の関係性をグラフィカルモデルで表現すると次のようになります。
それぞれの確率変数は以下の確率分布に従うとします。なお、近似ベイズ推論を利用するので共役性などは気にしていません。確率モデルの構築は分析者の主観に委ねられていますのでこのモデルが絶対という訳では無いことに注意してください。
上記のグラフィカルモデルではで表現したxの確率分布のパラメータは、正規分布なので、平均ベクトルと共分散行列です。
近似ベイズ推論
確率モデルのパラメータをベイズ推論する際には、解析解を導出する場合と近似解を求める場合があります。解析解を導出するには、確率変数間に共役性という制約が必要であったり、複雑なモデルではないとなどいくつか制約がつきます。そのため近似推論がよく用いられています。
近似ベイズ推論は大きく分けて2種類の方法がよく用いられています。
- サンプリングに基づく手法:MCMCアルゴリズム(H/M法、NUTSなど)
- 最適化に基づく手法:平均場近似を利用した変分推論
- (多くの場合)計算は高速
- 単純な分布で事後分布を近似するのでモデルの表現力に乏しい
- 当ブログのこの記事が参考になるかも
混合モデルのパラメータ推論を行った前回の記事では、MCMCアルゴリズムを利用していました。MCMCアルゴリズムはとても便利なアルゴリズムなのですが、推論に時間がかかる場合が少なくありません。
一方、変分推論は、高速に推論が行える場合が多いのですが、一般には最適化のための数式を解きたいモデルに対してそれぞれ導出する必要があります。 この変分推論を自動化するアルゴリズムとして自動微分変分推論(Automatic Differentiation Variational Inference ; ADVI)というアルゴリズムが提案されています(arXiv)。PyMC3では、このADVIがすでに実装されており、簡単に利用することができます。
自動微分変分推論
次の特徴を持ったアルゴリズムらしいですが、まださらっとしか論文を読んでいないので間違いがあるかもです。ここの記述を信用するのではなく、原文をあたりましょう。
ADVIとMCMCによるガウス混合モデルのパラメータ推論
適当なパラメータでコンポーネント数3のガウス混合モデルに従うトイデータを作成し、そのトイデータを使ってガウス混合モデルの推論をやってみます。
トイデータ
トイデータはこの記事にあるものと一緒です。
今回は以下のデータを生成しました。パラメータの設定値については添付のnotebookを参照してください。
MCMCアルゴリズムによる推論
実装はこの記事で書いてあることと全く一緒です。モデルの定義で、PyMC3が提供するMixture APIを利用しています。 Mixtureを利用すると、コンポーネントの割り当てを示す潜在変数が周辺化除去されます。MCMCサンプルを得る場合は、計算の効率化や安定化が目的でしたが、今回はこれが必要な条件となります。後述しますが、ADVIは離散変数に対応していないので。
MCMCでのパラメータ推論(MCMCサンプルの取得)は、以下のようにsample関数を呼ぶだけです。
with gmm_2d: tr_2d2 = pm.sample(1000, chains=1)
サンプル系列などはnotebookをご覧いただくとして、クラスタリングの結果だけ貼っておきます。
ADVIによる推論
続いて、ADVIを利用して変分推論してみましょう。
PyMC3で実装されているADVIは、モデルのインターフェースがMCMCと共通になっており、MCMCのために構築したモデルをそのまま使って推論することができます。ただし、ADVIでは離散確率変数は対応していません。そのため、混合モデルをシンプルに実装したモデルでは、ADVIを使って推論することができません。コンポーネントの割り当てを表す潜在変数が離散だからです(カテゴリ分布)。そのため、潜在変数を周辺化除去する必要があります。周辺化除去については、この記事を参考にしてください。PyMC3ではMixture APIを利用すれば細かいことは考えなくても実装できてしまいます。
モデルを構築した後で、ADVIによる推論は以下のコードを実行するだけです。
with gmm_2d: approx = pm.fit(n=10000, obj_optimizer=pm.adagrad(learning_rate=1e-1))
Optimizerにはadagradの他にもSGDなどが用意されています。詳しく知りたい方は公式ドキュメントなどを参照してください*2。上記の設定では、10,000イテレーションで最適化計算を終了するように設定しています。
推論の結果、学習曲線は以下の通り、収束しているようです。
推論結果として、各値の事後分布などはnotebookを参照してもらうとして、先のMCMCの結果と同様に、推論した混合分布でのクラスタリングの結果を以下に示します。
同じような結果になりましたね。
計算時間の比較
MCMCでの推論と変分推論の計算時間を比較してみます。以下の結果は、試行回数1回だけですし、iteretion数(MCMCならサンプル数)で結果は当然変わってきますので、あくまで参考とだけしてください。
アルゴリズム | 処理時間 |
---|---|
MCMC | 59.4 s |
ADVI | 40.9 s |
この数値だけ見るとやはり変分推論の方が速かったです。上記の通り、iteration数などで全然結果は違ってくるので、意味のある比較ではないのですが。。。
実装コード全体(jupyter notebook)
まとめ
ということで今回は、PyMC3に実装されているADVIを使ってガウス混合モデルの推論をやってみました。
PyMC3は直感的にモデルを構築でき、MCMCも変分推論(ADVI)もインターフェースが共通なので、容易に試すことができました。MCMCだと推論が遅くて困るという場合にかるーく試してみるには良いなと思います。(実際に私が扱ってる案件でまさに困っているので早速試してみよう)
今回は、理論的な解説ではなく、PyMC3の関数を触ってみたというだけの軽い内容だったので、次はADVIの仕組みなどを解説していきたいなと思います。
参考文献
機械学習スタートアップシリーズ ベイズ推論による機械学習入門 (KS情報科学専門書)
- 作者:須山 敦志
- 発売日: 2017/10/21
- メディア: 単行本(ソフトカバー)
Pythonによるベイズ統計モデリング: PyMCでのデータ分析実践ガイド
- 作者:マーティン,オズワルド
- 発売日: 2018/06/22
- メディア: 単行本
Automatic Differentiation Variational Inference, 2016, arXiv