いものやま。

雑多な知識の寄せ集め

強化学習とニューラルネットワークを組合せてみた。(その3)

昨日は強化学習関数近似として使うニューラルネットワークの実装を行った。

さっそくSarsa( \lambda)法と組合せたいところなんだけど、その前にいろいろ実装。

Markモジュール

まずはマーク(○、×、空白)を表すためのMarkモジュールから。

#====================
# mark.rb
#--------------------
# マーク
#====================

module Mark
  Empty = Object.new
  Maru = Object.new
  Batsu = Object.new

  class << Empty
    def empty?
      true
    end

    def opponent
      Empty
    end

    def to_i
      0
    end

    def to_s
      "."
    end
  end

  class << Maru
    def empty?
      false
    end

    def opponent
      Batsu
    end

    def to_i
      1
    end

    def to_s
      "o"
    end
  end

  class << Batsu
    def empty?
      false
    end

    def opponent
      Maru
    end

    def to_i
      -1
    end

    def to_s
      "x"
    end
  end
end

マークはシングルトンなので、素のオブジェクトを作って、それぞれに特異クラスとして、便利なメソッドの追加を行っている。
Rubyの柔軟さ、凄いよねw)

Stateクラス

次に、状態を表すStateクラス。

#====================
# state.rb
#--------------------
# 状態
#====================

require_relative "mark"

class State
  def initialize
    @state = Array.new(9, Mark::Empty)
  end

  def valid_actions
    (0..8).each_with_object(Array.new) do |i, valid_actions|
      if @state[i].empty?
        valid_actions.push i
      end
    end
  end

  def get(index)
    @state[index]
  end

  def set(index, mark)
    new_state = State.new
    new_state.state = @state.dup
    new_state.state[index] = mark
    new_state
  end

  def to_array
    @state.map(&:to_i)
  end

  def win?(mark)
    [
      [0, 1, 2], [3, 4, 5], [6, 7, 8],
      [0, 3, 6], [1, 4, 7], [2, 5, 8],
      [0, 4, 8], [2, 4, 6],
    ].each do |i, j, k|
      if (@state[i] == mark) &&
         (@state[j] == mark) &&
         (@state[k] == mark)
        return true
      end
    end
    false
  end

  def draw?
    self.valid_actions.empty? && (!self.win?(Mark::Maru)) && (!self.win?(Mark::Batsu))
  end

  def end?
    self.win?(Mark::Maru) || self.win?(Mark::Batsu) || self.valid_actions.empty?
  end

  def print
    @state.each_slice(3) do |a, b, c|
      puts "#{a}#{b}#{c}"
    end
  end

  protected

  attr_accessor :state
end

# 以下、動作確認のコードは省略

特に難しいことはなく。
ただ、不変なオブジェクトになっていることに注意。
setメソッドは、新しいオブジェクトを作って返す)

Gameクラス

そして、実際にゲームを行うためのクラス。

#====================
# game.rb
#--------------------
# ゲーム
#====================

require_relative "mark"
require_relative "state"

class Game
  def initialize(maru_player, batsu_player)
    @players = {
      Mark::Maru => maru_player,
      Mark::Batsu => batsu_player,
    }
  end

  def start(verbose=false)
    state = State.new
    current_player_mark = Mark::Maru
    result = nil
    loop do
      current_player = @players[current_player_mark]

      state.print if verbose
      index = current_player.select_index(state)
      puts "player #{current_player_mark} selected #{index}." if verbose

      state = state.set(index, current_player_mark)
      current_player.learn(0)

      if state.win?(current_player_mark)
        result = current_player_mark
        current_player.learn(1)
        @players[current_player_mark.opponent].learn(-1)
        if verbose
          state.print
          puts "player #{current_player_mark} win."
        end
        break
      elsif state.draw?
        result = Mark::Empty
        @players.each do |_, player|
          player.learn(0)
        end
        if verbose
          state.print
          puts "draw."
        end
        break
      end

      current_player_mark = current_player_mark.opponent
    end
    result
  end
end

ここでは、プレイヤーがselect_indexというメソッドとlearnというメソッドを持っているものとしている。
型に厳格な言語の場合、インタフェースやプロトコルとして定義するんだろうけど、Rubyはその辺りが緩いので書くのは楽チン。

なお、learnメソッドを呼び出すタイミングは、Sarsa( \lambda)法とちょっと関係があって、実際に選択された手を実行した後と(このときは報酬は常に0)、終端状態に至って勝敗が決したとき(このときの報酬は、勝ち/負け/引き分けに応じて、それぞれ+1/-1/0)。
BirdHeadを実装したときにもこうしたけど、この実装方法が簡単で分かりやすいと思う。

HumanPlayerクラス

あとは実際のプレイヤーを実装する。

まずは人間のプレイヤーから。

#====================
# human_player.rb
#--------------------
# 人間のプレイヤー
#====================

require_relative "mark"
require_relative "state"

class HumanPlayer
  def initialize(mark)
    @mark = mark
  end

  attr_reader :mark

  def select_index(state)
    puts "<player: #{self.mark}>"
    actions = state.valid_actions
    loop do
      puts "select index [#{actions.join(',')}]"
      index = $stdin.gets.chomp.to_i
      if actions.include?(index)
        break index
      end
    end
  end

  def learn(reward)
    # 何もしない
  end
end

if __FILE__ == $PROGRAM_NAME
  require_relative "game"

  maru_player = HumanPlayer.new(Mark::Maru)
  batsu_player = HumanPlayer.new(Mark::Batsu)
  game = Game.new(maru_player, batsu_player)
  game.start(true)
end

特に難しいことはなく。
動作確認すると、次のような感じ:

$ ruby human_player.rb
...
...
...
<player: o>
select index [0,1,2,3,4,5,6,7,8]
4
player o selected 4.
...
.o.
...
<player: x>
select index [0,1,2,3,5,6,7,8]
2
player x selected 2.
..x
.o.
...
<player: o>
select index [0,1,3,5,6,7,8]
6
player o selected 6.
..x
.o.
o..
<player: x>
select index [0,1,3,5,7,8]
5
player x selected 5.
..x
.ox
o..
<player: o>
select index [0,1,3,7,8]
8
player o selected 8.
..x
.ox
o.o
<player: x>
select index [0,1,3,7]
0
player x selected 0.
x.x
.ox
o.o
<player: o>
select index [1,3,7]
7
player o selected 7.
x.x
.ox
ooo
player o win.

実際にプレイできているのが分かると思う。

本当はテーブル型のSarsa( \lambda)法の実装も書くつもりだったんだけど、だいぶ長くなってきたので、それはまた明日。

今日はここまで!