多項式曲線フィッティングを実装してみた

PRMLの"1.1 例:多項式曲線フィッティング"を読んだ。
実際に実装して、科書通りにフィッティング出来ることを確認した。

多項式曲線フィッティングの実装

xを0~1の間で10個選び、それぞれのxに対してsin(2πx)を計算した値にノイズを加えて訓練集合tを作成する。
そして、この訓練集合tを以下の式でフィッティングする。

y\left( x,w \right) = w_0 + w_1 x + w_2 x^2 + \ldots + w_M x^M = \sum_{j=0}^{M} w_j x^j \hspace{50}  (1.1)

フィッティングは、訓練集合tに対して二乗和誤差が最小になるwを見つける事で行う。
二乗和誤差は以下の式で表される。

 E(w) = \frac{1}{2} \sum_{n=1}^N \{ y(x_n,w)-t_n\}^2 \hspace{50} (1.2)

(1.2)に(1.1)を代入して、各wで偏微分した結果が0になる時に二乗和誤差が最小である。

 \frac{\partial E(w)}{\partial w_i}=\sum_{n=1}^{N}\{\sum_{j=0}^{M}w_j x_{n}^{j} -t_n\}x_n^i \\ = \sum_{n=1}^{N}\{\sum_{j=0}^{M}w_j x_{n}^{i+j} -t_n x_n^i\} \\ = \sum_{n=1}^{N} \sum_{j=0}^{M} w_j x_n^{i+j} - \sum_{n=1}^{N} t_n x_n^i \\ = \sum_{j=0}^{M}w_j \sum_{n=1}^{N}x_n^{i+j} - \sum_{n=1}^{N}t_n x_n^i \\ = \sum_{j=0}^{M} A_{ij} w_j - T_i = 0

したがって、
 \sum_{j=0}^{M} A_{ij} w_j = T_i
 A_{ij} = \sum_{n=1}^{N} x_n^{i+j} \\ T_i = \sum_{n=1}^{N} t_n x_n^i

のように書く事が出来て(演習1.1)、M+1個の一次方程式から連立一次方程式が得られる。
その連立一次方程式はAw = Tの行列形式にして解く事が出来る。

ただし、このままでは過学習が発生するので、以下のように罰金項を加える。

 \widetilde{E(w)} = \frac{1}{2} \sum_{n=1}^{N} \{y(x_n, w) - t_n\}^2 + \frac{\lambda}{2}\| w\| ^2 \\ \| w \| = w_0^2 + w_1^2 + w_2^2 + \ldots + w_M^2  \hspace{50} (1.4)

(1.4)も、偏微分して0になる条件から最適なwを求める。
(1.4)に(1.1)を代入して各wで偏微分すると、

 \widetilde{E(w)} = \frac{1}{2} \sum_{n=1}^{N} \{ \sum_{j=0}^{M} w_j x^j -t_n\}^2 + \frac{\lambda}{2} \sum_{m=0}^{M} w_m^2

 \frac{\partial \widetilde{E(w)}}{\partial w_i} = \sum_{j=0}^M A_{ij} w_j - T_i + \lambda w_i = 0

となり、
 A_{ij}' = \begin{cases} \sum_{n=1}^{N} x_n^{i+j} & (i \neq j) \\ \lambda + \sum_{n=1}^{N} x_n^(i+i) & (i = j) \end{cases}
とおくと、
 \sum_{j=0}^{M} A_{ij}' w_j = T_i
が得られ(演習1.2)、ここから連立一次方程式を解くことでwを得る。

この一連の計算を実装したコードは以下の通りである。

import System.Random
import Control.Monad
import Numeric.LinearAlgebra
import Graphics.Gnuplot.Simple
import Data.List

trainingDataFileName = "training.dat"

-- 多項式の係数を計算 mは次数
polynomialCoefficients :: Field a => Bool -> a -> Int -> [a] -> [a] -> [a]
polynomialCoefficients isRegularized lambda m xn tn = [mtrxW @@> (i, 0) | i <- [0 .. m]]
  where
    m'    = m + 1
    -- Ai,j と Ti は 演習1.1を参照せよ
    a i j = sum $ map (^ (i + j)) xn
    -- 式1.4のように正則化した場合、Ai,jは i == jで係数λが加えられる
    -- 正則化した場合のAi,jの導出は演習1.2を参照せよ
    a' i j | i /= j    =  sum $ map (^ (i + j)) xn
           | otherwise =  lambda + (sum $ map (^ (i + j)) xn)
    t i   = sum $ zipWith (*) tn $ map (^ i) xn
    -- Aw = T の行列として扱い、連立方程式を解く
    fA | isRegularized = a'
       | otherwise     = a
    mtrxA = (m' >< m') [fA i j | i <- [0 .. m], j <- [0 .. m]]
    mtrxT = (m' >< 1 ) [t i| i <- [0 .. m]]
    mtrxW = linearSolve mtrxA mtrxT

-- ノイズを付与する
addNoise :: (RandomGen g, Random a, Num a) => (a, a) -> [a] -> g -> [a]
addNoise range xs g = zipWith (+) xs (randomRs range g)

-- gnuplotでフィッティング曲線, 訓練データ, sin(x)をpng形式で出力する
toFittingCurveScript :: Show a => [[a]] -> String -> String
toFittingCurveScript wss attr = "set xrange [-0.5: 2*pi + 0.5]\n"
                           ++ "set yrange [-1.5: 1.5]\n"
                           ++ "set term png\n"
                           ++ (concat $ map order (zip [0..] wss))
  where
    order (m, ws) = "set output \"order" ++ attr ++ show m ++ ".png\"\n"
                    ++ "set title " ++ "\"order " ++ attr ++ show m ++ "\"\n"
                    ++ "plot " ++ (toPolynomial ws)
                           ++ " w l lt 1 t \"fitting\", " 
                           ++ "sin(x) lt 2, "
                           ++ "\"" ++ trainingDataFileName ++ "\" w p pt 6 lt 3\n"
    toPolynomial xs = concat . map toTerm $ zip [0..] xs
    toTerm (i, coefficient) = "+ (" ++ show coefficient ++ ")" 
                              ++ (concat . take i $ repeat "*x")

erms :: Floating a => [a] -> [a] -> [a] -> a
erms xs ts ws = (2 * e / n) ** 0.5
  where
    f x = (sum . zipWith (*) ws $ scanl (*) 1 $ repeat x)
    e   = sum . map (^2) . zipWith (-) ts $ (map f xs)
    n   = genericLength xs

ermsList :: Floating a => [a] -> [a] -> [[a]] -> [a]
ermsList xs ts wss = map (erms xs ts) wss

valuesToString :: Show a => [a] -> String
valuesToString xs = concat $ map (\x -> show x ++ "\n") xs

main = do
  gen <- getStdGen
  let xs   = linearScale 10 (0,2*pi)
      ts   = addNoise (-0.5 :: Double, 0.5 :: Double) (fmap sin xs) gen
      wss  = map (\m -> polynomialCoefficients False 0   m xs ts) [0 .. 10]
      wss' = map (\m -> polynomialCoefficients True  1.5 m xs ts) [0 .. 10]
      testxs = linearScale 100 (0,2*pi)
      testys = addNoise (-0.5 :: Double, 0.5 :: Double) (fmap sin testxs) gen

  writeFile "plot.plt" $ toFittingCurveScript wss ""
  writeFile "plot_reg.plt" $ toFittingCurveScript wss' "_regularized"
  writeFile "erms_train.dat" . valuesToString $ ermsList xs ts wss
  writeFile "erms_test.dat"  . valuesToString $ ermsList testxs testys wss
  writeFile trainingDataFileName $ concat . map 
                (\(x,y) -> show x ++ " " ++ show y ++ "\n") $ zip xs ts

実行手順は以下の通り。

ghc test.hs
./test
gnuplot -p < plot.plt
gnuplot -p < plot_reg.plt


正則化無しで過学習する(1.1)式は以下のようになる。




正則化ありで、過学習を抑制した結果は以下の通り。





Ermsを計算した結果は、erms_train.datとerms_test.datに書きだすので、
gnuplotで以下のように実行して結果を表示する。

set term png
set output "erms.png"
plot "erms_train.dat" w lp lt 3 pt 6, "erms_test.dat" w lp lt 1 pt 6