昨日は○×ゲームを人がプレイできるようにするところまで実装した。
今日はテーブル型のSarsa()法を実装する。
SarsaComクラス
ということで、さっそく。
#==================== # sarsa_com.rb #-------------------- # テーブル型のSarsa(λ) AI #==================== require_relative "mark" require_relative "state" class SarsaCom # 続く
イニシャライザとアクセサメソッド
まずはイニシャライザとアクセサメソッドから。
# 続き def initialize(mark, epsilon=0.1, step_size=0.1, td_lambda=0.6) @mark = mark @epsilon = epsilon @step_size = step_size @td_lambda = td_lambda @value = Hash.new(0.0) @previous_state = nil @current_state = nil @accumulated_weights = Hash.new(0.0) @learn_mode = true @debug_mode = false end attr_reader :mark, :epsilon, :step_size, :td_lambda attr_accessor :learn_mode, :debug_mode alias_method :learn_mode?, :learn_mode alias_method :debug_mode?, :debug_mode # 続く
Sarsa()法では、直前の状態を保持(バックアップ)しておいて、次の状態と報酬が観測されたら、その価値の差分を使って学習を行う。
そのために、@previous_state
と@current_state
というインスタンス変数を用意している。
また、観測された各状態に対する適格度を保持する必要があるので、@accumulated_weights
というインスタンス変数を用意している。
保存と復元
次に保存と復元。
ファイルにデータを保存し、そこからオブジェクトを復元できるようにしている。
# 続き def self.load(filename) sarsa_com = nil File.open(filename) do |file| sarsa_com = Marshal.load(file) end if sarsa_com.mark == 1 sarsa_com.instance_variable_set :@mark, Mark::Maru elsif sarsa_com.mark == -1 sarsa_com.instance_variable_set :@mark, Mark::Batsu end sarsa_com end def save(filename) # singleton object can't be saved. original_mark = @mark @mark = @mark.to_i File.open(filename, "w") do |file| Marshal.dump(self, file) end @mark = original_mark end # 続く
RubyではMarshalモジュールを使うことで、オブジェクトの保存と復元が簡単に出来るようになっているので、それを使っている。
ただし、特異オブジェクトは保存できないので、マークについては一度数値に置き換えて保存して、復元後にオブジェクトに戻している。
(このとき、instance_variable_setメソッドでちょっと無理をしてる・・・)
アクションの選択
次はアクションの選択。
# 続き def select_index(state) selected_action = if (!@learn_mode) || (Random.rand > @epsilon) state.valid_actions.max_by do |action| new_state = state.set(action, @mark) value = @value[new_state.to_array] puts "action #{action} value: #{value}" if @debug_mode value end else puts "random" if @debug_mode state.valid_actions.sample end @current_state = state.set(selected_action, @mark) selected_action end # 続く
学習していない場合、もしくは、0.0〜1.0の乱数で@epsilon
より大きい値が出た場合には、価値が最も高いアクションを選択するようになっている。
そうでない場合はランダムにアクションを選ぶ。
それと、学習を行えるようにするために、現在の状態としてアクションを行ったあとの状態を保持するようにしている。
学習
そして学習。
# 続き def learn(reward) if @learn_mode && @previous_state puts "player #{@mark} learning..." if @debug_mode previous_state_array = @previous_state.to_array previous_value = @value[previous_state_array] # calculate accumulated weights @accumulated_weights.each_key do |state| @accumulated_weights[state] *= @td_lambda end @accumulated_weights[previous_state_array] += 1.0 # update values by sarsa(λ) value_diff = if @current_state # normal state current_value = @value[@current_state.to_array] reward + current_value - previous_value else # terminal state reward - previous_value end @accumulated_weights.each do |state, weight| @value[state] += @step_size * value_diff * weight end if @debug_mode puts "previous value: #{previous_value}" puts "value diff: #{value_diff}" updated_value = @value[previous_state_array] puts "updated value: #{updated_value}" end if @current_state.nil? @accumulated_weights = Hash.new(0.0) end end @previous_state = @current_state @current_state = nil end end # 続く
まずは適格度の計算。
これまでに観測されている各状態の適格度を@td_lambda
倍したあと、前回観測された状態の適格度に1を追加している。
そのあとは価値の更新。
今回の状態と前回の状態の価値の差分、それと、ステップサイズ、適格度を使って、各状態の価値を更新している。
このとき、学習の最後に今回の状態をnil
にしているのがポイントで、状態が更新された直後とゲームが終わったときにlearn
メソッドは呼ばれるようにしているわけだけど、今回の状態がnil
になっていれば、それはゲームが終わったときのlearn
メソッドの呼び出しだということになる。
最後に、今回の状態を前回の状態にセットして、今回の状態をnil
にしたら、その学習はオシマイ。
動作確認
次のようなコードで動作確認を行った。
1,000,000回自己対戦させて、学習を行っている。
(学習の進捗の様子が見えるように、1,000回ごとに様子を出力させている)
# 続き if __FILE__ == $PROGRAM_NAME require_relative "game" maru_player = SarsaCom.new(Mark::Maru, 0.1, 0.01, 0.6) batsu_player = SarsaCom.new(Mark::Batsu, 0.1, 0.01, 0.6) 1000000.times do |t| game = Game.new(maru_player, batsu_player) if t % 1000 == 0 puts "[#{t}]" game.start(true) else game.start(false) end end maru_player.learn_mode = false maru_player.debug_mode = true batsu_player.learn_mode = false batsu_player.debug_mode = true game = Game.new(maru_player, batsu_player) game.start(true) maru_player.save("sarsa_maru.dat") batsu_player.save("sarsa_batsu.dat") end
これを実行させると、次のような感じ:
$ ruby sarsa_com.rb [0] ... ... ... player o selected 0. o.. ... ... player x selected 1. ox. ... ... player o selected 2. oxo ... ... player x selected 3. oxo x.. ... player o selected 4. oxo xo. ... player x selected 5. oxo xox ... player o selected 6. oxo xox o.. player o win. [1000] ... ... ... player o selected 2. ..o ... ... player x selected 1. .xo ... ... player o selected 6. .xo ... o.. player x selected 4. .xo .x. o.. player o selected 8. .xo .x. o.o player x selected 3. .xo xx. o.o player o selected 5. .xo xxo o.o player o win. 〜省略〜 [999000] ... ... ... player o selected 4. ... .o. ... player x selected 0. x.. .o. ... player o selected 3. x.. oo. ... player x selected 5. x.. oox ... player o selected 2. x.o oox ... player x selected 6. x.o oox x.. player o selected 1. xoo oox x.. player x selected 7. xoo oox xx. player o selected 8. xoo oox xxo draw. ... ... ... action 0 value: 0.09545066928020114 action 1 value: 0.047119087194406266 action 2 value: 0.05029263252526899 action 3 value: 0.057371301967449294 action 4 value: 0.18533061077349378 action 5 value: 0.017655815132714626 action 6 value: 0.07453297520689049 action 7 value: 0.007029664392013386 action 8 value: 0.08874405020094858 player o selected 4. ... .o. ... action 0 value: -0.2195070236015553 action 1 value: -0.8190184559250111 action 2 value: -0.22746690964126007 action 3 value: -0.7703563953273663 action 5 value: -0.7042020965239122 action 6 value: -0.1591522382427485 action 7 value: -0.8103658863552381 action 8 value: -0.21889673563383863 player x selected 6. ... .o. x.. action 0 value: 0.1191473834627852 action 1 value: 0.018151043785580488 action 2 value: 0.09148836300654319 action 3 value: 0.16131567122492052 action 5 value: 0.11094255223526538 action 7 value: 0.10507590814082798 action 8 value: 0.10769901798606155 player o selected 3. ... oo. x.. action 0 value: -0.9654167705302769 action 1 value: -0.9053813591855598 action 2 value: -0.8808791365492502 action 5 value: -0.09241547417887054 action 7 value: -0.900615656533357 action 8 value: -0.8650737523109108 player x selected 5. ... oox x.. action 0 value: -0.8040332683593349 action 1 value: 0.05216871763548499 action 2 value: -0.0048845731200206945 action 7 value: 0.08644690293184881 action 8 value: 0.10186656779105396 player o selected 8. ... oox x.o action 0 value: -0.053733589672292296 action 1 value: -0.9100637894269122 action 2 value: -0.9990313270740123 action 7 value: -0.9425510167439461 player x selected 0. x.. oox x.o action 1 value: 0.04252600215533353 action 2 value: 0.0 action 7 value: 0.05787931090602695 player o selected 7. x.. oox xoo action 1 value: 0.0 action 2 value: -0.9999989310775225 player x selected 1. xx. oox xoo action 2 value: 0.0 player o selected 2. xxo oox xoo draw.
それっぽく学習できていることが分かる。
実際に学習が出来ているのか試すために、次のようなコードを用意した:
# human_com_game.rb require_relative "mark" require_relative "game" require_relative "human_player" require_relative "sarsa_com" com = SarsaCom.load(ARGV[0]) human = HumanPlayer.new(com.mark.opponent) game = case com.mark when Mark::Maru Game.new(com, human) when Mark::Batsu Game.new(human, com) end loop do game.start(true) end
引数にデータが保存されたファイル名を指定して実行すると、次のような感じ:
$ ruby human_com_game.rb sarsa_maru.dat ... ... ... action 0 value: 0.09545066928020114 action 1 value: 0.047119087194406266 action 2 value: 0.05029263252526899 action 3 value: 0.057371301967449294 action 4 value: 0.18533061077349378 action 5 value: 0.017655815132714626 action 6 value: 0.07453297520689049 action 7 value: 0.007029664392013386 action 8 value: 0.08874405020094858 player o selected 4. ... .o. ... <player: x> select index [0,1,2,3,5,6,7,8] 0 player x selected 0. x.. .o. ... action 1 value: 0.13644180657609892 action 2 value: 0.08549107724040744 action 3 value: 0.21413766609509366 action 5 value: 0.047593944000335106 action 6 value: 0.10878094152447267 action 7 value: 0.044813666953132915 action 8 value: 0.09799509153636561 player o selected 3. x.. oo. ... <player: x> select index [1,2,5,6,7,8] 5 player x selected 5. x.. oox ... action 1 value: 0.14027825800797694 action 2 value: 0.10080489625720024 action 6 value: -0.7866803603939236 action 7 value: 0.0655500733545393 action 8 value: 0.03277621016598627 player o selected 1. xo. oox ... <player: x> select index [2,6,7,8] 7 player x selected 7. xo. oox .x. action 2 value: 0.019671436175961916 action 6 value: 0.059774812624815196 action 8 value: 0.0 player o selected 6. xo. oox ox. <player: x> select index [2,8] 2 player x selected 2. xox oox ox. action 8 value: 0.0 player o selected 8. xox oox oxo draw. ... ... ... action 0 value: 0.09545066928020114 action 1 value: 0.047119087194406266 action 2 value: 0.05029263252526899 action 3 value: 0.057371301967449294 action 4 value: 0.18533061077349378 action 5 value: 0.017655815132714626 action 6 value: 0.07453297520689049 action 7 value: 0.007029664392013386 action 8 value: 0.08874405020094858 player o selected 4. ... .o. ... <player: x> select index [0,1,2,3,5,6,7,8] 1 player x selected 1. .x. .o. ... action 0 value: 0.8433980800612721 action 2 value: 0.6596210582724022 action 3 value: 0.6669976820401855 action 5 value: 0.11115157138276906 action 6 value: 0.6996064577978339 action 7 value: 0.006684006392742152 action 8 value: 0.6274293815316326 player o selected 0. ox. .o. ... <player: x> select index [2,3,5,6,7,8] 8 player x selected 8. ox. .o. ..x action 2 value: 0.013285060191656728 action 3 value: 0.9032070733661571 action 5 value: 0.11132507885138036 action 6 value: 0.7142778163449828 action 7 value: -0.02919502734225242 player o selected 3. ox. oo. ..x <player: x> select index [2,5,6,7] 5 player x selected 5. ox. oox ..x action 2 value: 0.04100036627359544 action 6 value: 0.9999999999791402 action 7 value: -0.8621386266690471 player o selected 6. ox. oox o.x player o win. ^C
$ ruby human_com_game.rb sarsa_batsu.dat ... ... ... <player: o> select index [0,1,2,3,4,5,6,7,8] 4 player o selected 4. ... .o. ... action 0 value: -0.2195070236015553 action 1 value: -0.8190184559250111 action 2 value: -0.22746690964126007 action 3 value: -0.7703563953273663 action 5 value: -0.7042020965239122 action 6 value: -0.1591522382427485 action 7 value: -0.8103658863552381 action 8 value: -0.21889673563383863 player x selected 6. ... .o. x.. <player: o> select index [0,1,2,3,5,7,8] 2 player o selected 2. ..o .o. x.. action 0 value: -0.029449525408073062 action 1 value: -0.5002824108595285 action 3 value: -0.46695825566350574 action 5 value: -0.5171741597983368 action 7 value: -0.4596529024799854 action 8 value: -0.05815243844315686 player x selected 0. x.o .o. x.. <player: o> select index [1,3,5,7,8] 3 player o selected 3. x.o oo. x.. action 1 value: -0.6756111093433974 action 5 value: -0.03930645750283344 action 7 value: -0.6855974759884454 action 8 value: -0.6796704678153767 player x selected 5. x.o oox x.. <player: o> select index [1,7,8] 7 player o selected 7. x.o oox xo. action 1 value: 0.0 action 8 value: -0.9999976824120433 player x selected 1. xxo oox xo. <player: o> select index [8] 8 player o selected 8. xxo oox xoo draw. ... ... ... <player: o> select index [0,1,2,3,4,5,6,7,8] 0 player o selected 0. o.. ... ... action 1 value: -0.6148702466409235 action 2 value: -0.45327431434107157 action 3 value: -0.43166695039083997 action 4 value: -0.01336035094095786 action 5 value: -0.5968754603388217 action 6 value: -0.3549832115425484 action 7 value: -0.41367284365126455 action 8 value: -0.3647693887348638 player x selected 4. o.. .x. ... <player: o> select index [1,2,3,5,6,7,8] 8 player o selected 8. o.. .x. ..o action 1 value: 0.04376569574409844 action 2 value: -0.1324238009861316 action 3 value: -0.012110866857599474 action 5 value: -0.007603740509887103 action 6 value: -0.2331071453930608 action 7 value: -0.023215083888458398 player x selected 1. ox. .x. ..o <player: o> select index [2,3,5,6,7] 7 player o selected 7. ox. .x. .oo action 2 value: -0.22574649833988844 action 3 value: -0.10829681905357397 action 5 value: -0.16360918491916404 action 6 value: 0.003943249919481477 player x selected 6. ox. .x. xoo <player: o> select index [2,3,5] 2 player o selected 2. oxo .x. xoo action 3 value: -0.3443407794258561 action 5 value: 0.0 player x selected 5. oxo .xx xoo <player: o> select index [3] 3 player o selected 3. oxo oxx xoo draw. ^C
全部のパターンを試したわけではないけど、うまく学習できていそうな感じ。
あるいは、何パターンか抜けていたとしても、さらに学習させればちゃんと動きそう。
今日はここまで!