いものやま。

雑多な知識の寄せ集め

強化学習とニューラルネットワークを組合せてみた。(その14)

昨日は関数近似のためのHMEの実装を行った。

今日はいよいよHMEを関数近似に使ったSarsa( \lambda)法の実装。

ファイルの整理

ただ、いざ実装しようと思うと、違ってくるのは関数近似の部分だけで、それ以外はまったく同じ。
なので、ちょっとバカらしい感じが。

そこで、これまで実装したファイルを以下のように整理することにした:

ホントはtable_sarsa_com.rbもsarsa_com.rbと共通化したかったんだけど、テーブル型というのは全部の状態に対する重みを持つことになるので、重みのサイズがすごいことになってしまうので断念。
いろいろ工夫すれば、出来なくはないんだろうけど。

SarsaComクラス

ということで、SarsaComクラスの実装。

#====================
# sarsa_com.rb
#--------------------
# 関数近似を使ったSarsa(λ) AI
#====================

require_relative "mark"
require_relative "state"

class SarsaCom
  def initialize(mark, value_network, epsilon=0.1, step_size=0.01, td_lambda=0.6)
    @mark = mark
    @value_network = value_network
    @epsilon = epsilon
    @step_size = step_size
    @td_lambda = td_lambda

    @previous_state = nil
    @current_state = nil
    @accumulated_weights_gradient = nil

    @learn_mode = true
    @debug_mode = false
  end

  attr_reader :mark, :value_network, :epsilon, :step_size, :td_lambda
  attr_reader :learn_mode
  attr_accessor :debug_mode
  alias_method :learn_mode?, :learn_mode
  alias_method :debug_mode?, :debug_mode

  def learn_mode=(value)
    if value
      @learn_mode = true
      @value_network.learn_mode = true
    else
      @learn_mode = false
      @value_network.learn_mode = false
    end
  end

  def self.load(filename)
    sarsa_com = nil
    File.open(filename) do |file|
      sarsa_com = Marshal.load(file)
    end
    if sarsa_com.mark == 1
      sarsa_com.instance_variable_set :@mark, Mark::Maru
    elsif sarsa_com.mark == -1
      sarsa_com.instance_variable_set :@mark, Mark::Batsu
    end

    sarsa_com
  end

  def save(filename)
    # singleton object can't be saved.
    original_mark = @mark
    @mark = @mark.to_i
    File.open(filename, "w") do |file|
      Marshal.dump(self, file)
    end
    @mark = original_mark
  end

  def select_index(state)
    selected_action =
      if (!@learn_mode) || (Random.rand > @epsilon)
        state.valid_actions.max_by do |action|
          new_state = state.set(action, @mark)
          value, _ = @value_network.get_value_and_weights_gradient(new_state.to_array)
          puts "action #{action} value: #{value}" if @debug_mode
          value
        end
      else
        puts "random" if @debug_mode
        state.valid_actions.sample
      end
    @current_state = state.set(selected_action, @mark)
    selected_action
  end

  def learn(reward)
    if @learn_mode && @previous_state
      puts "player #{@mark} learning..." if @debug_mode

      previous_state_array = @previous_state.to_array
      previous_value, weights_gradient = @value_network.get_value_and_weights_gradient(previous_state_array)

      # normalize sensitivity for weights gradient

      alpha = 1.0
      upper_bound = nil
      lower_bound = nil
      sensitivity = 0.0
      previous_sensitivity = 0.0
      10.times do
        new_value = @value_network.get_value_with_weights_gradient(previous_state_array, weights_gradient, alpha)
        previous_sensitivity = sensitivity
        sensitivity = new_value - previous_value
        puts "alpha: #{alpha}, sensitivity: #{sensitivity}" if @debug_mode
        if (sensitivity < 0.0) ||
           (upper_bound.nil? && (sensitivity < previous_sensitivity))
          upper_bound = alpha
          if lower_bound.nil?
            alpha /= 2.0
          else
            alpha = (upper_bound + lower_bound) / 2.0
          end
        elsif sensitivity < 0.9
          lower_bound = alpha
          if upper_bound.nil?
            alpha *= 2.0
          else
            alpha = (upper_bound + lower_bound) / 2.0
          end
        elsif sensitivity > 1.1
          upper_bound = alpha
          if lower_bound.nil?
            alpha /= 2.0
          else
            alpha = (upper_bound + lower_bound) / 2.0
          end
        else
          puts "OK. (alpha: #{alpha})" if @debug_mode
          break
        end
      end
      weights_gradient.map! {|i| i * alpha}

      # calculate accumulated weights gradient

      if @accumulated_weights_gradient.nil?
        @accumulated_weights_gradient = Array.new(weights_gradient.size, 0.0)
      end
      @accumulated_weights_gradient.size.times do |i|
        @accumulated_weights_gradient[i] *= @td_lambda
        @accumulated_weights_gradient[i] += weights_gradient[i]
      end

      # update weights by sarsa(λ)

      value_diff =
        if @current_state
          # normal state
          current_value, _ = @value_network.get_value_and_weights_gradient(@current_state.to_array)
          reward + current_value - previous_value
        else
          # terminal state
          reward - previous_value
        end
      weights_diff = @accumulated_weights_gradient.map{|weight_gradient| @step_size * value_diff * weight_gradient}
      @value_network.add_weights(weights_diff)

      if @debug_mode
        puts "previous value: #{previous_value}"
        puts "value diff: #{value_diff}"
        updated_value, _ = @value_network.get_value_and_weights_gradient(previous_state_array)
        puts "updated value: #{updated_value}"
      end

      if @current_state.nil?
        @accumulated_weights_gradient = nil
      end
    end

    @previous_state = @current_state
    @current_state = nil
  end
end

# 以下、動作確認のコードなので省略

説明は強化学習とニューラルネットワークを組合せてみた。(その5) - いものやま。で行っているので、省略。

ただ、勾配のサイズの調整は少し修正している。
というのも、勾配の方向に進んだときに、感度が0.9〜1.1に届く前に山を越えてしまって値が下がるということがあったから。
この場合、山を越えてしまった時点を上限とするようにしている。
(ただし、これでもホントは不十分・・・ちょっと難しくて、どう実装すればいいのか分からなかった)
それと、元々はsensitivityで割ることでサイズを調整していたけど、それを単に2倍するか、もしくは2で割るように変更している。
というのも、sensitivityが小さすぎる場合に、それで割り算すると、alphaがとんでもないサイズになって、結果、@value_network.get_value_with_weights_gradient()がNaNを返すということがあったから。

あと、このSarsaComクラスに合わせて、ValueNNクラスを少し修正。
必要なインタフェースを用意した。

--- a/RLandNN/TicTacToe/value_nn.rb
+++ b/RLandNN/TicTacToe/value_nn.rb
@@ -63,6 +63,9 @@ class ValueNN
     end
   end
 
+  alias_method :learn_mode, :drop_enabled
+  alias_method :learn_mode=, :drop_enabled=
+
   def get_value_and_weights_gradient(input)
     # select drop units
 

NNSarsaComモジュール

これで、関数近似ニューラルネットワークを使ったSarsaComのインスタンスを作りたかったら、SarsaComのイニシャライザにvalue_networkとしてValueNNのインスタンスを渡せばいいようになったので、NNSarsaComクラスは不要に。
ただ、お手軽にインスタンスを作れるようにするために、NNSarsaComモジュールを用意した。

#====================
# nn_sarsa_com.rb
#--------------------
# ニューラルネットワークによる関数近似を使ったSarsa(λ) AI
#====================

require_relative "mark"
require_relative "state"
require_relative "sarsa_com"
require_relative "value_nn"

module NNSarsaCom
  def self.create(mark, hidden_unit_size=4, drop_unit_size=0, epsilon=0.1, step_size=0.01, td_lambda=0.6)
    value_nn = ValueNN.new(9, hidden_unit_size, -1.0, 1.0, drop_unit_size)
    SarsaCom.new(mark, value_nn, epsilon, step_size, td_lambda)
  end
end

if __FILE__ == $PROGRAM_NAME
  require_relative "game"

  if ARGV.size == 2
    maru_player = SarsaCom.load(ARGV[0])
    batsu_player = SarsaCom.load(ARGV[1])
    maru_player.learn_mode = true
    maru_player.debug_mode = false
    batsu_player.learn_mode = true
    batsu_player.debug_mode = false
  else
    maru_player = NNSarsaCom.create(Mark::Maru, 160, 32, 0.1, 0.01, 0.6)
    batsu_player = NNSarsaCom.create(Mark::Batsu, 160, 32, 0.1, 0.01, 0.6)
  end

  10000.times do |t|
    game = Game.new(maru_player, batsu_player)
    if t % 1000 == 0
      puts "[#{t}]"
      game.start(true)
    else
      game.start(false)
    end
  end

  maru_player.learn_mode = false
  maru_player.debug_mode = true
  batsu_player.learn_mode = false
  batsu_player.debug_mode = true
  game = Game.new(maru_player, batsu_player)
  game.start(true)

  maru_player.save("nn_sarsa_maru.dat")
  batsu_player.save("nn_sarsa_batsu.dat")
end

なお、このファイルをスクリプトとして実行すると、関数近似ニューラルネットワークを使ったSarsaComのインスタンスを生成して、1,000,000回学習を行うようになっている。

HMESarsaComモジュール

NNSarsaComモジュールと同じように、関数近似にHMEを使ったSarsaComを簡単に作れるようにするために、HMESarsaComモジュールを用意した。

#====================
# nn_sarsa_com.rb
#--------------------
# HMEによる関数近似を使ったSarsa(λ) AI
#====================

require_relative "mark"
require_relative "state"
require_relative "sarsa_com"
require_relative "value_nn"
require_relative "value_hme"

module HMESarsaCom
  # structureにはArrayの入れ子を指定する。
  # 数字はValueNNの中間層のユニット数を意味する。
  def self.create(mark, structure, epsilon=0.1, step_size=0.01, td_lambda=0.6)
    value_hme = create_value_hme(structure)
    SarsaCom.new(mark, value_hme, epsilon, step_size, td_lambda)
  end

  def self.create_value_hme(structure)
    experts = Array.new
    structure.each do |item|
      if item.is_a?(Integer)
        value_nn = ValueNN.new(9, item, -1.0, 1.0)
        experts.push(value_nn)
      else
        lower_hme = create_value_hme(item)
        experts.push(lower_hme)
      end
    end
    ValueHME.new(9, experts)
  end
  private_class_method :create_value_hme
end

if __FILE__ == $PROGRAM_NAME
  require_relative "game"

  if ARGV.size == 2
    maru_player = SarsaCom.load(ARGV[0])
    batsu_player = SarsaCom.load(ARGV[1])
    maru_player.learn_mode = true
    maru_player.debug_mode = false
    batsu_player.learn_mode = true
    batsu_player.debug_mode = false
  else
    maru_player = HMESarsaCom.create(Mark::Maru, [64, 64], 0.1, 0.01, 0.6)
    batsu_player = HMESarsaCom.create(Mark::Batsu, [64, 64], 0.1, 0.01, 0.6)
  end

  1000000.times do |t|
    game = Game.new(maru_player, batsu_player)
    if t % 1000 == 0
      puts "[#{t}]"
      game.start(true)
    else
      game.start(false)
    end
  end

  maru_player.learn_mode = false
  maru_player.debug_mode = true
  batsu_player.learn_mode = false
  batsu_player.debug_mode = true
  game = Game.new(maru_player, batsu_player)
  game.start(true)

  maru_player.save("hme_sarsa_maru.dat")
  batsu_player.save("hme_sarsa_batsu.dat")
end

HMESarsaCom.createの引数のstructureには、HMEの構造を指定する。
例えば、1階層で、それぞれのエキスパートネットワークの中間層のユニット数が64なら、[64, 64]、2階層で、それぞれのエキスパートネットワークの中間層のユニット数が32なら、[[32, 32], [32, 32]]といった具合。
ちなみに、[[32, 32, 32, 32], [32, 32]][64, [32, 32]]といった、非対称な構造のHMEを作ることも可能。

そして、このファイルをスクリプトとして実行すると、1階層でそれぞれのエキスパートネットワークの中間層のユニット数が64のHMEを関数近似として使ったSarsaComのインスタンスを生成して、1,000,000回学習を行うようになっている。

今日はここまで!