yoppa.org


前橋工科大学 - メディアアート・プログラミング (アプリケーション開発) 2019

ml5.js 実践 – 転移学習2 : 特徴抽出による画像の回帰分析

今回も前回に引き続き転移学習 (Transfer Learning) を応用した実践的な機械学習のプログラミングに挑戦していきます。

前回はImageNetのデータセットから画像の特徴量を抽出して、その結果から新たな学習を行い画像をクラス分けしてラベルとその確度を計算することができました。今回はさらに単純にラベルづけするのではなく、複数の画像を連続的に入力してその傾向をなめらかなグラデーションとして学習させます。そして、そのデータを元にカメラの映像を分析します。このために今回は回帰分析 (Regression) という手法を用います。

まず機械学習における回帰分析とは何かを解説した上で、転移学習を応用して特徴抽出による画像の回帰分析を行っていきます。

スライド資料

サンプルプログラム

サンプルコード

Image Feature Extractor Regression – sketch.js

let featureExtractor;
let regressor; //classifierをregressorに
let video;
let loss;
let status = ''; //現在の状態を左上に表示
let showResult = ''; //クラス分け結果を中央に表示
let slider; //スライダー
let addImageButton, trainButton, predictButton; //ボタン

function setup() {
  createCanvas(windowWidth, windowHeight);
  video = createCapture(VIDEO);
  video.size(320, 240);
  video.hide();
  //特徴量抽出
  featureExtractor = ml5.featureExtractor('MobileNet', modelReady);  
  //classificationからrecressionへ
  regressor = featureExtractor.regression(video);
  //スライダーを配置
  slider = createSlider(0.0, 1.0, 0.5, 0.01);
  slider.position(20, 40);
  //ボタンを配置
  addImageButton = createButton('add image');
  addImageButton.position(slider.x + slider.width + 20, 40);
  addImageButton.mousePressed(addImage);
  trainButton = createButton('train');
  trainButton.position(addImageButton.x + addImageButton.width + 5, 40);
  trainButton.mousePressed(train);
  predictButton = createButton('start predict');
  predictButton.position(trainButton.x + trainButton.width + 5, 40);
  predictButton.mousePressed(predict);
}

function draw(){
  background(0);
  //ビデオ映像をフルクスリーン表示
  image(video, 0, 0, width, height);
  //現在の状態(status)を表示
  fill(255);
  textSize(12);
  textAlign(LEFT);
  text(status, 20, 20);
  //回帰分析の結果を表示
  fill(255, 255, 0);
  textSize(100);
  textAlign(CENTER);
  text(showResult, width/2, height/2);
}

//画像を追加
function addImage(){
  regressor.addImage(slider.value());
}

//訓練開始
function train(){
  regressor.train(function(lossValue) {
    if (lossValue) {
      loss = lossValue;
      status = 'Loss: ' + loss;
    } else {
      status = 'Done Training! Final Loss: ' + loss;
    }
  });
}

//回帰分析開始
function predict(){
  //結果が出たらgotResultsを実行
  regressor.predict(gotResults);
}

//モデルの読み込み完了
function modelReady() {
  status = 'MobileNet Loaded!';
}

function gotResults(err, result) {
  //エラー表示
  if (err) {
    console.error(err);
    status = err;
  }
  //クラス分けした結果を表示
  if (result && result.value) {
    showResult = result.value;
    predict();
  }
}