orizuru

つながる.見える.わかる IoTソリュ-ション

ディープラーニングをやらないPyTorch入門

約 7 分
ディープラーニングをやらないPyTorch入門

こんにちはtetsuです。

今回はPyTorchの入門的な記事となります。PyTorchを使ったディープラーニングのサンプルコードはよくありますが、それとは別の方法で説明していきたいと思います。PyTorchでおこなう処理の流れはディープラーニングを扱う場合と変わりませんので、計算の本質的な部分はある程度この記事で理解できるようになることを目指します。

PyTorchを使って解きたい問題、解き方

問題設定

今回は

    \begin{equation*}y = x^2 - 2\end{equation*}

に対してy=0となるようなxを求める問題を解いていこうと思います。この問題の答えは手計算によってすぐに求まり、

    \begin{equation*} x =  \pm\sqrt{2} \approx \pm 1.41421356\end{equation*}

となりますね。簡単すぎるかもしれませんが、答えが分かっている問題を解いたほうが計算が追いやすく、理解の手助けになるかなと思います。

解き方

先程設定した問題では手計算ですぐに答えが求まるので、次から説明することは回りくどく思うかもしれませんが、ディープラーニングのライブラリの計算方法を理解する上で必要な話になります。

ディープラーニングでは勾配降下法という手法をベースにしたものを使って解を求めていきます。
この勾配降下法を使う際には損失関数というものを定義する必要があります。損失関数をより小さくすることができるxが見つかったとき、より良い答え(\pm \sqrt{2}により近いx)が求まる、というような観点で損失関数を設計します(機械学習においては過学習の話を出すとこの限りではありませんが、置いておきましょう)。今の問題で考えると、y0との差が小さくなるxを求めたいので、例えば次のような損失関数Lを最小化できれば良さそうです。

    \begin{equation*}L=(y - 0)^2 = y^2 = (x^2 - 2)^2.\end{equation*}

この損失関数をグラフにしておきます。赤丸はそれぞれ(\sqrt{2}, 0), (-\sqrt{2}, 0)をあらわします。
treeshap_force

勾配降下法では損失関数Lを微分することで勾配を求めます。今回の場合には勾配は次式であらわされます。

(1)   \begin{equation*}\frac{d L}{d x} = 4x(x^2 - 2). \end{equation*}

実はこの勾配にマイナスを掛けた方向にxを移動させることで、Lを小さくすることができます。そのため、次式のようにxを移動させます。

(2)   \begin{eqnarray*}x \leftarrow x - \gamma \frac{d L}{d x} = x - \gamma 4x(x^2 - 2).\end{eqnarray*}

ここで\gammaは学習率と呼ばれるもので、どれだけxを動かすかを制御する値になります。学習率が大きければ、xは大きく移動することになります。大きすぎると移動しすぎる可能性もありますが、小さすぎるとほとんど移動しないため、いい塩梅の値を見つける必要があります。
ここまで説明したような①「勾配を求める」②「xを移動」という2つの操作を繰り返すことで、Lを小さくする方向に少しずつxを移動させていきます。残念ながら勾配降下法では一回で答えが求まるわけではなく、ちょっとずつ答えに近づけていきますが、この方法によって、人間の手計算が困難な場合でも最小値あるいは極小値をとるxの近似値を求めることができます。

PyTorchのコードと説明

ここまでに説明した解き方をPyTorchを用いたプログラムにしたものが以下です。なおPyTorchのバージョンは1.0を想定しています。

プログラムの内容について確認していきます。

  • 5行目でxの初期値を決めていきます。今回はx=1.0としています。勾配降下法では損失関数が小さくなるようにxが動いていくので、初期値の都合上、ひたすら\sqrt{2}の方向へ向かっていきます。先程示した損失関数のグラフをみても分かるとおり、x=1からx=-\sqrt{2}へ移動するには一度損失関数の値が大きくなる必要があるため、-\sqrt{2}の近くには辿り着きません。またrequires_grad=TrueはdL/dxを計算させるという意味になります。
  • 6行目のloss_func = nn.MSELoss()では損失関数を選択しています。MSELossは解き方のところで出てきた損失関数と同じ役割を担います。
  • 7行目のoptimizer = optim.SGD([x], lr=0.01)ではxを移動させる方法としてSGDを選んでいます。SGDは確率的勾配降下法と呼ばれ、説明した解き方と同じ働きをします。lr=0.01では学習率を0.01と定義しています。[x]の部分はSGDで移動させていく変数になります。ディープラーニングではここに指定する変数がたくさん出てきますが、今回はxのみです。
  • 11行目では勾配の計算結果を0で初期化しています。13行目で損失関数の値を計算し、それを利用して14行目で勾配を計算しているのですが、内部的には「(Lの勾配) +={d L}/{d x}」というように、代入ではなく加算がおこなわれます。このため、11行目の処理で一旦勾配の値の初期化が必要となります。
  • 16行目では計算した勾配を用いてxの移動をおこなっています。

このプログラムは勾配降下法の各反復ごとのxとそのときのydL/dxの値を標準出力しており、実行してみると次のように表示されます。一番上の行は初期値のxで計算した値であることに注意して下さい。

初期値に対応する勾配dL/dxですが、これは式(1)を用いて

    \begin{equation*}\frac{dL}{dx} = 4x(x^2 - 2)= 4 \times 1 \times (1 - 2)= -4\end{equation*}

と計算したものと等しいことが確認できます。
またその次の行でのx1.04ですが、式(2)を用いて

    \begin{equation*}x \leftarrow 1 - 0.01 \times -4 = 1.04\end{equation*}

と計算したものと等しいことが確認できます。
100反復後にはx=1.4142133ですので、おおよそ\sqrt{2}となっていることが確認できます。単精度で計算していますので、精度はこの程度が限界です。

終わりに

今回は簡単な方程式をPyTorchを使って解かせる例を示しました。ディープラーニングを用いたコードも似たような流れで処理をおこないますので、導入として参考になれば幸いです。

About The Author

エンジニアTetsu
機械学習を用いたデータ分析業務に従事。

Leave A Reply

*
*
* (公開されません)