いものやま。

雑多な知識の寄せ集め

強化学習とニューラルネットワークを組合せてみた。(その4)

昨日は○×ゲームを人がプレイできるようにするところまで実装した。

今日はテーブル型のSarsa( \lambda)法を実装する。

SarsaComクラス

ということで、さっそく。

#====================
# sarsa_com.rb
#--------------------
# テーブル型のSarsa(λ) AI
#====================

require_relative "mark"
require_relative "state"

class SarsaCom

# 続く

イニシャライザとアクセサメソッド

まずはイニシャライザとアクセサメソッドから。

# 続き

  def initialize(mark, epsilon=0.1, step_size=0.1, td_lambda=0.6)
    @mark = mark
    @epsilon = epsilon
    @step_size = step_size
    @td_lambda = td_lambda

    @value = Hash.new(0.0)

    @previous_state = nil
    @current_state = nil
    @accumulated_weights = Hash.new(0.0)

    @learn_mode = true
    @debug_mode = false
  end

  attr_reader :mark, :epsilon, :step_size, :td_lambda
  attr_accessor :learn_mode, :debug_mode
  alias_method :learn_mode?, :learn_mode
  alias_method :debug_mode?, :debug_mode

# 続く

Sarsa( \lambda)法では、直前の状態を保持(バックアップ)しておいて、次の状態と報酬が観測されたら、その価値の差分を使って学習を行う。
そのために、@previous_state@current_stateというインスタンス変数を用意している。
また、観測された各状態に対する適格度を保持する必要があるので、@accumulated_weightsというインスタンス変数を用意している。

保存と復元

次に保存と復元。
ファイルにデータを保存し、そこからオブジェクトを復元できるようにしている。

# 続き

  def self.load(filename)
    sarsa_com = nil
    File.open(filename) do |file|
      sarsa_com = Marshal.load(file)
    end
    if sarsa_com.mark == 1
      sarsa_com.instance_variable_set :@mark, Mark::Maru
    elsif sarsa_com.mark == -1
      sarsa_com.instance_variable_set :@mark, Mark::Batsu
    end
    sarsa_com
  end

  def save(filename)
    # singleton object can't be saved.
    original_mark = @mark
    @mark = @mark.to_i
    File.open(filename, "w") do |file|
      Marshal.dump(self, file)
    end
    @mark = original_mark
  end

# 続く

RubyではMarshalモジュールを使うことで、オブジェクトの保存と復元が簡単に出来るようになっているので、それを使っている。
ただし、特異オブジェクトは保存できないので、マークについては一度数値に置き換えて保存して、復元後にオブジェクトに戻している。
(このとき、instance_variable_setメソッドでちょっと無理をしてる・・・)

アクションの選択

次はアクションの選択。

# 続き

  def select_index(state)
    selected_action =
      if (!@learn_mode) || (Random.rand > @epsilon)
        state.valid_actions.max_by do |action|
          new_state = state.set(action, @mark)
          value = @value[new_state.to_array]
          puts "action #{action} value: #{value}" if @debug_mode
          value
        end
      else
        puts "random" if @debug_mode
        state.valid_actions.sample
      end
    @current_state = state.set(selected_action, @mark)
    selected_action
  end

# 続く

学習していない場合、もしくは、0.0〜1.0の乱数で@epsilonより大きい値が出た場合には、価値が最も高いアクションを選択するようになっている。
そうでない場合はランダムにアクションを選ぶ。

それと、学習を行えるようにするために、現在の状態としてアクションを行ったあとの状態を保持するようにしている。

学習

そして学習。

# 続き

  def learn(reward)
    if @learn_mode && @previous_state
      puts "player #{@mark} learning..." if @debug_mode

      previous_state_array = @previous_state.to_array
      previous_value = @value[previous_state_array]

      # calculate accumulated weights

      @accumulated_weights.each_key do |state|
        @accumulated_weights[state] *= @td_lambda
      end
      @accumulated_weights[previous_state_array] += 1.0

      # update values by sarsa(λ)

      value_diff =
        if @current_state
          # normal state
          current_value = @value[@current_state.to_array]
          reward + current_value - previous_value
        else
          # terminal state
          reward - previous_value
        end
      @accumulated_weights.each do |state, weight|
        @value[state] += @step_size * value_diff * weight
      end

      if @debug_mode
        puts "previous value: #{previous_value}"
        puts "value diff: #{value_diff}"
        updated_value = @value[previous_state_array]
        puts "updated value: #{updated_value}"
      end

      if @current_state.nil?
        @accumulated_weights = Hash.new(0.0)
      end
    end

    @previous_state = @current_state
    @current_state = nil
  end
end

# 続く

まずは適格度の計算。
これまでに観測されている各状態の適格度を@td_lambda倍したあと、前回観測された状態の適格度に1を追加している。

そのあとは価値の更新。
今回の状態と前回の状態の価値の差分、それと、ステップサイズ、適格度を使って、各状態の価値を更新している。
このとき、学習の最後に今回の状態をnilにしているのがポイントで、状態が更新された直後とゲームが終わったときにlearnメソッドは呼ばれるようにしているわけだけど、今回の状態がnilになっていれば、それはゲームが終わったときのlearnメソッドの呼び出しだということになる。

最後に、今回の状態を前回の状態にセットして、今回の状態をnilにしたら、その学習はオシマイ。

動作確認

次のようなコードで動作確認を行った。
1,000,000回自己対戦させて、学習を行っている。
(学習の進捗の様子が見えるように、1,000回ごとに様子を出力させている)

# 続き

if __FILE__ == $PROGRAM_NAME
  require_relative "game"

  maru_player = SarsaCom.new(Mark::Maru, 0.1, 0.01, 0.6)
  batsu_player = SarsaCom.new(Mark::Batsu, 0.1, 0.01, 0.6)

  1000000.times do |t|
    game = Game.new(maru_player, batsu_player)
    if t % 1000 == 0
      puts "[#{t}]"
      game.start(true)
    else
      game.start(false)
    end
  end

  maru_player.learn_mode = false
  maru_player.debug_mode = true
  batsu_player.learn_mode = false
  batsu_player.debug_mode = true
  game = Game.new(maru_player, batsu_player)
  game.start(true)

  maru_player.save("sarsa_maru.dat")
  batsu_player.save("sarsa_batsu.dat")
end

これを実行させると、次のような感じ:

$ ruby sarsa_com.rb
[0]
...
...
...
player o selected 0.
o..
...
...
player x selected 1.
ox.
...
...
player o selected 2.
oxo
...
...
player x selected 3.
oxo
x..
...
player o selected 4.
oxo
xo.
...
player x selected 5.
oxo
xox
...
player o selected 6.
oxo
xox
o..
player o win.
[1000]
...
...
...
player o selected 2.
..o
...
...
player x selected 1.
.xo
...
...
player o selected 6.
.xo
...
o..
player x selected 4.
.xo
.x.
o..
player o selected 8.
.xo
.x.
o.o
player x selected 3.
.xo
xx.
o.o
player o selected 5.
.xo
xxo
o.o
player o win.

〜省略〜

[999000]
...
...
...
player o selected 4.
...
.o.
...
player x selected 0.
x..
.o.
...
player o selected 3.
x..
oo.
...
player x selected 5.
x..
oox
...
player o selected 2.
x.o
oox
...
player x selected 6.
x.o
oox
x..
player o selected 1.
xoo
oox
x..
player x selected 7.
xoo
oox
xx.
player o selected 8.
xoo
oox
xxo
draw.
...
...
...
action 0 value: 0.09545066928020114
action 1 value: 0.047119087194406266
action 2 value: 0.05029263252526899
action 3 value: 0.057371301967449294
action 4 value: 0.18533061077349378
action 5 value: 0.017655815132714626
action 6 value: 0.07453297520689049
action 7 value: 0.007029664392013386
action 8 value: 0.08874405020094858
player o selected 4.
...
.o.
...
action 0 value: -0.2195070236015553
action 1 value: -0.8190184559250111
action 2 value: -0.22746690964126007
action 3 value: -0.7703563953273663
action 5 value: -0.7042020965239122
action 6 value: -0.1591522382427485
action 7 value: -0.8103658863552381
action 8 value: -0.21889673563383863
player x selected 6.
...
.o.
x..
action 0 value: 0.1191473834627852
action 1 value: 0.018151043785580488
action 2 value: 0.09148836300654319
action 3 value: 0.16131567122492052
action 5 value: 0.11094255223526538
action 7 value: 0.10507590814082798
action 8 value: 0.10769901798606155
player o selected 3.
...
oo.
x..
action 0 value: -0.9654167705302769
action 1 value: -0.9053813591855598
action 2 value: -0.8808791365492502
action 5 value: -0.09241547417887054
action 7 value: -0.900615656533357
action 8 value: -0.8650737523109108
player x selected 5.
...
oox
x..
action 0 value: -0.8040332683593349
action 1 value: 0.05216871763548499
action 2 value: -0.0048845731200206945
action 7 value: 0.08644690293184881
action 8 value: 0.10186656779105396
player o selected 8.
...
oox
x.o
action 0 value: -0.053733589672292296
action 1 value: -0.9100637894269122
action 2 value: -0.9990313270740123
action 7 value: -0.9425510167439461
player x selected 0.
x..
oox
x.o
action 1 value: 0.04252600215533353
action 2 value: 0.0
action 7 value: 0.05787931090602695
player o selected 7.
x..
oox
xoo
action 1 value: 0.0
action 2 value: -0.9999989310775225
player x selected 1.
xx.
oox
xoo
action 2 value: 0.0
player o selected 2.
xxo
oox
xoo
draw.

それっぽく学習できていることが分かる。

実際に学習が出来ているのか試すために、次のようなコードを用意した:

# human_com_game.rb

require_relative "mark"
require_relative "game"
require_relative "human_player"
require_relative "sarsa_com"

com = SarsaCom.load(ARGV[0])
human = HumanPlayer.new(com.mark.opponent)

game = case com.mark
       when Mark::Maru
         Game.new(com, human)
       when Mark::Batsu
         Game.new(human, com)
       end

loop do
  game.start(true)
end

引数にデータが保存されたファイル名を指定して実行すると、次のような感じ:

$ ruby human_com_game.rb sarsa_maru.dat
...
...
...
action 0 value: 0.09545066928020114
action 1 value: 0.047119087194406266
action 2 value: 0.05029263252526899
action 3 value: 0.057371301967449294
action 4 value: 0.18533061077349378
action 5 value: 0.017655815132714626
action 6 value: 0.07453297520689049
action 7 value: 0.007029664392013386
action 8 value: 0.08874405020094858
player o selected 4.
...
.o.
...
<player: x>
select index [0,1,2,3,5,6,7,8]
0
player x selected 0.
x..
.o.
...
action 1 value: 0.13644180657609892
action 2 value: 0.08549107724040744
action 3 value: 0.21413766609509366
action 5 value: 0.047593944000335106
action 6 value: 0.10878094152447267
action 7 value: 0.044813666953132915
action 8 value: 0.09799509153636561
player o selected 3.
x..
oo.
...
<player: x>
select index [1,2,5,6,7,8]
5
player x selected 5.
x..
oox
...
action 1 value: 0.14027825800797694
action 2 value: 0.10080489625720024
action 6 value: -0.7866803603939236
action 7 value: 0.0655500733545393
action 8 value: 0.03277621016598627
player o selected 1.
xo.
oox
...
<player: x>
select index [2,6,7,8]
7
player x selected 7.
xo.
oox
.x.
action 2 value: 0.019671436175961916
action 6 value: 0.059774812624815196
action 8 value: 0.0
player o selected 6.
xo.
oox
ox.
<player: x>
select index [2,8]
2
player x selected 2.
xox
oox
ox.
action 8 value: 0.0
player o selected 8.
xox
oox
oxo
draw.
...
...
...
action 0 value: 0.09545066928020114
action 1 value: 0.047119087194406266
action 2 value: 0.05029263252526899
action 3 value: 0.057371301967449294
action 4 value: 0.18533061077349378
action 5 value: 0.017655815132714626
action 6 value: 0.07453297520689049
action 7 value: 0.007029664392013386
action 8 value: 0.08874405020094858
player o selected 4.
...
.o.
...
<player: x>
select index [0,1,2,3,5,6,7,8]
1
player x selected 1.
.x.
.o.
...
action 0 value: 0.8433980800612721
action 2 value: 0.6596210582724022
action 3 value: 0.6669976820401855
action 5 value: 0.11115157138276906
action 6 value: 0.6996064577978339
action 7 value: 0.006684006392742152
action 8 value: 0.6274293815316326
player o selected 0.
ox.
.o.
...
<player: x>
select index [2,3,5,6,7,8]
8
player x selected 8.
ox.
.o.
..x
action 2 value: 0.013285060191656728
action 3 value: 0.9032070733661571
action 5 value: 0.11132507885138036
action 6 value: 0.7142778163449828
action 7 value: -0.02919502734225242
player o selected 3.
ox.
oo.
..x
<player: x>
select index [2,5,6,7]
5
player x selected 5.
ox.
oox
..x
action 2 value: 0.04100036627359544
action 6 value: 0.9999999999791402
action 7 value: -0.8621386266690471
player o selected 6.
ox.
oox
o.x
player o win.
^C
$ ruby human_com_game.rb sarsa_batsu.dat
...
...
...
<player: o>
select index [0,1,2,3,4,5,6,7,8]
4
player o selected 4.
...
.o.
...
action 0 value: -0.2195070236015553
action 1 value: -0.8190184559250111
action 2 value: -0.22746690964126007
action 3 value: -0.7703563953273663
action 5 value: -0.7042020965239122
action 6 value: -0.1591522382427485
action 7 value: -0.8103658863552381
action 8 value: -0.21889673563383863
player x selected 6.
...
.o.
x..
<player: o>
select index [0,1,2,3,5,7,8]
2
player o selected 2.
..o
.o.
x..
action 0 value: -0.029449525408073062
action 1 value: -0.5002824108595285
action 3 value: -0.46695825566350574
action 5 value: -0.5171741597983368
action 7 value: -0.4596529024799854
action 8 value: -0.05815243844315686
player x selected 0.
x.o
.o.
x..
<player: o>
select index [1,3,5,7,8]
3
player o selected 3.
x.o
oo.
x..
action 1 value: -0.6756111093433974
action 5 value: -0.03930645750283344
action 7 value: -0.6855974759884454
action 8 value: -0.6796704678153767
player x selected 5.
x.o
oox
x..
<player: o>
select index [1,7,8]
7
player o selected 7.
x.o
oox
xo.
action 1 value: 0.0
action 8 value: -0.9999976824120433
player x selected 1.
xxo
oox
xo.
<player: o>
select index [8]
8
player o selected 8.
xxo
oox
xoo
draw.
...
...
...
<player: o>
select index [0,1,2,3,4,5,6,7,8]
0
player o selected 0.
o..
...
...
action 1 value: -0.6148702466409235
action 2 value: -0.45327431434107157
action 3 value: -0.43166695039083997
action 4 value: -0.01336035094095786
action 5 value: -0.5968754603388217
action 6 value: -0.3549832115425484
action 7 value: -0.41367284365126455
action 8 value: -0.3647693887348638
player x selected 4.
o..
.x.
...
<player: o>
select index [1,2,3,5,6,7,8]
8
player o selected 8.
o..
.x.
..o
action 1 value: 0.04376569574409844
action 2 value: -0.1324238009861316
action 3 value: -0.012110866857599474
action 5 value: -0.007603740509887103
action 6 value: -0.2331071453930608
action 7 value: -0.023215083888458398
player x selected 1.
ox.
.x.
..o
<player: o>
select index [2,3,5,6,7]
7
player o selected 7.
ox.
.x.
.oo
action 2 value: -0.22574649833988844
action 3 value: -0.10829681905357397
action 5 value: -0.16360918491916404
action 6 value: 0.003943249919481477
player x selected 6.
ox.
.x.
xoo
<player: o>
select index [2,3,5]
2
player o selected 2.
oxo
.x.
xoo
action 3 value: -0.3443407794258561
action 5 value: 0.0
player x selected 5.
oxo
.xx
xoo
<player: o>
select index [3]
3
player o selected 3.
oxo
oxx
xoo
draw.
^C

全部のパターンを試したわけではないけど、うまく学習できていそうな感じ。
あるいは、何パターンか抜けていたとしても、さらに学習させればちゃんと動きそう。

今日はここまで!