ひろこま Hack Log

プログラミングや機械学習などの知識を記録・共有します

説明可能な人工知能「LIME」を使って初音ミクのイラストを分析してみた

f:id:twx:20191002235714p:plain
説明可能な人工知能「LIME」を使って初音ミクのイラストを分析してみた

1. LIMEとは

論文はこちら https://arxiv.org/abs/1602.04938

任意の判定モデルに対して、何らかの結果が出たときに「なぜその結果が出たのか?」を説明できるツールです。

LIMEはLocal Interpretable Model-agnostic Explanationsの略です。意訳すると「局所的な、モデルを知らずとも可能な説明」という感じですが、要は

・モデルやデータの形式、モデルの構造に依存的である

・任意の結果を包括的に説明するのではなく、ある単一の結果をピンポイントで説明する

ということです。

仕組みは意外とシンプルで、特定の入力に対して摂動(=わずかなノイズ)を与えたときの出力の変化を観察することで、どのような入力が出力に最も強く影響を与えているのかを算出します。そして、その「最も強く効いている入力」と出力との関係性を線形モデルで表すというアイデアです。

f:id:twx:20190415021758p:plain
https://arxiv.org/abs/1602.04938から引用

2. LIMEのチュートリアルをやってみる

前述のように、LIMEはモデルやデータ形式に依存しません。実際、LIMEの公式Githubでは、テキスト分類や画像認識など様々なタスクや様々なデータに対してLIMEを適用した例が掲載されています。

ここでは、チュートリアルとして画像認識へのLIME適用を試してみます。

以下では、GPU環境にGoogle Colaboratory (Python3, GPU使用)をします。なお、本記事で書いたコードはGoogle Colaboratory上で公開していますので、同じことをやりたい方は参考にしてください!

ついでに過去記事も読んでいただけると、よりわかりやすいと思います!

2.1 画像認識

LIMEのチュートリアルに従って進めていきましょう。

まずは普通に画像認識モデルを動かしてみます。

import os

# Kerasで画像分類モデル(Inception v3)を準備
import keras
from keras.applications import inception_v3 as inc_net
from keras.preprocessing import image
from keras.applications.imagenet_utils import decode_predictions
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
inet_model = inc_net.InceptionV3()

# 画像のパスリストを渡して画像のピクセル配列を返してくれる関数
def transform_img_fn(path_list):
    out = []
    for img_path in path_list:
        img = image.load_img(img_path, target_size=(299, 299))
        x = image.img_to_array(img)
        x = np.expand_dims(x, axis=0)
        x = inc_net.preprocess_input(x)
        out.append(x)
    return np.vstack(out)

画像はこちらを使います。

!wget https://raw.githubusercontent.com/marcotcr/lime/master/doc/notebooks/data/cat_mouse.jpg
# Inception v3 モデルでこの画像に写っているものを判定
images = transform_img_fn([os.path.join('cat_mouse.jpg')])
plt.imshow(images[0] / 2 + 0.5)
preds = inet_model.predict(images)

for x in decode_predictions(preds)[0]:
    print(x)

以下のような結果が得られます。

('n02133161', 'American_black_bear', 0.6371622)
('n02105056', 'groenendael', 0.03181786)
('n02104365', 'schipperke', 0.029944154)
('n01883070', 'wombat', 0.028509287)
('n01877812', 'wallaby', 0.025093317)

1位がスコア0.637で American_black_bear (アメリカグマ) ※黒い熊

2位がスコア0.031で groenendael(グローネンダール)※黒い犬

という感じ。

では、このモデルにLIMEを適用していきます。

2.2 LIME

pipでlimeをインストールします。(※ Google Colabolatory上での実行を前提としているため、シェルコマンドの先頭に ! をつけて表記しています。Colabolatoryを使わない場合は ! を外して読んでください。)

!pip install lime
import lime
from lime import lime_image
explainer = lime_image.LimeImageExplainer()
explanation = explainer.explain_instance(images[0], inet_model.predict, top_labels=5, hide_color=0, num_samples=1000)

ここで、入力画像に摂動を加えて出力の変化を観察しています。

最後の関数の引数について説明すると、

・第1引数: images[0] は入力画像です。直前に作った画像のリストのうち、0番目(最初)の画像を使用します。

・第2引数: inet_model.predict は説明対象のモデルです。今回は画像判定モデル(Inception v3)です。

・第3引数: top_labels=5 は画像認識の結果の上位何件に対して、説明を行うかを表しています。

・第4引数: hide_color=0 は「オフにするスーパーピクセルの色」なんですが、細かい話なので何でも良いです。(画像であまり使われていない色が良い)

・第5引数: num_samples=1000 は何回摂動させるかを表しています。この数字が大きいほど、精確な説明が得られます。(ただし時間がかかります)

上記のコードの実行には少し時間がかかります。

さて、実行が終わると、モデルの判定結果 上位5件に対して「なぜその5件が得られたのか」の理由が生成されます。

ちょっと解釈が難しいのですが、簡単にいうと「画像中の どの部分が 最もそれっぽかったのか」が可視化されるということです。今回の例でいうと、American_black_bear (アメリカグマ) っぽい部分はどこっだのか?ということですね。

「AIモデルが最も注目した箇所」と言い換えても良いと思います。

仕組みは、冒頭で解説した通りです。入力画像の一部にノイズを加えたときの出力(画像認識結果)の変化を観察しています。今回の場合、例えば背景の草原が多少ノイズで乱されても、問題なくアメリカグマと認識されるはずです。しかし、アメリカグマが写っている部分をノイズで乱されてしまったら、おそらくアメリカグマと認識しにくくなるはずですね。このような「入力に与えるノイズ と 出力結果の変化」から、アメリカグマと認識するために重要な箇所を割り出しているのです。

百聞は一見にしかず。結果を見てみましょう。

from skimage.segmentation import mark_boundaries
temp, mask = explanation.get_image_and_mask(295, positive_only=False, num_features=5, hide_rest=False)
plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))

get_image_and_maskの第1引数の295は何かというと、1位のラベル(American_black_bear)のIDです。

このIDはこちらで調べることができます。

https://github.com/thekevinscott/dataset-tutorial-for-image-classification/blob/master/imagenet_labels.json

結果はこちら。

f:id:twx:20191002230807p:plain
LIME適用結果

緑に光っている部分が「この画像を American_black_bear と判定した根拠」の部分です。 一方、赤くなっている部分が「この画像を American_black_bear と判定するにあたって、最も邪魔だった箇所」です。確かに、この部分には他の動物が写っているので、もしこの部分が無ければもっと高スコアでAmerican_black_bear と判定されていたはずです。

2.3 LIMEをillust2vecに応用

過去の記事で、イラストを分類できる illust2vec というモデルを紹介しました。

www.mahirokazuko.com

次は、これにLIMEを適用してみましょう。基本的は使い方は同じです。

まずはillust2vecをインストールします。

!git clone https://github.com/rezoo/illustration2vec.git
%cd illustration2vec/
!pip install -r requirements.txt

!wget https://github.com/rezoo/illustration2vec/releases/download/v2.0.0/illust2vec_tag_ver200.caffemodel
!wget https://github.com/rezoo/illustration2vec/releases/download/v2.0.0/illust2vec_ver200.caffemodel
!wget https://github.com/rezoo/illustration2vec/releases/download/v2.0.0/tag_list.json.gz
!gzip -d tag_list.json.gz

初音ミクを判定してみましょう。

import i2v
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

illust2vec = i2v.make_i2v_with_chainer(
    "illust2vec_tag_ver200.caffemodel", "tag_list.json")
img = Image.open("images/miku.jpg")
plt.imshow(np.array(img))
illust2vec.estimate_plausible_tags([img], threshold=0.5)

f:id:twx:20191002231405p:plain
ミクさんを判定

twintails aqua hair detached sleeves skirt など、なるほどなーと納得感のある結果が得られました。

個人的には detached sleeves (分離した袖)っていうタグがマニアックでツボりましたw

さて、このモデルに対してLIMEを適用します。

まずはモデルを再定義します。上でやったやり方では、スコアが threshold=0.5 を上回る結果しか表示しないようになっていましたが、出力ラベルのIDと、出力ベクトルの次元数を一致させたいので、以下では threshold=-1 として全てのラベルが出力されるようにしています。

import i2v
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

illust2vec = i2v.make_i2v_with_chainer(
    "illust2vec_tag_ver200.caffemodel", "tag_list.json")
img = Image.open("images/miku.jpg")
plt.imshow(np.array(img))
result = illust2vec.estimate_plausible_tags([img], threshold=-1)[0]

all_features = []
all_features.extend(result['general'])

all_features = sorted(all_features, key=lambda x: x[0])

tags = [f[0] for f in all_features]
for i, tag in enumerate(tags):
    print(i, tag)
print(all_features)
print(len(all_features))

次に、LIMEが受け取れる形に、モデルを整形しましょう。

def my_predict(images):
    images = [Image.fromarray(np.uint8(img)) for img in images]
    result = illust2vec.estimate_plausible_tags(images, threshold=-1)[0]
    all_features = []
    all_features.extend(result['general'])
    all_features = sorted(all_features, key=lambda x: x[0])
    scores = np.array( [ [f[1] for f in all_features] ], dtype=np.float32)
    
    scores = scores / scores[0].sum()
    return scores

この判定器(my_predict)を使って実際に判定してみます。

preds = my_predict(images)
for x in preds.argsort()[0][-10:]:
    print(x, tags[x], preds[0,x])

上位10件を表示してみます。

310 necktie 0.043684043
415 skirt 0.052816156
149 detached sleeves 0.058019824
489 very long hair 0.06485639
272 long hair 0.068677224
29 aqua hair 0.07140998
426 solo 0.07503151
476 twintails 0.07643303
2 1girl 0.076905265
459 thighhighs 0.07755106

おなじみのタグが出力されました!

一番左の数字がラベルのIDです。

では、先ほどと同じ様にLIMEで説明を生成します。対象モデルは my_predict, 対象ラベルは上位5件です。

explainer = lime_image.LimeImageExplainer()
explanation = explainer.explain_instance(images[0], my_predict, top_labels=5, batch_size=1, hide_color=0, num_samples=100)

いくつか結果を見てみましょう。

# thighhighs: 絶対領域に最も反応した箇所
temp, mask = explanation.get_image_and_mask(459, positive_only=False, num_features=5, hide_rest=False)
plt.imshow(mark_boundaries(temp, mask))

f:id:twx:20191002232452p:plain
絶対領域に反応した部分

# ツインテールに最も反応した箇所
temp, mask = explanation.get_image_and_mask(476, positive_only=False, num_features=5, hide_rest=False)
plt.imshow(mark_boundaries(temp, mask))

f:id:twx:20191002232851p:plain
ツインテールに反応した部分

いい感じで反応していますね!

2.4 LIMEを用いて2枚の画像の類似箇所を比較

以上、LIMEを用いた画像認識モデルの説明について紹介しましたが、ここではもうちょっと応用について考えてみます。

類似画像判定について考えてみます。

2つの画像があったとして、それらが似ているかどうかの判定を行い、しかも「どこが似ているのか」を提示する ということがLIMEを用いて実現できます。

考え方は簡単で、2つのイラストに対して上述の illust2vec でタグ付与を行い、上位10件に同一のタグを含んでいたら「似ている」と判定します。そして、そのラベルを出力した理由をLIMEで可視化します。

ここでは、以下の2つの画像を使います。

f:id:twx:20191002234149p:plain
実験に用いる2枚の画像

左側のミクはもうやったので、今度は右側のミクを判定してみましょう。

preds = my_predict(images)
for x in preds.argsort()[0][-10:]:
    print(x, tags[x], preds[0,x])
70 blue hair 0.029066972
28 aqua eyes 0.029589249
118 cloud 0.036495335
69 blue eyes 0.039897494
272 long hair 0.044192325
426 solo 0.04657326
476 twintails 0.05039558
418 sky 0.051117815
29 aqua hair 0.051458504
2 1girl 0.056882586

右側のミクからも twintails が検出されました。

したがって「この2枚の画像はともに twintails の要素があるので似ている」という結果となります。

では、twintailsに最も反応した領域を可視化してみます。

explainer = lime_image.LimeImageExplainer()
explanation = explainer.explain_instance(images[0], my_predict, top_labels=5, batch_size=1, hide_color=0, num_samples=100)

# ツインテール
temp, mask = explanation.get_image_and_mask(476, positive_only=False, num_features=5, hide_rest=False)
plt.imshow(mark_boundaries(temp, mask))

f:id:twx:20191002234425p:plain
ツインテールに反応した部分

LIMEの結果を並べてみます。

f:id:twx:20191002235112p:plain
LIMEの結果を並べて表示

このように、両者に共通している特徴である twintails を並べてることで、「どこが似ているのか」を可視化することができます。最後に、LIMEで反応した部分を緑に塗るのではなく、その部分を円で囲うようにしてみました。ちょっと見やすくなりましたかね?

f:id:twx:20191002235401p:plain
見やすくしてみた

以上、説明可能な人工知能「LIME」を使って初音ミクのイラストを分析してみました。

良い記事だと思っていただいた方は、以下の「★+」ボタンのクリック、SNSでのシェア、「読者になる」ボタンのクリック、Twitterのフォローをお願いします!

Koma Hirokazu 's Hacklog ―― Copyright © 2018 Koma Hirokazu