いものやま。

雑多な知識の寄せ集め

強化学習用のニューラルネットワークをSwiftで書いてみた。(その7)

昨日はValueNetworkの保存とロードの実装をした。
(ただ、いろいろ問題があったので、後で修正する予定)

これで実際に学習をするために、今日は○×ゲームをSwiftで実装する。

Markクラス

まずはマークを表すMarkクラスから。
enumで実装するのも一つの手だけど(というか、その方がSwiftっぽくはある)、今回はJavaenumに実装した。
(Swiftのenumだと、switchを書くのが面倒)

//==============================
// TicTacToe
//------------------------------
// Mark.swift
//==============================

import Foundation

class Mark: CustomStringConvertible, Hashable {
  static let Empty = Mark(mark: ".", value: 0.0, hash: 0)
  static let Maru = Mark(mark: "o", value: 1.0, hash: 1)
  static let Batsu = Mark(mark: "x", value: -1.0, hash: 2)
  
  let mark: String
  let value: Double
  let hash: Int
  
  private init(mark: String, value: Double, hash: Int) {
    self.mark = mark
    self.value = value
    self.hash = hash
  }
  
  var description: String {
    return self.mark
  }
  
  var hashValue: Int {
    return self.hash
  }
  
  var isEmpty: Bool {
    return self === Mark.Empty
  }
  
  var opponent: Mark {
    if self === Mark.Maru {
      return Mark.Batsu
    } else if self === Mark.Batsu {
      return Mark.Maru
    } else {
      return Mark.Empty
    }
  }
}

func ==(left: Mark, right: Mark) -> Bool {
  return left === right
}

Stateクラス

次は状態を表すStateクラス。

//==============================
// TicTacToe
//------------------------------
// State.swift
//==============================

import Foundation

class State: CustomStringConvertible {
  private let state: [Mark]
  
  convenience init() {
    let state = Array(count: 9, repeatedValue: Mark.Empty)
    self.init(state: state)
  }
  
  private init(state: [Mark]) {
    self.state = state
  }
  
  var description: String {
    return (0..<3).map {
      [unowned self] (row: Int) in
      let from = 3 * row
      let to = 3 * (row + 1)
      return self.state[from..<to].map{$0.mark}.joinWithSeparator("")
    }.joinWithSeparator("\n")
  }
  
  subscript(index: Int) -> Mark {
    return self.state[index]
  }
  
  func set(mark: Mark, atIndex index: Int) -> State {
    var newState = Array(self.state)
    newState[index] = mark
    return State(state: newState)
  }
  
  func toVector() -> Vector {
    return Vector.fromArray(self.state.map{$0.value})
  }
  
  func validActions() -> [Int] {
    return (0..<9).filter {
      [unowned self] (index: Int) in
      self.state[index].isEmpty
    }
  }
  
  func win(mark: Mark) -> Bool {
    let lines = [
      (0, 1, 2), (3, 4, 5), (6, 7, 8),
      (0, 3, 6), (1, 4, 7), (2, 5, 8),
      (0, 4, 8), (2, 4, 6)]
    for (i, j, k) in lines {
      if (self.state[i] == mark) &&
         (self.state[j] == mark) &&
         (self.state[k] == mark) {
        return true
      }
    }
    return false
  }
  
  var isDraw: Bool {
    return self.validActions().isEmpty && (!self.win(Mark.Maru)) && (!self.win(Mark.Batsu))
  }
  
  var isEnd: Bool {
    return self.win(Mark.Maru) || self.win(Mark.Batsu) || self.validActions().isEmpty
  }
}

Playerプロトコル

そして、プレイヤーを表すPlayerプトロコル。

//==============================
// TicTacToe
//------------------------------
// Player.swift
//==============================

import Foundation

protocol Player {
  var mark: Mark { get }
  var isLearning: Bool { get set }
  func selectIndex(state: State) -> Int
  func learn(reward: Double)
}

HumanPlayerクラス

人間プレイヤーを表すHumanPlayerクラス。

//==============================
// TicTacToe
//------------------------------
// HumanPlayer.swift
//==============================

import Foundation

class HumanPlayer: Player {
  let mark: Mark
  
  var isLearning: Bool {
    get { return false }
    set { /* ignore */ }
  }
  
  init(mark: Mark) {
    self.mark = mark
  }
  
  func selectIndex(state: State) -> Int {
    let stdin = NSFileHandle.fileHandleWithStandardInput()
    
    print("<player: \(self.mark)>")
    let actions = state.validActions()
    while true {
      print("select index [\(actions.map{$0.description}.joinWithSeparator(","))]")
      let input = stdin.readString()
      if let selectedIndex = Int(input) {
        if actions.indexOf(selectedIndex) != nil {
          return selectedIndex
        } else {
          print("invalid number.")
        }
      } else {
        print("invalid input.")
      }
    }
  }
  
  func learn(reward: Double) {
    // do nothing
  }
}

Gameクラス

最後に、ゲームを行うためのGameクラス。

//==============================
// TicTacToe
//------------------------------
// Game.swift
//==============================

import Foundation

class Game {
  private class func verboseOutput(description: String) {
    print(description)
  }
  
  private class func emptyOutput(description: String) {
    // do nothing
  }
  
  let players: [Mark: Player]
  
  init(maruPlayer: Player, batsuPlayer: Player) {
    self.players = [
      Mark.Maru: maruPlayer,
      Mark.Batsu: batsuPlayer,
    ]
  }
  
  func start(verbose: Bool = false) -> Mark {
    let output = verbose ? Game.verboseOutput : Game.emptyOutput
    
    var state = State()
    var currentPlayerMark = Mark.Maru
    var winner = Mark.Empty
    
    while true {
      let currentPlayer = self.players[currentPlayerMark]!
      
      output(state.description)
      let index = currentPlayer.selectIndex(state)
      output("player \(currentPlayerMark) selected \(index).")
      
      state = state.set(currentPlayerMark, atIndex: index)
      currentPlayer.learn(0.0)
      
      if state.win(currentPlayerMark) {
        winner = currentPlayerMark
        currentPlayer.learn(1.0)
        self.players[currentPlayerMark.opponent]!.learn(-1.0)
        output(state.description)
        output("player \(currentPlayerMark) win.")
        break
      } else if state.isDraw {
        for (_, player) in self.players {
          player.learn(0.0)
        }
        output(state.description)
        output("draw.")
        break
      } else {
        currentPlayerMark = currentPlayerMark.opponent
      }
    }
    
    return winner
  }
}

ちょっと説明が必要かもしれないところは、冗長出力の切り替え。
冗長出力を行うクラスメソッドと行わないクラスメソッドを用意しておいて、関数ポインタの指し先を切り替えることで、冗長出力を行うかどうかを切り替えている。

動作確認

とりあえずこれで人同士でプレイできるようにしたコードが以下:

//==============================
// TicTacToe
//------------------------------
// main.swift
//==============================

import Foundation

// Human v.s. Human

var maruPlayer: Player = HumanPlayer(mark: Mark.Maru)
var batsuPlayer: Player = HumanPlayer(mark: Mark.Batsu)
var game = Game(maruPlayer: maruPlayer, batsuPlayer: batsuPlayer)
game.start(true)

この実行例は、以下:

...
...
...
<player: o>
select index [0,1,2,3,4,5,6,7,8]
4
player o selected 4.
...
.o.
...
<player: x>
select index [0,1,2,3,5,6,7,8]
0
player x selected 0.
x..
.o.
...
<player: o>
select index [1,2,3,5,6,7,8]
8
player o selected 8.
x..
.o.
..o
<player: x>
select index [1,2,3,5,6,7]
2
player x selected 2.
x.x
.o.
..o
<player: o>
select index [1,3,5,6,7]
1
player o selected 1.
xox
.o.
..o
<player: x>
select index [3,5,6,7]
7
player x selected 7.
xox
.o.
.xo
<player: o>
select index [3,5,6]
3
player o selected 3.
xox
oo.
.xo
<player: x>
select index [5,6]
5
player x selected 5.
xox
oox
.xo
<player: o>
select index [6]
6
player o selected 6.
xox
oox
oxo
draw.

今日はここまで!