yoppa.org


前橋工科大学 - メディアアート・プログラミング (アプリケーション開発) 2019

ml5.js 実践 – 転移学習3 : 転移学習を活用してゲームを作る

前回は、転移学習 (Transfer Learning) の応用として、画像の特徴を抽出してその結果を別の画像に適用することで回帰分析 (Regression) を行うコードを作成しました。今回はこの仕組みをより実践的に活用して、ゲームのコントローラーとして使用します。今回は簡単なテニスゲームのラケットの左右に移動する動きをジェスチャーにより行うゲームの作成を通して、機械学習の応用について実践します。

スライド資料

サンプルプログラム

let featureExtractor, regressor, loss, predictResult=0.5;
let video;
let status = 'loading...';
let slider, addImageButton, trainButton, predictButton;
let mode = 0; //ゲームモード 0:学習中、1:ゲームプレイ
let game; //ゲーム本体

function setup() {
  createCanvas(windowWidth, windowHeight);
  //新規にゲームを生成
  game = new TennisGame();
  //解析関係の設定
  video = createCapture(VIDEO);
  video.size(320, 240);
  video.hide();  
  featureExtractor = ml5.featureExtractor('MobileNet', modelReady);  
  regressor = featureExtractor.regression(video);
  slider = createSlider(0.0, 1.0, 0.5, 0.01);
  slider.position(20, 40);
  addImageButton = createButton('add image');
  addImageButton.position(slider.x + slider.width + 20, 40);
  addImageButton.mousePressed(addImage);
  trainButton = createButton('train');
  trainButton.position(addImageButton.x + addImageButton.width + 5, 40);
  trainButton.mousePressed(train);
  predictButton = createButton('start predict');
  predictButton.position(trainButton.x + trainButton.width + 5, 40);
  predictButton.mousePressed(predict);
}

function draw(){
  background(0);
  image(video, 0, 0, width, height);
  fill(0, 63);
  noStroke();
  rectMode(CORNER);
  rect(0, 0, 400, 80);
  fill(255);
  textSize(12);
  textAlign(LEFT);
  text(status, 20, 20);
  //もしゲームモードだったら
  if(mode == 1){    
    //ゲームの更新と描画
    game.update();
    game.draw();
    //分析の結果でバーを動かす
    let speed = map(predictResult, 0.0, 1.0, -10, 10);
    game.bar.location.x += speed;
  }
}

function addImage(){
  regressor.addImage(slider.value());
}

function train(){
  regressor.train(function(lossValue) {
    if (lossValue) {
      loss = lossValue;
      status = 'Loss: ' + loss;
    } else {
      status = 'Done Training! Final Loss: ' + loss;
    }
  });
}

function predict(){
  regressor.predict(gotResults);
  mode = 1;
}

function modelReady() {
  status = 'MobileNet Loaded!';
}

function gotResults(err, result) {
  if (err) {
    console.error(err);
    status = err;
  }
  if (result && result.value) {
    predictResult = result.value;
    predict();
  }
}

//----- ゲームに関するクラス : TennisGame、Ball、Bar ----//

class TennisGame {
  constructor(){
    this.ball = new Ball();
    this.bar = new Bar();
    this.score = 0;
  }
  update(){
    this.ball.update();
    // バー(コントローラー)でバウンド
    if (this.ball.location.y > height - 50) {
      if (abs(this.bar.location.x - this.ball.location.x) < 150) {
        this.ball.velocity.y *= -1.0;
        this.ball.velocity.mult(1.2);
        this.score += 100;
      } else {
        this.initGame();
      }
    }
  }
  draw(){
    fill(255);
    this.ball.draw();
    this.bar.draw();
    textSize(40);
    textAlign(RIGHT);
    text(this.score, width-20, 60);
  }
  initGame() {
    // リセット
    this.score = 0;
    this.ball.location.set(width / 2, 40);
    this.ball.velocity = createVector(random(-2, 2), 2);
    this.bar.location.x = width / 2;
  }
}

class Ball {
  constructor() {
    this.location = createVector(width / 2, 40);
    this.velocity = createVector(random(-2, 2), 2);        
  }
  update() {
    this.location.add(this.velocity);
    // 壁でバウンド
    if (this.location.x < 0 || this.location.x > width) {
      this.velocity.x *= -1;
    }
    if (this.location.y < 0) {
      this.velocity.y *= -1;
    }
  }
  draw() {
    fill(255);
    noStroke();
    ellipse(this.location.x, this.location.y, 20);
  }
}

class Bar {
  constructor() {
    this.location = createVector(width / 2, height - 50);
    this.speed = 10;
  }
  draw() {
    fill(255);
    noStroke();
    rectMode(CENTER);
    this.location.x = constrain(this.location.x, 0, width);
    rect(this.location.x, this.location.y, 300, 20);
  }
  moveLeft(){
    this.location.x -= this.speed;
  }
  moveRight(){
    this.location.x += this.speed;
  }
}