読者です 読者をやめる 読者になる 読者になる

いものやま。

雑多な知識の寄せ集め

強化学習とニューラルネットワークを組合せてみた。(その5)

技術 AI 強化学習 ニューラルネットワーク Ruby

昨日はテーブル型のSarsa( \lambda)法の実装を行った。

今日はいよいよ関数近似ニューラルネットワークを使ったSarsa( \lambda)法の実装してみる。

NNSarsaComクラス

関数近似ニューラルネットワークを使ったSarsa( \lambda)法のクラスは、NNSarsaComクラスとした。

#====================
# nn_sarsa_com.rb
#--------------------
# ニューラルネットワークによる関数近似を使ったSarsa(λ) AI
#====================

require_relative "mark"
require_relative "state"
require_relative "value_nn"

class NNSarsaCom

# 続く

イニシャライザとアクセサメソッド

まずはイニシャライザとアクセサメソッドから。

# 続き

  def initialize(mark, hidden_unit_size=4, epsilon=0.1, step_size=0.01, td_lambda=0.6)
    @mark = mark
    @hidden_unit_size = hidden_unit_size
    @epsilon = epsilon
    @step_size = step_size
    @td_lambda = td_lambda

    @value_nn = ValueNN.new(9, @hidden_unit_size, -1.0, 1.0)

    @previous_state = nil
    @current_state = nil
    @accumulated_weights_gradient = nil

    @learn_mode = true
    @debug_mode = false
  end

  attr_reader :mark, :hidden_unit_size, :epsilon, :step_size, :td_lambda
  attr_accessor :learn_mode, :debug_mode
  alias_method :learn_mode?, :learn_mode
  alias_method :debug_mode?, :debug_mode

# 続く

基本的にはSarsaComと同じなんだけど、中間層(隠れ層)のユニット数が指定できるようにしている。
なお、入力の次元が9なので、次元を落とそうと思ってデフォルトを4にしたけど、少なかったみたい・・・

保存と復元

次は保存と復元。

# 続き

  def self.load(filename)
    nn_sarsa_com = nil
    File.open(filename) do |file|
      nn_sarsa_com = Marshal.load(file)
    end
    if nn_sarsa_com.mark == 1
      nn_sarsa_com.instance_variable_set :@mark, Mark::Maru
    elsif nn_sarsa_com.mark == -1
      nn_sarsa_com.instance_variable_set :@mark, Mark::Batsu
    end
    nn_sarsa_com
  end

  def save(filename)
    # singleton object can't be saved.
    original_mark = @mark
    @mark = @mark.to_i
    File.open(filename, "w") do |file|
      Marshal.dump(self, file)
    end
    @mark = original_mark
  end

# 続く

これはSarsaComと同じ。

なお、Rubyはダックタイピングに対応してるので、このloadメソッドはSarsaComのデータに対しても使えたり。
逆もまた然り。
(Marshalモジュールで保存されたデータが、自分の型が何なのかを知っているというのも大きい)

アクションの選択

次はアクションの選択。

# 続き

  def select_index(state)
    selected_action =
      if (!@learn_mode) || (Random.rand > @epsilon)
        state.valid_actions.max_by do |action|
          new_state = state.set(action, @mark)
          value, _ = @value_nn.get_value_and_weights_gradient(new_state.to_array)
          puts "action #{action} value: #{value}" if @debug_mode
          value
        end
      else
        puts "random" if @debug_mode
        state.valid_actions.sample
      end
    @current_state = state.set(selected_action, @mark)
    selected_action
  end

# 続く

これはSarsaComとほぼ同じ。
ただし、NNSarsaComの場合、状態の価値をニューラルネットワークから得るようにしている。

学習

そして学習。

# 続き

  def learn(reward)
    if @learn_mode && @previous_state
      puts "player #{@mark} learning..." if @debug_mode

      previous_state_array = @previous_state.to_array
      previous_value, weights_gradient = @value_nn.get_value_and_weights_gradient(previous_state_array)

      # normalize sensitivity for weights gradient

      alpha = 1.0
      upper_bound = nil
      lower_bound = nil
      100.times do
        new_value = @value_nn.get_value_with_weights_gradient(previous_state_array, weights_gradient, alpha)
        sensitivity = new_value - previous_value
        puts "alpha: #{alpha}, sensitivity: #{sensitivity}" if @debug_mode
        if sensitivity < 0.0
          upper_bound = alpha
          if lower_bound.nil?
            alpha /= 2.0
          else
            alpha = (upper_bound + lower_bound) / 2.0
          end
        elsif sensitivity < 0.9
          lower_bound = alpha
          if upper_bound.nil?
            alpha /= sensitivity
          else
            alpha = (upper_bound + lower_bound) / 2.0
          end
        elsif sensitivity > 1.1
          upper_bound = alpha
          if lower_bound.nil?
            alpha /= sensitivity
          else
            alpha = (upper_bound + lower_bound) / 2.0
          end
        else
          puts "OK. (alpha: #{alpha})" if @debug_mode
          break
        end
      end
      weights_gradient.map! {|i| i * alpha}

      # calculate accumulated weights gradient

      if @accumulated_weights_gradient.nil?
        @accumulated_weights_gradient = Array.new(weights_gradient.size, 0.0)
      end
      @accumulated_weights_gradient.size.times do |i|
        @accumulated_weights_gradient[i] *= @td_lambda
        @accumulated_weights_gradient[i] += weights_gradient[i]
      end

      # update weights by sarsa(λ)

      value_diff =
        if @current_state
          # normal state
          current_value, _ = @value_nn.get_value_and_weights_gradient(@current_state.to_array)
          reward + current_value - previous_value
        else
          # terminal state
          reward - previous_value
        end
      weights_diff = @accumulated_weights_gradient.map{|weight_gradient| @step_size * value_diff * weight_gradient}
      @value_nn.add_weights(weights_diff)

      if @debug_mode
        puts "previous value: #{previous_value}"
        puts "value diff: #{value_diff}"
        updated_value, _ = @value_nn.get_value_and_weights_gradient(previous_state_array)
        puts "updated value: #{updated_value}"
      end

      if @current_state.nil?
        @accumulated_weights_gradient = nil
      end
    end

    @previous_state = @current_state
    @current_state = nil
  end
end

# 続く

アルゴリズム自体はSarsaComと同じなんだけど、SarsaComと違って、適格度は観測された状態に対してではなく、出力の重みに関する勾配に対して保持するようになっている。

そして、これは自分のやった工夫なんだけど、出力の重みに関する勾配のサイズは、それを重みに加えたときの出力の変化が0.9〜1.1に収まるサイズにスカラー倍してから、適格度に足しこむようにしている。
(このサイズ自体はバイナリサーチで求めている)
こうすることで、状態価値に関する学習のステップサイズが、ニューラルネットワークの具合にあまり依存しないようになることを期待している。

動作確認

SarsaComと同様に、次のようなコードで動作確認を行った。
1,000,000回自己対戦させて、学習を行っている。
(学習の進捗の様子が見えるように、1,000回ごとに様子を出力させている)

# 続き

if __FILE__ == $PROGRAM_NAME
  require "pp"
  require_relative "game"

  maru_player = NNSarsaCom.new(Mark::Maru)
  batsu_player = NNSarsaCom.new(Mark::Batsu)

  1000000.times do |t|
    game = Game.new(maru_player, batsu_player)
    if t % 1000 == 0
      puts "[#{t}]"
      game.start(true)
    else
      game.start(false)
    end
  end

  maru_player.learn_mode = false
  maru_player.debug_mode = true
  batsu_player.learn_mode = false
  batsu_player.debug_mode = true
  game = Game.new(maru_player, batsu_player)
  game.start(true)

  maru_player.save("nn_sarsa_maru.dat")
  batsu_player.save("nn_sarsa_batsu.dat")
end

これを実行させると、次のような感じ:

$ ruby nn_sarsa_com.rb
[0]
...
...
...
player o selected 8.
...
...
..o
player x selected 7.
...
...
.xo
player o selected 4.
...
.o.
.xo
player x selected 6.
...
.o.
xxo
player o selected 2.
..o
.o.
xxo
player x selected 1.
.xo
.o.
xxo
player o selected 0.
oxo
.o.
xxo
player o win.
[1000]
...
...
...
player o selected 3.
...
o..
...
player x selected 7.
...
o..
.x.
player o selected 8.
...
o..
.xo
player x selected 4.
...
ox.
.xo
player o selected 1.
.o.
ox.
.xo
player x selected 5.
.o.
oxx
.xo
player o selected 0.
oo.
oxx
.xo
player x selected 6.
oo.
oxx
xxo
player o selected 2.
ooo
oxx
xxo
player o win.

〜省略〜

[999000]
...
...
...
player o selected 2.
..o
...
...
player x selected 1.
.xo
...
...
player o selected 4.
.xo
.o.
...
player x selected 5.
.xo
.ox
...
player o selected 6.
.xo
.ox
o..
player o win.
...
...
...
action 0 value: 0.7592360718972405
action 1 value: -0.02962809699190294
action 2 value: 1.0010091374988384
action 3 value: 0.2731015090611467
action 4 value: 0.8569475458440898
action 5 value: -0.2696471438720517
action 6 value: -0.2184444357247119
action 7 value: -0.4949949208793559
action 8 value: -0.606859908908191
player o selected 2.
..o
...
...
action 0 value: -1.026218720998526
action 1 value: -0.9993140827173385
action 3 value: -1.060684834616774
action 4 value: -1.0487096330739265
action 5 value: -1.0234193800622733
action 6 value: -1.093670386287518
action 7 value: -1.089395661793146
action 8 value: -1.0705706709856997
player x selected 1.
.xo
...
...
action 0 value: 0.9534550005960045
action 3 value: 0.9387837884586252
action 4 value: 1.0017041725072224
action 5 value: 0.9448860499554803
action 6 value: 0.9301529787361479
action 7 value: 0.9705340885365262
action 8 value: 0.998984222238604
player o selected 4.
.xo
.o.
...
action 0 value: -0.7719020005213131
action 3 value: -1.0119066098145815
action 5 value: -0.7464115526008064
action 6 value: -1.044892161485325
action 7 value: -1.0406174369909533
action 8 value: -1.021792446183507
player x selected 5.
.xo
.ox
...
action 0 value: 0.9575772928727357
action 3 value: 0.9429060807353564
action 6 value: 1.008786541455141
action 7 value: 0.9746563808132576
action 8 value: 1.0003106514515336
player o selected 6.
.xo
.ox
o..
player o win.

すごくダメっぽい(^^;

昨日と同じように、以下のコードでも確認してみる。

# human_com_game.rb

require_relative "mark"
require_relative "game"
require_relative "human_player"
require_relative "sarsa_com"
require_relative "nn_sarsa_com"

com = NNSarsaCom.load(ARGV[0])
human = HumanPlayer.new(com.mark.opponent)

game = case com.mark
       when Mark::Maru
         Game.new(com, human)
       when Mark::Batsu
         Game.new(human, com)
       end

loop do
  game.start(true)
end

保存されたデータを指定して実行してみると、以下のような感じ:

$ ruby human_com_game.rb nn_sarsa_maru.dat
...
...
...
action 0 value: 0.7592360718972405
action 1 value: -0.02962809699190294
action 2 value: 1.0010091374988384
action 3 value: 0.2731015090611467
action 4 value: 0.8569475458440898
action 5 value: -0.2696471438720517
action 6 value: -0.2184444357247119
action 7 value: -0.4949949208793559
action 8 value: -0.606859908908191
player o selected 2.
..o
...
...
<player: x>
select index [0,1,3,4,5,6,7,8]
4
player x selected 4.
..o
.x.
...
action 0 value: 0.8827367795931977
action 1 value: 0.2863637997239468
action 3 value: 0.3943872830131151
action 5 value: -0.03217893418419632
action 6 value: -0.09715866177274346
action 7 value: 0.0724506087927672
action 8 value: 0.26988853969294124
player o selected 0.
o.o
.x.
...
<player: x>
select index [1,3,5,6,7,8]
1
player x selected 1.
oxo
.x.
...
action 3 value: 0.8751970639824053
action 5 value: 0.5139473893158397
action 6 value: 0.4992143180965072
action 7 value: 0.8424632036099864
action 8 value: 0.714006948924472
player o selected 3.
oxo
ox.
...
<player: x>
select index [5,6,7,8]
7
player x selected 7.
oxo
ox.
.x.
player x win.
...
...
...
action 0 value: 0.7592360718972405
action 1 value: -0.02962809699190294
action 2 value: 1.0010091374988384
action 3 value: 0.2731015090611467
action 4 value: 0.8569475458440898
action 5 value: -0.2696471438720517
action 6 value: -0.2184444357247119
action 7 value: -0.4949949208793559
action 8 value: -0.606859908908191
player o selected 2.
..o
...
...
<player: x>
select index [0,1,3,4,5,6,7,8]
6
player x selected 6.
..o
...
x..
action 0 value: 0.6377490481476537
action 1 value: 0.5018942335201403
action 3 value: 0.8431770692159409
action 4 value: 0.7793616190634061
action 5 value: 0.6334926347244589
action 7 value: 0.03724369916404108
action 8 value: 0.10584746827493238
player o selected 3.
..o
o..
x..
<player: x>
select index [0,1,4,5,7,8]
7
player x selected 7.
..o
o..
xx.
action 0 value: 0.8480529532318108
action 1 value: 0.90468932762419
action 4 value: 0.9116396777080306
action 5 value: 0.9075173854312992
action 8 value: 0.8882184695746977
player o selected 4.
..o
oo.
xx.
<player: x>
select index [0,1,5,8]
8
player x selected 8.
..o
oo.
xxx
player x win.
^C
$ ruby human_com_game.rb nn_sarsa_batsu.dat
...
...
...
<player: o>
select index [0,1,2,3,4,5,6,7,8]
4
player o selected 4.
...
.o.
...
action 0 value: -0.11263135897280074
action 1 value: -0.026741732461178887
action 2 value: -0.0015600399785638958
action 3 value: -0.47377071673107196
action 5 value: -0.11229485753504664
action 6 value: -0.7971547383434711
action 7 value: -0.7950680266663384
action 8 value: -0.5586538202865576
player x selected 2.
..x
.o.
...
<player: o>
select index [0,1,3,5,6,7,8]
1
player o selected 1.
.ox
.o.
...
action 0 value: 0.5466392825757119
action 3 value: -0.8169735796909625
action 5 value: -0.6860791771503817
action 6 value: -0.8626589528575106
action 7 value: -0.7582177931750976
action 8 value: -0.4319130709577208
player x selected 0.
xox
.o.
...
action 0 value: 0.5466392825757119
action 3 value: -0.8169735796909625
action 5 value: -0.6860791771503817
action 6 value: -0.8626589528575106
action 7 value: -0.7582177931750976
action 8 value: -0.4319130709577208
player x selected 0.
xox
.o.
...
<player: o>
select index [3,5,6,7,8]
7
player o selected 7.
xox
.o.
.o.
player o win.
...
...
...
<player: o>
select index [0,1,2,3,4,5,6,7,8]
8
player o selected 8.
...
...
..o
action 0 value: -0.36138807054463506
action 1 value: 0.36634918597083294
action 2 value: 0.7442668263666679
action 3 value: -0.8783767705333496
action 4 value: -0.5391213449136951
action 5 value: -0.6038845137750162
action 6 value: -0.5710405941870056
action 7 value: -0.5282933492432882
player x selected 2.
..x
...
..o
<player: o>
select index [0,1,3,4,5,6,7]
0
player o selected 0.
o.x
...
..o
action 1 value: -0.8594460799508006
action 3 value: -1.08879319399123
action 4 value: -0.5426204205094465
action 5 value: -1.063015965885905
action 6 value: -1.0828717975594055
action 7 value: -1.074869481351242
player x selected 4.
o.x
.x.
..o
<player: o>
select index [1,3,5,6,7]
6
player o selected 6.
o.x
.x.
o.o
action 1 value: -0.23013666375660302
action 3 value: -1.1078864795936505
action 5 value: -1.0614953533855485
action 7 value: -0.4526156005701207
player x selected 1.
oxx
.x.
o.o
<player: o>
select index [3,5,7]
3
player o selected 3.
oxx
ox.
o.o
player o win.
^C

弱すぎw

さて、こうなったときに困るのが、原因が簡単に分からないこと。

SarsaComが学習できていたことを考えると、Sarsa( \lambda)法自体はおそらく間違ってなくて、ValueNNの動作確認を見ても、勾配の方向に進めばちゃんと出力される値が増えていたことから、ニューラルネットワークの学習アルゴリズム自体もおそらくは間違っていない。
けど、うまくいっていない。

うまくいっていない原因として、次のようなことが考えられる:

このとき、普通のニューラルネットワークのように、教師あり学習であれば、訓練誤差の様子から学習が収束しているのかどうか判断できるし、テスト誤差の様子から過適合が起きているのかどうか判断することも出来るのだけど、強化学習だとそれが出来ないのが難しいところ。

ステップサイズについては、「BirdHead」の思考ルーチンを作ってみた。(その5) - いものやま。で最初に失敗した、パラメータが途中で+/-∞に発散してしまうようなことは起きていないので(これは自分の工夫の成果だと思っている)、おそらく大きすぎるということはないのだろうけど、小さすぎて学習が遅くなっているという可能性はある。
この判断も難しい。

このあたりの、学習がうまくいっているのかどうか/どれくらい進んでいるのかどうかを測る手法が、強化学習の研究では足りてないように思う。。。

今日はここまで!