昨日はテーブル型のSarsa()法の実装を行った。
今日はいよいよ関数近似にニューラルネットワークを使ったSarsa()法の実装してみる。
NNSarsaComクラス
関数近似にニューラルネットワークを使ったSarsa()法のクラスは、NNSarsaComクラスとした。
#==================== # nn_sarsa_com.rb #-------------------- # ニューラルネットワークによる関数近似を使ったSarsa(λ) AI #==================== require_relative "mark" require_relative "state" require_relative "value_nn" class NNSarsaCom # 続く
イニシャライザとアクセサメソッド
まずはイニシャライザとアクセサメソッドから。
# 続き def initialize(mark, hidden_unit_size=4, epsilon=0.1, step_size=0.01, td_lambda=0.6) @mark = mark @hidden_unit_size = hidden_unit_size @epsilon = epsilon @step_size = step_size @td_lambda = td_lambda @value_nn = ValueNN.new(9, @hidden_unit_size, -1.0, 1.0) @previous_state = nil @current_state = nil @accumulated_weights_gradient = nil @learn_mode = true @debug_mode = false end attr_reader :mark, :hidden_unit_size, :epsilon, :step_size, :td_lambda attr_accessor :learn_mode, :debug_mode alias_method :learn_mode?, :learn_mode alias_method :debug_mode?, :debug_mode # 続く
基本的にはSarsaComと同じなんだけど、中間層(隠れ層)のユニット数が指定できるようにしている。
なお、入力の次元が9なので、次元を落とそうと思ってデフォルトを4にしたけど、少なかったみたい・・・
保存と復元
次は保存と復元。
# 続き def self.load(filename) nn_sarsa_com = nil File.open(filename) do |file| nn_sarsa_com = Marshal.load(file) end if nn_sarsa_com.mark == 1 nn_sarsa_com.instance_variable_set :@mark, Mark::Maru elsif nn_sarsa_com.mark == -1 nn_sarsa_com.instance_variable_set :@mark, Mark::Batsu end nn_sarsa_com end def save(filename) # singleton object can't be saved. original_mark = @mark @mark = @mark.to_i File.open(filename, "w") do |file| Marshal.dump(self, file) end @mark = original_mark end # 続く
これはSarsaComと同じ。
なお、Rubyはダックタイピングに対応してるので、このload
メソッドはSarsaComのデータに対しても使えたり。
逆もまた然り。
(Marshalモジュールで保存されたデータが、自分の型が何なのかを知っているというのも大きい)
アクションの選択
次はアクションの選択。
# 続き def select_index(state) selected_action = if (!@learn_mode) || (Random.rand > @epsilon) state.valid_actions.max_by do |action| new_state = state.set(action, @mark) value, _ = @value_nn.get_value_and_weights_gradient(new_state.to_array) puts "action #{action} value: #{value}" if @debug_mode value end else puts "random" if @debug_mode state.valid_actions.sample end @current_state = state.set(selected_action, @mark) selected_action end # 続く
これはSarsaComとほぼ同じ。
ただし、NNSarsaComの場合、状態の価値をニューラルネットワークから得るようにしている。
学習
そして学習。
# 続き def learn(reward) if @learn_mode && @previous_state puts "player #{@mark} learning..." if @debug_mode previous_state_array = @previous_state.to_array previous_value, weights_gradient = @value_nn.get_value_and_weights_gradient(previous_state_array) # normalize sensitivity for weights gradient alpha = 1.0 upper_bound = nil lower_bound = nil 100.times do new_value = @value_nn.get_value_with_weights_gradient(previous_state_array, weights_gradient, alpha) sensitivity = new_value - previous_value puts "alpha: #{alpha}, sensitivity: #{sensitivity}" if @debug_mode if sensitivity < 0.0 upper_bound = alpha if lower_bound.nil? alpha /= 2.0 else alpha = (upper_bound + lower_bound) / 2.0 end elsif sensitivity < 0.9 lower_bound = alpha if upper_bound.nil? alpha /= sensitivity else alpha = (upper_bound + lower_bound) / 2.0 end elsif sensitivity > 1.1 upper_bound = alpha if lower_bound.nil? alpha /= sensitivity else alpha = (upper_bound + lower_bound) / 2.0 end else puts "OK. (alpha: #{alpha})" if @debug_mode break end end weights_gradient.map! {|i| i * alpha} # calculate accumulated weights gradient if @accumulated_weights_gradient.nil? @accumulated_weights_gradient = Array.new(weights_gradient.size, 0.0) end @accumulated_weights_gradient.size.times do |i| @accumulated_weights_gradient[i] *= @td_lambda @accumulated_weights_gradient[i] += weights_gradient[i] end # update weights by sarsa(λ) value_diff = if @current_state # normal state current_value, _ = @value_nn.get_value_and_weights_gradient(@current_state.to_array) reward + current_value - previous_value else # terminal state reward - previous_value end weights_diff = @accumulated_weights_gradient.map{|weight_gradient| @step_size * value_diff * weight_gradient} @value_nn.add_weights(weights_diff) if @debug_mode puts "previous value: #{previous_value}" puts "value diff: #{value_diff}" updated_value, _ = @value_nn.get_value_and_weights_gradient(previous_state_array) puts "updated value: #{updated_value}" end if @current_state.nil? @accumulated_weights_gradient = nil end end @previous_state = @current_state @current_state = nil end end # 続く
アルゴリズム自体はSarsaComと同じなんだけど、SarsaComと違って、適格度は観測された状態に対してではなく、出力の重みに関する勾配に対して保持するようになっている。
そして、これは自分のやった工夫なんだけど、出力の重みに関する勾配のサイズは、それを重みに加えたときの出力の変化が0.9〜1.1に収まるサイズにスカラー倍してから、適格度に足しこむようにしている。
(このサイズ自体はバイナリサーチで求めている)
こうすることで、状態価値に関する学習のステップサイズが、ニューラルネットワークの具合にあまり依存しないようになることを期待している。
動作確認
SarsaComと同様に、次のようなコードで動作確認を行った。
1,000,000回自己対戦させて、学習を行っている。
(学習の進捗の様子が見えるように、1,000回ごとに様子を出力させている)
# 続き if __FILE__ == $PROGRAM_NAME require "pp" require_relative "game" maru_player = NNSarsaCom.new(Mark::Maru) batsu_player = NNSarsaCom.new(Mark::Batsu) 1000000.times do |t| game = Game.new(maru_player, batsu_player) if t % 1000 == 0 puts "[#{t}]" game.start(true) else game.start(false) end end maru_player.learn_mode = false maru_player.debug_mode = true batsu_player.learn_mode = false batsu_player.debug_mode = true game = Game.new(maru_player, batsu_player) game.start(true) maru_player.save("nn_sarsa_maru.dat") batsu_player.save("nn_sarsa_batsu.dat") end
これを実行させると、次のような感じ:
$ ruby nn_sarsa_com.rb [0] ... ... ... player o selected 8. ... ... ..o player x selected 7. ... ... .xo player o selected 4. ... .o. .xo player x selected 6. ... .o. xxo player o selected 2. ..o .o. xxo player x selected 1. .xo .o. xxo player o selected 0. oxo .o. xxo player o win. [1000] ... ... ... player o selected 3. ... o.. ... player x selected 7. ... o.. .x. player o selected 8. ... o.. .xo player x selected 4. ... ox. .xo player o selected 1. .o. ox. .xo player x selected 5. .o. oxx .xo player o selected 0. oo. oxx .xo player x selected 6. oo. oxx xxo player o selected 2. ooo oxx xxo player o win. 〜省略〜 [999000] ... ... ... player o selected 2. ..o ... ... player x selected 1. .xo ... ... player o selected 4. .xo .o. ... player x selected 5. .xo .ox ... player o selected 6. .xo .ox o.. player o win. ... ... ... action 0 value: 0.7592360718972405 action 1 value: -0.02962809699190294 action 2 value: 1.0010091374988384 action 3 value: 0.2731015090611467 action 4 value: 0.8569475458440898 action 5 value: -0.2696471438720517 action 6 value: -0.2184444357247119 action 7 value: -0.4949949208793559 action 8 value: -0.606859908908191 player o selected 2. ..o ... ... action 0 value: -1.026218720998526 action 1 value: -0.9993140827173385 action 3 value: -1.060684834616774 action 4 value: -1.0487096330739265 action 5 value: -1.0234193800622733 action 6 value: -1.093670386287518 action 7 value: -1.089395661793146 action 8 value: -1.0705706709856997 player x selected 1. .xo ... ... action 0 value: 0.9534550005960045 action 3 value: 0.9387837884586252 action 4 value: 1.0017041725072224 action 5 value: 0.9448860499554803 action 6 value: 0.9301529787361479 action 7 value: 0.9705340885365262 action 8 value: 0.998984222238604 player o selected 4. .xo .o. ... action 0 value: -0.7719020005213131 action 3 value: -1.0119066098145815 action 5 value: -0.7464115526008064 action 6 value: -1.044892161485325 action 7 value: -1.0406174369909533 action 8 value: -1.021792446183507 player x selected 5. .xo .ox ... action 0 value: 0.9575772928727357 action 3 value: 0.9429060807353564 action 6 value: 1.008786541455141 action 7 value: 0.9746563808132576 action 8 value: 1.0003106514515336 player o selected 6. .xo .ox o.. player o win.
すごくダメっぽい(^^;
昨日と同じように、以下のコードでも確認してみる。
# human_com_game.rb require_relative "mark" require_relative "game" require_relative "human_player" require_relative "sarsa_com" require_relative "nn_sarsa_com" com = NNSarsaCom.load(ARGV[0]) human = HumanPlayer.new(com.mark.opponent) game = case com.mark when Mark::Maru Game.new(com, human) when Mark::Batsu Game.new(human, com) end loop do game.start(true) end
保存されたデータを指定して実行してみると、以下のような感じ:
$ ruby human_com_game.rb nn_sarsa_maru.dat ... ... ... action 0 value: 0.7592360718972405 action 1 value: -0.02962809699190294 action 2 value: 1.0010091374988384 action 3 value: 0.2731015090611467 action 4 value: 0.8569475458440898 action 5 value: -0.2696471438720517 action 6 value: -0.2184444357247119 action 7 value: -0.4949949208793559 action 8 value: -0.606859908908191 player o selected 2. ..o ... ... <player: x> select index [0,1,3,4,5,6,7,8] 4 player x selected 4. ..o .x. ... action 0 value: 0.8827367795931977 action 1 value: 0.2863637997239468 action 3 value: 0.3943872830131151 action 5 value: -0.03217893418419632 action 6 value: -0.09715866177274346 action 7 value: 0.0724506087927672 action 8 value: 0.26988853969294124 player o selected 0. o.o .x. ... <player: x> select index [1,3,5,6,7,8] 1 player x selected 1. oxo .x. ... action 3 value: 0.8751970639824053 action 5 value: 0.5139473893158397 action 6 value: 0.4992143180965072 action 7 value: 0.8424632036099864 action 8 value: 0.714006948924472 player o selected 3. oxo ox. ... <player: x> select index [5,6,7,8] 7 player x selected 7. oxo ox. .x. player x win. ... ... ... action 0 value: 0.7592360718972405 action 1 value: -0.02962809699190294 action 2 value: 1.0010091374988384 action 3 value: 0.2731015090611467 action 4 value: 0.8569475458440898 action 5 value: -0.2696471438720517 action 6 value: -0.2184444357247119 action 7 value: -0.4949949208793559 action 8 value: -0.606859908908191 player o selected 2. ..o ... ... <player: x> select index [0,1,3,4,5,6,7,8] 6 player x selected 6. ..o ... x.. action 0 value: 0.6377490481476537 action 1 value: 0.5018942335201403 action 3 value: 0.8431770692159409 action 4 value: 0.7793616190634061 action 5 value: 0.6334926347244589 action 7 value: 0.03724369916404108 action 8 value: 0.10584746827493238 player o selected 3. ..o o.. x.. <player: x> select index [0,1,4,5,7,8] 7 player x selected 7. ..o o.. xx. action 0 value: 0.8480529532318108 action 1 value: 0.90468932762419 action 4 value: 0.9116396777080306 action 5 value: 0.9075173854312992 action 8 value: 0.8882184695746977 player o selected 4. ..o oo. xx. <player: x> select index [0,1,5,8] 8 player x selected 8. ..o oo. xxx player x win. ^C
$ ruby human_com_game.rb nn_sarsa_batsu.dat ... ... ... <player: o> select index [0,1,2,3,4,5,6,7,8] 4 player o selected 4. ... .o. ... action 0 value: -0.11263135897280074 action 1 value: -0.026741732461178887 action 2 value: -0.0015600399785638958 action 3 value: -0.47377071673107196 action 5 value: -0.11229485753504664 action 6 value: -0.7971547383434711 action 7 value: -0.7950680266663384 action 8 value: -0.5586538202865576 player x selected 2. ..x .o. ... <player: o> select index [0,1,3,5,6,7,8] 1 player o selected 1. .ox .o. ... action 0 value: 0.5466392825757119 action 3 value: -0.8169735796909625 action 5 value: -0.6860791771503817 action 6 value: -0.8626589528575106 action 7 value: -0.7582177931750976 action 8 value: -0.4319130709577208 player x selected 0. xox .o. ... action 0 value: 0.5466392825757119 action 3 value: -0.8169735796909625 action 5 value: -0.6860791771503817 action 6 value: -0.8626589528575106 action 7 value: -0.7582177931750976 action 8 value: -0.4319130709577208 player x selected 0. xox .o. ... <player: o> select index [3,5,6,7,8] 7 player o selected 7. xox .o. .o. player o win. ... ... ... <player: o> select index [0,1,2,3,4,5,6,7,8] 8 player o selected 8. ... ... ..o action 0 value: -0.36138807054463506 action 1 value: 0.36634918597083294 action 2 value: 0.7442668263666679 action 3 value: -0.8783767705333496 action 4 value: -0.5391213449136951 action 5 value: -0.6038845137750162 action 6 value: -0.5710405941870056 action 7 value: -0.5282933492432882 player x selected 2. ..x ... ..o <player: o> select index [0,1,3,4,5,6,7] 0 player o selected 0. o.x ... ..o action 1 value: -0.8594460799508006 action 3 value: -1.08879319399123 action 4 value: -0.5426204205094465 action 5 value: -1.063015965885905 action 6 value: -1.0828717975594055 action 7 value: -1.074869481351242 player x selected 4. o.x .x. ..o <player: o> select index [1,3,5,6,7] 6 player o selected 6. o.x .x. o.o action 1 value: -0.23013666375660302 action 3 value: -1.1078864795936505 action 5 value: -1.0614953533855485 action 7 value: -0.4526156005701207 player x selected 1. oxx .x. o.o <player: o> select index [3,5,7] 3 player o selected 3. oxx ox. o.o player o win. ^C
弱すぎw
さて、こうなったときに困るのが、原因が簡単に分からないこと。
SarsaComが学習できていたことを考えると、Sarsa()法自体はおそらく間違ってなくて、ValueNNの動作確認を見ても、勾配の方向に進めばちゃんと出力される値が増えていたことから、ニューラルネットワークの学習アルゴリズム自体もおそらくは間違っていない。
けど、うまくいっていない。
うまくいっていない原因として、次のようなことが考えられる:
- ニューラルネットワークの中間層のユニット数が適切でない
- ステップサイズが適切でない
- の値が適切でない
- 学習回数が足りていない
- ニューラルネットワークが過適合を起こしている
このとき、普通のニューラルネットワークのように、教師あり学習であれば、訓練誤差の様子から学習が収束しているのかどうか判断できるし、テスト誤差の様子から過適合が起きているのかどうか判断することも出来るのだけど、強化学習だとそれが出来ないのが難しいところ。
ステップサイズについては、「BirdHead」の思考ルーチンを作ってみた。(その5) - いものやま。で最初に失敗した、パラメータが途中で+/-∞に発散してしまうようなことは起きていないので(これは自分の工夫の成果だと思っている)、おそらく大きすぎるということはないのだろうけど、小さすぎて学習が遅くなっているという可能性はある。
この判断も難しい。
このあたりの、学習がうまくいっているのかどうか/どれくらい進んでいるのかどうかを測る手法が、強化学習の研究では足りてないように思う。。。
今日はここまで!