多項式曲線フィッティングを実装してみた
PRMLの"1.1 例:多項式曲線フィッティング"を読んだ。
実際に実装して、科書通りにフィッティング出来ることを確認した。
多項式曲線フィッティングの実装
xを0~1の間で10個選び、それぞれのxに対してsin(2πx)を計算した値にノイズを加えて訓練集合tを作成する。
そして、この訓練集合tを以下の式でフィッティングする。
フィッティングは、訓練集合tに対して二乗和誤差が最小になるwを見つける事で行う。
二乗和誤差は以下の式で表される。
(1.2)に(1.1)を代入して、各wで偏微分した結果が0になる時に二乗和誤差が最小である。
したがって、
のように書く事が出来て(演習1.1)、M+1個の一次方程式から連立一次方程式が得られる。
その連立一次方程式はAw = Tの行列形式にして解く事が出来る。
ただし、このままでは過学習が発生するので、以下のように罰金項を加える。
(1.4)も、偏微分して0になる条件から最適なwを求める。
(1.4)に(1.1)を代入して各wで偏微分すると、
となり、
とおくと、
が得られ(演習1.2)、ここから連立一次方程式を解くことでwを得る。
この一連の計算を実装したコードは以下の通りである。
import System.Random import Control.Monad import Numeric.LinearAlgebra import Graphics.Gnuplot.Simple import Data.List trainingDataFileName = "training.dat" -- 多項式の係数を計算 mは次数 polynomialCoefficients :: Field a => Bool -> a -> Int -> [a] -> [a] -> [a] polynomialCoefficients isRegularized lambda m xn tn = [mtrxW @@> (i, 0) | i <- [0 .. m]] where m' = m + 1 -- Ai,j と Ti は 演習1.1を参照せよ a i j = sum $ map (^ (i + j)) xn -- 式1.4のように正則化した場合、Ai,jは i == jで係数λが加えられる -- 正則化した場合のAi,jの導出は演習1.2を参照せよ a' i j | i /= j = sum $ map (^ (i + j)) xn | otherwise = lambda + (sum $ map (^ (i + j)) xn) t i = sum $ zipWith (*) tn $ map (^ i) xn -- Aw = T の行列として扱い、連立方程式を解く fA | isRegularized = a' | otherwise = a mtrxA = (m' >< m') [fA i j | i <- [0 .. m], j <- [0 .. m]] mtrxT = (m' >< 1 ) [t i| i <- [0 .. m]] mtrxW = linearSolve mtrxA mtrxT -- ノイズを付与する addNoise :: (RandomGen g, Random a, Num a) => (a, a) -> [a] -> g -> [a] addNoise range xs g = zipWith (+) xs (randomRs range g) -- gnuplotでフィッティング曲線, 訓練データ, sin(x)をpng形式で出力する toFittingCurveScript :: Show a => [[a]] -> String -> String toFittingCurveScript wss attr = "set xrange [-0.5: 2*pi + 0.5]\n" ++ "set yrange [-1.5: 1.5]\n" ++ "set term png\n" ++ (concat $ map order (zip [0..] wss)) where order (m, ws) = "set output \"order" ++ attr ++ show m ++ ".png\"\n" ++ "set title " ++ "\"order " ++ attr ++ show m ++ "\"\n" ++ "plot " ++ (toPolynomial ws) ++ " w l lt 1 t \"fitting\", " ++ "sin(x) lt 2, " ++ "\"" ++ trainingDataFileName ++ "\" w p pt 6 lt 3\n" toPolynomial xs = concat . map toTerm $ zip [0..] xs toTerm (i, coefficient) = "+ (" ++ show coefficient ++ ")" ++ (concat . take i $ repeat "*x") erms :: Floating a => [a] -> [a] -> [a] -> a erms xs ts ws = (2 * e / n) ** 0.5 where f x = (sum . zipWith (*) ws $ scanl (*) 1 $ repeat x) e = sum . map (^2) . zipWith (-) ts $ (map f xs) n = genericLength xs ermsList :: Floating a => [a] -> [a] -> [[a]] -> [a] ermsList xs ts wss = map (erms xs ts) wss valuesToString :: Show a => [a] -> String valuesToString xs = concat $ map (\x -> show x ++ "\n") xs main = do gen <- getStdGen let xs = linearScale 10 (0,2*pi) ts = addNoise (-0.5 :: Double, 0.5 :: Double) (fmap sin xs) gen wss = map (\m -> polynomialCoefficients False 0 m xs ts) [0 .. 10] wss' = map (\m -> polynomialCoefficients True 1.5 m xs ts) [0 .. 10] testxs = linearScale 100 (0,2*pi) testys = addNoise (-0.5 :: Double, 0.5 :: Double) (fmap sin testxs) gen writeFile "plot.plt" $ toFittingCurveScript wss "" writeFile "plot_reg.plt" $ toFittingCurveScript wss' "_regularized" writeFile "erms_train.dat" . valuesToString $ ermsList xs ts wss writeFile "erms_test.dat" . valuesToString $ ermsList testxs testys wss writeFile trainingDataFileName $ concat . map (\(x,y) -> show x ++ " " ++ show y ++ "\n") $ zip xs ts
実行手順は以下の通り。
ghc test.hs ./test gnuplot -p < plot.plt gnuplot -p < plot_reg.plt
Ermsを計算した結果は、erms_train.datとerms_test.datに書きだすので、
gnuplotで以下のように実行して結果を表示する。
set term png set output "erms.png" plot "erms_train.dat" w lp lt 3 pt 6, "erms_test.dat" w lp lt 1 pt 6