いものやま。

雑多な知識の寄せ集め

強化学習用のニューラルネットワークをSwiftで書いてみた。(その8)

前回は○×ゲームをSwiftで実装した。

今日はSarsaComの実装。

なお、Rubyでの実装は、以下を参照:

SarsaComクラス

ということで、さっそく。

//==============================
// TicTacToe
//------------------------------
// SarsaCom.swift
//==============================

import Foundation

@objc class SarsaCom: NSObject, Player, NSCoding {
  private static let isMaruKey = "isMaru"
  private static let valueNetworkKey = "valueNetwork"
  private static let epsilonKey = "epsilon"
  private static let stepSizeKey = "stepSize"
  private static let tdLambdaKey = "tdLambda"
  private static let isLearningKey = "isLearning"
  
  let mark: Mark
  private let valueNetwork: ValueNetwork
  private let epsilon: Double
  private let stepSize: Double
  private let tdLambda: Double
  
  private var previousState: State!
  private var currentState: State!
  private var accumulatedWeightGradient: Weight!
  
  var isLearning: Bool
  
  init(mark: Mark, valueNetwork: ValueNetwork,
       epsilon: Double = 0.1, stepSize: Double = 0.01, tdLambda: Double = 0.6) {
    self.mark = mark
    self.valueNetwork = valueNetwork
    self.epsilon = epsilon
    self.stepSize = stepSize
    self.tdLambda = tdLambda
    
    self.previousState = nil
    self.currentState = nil
    self.accumulatedWeightGradient = nil
    
    self.isLearning = true
    
    super.init()
  }
  
  required init(coder aDecoder: NSCoder) {
    // NOTE: marks are singleton objects.
    if aDecoder.decodeBoolForKey(SarsaCom.isMaruKey) {
      self.mark = Mark.Maru
    } else {
      self.mark = Mark.Batsu
    }
    self.valueNetwork = aDecoder.decodeObjectForKey(SarsaCom.valueNetworkKey) as! ValueNetwork
    self.epsilon = aDecoder.decodeDoubleForKey(SarsaCom.epsilonKey)
    self.stepSize = aDecoder.decodeDoubleForKey(SarsaCom.stepSizeKey)
    self.tdLambda = aDecoder.decodeDoubleForKey(SarsaCom.tdLambdaKey)
    self.isLearning = aDecoder.decodeBoolForKey(SarsaCom.isLearningKey)
    super.init()
  }
  
  func encodeWithCoder(aCoder: NSCoder) {
    if self.mark == Mark.Maru {
      aCoder.encodeBool(true, forKey: SarsaCom.isMaruKey)
    } else {
      aCoder.encodeBool(false, forKey: SarsaCom.isMaruKey)
    }
    // workaround: protocol 'ValueNetwork' is not for Objective-C
    if let valueNN = self.valueNetwork as? ValueNN {
      aCoder.encodeObject(valueNN, forKey: SarsaCom.valueNetworkKey)
    } else if let valueHME = self.valueNetwork as? ValueHME {
      aCoder.encodeObject(valueHME, forKey: SarsaCom.valueNetworkKey)
    } else {
      fatalError("not supported value network.")
    }
    aCoder.encodeDouble(self.epsilon, forKey: SarsaCom.epsilonKey)
    aCoder.encodeDouble(self.stepSize, forKey: SarsaCom.stepSizeKey)
    aCoder.encodeDouble(self.tdLambda, forKey: SarsaCom.tdLambdaKey)
    aCoder.encodeBool(self.isLearning, forKey: SarsaCom.isLearningKey)
  }
  
  func selectIndex(state: State) -> Int {
    let selectedIndex: Int
    if (!self.isLearning) || (Random.getRandomProbability() > self.epsilon) {
      let actions = state.validActions()
      let actionValues: [Double] = actions.map {
        [unowned self] (index: Int) in
        let newState = state.set(self.mark, atIndex: index)
        return self.valueNetwork.getValue(newState.toVector())
      }
      (selectedIndex, _) = zip(actions, actionValues).maxElement{$0.1 < $1.1}!
    } else {
      selectedIndex = state.validActions().sample()
    }
    self.currentState = state.set(self.mark, atIndex: selectedIndex)
    return selectedIndex
  }
  
  func learn(reward: Double) {
    if self.isLearning && (self.previousState != nil) {
      let previousStateVector = self.previousState.toVector()
      let (previousValue, weightGradient) = self.valueNetwork.getValueAndWeightGradient(previousStateVector)
      
      // normalize sensitivity for weight gradient
      
      var scale = 1.0
      var upperBound: Double! = nil
      var lowerBound: Double! = nil
      var sensitivity = 0.0
      var previousSensitivity = 0.0
      for _ in (0..<10) {
        let newValue = self.valueNetwork.getValue(previousStateVector, withWeightDiff: weightGradient, scale: scale)
        previousSensitivity = sensitivity
        sensitivity = newValue - previousValue
        
        if (sensitivity > 1.1) ||
           (sensitivity < 0.0) ||
           ((upperBound == nil) && (sensitivity < previousSensitivity)) {
          upperBound = scale
          scale = (lowerBound == nil) ? scale / 2.0 : (upperBound + lowerBound) / 2.0
        } else if sensitivity < 0.9 {
          lowerBound = scale
          scale = (upperBound == nil) ? scale * 2.0 : (upperBound + lowerBound) / 2.0
        } else {
          break
        }
      }
      
      // calculate accumulated weight gradient
      
      let scaledWeightGradient = weightGradient * scale
      if self.accumulatedWeightGradient == nil {
        self.accumulatedWeightGradient = scaledWeightGradient
      } else {
        self.accumulatedWeightGradient = (self.accumulatedWeightGradient * self.tdLambda
                                            + scaledWeightGradient)
      }
      
      // update weight by sarsa(lambda)
      
      let valueDiff: Double
      if self.currentState != nil {
        // normal state
        let currentValue = self.valueNetwork.getValue(self.currentState.toVector())
        valueDiff = reward + currentValue - previousValue
      } else {
        // terminal state
        valueDiff = reward - previousValue
      }
      let weightDiff = self.stepSize * valueDiff * self.accumulatedWeightGradient
      self.valueNetwork.addWeight(weightDiff)
      
      // finish episode
      
      if self.currentState == nil {
        self.accumulatedWeightGradient = nil
      }
    }
    
    self.previousState = self.currentState
    self.currentState = nil
  }
}

説明は省略。
(過去の記事を参照)


これで一通り実装できたので、あとは実際に学習を行うだけなんだけど、いろいろ問題が。

具体的には、以下のとおり:

  • Accelerateフレームワークの遅延評価のため、実質的なメモリリークが発生し、使用メモリ量が増え続ける。
  • NSKeyedArchiver/NSKeyedUnarchiverでオブジェクトをエンコード/デコードすると、クラス名が(モジュール名).(クラス名)となるため、ある実行ファイルで保存したファイルを他の実行ファイルでロードすると、例外が発生する。

特に、後者は解決がいろいろと面倒・・・

これらについては、明日以降、修正を行っていきたい。

今日はここまで!


ちょっとお知らせ。
いろいろと生活に変化が生じたので、明日以降、更新の頻度が落ちそう・・・