機械と学習する

統計解析、機械学習について学習したことをまとめていきます

ベイズ推論により混合分布のパラメータ推論をやってみる 〜ガウス混合分布のパラメータ推論とクラスタリング〜

【概要】

  • 混合分布(混合モデル)はモデルを潜在変数でスイッチする構造を持ったモデルであり、実用的な観点でも面白いです
  • これから数回にわたって、混合分布を使って遊んでみます
  • 第2弾では、ガウス混合分布のパラメータ推論と、推論結果を使って未知の入力のクラスタリングをやってみます

【目次】


はじめに

機械学習や統計の問題では、手元にあるデータを解釈して応用しようとしますね。 この時、明に暗になんらかの「確率モデル」を仮定しているはずです。

確率モデルの中でも、混合分布(混合モデル)は、複数の確率モデルの組み合わせとして定義されており、複雑なデータ構造を表現できます。 応用としても、クラス分類や多クラスの回帰など面白い実用例があります。

ということで数回にわたって、混合モデルを使って遊んでみたいと思います。
第2弾となる本記事では、混合モデルの中でもよく使われている「ガウス混合モデル(GMM;Gaussian Mixture Model)」のパラメータ推論を扱います。また、ガウス混合モデルの応用例として「クラスタリング」も行います。

ベイズ推論により混合分布のパラメータ推論をやってみる 記事一覧】

learning-with-machine.hatenablog.com

  • 第4弾:

learning-with-machine.hatenablog.com

【トップに戻る】

ガウス混合モデル(GMM;Gaussian Mixture Model)

ガウス混合モデルとは、ガウス分布が複数組み合わさってできたモデルです。 一般には、組み合わせる分布はガウス分布に限定する必要はなく、また、複数の分布を組み合わせても良いです(ガウス分布の組み合わせていない場合は、総称して「混合モデル」と呼びます)。

分布を組み合わせることで、複雑なデータ構造をモデリングできます。潜在的に多クラスあることがわかっているけど、事前にクラスを分類しておくことができないようなデータに適用できます*1

観測変数をx_nとしてN個得られるとします。 x_nガウス分布\mathcal{N}(x_n | \mu_k, \Sigma_k)に従うものとします。 ここで\cdot_kはk番のコンポーネントであることを示します(コンポーネントとは、混合分布を形成する要素、ガウス分布などを示します)。 コンポーネントの所属を示す潜在変数をs_nとします。 コンポーネントの混ざり具合を表すパラメータを\piとします。これは、sがどの程度割り振られているかを示す値です。

確率変数の関係性をグラフィカルモデルで表現すると以下のようになります。

f:id:hippy-hikky:20200308143035p:plain

このモデルの同時分布は次の通りです。

f:id:hippy-hikky:20200308144701p:plain

Xは観測されるデータと考えるので、推論対象の確率変数は \left\{ \pi, S, \Theta \right\}の3種類です。 今回も前回と同様に、パラメータ推論にはPyMC3を利用したMCMCアルゴリズムを利用します*2

【トップに戻る】

ガウス混合モデルのパラメータ推論とクラスタリング

実際の推論については、添付するnotebookを参照ください。重要な部分だけかいつまんで解説します。

1次元ガウス混合モデル

1次元のガウス混合モデルについては、PyMC3のExampleであるGaussian Mixture Modelほぼそのままです。しかし、Exampleではpm.Potentialという関数で何かを書いているのですが、何をやっているのかよくわかりません*3。よくわからないのでここでは使っていません。使わないでも大きな支障はなさそうですし。 詳しくはnotebookを見てください。

推論結果の利用例として、本記事では未知のデータx_*がどのコンポーネントに属するのかを推論する、いわゆる「クラスタリング」を扱います。 クラスタリングについては、新しいデータx_*が与えられた時の、コンポーネントkの確率p(k|x_*)を推論する問題となります。この条件付き確率は、定義に従って計算すると以下のように、確率密度に比例する形になります。なので、コンポーネント毎に確率密度を算出し、正規化することで、p(k|x_*)を得ることにしました(notebook参照)。

f:id:hippy-hikky:20200308151849p:plain:w280

結果は次のようになります。

f:id:hippy-hikky:20200308145720p:plain
コンポーネント数3のガウス混合モデルの推論結果とクラスタリング結果.ヒストグラムは実際に得られたデータ.ヒストグラムのビン毎にどのクラスに属する確率が高いのかを色で表現している. 太線は推論した事後分布の平均で描いたガウス分布

ヒストグラムの色をみると、4付近で二つの分布が接近しており、この付近の確率が青と赤が混じった確率になっていることがわかります。また同様に、−1の付近でも色の重なりが確認できます。

2次元ガウス混合モデル

2次元ガウス混合モデルも1次元の場合とモデル/考え方は同じです。

一つ注意するのは、前回共分散行列の推論には3つのアプローチがあると述べました。ここでは、LKJCholeskyCov APIを利用します*4

ここでは、共分散行列は共通で平均ベクトルだけが異なるガウス分布の混合分布の推論を行います。次の章でそれぞれの共分散行列も推論するモデルに拡張します。

推論した混合モデルを使ってクラスタリングをした結果は次の通りです。

f:id:hippy-hikky:20200308145854p:plain
コンポーネント数2の2次元ガウス混合モデルでのクラスタリング結果.観測データをオレンジと青の点で示す.背景色がクラスの確率を示す.

コンポーネント数が2なので、二つの領域に分離された空間になります。境界付近は確率が混ざっているので、色がぼやけていることがわかりますね*5

潜在変数の周辺化除去

先の例では、共分散行列行列がコンポーネント間で共通であると仮定していました。しかしこれでは全然面白くないので、コンポーネント毎に共分散行列を推論するように拡張します。

前章の素直なアプローチでは、コンポーネント数2で共分散行列が共通なのに、MCMCアルゴリズムでサンプルを1,500点得るのに10分近くかかってしまいました。計算の効率化が必要ですね。

基本的にはこれまでと同じ考え方で推論を行うのですが、計算を効率化するために、コンポーネントの割り当てを示す潜在変数s_nを周辺化除去します*6。 周辺化除去については、PyMC3 Examples Marginalized Gaussian Mixture Modelに詳しく書いてます。また、Stanのドキュメントがベースの議論で、こちらの方が詳しく書かれています。

潜在変数を周辺化除去すると、グラフィカルモデルは次のようになります。

f:id:hippy-hikky:20200308150225p:plain:w300

この潜在変数の周辺化除去については、PyMC3のMixture APIで機能が提供されています。ここでは、このAPIを使って実装します。

3コンポーネントガウス混合モデルの推論結果を利用したクラスタリング結果は次の通りです。

f:id:hippy-hikky:20200308150056p:plain
コンポーネント数3の2次元ガウス混合モデルでのクラスタリング結果.図中の点は観測データを示す.背景色が3つのコンポーネントの所属クラスの確率を示す.

注目は座標(0, 8)のあたり。この付近は、実際のデータでは青と赤のクラスに属するデータが近いですね。しかし、推論結果は緑のクラスの確率が高くなっているようです。

これは、青と赤のクラスの共分散の構造が尖っていることから、この領域は広く分布する形である緑のクラスに所属する確率が高いものとなっていると思われます。 この結果は、「ベイズ推論による機械学習」のp.158にある例を再現できたということなのかなと思っています(間違ってたらごめんなさい)。

ということで実際のnotebook

【トップに戻る】

まとめ

今回は、ガウス混合モデルのパラメータをベイズ推論により推論してみました。前回は単純なガウス分布のパラメータ推論だったので、ようやく実践的なモデルになりましたね。

計算効率化のテクニックとして潜在変数を周辺化除去しましたが、これによってだいぶ計算が効率化されました。ただし、周辺化除去だけの効果なのかMixture APIでもっと色々やっているのかは調べきれてないです。。。
周辺化除去については、単純なアイディアで定式化もすんなりできるので*7参考になりました。

次回は、混合モデルのスイッチングの機能に注目して、複数のクラスが混在する回帰モデルを推論してみたいと思います。

【トップに戻る】

参考文献

パターン認識と機械学習 上

パターン認識と機械学習 上

【トップに戻る】

*1:実際の問題はほとんどがこのような問題ですよね

*2:前回MCMCを利用したのは、混合モデルのコンポーネントの定義をPyMC3でどのように書くかを整理したかったためです

*3:ググったら同じように悩んでる人がいて良かった

*4:PyMC3の公式Exampleなどを見るとこちら手法が推奨されているようだったので

*5:あんまりぼやけなかったので、もうちょっと近づけた方が良かったかもです

*6:こういうのが「崩壊型ギブスサンプリング 」と思って良いのかな??

*7:定式化の詳細は添付してないですが