Haskellの線形代数用ライブラリで多変量ガウス分布を計算した際のメモ

Haskellで行列計算したい場合にはhmatrixという線形代数用ライブラリが使える。
この記事では、行列計算でよく使いそうな手続きを紹介する。
また、それらの例として多変量ガウス分布を実装する。

hmarixのインストール

例によってcabalでインストールする。

cabal update
cabal install hmatrix

hmatrixの簡単な使い方

例えば3×3の行列を作る場合には、以下のように実行する。

*Main> :m Numeric.LinearAlgebra
Prelude Numeric.LinearAlgebra> (3 >< 3) ([1..9] :: [Double])
(3><3)
 [ 1.0, 2.0, 3.0
 , 4.0, 5.0, 6.0
 , 7.0, 8.0, 9.0 ]


2×4の行列を作る場合には、以下のようにする。

Prelude Numeric.LinearAlgebra> (2 >< 4) ([1..8] :: [Double])
(2><4)
 [ 1.0, 2.0, 3.0, 4.0
 , 5.0, 6.0, 7.0, 8.0 ]


fromListsを使えば、aから行列を作ることが出来る。

Prelude Numeric.LinearAlgebra> fromLists ([[1,0], [0,1]] :: [[Double]])
(2><2)
 [ 1.0, 0.0
 , 0.0, 1.0 ]


toListsで行列をaに戻す事も出来る。

Prelude Numeric.LinearAlgebra> let m3x3 = (3 >< 3) ([1..9] :: [Double])
Prelude Numeric.LinearAlgebra> toLists m3x3
[[1.0,2.0,3.0],[4.0,5.0,6.0],[7.0,8.0,9.0]]


行列の行数はrowsで求める事が出来る。
行列の列数はcolsで求める事が出来る。

Prelude Numeric.LinearAlgebra> let m2x4 = (2 >< 4) ([1..8] :: [Double])
Prelude Numeric.LinearAlgebra> rows m2x4
2
Prelude Numeric.LinearAlgebra> cols m2x4
4

転置行列はtransで求める。
行列式はdetで求める。
逆行列はinvで求める。

Prelude Numeric.LinearAlgebra> trans m2x4
(4><2)
 [ 1.0, 5.0
 , 2.0, 6.0
 , 3.0, 7.0
 , 4.0, 8.0 ]
Prelude Numeric.LinearAlgebra> det m2x4
*** Exception: det of nonsquare (2><4) matrix
Prelude Numeric.LinearAlgebra> det m3x3
6.661338147750939e-16
Prelude Numeric.LinearAlgebra> inv m3x3
(3><3)
 [ -4.503599627370498e15,   9.007199254740992e15,  -4.503599627370496e15
 ,  9.007199254740996e15, -1.8014398509481984e16,   9.007199254740991e15
 , -4.503599627370498e15,   9.007199254740992e15, -4.5035996273704955e15 ]


行列の積はmultiplyで計算する。

Prelude Numeric.LinearAlgebra> let m1x2 = (1 >< 2) ([1,2]  :: [Double])
Prelude Numeric.LinearAlgebra> let m2x1 = (2 >< 1) ([2,-1] :: [Double])
Prelude Numeric.LinearAlgebra> m1x2 `multiply` m2x1
(1><1)
 [ 0.0 ]


行列から位置を指定して値を取り出す時には@@>演算子を使う。

Prelude Numeric.LinearAlgebra> m3x3 @@> (0,0)
1.0
Prelude Numeric.LinearAlgebra> m3x3 @@> (2,2)
9.0


連立1次方程式も解ける。
以下のような方程式が与えられていた場合には、

 \left\{ 2x_1 + 3x_2 - x_3 = -3\\ -x_1 + 2x_2 + 2x_3 = 1\\ x_1 + x_2 + x_3 = -2\right.

次のように実行して方程式を解くことが出来る。

Prelude Numeric.LinearAlgebra> let a = (3><3) ([ 2, 3,-1,-1, 2, 2,1, 1,-1] :: [Double])
Prelude Numeric.LinearAlgebra> let b = (3><1) ([-3, 1, -2]:: [Double])
Prelude Numeric.LinearAlgebra> linearSolve a b
(3><1)
 [  1.0
 , -1.0
 ,  2.0 ]

多変量ガウス分布

hmatrixを使って多変量ガウス分布を計算する。
このプログラムをコンパイルして実行すると、"gauss.dat"というファイルに計算結果が出力される。

import Numeric.LinearAlgebra
import Graphics.Gnuplot.Simple

-- mahalanobisDistance :: Matrix a -> Matrix a -> Matrix a -> a
mahalanobisDistance avrg covariance x = ((trans y)
                                         `multiply` (inv covariance)
                                         `multiply` y) @@> (0,0)
  where
    y = x - avrg

-- gaussDist :: Matrix a -> Matrix a -> Matrix a -> a
gaussDist mu sigma x = c * exp a
  where
    d = fromIntegral $ cols x
    c = 1.0 / (((2 * pi) ** (d / 0.5)) * ((det sigma) ** 0.5))
    y = x - mu
    a = -0.5 * mahalanobisDistance mu sigma x

xs =  [(2 >< 1) ([x1, x2] :: [Double]) |
        x1 <- linearScale 200 (-1, 1)
      , x2 <- linearScale 200 (-1, 1)]

averageVector    = (2 >< 1) ([0.0, 0.0]:: [Double])
covarianceMatrix = (2 >< 2) ([0.1, 0.05, 0.05, 0.1] :: [Double])

calcGaussDist avrgVec covMatrix xs = xyz
  where
    values = map (gaussDist averageVector covarianceMatrix) xs
    xl     = map (concat . toLists) xs
    xyz    = zipWith (\(x:y:[]) v -> x:y:v:[]) xl values

toSpaceSepString :: Show a => [a] -> String
toSpaceSepString xs = foldr (\x acc -> (show x) ++ " " ++ acc) "\n" xs

main = do
  let results = calcGaussDist averageVector covarianceMatrix xs
      str     = concat $ map toSpaceSepString results
  writeFile "gauss.dat" str
  return ()

"gauss.dat"が作成されたら、以下のように実行してgnuplotで表示する。

gnuplot> set pm3d
gnuplot> set dgrid3d 51,51
gnuplot> set hidden3d
gnuplot> set terminal png
Terminal type set to 'png'
Options are 'nocrop font /usr/share/fonts/dejavu/DejaVuSans.ttf 12 size 640,480 '
gnuplot>  set output 'gauss1.png'
gnuplot> splot "gauss.dat" u 1:2:3 with lines

画像が作成されたら以下のように画像表示コマンドを実行して画像を表示して確認する。

eog gauss1.png