まひろ量子のハックログ

プログラミングや機械学習などの知識を記録・共有します

GANを使って簡単に架空アイドル画像を自動生成(Progressive Growing of GANs)

f:id:twx:20181215154156p:plain
Artificial Idol
この記事で紹介する方法で、このような画像が作れるようになります。

最近趣味でやってる画像生成系のDNNについて簡単にレポートします。

1. Progressive Growing of GANsとは

Paperはこちら。 [1710.10196] Progressive Growing of GANs for Improved Quality, Stability, and Variation

Githubはこちらです。 https://github.com/tkarras/progressive_growing_of_gans

提案されているテクニックは、簡単にいうとGANの学習をする際に「小さいネットワーク」から段階的に「大きいネットワーク」に転移させていくことで、大きな画像においても安定した学習を可能にする、というものです。論文では、4x4の小さい画像から始めて1024x1024の大きな画像を生成することに成功したと述べられています。

f:id:twx:20181215151701p:plain
引用元:Figure 1; PROGRESSIVE GROWING OF GANS FOR IMPROVED QUALITY, STABILITY, AND VARIATION

また、1024x1024の学習に必要となる高画質な学習用画像を得るために様々な工夫がされています。簡単に言うと、顔の位置を揃える操作、超解像、背景のぼかしです。元となる顔画像のデータセットに対して、両目の位置を検出しその座標を起点としてトリミングすることで顔の位置をすべてのデータで合わせます。更に、超解像技術を用いて512x512の画像を1024x1024に高解像度化します。最後に、背景にブロー処理をかけてぼかします。

かなりの手間をかけたこのようなデータのクリーニングは、ハイクオリティな画像を生成するのに必須だと言われています。

今回はこの論文の再現実験として、実在しない架空のアイドルの顔画像を生成してみました。アイドルの生成には既に先駆者がいて、1年ほど前にかなりバズったのを覚えていらっしゃる方もいると思います。なので、目新しさは無いです。

2. データの準備

アイドル画像をひたすらクローリングしまくります。クローリング対象のURLを公開することは迷惑行為になってしまうので、すみませんが非公開とさせてください。ここでは、クローリングを行うコードをいくつか載せるに留めておきます。

2.1 google画像検索結果を保存するコード

Google検索で画像を手に入れる方法です。 google-images-download というpipモジュールを使います。このモジュールはコマンドライン上で pip コマンドを使ってインストールします。

# pipコマンドでインストール
pip install google_images_download

詳しい使用方法は以下のページが詳しいです。

co.bsnws.net

さて、これを使ってアイドル画像をダウンロードします。以下のXXXXXXXXXXXの部分を、任意の検索クエリに書き換えて実行すると大量の画像が手に入ります。 XXXXXXXXXXXには、アイドルの名前を入れると良いです。

  • get_images1.sh *
googleimagesdownload --keywords "XXXXXXXXXXX" --size large
googleimagesdownload --keywords "XXXXXXXXXXX" --size large
・
・
・
googleimagesdownload --keywords "XXXXXXXXXXX" --size large

2.2 アイドル画像が掲載されている特定ページをクローリング

soupなどの一般的なクロール技術を使って、Google検索ではなく、アイドルの写真をたくさん載せているサイトからクローリングします。迷惑行為になりかねないので、手順の公開は控えさせていただきます。

2.3 顔画像のトリミング

上の2.1と2.2の方法で数万枚オーダーの画像を集め終わったら、今度は写真の中から顔画像を検出して適切なサイズにトリミングします。これには以下のツールを使いました。

https://github.com/deepfakes/faceswap

これは元々、Faceswapという、2人の人物の顔を互いに入れ替えるタスクで有名なツールです。このタスクも、事前に学習データ(顔画像)に対して前処理を行う必要があり、顔をトリミングする機能をもつコードも含まれています。これを利用しましょう。

デフォルトでは以下のコマンドを実行すると、srcフォルダの中の全ての画像に対して顔検出を行い、両目の位置がx軸と平行になるように回転補正をかけたうえで、目の位置を起点に正方形の画像をトリミングしてくれます。

run python faceswap.py extract

しかし、問題がいくつかあります。 トリミング後の画像サイズは256x256で固定なので、高解像度画像が必要なPGGANsで使うには少し小さすぎます。また、顔がややズームアップされた状態でトリミングされるため、髪や服装があまり写らないという欠点もあります。更に、回転補正して正方形に切り出すため、もしも顔が画面端で斜めに検出されてしまうと、正方形に切り出した際に四隅にデッドスペースができてしまいます(以下のように)。 これらの点をなんとかして改善する必要があります。

f:id:twx:20181215173339p:plain
回転時に四隅が消えてしまう失敗例

まずは四隅のデッドスペースをなんとかします。これには、元論文にも書かれていますが、元画像の端っこに鏡像反転させた余白を付与するという手法を適用します。

f:id:twx:20181215175059p:plain
境界ミラーリングの例(引用元:PROGRESSIVE GROWING OF GANS FOR IMPROVED QUALITY, STABILITY, AND VARIATION)

以下のコマンドで、元画像が保存されているフォルダに対して、全ての画像の上下左右に10%のマージンを付与します。

import cv2
import numpy as np
from matplotlib import pyplot as plt
import glob

def mirror_padding(img_path):
    img1 = cv2.imread(img_path)
    padding_y = img1.shape[0] // 10
    padding_x = img1.shape[1] // 10
    img2 = cv2.copyMakeBorder(img1, padding_y, padding_y, padding_x, padding_x, cv2.BORDER_REFLECT_101)
    return img2    

image_paths = glob.glob('/Path/To/Src/Images/*')
for image_path in image_paths:
    img_name = image_path.split('/')[-1]
    img = mirror_padding(image_path)
    cv2.imwrite('/Path/To/Output/' + img_name, img)

こうすることで、もしも顔が画面端にあったとしても、ある程度回転角に余裕をもたせられます。

次に、顔がズームアップされてしまう件と、画像サイズが小さい問題を解決します。これは、Faceswapのソースを改造すればOKです。

以下の2箇所をこのように書き換えましょう。

# faceswap/plugins/Extract_Align.py
12c12
<         extracted = self.transform(image, alignment, size, 48)
---
>         extracted = self.transform(image, alignment, size, 48*3)
# faceswap/scripts/extract.py
129c129
<             256,
---
>             512,

これで、うまく顔のトップから、肩くらいまでがトリミングされます。

2.4 良質な画像の選別

数万枚の画像を実際に目で見て、良い画像と悪い画像に分けます。ここは、どうしても気合と根性が必要です。「正面を向いている」「見切れていない」「暗くない」「手で顔が隠れていない」といった条件に満たすものを選別します。

ただ、完全に手作業だとかなり辛いので、以下のような画像仕分け効率化ツールを作りました。

  • 仕分けツール.html *
<!DOCTYPE html>
<html>
<head>
  <title>仕分けツール</title>
  <link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/bootstrap/4.1.3/css/bootstrap.min.css">
  <script src="https://ajax.googleapis.com/ajax/libs/jquery/1.11.1/jquery.min.js"></script>
</head>
<body>
<script>
  var numFiles = 0;
  var files = new Array();
  var fileNames = new Array();
  var cursor = 0;
  var prevImageName, currImageName, nextImageName;

  var classA = new Array();
  var classB = new Array();
  var chache = new Array();

  function drawImageOnCanvas(file){
    var image = new Image();
    var reader = new FileReader();
    var canvas = $('#cur_canvas');
    var ctx = canvas[0].getContext('2d');
    reader.onload = function(evt) {
      image.onload = function() {
        ctx.clearRect(0, 0, 300, 300);
        ctx.drawImage(image, 0, 0, 300, 300);
      }
      image.src = evt.target.result;
    }
    reader.readAsDataURL(file);
  }

  function fileListDirectory(_files) {
    for (i=0; i<_files.length; i++) {
      var fileType = _files[i].type;
        if (fileType == 'image/jpeg' || fileType == 'image/png' ) {
          files.push(_files[i])
          fileNames.push(_files[i].name );
          numFiles ++;
        }
    }
    resetImage(cursor);
  }

  function resetImage(cursor) {
    prevImageName = (cursor == 0 ? 'なし' : fileNames[cursor-1]);
    currImageName = fileNames[cursor];
    nextImageName = (cursor == numFiles - 1 ? 'なし' : fileNames[cursor+1]);
    document.getElementById('previous').innerHTML = prevImageName;
    document.getElementById('current').innerHTML = currImageName;
    document.getElementById('next').innerHTML = nextImageName;
    drawImageOnCanvas(files[cursor]);
    document.getElementById('progress').innerHTML = (cursor+1) + '/' + numFiles;
  }

  function previous(){
    cursor --;
    if( cursor < 0 ) {
      cursor = 0;
    }
    resetImage(cursor);
  }
    
  function next(){
    cursor ++;
    if( cursor > numFiles-1 ) {
      cursor = numFiles-1;
    }
    resetImage(cursor);
  }

  function undo() {
    if(cursor > 0){
      var which = chache[cursor];
      if(which == 'A') {
        classA.pop();
      } else if (which == 'B') {
        classB.pop();
      }
      chache.pop();
      previous();
      document.getElementById('classA').innerHTML = classA.length;
      document.getElementById('classB').innerHTML = classB.length;
    }
  }

  function downloadData() {
    var hiddenElement = document.createElement('a');
    hiddenElement.href = 'data:attachment/text,' + encodeURI(classA);
    hiddenElement.target = '_blank';
    hiddenElement.download = 'NG.txt';
    hiddenElement.click();
    var hiddenElement = document.createElement('a');
    hiddenElement.href = 'data:attachment/text,' + encodeURI(classB);
    hiddenElement.target = '_blank';
    hiddenElement.download = 'OK.txt';
    hiddenElement.click();
  }

  window.onload = function() {
    function onKeyUp(e) {
      if(e.code=='KeyF') {
        classA.push( fileNames[cursor] );
        chache.push('A');
        next();
        document.getElementById('classA').innerHTML = classA.length;
      } else if(e.code=='KeyJ') {
        classB.push( fileNames[cursor] );
        chache.push('B');
        next();
        document.getElementById('classB').innerHTML = classB.length;
      }
      e.preventDefault();
    };

    // Set up key event handlers
    window.addEventListener('keyup', onKeyUp);
  };

</script>

<div class="container">
  <div class="row">
    <div class="col-sm-12 mt-5">
      <div class="btn btn-success p-0">
        <input class="p-1" type="file" webkitdirectory directory onChange="fileListDirectory(this.files)">
      </div>
      <div id=progress></div>
      <div style="display: none;">前の画像:<span id="previous">結果がここに表示されます。</span></div>
      <div style="display: none;">今の画像:<span id="current">結果がここに表示されます。</span></div>
      <div style="display: none;">次の画像:<span id="next">結果がここに表示されます。</span></div>

      <div>
        <canvas id="cur_canvas" width="300" height="300"></canvas>
      </div>
      <button class="btn btn-success" onclick="undo()">1つ戻る</button-->
      <button class="btn btn-success" onclick="downloadData()">ダウンロード</button-->
    </div>
    <div class="col-sm-6 mt-5">
      <h2>不良データ</h2>
      <div id="classA">
      </div>
    </div>
    <div class="col-sm-6 mt-5">
      <h2>優良データ</h2>
      <div id="classB">
      </div>
    </div>

  </div>
</div>

</body>
</html>

このhtmlをローカルに保存してChromeで開いてください。

f:id:twx:20181215182633g:plain
仕分けツール

フォルダを選択できるボタンがあるので、これまで準備を進めてきた「トリミング済みの顔画像が大量に保存されているフォルダ」を選択してください。すると、画面中央に画像が出現しますので、画像にフォーカスをあてたうえで「F」キーと「J」キーで、「NG」か「OK」かを仕分けしてください。要は、キーボード操作でスピーディーに仕分けができるというツールです。

3秒で1枚をさばけると仮定すると、8時間強で1万枚さばけます。

最後に、「ダウンロードボタン」を押すと、「OK」に仕分けられた画像の名前が列挙されたテキストファイルを得ることができます。あとは、この「OKと判定した画像」だけを別のフォルダにコピーするなりしてください。

こうして、良質な学習データが手に入りました! 私はこれで1万3000枚ほど集めました。

f:id:twx:20181215183848p:plain
あつめた画像たち

3. 学習する

学習にはGPUが必要です。Google ColaboratoryならGPUを無料で使えます(2018年12月現在)。

Google Colaboratoryは、Jupyter notebook風にブラウザ上でコードを実行できるGoogleのサービスです。詳しくは以下からどうぞ。

https://colab.research.google.com

Google ColaboratoryはGoogle Driveと連携できます。つまり、自作の学習データをGoogle Driveに置いておくと、それをGoogle Colaboratoryから読み込むことができます。まず、先程作った画像をDriveにアップロードします。フォルダをzipで圧縮してから送ります。先程の「OKと判定した画像」が大量に入っているフォルダをzip化し、OK_idol.zipとします。

f:id:twx:20181215184709p:plain
Google Driveの画面

私は、Google Driveのルート直下に、dataというフォルダを作り、その中にOK_idol.zipを保存しました。

また、学習済モデルや、生成画像を保存したりするのに使う作業用のフォルダも作っておきます。 Google Driveのルート直下に、workというフォルダを作り、その中にPGGANというフォルダを作っておきます。

f:id:twx:20181215192307p:plain
Google Driveの画面2

ここから、Colaboratory上での操作になります。

まず、Google DriveをColaboratoryにマウントします。そして、PGGANsのソースをColaboratory環境上にクローンします。

from google.colab import drive
drive.mount('/content/drive')

%cd /content/drive/My\ Drive/work/PGGAN
!git clone https://github.com/tkarras/progressive_growing_of_gans.git

クローンが完了すると、Drive上で work/PGGAN/progressive_growing_of_gansの中に、config.pyというファイルが見つかります。

このconfigを以下のように編集します。

class EasyDict(dict):
    def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
    def __getattr__(self, name): return self[name]
    def __setattr__(self, name, value): self[name] = value
    def __delattr__(self, name): del self[name]

#----------------------------------------------------------------------------
# Paths.

data_dir = '/content/my_dataset'
result_dir = 'results'

#----------------------------------------------------------------------------
# TensorFlow options.

tf_config = EasyDict()  # TensorFlow session config, set by tfutil.init_tf().
env = EasyDict()        # Environment variables, set by the main program in train.py.

tf_config['graph_options.place_pruned_graph']   = True      # False (default) = Check that all ops are available on the designated device. True = Skip the check for ops that are not used.
env.TF_CPP_MIN_LOG_LEVEL                        = '1'       # 0 (default) = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info.

#----------------------------------------------------------------------------
# Official training configs, targeted mainly for CelebA-HQ.
# To run, comment/uncomment the lines as appropriate and launch train.py.

desc        = 'pgan'                                        # Description string included in result subdir name.
random_seed = 1000                                          # Global random seed.
dataset     = EasyDict()                                    # Options for dataset.load_dataset().
train       = EasyDict(func='train.train_progressive_gan')  # Options for main training func.
G           = EasyDict(func='networks.G_paper')             # Options for generator network.
D           = EasyDict(func='networks.D_paper')             # Options for discriminator network.
G_opt       = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for generator optimizer.
D_opt       = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for discriminator optimizer.
G_loss      = EasyDict(func='loss.G_wgan_acgan')            # Options for generator loss.
D_loss      = EasyDict(func='loss.D_wgangp_acgan')          # Options for discriminator loss.
sched       = EasyDict()                                    # Options for train.TrainingSchedule.
grid        = EasyDict(size='1080p', layout='random')       # Options for train.setup_snapshot_image_grid().

# Dataset (choose one).
desc += '-idol512';               dataset = EasyDict(tfrecord_dir='OK_idol_for_PGGAN'); train.network_snapshot_ticks = 1; train.mirror_augment = True

# Resume
#train.resume_run_id = '/content/drive/My Drive/work/PGGAN/progressive_growing_of_gans/results/032-pgan-idol512-preset-v2-1gpu-fp32/network-snapshot-xxxxxx.pkl';
#train.resume_kimg = 0

# Config presets (choose one).
desc += '-preset-v2-1gpu'; num_gpus = 1; sched.minibatch_base = 4; sched.minibatch_dict = {4: 128, 8: 128, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8, 512: 4}; sched.G_lrate_dict = {1024: 0.0015}; sched.D_lrate_dict = EasyDict(sched.G_lrate_dict); train.total_kimg = 12000

# Numerical precision (choose one).
desc += '-fp32'; sched.max_minibatch_per_gpu = {256: 16, 512: 8, 1024: 4}
#----------------------------------------------------------------------------

コンフィグの各行の意味を詳しく知りたい方はオリジナルのgithubのページをご確認ください。ここで重要なのは、データセットのパスと、resumeの設定です。resumeとは、学習がある程度進んで保存したモデルから、学習を再開することを言います。1番最初に学習を開始するときはresumeは関係ありません。

まず、データセットのパスに注意してください。 data_dir = '/content/my_dataset'のように指定しています。さきほど、Google Driveに保存したので、データは/content/drive/My\ Drive/data/OK_idol.zipに格納されているはずです。ここでzipを展開すれば良い気がしますが、実はこれは良くありません。Google driveは書き込みが非常に遅いため、zipを展開して大量の画像を書き込むのに長時間かかってしまいます。一方、Colaboratory上での展開は高速です。なので、以下のようにしてColaboratory上にディレクトリを作り、そこに画像を展開してください。

!mkdir /content/my_dataset
!unzip /content/drive/My\ Drive/data/OK_idol.zip -d /content/my_dataset > /dev/null 2>&1 & 

展開には数分かかります。以下のコマンドで、画像が何枚展開されたかをカウントできます。

!echo /content/my_dataset/OK_idol/* | xargs ls | wc

全て展開できたことを確認したら、以下のコマンドを実行します。詳しくはPGGANのREADMEに書いてありますが、自作のデータセットを、PGGANが読み込める形に変形するコマンドです。

!python dataset_tool.py create_from_images /content/my_dataset/OK_idol_for_PGGAN /content/my_dataset/OK_idol

以上のコマンドが完了したら、次のコマンドを実行します。これで学習が開始します。

!python train.py

Colaboratoryは、ある条件を満たすと環境が丸ごと削除されてしまいます。「ブラウザのセッションが切れて90分以上経過する」「連続稼働時間が12時間を上回る」のどちらか一方でも満たすと削除されてしまいます。したがって、ブラウザを起動してPCを放置していても、12時間しか学習を回せません。私は、毎晩0時頃、夜寝る前と、12時頃にお昼ご飯を食べるときに、resume機能を使って学習を再開させています。

学習を再開させる方法は、上のコンフィグで、次の行を書き換えればOKです。

# Resume
train.resume_run_id = '/content/drive/My Drive/work/PGGAN/progressive_growing_of_gans/results/032-pgan-idol512-preset-v2-1gpu-fp32/network-snapshot-xxxxxx.pkl';
train.resume_kimg = xxxxxxx

最初に学習を始めるときは、これらの行はコメントアウトされていました。resumeで2回目以降の学習を行うときはコメントを外して、032-pgan-idol512-preset-v2-1gpu-fp32/network-snapshot-xxxxxx.pklの部分を、ご自身のGoogle Driveに保存されている学習済みモデルの名前にリネームしてください。これが、resume時に使用するモデルとなります。

また、

train.resume_kimg = xxxxxxx

の右辺には、network-snapshot-xxxxxx.pkl のxxxxxと同じ値を整数で指定してください。

学習は、12時間×2セットを毎日行い、3週間ほど続ける必要があります。

3週間の学習の末、得られたモデルを使ってアイドル画像を生成してみました。以下がその結果です。

f:id:twx:20181215204301g:plain
生成されたアイドルが変形していく様子

こんな感じで動画も作れます。

いい感じのショットを選んで4x2にアペンドしてみました。

f:id:twx:20181215154156p:plain
Artificial Idol

なかなかの出来ですね!

以上、今回はProgressive Growing of GANsを使って簡単にアイドル画像を自動生成してみました。良い記事だと思っていただいた方は、SNSでのシェア、ブログからのリンク、「読者になる」ボタンのクリック、「★」ボタンのクリック、よろしくお願いします! ではまた次の記事でお会いしましょう!

OpenCV3をインポートするときcv3ではなくcv2なのは何故?

f:id:twx:20180917152722p:plain

ふと気になったので調べてみました。

こちらに答えがありました。↓

answers.opencv.org

cv2convert toの略なんじゃないか?という意見もあるみたいですが、実際は、「2はバージョンを表しているのではなく、C APIをcvというプレフィックスで表し、C++ APIをcv2というプレフィックスで表すようにしている」だけのようです。

Google ColaboratoryのCUDAを9.0にアップグレードする方法【失敗】

f:id:twx:20180915171037p:plain

結論

失敗しました。CUDA9をインストールしても、Nvidiaドライバとの互換性を合わせられなかったり、pipでtensorflowをアップグレードできなかったりと、色々ハマります。Googleが公式でCUDA9をサポートしてくれるのを待つしかなさそうだという結論に至りました。

やったこと

現時点(2018年9月15日)でGoogle ColaroratoryにインストールされているCUDAのバージョンは8です。CUDA9を前提としているプログラムが動かず困っていたのですが、色々試した結果CUDA9.0にアップグレードする方法を見つけました。この記事では、その方法を簡単にご紹介します。

CUDA9.0をインストール

こちらのページを参考にしました。 stackoverflow.com

Colab上で以下を実行します。

!wget https://developer.nvidia.com/compute/cuda/9.0/Prod/local_installers/cuda-repo-ubuntu1604-9-0-local_9.0.176-1_amd64-deb
!dpkg -i cuda-repo-ubuntu1604-9-0-local_9.0.176-1_amd64-deb
!apt-key add /var/cuda-repo-9-0-local/7fa2af80.pub
!apt-get update
!apt-get install cuda

cudnn7をインストール

こちらのページを参考にしました。

qiita.com

Colab上で以下を実行します。

!wget http://developer.download.nvidia.com/compute/redist/cudnn/v7.1.4/cudnn-9.2-linux-x64-v7.1.tgz
!tar -xvzf cudnn-9.2-linux-x64-v7.1.tgz
!cp cuda/lib64/* /usr/local/cuda/lib64/
!cp cuda/include/* /usr/local/cuda/include/

環境変数をセット

Colab上で以下を実行します。

import os 
os.environ["CUDA_HOME"]="/usr/local/cuda"
os.environ["LD_LIBRARY_PATH"]="/usr/local/cuda/lib64"

これで、CUDA9.0、cudnn7.1がインストールされます。

追記 (2018.9.15 21:30)

ドライバのバージョンが合ってないと怒られました。対策を検討中です。

!nvidia-smi

Failed to initialize NVML: Driver/library

追記 (2018.9.16 10:20)

結局、Nvidiaドライバを切り替えることができず断念しました。。。Googleが公式でCUDA9をサポートしてくれる日を待ちましょう。

以上です。 今回は「Google ColaboratoryのCUDAを9.0にアップグレードする方法【失敗】」をご紹介しました。

良い記事だと思っていただいた方は以下の「★+」ボタンのクリック、SNSでのシェア、コメント、「読者になる」ボタンのクリックをお願いします!

それではまたー

Image Inpainting(画像修復)をやってみた

f:id:twx:20180913152447p:plain

画像出典: https://arxiv.org/pdf/1804.07723.pdf

Image Inpaintingとは

不定形に塗りつぶされた画像を修復させるというタスクだ。NVIDAが発表した論文とデモ動画が、その精度の高さで話題となっている。

元論文

元論文はこちら。 Image Inpainting for Irregular Holes Using Partial Convolutions https://arxiv.org/pdf/1804.07723.pdf

実装

サードパーティー製の実装がGithubにある。こちらを使わせていただこう。

GitHub - MathiasGruber/PConv-Keras: Keras implementation of "Image Inpainting for Irregular Holes Using Partial Convolutions"

画像の準備

画像データは自分で用意した。 ImageNet画像を全部で約2万9000枚ほど用意し、train, test ,validationという名前のディレクトリに振り分けた。

trainに約2万7000枚、trainvalidationにそれぞれ約1000枚という配分にした。

これらをzipで固めてGoogle Driveに送り、以下の要領でColaboratory環境に移した。

www.mahirokazuko.com

学習

はじめ、コードをそのまま実行したところサイズが512x512x3の画像を入力として学習がはじまった。しかし、1エポック進むのにかなりの長時間を要しそうだということが分かったため、入力サイズを256x256x3に変更した。

model = PConvUnet(img_rows=256, img_cols=256, weight_filepath='data/logs/')

また、試しに動作確認をしたかったので、現実時間で待てるレベルまでエポック数を下げた。

model.fit(
    train_generator, 
    steps_per_epoch=100,
    validation_data=val_generator,
    validation_steps=50,
    epochs=50,        
    plot_callback=plot_callback,
    callbacks=[
        TensorBoard(log_dir='data/logs/initial_training', write_graph=False)
    ]
)

90分くらいで20エポックまで来た。

Epoch 1/1
Found 27308 images belonging to 1 classes.
 99/100 [============================>.] - ETA: 1s - loss: 1192959.7986Found 1000 images belonging to 1 classes.
100/100 [==============================] - 149s 1s/step - loss: 1192393.0381 - val_loss: 1140785.5613
Epoch 2/2
100/100 [==============================] - 145s 1s/step - loss: 929772.3394 - val_loss: 1083900.0400
Epoch 3/3
100/100 [==============================] - 139s 1s/step - loss: 896846.2562 - val_loss: 918737.5038
Epoch 4/4
100/100 [==============================] - 141s 1s/step - loss: 832707.7262 - val_loss: 912742.8137
Epoch 5/5
100/100 [==============================] - 144s 1s/step - loss: 817528.7878 - val_loss: 907044.2275
Epoch 6/6
100/100 [==============================] - 143s 1s/step - loss: 765230.2653 - val_loss: 918365.0200
Epoch 7/7
100/100 [==============================] - 148s 1s/step - loss: 748308.9103 - val_loss: 863127.1031
Epoch 8/8
100/100 [==============================] - 149s 1s/step - loss: 753139.3806 - val_loss: 886118.9925
Epoch 9/9
100/100 [==============================] - 146s 1s/step - loss: 733665.1872 - val_loss: 820381.9600
Epoch 10/10
100/100 [==============================] - 148s 1s/step - loss: 724513.2947 - val_loss: 807128.7750
Epoch 11/11
100/100 [==============================] - 150s 1s/step - loss: 718573.7362 - val_loss: 842021.4556
Epoch 12/12
100/100 [==============================] - 147s 1s/step - loss: 687009.6784 - val_loss: 792984.8875
Epoch 13/13
100/100 [==============================] - 146s 1s/step - loss: 717200.5312 - val_loss: 793620.3700
Epoch 14/14
100/100 [==============================] - 141s 1s/step - loss: 671126.3444 - val_loss: 784162.4425
Epoch 15/15
100/100 [==============================] - 141s 1s/step - loss: 664130.9928 - val_loss: 768478.3525
Epoch 16/16
100/100 [==============================] - 141s 1s/step - loss: 641558.0719 - val_loss: 734417.9025
Epoch 17/17
100/100 [==============================] - 150s 2s/step - loss: 654648.5581 - val_loss: 778414.4363
Epoch 18/18
100/100 [==============================] - 145s 1s/step - loss: 643995.8372 - val_loss: 727236.1906
Epoch 19/19
100/100 [==============================] - 146s 1s/step - loss: 633081.0981 - val_loss: 703399.7200
Epoch 20/20
100/100 [==============================] - 146s 1s/step - loss: 630721.5150 - val_loss: 742294.0931
Epoch 21/21
100/100 [==============================] - 147s 1s/step - loss: 626138.9844 - val_loss: 715688.0587
Epoch 22/22
100/100 [==============================] - 146s 1s/step - loss: 606827.5503 - val_loss: 690054.2956

ロスのオーダーが105なのが気になる…

TensorBoardで見てみるとこんな感じ↓。

f:id:twx:20180913153940p:plain

20エポック時点でのテスト結果を見てみるとこんな感じ↓。

f:id:twx:20180913154136p:plain f:id:twx:20180913154150p:plain f:id:twx:20180913154201p:plain

いい感じではある。

論文によると、「10日間学習させる必要があった」と書かれているので、もう少し経過を待ってみようと思う。

追記 (50エポック終了)

約3時間で50エポックが終了しました。

f:id:twx:20180913180646p:plain

まだまだロスは下がりそうです。

テストデータでの結果は以下の通り。

f:id:twx:20180913180842p:plain f:id:twx:20180913180850p:plain f:id:twx:20180913180859p:plain

さらに学習は継続していくので、続報をお楽しみに。

以上、本日は「Image Inpaintingをやってみた」ということで、Image InpaintingをGoogle Colaboratoryで回してみました。

良い記事だと思っていただいた方は、以下の「★+」ボタンのクリック、SNSでのシェア、「読者になる」ボタンのクリックをお願いします。 それではまたー!

Google Colaboratory に大量データをアップロードする方法

Google Driveをマウントする

Colab上で以下を実行。/content/driveというディレクトリにGoogle Driveがマウントされる。

!apt-get install -y -qq software-properties-common python-software-properties module-init-tools
!add-apt-repository -y ppa:alessandro-strada/ppa 2>&1 > /dev/null
!apt-get update -qq 2>&1 > /dev/null
!apt-get -y install -qq google-drive-ocamlfuse fuse

from google.colab import auth
from oauth2client.client import GoogleCredentials
import getpass
auth.authenticate_user()
creds = GoogleCredentials.get_application_default()

!google-drive-ocamlfuse -headless -id={creds.client_id} -secret={creds.client_secret} < /dev/null 2>&1 | grep URL
vcode = getpass.getpass()
!echo {vcode} | google-drive-ocamlfuse -headless -id={creds.client_id} -secret={creds.client_secret}

!mkdir -p drive
!google-drive-ocamlfuse drive

Google Driveにアップロードする

xxxxx.zipのようにzipで圧縮してGoogle Driveに普通にアップロードする。

Google DriveからGoogle Colaboratoryにファイルを移動し展開

Colab上で以下を実行。/content/dataというディレクトリに大量データが展開される。

!mkdir /content/data
!unzip /content/drive/train.zip -d /content/data > /dev/null 2>&1 &

ポイントは、> /dev/null 2>&1 &で、標準出力を捨てているところだ。

そうしないと、unzipで大量データを展開したときに大量の標準出力が表示されてしまい、Chromeがクラッシュして落ちてしまうことがある。

解凍後のファイル数を確認するには以下を実行する。

!echo /content/data/train/* | xargs ls | wc

出力例。27308個のファイルが展開されていることがわかる。

  27308   27308 1119628

以上、本日は「Google Colaboratory に大量データをアップロードする方法」を紹介しました。 良い記事だと思っていただいた方は下の「★+」ボタンのクリック、SNSでのシェア、「読者になる」ボタンのクリックをお願いします! それではまたー

RailsアプリをAWS EC2で公開する超簡単な手順 【独自ドメイン/HTTPS対応】

f:id:twx:20180910173108p:plain

やりたいこと

  • Railsアプリをインターネットに公開したい
  • AWS EC2で公開したい
  • 独自ドメインで公開したい
  • HTTPSで公開したい
  • 単一のEC2で複数のアプリを公開したい

この記事では、独自ドメインやHTTPSにも対応した形でEC2でアプリを公開する手順を紹介する。

なお、独自ドメインの発行およびAWSの運用には料金がかかる。 筆者はだいたい月に3000円くらいかかっている。(AWS: 3000円/月くらい, ドメイン: 10円/月くらい)

この記事に書かれていることを実践し何らかの不利益を被ったとしても自己責任であり、筆者は責任を負わない旨をご了承いただきたい。

また、「EC2インスタンスを立ち上げる」「Route53でDNSレコードを登録する」などといった基本的な作業の手順はここには載せない。ググれば易しく解説してくれている記事がたくさん見つかるので、それらを参照していただきたい。

1. AWS EC2の環境を整える

まずはAWSコンソールにログインしよう。

https://ap-northeast-1.console.aws.amazon.com/console/

EC2インスタンスを1台立ち上げる。本格運用しないのであれば安いインスタンスで構わない。 この記事ではOSがUbuntuのイメージを使用したという前提で話を進める。

1.1 ユーザ作成

EC2にログインしたら、まずアプリごとにLinuxユーザを作ろう。 まず、デフォルトユーザ(ユーザ名ubuntu)でログインする。

以下のコマンドを実行してユーザを作っていく。ここではユーザ名はbananaとして話を進めていく。

ubuntuユーザとして実行

sudo useradd -m banana
sudo gpasswd -a banana sudo
sudo passwd banana

パスワードを聞かれれるので任意の文字列を打ち込む。これは今後、bananaユーザのパスワードとなるので覚えておこう。

次に、ログイン時のシェルを指定する。

sudo chsh banana

シェル名を聞かれるので、/bin/bashと打ち込もう。

1.2 sshでログインできるようにする

ubuntuユーザのキーをコピーして使うことにする。

sudo su banana
cd /home/banana
mkdir .ssh
sudo cp ../ubuntu/.ssh/authorized_keys .ssh/
sudo chown banana .ssh/authorized_keys
sudo chgrp banana .ssh/authorized_keys
exit

これで、bananaユーザとしてsshログインできるようになる。いったんEC2からログアウトし、bananaユーザとして入り直そう。以下の操作は全てbananaユーザとして行う。

1.3 もろもろ環境設定

ここは、各自好きなように環境設定すれば良い。

## 誤ってrmコマンドで削除してしまうのを防止する
echo "alias rm='rm -i'" >> ~/.bashrc
exec $SHELL

## sshが切れないようにする
sudo vi /etc/ssh/sshd_config
ClientAliveInterval 30 # 末尾に追加
sudo /etc/init.d/ssh restart

## OS最新化
sudo apt-get update -y
sudo apt-get upgrade -y

## IPv6サポートをoffにする
sudo vi /etc/default/ufw 
IPV6=no # yesからnoに変更


## ポート22, 80, 3000番を開けファイアウォールを有効化
sudo ufw allow 22
sudo ufw allow 80
sudo ufw allow 3000
sudo ufw enable
sudo ufw status

## よく使うコマンドのインストール (自身の好みに合わせてこの辺は自由にインストールしてください)
sudo apt-get install -y pwgen zip unzip nkf screen imagemagick
sudo apt-get install mecab libmecab-dev mecab-ipadic # Mecab
sudo apt-get install mecab-ipadic-utf8
echo "本日は晴天なり" | mecab # => 形態素解析のテスト

## 開発ツールのインストール
sudo apt-get install -y build-essential automake libssl-dev libreadline-dev libyaml-dev libpq-dev libbz2-dev libsqlite3-dev
sudo apt-get upgrade 

## pythonのインストール
git clone https://github.com/pyenv/pyenv.git ~/.pyenv
echo '' >> ~/.bashrc
echo '# pyenv' >> ~/.bashrc
echo 'export PYENV_ROOT="$HOME/.pyenv"' >> ~/.bashrc
echo 'export PATH="$PYENV_ROOT/bin:$PATH"' >> ~/.bashrc
echo 'eval "$(pyenv init -)"' >> ~/.bashrc
exec $SHELL

pyenv install 2.7.11
pyenv global 2.7.11
python --version

## rubyのインストール
git clone https://github.com/rbenv/rbenv.git ~/.rbenv
echo '' >> ~/.bashrc
echo '# rbenv' >> ~/.bashrc
echo 'export RBENV_ROOT="$HOME/.rbenv"' >> ~/.bashrc
echo 'export PATH="$RBENV_ROOT/bin:$PATH"' >> ~/.bashrc
echo 'eval "$(rbenv init -)"' >> ~/.bashrc
exec $SHELL

git clone https://github.com/sstephenson/ruby-build.git ~/.rbenv/plugins/ruby-build
rbenv install --list
rbenv install 2.4.1
rbenv global 2.4.1
ruby --version

## nodeのインストール
curl -L git.io/nodebrew | perl - setup
echo '' >> ~/.bashrc
echo '# nodebrew' >> ~/.bashrc
echo 'export NB_ROOT="$HOME/.nodebrew/current"' >> ~/.bashrc
echo 'export PATH="$NB_ROOT/bin:$PATH"' >> ~/.bashrc
exec $SHELL

nodebrew ls-remote
mkdir .nodebrew/default/src
nodebrew install-binary v8.1.2
nodebrew use v8.1.2
node --version

## 各種パッケージ最新化
pip install --upgrade pip
gem update --system
npm install -g npm

## よく使うgemをグローバルインストール
gem install bundler
gem install unicorn
gem install execjs

2. データベースの設定

PostgreSQLを使用する。

2.1 PostgreSQLのインストール・設定

### (参考) https://www.postgresql.org/download/linux/ubuntu/
sudo vi /etc/apt/sources.list.d/pgdg.list
deb http://apt.postgresql.org/pub/repos/apt/ xenial-pgdg main # この行を追加

wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add -
sudo apt-get update
sudo apt-get -y install postgresql-9.4
sudo apt autoremove

sudo locale-gen ja_JP.UTF-8
sudo service postgresql restart


## PostgreSQLの設定
sudo vi /etc/postgresql/9.4/main/postgresql.conf # PostgreSQL設定ファイル
shared_buffers = 512MB # パフォーマンス向上のためサイズ拡大(物理メモリの1/4程度が望ましい)
sudo service postgresql restart

DBを作る

アプリごとにDBを作ろう。

ここでは、bananaという名前のPostgreSQLユーザ、banana_productionという名前のデータベースを例として話を進める。 パスワードはpwgenで生成したものから適当に選んで使う。

pwgen  # => Nohwee3k
sudo su - postgres
psql

CREATE DATABASE banana_production WITH TEMPLATE template0 ENCODING = 'UTF-8' LC_COLLATE = 'ja_JP.UTF-8' LC_CTYPE = 'ja_JP.UTF-8';
CREATE USER banana WITH LOGIN PASSWORD 'Nohwee3k';
ALTER USER banana CREATEDB;
GRANT ALL PRIVILEGES ON DATABASE banana_production TO banana;

3. ソースコードをEC2に持ってくる

ローカルで開発したアプリのコードをEC2上に持ってこよう。これはGitリポジトリで経由で行う。 わたしはBitBucketをよく使う。BitBucketは少人数開発であれば無料でプライベートリポジトリを複数作ることができるので便利だ。

bitbucket.org

ただ、事前にBitBucketにSSHキーを登録しておく必要があるので注意。ローカルPCとEC2の両方で、以下のようにしてキーをBitBucketに登録する必要がある。

cd ~/.ssh
ssh-keygen -t rsa -C your_mail_address@example.com
mv id_rsa bitbucket_rsa # わかりやすくするためリネーム
mv id_rsa.pub bitbucket_rsa.pub # わかりやすくするためリネーム
chmod 600 bitbucket_rsa
vi ~/.ssh/config

### 以下を追記
Host bitbucket.org
  HostName bitbucket.org
  IdentityFile ~/.ssh/bitbucket_rsa
  User git
### ここまで

cat ~/.ssh/bitbucket_rsa.pub # ここで表示される文字列をクリップボードへコピーしておく

コピーした公開鍵情報をBitBucketの「SSH鍵」に登録する。 以下のURL (2018.9.10現在)の画面から登録する。

https://bitbucket.org/account/user/YOUR_USER_NAME/ssh-keys/

SSH鍵を登録できたら、BitBucketにリポジトリを作り、ここにローカルアプリをプッシュしよう。

# ローカル等の開発環境で実行
git init
git add .
git commit -m "First commit"
git remote add origin git@bitbucket.org:YOUR_NAME/banana_app.git
git push -u origin master 

EC2の/home/banana上で、リポジトリをクローンする。

# EC2で実行
cd /home/banana
git clone git@bitbucket.org:YOUR_NAME/banana_app.git

4. Production環境で動かす

本番環境でRailsアプリを動かすために、環境変数やデータベース接続の設定を行おう。

4.1 本番用GEMのインストール

unicornとpostgresqlを使うのでGemfileに以下を追加しておくこと。

gem 'pg'
gem 'unicorn'

まずgemをインストールする。

bundle install --path vendor/bundle --without development test # 本番環境依存のパッケージをローカルインストール

インストール時にこんなエラーが出ることがある

Important: You may need to add a javascript runtime to your Gemfile in order for bootstrap's LESS files to compile to CSS.

**********************************************

ExecJS supports these runtimes:

therubyracer - Google V8 embedded within Ruby

therubyrhino - Mozilla Rhino embedded within JRuby

Node.js

**********************************************

対処方法は以下。

gem install execjs

4.2 データベースの接続設定

vi config/database.yml

### 以下を追記
production:
  adapter: postgresql
  encoding: unicode
  database: banana_production
  pool: 5
  username: banana
  password: Nohwee3k
  min_messages: WARNING
### ここまで

ここで、databaseusernamepasswordは先ほどPostgreSQLの設定で使った値と同じものを使うこと。

4.3 環境変数の設定

## 本番環境であることを表す環境変数をセットする
echo "export SECRET_KEY_BASE='production'" >> ~/.bashrc
rake secret
echo "export SECRET_KEY_BASE='435c1b5517e10640b8d661e361cac2682c6dc0c4690ad4c952d6af84098c02538d0a5acb138e34bb313a420b2a4dd97743d182a6307cea12ace817d8db844e6d'" >> ~/.bashrc
exec $SHELL
RAILS_ENV=production bundle exec rake db:setup # DBをproduction環境でセットアップする

## アセットコンパイルを行う
RAILS_ENV=production bundle exec rails assets:precompile

## 起動確認
bundle exec rails s -e production # ブラウザでhttp://xxx.xxx.xxx.xxx:3000を開いて表示されるか確認してみる

ここまで行えば、http://xxx.xxx.xxx.xxx:3000 という風にIPとポート番号でアクセスできるようになる。

次からは独自ドメイン、HTTPSで公開する方法を説明する。

5. Unicorn/Nginxの設定

5.1 Unicornの設定

Unicornというミドルウェアを使用する。Unicornは、次に紹介するNginxというWebサーバとRailsアプリの間を橋渡ししてくれる。

# アプリのルートディレクトリに入って実行
vi config/unicorn.rb # 新規作成

### 以下を追加
# set lets
$app_dir = "/home/banana/banana_app" # Railsアプリケーションのドキュメントルート
$worker  = 2 # HTTPリクエストを処理するプロセスの数. 最低でもCPUコア数以上とすること. コア数は cat /proc/cpuinfo コマンドで確認できる.
$timeout = 30
$listen  = File.expand_path 'tmp/sockets/.unicorn.sock', $app_dir # 任意のパスで良い. Nginxが参照可能な, UNIXドメインソケットのパス
$pid     = File.expand_path 'tmp/pids/unicorn.pid', $app_dir # 同じく任意のパスで良い.
$std_out = File.expand_path 'log/unicorn.stdout.log', $app_dir # 任意のパス. ログの出力先
$std_err = File.expand_path 'log/unicorn.stderr.log', $app_dir # 任意のパス. エラーログの出力先

# set config
working_directory $app_dir
worker_processes  $worker
stdout_path $std_out
stderr_path $std_err
timeout $timeout
listen  $listen
pid $pid

# loading booster
preload_app true

# before starting processes
before_fork do |server, worker|
  defined?(ActiveRecord::Base) and ActiveRecord::Base.connection.disconnect!
  old_pid = "#{server.config[:pid]}.oldbin"
  if old_pid != server.pid
    begin
      Process.kill "QUIT", File.read(old_pid).to_i
    rescue Errno::ENOENT, Errno::ESRCH
    end
  end
end

# after finishing processes
after_fork do |server, worker|
  defined?(ActiveRecord::Base) and ActiveRecord::Base.establish_connection
end
### ここまで

ここで、上部のこのあたりは各自の環境に合わせて変えていただきたい。その他の部分はオマジナイだと思って、そっくりそのままコピペで構わない。

$app_dir = "/home/banana/banana_app" # Railsアプリケーションのドキュメントルート
$worker  = 2 # HTTPリクエストを処理するプロセスの数. 最低でもCPUコア数以上とすること. コア数は cat /proc/cpuinfo コマンドで確認できる.
$timeout = 30

次に、

sudo vi /etc/init.d/banana

### 以下を追加
#!/bin/sh

# BEGIN INIT INFO
# Provides:          banana
# Required-Start:    $local_fs $remote_fs $network $syslog
# Required-Stop:     $local_fs $remote_fs $network $syslog
# Default-Start:     2 3 4 5
# Default-Stop:      0 1 6
# Short-Description: starts banana (a rails app)
# Description:       starts banana (a rails app) using start-stop-daemon
# END INIT INFO

USER=banana
APP_ROOT=/home/banana/banana_app
RAILS_ENV=production
PID_FILE=$APP_ROOT/tmp/pids/unicorn.pid
CONFIG_FILE=$APP_ROOT/config/unicorn.rb
CMD="/home/banana/.rbenv/shims/bundle exec /home/banana/.rbenv/shims/unicorn_rails"
ARGS="-c $CONFIG_FILE -D -E $RAILS_ENV"

export RAILS_SERVE_STATIC_FILES=1
export SECRET_KEY_BASE='435c1b5517e10640b8d661e361cac2682c6dc0c4690ad4c952d6af84098c02538d0a5acb138e34bb313a420b2a4dd97743d182a6307cea12ace817d8db844e6d''
export PATH=/home/banana/.rbenv/shims/:$PATH

case $1 in
  start)
    start-stop-daemon --start --chuid $USER --chdir $APP_ROOT --exec $CMD -- $ARGS || true
    ;;
  stop)
    start-stop-daemon --stop --signal QUIT --pidfile $PID_FILE || true
    ;;
  restart|force-reload)
    start-stop-daemon --stop --signal USR2 --pidfile $PID_FILE || true
    ;;
  status)
    status_of_proc -p $PID_FILE "$CMD" banana && exit 0 || exit $?
    ;;
  *)
    echo >&2 "Usage: $0 <start|stop|restart|force-reload|status>"
    exit 1
    ;;
esac
### ここまで

ここで、banana_appbananaと書かれている箇所だけ、自身の環境に合わせればOK。たとえば、mikanというユーザでmikan_appというアプリを作る場合は以下のように置換すれば良い。

cat  | perl -pe  "s/banana_app/mikan_app/g" | perl -pe "s/banana/mikan/g"

また、Railsアプリ内でENV["HOGE_HOGE"]のように環境変数を使っている場合は、このファイルに環境変数を書いておく必要がある。たとえば、ツイッターのAPIを使っていて、ツイッターのAPI KEYとAPI SECRETを環境変数から呼び出したいときは

# Twitter
export TWITTER_KEY='xxxxxxxxxxxxxxxxxxxxxxxxxxxx'
export TWITTER_SECRET='xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx'

みたいに書いておこう。

ここまで出来たら有効化する。

# 有効化
sudo chmod 755 /etc/init.d/banana
sudo update-rc.d banana defaults

### bananaアプリの動作確認
sudo service banana start
sudo service banana stop
sudo service banana restart
sudo service banana status

# ※ statusがrunningになっているかどうか確認する。

5.2 Nginxの設定

Nginxがインストールされていなければ、まずインストールするところから。

### (1) nginxサイトが配布するPGPキーを追加
curl http://nginx.org/keys/nginx_signing.key | sudo apt-key add -

### (2) リポジトリを一覧に追加
sudo sh -c "echo 'deb http://nginx.org/packages/ubuntu/ xenial nginx' >> /etc/apt/sources.list"
sudo sh -c "echo 'deb-src http://nginx.org/packages/ubuntu/ xenial nginx' >> /etc/apt/sources.list"

### (3) アップデート後、nginxをインストール
sudo apt-get update
sudo apt-get install nginx

インストールできたら、設定ファイルをいじっていく。

## バーチャルホストの設定
cd /etc/nginx
sudo rm koi-utf koi-win win-utf # これらはロシア語の文字コード返還テーブル.  ただ邪魔なだけなので削除しておく.

sudo vi nginx.conf

### 以下のように変更
user  nginx;
worker_processes  1; # CPUコア数と同じ値が推奨

error_log  /var/log/nginx/error.log warn;
pid        /var/run/nginx.pid;

events {
    worker_connections  1024; # 同時接続数
}

http {
    include       /etc/nginx/mime.types;
    default_type  application/octet-stream;
    log_format  main  '$remote_addr - $remote_user [$time_local] "$request" '
                      '$status $body_bytes_sent "$http_referer" '
                      '"$http_user_agent" "$http_x_forwarded_for"';

    access_log  /var/log/nginx/access.log  main;
    sendfile        on;
    keepalive_timeout  65;
    include /etc/nginx/sites-enabled/*.conf;
}
### ここまで

ここで、最後の行include /etc/nginx/sites-enabled/*.conf;が重要な役割を担っている。sites-enabledディレクトリの中に入っている*.confというコンフィグファイルがアルファベット順に実行される。

ただし、Ngingの作法としては、sites-enabledに直接コンフィグファイルを置かず、代わりにsites-availableというディレクトリにコンフィグファイルを置くという風にされている。sites-available内に作ったコンフィグファイルに対してシンボリックリンクを張り、そのリンクをsites-enabledに置くというやり方をする。

sudo mkdir sites-available # コンフィグファイルの置き場を作る
sudo mkdir sites-enabled # コンフィグファイルに張るシンボリックリンクを置く場所を作る
sudo cp conf.d/default.conf sites-available/ # デフォルトのコンフィグファイルをコンフィグ置き場にコピーする

デフォルトのコンフィグを以下のように修正する。

vi sites-available/default.conf

###
server {
    listen       80;
    server_name  localhost;

    # 適当なランダムな文字列↓
    location /_LJoilLNe5KJHIy84LL4lKJEZ {
        root   /usr/share/nginx/html;
        index  index.html index.htm;
    }

    error_page   500 502 503 504  /50x.html;
    location = /50x.html {
        root   /usr/share/nginx/html;
    }
}
### ここまで

ここで_LJoilLNe5KJHIy84LL4lKJEZというのは適当なランダムな文字列だ。これは、後ほど紹介するEC2ロードバランサのヘルスチェック用として使う。

以下のヘルスチェック用のページを用意しよう。

sudo vi /usr/share/nginx/html/_rQhFwH9PWgMXZyUtPrrfi2JQ9nVMfPUeyInYcdEZ/index.html

### 以下を追加
<!DOCTYPE html>
<html>
<head>
<title>Check Page</title>
</head>
<body>
Check Page
</body>
</html>
###

次に、banana用のコンフィルファイルを作る。

sudo vi sites-available/banana.conf # 新しいコンフィグ新規作成

### ここから
limit_conn_zone $binary_remote_addr zone=conn_limit_per_ip:10m;
limit_req_zone $binary_remote_addr zone=req_limit_per_ip:10m rate=1r/s;

upstream app_server {
  server unix:/home/banana/banana_app/tmp/sockets/.unicorn.sock;
}

server {
  listen 80;
  server_name example.com; # このドメインでアプリを公開する

  # 静的コンテンツ(public/****)の場所
  root /home/banana/banana_app/public;
  error_log  /home/banana/banana_app/log/nginx.error.log;
  access_log /home/banana/banana_App/log/nginx.access.log;

  # 接続制限関連
  keepalive_timeout 5; # 接続を保つ秒数 (国内であれば5-10でOK)
  client_max_body_size 256k; # クライアントからのリクエストボディは256KBまで許容する.
  client_header_buffer_size 128k; # クライアントからのリクエストヘッダは128KBまで許容する.
  large_client_header_buffers 4 128k; # クライアントからのリクエストヘッダは128KBまで許容する.
  limit_conn conn_limit_per_ip 10; # 1つのIPに対する同時接続は10本まで許容する.
  limit_req zone=req_limit_per_ip burst=5; # リクエスト数/秒の制限にかかったリクエストは5件まで許容する

  # Railsアプリへのルーティング
  try_files $uri/index.html $uri.html $uri @app;
  location @app {
    # HTTP headers
    proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
    proxy_set_header Host $http_host;
    proxy_redirect off;
    proxy_pass http://app_server;
  }

  # Railsエラーページ
  error_page 500 502 503 504 /500.html;
  location = /500.html {
    root /home/banana/banana_app/public;
  }
}
### ここまで

sudo ln -s /etc/nginx/sites-available/default.conf /etc/nginx/sites-enabled/000_default.conf # シンボリックリンクを張る
sudo ln -s /etc/nginx/sites-available/banana.conf /etc/nginx/sites-enabled/ # シンボリックリンクを張る

ここで、例としてexample.comというドメインでアプリを公開するという体で話を進める。ここは、自分が使いたいドメイン名で置き換えていただきたい。

最後に、これまで作ったNginxのファイルが間違っていないかチェックする。

sudo nginx -t # コンフィグが正常かどうかテスト

# syntax ok と出ればOK.
# okと出ない場合は、どこかタイピングミスしていると思われる。

最後に、Nginxを有効化する。

sudo service nginx restart
sudo service nginx status

# ※ statusがrunningになっているかどうか確認する。

6. 独自ドメイン/HTTPSで公開

6.1 お名前.comでドメインを買う

www.onamae.com

6.2 お名前.comでネームサーバをRoute53に設定する

AWSのRoute53で、新たにHosted zoneを作り、先程購入したドメインと同じ名前のhosted zoneを作る。新しく割り当てられた4のネームサーバ (nsと書かれている)をコピーしておく。 https://console.aws.amazon.com/route53/

f:id:twx:20180911105543p:plain

コピーしたネームサーバを、お名前.comのネームサーバに設定する。

https://www.onamae.com/domain/navi/ns_update/input

このページで、「他のネームサーバを利用」というタブををクリックし、先程コピーしておいた4つのNSを入力する。

6.3 AWS Certificate ManagerでSSL証明書を取得する

https://ap-northeast-1.console.aws.amazon.com/acm/home

「証明書のリクエスト」を押して、先程取得したドメイン名と同じドメインの証明書を取得する。その際、「Route53にレコード(CNAME)を追加する」という旨のボタンを押しておく。(AWSのアップデートで、ボタン名の表現は頻繁に変わるので一言一句同じボタンは無いかもしれない。ニュアンスで判断していただきたい。)

※ 10分ほど待つと有効になる。有効にならないと、次の処理が失敗するので必ず有効になっているかどうか確認すること。

6.4 クラシックロードバランサを立てる

https://ap-northeast-1.console.aws.amazon.com/ec2/v2/home#LoadBalancers

クラシックロードバランサはドメインごとに立てる。(アプリを2つ公開したい場合は2つのドメイン、2つのロードバランサを立てる。) クラシックロードバランサを立てるとき、設定画面で「HTTPS:443をHTTP:80にリダイレクトする」ように設定する。また、SSL証明書として、上記で取得したACMの証明書を指定する。また、紐付けるインスタンスは当然、本番運用するEC2とする。

ロードバランサが出来たら、ヘルスチェックの設定を行う。

これは、定期的にロードバランサがEC2の死活をチェックするという機能だ。特定のURLでアクセスしEC2から反応があれば「生きている」と判断される。ここで、「特定のURL」というのが、先程作成した適当な名前(_LJoilLNe5KJHIy84LL4lKJEZとした)のサイトだ。

「ヘルスチェックの編集」画面で以下の様に設定する。

f:id:twx:20180911111315p:plain

Aレコードでロードバランサのエイリアスを指定

Route53で、Aレコードを作成する。 AWSでは、Aレコードを選んだときに「エイリアス」というものを指定できる。

f:id:twx:20180911111722p:plain

関連付け対象に、先程作ったロードバランサを指定する。

以上で準備は全て整った。

7. 確認

https://YOUR_DOMAIN

でアプリが見れるようになった。

以上、今回はRailsアプリをAWS EC2で公開する超簡単な手順 【独自ドメイン/HTTPS対応】を紹介しました。 この記事が役にたったという方は以下の「★+」ボタンのクリック、SNSでのシェア、「読者になる」ボタンのクリックをお願いいたします。 ではまた〜

深層学習の「超解像」でモザイクは除去できるのか (Tensorflow + Keras)

f:id:twx:20180909143610p:plain

画像出典: http://mmlab.ie.cuhk.edu.hk/projects/SRCNN.html

結論から言うと

うまくいかなかった。

あくまで実験記録として記事を書くが、ここにあるコードを真似してもうまく高解像度化できないので注意していただきたい。

超解像をやってみた。

モザイク画像から元画像を推定するAutoEncoderを作ってみた。深層学習フレームワークはTensorflow+Kerasを使用した。 参考にした論文はこれだ。

Learning a Deep Convolutional Network for Image Super-Resolution

SRCNNという、3層で構成されたシンプルなネットワークだ。任意のshapeの画像を受け取り、画像サイズを変えないようにpaddingを付与してCNNを3回繰り返すという処理を行う。活性化関数はReLuとした。

# ネットワークの定義
model = Sequential()
model.add(Conv2D(
    filters=64,
    kernel_size=9,
    padding='same',
    activation='relu',
    input_shape=(None, None, 3)
))
model.add(Conv2D(
    filters=32,
    kernel_size=1,
    padding='same',
    activation='relu'
))
model.add(Conv2D(
    filters=3,
    kernel_size=5,
    padding='same'
))

低解像度画像を作る

元画像を10%ほど粗くしたモザイク画像を作り、モザイク画像から元画像を推定させる。画像をモザイク化する関数は以下の記事を参考にしてOpenCVを用いて実装した。

note.nkmk.me

# 受け取ったnp arrayにモザイクをかけnp arrayとしてreturnする
def drop_resolution(arr, ratio=0.1):    
    tmp = cv2.resize(arr, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_NEAREST)
    new_img = cv2.resize(tmp, arr.shape[:2][::-1], interpolation=cv2.INTER_NEAREST)
    new_arr = img_to_array(new_img)
    
    return new_arr

試しにモザイク化してみた。こんな感じになった。

f:id:twx:20180909144802p:plain

学習しよう

コード全文を以下に掲載する。Google ColaboratoryでGPUを使って回すと、だいたい30分〜40分くらいで終わる。

Google Colaboratoryで学習するには、画像データをGoogle Driveにアップロードしておき、ColaboratoryでGoogle Driveにアクセスできるように設定しておく必要がある。詳しくは以下の記事を参照していただきたい。

www.mahirokazuko.com

# 各種モジュールのインポート
import os
import glob
import math
import random
import cv2
import numpy as np
from tensorflow.python import keras
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.models import Model, Sequential
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array, array_to_img
from tensorflow.python.keras.layers import Add, Input, Conv2D, Conv2DTranspose, Dense, Input, MaxPooling2D, UpSampling2D, Lambda

# 受け取ったnp arrayにモザイクをかけnp arrayとしてreturnする
def mosaic(arr, ratio=0.1):    
    tmp = cv2.resize(arr, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_NEAREST)
    new_img = cv2.resize(tmp, arr.shape[:2][::-1], interpolation=cv2.INTER_NEAREST)
    new_arr = img_to_array(new_img)
    return new_arr

# yieldをもちいてミニバッチをreturnするジェネレータ
def data_generator(data_dir, mode, mosaic_ratio=0.1, target_size=(128, 128), batch_size=32, shuffle=True):
    for imgs in ImageDataGenerator().flow_from_directory(
        directory=data_dir,
        classes=[mode],
        class_mode=None,
        color_mode='rgb',
        target_size=target_size,
        batch_size=batch_size,
        shuffle=shuffle
    ):
        x = np.array([mosaic(img_to_array(img), mosaic_ratio) for img in imgs])
        yield x/255., imgs/255.

        
DATA_DIR = 'data/'
BATCH_SIZE = 100

# 上で定義したジェネレータをもちいて学習データとテストデータをロード
train_data_generator = data_generator(DATA_DIR, 'train', batch_size=BATCH_SIZE, shuffle=True)
test_data_generator = data_generator(DATA_DIR, 'test', batch_size=BATCH_SIZE, shuffle=False)


# ネットワークの定義
model = Sequential()
model.add(Conv2D(
    filters=64,
    kernel_size=9,
    padding='same',
    activation='relu',
    input_shape=(None, None, 3)
))
model.add(Conv2D(
    filters=32,
    kernel_size=1,
    padding='same',
    activation='relu'
))
model.add(Conv2D(
    filters=3,
    kernel_size=5,
    padding='same'
))

# ピーク信号対雑音比(PSNR)を「解像度の高さ」の指標とする。この値が20以上くらいになると比較的、解像度が良いとされている。
def psnr(y_true, y_pred):
    return -10*K.log(
        K.mean(K.flatten((y_true - y_pred))**2)
    )/np.log(10)

# モデルをコンパイルする。損失関数は二乗誤差。最適化アルゴリズムはAdam。評価指標は上で定義したpsnr。
model.compile(
    loss='mean_squared_error', 
    optimizer='adam', 
    metrics=[psnr]
)

# 学習させる
model.fit_generator(
    train_data_generator,
    validation_data=test_data_generator,
    validation_steps=1,
    steps_per_epoch=100,
    epochs=50
)

結果を見てみる

# 未知のデータに対してモザイク除去を試してみる

from IPython.display import display_png

unknown_img_t = img_to_array( load_img('pingpong.jpg') )
unknown_img_x = mosaic(unknown_img_t)
unknown_img_y = model.predict(unknown_img_x.reshape(1,128,128,3))[0]

display_png( array_to_img( unknown_img_t ) )
display_png( array_to_img( unknown_img_x ) )
display_png( array_to_img( unknown_img_y ) )

以下のようになった。上から順に、オリジナル画像、モザイク画像、モザイク除去画像だ。

f:id:twx:20180909163941p:plain

確かに、少しはマシになっているのかもしれないが、モザイクが除去できているとは言えない結果となった。

冒頭の元論文では、「少しぼやけた画像」に対して高解像度化の効果が得られたと書かれていた。さすがに、今回のモザイクのようにかなり粗い画像に対しては期待できる効果は得られないようだ。

GAN等を使ったほうが、良い結果になるかもしれない。

以上、今回は超解像を試してみました。 良い記事だと思っていただけた方は下の「★+」ボタンのクリック、SNSでのシェア、「読者になる」ボタンのクリックをお願いいたします。 それではまた〜

Falconを使い超簡単なAPIを作る 【所要時間たったの3分】

Falconとは

f:id:twx:20180908114858p:plain

Falcon - Bare-metal web API framework for Python

Falcon を使えばRESTful な Web API をサクッと作ることができる。業務で簡単なWeb API を用意する必要があったり、ハッカソンなどの短期間開発で疎結合なアークテクチャを作らなくてはならないときに重宝する。Falcon自体はpythonのモジュールだが、肝心のAPI部分は好きな言語で作っても全然構わない (pythonから呼べるようにしておけば良いというだけの話だ)。

この記事では、例として「文を単語分割する」という機能のWeb APIを作ってみようと思う。所要時間はタイトル通り3分だ。

準備 【1分】

「文を単語分割する」とは何か

たとえば、「今日も1日頑張るぞい」という文があったとき、これを単語にスペースで区切ってみると以下のようになる。

今日  も  1  日  頑張る  ぞい

このように、文を単語ごとに区切ることを単語分割という。このような処理は形態素解析ツールによって実現できる。有名なツールに、mecabというものがある。今回はこれを使う。

mecabのインストール

Ubuntuであれば、以下のようにapt-getでインストールできる。

sudo apt-get install mecab libmecab-dev mecab-ipadic
sudo apt-get install mecab-ipadic-utf8
echo "本日は晴天なり" | mecab # => 形態素解析できるかテストしてみよう

falconなどのpipモジュールの準備

以下のようにpipで必要なモジュールをインストールする。ここでは、falcon, falcon-multipart, gunicornをインストールする。

falcon-multipartは、マルチパート形式のPOSTメソッドでパラメータを受け取ることができるようにするモジュールだ。 gunicornは簡易的なWEBサーバだ。

pip install falcon
pip install falcon-multipart
pip gunicorn

サーバ側のコード【1分】

server.py

import falcon
import json
import subprocess
from falcon_multipart.middleware import MultipartMiddleware

# クロスオリジンを許可すする
class CORSMiddleware:
    def process_request(self, req, resp):
        resp.set_header('Access-Control-Allow-Origin', '*')

# shelコマンドを実行する関数
def do_command(command):
    proc = subprocess.Popen(
        command,
        shell  = True,
        stdin  = subprocess.PIPE,
        stdout = subprocess.PIPE,
        stderr = subprocess.PIPE
    )
    return proc.communicate()

class Mecab:

  def on_post(self, req, res):
    # POSTメソッドを受け付ける
    # sentenseという名前のパラメータを受け取る
    sentence = req.get_param('sentence')

    # 'echo hogehoge| mecab -Owakati'というコマンドをたたく。-Owakatiオプションでmecabをたたけば単語分割ができる。
    stdout, stderr = do_command('echo {} | mecab -Owakati'.format(sentence))

    # データを整形し、クライアントにレスポンスを返す
    resp = {
      'stdout': stdout.decode('utf-8').strip(),
      'stderr': stderr.decode('utf-8').strip()
    }
    res.body = json.dumps(resp)

# ミドルウェアを2つ設定している。前者はCORSを許可するために必要だ。これがないと異なるオリジンからのアクセスが許可されない。
api = falcon.API(middleware=[CORSMiddleware(), MultipartMiddleware()])

# /hogeなど任意のエンドポイントを設定できる。今回は / としている。
api.add_route('/', Mecab())

上のコードをサーバとして起動させよう。サーバ側で以下のコマンドを実行する。

gunicorn server:api

これで、デフォルトではhttp://127.0.0.1:8000でサーバが立ち上がる。

IPを指定したい場合は、以下のようにすれば良い。

gunicorn server:api -b xxx.xxx.xxx.xxx

クライアント側のコード【1分】

client.html

<html>
<body>

<p>単語に分割したい文を入力してください</p>
<input type="text" id="sentence"/>
<button id="send">送信</button>

<p>結果</p>
<span id="result"></span>

<script src="https://ajax.googleapis.com/ajax/libs/jquery/2.2.4/jquery.min.js"></script>
<script>
$(document).ready(function () {

  $('#send').click(function (event) {
    var fd = new FormData();
    fd.append('sentence', $('#sentence').val());
    $.ajax({
      url: 'http://127.0.0.1:8000',
      type: 'POST',
      dataType: 'json',
      data: fd,
      processData: false,
      contentType: false
    })
    .done(function( res, textStatus, jqXHR ) {
      $('#result').text(res.stdout)
    });
  });
});
</script>

</body>
</html>

実行結果

f:id:twx:20180909013202p:plain

 ↓えいや

f:id:twx:20180909013210p:plain

できた!

以上、falconでAPIを簡単に作る方法をご紹介しました。

この記事を良いと思っていただけた方は下の「★+」ボタンのクリックと、SNSでのシェア、「読者になる」ボタンのクリックをお願いいたします! それではまた!

Google Colaboratory のTips集その1 (GoogleDriveマウント、セッション継続、TensorBoard接続)

Google Colaboratory、すごいです。 誰でも、Tesla K80のGPUを無料ですぐに使うことができる。Tensorflow環境もすぐに手に入れることができる。

https://colab.research.google.com/

この記事では、Google Colaboratoryを使って深層学習をまわすときのちょっとしたTipsを紹介する。

Google Colaboratoryの時間制限ルール

2018年9月7日時点で、以下のようなルールとなっている

  • インスタンスの連続稼働時間は12時間
  • セッションが途切れて90分立つとセッション初期化

長時間の学習で12時間を超えてしまう場合、途中の学習結果をGoogle Drive等に保存できるようにしておく必要がある。また、セッションが途切れてしまわないように、定期的にアクセスする必要がある。

Google driveをマウントしてモデルを保存

ノートブックに以下のコードを追加する。

!apt-get install -y -qq software-properties-common python-software-properties module-init-tools
!add-apt-repository -y ppa:alessandro-strada/ppa 2>&1 > /dev/null
!apt-get update -qq 2>&1 > /dev/null
!apt-get -y install -qq google-drive-ocamlfuse fuse

from google.colab import auth
from oauth2client.client import GoogleCredentials
import getpass
auth.authenticate_user()
creds = GoogleCredentials.get_application_default()

!google-drive-ocamlfuse -headless -id={creds.client_id} -secret={creds.client_secret} < /dev/null 2>&1 | grep URL
vcode = getpass.getpass()
!echo {vcode} | google-drive-ocamlfuse -headless -id={creds.client_id} -secret={creds.client_secret}

!mkdir -p drive
!google-drive-ocamlfuse drive

途中で2回、Googleにログインするのためのコード入力が要求される。 無事、ログインできれば、driveというディレクトリにgoogle driveがマウントされる。

わたしはいつも、Google driveでプロジェクトごとにディレクトリを分けている。

%cd /content/drive/work/project_A

こんな感じで、ディレクトリ移動できるようになる。

モデルを保存するときは、上述のようにディレクトリを移動して普通にカレントディレクトリに保存するか、モデルの保存先として/content/drive/work/project_AなどのGoogleDriveマウント先のディレクトリを指定すればOK。

セッションが途切れないように定期アクセス

クライアントPCで定期的にアクセスする。以下の例では、Macのopenコマンドで定期的にColaboratoryを開いている。

for i in `seq 0 12`
do
  open https://colab.research.google.com/drive/xxxxxxxxxxxxxxxxxxxxx
  sleep 100
done

AWS等から定期的にアクセスする方法も、のちほど記事化したい。乞うご期待。

Google ColaboratoryでTensorBoardを使う。

ノートブックに以下のコードを追加する。

# ngrokというツールをもちいて、Colaboratoryのローカル環境に立つサーバにインターネットからアクセスできるようにする。
! wget https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip 
! unzip ngrok-stable-linux-amd64.zip 

get_ipython().system_raw('./ngrok http 6006 &') 
get_ipython().system_raw('tensorboard --logdir=logs --host 0.0.0.0 --port 6006 &')
# ↑ --logdir=logsの右辺にTensoBoardのログの名前を書くこと。

! curl -s http://localhost:4040/api/tunnels | python3 -c "import sys, json; print(json.load(sys.stdin)['tunnels'][0]['public_url'])" 

# 表示されるURLにアクセスするとTensorBoardの画面が見れる。

以上、Google ColaboratoryのTipsその1でした。その2では、AWSをもちいたセッション継続や、GPU使用状況の確認などについて書こうと思います。

この記事が役に立ったという方は、下の「★+」ボタンのクリック、SNSでのシェア、「読者になる」ボタンのクリックをお願いします!!

GPU搭載のMacBook Pro (Retina, 15-inch, Mid 2014)でCUDAを動かす

CUDA周りにはあまり詳しくないのだが、MacでCUDAを動かす方法を色々調べたのでまとめておく。

特に、「何が必要か」「今の状況をどうやって確認するか」「何をすれば動くようになるか」を一元的にまとめた記事がなかったので、それを重点的に書こうと思う。

まず、わたしのMacのスペックを載せておく。

  • OS: 10.13.6 (High Sierra)
  • 型番: MacBook Pro (Retina, 15-inch, Mid 2014)
  • メモリ: 16 GB 1600 MHz DDR3
  • GPU: NVIDIA GeForce GT 750M 2048 MB

CUDAを動かすのに必要なもの

  • CUDA Toolkit
  • CUDA driver
  • NVIDIA display driver
  • nvcc
  • cudnn
  • Command Line Tools for Xcode

この他、tensorflowを動かそうと思ったら、最近のtensorflowはMacでのGPUをサポートしないらしいので、自分でビルドする必要があるらしい。このあたりは後日改めてまとめようと思う。

今の状況をどうやって確認するのか

CUDA Toolkit

以下のようにCUDA-9.2のようなディレクトリが存在していれば、CUDAはインストールされている。バージョンは、この例だと9.2だ。

/Developer/NVIDIA/CUDA-9.2

CUDA driver

Macの「システム環境設定」 → 「CUDA」

で表示される。こんな感じだ。

f:id:twx:20180902225906p:plain

この例だと、バージョンは396.148だ。

NVIDIA display driver

Macの「システム環境設定」 → 「CUDA」

で表示される。上の例ではバージョンは387.10.10.10.40.105だ。

nvcc

以下のコマンドで確認可能。CUDA Toolkitをインストールすればnvccも勝手に入っている。

$ nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2018 NVIDIA Corporation
Built on Tue_Jun_12_23:08:12_CDT_2018
Cuda compilation tools, release 9.2, V9.2.148

cudnn

以下のようにlibcudnn.7.dylibcudnn.hのようなファイルが存在していれば、cudnnはインストールされている。ファイル名に7とあるのでバージョンは7だ。cudnn.hの中身を見てみると、この中にもバージョン情報が書かれている。

/usr/local/cuda/lib/libcudnn.7.dylib
/usr/local/cuda/include/cudnn.h

Command Line Tools for Xcode

少し調べたが、バージョンを確かめる方法はわからなかった。

とりあえず、以下のディレクトリが存在していれば、何らかのバージョンのCommand Line Toolsがインストールしていることがわかる。

/Library/Developer/CommandLineTools

そして、正攻法ではないかもしれないが、以下のファイルを見ればバージョンがわかる(たぶん)。

$ cat /Library/Developer/CommandLineTools/Library/PrivateFrameworks/LLDB.framework/Versions/Current/Resources/Info.plist

<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
         …
          (省略)
         …
    <key>DTSDKName</key>
    <string>macosx10.12internal</string>
    <key>DTXcode</key>
    <string>0820</string>
         …
          (省略)
         …
</dict>
</plist>

この部分がバージョンぽい。

  <key>DTXcode</key>
  <string>0820</string>

何をすれば動くのか、どうなっていれば動くのか

上で述べたそれぞれのパーツを、相性が合うバージョンで揃えれば良い。

わたしは以下のバージョンで揃えた。

  • CUDA Toolkit: 9.2
  • CUDA driver: 396.148
  • NVIDIA display driver: 387.10.10.10.40.105
  • nvcc: 9.2
  • cudnn: 7.2
  • Command Line Tools for Xcode: 8.2

どういう組み合わせが相性が合うのか? それは、まず、CUDA Toolkit, CUDA driver, NVIDIA display driverの3つに関しては、NVIDIAのドライバのダウンロードページを見て確認しよう。以下のMac用ドライバのダウンロードページを見てほしい。

MAC アーカイブ用CUDA ドライバ|NVIDIA

いずれかのドライバをダウンロードしようとすると、2018年9月3日現在、このようなページが表示される。

f:id:twx:20180903000244p:plain

ここに、どのOS、どのCUDA Toolkit, どのNVIDIA display driverをサポートしているのかが書かれている。

また、nvccに関しては、CUDA Toolkitをインストールすれば、勝手に同じバージョンのものが手に入る。

cudnnは、cudnnのダウンロードページから、サポートしているCUDAのバージョンを見れば良い。

https://developer.nvidia.com/rdp/cudnn-download

Command Line Tools for Xcodeに関してはよくわからないのだが、最近のXcodeでインストールされるclangではコンパイルできないので、あえて古いバージョンのXcodeをインストールするそうだ。

以下で古いXcodeをダウンロードできる。

Sign in with your Apple ID - Apple Developer

Command Line Tools for Xcode_8.2 のようなファイル名のものをインストールすること。

インストールが終わったら、以下のコマンドを打つとバージョンの切り替えができる。

sudo xcode-select --switch /Library/Developer/CommandLineTools

古くなったか確かめてみる。

$ clang --version
Apple LLVM version 8.0.0 (clang-800.0.42.1)
Target: x86_64-apple-darwin17.7.0
Thread model: posix
InstalledDir: /Library/Developer/CommandLineTools/usr/bin

はい、OK。

CUDAが正しく起動するかどうかを確かめるには以下を実行すればOK。

cd /Developer/NVIDIA/CUDA-9.2/samples/1_Utilities/deviceQuery
sudo make
./deviceQuery

./deviceQuery Starting...

 CUDA Device Query (Runtime API) version (CUDART static linking)

Detected 1 CUDA Capable device(s)

Device 0: "GeForce GT 750M"
  CUDA Driver Version / Runtime Version          9.2 / 9.2
  CUDA Capability Major/Minor version number:    3.0
  Total amount of global memory:                 2048 MBytes (2147024896 bytes)
  ( 2) Multiprocessors, (192) CUDA Cores/MP:     384 CUDA Cores
  GPU Max Clock rate:                            926 MHz (0.93 GHz)
  Memory Clock rate:                             2508 Mhz
  Memory Bus Width:                              128-bit
  L2 Cache Size:                                 262144 bytes
  Maximum Texture Dimension Size (x,y,z)         1D=(65536), 2D=(65536, 65536), 3D=(4096, 4096, 4096)
  Maximum Layered 1D Texture Size, (num) layers  1D=(16384), 2048 layers
  Maximum Layered 2D Texture Size, (num) layers  2D=(16384, 16384), 2048 layers
  Total amount of constant memory:               65536 bytes
  Total amount of shared memory per block:       49152 bytes
  Total number of registers available per block: 65536
  Warp size:                                     32
  Maximum number of threads per multiprocessor:  2048
  Maximum number of threads per block:           1024
  Max dimension size of a thread block (x,y,z): (1024, 1024, 64)
  Max dimension size of a grid size    (x,y,z): (2147483647, 65535, 65535)
  Maximum memory pitch:                          2147483647 bytes
  Texture alignment:                             512 bytes
  Concurrent copy and kernel execution:          Yes with 1 copy engine(s)
  Run time limit on kernels:                     Yes
  Integrated GPU sharing Host Memory:            No
  Support host page-locked memory mapping:       Yes
  Alignment requirement for Surfaces:            Yes
  Device has ECC support:                        Disabled
  Device supports Unified Addressing (UVA):      Yes
  Device supports Compute Preemption:            No
  Supports Cooperative Kernel Launch:            No
  Supports MultiDevice Co-op Kernel Launch:      No
  Device PCI Domain ID / Bus ID / location ID:   0 / 1 / 0
  Compute Mode:
     < Default (multiple host threads can use ::cudaSetDevice() with device simultaneously) >

deviceQuery, CUDA Driver = CUDART, CUDA Driver Version = 9.2, CUDA Runtime Version = 9.2, NumDevs = 1
Result = PASS

こんな感じになればOK!

以上、今回はMacでCUDAを動かすということをやってみました。良い記事だと思った方は、以下の「★+」ボタンと、「読者になる」ボタンのクリック、記事の拡散、お願いします! ではでは、またの記事で。

pip face_recognitionでcmakeエラー(Mac OSX)

Macでpip install face_recognitionが失敗してハマったので、解決策をメモしておく。

まず、筆者がハマったときのエラーを以下に示す。

  -- Looking for sys/types.h - found
  -- Looking for stdint.h
  -- Looking for stdint.h - found
  -- Looking for stddef.h
  -- Looking for stddef.h - found
  -- Check size of void*
  -- Check size of void* - done
  -- Found LAPACK library
  -- Found CBLAS library
  -- Looking for cblas_ddot
  -- Looking for cblas_ddot - found
  -- Looking for sgesv
  -- Looking for sgesv - found
  -- Looking for sgesv_
  -- Looking for sgesv_ - found
  -- Found CUDA: /usr/local/cuda (found suitable version "9.2", minimum required is "7.5")
  -- Looking for cuDNN install...
  -- Found cuDNN: /usr/local/cuda/lib/libcudnn.dylib
  -- Building a CUDA test project to see if your compiler is compatible with CUDA...
  -- Checking if you have the right version of cuDNN installed.
  CMake Error at /private/var/folders/0z/r_3km1sj22gfl2fdx7g09gz40000gn/T/pip-install-s5e1bkz6/dlib/dlib/cmake_utils/use_cpp_11.cmake:74 (target_compile_features):
    target_compile_features The compiler feature "cxx_thread_local" is not
    known to CXX compiler

    "Clang"

    version 6.1.0.6020053.
  Call Stack (most recent call first):
    /private/var/folders/0z/r_3km1sj22gfl2fdx7g09gz40000gn/T/pip-install-s5e1bkz6/dlib/dlib/cmake_utils/test_for_cudnn/CMakeLists.txt:18 (enable_cpp11_for_target)


  CMake Error at /private/var/folders/0z/r_3km1sj22gfl2fdx7g09gz40000gn/T/pip-install-s5e1bkz6/dlib/dlib/CMakeLists.txt:662 (try_compile):
    Failed to configure test project build system.


  -- Configuring incomplete, errors occurred!
  See also "/private/var/folders/0z/r_3km1sj22gfl2fdx7g09gz40000gn/T/pip-install-s5e1bkz6/dlib/build/temp.macosx-10.11-x86_64-3.5/CMakeFiles/CMakeOutput.log".
  Traceback (most recent call last):
    File "<string>", line 1, in <module>
    File "/private/var/folders/0z/r_3km1sj22gfl2fdx7g09gz40000gn/T/pip-install-s5e1bkz6/dlib/setup.py", line 257, in <module>
      'Topic :: Software Development',
    File "/usr/local/var/pyenv/versions/3.5.2/lib/python3.5/distutils/core.py", line 148, in setup
      dist.run_commands()
    File "/usr/local/var/pyenv/versions/3.5.2/lib/python3.5/distutils/dist.py", line 955, in run_commands
      self.run_command(cmd)
    File "/usr/local/var/pyenv/versions/3.5.2/lib/python3.5/distutils/dist.py", line 974, in run_command
      cmd_obj.run()
    File "/usr/local/var/pyenv/versions/3.5.2/lib/python3.5/site-packages/wheel/bdist_wheel.py", line 202, in run
      self.run_command('build')
    File "/usr/local/var/pyenv/versions/3.5.2/lib/python3.5/distutils/cmd.py", line 313, in run_command
      self.distribution.run_command(command)
    File "/usr/local/var/pyenv/versions/3.5.2/lib/python3.5/distutils/dist.py", line 974, in run_command
      cmd_obj.run()
    File "/usr/local/var/pyenv/versions/3.5.2/lib/python3.5/distutils/command/build.py", line 135, in run
      self.run_command(cmd_name)
    File "/usr/local/var/pyenv/versions/3.5.2/lib/python3.5/distutils/cmd.py", line 313, in run_command
      self.distribution.run_command(command)
    File "/usr/local/var/pyenv/versions/3.5.2/lib/python3.5/distutils/dist.py", line 974, in run_command
      cmd_obj.run()
    File "/private/var/folders/0z/r_3km1sj22gfl2fdx7g09gz40000gn/T/pip-install-s5e1bkz6/dlib/setup.py", line 133, in run
      self.build_extension(ext)
    File "/private/var/folders/0z/r_3km1sj22gfl2fdx7g09gz40000gn/T/pip-install-s5e1bkz6/dlib/setup.py", line 170, in build_extension
      subprocess.check_call(cmake_setup, cwd=build_folder)
    File "/usr/local/var/pyenv/versions/3.5.2/lib/python3.5/subprocess.py", line 581, in check_call
      raise CalledProcessError(retcode, cmd)
  subprocess.CalledProcessError: Command '['cmake', '/private/var/folders/0z/r_3km1sj22gfl2fdx7g09gz40000gn/T/pip-install-s5e1bkz6/dlib/tools/python', '-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=/private/var/folders/0z/r_3km1sj22gfl2fdx7g09gz40000gn/T/pip-install-s5e1bkz6/dlib/build/lib.macosx-10.11-x86_64-3.5', '-DPYTHON_EXECUTABLE=/usr/local/var/pyenv/versions/3.5.2/bin/python3.5', '-DCMAKE_BUILD_TYPE=Release']' returned non-zero exit status 1

  ----------------------------------------
  Failed building wheel for dlib

解決策

結論から言うと、以下に示す環境にすることでうまくインストールできた。筆者はGPUを使っているが、GPU無しの場合はCUDAやcudnnは不要と思われる。

  • OSX: 10.13.2 (High Sierra)
  • CUDA: 9.2
  • cudnn: 7.2
  • cmake: 3.12.1
  • clang: Apple LLVM version 8.0.0 (clang-800.0.42.1)

CUDAのバージョン確認

 ls /Developer/NVIDIA/CUDA-9.2 # CUDA Toolkitをdmgからインストールするとデフォルトでこのディレクトリに入る

古いバージョンが入っていたら、それをアンインストールしてから新しいCUDA Toolkit dmgをダウンロードしてきて再インストール。

cudnnのバージョン確認

cat /usr/local/cuda/include/cudnn.h | grep CUDNN_M

メジャーバージョンとマイナーバージョンが表示される。メジャーが7、マイナーが2なら、バージョン7.2ということ。上述のディレクトリにcudnn.hが無い場合はそもそもインストールされていない。

NVIDIAのサイトからtgz形式のcudnn7.2をダウンロードしてきて、解凍後、includeの中身を全て/usr/local/cuda/include/に、libの中身を全て/usr/local/cuda/lib/にコピーすればインストール完了。

cmakeのバージョン確認

cmake --version

pipで最新版をインストールする。

pip install cmake

clangのバージョン確認

clang --version

以下の記事を参考に、Command_Line_Tools_macOS_10.12_for_Xcode_8.2.dmg をインストールした。

qiita.com

以下のコマンドで、コマンドラインツールのバージョンを切り替える。

$ sudo xcode-select --switch /Library/Developer/CommandLineTools

$ clang --version
Apple LLVM version 8.0.0 (clang-800.0.42.1)
Target: x86_64-apple-darwin17.3.0
Thread model: posix
InstalledDir: /Library/Developer/CommandLineTools/usr/bin

この状態で、pip install face_recognitionをたたくと成功した。

正規表現をFSTに変換してOpenFSTで扱えるようにする

FSTとは

Finite-state Transducerの略。オートマトンの一種で、複数のノードとアークから成る。

入力系列に対し、その入力系列が受領可能か受領不可能かを返す。受領可能だった場合は、同時に出力系列も返す。 出力系列の候補が2通りあった場合、重みの大小によって出力の優先度が決まる。このような重みが存在するFSTをWFST(Weighted Finite-state Transducer)という。

WFSTを扱うときに便利なOSSがある。OpenFSTというツールだ。

以下の記事でも紹介したので詳しくはコチラを御覧いただきたい。

www.mahirokazuko.com

正規表現をFST化する

FSTの教科書を読むと、「正規表現も実はFSTで表せるよー」とよく書いてある。しかし、正規表現をFSTに変換してくれるツールを探してみても、意外とほとんど見つからなかった。唯一見つかったのがコチラのコード。python版とc++版がある。

GitHub - siddharthasahu/automata-from-regex: A python program to build nfa, dfa and minimised DFA from given regular expression. Uses Tkinter for GUI and GraphViz for graphs.

ただし、このコードはアルファベットをもちいた正規表現にのみ対応しており、そのままでは日本語を扱うことができない。そこで、一部を改造して日本語に対応できるようにした。改造したのはAutomataTheory.pyという名前のソースだ。ライセンスはGPL lv3とする。 (ついでにpython3で動くようにした。)

AutomataTheory.py (LICENSE: GNU GPL lv3)

from os import popen
import time

class Automata:
    """class to represent an Automata"""

    def __init__(self, language = set(['0', '1'])):
        self.states = set()
        self.startstate = None
        self.finalstates = []
        self.transitions = dict()
        self.language = language

    @staticmethod
    def epsilon():
        return ":e:"

    def setstartstate(self, state):
        self.startstate = state
        self.states.add(state)

    def addfinalstates(self, state):
        if isinstance(state, int):
            state = [state]
        for s in state:
            if s not in self.finalstates:
                self.finalstates.append(s)

    def addtransition(self, fromstate, tostate, inp):
        if isinstance(inp, str):
            inp = set([inp])
        self.states.add(fromstate)
        self.states.add(tostate)
        if fromstate in self.transitions:
            if tostate in self.transitions[fromstate]:
                self.transitions[fromstate][tostate] = self.transitions[fromstate][tostate].union(inp)
            else:
                self.transitions[fromstate][tostate] = inp
        else:
            self.transitions[fromstate] = {tostate : inp}

    def addtransition_dict(self, transitions):
        for fromstate, tostates in transitions.items():
            for state in tostates:
                self.addtransition(fromstate, state, tostates[state])

    def gettransitions(self, state, key):
        if isinstance(state, int):
            state = [state]
        trstates = set()
        for st in state:
            if st in self.transitions:
                for tns in self.transitions[st]:
                    if key in self.transitions[st][tns]:
                        trstates.add(tns)
        return trstates

    def getEClose(self, findstate):
        allstates = set()
        states = set([findstate])
        while len(states)!= 0:
            state = states.pop()
            allstates.add(state)
            if state in self.transitions:
                for tns in self.transitions[state]:
                    if Automata.epsilon() in self.transitions[state][tns] and tns not in allstates:
                        states.add(tns)
        return allstates

    def display(self):
        print ("states:", self.states)
        print ("start state: ", self.startstate)
        print ("final states:", self.finalstates)
        print ("transitions:")
        for fromstate, tostates in self.transitions.items():
            for state in tostates:
                for char in tostates[state]:
                    print ("#", fromstate, state, char, char, "1")
            print ()
        for finalstate in self.finalstates:
            print ('#', finalstate)

    def getPrintText(self):
        text = "language: {" + ", ".join(self.language) + "}\n"
        text += "states: {" + ", ".join(map(str,self.states)) + "}\n"
        text += "start state: " + str(self.startstate) + "\n"
        text += "final states: {" + ", ".join(map(str,self.finalstates)) + "}\n"
        text += "transitions:\n"
        linecount = 5
        for fromstate, tostates in self.transitions.items():
            for state in tostates:
                for char in tostates[state]:
                    text += "    " + str(fromstate) + " -> " + str(state) + " on '" + char + "'\n"
                    linecount +=1
        return [text, linecount]

    def newBuildFromNumber(self, startnum):
        translations = {}
        for i in list(self.states):
            translations[i] = startnum
            startnum += 1
        rebuild = Automata(self.language)
        rebuild.setstartstate(translations[self.startstate])
        rebuild.addfinalstates(translations[self.finalstates[0]])
        for fromstate, tostates in self.transitions.items():
            for state in tostates:
                rebuild.addtransition(translations[fromstate], translations[state], tostates[state])
        return [rebuild, startnum]

    def newBuildFromEquivalentStates(self, equivalent, pos):
        rebuild = Automata(self.language)
        for fromstate, tostates in self.transitions.items():
            for state in tostates:
                rebuild.addtransition(pos[fromstate], pos[state], tostates[state])
        rebuild.setstartstate(pos[self.startstate])
        for s in self.finalstates:
            rebuild.addfinalstates(pos[s])
        return rebuild

    def getDotFile(self):
        dotFile = "digraph DFA {\nrankdir=LR\n"
        if len(self.states) != 0:
            dotFile += "root=s1\nstart [shape=point]\nstart->s%d\n" % self.startstate
            for state in self.states:
                if state in self.finalstates:
                    dotFile += "s%d [shape=doublecircle]\n" % state
                else:
                    dotFile += "s%d [shape=circle]\n" % state
            for fromstate, tostates in self.transitions.items():
                for state in tostates:
                    for char in tostates[state]:
                        dotFile += 's%d->s%d [label="%s"]\n' % (fromstate, state, char)
        dotFile += "}"
        return dotFile

class BuildAutomata:
    """class for building e-nfa basic structures"""

    @staticmethod
    def basicstruct(inp):
        state1 = 1
        state2 = 2
        basic = Automata()
        basic.setstartstate(state1)
        basic.addfinalstates(state2)
        basic.addtransition(1, 2, inp)
        return basic

    @staticmethod
    def plusstruct(a, b):
        [a, m1] = a.newBuildFromNumber(2)
        [b, m2] = b.newBuildFromNumber(m1)
        state1 = 1
        state2 = m2
        plus = Automata()
        plus.setstartstate(state1)
        plus.addfinalstates(state2)
        plus.addtransition(plus.startstate, a.startstate, Automata.epsilon())
        plus.addtransition(plus.startstate, b.startstate, Automata.epsilon())
        plus.addtransition(a.finalstates[0], plus.finalstates[0], Automata.epsilon())
        plus.addtransition(b.finalstates[0], plus.finalstates[0], Automata.epsilon())
        plus.addtransition_dict(a.transitions)
        plus.addtransition_dict(b.transitions)
        return plus

    @staticmethod
    def dotstruct(a, b):
        [a, m1] = a.newBuildFromNumber(1)
        [b, m2] = b.newBuildFromNumber(m1)
        state1 = 1
        state2 = m2-1
        dot = Automata()
        dot.setstartstate(state1)
        dot.addfinalstates(state2)
        dot.addtransition(a.finalstates[0], b.startstate, Automata.epsilon())
        dot.addtransition_dict(a.transitions)
        dot.addtransition_dict(b.transitions)
        return dot

    @staticmethod
    def starstruct(a):
        [a, m1] = a.newBuildFromNumber(2)
        state1 = 1
        state2 = m1
        star = Automata()
        star.setstartstate(state1)
        star.addfinalstates(state2)
        star.addtransition(star.startstate, a.startstate, Automata.epsilon())
        star.addtransition(star.startstate, star.finalstates[0], Automata.epsilon())
        star.addtransition(a.finalstates[0], star.finalstates[0], Automata.epsilon())
        star.addtransition(a.finalstates[0], a.startstate, Automata.epsilon())
        star.addtransition_dict(a.transitions)
        return star


class DFAfromNFA:
    """class for building dfa from e-nfa and minimise it"""

    def __init__(self, nfa):
        self.buildDFA(nfa)
        self.minimise()

    def getDFA(self):
        return self.dfa

    def getMinimisedDFA(self):
        return self.minDFA

    def displayDFA(self):
        self.dfa.display()

    def displayMinimisedDFA(self):
        self.minDFA.display()

    def buildDFA(self, nfa):
        allstates = dict()
        eclose = dict()
        count = 1
        state1 = nfa.getEClose(nfa.startstate)
        eclose[nfa.startstate] = state1
        dfa = Automata(nfa.language)
        dfa.setstartstate(count)
        states = [[state1, count]]
        allstates[count] = state1
        count +=  1
        while len(states) != 0:
            [state, fromindex] = states.pop()
            for char in dfa.language:
                trstates = nfa.gettransitions(state, char)
                for s in list(trstates)[:]:
                    if s not in eclose:
                        eclose[s] = nfa.getEClose(s)
                    trstates = trstates.union(eclose[s])
                if len(trstates) != 0:
                    if trstates not in allstates.values():
                        states.append([trstates, count])
                        allstates[count] = trstates
                        toindex = count
                        count +=  1
                    else:
                        toindex = [k for k, v in allstates.items() if v  ==  trstates][0]
                    dfa.addtransition(fromindex, toindex, char)
        for value, state in allstates.items():
            if nfa.finalstates[0] in state:
                dfa.addfinalstates(value)
        self.dfa = dfa

    def acceptsString(self, string):
        currentstate = self.dfa.startstate
        for ch in string:
            if ch==":e:":
                continue
            st = list(self.dfa.gettransitions(currentstate, ch))
            if len(st) == 0:
                return False
            currentstate = st[0]
        if currentstate in self.dfa.finalstates:
            return True
        return False

    def minimise(self):
        states = list(self.dfa.states)
        n = len(states)
        unchecked = dict()
        count = 1
        distinguished = []
        equivalent = dict(zip(range(len(states)), [{s} for s in states]))
        pos = dict(zip(states,range(len(states))))
        for i in range(n-1):
            for j in range(i+1, n):
                if not ([states[i], states[j]] in distinguished or [states[j], states[i]] in distinguished):
                    eq = 1
                    toappend = []
                    for char in self.dfa.language:
                        s1 = self.dfa.gettransitions(states[i], char)
                        s2 = self.dfa.gettransitions(states[j], char)
                        if len(s1) != len(s2):
                            eq = 0
                            break
                        if len(s1) > 1:
                            raise BaseException("Multiple transitions detected in DFA")
                        elif len(s1) == 0:
                            continue
                        s1 = s1.pop()
                        s2 = s2.pop()
                        if s1 != s2:
                            if [s1, s2] in distinguished or [s2, s1] in distinguished:
                                eq = 0
                                break
                            else:
                                toappend.append([s1, s2, char])
                                eq = -1
                    if eq == 0:
                        distinguished.append([states[i], states[j]])
                    elif eq == -1:
                        s = [states[i], states[j]]
                        s.extend(toappend)
                        unchecked[count] = s
                        count += 1
                    else:
                        p1 = pos[states[i]]
                        p2 = pos[states[j]]
                        if p1 != p2:
                            st = equivalent.pop(p2)
                            for s in st:
                                pos[s] = p1
                            equivalent[p1] = equivalent[p1].union(st)
        newFound = True
        while newFound and len(unchecked) > 0:
            newFound = False
            toremove = set()
            for p, pair in list(unchecked.items()):
                for tr in pair[2:]:
                    if [tr[0], tr[1]] in distinguished or [tr[1], tr[0]] in distinguished:
                        unchecked.pop(p)
                        distinguished.append([pair[0], pair[1]])
                        newFound = True
                        break
        for pair in unchecked.values():
            p1 = pos[pair[0]]
            p2 = pos[pair[1]]
            if p1 != p2:
                st = equivalent.pop(p2)
                for s in st:
                    pos[s] = p1
                equivalent[p1] = equivalent[p1].union(st)
        if len(equivalent) == len(states):
            self.minDFA = self.dfa
        else:
            self.minDFA = self.dfa.newBuildFromEquivalentStates(equivalent, pos)

class NFAfromRegex:
    """class for building e-nfa from regular expressions"""

    def __init__(self, regex):
        self.star = '*'
        self.plus = '+'
        self.dot = '.'
        self.openingBracket = '('
        self.closingBracket = ')'
        self.operators = [self.plus, self.dot]
        self.regex = regex
        self.alphabet = []
        self.alphabet = [c.strip() for c in open('charlist')]
        self.buildNFA()

    def getNFA(self):
        return self.nfa

    def displayNFA(self):
        self.nfa.display()

    def buildNFA(self):
        language = set()
        self.stack = []
        self.automata = []
        previous = "::e::"
        for char in self.regex:
            if char in self.alphabet:
                language.add(char)
                if previous != self.dot and (previous in self.alphabet or previous in [self.closingBracket,self.star]):
                    self.addOperatorToStack(self.dot)
                self.automata.append(BuildAutomata.basicstruct(char))
            elif char  ==  self.openingBracket:
                if previous != self.dot and (previous in self.alphabet or previous in [self.closingBracket,self.star]):
                    self.addOperatorToStack(self.dot)
                self.stack.append(char)
            elif char  ==  self.closingBracket:
                if previous in self.operators:
                    raise BaseException("Error processing '%s' after '%s'" % (char, previous))
                while(1):
                    if len(self.stack) == 0:
                        raise BaseException("Error processing '%s'. Empty stack" % char)
                    o = self.stack.pop()
                    if o == self.openingBracket:
                        break
                    elif o in self.operators:
                        self.processOperator(o)
            elif char == self.star:
                if previous in self.operators or previous  == self.openingBracket or previous == self.star:
                    raise BaseException("Error processing '%s' after '%s'" % (char, previous))
                self.processOperator(char)
            elif char in self.operators:
                if previous in self.operators or previous  == self.openingBracket:
                    raise BaseException("Error processing '%s' after '%s'" % (char, previous))
                else:
                    self.addOperatorToStack(char)
            else:
                raise BaseException("Symbol '%s' is not allowed" % char)
            previous = char
        while len(self.stack) != 0:
            op = self.stack.pop()
            self.processOperator(op)
        if len(self.automata) > 1:
            print (self.automata)
            raise BaseException("Regex could not be parsed successfully")
        self.nfa = self.automata.pop()
        self.nfa.language = language

    def addOperatorToStack(self, char):
        while(1):
            if len(self.stack) == 0:
                break
            top = self.stack[len(self.stack)-1]
            if top == self.openingBracket:
                break
            if top == char or top == self.dot:
                op = self.stack.pop()
                self.processOperator(op)
            else:
                break
        self.stack.append(char)

    def processOperator(self, operator):
        if len(self.automata) == 0:
            raise BaseException("Error processing operator '%s'. Stack is empty" % operator)
        if operator == self.star:
            a = self.automata.pop()
            self.automata.append(BuildAutomata.starstruct(a))
        elif operator in self.operators:
            if len(self.automata) < 2:
                raise BaseException("Error processing operator '%s'. Inadequate operands" % operator)
            a = self.automata.pop()
            b = self.automata.pop()
            if operator == self.plus:
                self.automata.append(BuildAutomata.plusstruct(b,a))
            elif operator == self.dot:
                self.automata.append(BuildAutomata.dotstruct(b,a))

def drawGraph(automata, file = ""):
    """From https://github.com/max99x/automata-editor/blob/master/util.py"""
    f = popen(r"dot -Tpng -o graph%s.png" % file, 'w')
    try:
        f.write(automata.getDotFile())
    except:
        raise BaseException("Error creating graph")
    finally:
        f.close()

def isInstalled(program):
    """From http://stackoverflow.com/questions/377017/test-if-executable-exists-in-python"""
    import os
    def is_exe(fpath):
        return os.path.isfile(fpath) and os.access(fpath, os.X_OK)
    fpath, fname = os.path.split(program)
    if fpath:
        if is_exe(program) or is_exe(program+".exe"):
            return True
    else:
        for path in os.environ["PATH"].split(os.pathsep):
            exe_file = os.path.join(path, program)
            if is_exe(exe_file) or is_exe(exe_file+".exe"):
                return True
    return False

ここで、日本語対応するにあたり、このツールで扱えるようにする文字群をcharlistというファイルで定義するようにした。charlistは「1文字が1行に並んだテキストファイル」とする。以下に例を示す。

charlist

a
b
c
d
e
f
g
h
i
j
・
・
(省略)
・
・
枠
鷲
亙
亘
鰐
詫
藁
蕨
椀
湾
碗
腕
々
ー
/
+
−

このcharlistAutomataTheory.pyと同じ階層に置いておく。

では早速、適当な正規表現をWFST化してみよう。以下のように引数に正規表現を受け取り、標準出力でWFSTの情報を出力するpythonコードを用意。

convert-regex2fst.py

from AutomataTheory import *
import sys

def main():
    if len(sys.argv)>1:
        inp = sys.argv[1]
    else:
        return -1;

    print ("Regular Expression: ", inp)
    nfaObj = NFAfromRegex(inp)
    nfa = nfaObj.getNFA()
    dfaObj = DFAfromNFA(nfa)
    dfa = dfaObj.getDFA()
    minDFA = dfaObj.getMinimisedDFA()

    #print "\nNFA: "
    #nfaObj.displayNFA()

    #print ("\nDFA: ")
    #dfaObj.displayDFA()

    print ("\nMinimised DFA: ")
    dfaObj.displayMinimisedDFA()

if __name__  ==  '__main__':
    t = time.time()
    try:
        main()
    except BaseException as e:
        print ("\nFailure:", e)
    print ("\nExecution time: ", time.time() - t, "seconds")

以下のコマンドを実行。

$ python convert-regex2fst.py "今日も1日頑張る(ぞい+ZOI)"
Regular Expression:  今日も1日頑張る(ぞい+ZOI)

Minimised DFA:
states: {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}
start state:  0
final states: [11]
transitions:
# 0 1 今 今 1

# 1 2 日 日 1

# 2 3 も も 1

# 3 4 1 1 1

# 4 5 日 日 1

# 5 6 頑 頑 1

# 6 7 張 張 1

# 7 8 る る 1

# 8 9 Z Z 1
# 8 10 ぞ ぞ 1

# 10 11 い い 1

# 9 12 O O 1

# 12 11 I I 1

# 11

Execution time:  0.009211540222167969 seconds

はい、できた。ちなみに、上述の正規表現では、「または」を表す記号に"|"ではなく"+"が使われていることに注意。

#から始まる行が、それぞれのアークを表している。最初の#で始まる行は、0番のノードから1番のノードにアークが張られており、入力シンボルは"今"、出力シンボルも"今"、重みは1ということを意味している。

これを、OpenFSTで扱える形に整形してみる。

今度は以下のようにコマンドを実行。

$ python convert-regex2fst.py "今日も1日頑張る(ぞい+ZOI)" | grep "^#" | perl -pe "s/# //g"
0 1 今 今 1
1 2 日 日 1
2 3 も も 1
3 4 1 1 1
4 5 日 日 1
5 6 頑 頑 1
6 7 張 張 1
7 8 る る 1
8 9 ぞ ぞ 1
8 10 Z Z 1
10 11 O O 1
11 12 I I 1
9 12 い い 1
12

はい、#で始まる行だけを抽出できた。これを、適当な名前で保存しておこう。

$ python convert-regex2fst.py "今日も1日頑張る(ぞい+ZOI)" | grep "^#" | perl -pe "s/# //g" > zoi

zoiという名前で保存した。

これを、OpenFSTのバイナリ形式にする。大まかな手順はこうだ。

  • OpenFSTには、入力シンボルと出力シンボルを定義するためのファイルが必要なので、まずこれを作る
  • 上述のアークの情報と、入力シンボル、出力シンボルから、FSTのバイナリを作る。
  • FSTを図示する。

なお、OpenFSTの使い方や、日本語対応などは以下の記事にまとめてあるので参考にされたし。

www.mahirokazuko.com

さて、手順通り進めよう。

まずは、入力シンボルと出力シンボルの準備だ。 これは、さきほど定義したcharlistをそのまま使いたい。だが、charlistにはイプシロンが定義されていないので、これを追加しよう。以下のようなコマンドを作った。

charlist2symbollist.sh

awk '{print $0, NR-1}' <(echo "<eps>" && cat charlist) # charlistの先頭に<eps>を追加したあと、先頭の文字からナンバリングして出力するだけ

これを使ってシンボルリストを作ろう。

./charlist2symbollist.sh  > charlist.ins   # insはinput symbolという意味で命名した
./charlist2symbollist.sh  > charlist.ous    # ousはoutput symbolという意味で命名した

次に、fstのバイナリを作ろう。

fstcompile2 charlist.ins charlist.ous zoi

これで、zoi.binができる。なお、fstcompile2は独自定義したコマンドである。上述の別記事に記載されているので見ていただきたい。

次に、このfstをgraphvizで描画してみよう。

fstdraw2 charlist.ins charlist.ous zoi.bin 500

これで、zoi.bin.pngができる。fstdraw2も例によって独自コマンドだ。

f:id:twx:20180901162533p:plain

たしかに、正規表現として正しい形になっていることが確認できる。

以上。今回は正規表現をOpenFSTで扱える形に変換してみました。

いい記事だと思っていただければ、以下の「★ + 」ボタンのクリックと、「読者になる」のボタンのクリックをお願いします!!

独自学習したモデルをTensorFlow.jsで使う

1. TensorFlow.js

Tensorflowをブラウザで動かす技術が登場した。その名もTensorFlow.jsだ。 js.tensorflow.org

この記事では、独自に学習したモデルをTensorflowで使うときの流れを紹介する。

2. 学習するモデル(デブ判定器)

身長と体重を入力すると、どのくらい太っているのかを返すモデルを作ろうと思う。 「どのくらい太っているのか」は、身長と体重から計算されるBMIという値を使って定義する。

www.jpm1960.org

以下のように定義した。

  • 18.5未満 やせ (クラス番号0)
  • 18.5~24.9 ふつう (クラス番号1)
  • 25.0~29.9 肥満1度 (クラス番号2)
  • 30.0~34.9 肥満2度 (クラス番号3)
  • 35.0~39.9 肥満3度 (クラス番号4)
  • 40.0以上 肥満4度 (クラス番号5)

学習にもちいるデータセットの形式はこんな感じだ。

177.5890062,64.71439732,20.51959369,1
164.8972845,63.23919414,23.25730132,1
184.4277219,61.55920972,18.09841875,0
154.4941492,80.30581669,33.64518567,3
158.6124011,73.35584165,29.15818083,2
・
・
・

左から順に、身長(cm), 体重(kg), BMI, クラス番号である。カンマで区切られている。 実際の学習データと評価データは記事の末尾に載せておくので参考にしていただきたい。

ここでは、身長と体重の情報だけを使って、クラス番号を予測するモデルを作る。

bmi_tf.py

from sklearn import preprocessing
import tensorflow as tf
import numpy as np

training_data = 'bmi_training.csv'
test_data = 'bmi_test.csv'

# dim次元のOne-hotベクトルを表すlistを返す関数
def one_hot(idx, dim):
    return [1 if i == idx else 0 for i in range(dim)]

# データを読み込む関数
def load_data(filename):
    with open(filename) as f:
        rows = [[float(elem.strip()) for elem in row.split(',')] for row in f.readlines() ]
        x = [row[0:2] for row in rows] # データセットの中から, 身長と体重だけを取り出す。
        t = [one_hot(row[3], 6) for row in rows] # データセットの中から、クラス番号だけを取り出す。分類先のクラス数は6個。

        x = preprocessing.normalize(np.array(x, dtype=np.float32)) # 正規化
        return x, t

x = tf.placeholder(tf.float32, [None, 2], name='x') # 身長と体重の2次元ベクトルを受け取れるようにする。
t = tf.placeholder(tf.float32, [None, 6], name='t') # 「やせ」「ふつう」「肥満1度」「肥満2度」「肥満3度」「肥満4度」の6次元ベクトルを受け取れるようにする。
keep_prob = tf.placeholder(tf.float32, name='keep_prob') # ドロップアウトのためのハイパーパラメータ

# ネットワーク定義
w1 = tf.Variable(tf.random_normal([2,30], mean=0.0, stddev=0.5))
b1 = tf.Variable(tf.random_normal([30], mean=0.0, stddev=0.5))
h1 = tf.matmul(x, w1) + b1
h1 = tf.nn.dropout(h1, keep_prob) # ドロップアウト。keep_probで指定した割合のノードが生き残り、それ以外のノードは値が0になる。
h1 = tf.nn.tanh(h1)

w2 = tf.Variable(tf.random_normal([30,30], mean=0.0, stddev=0.5))
b2 = tf.Variable(tf.random_normal([30], mean=0.0, stddev=0.5))
h2 = tf.matmul(h1, w2) + b2
h2 = tf.nn.dropout(h2, keep_prob) # ドロップアウト
h2 = tf.nn.tanh(h2)

w3 = tf.Variable(tf.random_normal([30,6], mean=0.0, stddev=0.5))
b3 = tf.Variable(tf.random_normal([6], mean=0.0, stddev=0.5))
y = tf.matmul(h2, w3) + b3

# 誤差関数の定義
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=t, logits=y))

# 正解率の定義
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(t, 1))
acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

# 推論結果の定義
predict = tf.argmax(y, 1, name='predict') # 'predict'という名前をつけた。

# 学習アルゴリズム定義
train = tf.train.AdamOptimizer().minimize(loss)

# 初期化
init = tf.global_variables_initializer()

# データセットのロード
train_x, train_t = load_data(training_data)
test_x, test_t = load_data(test_data)

with tf.Session() as sess:
    sess.run(init)
    print('Epoch\tTraining loss\tTest loss\tTraining acc\tTest acc')
    for epoch in range(3000):
        sess.run(train, feed_dict={
            x: train_x,
            t: train_t,
            keep_prob: 0.5
        })
        if (epoch+1) % 5 == 0:
            print('{}\t{}\t{}\t{}\t{}'.format(
                str(epoch+1),
                str(sess.run(loss, feed_dict={x:train_x, t:train_t, keep_prob:1.0})),
                str(sess.run(loss, feed_dict={x:test_x, t:test_t, keep_prob:1.0})),
                str(sess.run(acc, feed_dict={x:train_x, t:train_t, keep_prob:1.0})),
                str(sess.run(acc, feed_dict={x:test_x, t:test_t, keep_prob:1.0})),
            ))

    tf.saved_model.simple_save(sess, 'saved_model_debu', inputs={'x': x, 'keep_prob': keep_prob}, outputs={'predict': predict})

以下のコマンドで学習を開始する。

python bmi_tf.py

学習した結果、こうなった↓

Epoch    Training loss   Test loss   Training acc    Test acc
5   3.1208303   3.2619486   0.3511111   0.36
10  2.9821303   3.114055    0.3511111   0.36
15  2.8266895   2.9496677   0.3511111   0.36
20  2.6841924   2.7964628   0.3511111   0.36
25  2.5317469   2.634027    0.3511111   0.36
30  2.3827975   2.4757023   0.3511111   0.36
35  2.2274723   2.3102322   0.3511111   0.36
40  2.0915022   2.1640873   0.3511111   0.36
・
・
(省略)
・
・
3000    0.58536345  0.5924296   0.7511111   0.68

図示するとこんな感じだ。だいたい精度70%であることがわかる。 f:id:twx:20180831153711p:plain

3. TensorFlow.js形式への変換

学習済のモデルは以下のディレクトリ階層で保存される。

saved_model_debu/
├── saved_model.pb
└── variables
    ├── variables.data-00000-of-00001
    └── variables.index

このモデルを、TensorFlow.jsで使える形に変換しよう。変換にはtensorflowjs_converterというツールを使う。tensorflowjs_converterは以下のコマンドでインストールできる。

pip install tensorflowjs

インストールできたら、早速変換してみよう。

tensorflowjs_converter --input_format=tf_saved_model --output_node_names='predict' --saved_model_tags=serve ./saved_model_debu ./web_model_debu

ここで、--output_node_namesには、学習済みモデルの出力層につけた名前を指定すること。2つの引数./saved_model_debu./web_model_debuは、前者が学習済みモデルのディレクトリ、後者が変換後のモデル名(任意)だ。

変換後のモデルは以下のディレクトリ階層で保存される。

web_model_debu/
├── group1-shard1of1
├── tensorflowjs_model.pb
└── weights_manifest.json

4. TensorFlow.jsアプリの作り方

先程生成した変換後のモデルを全て、インターネット上にアップロードしよう。 筆者はAmazon AWSのS3にアップした。

https://s3-ap-northeast-1.amazonaws.com/********************/weights_manifest.json

のようなURLでアクセスできるように、AWS側でアクセス権の緩和と、CORSの設定をしておく。CORSの詳細は以下の記事を参考にした。

qiita.com

モデルをアップロードし、外からアクセス可能であることが確認できたら、次に以下のようなhtmlファイルを作ろう。

index.html

<html>
<head>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.12.5"></script>
<script src="https://ajax.googleapis.com/ajax/libs/jquery/3.2.1/jquery.min.js"></script>
</head>
<body>

<p>身長(cm)</p>
<input id="height">

<p>体重(kg)</p>
<input id="weight">

<button id="predict"/>推定</button>
<p id="result"></p>

<script type="text/javascript">

  const MODEL_URL = 'https://s3-ap-northeast-1.amazonaws.com/********************/tensorflowjs_model.pb';
  const WEIGHTS_URL = 'https://s3-ap-northeast-1.amazonaws.com/********************/weights_manifest.json';
  const debuNet = tf.loadFrozenModel(MODEL_URL, WEIGHTS_URL);

  $(function(){
    $('#predict').click(function(){
      var list = [
        parseFloat($('#height').val()),
        parseFloat($('#weight').val())
      ];

      // 正規化
      var norm = 0;
      for(var i=0; i<list.length; i++){
        var elem = list[i];
        norm += elem*elem;
      }
      norm = Math.sqrt(norm);
      for(var i=0; i<list.length; i++){
        list[i] = list[i]/norm;
      }

      var x = tf.tensor([list]);
      var keep_prob = tf.tensor(1);

      debuNet
      .then(function(model) {
          return model.predict({
            x: x,
            keep_prob :keep_prob
          }).data();
      })
      .then(function(result){
        if(result == 0) {
          $('#result').text('やせ')
        } else if(result ==1) {
          $('#result').text('ふつう')
        } else if(result ==2) {
          $('#result').text('肥満1度')
        } else if(result ==3) {
          $('#result').text('肥満2度')
        } else if(result ==4) {
          $('#result').text('肥満3度')
        } else if(result ==5) {
          $('#result').text('肥満4度')
        } else {

        }
      });
    });
  });

</script>
</body>
</html>

ブラウザでみると、こんな画面になる。

f:id:twx:20180901130157p:plain

特に難しい点はない。初心者が気をつけるポイントとしては以下が挙げられる。

  • tensorflow.jsをheadタグ内でロードする
  • S3にアップしたモデルを読み込む
  • scriptタグ内で、身長と体重を受け取り正規化したあと、その値をモデルに入力して推定(predict)する
  • モデルはPromiseの形式になっているので、predictを行うときや結果を取り出すときは then メソッドを使う

画面に試しに身長と体重を入力して「推定ボタン」を推してみると・・・

f:id:twx:20180901130206p:plain

ふぇぇ (´・ω・`)

以上。TensorFlow.jsで自作モデルを動かしてみました。 いい記事だと思っていただければ、「スター☆」ボタンのクリックと、「読者になる」のボタンのクリックをお願いします!! 最後に、学習と評価に使ったデータセットを以下においておきます。

bmi_training.csv

157.7155327,89.68050674,36.05364012,4
149.0121203,64.41123934,29.00804543,2
180.8459627,55.09253775,16.84516007,0
148.1409148,58.21515077,26.52686836,2
149.2502218,53.00152538,23.79350368,1
184.8827907,86.88905118,25.41979705,2
148.836688,71.11034448,32.10057323,3
167.876071,73.54321637,26.09546421,2
177.8102206,50.07939998,15.83965344,0
150.449792,63.65515523,28.12227159,2
149.8739218,80.70286729,35.92831268,4
149.6785531,68.83859997,30.72648466,3
173.5408125,57.99238695,19.25607336,1
146.4555397,69.72566599,32.50731235,3
177.0621681,75.50536945,24.0838704,1
151.9656812,89.98033185,38.96337319,4
173.9581647,61.11309788,20.19503755,1
161.9356154,51.38360914,19.59475571,1
183.6585108,63.37645285,18.7890928,1
149.2184957,86.90654869,39.03077622,4
177.408315,84.17821087,26.74556329,2
157.6604586,50.83106006,20.44954248,1
152.7515651,75.97009171,32.55901914,3
173.8752019,84.72934364,28.02583414,2
150.2700977,53.63825253,23.75360241,1
177.5344335,79.65344496,25.27198563,2
145.5893331,62.96287424,29.70471786,2
179.2373856,60.37189685,18.79219975,1
184.8458864,66.59516262,19.49050552,1
172.8220754,83.34467776,27.90483374,2
164.1720909,70.22624545,26.05558277,2
165.7871527,53.74288128,19.55326114,1
174.1596466,68.32979815,22.52760737,1
174.5970877,87.13738126,28.58449429,2
179.2134466,63.90002495,19.89572838,1
165.4931726,82.13920995,29.99095932,2
166.7130118,64.7062284,23.28129274,1
170.9222462,88.16626659,30.17902939,3
155.4404346,61.4285431,25.42390234,2
161.1070384,81.16486086,31.27080205,3
175.4938469,50.52338884,16.40471482,0
173.6004887,50.95757405,16.90856913,0
166.7265635,88.32838536,31.77537571,3
152.7375314,66.19004502,28.3727311,2
172.5711501,88.01113292,29.55297623,2
159.4249706,62.74471303,24.68677997,1
165.6377195,51.96867408,18.94188308,1
169.210192,62.37068417,21.78349055,1
152.0001686,50.79446732,21.98509116,1
168.0976016,81.25903212,28.75733329,2
170.1208967,78.28826519,27.05087658,2
178.3777173,69.50922592,21.84546272,1
145.3744007,62.87172388,29.74948785,2
167.6526607,64.36569707,22.89989951,1
177.8696125,62.96025472,19.90045282,1
182.9627614,54.05407072,16.14741447,0
170.4067472,82.56212355,28.43199156,2
173.50718,59.56403492,19.78559919,1
153.9613623,87.00660585,36.70529648,4
169.992288,64.99083206,22.49021767,1
171.8726179,67.62774648,22.89346511,1
149.1961418,82.31398231,36.97927806,4
148.5367581,53.84273354,24.40389899,1
164.4137849,50.3552604,18.62809276,1
183.2855421,72.62546189,21.61884514,1
149.6538687,63.18539533,28.21245053,2
150.6515524,80.87792831,35.63549499,4
150.0463927,57.66190822,25.61166972,2
179.390308,69.34457456,21.5483756,1
168.6063504,67.31180665,23.67791142,1
164.715269,75.71062652,27.90545577,2
157.6131655,88.96907978,35.81410607,4
180.6176576,80.64710438,24.72113278,1
155.7455737,67.45877817,27.81038782,2
169.5264842,50.5105705,17.57547872,0
183.376559,76.88244008,22.86333186,1
172.1076611,76.53237814,25.83716171,2
182.4525469,78.52615519,23.5892685,1
152.4416425,70.28994314,30.24725397,3
175.0827847,76.36906384,24.91326089,1
148.9325606,50.40144284,22.72289473,1
169.8609411,57.74390422,20.01331784,1
148.0900142,80.17583936,36.55880284,4
145.9226146,71.74693768,33.69443388,3
159.7269954,57.76592443,22.6420155,1
184.3214793,75.06659062,22.09504025,1
170.6399787,54.662222,18.77265528,1
148.4289187,55.17332476,25.04333185,2
184.5440594,52.19220818,15.32517818,0
162.1844097,57.65434794,21.91865272,1
161.9099959,89.63481422,34.19238531,3
181.0545639,81.81706614,24.9588717,1
172.9925781,59.32444994,19.82342893,1
153.5270401,75.09438561,31.85940894,3
168.0620045,55.64229267,19.69998693,1
159.0545163,68.76936709,27.18335249,2
182.2331352,59.02123891,17.7727089,0
161.2793743,85.68580379,32.94209499,3
152.6140681,61.52894501,26.41741571,2
156.2272884,51.67716953,21.17312339,1
148.6560908,80.69228625,36.5146048,4
172.8285406,66.76435121,22.35186201,1
148.1289169,75.62562337,34.46587218,3
145.5838673,74.24987407,35.03234571,4
176.4481979,88.13375372,28.30791047,2
166.3855633,58.06165102,20.97288129,1
159.7874675,85.86434008,33.63004199,3
158.6001423,83.0993259,33.03622045,3
155.6849648,86.72531051,35.78100361,4
182.0380679,81.83331672,24.69481168,1
151.0529471,57.73475946,25.30340432,2
183.8554464,55.81207933,16.51106943,0
146.0337252,53.14171151,24.91891964,1
175.3983385,57.83828842,18.80028714,1
184.9982516,63.22194314,18.47279419,0
159.3768643,84.30043029,33.18785893,3
145.9164074,53.89040686,25.31064458,2
156.9363765,60.67406266,24.63518533,1
155.4874739,81.30182693,33.62865429,3
155.4100094,74.98071056,31.04499388,3
153.8766917,57.49420776,24.28166219,1
182.1612558,81.83728795,24.66261966,1
171.4525982,78.5244881,26.71264989,2
178.9225733,65.39993003,20.42899549,1
180.2905305,88.71783925,27.29387024,2
167.4993216,83.59944472,29.7973257,2
178.3726076,61.48050138,19.32329028,1
153.594292,60.74658652,25.74967349,2
159.3989171,69.99456954,27.54822451,2
167.2124346,89.74665363,32.09823288,3
160.7251328,54.42312257,21.06763907,1
158.3090676,78.59805222,31.36174371,3
146.2168134,57.68851931,26.98328141,2
149.5912478,86.27764039,38.55546006,4
147.1582954,76.81850828,35.47287368,4
149.4123965,67.45203913,30.21494584,3
163.3200818,50.19543108,18.81850033,1
153.7814331,61.94603791,26.1942354,2
155.5895322,78.40873408,32.38945883,3
169.2698674,65.12670279,22.73001821,1
154.1585034,65.08940062,27.38894485,2
167.8883344,64.03186102,22.71721473,1
152.9047756,80.15882862,34.28540334,3
179.820623,54.4641647,16.84348102,0
163.5407149,72.08616904,26.95256873,2
162.7653018,64.14076394,24.21087429,1
176.945171,60.88480941,19.4460555,1
184.014674,86.65779044,25.59190992,2
156.5266603,56.08205813,22.89007971,1
169.5587512,58.51553614,20.35310953,1
175.3063965,89.09586175,28.99092067,2
148.9111549,81.24573858,36.63921224,4
151.6049608,78.86193774,34.31157224,3
164.6174398,62.04237867,22.89479473,1
157.9983343,79.91732085,32.01370211,3
172.441642,67.62475641,22.74162086,1
170.1356003,82.54416927,28.51648586,2
170.5969597,56.81603835,19.52218276,1
160.1619961,88.09289876,34.34171308,3
182.6271006,71.5375181,21.44882077,1
184.0416223,69.38383399,20.48453872,1
170.2841579,86.50081983,29.83126922,2
175.1466346,80.41120407,26.21277329,2
169.339176,74.30209615,25.91111975,2
160.9370743,77.67668167,29.99013773,2
182.3850435,79.8669827,24.0098162,1
146.7053531,73.04376137,33.93838774,3
168.0845942,71.47918406,25.30018904,2
170.6004502,69.61179557,23.9178694,1
178.3041772,71.91893524,22.62143827,1
162.3017291,53.32948874,20.24515596,1
172.4008266,51.39286873,17.29116174,0
149.4934094,90.99993305,40.71898834,5
147.6022913,69.52411995,31.91165306,3
151.9466991,52.3702797,22.6830972,1
156.7888829,59.06078923,24.0252942,1
172.4115667,78.54564514,26.42343745,2
148.651563,78.87955444,35.6964879,4
167.6532432,83.85123979,29.8322198,2
178.7979229,79.92009286,24.99948381,1
159.5780316,82.00740674,32.20378148,3
180.5910297,72.75337582,22.30800849,1
149.3074282,79.3107191,35.5769782,4
151.3716546,75.1254063,32.7867008,3
166.1840818,82.99459665,30.05183357,3
184.7692166,71.41154119,20.9174708,1
149.6300372,68.98412328,30.81141059,3
173.5499971,69.92581004,23.21604755,1
161.4384153,76.94770134,29.5244543,2
182.9679794,61.04245163,18.233993,0
156.1995881,78.27606955,32.08257674,3
182.3643508,84.78544291,25.49420086,2
181.0499511,62.35027309,19.02135873,1
167.0806891,65.62502791,23.50807056,1
168.5063332,80.75420945,28.4402044,2
164.1788409,69.70613119,25.86048161,2
181.4215994,53.1252378,16.14072002,0
151.9402083,78.70155436,34.09085353,3
168.2577001,71.65916784,25.31173201,2
152.0237422,79.72960205,34.49822564,3
175.0427094,64.78010501,21.14236635,1
177.3033687,76.28344327,24.26589499,1
171.7354741,75.93022702,25.74510524,2
179.9685707,82.63429176,25.51331989,2
153.3040287,71.94715896,30.61304382,3
152.6615533,63.01105518,27.03693135,2
145.8459363,79.48174547,37.36618245,4
146.4958156,53.58907889,24.97042028,1
173.1265168,80.58151243,26.884888,2
158.3202433,75.76344234,30.22642621,3
173.0441548,57.15814026,19.08816682,1
158.333763,85.9455012,34.28278396,3
182.9343224,79.80056503,23.84600191,1
167.3083282,88.7581202,31.7083013,3
155.4539857,61.61662447,25.49729919,2
154.1209788,78.49089725,33.04424069,3
184.4653564,85.03876528,24.9912118,1
155.0318437,83.39911716,34.69921338,3
152.5907669,86.97857267,37.35560418,4
149.0609161,61.95202051,27.88225555,2
171.9372368,77.30724682,26.1505195,2
170.601271,52.94429641,18.19091969,0
153.2182459,79.32503159,33.79008479,3
170.659357,73.56253416,25.25785535,2
172.3643416,72.12418905,24.2765011,1
168.5763064,72.2965142,25.44042077,2
174.0869549,82.00927773,27.06017057,2
175.74006,77.15672636,24.98229038,1
146.2420239,66.26265887,30.98307232,3
145.489392,75.1165553,35.48730391,4
160.5201071,70.67651216,27.42939494,2
174.8082885,88.32422522,28.90385667,2
160.2481918,60.52577291,23.56970075,1
148.0434908,74.90199009,34.17548785,3
181.7853065,89.12081347,26.96879787,2
181.9521881,73.89992278,22.32181062,1
168.1214277,63.3882205,22.42654753,1
184.1070588,75.8561311,22.37947059,1
165.7434423,65.45845082,23.82829751,1
163.921465,75.17112768,27.97560013,2
163.1878674,75.04894726,28.18180924,2
184.1514241,58.99138912,17.39556182,0
182.8329938,69.36432068,20.750425,1
182.2906849,63.87738932,19.22286893,1
149.8682582,50.01043963,22.26595631,1
182.2094183,63.43718942,19.10743056,1
147.0194119,86.25199189,39.90430682,4
167.574818,77.5330136,27.61017453,2
161.9812326,52.24765211,19.91303147,1
167.6740848,82.15761916,29.22240505,2
154.118762,81.68941892,34.39179022,3
181.8022972,80.42652513,24.33327545,1
176.3578257,60.59207313,19.48167707,1
159.5974037,51.84091934,20.35265405,1
172.6465786,76.91316822,25.80386427,2
178.7631678,55.0387295,17.22313938,0
172.1310153,74.93066169,25.28956195,2
165.6504023,65.34991449,23.8155181,1
183.4936715,50.8949689,15.11585343,0
179.6551575,58.01877767,17.97584011,0
169.0473656,75.92859509,26.56981616,2
172.7146392,80.20031664,26.8854793,2
178.2950252,77.99774604,24.53599088,1
159.2760076,88.18386302,34.7606907,3
150.9588277,58.11316379,25.5010164,2
182.950318,75.39919468,22.52684495,1
157.3406697,54.83425809,22.14980887,1
152.4065117,66.33777838,28.55971553,2
177.4641784,54.53854525,17.31737828,0
175.717313,86.30359482,27.95115629,2
172.4872564,50.54806869,16.98988888,0
181.3106385,64.24570114,19.543282,1
184.9831686,89.02859927,26.01747297,2
159.975983,89.19376545,34.85177679,3
152.2490568,67.73641361,29.22220458,2
154.4820317,80.95117427,33.9208876,3
148.9685523,80.39530118,36.22775917,4
167.62042,55.67636625,19.81604747,1
158.3735751,87.65824569,34.94840289,3
180.6064409,64.07336001,19.64314536,1
156.0508291,84.53370915,34.71344889,3
149.0075481,73.9006028,33.28368811,3
179.7071368,70.2424338,21.75048279,1
161.5705073,80.22228081,30.73058501,3
159.5416722,80.71150914,31.70933897,3
174.3045058,77.05127965,25.36078126,2
150.3736104,70.93578243,31.37054827,3
181.2550919,71.21812747,21.67754612,1
175.5010334,78.37750267,25.44673491,2
177.590069,64.80105457,20.54682499,1
163.66572,80.19388556,29.93820831,2
174.6346543,50.36624012,16.51500365,0
183.8126828,86.95996376,25.73762107,2
168.8061071,72.85117331,25.56585029,2
168.9084438,61.22302023,21.45912595,1
168.0549212,68.71393805,24.33001462,1
150.627658,75.11132889,33.10518342,3
173.1173238,51.04018084,17.03064691,0
166.0085831,79.87424048,28.98315476,2
171.1838987,63.28320221,21.59546803,1
182.0776954,65.79373398,19.84591099,1
156.8419633,82.05200599,33.35528558,3
175.5000343,72.38378393,23.5010286,1
171.9697225,67.0503105,22.67236437,1
167.594816,67.56062571,24.05318,1
150.7517843,56.82400824,25.0038533,2
167.7826929,81.38352545,28.90960637,2
177.0478577,83.75032304,26.71807341,2
145.5779619,78.52835032,37.05400913,4
163.152313,84.52254277,31.75309655,3
179.6280211,87.45370859,27.10379227,2
166.0409623,86.74383311,31.46357925,3
172.8257661,64.54318441,21.60893721,1
155.9954628,89.16971087,36.64320268,4
179.2596599,89.18705305,27.75470896,2
164.7431365,59.16750178,21.80060631,1
158.0807508,88.83636325,35.54944224,4
162.7486243,69.70147519,26.31523891,2
174.3393437,76.53706005,25.18146323,2
158.410351,67.77739588,27.00957594,2
172.9661805,87.42263075,29.22143106,2
151.5870944,87.07661541,37.89457871,4
173.302157,74.92502995,24.94703801,1
177.0968798,76.62601077,24.43173996,1
176.3212754,75.0916669,24.15362435,1
171.8038901,52.63785661,17.83332032,0
174.0420899,51.37380626,16.96028573,0
173.2049595,54.38778488,18.12928702,0
168.8166411,58.50679479,20.52937826,1
170.5796392,51.79879565,17.80185582,0
161.0195515,52.43331794,20.22321109,1
168.0551411,79.50085161,28.1493373,2
146.2433242,87.52450675,40.92395242,5
154.5407026,75.98272466,31.81479273,3
172.4660046,69.92307091,23.50788128,1
174.0313115,80.05359544,26.43175841,2
147.8605655,60.79959762,27.8096792,2
145.4332319,78.23109588,36.98725514,4
175.2361815,54.15221739,17.63472463,0
148.2657267,72.99963073,33.20772486,3
162.3534235,74.14012178,28.12745206,2
157.1223385,75.92337329,30.75385957,3
181.2029486,58.6381576,17.8586981,0
179.2092745,56.08266545,17.46255081,0
148.9418644,55.49571146,25.01645986,2
182.986381,76.76870269,22.92697015,1
153.3085052,78.19933319,33.27135984,3
177.9805621,83.67255136,26.41422445,2
159.3210844,57.14116349,22.51140345,1
165.9615903,77.25883561,28.05000758,2
146.1174498,85.9862305,40.27399791,5
159.154756,62.41325402,24.639823,1
175.9788186,74.92647418,24.19438018,1
149.8231288,82.53009217,36.76669612,4
183.1112068,79.15329579,23.60691048,1
174.854012,59.90646113,19.59397106,1
182.9240039,69.47266013,20.76215987,1
164.9759126,67.39014217,24.7602653,1
167.6688455,76.45539041,27.19589541,2
176.3464495,82.30658632,26.46678255,2
174.5285875,50.24554942,16.4954608,0
183.3561583,68.14734187,20.27019463,1
183.7558575,51.50061773,15.25211442,0
162.075959,86.28202245,32.84604751,3
155.957132,71.26326178,29.29916318,2
160.1795587,76.93054418,29.98365821,2
147.6248672,59.34823102,27.2325769,2
181.3829508,71.73579676,21.80434276,1
154.8249707,80.68011227,33.65770356,3
177.8784041,75.75758154,23.94305996,1
167.9123709,83.86121185,29.74375166,2
175.5386515,87.82911632,28.50315915,2
172.8805992,80.81401083,27.03921899,2
174.3321336,84.19801098,27.70428468,2
149.4782888,78.38342296,35.08067893,4
169.1932897,67.87375715,23.71022203,1
174.9127973,83.09751802,27.16094535,2
149.9777468,53.85939414,23.94461257,1
157.6901,67.58749653,27.18050384,2
149.8611359,61.2001462,27.25049651,2
183.7267805,84.48437459,25.02830644,2
161.1953623,86.1595461,33.15876273,3
153.5181619,69.35334107,29.42712702,2
155.6317648,86.45351537,35.69325648,4
169.9845004,76.60473879,26.51166415,2
176.8340627,50.81973709,16.25176614,0
175.0697603,55.23408735,18.02124986,0
182.2243265,74.17226373,22.33720482,1
151.1684906,88.54046905,38.74532115,4
168.1994292,78.35154858,27.69482014,2
150.9651206,63.64761493,27.92729659,2
179.9674734,65.84629508,20.33027741,1
147.9210768,72.72361748,33.2365054,3
154.7921768,81.82589135,34.15015895,3
164.7096753,88.75038795,32.71387864,3
152.1781317,52.79752692,22.79864752,1
149.982159,79.94577965,35.53991134,4
165.2595737,84.32591244,30.87648063,3
179.8377603,75.3279897,23.29134708,1
176.7450689,62.62620267,20.04755746,1
173.296572,64.51815255,21.48334878,1
173.4572788,75.59740841,25.12591288,2
174.505952,71.3026303,23.41450939,1
157.8317687,70.35301931,28.24189889,2
182.280835,89.64614962,26.98047693,2
159.8341186,71.37517257,27.93882839,2
163.2056594,67.49663993,25.34029841,2
175.369942,89.22555119,29.01208379,2
156.4252993,87.6635424,35.82655315,4
169.8816604,81.32900021,28.18074299,2
171.1144927,53.38125057,18.23119581,0
168.8614906,65.14963914,22.84813611,1
176.4838219,68.0626426,21.8523932,1
148.2570204,88.63959365,40.32712127,5
152.3758229,70.22619111,30.24593301,3
160.2629815,58.11471752,22.62662063,1
153.3379902,82.04809498,34.89546329,3
152.3362552,67.88776432,29.25397973,2
148.6005181,74.18673024,33.59584609,3
184.6203628,79.20092313,23.23651769,1
174.9782669,80.64383213,26.33922156,2
166.7766028,68.71267061,24.70396037,1
148.1258255,83.70953852,38.15165482,4
184.5434548,69.5752964,20.42950161,1
171.9460395,76.22336771,25.78123865,2
158.6319375,89.72704574,35.65677917,4
178.8705627,88.85236123,27.77098138,2
163.4693003,58.87708535,22.03301321,1
165.9491547,85.63276738,31.09495194,3
183.0362117,89.24582279,26.63875113,2
149.0118872,86.66500511,39.03030427,4
148.1957429,86.87868553,39.5586794,4
153.1556163,87.06557239,37.11766134,4
152.3601273,65.65493287,28.28295014,2
171.7113937,83.98450965,28.48399917,2
162.8425068,74.77283928,28.19735119,2
183.5581212,63.65249121,18.89157635,1
182.9890081,69.904489,20.87637356,1
158.9937181,65.96703862,26.09558622,2
147.9565227,77.11510724,35.22664113,4
155.5827786,67.21372696,27.76738172,2
147.6850494,79.20831283,36.31597408,4
184.5470772,54.84662395,16.10406657,0
150.6249304,78.21400773,34.47393207,3
174.7982124,62.96715665,20.6082023,1
151.8402525,55.82792759,24.21462079,1
175.6415194,53.93923076,17.48438568,0
166.9464257,59.92177302,21.49960123,1
150.5567741,53.66719352,23.67599738,1
170.9130725,59.93013043,20.51609302,1

bmi_test.csv

177.5890062,64.71439732,20.51959369,1
164.8972845,63.23919414,23.25730132,1
184.4277219,61.55920972,18.09841875,0
154.4941492,80.30581669,33.64518567,3
158.6124011,73.35584165,29.15818083,2
166.0370229,82.36169692,29.87551569,2
169.6612532,82.39583608,28.62463278,2
172.8582082,52.272359,17.49411939,0
183.7597861,73.43739126,21.74784793,1
178.3102594,84.42059402,26.55190832,2
145.7058367,88.84342396,41.84768034,5
174.9631237,63.90942817,20.87718232,1
165.3872153,83.06574655,30.36813372,3
173.4048449,66.75339003,22.19989249,1
179.6483265,58.74047555,18.20082635,0
175.8814376,76.5235827,24.73747004,1
150.43772,88.93210917,39.29570701,4
156.0780163,50.15016044,20.58680037,1
164.5276807,60.66220104,22.40991518,1
172.3360904,75.772929,25.51300431,2
175.7425384,84.70157933,27.42443685,2
153.4043271,78.98498113,33.56365891,3
181.4151475,73.55314109,22.34879601,1
173.3434861,60.09424361,19.99943931,1
165.0101735,74.37733123,27.31612864,2
181.0326434,78.43490605,23.9329148,1
166.9788122,72.89473459,26.14408364,2
184.4072401,77.5883363,22.81605229,1
158.9369714,63.38543563,25.09225044,2
154.0814563,85.96921674,36.21114195,4
164.7194599,76.60178479,28.23248262,2
183.6692317,79.3862389,23.53273521,1
174.8219975,56.41324809,18.45818309,0
147.936031,84.47716627,38.60036506,4
150.107019,51.9918355,23.07454516,1
159.2579522,68.60528806,27.04925684,2
168.1818128,80.40685083,28.42725953,2
171.7174646,72.95312311,24.74086922,1
172.1301013,85.02582818,28.69704787,2
145.4192007,58.97900827,27.89034966,2
150.2171378,72.46398354,32.11317451,3
163.2451122,64.42814556,24.17660182,1
164.2772281,67.65596147,25.06982704,2
182.2175676,54.06466818,16.2829486,0
152.9064199,70.45371725,30.13370081,3
158.4670756,54.28863847,21.61876413,1
176.06228,78.99466036,25.48385309,2
150.7407019,77.22001955,33.98355754,3
169.2401197,80.0370143,27.94372068,2
181.8093354,76.64634827,23.18777671,1

世界一わかりやすいTensorflowのチュートリアル(を目指す)

Tensorflowの公式チュートリアル、わかりにくくないっすか…?

Get Start with Tensorflow!さぁ、チュートリアルを始めよう! そう思って1番最初のBasic classificationに足を踏み入れた瞬間、そこにはFashion Mnistをmatplotlibで描画するGoogle先生の姿が…( ゚д゚)

しかも、最初のモデルを作ってみようのコーナーで

model = keras.Sequential([
    keras.layers.Flatten(input_shape=(28, 28)),
    keras.layers.Dense(128, activation=tf.nn.relu),
    keras.layers.Dense(10, activation=tf.nn.softmax)
])

…( ゚д゚)

いきなり抽象化されたkerasの関数使ってますやん。便利であることは間違いないのですが、ブラックボックスであるが故に中でどんな計算が行われているのか全然わかりません。 機械学習の教科書を真面目に読んできた初心者たちは、「あれ? わいは入力ベクトルと行列の掛け算がニューラルネットワークやって習ったで? Flattenってなんや? Denseってなんや?」となってしまいそうです。

この記事は、そんな初心者でもピュアなTensorflowの使い方を1発で理解できるチュートリアルとなっております。「ピュアな」というのは、抽象化されたkeras関数を使わず、行列計算を組み合わせてニューラルネットワークをスクラッチで作っていくということです。

世界一わかりやすい記事を目指してがんばります。

最も基本的なTensorflowの流れ

何をするにも「大きな流れ」というものがあります。カレーを作るには、材料をカットして炒めて煮てカレー粉を入れる、という大きな流れがあります。まずは、この「基本」をマスターしましょう。「10分煮込むより15分煮込んだ方が美味しいんじゃないか」といった細かい研究は後で好きなだけやってください。

Tensorflowでは以下の流れが基本です。

  • プレースホルダーの定義
  • ニューラルネットワークの定義
  • 誤差関数の定義
  • 正解率の定義
  • 学習アルゴリズムの選定
  • 教師データの読み込み
  • セッションを張る
  • 学習の実行

それぞれが何を表しているのかについて、これから1つずつ見ていきます。

まずは、この基本をおさえた最も単純なコードを以下に示します。

import tensorflow as tf

# プレースホルダーの定義
x = tf.placeholder(tf.float32, [None, 2])
t = tf.placeholder(tf.float32, [None, 2])

# ネットワークの定義
w = tf.Variable(tf.random_normal([2,2], mean=0.0, stddev=1.0))
b = tf.Variable(tf.random_normal([2], mean=0.0, stddev=1.0))
y = tf.matmul(x, w) + b

# 誤差関数の定義
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=t, logits=y))

# 正解率の定義
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(t, 1))
acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

# 学習アルゴリズムの選定
train = tf.train.MomentumOptimizer(0.01, 0.9).minimize(loss)

# おまじない
init = tf.global_variables_initializer()

# 教師データの読み込み
#  ここでは、(1,1) が入力されたときのみ (0,1) を出力し、
#  (0,0), (1,0), (0,1) が入力されたときは (1,0) を出力するAND回路を
#  題材としました。
train_x = [ [0,0], [0,1], [1,0], [1,1] ]
train_t = [ [1,0], [1,0], [1,0], [0,1] ]
test_x  = [ [0,0], [0,1], [1,0], [1,1] ]
test_t  = [ [1,0], [1,0], [1,0], [0,1] ]

# セッションを張る
with tf.Session() as sess:
    sess.run(init) # おまじない

    print('Epoch\tTraining loss\tTest loss\tTraining acc\tTest acc')

    # 学習の実行
    for epoch in range(200):
        sess.run(train, feed_dict={
            x: train_x,
            t: train_t
        })

        # 5エポックに1回の頻度で、学習の状況を描画する
        if (epoch+1) % 5 == 0:
            print('{}\t{}\t{}\t{}\t{}'.format(
                str(epoch+1),
                str(sess.run(loss, feed_dict={x:train_x, t:train_t})),
                str(sess.run(loss, feed_dict={x:test_x, t:test_t})),
                str(sess.run(acc, feed_dict={x:train_x, t:train_t})),
                str(sess.run(acc, feed_dict={x:test_x, t:test_t})),
            ))

では、それぞれの要素を順番に見ていきます。

プレースホルダーの定義

# プレースホルダーの定義
x = tf.placeholder(tf.float32, [None, 2])
t = tf.placeholder(tf.float32, [None, 2])

Tensorflowにはプレースホルダーという概念があります。よく文字列のプレースホルダーとして「*」という記号を使いますよね。あれは「*には色々な文字が入りますよ」ということを表しています。 Tensorflowのプレースホルダーも同じです。x = tf.placeholder(…)のように書くと「xには色々な値が入りますよ」ということを宣言したことになります。「色々な値」とは、教師データのことです。 上のように定義したプレースホルダーxtは、それぞれ「入力データ」と「出力データ」を入れるための箱ということになります。( x は数学で変数名としてよく使うエックス、tは目的変数(target variable)の頭文字から来ています。)

引数にtf.float32, [None, 2]と書いてありますが、これは「どんな値が入るのか」を指定しています。第一引数で値の型を指定し、第二引数で値の shape (何×何 行列か)を指定します。

shapeの1つ目がNoneとなっているのは「可変」であることを意味しています。なぜ可変で良いのかというと、バッチサイズが関係しているのですが、初心者の方は「そういうものか」と受け流してもらっても良いです。 shapeの2つ目が2となっているのは、今回の入力データおよび出力データの次元が2次元だからです。

ネットワークの定義

# ネットワークの定義
w = tf.Variable(tf.random_normal([2,2], mean=0.0, stddev=1.0))
b = tf.Variable(tf.random_normal([2], mean=0.0, stddev=1.0))
y = tf.matmul(x, w) + b

tf.Variable()で、学習可能な重みパラメータを定義できます。ここでは、2x2の重み行列としてw、 2次元のバイアスとしてbを学習可能なパラメータとして定義しました。 yはニューラルネットワークの出力となります。ここでは、さきほど定義したxwの積(matmul)を計算し、バイアスbを加えるという、最も単純なニューラルネットワークを定義しました。

誤差関数の定義

# 誤差関数の定義
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=t, logits=y))

softmax_cross_entropy_with_logits()は、2つの行列のソフトマックスクロスエントロピーを計算します。最終的に誤差lossはスカラーにする必要があるため、各バッチの平均をとるためにtf.reduce_meanを呼んでいます。

教師データの出力値tと、ニューラルネットワークの出力値yを、それぞれlabelslogitsに指定します。

正解率の定義

# 正解率の定義
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(t, 1))
acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

教師データtと、ニューラルネットワークの出力値yから得られる予測とを比較し、予測が当たっていたかどうかを判定します。そのためには、1hotベクトルtの1が立っているインデックスと、ベクトルyの最も大きい要素のインデックスが等しいかどうかを見ればOKです。それをやっているのがtf.equal(tf.argmax(y, 1), tf.argmax(t, 1))です。これも、出力は行列になるので平均をとってスカラー化するためにtf.reduce_meanを呼んでいます。

学習アルゴリズムの選定

# 学習アルゴリズムの選定
train = tf.train.MomentumOptimizer(0.01, 0.9).minimize(loss)

Tensorflowでは、任意のOptimizerのがもつminimizeというメソッドに誤差関数を渡します。この左辺trainを使って、今後、学習を実行していくことになります。上の例ではOptimizerとしてMomentumOptimizerを採用していますが他のものを選定してもらっても構いません。

教師データの読み込み

# 教師データの読み込み
#  ここでは、(1,1) が入力されたときのみ (0,1) を出力し、
#  (0,0), (1,0), (0,1) が入力されたときは (1,0) を出力するAND回路を
#  題材としました。
train_x = [ [0,0], [0,1], [1,0], [1,1] ]
train_t = [ [1,0], [1,0], [1,0], [0,1] ]
test_x  = [ [0,0], [0,1], [1,0], [1,1] ]
test_t  = [ [1,0], [1,0], [1,0], [0,1] ]

(1,1) が入力されたときのみ (0,1) を出力し、(0,0), (1,0), (0,1) が入力されたときは (1,0) を出力するAND回路を学習させようと思います。

セッションを張る

# セッションを張る
with tf.Session() as sess:

ほとんど、おまじないと思ってもらってOKです。Tensorflowではセッションというものを張り、セッション中でrun(…)というメソッドを叩くことで、ニューラルネットワークを動かしていきます。

学習の実行

    # 学習の実行
    for epoch in range(200):
        sess.run(train, feed_dict={
            x: train_x,
            t: train_t
        })

        # 5エポックに1回の頻度で、学習の状況を描画する
        if (epoch+1) % 5 == 0:
            print('{}\t{}\t{}\t{}\t{}'.format(
                str(epoch+1),
                str(sess.run(loss, feed_dict={x:train_x, t:train_t})),
                str(sess.run(loss, feed_dict={x:test_x, t:test_t})),
                str(sess.run(acc, feed_dict={x:train_x, t:train_t})),
                str(sess.run(acc, feed_dict={x:test_x, t:test_t})),
            ))

200エポックのループを回し、各ループで sess.run(train, feed_dict={ x: train_x, t: train_t }) というメソッドを呼んでいます。第一引数には#学習アルゴリズムの選定 のところで定義したtrainという変数を、第二引数にはfeed_dictという辞書オブジェクトを指定します。このfeed_dictこそが、プレースホルダーに与える「教師データ」となります。

実行してみる

上述のコードをtrain.pyと名付けて、以下のコマンドを実行しましょう。

python train.py

以下のような出力が得られると思います。(実際の出力は、乱数によって少し変わります。)

Epoch    Training loss   Test loss   Training acc    Test acc
5   1.0601006   1.0601006   0.25    0.25
10  1.0230795   1.0230795   0.25    0.25
15  0.97862685  0.97862685  0.25    0.25
20  0.93293065  0.93293065  0.5 0.5
25  0.8886891   0.8886891   0.5 0.5
30  0.8467394   0.8467394   0.5 0.5
35  0.8071697   0.8071697   0.5 0.5
40  0.7698909   0.7698909   0.5 0.5
45  0.73484254  0.73484254  0.75    0.75
50  0.7020121   0.7020121   0.75    0.75
55  0.67139405  0.67139405  0.75    0.75
60  0.6429507   0.6429507   0.75    0.75
65  0.61659825  0.61659825  0.75    0.75
70  0.59221315  0.59221315  0.75    0.75
75  0.5696492   0.5696492   0.75    0.75
80  0.54875386  0.54875386  0.75    0.75
85  0.52937984  0.52937984  0.75    0.75
90  0.5113907   0.5113907   0.75    0.75
95  0.49466205  0.49466205  0.75    0.75
100 0.47908068  0.47908068  0.75    0.75
105 0.46454394  0.46454394  0.75    0.75
110 0.45095852  0.45095852  0.75    0.75
115 0.43823957  0.43823957  0.75    0.75
120 0.42631024  0.42631024  0.75    0.75
125 0.4151013   0.4151013   0.75    0.75
130 0.40455014  0.40455014  0.75    0.75
135 0.39460057  0.39460057  1.0 1.0
140 0.38520193  0.38520193  1.0 1.0
145 0.3763088   0.3763088   1.0 1.0
150 0.3678801   0.3678801   1.0 1.0
155 0.35987893  0.35987893  1.0 1.0
160 0.35227215  0.35227215  1.0 1.0
165 0.3450296   0.3450296   1.0 1.0
170 0.33812422  0.33812422  1.0 1.0
175 0.33153147  0.33153147  1.0 1.0
180 0.32522893  0.32522893  1.0 1.0
185 0.3191966   0.3191966   1.0 1.0
190 0.31341603  0.31341603  1.0 1.0
195 0.30787042  0.30787042  1.0 1.0
200 0.30254456  0.30254456  1.0 1.0

プロットするとこんな感じです。

f:id:twx:20180821030519p:plain

精度100%で学習できましたね。

大まかな流れは以上となります。

次の記事では、より難しいデータセットで学習する方法や、精度を上げるための様々なテクニックをご紹介する予定です。

Kozuko Mahiro's Hacklog ―― Copyright © 2018 Mahiro Kazuko