読者です 読者をやめる 読者になる 読者になる

いものやま。

雑多な知識の寄せ集め

強化学習用のニューラルネットワークをSwiftで書いてみた。(その2)

技術 AI 強化学習 ニューラルネットワーク HME Swift

昨日は乱数生成器の実装を行った。

今日は強化学習用のニューラルネットワークの計算を行列で表現する。

強化学習用のニューラルネットワークの計算

説明を簡単にするために、ここでは次のようなニューラルネットワークを考える:

  • 3層ニューラルネットワーク
  • 入力は  \boldsymbol{x} \in \mathbb{R}^{n}、出力は  y \in \mathbb{R}
  • 中間層のユニット数は  m
    • 各ユニットの重みは  \boldsymbol{w}_i^{(h)} \in \mathbb{R}^{n}、バイアスは  b_i^{(h)} \in \mathbb{R}
    • 活性化関数は  f^{(h)} : \mathbb{R} \rightarrow \mathbb{R}
  • 出力層のユニット数は1
    • ユニットの重みは  \boldsymbol{w}^{(o)} \in \mathbb{R}^{m}、バイアスは  b^{(o)} \in \mathbb{R}
    • 活性化関数は  f^{(o)} : \mathbb{R} \rightarrow \mathbb{R}

行列で表現しない計算

まず、おさらいとして、行列で表現していない計算を書いておく:

  1. 入力  \boldsymbol{x} から、中間層の出力  \boldsymbol{z} \in \mathbb{R}^{m} を求める:
    1.  u^{(h)}_i = \boldsymbol{w}_i^{(h)} {}^{\mathrm{T}} \boldsymbol{x} + b^{(h)}_i
    2.  z_i = f^{(h)}(u^{(h)}_i)
  2. 中間層の出力  \boldsymbol{z} から、出力層の出力  y を求める:
    1.  u^{(o)} = \boldsymbol{w}^{(o)} {}^{\mathrm{T}} \boldsymbol{z} + b^{(o)}
    2.  y = f^{(o)}(u^{(o)})
  3. 出力層のデルタ  \delta^{(o)} \in \mathbb{R} を求める:
    1.  \delta^{(o)} = {f^{(o)}}'(u^{(o)})
  4. 中間層のデルタ  \boldsymbol{\delta}^{(h)} \in \mathbb{R}^{m} を求める:
    1.  \delta^{(h)}_i = {f^{(h)}}'(u^{(h)}_i) \delta^{(o)} w^{(o)}_i
  5. 偏微分を求める:
    1.  \frac{\partial y}{\partial \boldsymbol{w}^{(h)}_i} = \delta^{(h)}_i \boldsymbol{x}
    2.  \frac{\partial y}{\partial b^{(h)}_i} = \delta^{(h)}_i
    3.  \frac{\partial y}{\partial \boldsymbol{w}^{(o)}} = \delta^{(o)} \boldsymbol{z}
    4.  \frac{\partial y}{\partial b^{(o)}} = \delta^{(o)}

行列で表現した計算

ここで、中間層の重みを  W^{(h)} \in \mathbb{R}^{m \times n} で次のように表すことにする:

 {
W^{(h)} = \left( \begin{array}{c}
\boldsymbol{w}^{(h)}_1 {}^{\mathrm{T}} \\
\vdots \\
\boldsymbol{w}^{(h)}_m {}^{\mathrm{T}}
\end{array} \right)
}

そして、表記をもう少しスマートにすると、上の計算は次のように書き直すことが出来る:

  1. 入力  \boldsymbol{x} から、中間層の出力  \boldsymbol{z} を求める:
    1.  \boldsymbol{u}^{(h)} = W^{(h)} \boldsymbol{x} + \boldsymbol{b}^{(h)}
    2.  \boldsymbol{z} = f^{(h)}(\boldsymbol{u}^{(h)})
  2. 中間層の出力  \boldsymbol{z} から、出力層の出力  y を求める:
    1.  u^{(o)} = \boldsymbol{w}^{(o)} {}^{\mathrm{T}} \boldsymbol{z} + b^{(o)}
    2.  y = f^{(o)}(u^{(o)})
  3. 出力層のデルタ  \delta^{(o)} \in \mathbb{R} を求める:
    1.  \delta^{(o)} = {f^{(o)}}'(u^{(o)})
  4. 中間層のデルタ  \boldsymbol{\delta}^{(h)} \in \mathbb{R}^{m} を求める:
    1.  \boldsymbol{\delta}^{(h)} = {f^{(h)}}'(\boldsymbol{u}^{(h)}) \odot \left(\delta^{(o)} \boldsymbol{w}^{(o)} \right)
  5. 偏微分を求める:
    1.  \frac{\partial y}{\partial W^{(h)}} = \boldsymbol{\delta}^{(h)} \otimes \boldsymbol{x}
    2.  \frac{\partial y}{\partial \boldsymbol{b}^{(h)}} = \boldsymbol{\delta}^{(h)}
    3.  \frac{\partial y}{\partial \boldsymbol{w}^{(o)}} = \delta^{(o)} \boldsymbol{z}
    4.  \frac{\partial y}{\partial b^{(o)}} = \delta^{(o)}

なお、表記法として、関数  f の引数にベクトル(や行列)が来ている場合、各要素に関数  f を適用したベクトル(や行列)を返すものとする。
また、 \odot は行列(ベクトルを含む)のアダマール積(要素ごとの積)、 \otimes はベクトルの直積(外積)を表すものとする。

ここまでくれば、Swiftでの行列計算について調べてみた。(その4) - いものやま。で作ったクラスを使って計算が出来るようになる。

中間層や出力の次元が増えた場合

おまけとして、中間層が増えた場合や、出力の次元が増えた場合の計算も書いておく。

以下では、各層  (l = 2, \cdots, L) の重みが  W^{(l)}、バイアスが  \boldsymbol{b}^{(l)}、活性化関数が  f^{(l)} とする。
(第1層は入力層、第  L 層は出力層)

  1. 入力層の出力  \boldsymbol{z}^{(1)} \boldsymbol{z}^{(1)} = \boldsymbol{x} とする。
  2.  l (l = 2, \cdots, L) の出力  \boldsymbol{z}^{(l)} を求める:
    1.  \boldsymbol{u}^{(l)} = W^{(l)} \boldsymbol{z}^{(l - 1)} + \boldsymbol{b}^{(l)}
    2.  \boldsymbol{z}^{(l)} = f^{(l)}(\boldsymbol{u}^{(l)})
  3. 出力  \boldsymbol{y} \boldsymbol{y} = \boldsymbol{z}^{(L)} とする。
  4. 出力層のデルタ  \boldsymbol{\delta}^{(L)} を求める:
    1.  \boldsymbol{\delta}^{(L)} = {f^{(L)}}'(\boldsymbol{u}^{(L)})
  5.  l (l = L-1, \cdots, 2) のデルタ  \boldsymbol{\delta}^{(l)} を求める:
    1.  \boldsymbol{\delta}^{(l)} = {f^{(h)}}'(\boldsymbol{u}^{(l)}) \odot \left( W^{(l + 1)} {}^{\mathrm{T}} \boldsymbol{\delta}^{(l + 1)} \right)
  6. 偏微分  (l = 2, \cdots, L) を求める:
    1.  \frac{\partial y}{\partial W^{(l)}} = \boldsymbol{\delta}^{(l)} \otimes \boldsymbol{z}^{(l - 1)}
    2.  \frac{\partial y}{\partial \boldsymbol{b}^{(l)}} = \boldsymbol{\delta}^{(l)}

今日はここまで!