4. Deep Learningフレームワークの基礎

colab-logo

ChainerはDeep Learningフレームワークの一つで,現在様々なDeep Learningフレームワーク(TensorFlow, PyTorch, etc.)でも採用され主要なニューラルネットワークの記法となっているDefine-by-Runというアイデアを汎用的なDeep Learningフレームワークとしては初めて採用し,2015年からPreferred Networks社によって開発が続けられています.Define-by-Runとは,ニューラルネットワーク中の計算を行うコードを記述することでニューラルネットワークの構造を定義する考え方です.学習を行う前にネットワーク構造を定義しておき,そのネットワークに学習に用いるデータを入力するためのコードを別途書く必要がある方法はDefine-and-Runと呼ばれます.Define-by-Runは実行時にネットワーク構造が決定されるため,動的な構造を記述しやすいという特徴があります.

ここでは,その柔軟性直感的であることを特徴とするこのChainerというフレームワークの基本的な使い方を解説します.

4.1. 環境構築

まずはColab上で以下のセルを実行し,必要なライブラリをインストールしましょう.ここではgraphvizというソフトウェアをインストールしています.これは,後にニューラルネットワークのアーキテクチャをグラフ構造として可視化するために使用します.Google Colab上には,ChainerやCuPyは予めインストールされています.

[1]:
!apt-get install -y graphviz
Reading package lists... Done
Building dependency tree
Reading state information... Done
graphviz is already the newest version (2.40.1-2).
The following packages were automatically installed and are no longer required:
  cuda-cufft-10-1 cuda-cufft-dev-10-1 cuda-curand-10-1 cuda-curand-dev-10-1
  cuda-cusolver-10-1 cuda-cusolver-dev-10-1 cuda-cusparse-10-1
  cuda-cusparse-dev-10-1 cuda-license-10-2 cuda-npp-10-1 cuda-npp-dev-10-1
  cuda-nsight-10-1 cuda-nsight-compute-10-1 cuda-nsight-systems-10-1
  cuda-nvgraph-10-1 cuda-nvgraph-dev-10-1 cuda-nvjpeg-10-1
  cuda-nvjpeg-dev-10-1 cuda-nvrtc-10-1 cuda-nvrtc-dev-10-1 cuda-nvvp-10-1
  libcublas10 libnvidia-common-430 nsight-compute-2019.5.0
  nsight-systems-2019.5.2
Use 'apt autoremove' to remove them.
0 upgraded, 0 newly installed, 0 to remove and 5 not upgraded.

それでは,以下のコマンドをターミナルで実行し,Chainerや,ChainerでGPUを活用するために必要となるCuPyというパッケージが正しくインストールされているかどうかを確認してみましょう.

[2]:
!python -c 'import chainer; chainer.print_runtime_info()'
Platform: Linux-4.14.137+-x86_64-with-Ubuntu-18.04-bionic
Chainer: 6.5.0
ChainerX: Not Available
NumPy: 1.17.4
CuPy:
  CuPy Version          : 6.5.0
  CUDA Root             : /usr/local/cuda
  CUDA Build Version    : 10010
  CUDA Driver Version   : 10010
  CUDA Runtime Version  : 10010
  cuDNN Build Version   : 7603
  cuDNN Version         : 7603
  NCCL Build Version    : 2402
  NCCL Runtime Version  : 2402
iDeep: 2.0.0.post3

Chainer, NumPy, そしてCuPy, さらにCuPyの下にCUDAやcuDNN, NCCLといった項目があり,それぞれバージョン番号が表示されていれば成功です.

4.2. Chainerの基本的な使い方

はじめに,シンプルなタスクに実際に取り組むことによって,Chainerの基本的な使い方を説明していきます.さっそく,有名な手書き数字のデータセットMNISTを使って,画像を10クラス(数字の0 - 9)のいずれかに分類するネットワークを書き,学習させてみましょう.

4.2.1. データセットの準備

まずは学習対象となるデータセットの準備をします.教師あり学習の場合,データセットは「入力データ」と「それと対になるラベルデータ」のペアを返すオブジェクトである必要があります.

Chainerには,MNISTやCIFAR10/100のような良く用いられるデータセットに対して,データのダウンロードからオブジェクト作成までを自動的に行ってくれる便利なメソッドがあります.ここではひとまずこれを用いましょう.

[3]:
from chainer.datasets import mnist

# データセットがダウンロード済みでなければ,ダウンロードも行う
train_val, test = mnist.get_mnist(withlabel=True, ndim=1)
Downloading from http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz...
Downloading from http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz...
Downloading from http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz...
Downloading from http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz...

データセットオブジェクトの準備ができました.このオブジェクトは, train_val[i] のように指定すると,i番目の (data, label) というタプルを返すリスト と同様のものと考えてください.(実際ただのPythonリストもChainerのデータセットオブジェクトとして利用可能です).それでは,0番目のデータとラベルを取り出して,表示してみましょう.

[4]:
# matplotlibを使ったグラフ描画結果がnotebook内に表示されるようにします.
%matplotlib inline
import matplotlib.pyplot as plt

# データの例示
x, t = train_val[0]  # 0番目の (data, label) を取り出す
plt.imshow(x.reshape(28, 28), cmap='gray')
plt.axis('off')
plt.show()
print('label:', t)
../_images/notebooks_04_Introduction_to_Chainer_10_0.png
label: 5

4.2.2. Validation用データセットを作る

次に,さきほど作成したtrain_valデータセットを,Training用のデータセットとValidation用のデータセットに分割します.Validationデータセットとは,学習には用いずにモデルの汎化性能をチェックしたり,学習率などのハイパーパラメータを調整するために用いる検証用のデータセットスプリットのことです.分割処理も,Chainerが提供しているデータセット分割用の関数を用いて行うことができます.元々60000個のデータが入っているtrainデータセットを,ランダムに選択された50000個のデータと残りの10000個のデータの2つに分割しましょう.これには,split_dataset_randomという関数を使用します.

[5]:
from chainer.datasets import split_dataset_random

train, valid = split_dataset_random(train_val, 50000, seed=0)

関数の第1引数が分割したい対象のデータセットオブジェクト,第2引数が1つ目のデータセットの要素数,第3引数がランダムな抽出を行う際に用いられる乱数シード(これは省略可)となります.第3引数のseedとして同じ値を指定すると,再実行した際にデータセットを同じように分割するようになります.それでは,それぞれのデータセットの中に入っているデータの数を確認してみましょう.

[6]:
print('Training dataset size:', len(train))
print('Validation dataset size:', len(valid))
Training dataset size: 50000
Validation dataset size: 10000

4.2.3. Iteratorの作成

次に,さきほど準備したデータセットオブジェクトから,幾つかのデータ(入力とラベルのペア)を束ねて学習モデルに次々に渡す,Iteratorという機能を紹介します.なぜIteratorの機能が必要かというと,ニューラルネットワークのパラメータを更新する際に利用される,確率的勾配降下法(Stochastic Gradient Descent, SGD)をはじめとする最適化手法では,一つのデータだけを元に更新する処理を繰り返すのではなく,幾つかのデータを束ねた ミニバッチ を元に計算していくのが一般的となっているためです(ミニバッチ計算が一般的である理由としては,勾配のミニバッチ平均を計算することでパラメータ更新が安定することや,GPUなどを用いた並列化がしやすいこと等が挙げられます).

Iteratorは,さきほど作成したデータセットオブジェクトを引数として指定し,next()メソッドを呼ぶことで新しいミニバッチを返してくれます.データセット内のデータすべてを1度ずつ学習に利用し終えた時点のことを 1エポック(epoch) と呼びます.Iteratorの内部では,学習中に何エポックまで学習を行ったか,などの情報が逐次記録されており,データセット内のデータを何度も使って学習のループを回すようなコードを簡単に書くことができるようになります.

データセットオブジェクトからイテレータを作るには,以下のようにします.

[7]:
from chainer import iterators

batchsize = 128

train_iter = iterators.SerialIterator(train, batchsize)
valid_iter = iterators.SerialIterator(
    valid, batchsize, repeat=False, shuffle=False)
test_iter = iterators.SerialIterator(
    test, batchsize, repeat=False, shuffle=False)

今,学習データセット用のイテレータ(train_iter)と,検証データセット用のイテレータ(valid_iter),および学習したネットワークの評価に用いるテストデータセット用のイテレータ(test_iter)の計3つを作成しました.ここではbatchsize = 128としているため,作成した3つのイテレータはnext()メソッドが(train_iter.next()のように)呼ばれると,128枚の数字画像データを一括りにして返します.実際にnext()の返り値を調べてみましょう.

[8]:
minibatch = train_iter.next()

このminibatchという変数は,(img, label)というタプルが128個(ミニバッチサイズだけ)並んだリストになっています.実際に,このリストの長さが128であることを確認してみましょう.

[9]:
print('batchsize:', len(minibatch))
batchsize: 128

次に,このminibatchというリストの一つ目の要素(画像とラベルを持つタプルになっているはずです)をminibatch[0]として取り出してみます.

[10]:
x, t = minibatch[0]

print('x:', x.shape)
print('t:', t.shape)
x: (784,)
t: ()

そのときの返り値である2つの配列 xt のshapeを調べてみると,データはそれぞれ長さ784のベクトルとして格納されており,正解ラベルはスカラー値となっています.784は,\(28 \times 28\)で,28ピクセル四方の画像データの画素値を1列に並べたものになっています.

4.2.3.1. SerialIteratorについて

Chainerにいくつか用意されているイテレータの一種であるSerialIteratorは,データセットの中のデータを順番に取り出してくる最もシンプルなイテレータです.SerialIterator のコンストラクタ(クラスをインスタンス化するタイミングで呼ばれるメソッド)の引数にデータセットオブジェクトと,バッチサイズを取ります.このとき,渡したデータセットオブジェクトから,データを繰り返し読み出す必要がある場合はrepeat引数をTrueとし,1周が終わったらそれ以上データを取り出したくない場合はこれをFalseとします.これは,主にvalidation用のデータセットに対して使うフラグです.デフォルトでは,Trueになっています.また,shuffle引数にTrueを渡すと,データセットから取り出されてくるデータの順番をエポックごとにランダムに変更します.SerialIteratorの他にも,マルチプロセスで高速にデータを処理できるようにしたMultiprocessIteratorMultithreadIteratorなど,複数のイテレータが用意されています.詳しくは以下を見てください.

4.2.4. ネットワークの定義

それでは,学習させるネットワークを定義してみましょう.今回は,全結合層のみからなるニューラルネットワーク(多層パーセプトロン)を作ることにして,中間層のユニット数は100とします.今回用いるMNISTデータセットは0〜9までの数字のいずれかを意味する10種のラベルを持つことから,出力ユニット数は10とします.

ここで,ネットワークを定義するために必要なLink, Function, Chainについて簡単に説明します.

4.2.4.1. LinkとFunction

Chainerでは,ニューラルネットワークの各層を,LinkFunctionに区別します.

  • Linkは,パラメータを持つ関数です.
  • Functionは,パラメータを持たない関数です.

これらを組み合わせてネットワークを記述します.パラメータを持つ層は,chainer.linksモジュール以下に用意されています.例えば chainer.links.Linear は,前章で説明した全結合層に対応しており,内部に Wb という学習できるパラメータが保持されています.パラメータを持たない層は,chainer.functionsモジュール以下に用意されています.これらに簡単にアクセスするために,

import chainer.links as L
import chainer.functions as F

と別名を与えて,L.Convolution2D(...)F.relu(...)のように用いる慣習がありますが,特にこれが決まった書き方というわけではありません.

4.2.4.2. Chain

Chainは,パラメータを持つ層(Link)をまとめておくためのクラスです.パラメータを持つということは,基本的にネットワークの学習の際にそれらを更新していく必要があるということです(更新されないパラメータを持たせることもできます).Chainerでは,モデルのパラメータの更新は,Optimizerという機能が担います.その際,更新すべき全てのパラメータを簡単に発見できるように,Chainで一箇所にまとめておきます.

4.2.4.3. 同じ結果を保証する

ネットワークを書き始める際に乱数シードを固定すると,本記事とほぼ同様の結果が再現できるようになります.(cuDNNが有効になっている環境下でより厳密に計算結果の再現性を保証したい場合は,chainer.config.cudnn_deterministicというConfiguringオプションについて知る必要があります.こちらのドキュメントを参照してください:chainer.config.cudnn_deterministic

[11]:
import random
import numpy
import chainer

def reset_seed(seed=0):
    random.seed(seed)
    numpy.random.seed(seed)
    if chainer.cuda.available:
        chainer.cuda.cupy.random.seed(seed)

reset_seed(0)

4.2.4.4. Chainを継承したネットワークの定義

Chainerでは,ネットワークは Chain クラスを継承したクラスとして定義されることが一般的です. Chain を継承することで,中間層のユニット数=100,出力ユニット数=10とした3層の多層パーセプトロンは以下のように書くことができます.

[12]:
import chainer
import chainer.links as L
import chainer.functions as F

class MLP(chainer.Chain):

    def __init__(self, n_mid_units=100, n_out=10):
        super(MLP, self).__init__()

        # パラメータを持つ層の登録
        with self.init_scope():
            self.l1 = L.Linear(None, n_mid_units)
            self.l2 = L.Linear(n_mid_units, n_mid_units)
            self.l3 = L.Linear(n_mid_units, n_out)

    def forward(self, x):
        # データを受け取った際のforward計算を書く
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)

gpu_id = 0  # CPUを用いる場合は,この値を-1にしてください

net = MLP()

if gpu_id >= 0:
    net.to_gpu(gpu_id)

継承した MLP クラスのコンストラクタ内で with self.init_scope() が呼ばれており,その中でネットワークに登場するLink (具体的には,全結合層の L.Linear )が定義されています.このような形で記述することで,Optimizer はこれらが最適化対象となるパラメータを持つ層であると自動的に解釈してくれるようになります.

また, forward というメソッドには,関数の名前の通り,ネットワークの順伝播を記述します.forward の引数としてデータ x を受け取り,出力として順伝播の計算結果を返すようにすることで, MLP クラスをインスタンス化して作成されたオブジェクトを,関数のように使えるようになります.(例:output = net(data)

Chainerには数多くの FunctionLink が用意されています.ぜひ一度以下の一覧のページを見てみてください.

Linkには,ニューラルネットワークによく用いられる全結合層や畳み込み層,LSTMなどに加えて,ResNetや,VGGなどの有名なネットワーク構造も登録されています.また,Functionには,ReLUなどの活性化関数や,画像の大きさをresizeする関数,サイン・コサインのような関数を始め,ネットワークの要素として使える関数が登録されています.Define-by-Runでは,データをネットワークに入力して順伝播計算を行ったあとに,データに適用された関数(パラメータあり・なし両方)の履歴をたどり直すことで,バックプロパゲーションによる勾配計算を行うパスを取得するため,パラメータを持たない関数であっても chainer.functions に含まれているものを繋げて用いる必要があります.

4.2.4.5. GPUで実行するには

深層学習で用いられるような多くのパラメータを持ったネットワークの学習には,GPUを用いることが一般的となっています.GPUを使うと,行列演算などの一部の処理をCPUに比べとても高速に行うことができます.Chainerで計算をGPUで行う方法は簡単です.Chainクラスはto_gpuメソッドを持ち,この引数にGPU IDを指定すると,指定したGPU IDのメモリ上にネットワークの全パラメータを転送します.こうしておくと,順伝播も学習の際のパラメータ更新なども全てGPU上で行われるようになります.GPU IDとして-1を指定すると,CPUを使用します.

4.2.4.6. 入力側ユニット数の自動計算

上のネットワーク定義で,最初のLinear層は第一引数にNoneが渡されています.このように引数を指定すると,データが最初にその層に入力されたタイミングで,自動的に必要な数の入力側のユニット数を判断し, n_input \(\times\) n_mid_units の大きさの行列を作成し,学習対象パラメータとして保持します.これは後々,畳み込み層を全結合層の前に配置する際などに便利な機能となるため,覚えておいてください.

4.2.5. 最適化手法の選択

それでは,上で定義したネットワークをMNISTデータセットを使って訓練してみましょう.学習時に用いる最適化の手法は数多く提案されていますが,Chainerは多くの手法を同一のインターフェースで利用できるよう,Optimizerという機能でそれらを提供しています.chainer.optimizersモジュール以下に定義されています.一覧はこちらにあります:

ここでは最もシンプルな勾配降下法の手法であるoptimizers.SGDを用います.Optimizerのオブジェクトには,setupメソッドを使ってモデル(Chainオブジェクト)を渡します.こうすることでOptimizerに,何を最適化すればいいか把握させることができます.

他にもいろいろな最適化手法が手軽に試せるので,色々と試してみて結果の変化を見てみてください.例えば,下のchainer.optimizers.SGDのうちSGDの部分をMomentumSGD, RMSprop, Adamなどに変えるだけで,最適化手法の違いがどのような学習曲線(ロスカーブとも言う.目的関数の値のプロットのこと)の違いを生むかなどを簡単に調べることができます.最適化の手法によっては,人が与える必要があった学習率を適切に自動決定するものもあります.

[13]:
from chainer import optimizers

optimizer = optimizers.SGD(lr=0.01).setup(net)

4.2.5.1. 学習率(learning rate)

今回はSGDのlrという引数に \(0.01\) を与えました.この値は学習率として知られ,モデルをうまく訓練して良いパフォーマンスを発揮させるために調整する必要がある重要なハイパーパラメータとして知られています.ハイパーパラメータは学習されるパラメータとは異なり人が手で与える学習の設定に関するものやネットワークの構造に関するもののことを指します.

4.2.6. 学習の開始

今回は0〜9の数字を区別する分類問題なので,softmax_cross_entropyという損失関数を使って最小化すべき損失を計算します.Softmax関数は,\(d\)次元のベクトル\({\bf y} \in \mathbb{R}^d\)が与えられたとき,その各次元の値の合計が1になるように正規化することができます.すなわち,確率分布のような出力を任意の実数ベクトルから作ることができます.\({\bf y}\)\(i\)番目の次元を\(y_i\)と書くと,Softmax関数は

\[p_i = \frac{\exp(y_i)}{\sum_{j=1}^d \exp(y_j)}\]

と表せます.これによって正規化された出力ベクトルを入力が各クラスに所属する確率を表しているものと考え,正解の1-hotベクトルとの間で前章で説明した交差エントロピーを計算するのが softmax_cross_entropy 関数です.

まずネットワークにデータを渡し,順伝播により予測値を計算します.そして,この予測値と入力データに対応する正解ラベルを損失関数に渡して損失(最小化したい値)を計算をします.損失は,chainer.Variableのオブジェクトとして得られます.このVariableは,過去の計算の履歴を覚えていて,辿れるようになっています.この仕組みが,Define-by-Run [Tokui 2015]とよばれる発明の中心的な役割を果たしています.

計算した損失に対する勾配をネットワークに逆向きに計算していく処理は,Chainerではネットワークが出力したVariableから,backwardメソッドを呼ぶだけで実現できます.これを呼ぶことで,誤差逆伝播用の計算グラフを構築し,途中のパラメータの勾配を連鎖率を使って計算してくれます.(詳しくは日本ソフトウェア科学会におけるチュートリアルの資料をご覧ください.)

最後に,計算された各パラメータに対する勾配を用いて,Optimizerによってネットワークパラメータの更新(=学習)が行われます.

まとめると,一連の更新処理の中で行われるのは,以下の4項目となります.

  1. ネットワークにデータを渡して順伝播を計算し,出力yを得る
  2. 出力yと正解ラベルtを使って,最小化すべき損失をsoftmax_cross_entropy関数で計算する
  3. softmax_cross_entropy関数の出力(Variable)のbackwardメソッドを呼んで,ネットワークの全てのパラメータの勾配を誤差逆伝播法で計算する
  4. Optimizerのupdateメソッドを呼び,3.で計算した勾配を使って全パラメータを更新する

パラメータの更新は,上記ステップを繰り返すことで行われます.一度のパラメータ更新に用いられるデータは,ネットワークに入力された,ミニバッチとして束ねられたデータのみです.次々と新しいミニバッチを入力し,上記のステップを繰り返すことで,データセット全体を用いて学習を行います.この過程を学習ループと呼んでいます.

4.2.6.1. 目的関数

目的関数として,例えば分類問題ではなく回帰問題を解きたいような場合,F.softmax_cross_entropyの代わりにF.mean_squared_errorなどを用いることもできます.他にも,いろいろな問題設定に対応するために様々な損失関数がChainerには用意されています.こちらからその一覧を見ることができます:

4.2.6.2. 学習ループのコード

[14]:
import numpy as np
from chainer.dataset import concat_examples
from chainer.cuda import to_cpu

max_epoch = 10

while train_iter.epoch < max_epoch:

    # ---------- 学習の1イテレーション ----------
    train_batch = train_iter.next()
    x, t = concat_examples(train_batch, gpu_id)

    # 予測値の計算
    y = net(x)

    # 損失の計算
    loss = F.softmax_cross_entropy(y, t)

    # 勾配の計算
    net.cleargrads()
    loss.backward()

    # パラメータの更新
    optimizer.update()
    # --------------- ここまで ----------------

    # 1エポック終了ごとにValidationデータに対する予測精度を測って,
    # モデルの汎化性能が向上していることをチェックしよう
    if train_iter.is_new_epoch:  # 1 epochが終わったら

        # 損失の表示
        print('epoch:{:02d} train_loss:{:.4f} '.format(
            train_iter.epoch, float(to_cpu(loss.data))), end='')

        valid_losses = []
        valid_accuracies = []
        while True:
            valid_batch = valid_iter.next()
            x_valid, t_valid = concat_examples(valid_batch, gpu_id)

            # Validationデータをforward
            with chainer.using_config('train', False), \
                    chainer.using_config('enable_backprop', False):
                y_valid = net(x_valid)

            # 損失を計算
            loss_valid = F.softmax_cross_entropy(y_valid, t_valid)
            valid_losses.append(to_cpu(loss_valid.array))

            # 精度を計算
            accuracy = F.accuracy(y_valid, t_valid)
            accuracy.to_cpu()
            valid_accuracies.append(accuracy.array)

            if valid_iter.is_new_epoch:
                valid_iter.reset()
                break

        print('val_loss:{:.4f} val_accuracy:{:.4f}'.format(
            np.mean(valid_losses), np.mean(valid_accuracies)))

# テストデータでの評価
test_accuracies = []
while True:
    test_batch = test_iter.next()
    x_test, t_test = concat_examples(test_batch, gpu_id)

    # テストデータをforward
    with chainer.using_config('train', False), \
            chainer.using_config('enable_backprop', False):
        y_test = net(x_test)

    # 精度を計算
    accuracy = F.accuracy(y_test, t_test)
    accuracy.to_cpu()
    test_accuracies.append(accuracy.array)

    if test_iter.is_new_epoch:
        test_iter.reset()
        break

print('test_accuracy:{:.4f}'.format(np.mean(test_accuracies)))
epoch:01 train_loss:0.9100 val_loss:0.9743 val_accuracy:0.8018
epoch:02 train_loss:0.5396 val_loss:0.5336 val_accuracy:0.8645
epoch:03 train_loss:0.4012 val_loss:0.4230 val_accuracy:0.8847
epoch:04 train_loss:0.3329 val_loss:0.3741 val_accuracy:0.8941
epoch:05 train_loss:0.4588 val_loss:0.3455 val_accuracy:0.9002
epoch:06 train_loss:0.2481 val_loss:0.3274 val_accuracy:0.9074
epoch:07 train_loss:0.3306 val_loss:0.3109 val_accuracy:0.9118
epoch:08 train_loss:0.3801 val_loss:0.2990 val_accuracy:0.9145
epoch:09 train_loss:0.2974 val_loss:0.2886 val_accuracy:0.9180
epoch:10 train_loss:0.3216 val_loss:0.2803 val_accuracy:0.9204
test_accuracy:0.9234

val_accuracyに着目してみると,最終的におよそ92%程度の精度で手書きの数字が分類できるようになりました.ここで言う精度とは,Validationデータセット中に \(N\) 個のデータがあり分類結果が正しかったものが \(M\) 個あるとすると \(M/N\) を指します.学習中は,各ループの終わりに始めに取り分けておいたValidationデータセットを使って精度をはかることで,モデルの汎化性能をチェックしています.汎化性能とは,主に未知のデータに対する性能の高さのことを意味します.学習終了後には,テスト用のデータセットを用いて,学習が完了したネットワークの評価を行います.テストデータでの評価結果は,およそ92.37%の正解率となりました.

4.2.6.3. ValidationやTestを行う際の注意点

学習終了後の最終的な評価には,ハイパーパラメータ調整などにも用いられるValidationデータセットとはさらに別のTestデータセットを用います.TestデータセットはTrainingデータセットともValidationデータセットともデータの重複がないように用意しておきます.

さて,これまでは主に,「学習」のやり方について説明してきましたが,「評価」を行う際には注意すべき点があります.なぜなら,一部の関数や,計算過程において,学習時と評価時でその挙動が異なるためです.以下では,それらの挙動の違いを制御するための方法について説明します.

4.2.6.3.1. chainer.using_config('train', False)

先程の例では,学習時と推論時で動作が異なる関数は含まれていませんでしたが,Validationやテストのために推論を行うときは以下のように,chainer.using_config('train', False)をwith構文と共に使うことで,その中では対応する関数が推論モードとして動作することになります.これによって,学習時と推論時で挙動が異なる関数などが正しく推論のための動作をするようになります(例えば,Dropoutなど).詳しくはこちらの train の項をお読みください:Configuration Keys

with chainer.using_config('train', False):
    --- 何か推論処理 ---
4.2.6.3.2. chainer.using_config('enable_backprop', False)

評価のみ行うことを考えた場合,出力の計算後に損失関数の各パラメータについての勾配の情報は不要なため,chainer.using_config('enable_backprop', False)とすることで,無駄な計算グラフの構築が行われず,メモリ消費量を節約することができます.詳しくはこちらの enable_backprop の項をお読みください:Configuration Keys

4.2.6.3.3. ChainerのConfig

Chainerにはこの他にも,いくつかのグローバルなConfigが用意されています.また,chainer.config以下にユーザが自由な設定値を置くこともできます.詳しくはこちらを一読してください:Configuring Chainer

4.2.7. 学習済みモデルの保存

学習が終了後,その結果を保存します.Chainerには,2種類のフォーマットで学習済みネットワークをファイルに保存する機能が用意されています.一つはHDF5形式,もう一つはNumPyのNPZ形式で,ネットワークを保存します.今回は,追加ライブラリのインストールが必要なHDF5ではなく,NumPy標準機能で提供されているシリアライズ機能(numpy.savez())を利用したNPZ形式でのモデルの保存を行います.

[15]:
from chainer import serializers

serializers.save_npz('my_mnist.model', net)
[16]:
# 保存されていることを確認
%ls -la my_mnist.model
-rw-r--r-- 1 root root 334084 Dec  9 11:13 my_mnist.model

4.2.8. 保存したモデルを読み込んで推論

学習が終了して保存したモデルを読み込み,推論を行う方法について説明します.はじめに,学習に利用したネットワークを再度インスタンス化して,そこにさきほど保存したNPZファイルを読み込ませます.

[17]:
# まず同じネットワークのオブジェクトを作る
infer_net = MLP()

# そのオブジェクトに保存済みパラメータをロードする
serializers.load_npz('my_mnist.model', infer_net)

以上で準備が整いました.それでは,試しにテストデータの中から一つ目の画像を取ってきて,それに対する分類を行ってみましょう.

[18]:
gpu_id = 0  # CPUで計算をしたい場合は,-1を指定してください

if gpu_id >= 0:
    infer_net.to_gpu(gpu_id)

# 1つ目のテストデータを取り出します
x, t = test[0]  #  tは使わない

# どんな画像か表示してみます
plt.imshow(x.reshape(28, 28), cmap='gray')
plt.show()

# ミニバッチの形にする(複数の画像をまとめて推論に使いたい場合は,サイズnのミニバッチにしてまとめればよい)
print('元の形:', x.shape, end=' -> ')

x = x[None, ...]

print('ミニバッチの形にしたあと:', x.shape)

# ネットワークと同じデバイス上にデータを送る
x = infer_net.xp.asarray(x)

# モデルのforward関数に渡す
with chainer.using_config('train', False), chainer.using_config('enable_backprop', False):
    y = infer_net(x)

# Variable形式で出てくるので中身を取り出す
y = y.array

# 結果をCPUに送る
y = to_cpu(y)

# 予測確率の最大値のインデックスを見る
pred_label = y.argmax(axis=1)

print('ネットワークの予測:', pred_label[0])
../_images/notebooks_04_Introduction_to_Chainer_53_0.png
元の形: (784,) -> ミニバッチの形にしたあと: (1, 784)
ネットワークの予測: 7

ネットワークの予測は7でした.画像を見る限り,正しく予測できていることが確認できます.

4.3. Trainerの使用方法

Chainerは,これまで書いてきたような学習ループを隠蔽するTrainerという機能を提供しています.これを使うと,学習ループを自ら書く必要がなくなり,また便利な拡張機能(Extention)を使うことで,学習過程での学習曲線の可視化や,ログの保存なども簡単に行うことができます.

4.3.1. データセット・Iterator・ネットワークの準備

データセット,Iterator,ネットワークは,Trainerを使用する場合にも同様に準備します.

[19]:
reset_seed(0)

train_val, test = mnist.get_mnist()
train, valid = split_dataset_random(train_val, 50000, seed=0)

batchsize = 128

train_iter = iterators.SerialIterator(train, batchsize)
valid_iter = iterators.SerialIterator(valid, batchsize, False, False)
test_iter = iterators.SerialIterator(test, batchsize, False, False)

gpu_id = 0  # CPUを用いたい場合は,-1を指定してください

net = MLP()

if gpu_id >= 0:
    net.to_gpu(gpu_id)

4.3.2. Updaterの準備

学習ループを自分で書く場合の学習ステップについて再度確認すると,「データセットからミニバッチを作成」「ネットワークに入力して予測を出力」「正解と比較し誤差を計算」「バックワード(誤差逆伝播)を実行」「Optimizerによってパラメータを更新」という一連のステップを,以下のように書いていました.

# ---------- 学習の1イテレーション ----------
train_batch = train_iter.next()
x, t = concat_examples(train_batch, gpu_id)

# 予測値の計算
y = net(x)

# 損失の計算
loss = F.softmax_cross_entropy(y, t)

# 勾配の計算
net.cleargrads()
loss.backward()

# パラメータの更新
optimizer.update()

Chainerの機能として提供されているUpdaterを用いることで,これらの一連の処理を簡単に書けるようになります.UpdaterにはIteratorOptimizerを渡します. Iteratorはデータセットオブジェクトを持っているため,そこからミニバッチを作成します.Optimizerは最適化対象のネットワークを持っているため,それを使って順伝播と誤差計算・パラメータのアップデートをすることができます.従って,この2つを渡すことで,Updater内で全ての処理が完結します.さっそく,Updaterオブジェクトを作成してみましょう.

[20]:
from chainer import training

gpu_id = 0  # CPUを使いたい場合は-1を指定してください

# ネットワークをClassifierで包んで,損失の計算などをモデルに含める
net = L.Classifier(net)

# 最適化手法の選択
optimizer = optimizers.SGD(lr=0.01).setup(net)

# UpdaterにIteratorとOptimizerを渡す
updater = training.StandardUpdater(train_iter, optimizer, device=gpu_id)

4.3.2.1. 損失計算のためのChain

ここでは,ネットワークをL.Classifierで包んでいます.L.Classifierは,渡されたネットワーク自体をpredictorというattributeに持ち,損失計算を行う機能を追加してくれます.こうすることで,net()はデータxだけでなくラベルtも取るようになり,受け取ったデータをpredictorに通して予測値を計算し,正解ラベルtと比較して損失のVariableを返します.損失関数として何を用いるかはデフォルトではF.softmax_cross_entropyとなっていますが,L.Classifierの引数lossfunに損失計算を行う関数を渡してやれば変更することができ,(Classifierという名前ながら)回帰問題などの損失計算機能の追加にも使うことができます.(L.Classifier(net, lossfun=L.mean_squared_error, compute_accuracy=False)のようにする)

StandardUpdaterは前述のようなUpdaterの担当する処理を遂行するための最もシンプルなクラスです.この他にも複数のGPUを用いるためのParallelUpdaterなどが用意されています.

4.3.3. Trainerの準備

実際に学習ループ部分を隠蔽しているのはUpdaterですが,TrainerはさらにUpdaterを受け取って学習全体の管理を行う機能を提供しています.例えば,データセットを何周したら学習を終了するか(stop_trigger) や,途中の損失の値をどのファイルに保存したいか学習曲線を可視化した画像ファイルを保存するかどうかなど,学習全体の設定として必須・もしくはあると便利な色々な機能を提供しています.

必須なものとしては学習終了のタイミングを指定するstop_triggerがありますが,これはTrainerオブジェクトを作成するときのコンストラクタで指定します.指定の方法は単純で,(長さ, 単位)という形のタプルを与えればよいだけです.「長さ」には数字を,「単位」には'iteration'もしくは'epoch'のいずれかの文字列を指定します.こうすると,たとえば100 epoch(データセット100周)で学習を終了してください,とか,1000 iteration(1000回更新)で学習を終了してください,といったことが指定できます.Trainerを作るときに,stop_triggerを指定しないと,学習は自動的には止まりません.

では,実際にTrainerオブジェクトを作ってみましょう.

[21]:
max_epoch = 10

# TrainerにUpdaterを渡す
trainer = training.Trainer(
    updater, (max_epoch, 'epoch'), out='results/mnist_result')

out引数では,この次に説明するExtensionを使って,ログファイルや損失の変化の過程を描画したグラフの画像ファイルなどを保存するディレクトリを指定しています.

Trainerと,その内側にあるいろいろなオブジェクトの関係は,図にまとめると以下のようになっています.このイメージを持っておくと自分で部分的に改造したりする際に便利だと思います.

Trainerに関連するオブジェクト間の関係図

4.3.4. TrainerにExtensionを追加

Trainerを使う利点として,

  • ログを自動的にファイルに保存(LogReport)
  • ターミナルに定期的に損失などの情報を表示(PrintReport
  • 損失を定期的にグラフで可視化して画像として保存(PlotReport)
  • 定期的にモデルやOptimizerの状態を自動シリアライズ(snapshot
  • 学習の進捗を示すプログレスバーを表示(ProgressBar
  • ネットワークの構造をGraphvizのdot形式で保存(dump_graph
  • ネットワークのパラメータの平均や分散などの統計情報を出力(ParameterStatistics

などの様々な便利な機能を簡単に利用することができる点があります.これらの機能を利用するには,Trainerオブジェクトに対してextendメソッドを使って追加したいExtensionのオブジェクトを渡すだけです.では実際に幾つかのExtensionを追加してみましょう.

[22]:
from chainer.training import extensions

trainer.extend(extensions.LogReport())
trainer.extend(extensions.snapshot(filename='snapshot_epoch-{.updater.epoch}'))
trainer.extend(extensions.Evaluator(valid_iter, net, device=gpu_id), name='val')
trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy', 'val/main/loss', 'val/main/accuracy', 'l1/W/data/std', 'elapsed_time']))
trainer.extend(extensions.ParameterStatistics(net.predictor.l1, {'std': np.std}))
trainer.extend(extensions.PlotReport(['l1/W/data/std'], x_key='epoch', file_name='std.png'))
trainer.extend(extensions.PlotReport(['main/loss', 'val/main/loss'], x_key='epoch', file_name='loss.png'))
trainer.extend(extensions.PlotReport(['main/accuracy', 'val/main/accuracy'], x_key='epoch', file_name='accuracy.png'))
trainer.extend(extensions.dump_graph('main/loss'))

4.3.4.1. LogReport

epochiterationごとのloss, accuracyなどを自動的に集計し,logというファイル名で保存します.

4.3.4.2. snapshot

Trainerオブジェクトを指定されたタイミング(デフォルトでは1エポックごと)で保存します.Trainerオブジェクトは上述のようにUpdaterを持っており,この中にOptimizerとモデルが保持されているため,このExtensionでスナップショットをとっておけば,その時点から学習を再開させたり,学習済みモデルを使った推論などが可能になります.

4.3.4.3. dump_graph

指定されたVariableオブジェクトから辿れる計算グラフをGraphvizのdot形式で保存します.

4.3.4.4. Evaluator

評価用のデータセットのIteratorと,学習に使うモデルのオブジェクトを渡しておくことで,学習中のモデルを指定されたタイミングで評価用データセットを用いて評価します.内部では,chainer.config.using_config('train', False)が自動的に行われます.backprop_enableFalseにすることは行われないため,メモリ使用効率はデフォルトでは最適ではありませんが,基本的にはEvaluatorを使えば評価を行えるという点において問題はありません.

4.3.4.5. PrintReport

LogReportと同様に集計された値を標準出力に出力します.この際,どの値を出力するかをリストの形で与えます.

4.3.4.6. PlotReport

引数のリストで指定された値の変遷をmatplotlibライブラリを使ってグラフに描画し,出力ディレクトリにfile_name引数で指定されたファイル名で画像として保存します.

4.3.4.7. ParameterStatistics

指定したレイヤ(Link)が持つパラメータの平均・分散・最小値・最大値などなどの統計情報を計算して,ログに保存します.パラメータが発散していないかなどをチェックするのに便利です.


これらのExtensionは,ここで紹介した以外にも,例えばtriggerによって個別に作動するタイミングを指定できるなどのいくつかのオプションを持っており,より柔軟に組み合わせることができます.詳しくは公式のドキュメントを見てください.

4.3.5. 学習の開始 (Trainer利用)

学習を開始するために,Trainerオブジェクトのメソッドrunを実行してください.

[23]:
trainer.run()
epoch       main/loss   main/accuracy  val/main/loss  val/main/accuracy  l1/W/data/std  elapsed_time
1           1.66917     0.599904       0.938911       0.806764           0.0359232      4.00874
2           0.673337    0.843211       0.519283       0.86699            0.0366054      6.59789
3           0.459913    0.878686       0.414855       0.887757           0.0370351      9.17739
4           0.38953     0.893262       0.370488       0.896855           0.037301       11.7382
5           0.353165    0.901215       0.342328       0.90447            0.03749        14.4128
6           0.33014     0.90609        0.32212        0.90981            0.037639       17.1353
7           0.312328    0.910906       0.30679        0.913172           0.0377671      19.834
8           0.298127    0.914704       0.295095       0.915744           0.0378811      22.4303
9           0.28583     0.917659       0.284156       0.918513           0.0379864      25.0918
10          0.275227    0.921096       0.274761       0.921677           0.0380852      27.7848

学習ループを自分で書いた場合よりも遥かに簡単に,同様の結果を得ることができました.さらに,Extensionの機能を利用することで,様々なスコアや,学習曲線の可視化も自動で出力されます.

では,保存されている損失のグラフを確認してみましょう.

[24]:
from IPython.display import Image
Image(filename='results/mnist_result/loss.png')
[24]:
../_images/notebooks_04_Introduction_to_Chainer_77_0.png

精度のグラフも見てみましょう.

[25]:
Image(filename='results/mnist_result/accuracy.png')
[25]:
../_images/notebooks_04_Introduction_to_Chainer_79_0.png

もう少し学習を続ければ,さらに精度の向上が期待できそうです.

最後に,dump_graphというExtensionによって出力された計算グラフのファイルを,Graphvizで画像化してみましょう.

[26]:
!dot -Tpng results/mnist_result/cg.dot -o results/mnist_result/cg.png
[27]:
Image(filename='results/mnist_result/cg.png')
[27]:
../_images/notebooks_04_Introduction_to_Chainer_82_0.png

データやパラメータが関数に次々と渡され,損失が出力されるまでの一連の計算過程が確認できます.

4.3.6. テストデータ評価

Validationデータに対する評価を学習中に行うために使用されるEvaluatorは,Trainerと関係なく独立して使うこともできます.以下のようにしてIteratorとネットワークのオブジェクト(net),使用するデバイスIDを渡してEvaluatorオブジェクトを作成し,これを関数として実行するだけです.

[28]:
test_evaluator = extensions.Evaluator(test_iter, net, device=gpu_id)
results = test_evaluator()
print('Test accuracy:', results['main/accuracy'])
Test accuracy: 0.9250395

4.3.7. 学習済みモデルで推論する

それでは,Trainer Extensionのsnapshotが保存した学習済みパラメータを読み込んで,以前と同様に1番目のテストデータで推論を行ってみましょう.

ここで一点注意が必要ですが,snapshotが保存するnpzファイルはTrainer全体のスナップショットとなっており,学習の再開に必要となるextensionの内部のパラメータなども一緒に保存されています.しかし,今回はネットワークのパラメータだけを読み込めば良いので, serializers.load_npz()path引数にネットワーク部分までのパスを指定します.こうすることで,ネットワークのオブジェクトにパラメータだけを読み込ませることができます.

[29]:
reset_seed(0)

infer_net = MLP()
serializers.load_npz(
    'results/mnist_result/snapshot_epoch-10',
    infer_net, path='updater/model:main/predictor/')

if gpu_id >= 0:
    infer_net.to_gpu(gpu_id)

x, t = test[0]
plt.imshow(x.reshape(28, 28), cmap='gray')
plt.show()

x = infer_net.xp.asarray(x[None, ...])
with chainer.using_config('train', False), chainer.using_config('enable_backprop', False):
    y = infer_net(x)
y = to_cpu(y.array)

print('予測ラベル:', y.argmax(axis=1)[0])
../_images/notebooks_04_Introduction_to_Chainer_87_0.png
予測ラベル: 7

無事正解できていることが確認できました.

4.4. 新しいネットワークの利用

ここでは,MNISTデータセットではなくCIFAR10という32x32サイズの小さなカラー画像に10クラスのいずれかのラベルがついたデータセットを用いて,いろいろなモデルを自分で書いて試行錯誤する流れを体験してみます.

airplane automobile bird cat deer dog frog horse ship truck
Airplane Automobile Bird Cat Deer Dog Frog Horse Ship Truck

4.4.1. 新しいネットワークの定義

ここでは,さきほど試した全結合層だけからなるネットワークではなく,前章で紹介した,畳込み層を持つネットワークを定義してみます.3つの畳み込み層を持ち,2つの全結合層がそのあとに続いています.

[30]:
class MyNet(chainer.Chain):

    def __init__(self, n_out):
        super(MyNet, self).__init__()
        with self.init_scope():
            self.conv1 = L.Convolution2D(None, 32, 3, 3, 1)
            self.conv2 = L.Convolution2D(32, 64, 3, 3, 1)
            self.conv3 = L.Convolution2D(64, 128, 3, 3, 1)
            self.fc4 = L.Linear(None, 1000)
            self.fc5 = L.Linear(1000, n_out)

    def forward(self, x):
        h = F.relu(self.conv1(x))
        h = F.relu(self.conv2(h))
        h = F.relu(self.conv3(h))
        h = F.relu(self.fc4(h))
        h = self.fc5(h)
        return h

4.4.2. 学習

ここで,あとから別のネットワークも簡単に同じ設定で訓練できるよう,train関数を作っておきます.これは,

  • ネットワークのオブジェクト
  • バッチサイズ
  • 使用するGPU ID
  • 学習を終了するエポック数
  • データセットオブジェクト
  • 学習率の初期値
  • 学習率減衰のタイミング

などを渡すと,内部でTrainerを用いて渡されたデータセットを使ってネットワークを訓練し,学習が終了した状態のネットワークを返してくれる関数です.Trainer.run()が終了した後に,テストデータセットを使って評価まで行ってくれます.先程のMNISTでの例と違い,最適化手法にはMomentumSGDを用い,ExponentialShiftというExtentionを使って,指定したタイミングごとに学習率を減衰させるようにしてみます.

また,ここではcifar.get_cifar10()が返す学習用データセットのうち9割のデータをtrain,残りの1割をvalidとして使うようにしています.

このtrain関数を用いて,上で定義したMyNetモデルを訓練してみます.

[31]:
from chainer.datasets import cifar


def train(network_object, batchsize=128, gpu_id=0, max_epoch=20, train_dataset=None, valid_dataset=None, test_dataset=None, postfix='', base_lr=0.01, lr_decay=None, snapshot=None):

    # 1. Dataset
    if train_dataset is None and valid_dataset is None and test_dataset is None:
        train_val, test = cifar.get_cifar10()
        train_size = int(len(train_val) * 0.9)
        train, valid = split_dataset_random(train_val, train_size, seed=0)
    else:
        train, valid, test = train_dataset, valid_dataset, test_dataset

    # 2. Iterator
    train_iter = iterators.MultiprocessIterator(train, batchsize)
    valid_iter = iterators.MultiprocessIterator(valid, batchsize, False, False)

    # 3. Model
    net = L.Classifier(network_object)

    # 4. Optimizer
    optimizer = optimizers.MomentumSGD(lr=base_lr).setup(net)
    optimizer.add_hook(chainer.optimizer.WeightDecay(0.0005))

    # 5. Updater
    updater = training.StandardUpdater(train_iter, optimizer, device=gpu_id)

    # 6. Trainer
    trainer = training.Trainer(updater, (max_epoch, 'epoch'), out='results/{}_cifar10_{}result'.format(network_object.__class__.__name__, postfix))

    # 7. Trainer extensions
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.observe_lr())
    trainer.extend(extensions.snapshot(filename='snapshot_epoch_{.updater.epoch}'), trigger=(10, 'epoch'))
    trainer.extend(extensions.Evaluator(valid_iter, net, device=gpu_id), name='val')
    trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy', 'val/main/loss', 'val/main/accuracy', 'elapsed_time', 'lr']))
    trainer.extend(extensions.PlotReport(['main/loss', 'val/main/loss'], x_key='epoch', file_name='loss.png'))
    trainer.extend(extensions.PlotReport(['main/accuracy', 'val/main/accuracy'], x_key='epoch', file_name='accuracy.png'))
    if lr_decay is not None:
        trainer.extend(extensions.ExponentialShift('lr', 0.1), trigger=lr_decay)
    if snapshot is not None:
        chainer.serializers.load_npz(snapshot, trainer)
    trainer.run()
    del trainer

    # 8. Evaluation
    test_iter = iterators.MultiprocessIterator(test, batchsize, False, False)
    test_evaluator = extensions.Evaluator(test_iter, net, device=gpu_id)
    results = test_evaluator()
    print('Test accuracy:', results['main/accuracy'])

    return net
[32]:
net = train(MyNet(10), gpu_id=0)
Downloading from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz...
epoch       main/loss   main/accuracy  val/main/loss  val/main/accuracy  elapsed_time  lr
1           1.92583     0.305065       1.72466        0.39668            4.88937       0.01
2           1.60857     0.423007       1.53026        0.463281           9.25879       0.01
3           1.47209     0.46964        1.48127        0.478125           13.5662       0.01
4           1.39223     0.499911       1.39299        0.499609           18.0876       0.01
5           1.32882     0.526197       1.3789         0.511719           22.5673       0.01
6           1.26765     0.547852       1.35271        0.516406           27.1432       0.01
7           1.21327     0.568999       1.25582        0.560547           31.8979       0.01
8           1.16433     0.583984       1.22899        0.570508           36.4486       0.01
9           1.12036     0.602384       1.23554        0.565039           40.9875       0.01
10          1.07057     0.61899        1.21995        0.56543            45.5839       0.01
11          1.02992     0.636808       1.1724         0.585938           50.4524       0.01
12          0.98116     0.653112       1.19605        0.579883           55.0429       0.01
13          0.938254    0.667392       1.159          0.59375            59.4494       0.01
14          0.901819    0.681067       1.20838        0.579492           64.0684       0.01
15          0.855333    0.698287       1.19485        0.585938           68.5982       0.01
16          0.810262    0.714321       1.19381        0.583984           73.0674       0.01
17          0.764117    0.731423       1.21938        0.587109           77.5318       0.01
18          0.72205     0.743697       1.20823        0.585742           81.9437       0.01
19          0.666414    0.764712       1.23899        0.593164           86.2922       0.01
20          0.620457    0.782715       1.24922        0.597461           90.6681       0.01
Test accuracy: 0.6065071

学習が20エポックまで終わりました.損失と精度のプロットを見てみましょう.

[33]:
Image(filename='results/MyNet_cifar10_result/loss.png')
[33]:
../_images/notebooks_04_Introduction_to_Chainer_96_0.png
[34]:
Image(filename='results/MyNet_cifar10_result/accuracy.png')
[34]:
../_images/notebooks_04_Introduction_to_Chainer_97_0.png

学習データでの精度(main/accuracy)は77%程度まで到達していますが,テストデータでの損失(val/main/loss)は途中から下げ止まり,精度(val/main/accuracy)も60%前後で頭打ちになってしまっています.表示されたログの最後の行を確認すると,テストデータでの精度も同様に60%程度となっています.学習データでは精度が良いが, テストデータでは精度が良くない場合,モデルが学習データにオーバーフィッティングしていると考えられます.

4.4.3. 学習済みネットワークを使った予測

テスト精度は60%程度でしたが,試しにこの学習済みネットワークを使っていくつかのテスト画像を分類させてみましょう.あとで使いまわせるようにpredict関数を作っておきます.

[35]:
cls_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
             'dog', 'frog', 'horse', 'ship', 'truck']

def predict(net, image_id):
    _, test = cifar.get_cifar10()
    x, t = test[image_id]
    net.to_cpu()
    with chainer.using_config('train', False), chainer.using_config('enable_backprop', False):
        y = net.predictor(x[None, ...]).data.argmax(axis=1)[0]

    plt.imshow(x.transpose(1, 2, 0))
    plt.show()
    print('predicted_label:', cls_names[y])
    print('answer:', cls_names[t])

for i in range(10, 15):
    predict(net, i)
../_images/notebooks_04_Introduction_to_Chainer_100_0.png
predicted_label: airplane
answer: airplane
../_images/notebooks_04_Introduction_to_Chainer_100_2.png
predicted_label: truck
answer: truck
../_images/notebooks_04_Introduction_to_Chainer_100_4.png
predicted_label: dog
answer: dog
../_images/notebooks_04_Introduction_to_Chainer_100_6.png
predicted_label: horse
answer: horse
../_images/notebooks_04_Introduction_to_Chainer_100_8.png
predicted_label: truck
answer: truck

うまく分類できているものもあれば,そうでないものもありました.ネットワークの学習に使用したデータセット上ではほぼ百発百中で正解できても,未知のデータ,すなわちテストデータセットの画像に対して高精度な予測ができなければ意味がありません.テストデータでの精度は,モデルの汎化性能に関係していると言われています.

どうすれば高い汎化性能を持つネットワークを設計し,学習することができるでしょうか?これは非常に難しい問いですが,機械学習を使った応用を考えるとき,最も重要な問いの一つです.

4.4.4. 深いネットワークの定義

では,さきほどのネットワークよりも多層のネットワークを定義してみましょう.ここでは,1層の畳み込みネットワークをConvBlock,1層の全結合ネットワークをLinearBlockとして定義し,これを数多く積み重ねることで大きなネットワークを定義してみます.

4.4.4.1. 構成要素を定義する

まず,ネットワークの構成要素となるConvBlockLinearBlockを定義してみましょう.

[36]:
class ConvBlock(chainer.Chain):

    def __init__(self, n_ch, pool_drop=False):
        w = chainer.initializers.HeNormal()
        super(ConvBlock, self).__init__()
        with self.init_scope():
            self.conv = L.Convolution2D(None, n_ch, 3, 1, 1, nobias=True, initialW=w)
            self.bn = L.BatchNormalization(n_ch)
        self.pool_drop = pool_drop

    def forward(self, x):
        h = F.relu(self.bn(self.conv(x)))
        if self.pool_drop:
            h = F.max_pooling_2d(h, 2, 2)
            h = F.dropout(h, ratio=0.25)
        return h

class LinearBlock(chainer.Chain):

    def __init__(self, drop=False):
        w = chainer.initializers.HeNormal()
        super(LinearBlock, self).__init__()
        with self.init_scope():
            self.fc = L.Linear(None, 1024, initialW=w)
        self.drop = drop

    def forward(self, x):
        h = F.relu(self.fc(x))
        if self.drop:
            h = F.dropout(h)
        return h

ConvBlockChainを継承した小さなネットワークとして定義されており,一つの畳み込み層とBatch Normalization層で構成されます.Batch Normalization層は,ネットワークの学習プロセスを安定させるために広く利用されている手法の一つで,例えば今回のように,畳み込み層の直後に挿入する形で利用されます.forwardメソッドでは,これらにデータを渡しつつ,活性化関数ReLUを適用して,さらにpool_drop引数がTrueであれば,Max PoolingとDropoutを適用するような順伝播の計算が行われます.Dropoutは,ネットワークの過学習を避けて汎化性能を上げる目的で利用される手法の一つで,層の中のノードのうち,一定割合(dropout ratioと呼ばれる)をランダムに無効にしながら学習を行います(無効にする割合はratioという引数で指定でき,何も指定しなければ50%が無効化されます).推論時は,dropout ratioを\(p\)とすると,Dropout層への入力をただ\(p\)倍して出力するだけの層として働きます.これによって,擬似的に複数のネットワークの学習結果をアンサンブル(参考:Ensemble averaging)するような効果があると言われ,汎化性能が向上する場合があります.最適化の際にモデルのパラメータに何らかの制約を与えて汎化性能を向上させるための工夫は正則化(regularization)と呼ばれ,このDropoutやパラメータの絶対値が大きくなりすぎないようにするWeight decayなどの方法が知られています.

Chainerでは,Pythonを使って書いたforward計算のコード自体がネットワークの構造を表します.すなわち,実行時にデータがどの層を通過していったか,によってネットワークそのものが定義されます.この性質によって,上記のような分岐などを含むネットワークも簡単に記述でき,柔軟かつシンプルで可読性の高いネットワーク定義が可能になります.これがDefine-by-Runの大きな特徴となっています.

4.4.4.2. 大きなネットワークの定義

次に,これらの小さなネットワークを構成要素として積み重ねて,大きなネットワークを定義してみましょう.

[37]:
class DeepCNN(chainer.ChainList):

    def __init__(self, n_output):
        super(DeepCNN, self).__init__(
            ConvBlock(64),
            ConvBlock(64, True),
            ConvBlock(128),
            ConvBlock(128, True),
            ConvBlock(256),
            ConvBlock(256),
            ConvBlock(256),
            ConvBlock(256, True),
            LinearBlock(),
            LinearBlock(),
            L.Linear(None, n_output)
        )

    def forward(self, x):
        for f in self.children():
            x = f(x)
        return x

ここで,ChainListというクラスが利用されています.このクラスはChainを継承したクラスで,いくつものLinkChainを順次呼び出していくようなネットワークを定義するときに便利です.ChainListを継承して定義されるモデルは,親クラスのコンストラクタを呼び出す際に,キーワード引数ではなく通常の引数としてLinkもしくはChainオブジェクトを渡すことができ,self.children()メソッドによって登録した順番に取り出すことができます.この特徴を使うと,forward計算が上記のように簡単に記述可能となります.

4.4.4.3. 高速化のTIPS

今回は多くの畳込み層を使う大きなネットワークを使うので,Chainerが用意してくれているcuDNNのautotune機能を有効にしてみます.やり方は簡単で,以下の二行を事前に実行しておくだけです.これを有効にすると,cuDNNが自動的に高速な畳み込みのアルゴリズムを選択するなどの実行時の調整を行ってくれるようになります.

[38]:
chainer.cuda.set_max_workspace_size(1024 * 1024 * 1024)
chainer.config.autotune = True

それでは,学習を回してみます.今回はパラメータ数も多いので,学習を停止するエポック数を100に設定します.また,学習率を0.1から始めて,30エポックごとに10分の1にするように設定します.

本来は,以下の2行を実行することで乱数シードを固定し,100エポック分上で定義した DeepCNN というクラスが表すモデルの学習ができるのですが,これは40分以上の時間を要するので,今回は事前に90エポックまで学習を進めておいた重みを読み込んで,90エポック終了時点から学習を再開し,最後の10エポックだけ実際にここで学習を回すことにします.

[39]:
!wget https://github.com/japan-medical-ai/medical-ai-course-materials/releases/download/v0.1/DeepCNN_cifar10_snapshot_epoch_90.npz

reset_seed(0)

model = train(DeepCNN(10), max_epoch=100, base_lr=0.1, lr_decay=(30, 'epoch'), snapshot='DeepCNN_cifar10_snapshot_epoch_90.npz')
--2019-12-09 11:16:00--  https://github.com/japan-medical-ai/medical-ai-course-materials/releases/download/v0.1/DeepCNN_cifar10_snapshot_epoch_90.npz
Resolving github.com (github.com)... 13.250.177.223
Connecting to github.com (github.com)|13.250.177.223|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://github-production-release-asset-2e65be.s3.amazonaws.com/153412006/4fcc1200-eeb7-11e8-8ca0-9095e5bca078?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20191209%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20191209T111601Z&X-Amz-Expires=300&X-Amz-Signature=f2d9063e8a6ec0edd8567b4ae8caeebbd09f2d1e6eabc59db80bcd2c519f2787&X-Amz-SignedHeaders=host&actor_id=0&response-content-disposition=attachment%3B%20filename%3DDeepCNN_cifar10_snapshot_epoch_90.npz&response-content-type=application%2Foctet-stream [following]
--2019-12-09 11:16:01--  https://github-production-release-asset-2e65be.s3.amazonaws.com/153412006/4fcc1200-eeb7-11e8-8ca0-9095e5bca078?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20191209%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20191209T111601Z&X-Amz-Expires=300&X-Amz-Signature=f2d9063e8a6ec0edd8567b4ae8caeebbd09f2d1e6eabc59db80bcd2c519f2787&X-Amz-SignedHeaders=host&actor_id=0&response-content-disposition=attachment%3B%20filename%3DDeepCNN_cifar10_snapshot_epoch_90.npz&response-content-type=application%2Foctet-stream
Resolving github-production-release-asset-2e65be.s3.amazonaws.com (github-production-release-asset-2e65be.s3.amazonaws.com)... 52.216.164.75
Connecting to github-production-release-asset-2e65be.s3.amazonaws.com (github-production-release-asset-2e65be.s3.amazonaws.com)|52.216.164.75|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 56889445 (54M) [application/octet-stream]
Saving to: ‘DeepCNN_cifar10_snapshot_epoch_90.npz’

DeepCNN_cifar10_sna 100%[===================>]  54.25M  12.5MB/s    in 5.4s

2019-12-09 11:16:07 (10.1 MB/s) - ‘DeepCNN_cifar10_snapshot_epoch_90.npz’ saved [56889445/56889445]

/usr/local/lib/python3.6/dist-packages/chainer/training/triggers/interval_trigger.py:92: UserWarning: The previous value of iteration is not saved. IntervalTrigger guesses it using current iteration. If this trigger is not called at every iteration, it may not work correctly.
  'The previous value of iteration is not saved. '
/usr/local/lib/python3.6/dist-packages/chainer/training/triggers/interval_trigger.py:104: UserWarning: The previous value of epoch_detail is not saved. IntervalTrigger uses the value of trainer.updater.previous_epoch_detail. If this trigger is not called at every iteration, it may not work correctly.
  'The previous value of epoch_detail is not saved. '
epoch       main/loss   main/accuracy  val/main/loss  val/main/accuracy  elapsed_time  lr
1           2.62849     0.144931       2.22525        0.15625            27.9748       0.1
2           2.11316     0.210804       1.97533        0.266406           54.5483       0.1
3           1.87483     0.289396       1.8026         0.320508           81.0827       0.1
4           1.74066     0.340443       1.74728        0.358203           107.7         0.1
5           1.58789     0.409411       1.63916        0.40332            134.235       0.1
6           1.38757     0.492831       1.22399        0.561719           161.093       0.1
7           1.2036      0.566128       1.32939        0.553906           187.55        0.1
8           1.07527     0.617077       1.17079        0.589648           214.184       0.1
9           0.964984    0.660711       0.950063       0.67207            240.817       0.1
10          0.895905    0.688056       0.972697       0.657031           267.372       0.1
11          0.828796    0.715043       0.944733       0.686914           297.779       0.1
12          0.784123    0.731793       0.960205       0.687305           324.312       0.1
13          0.742033    0.748291       0.858308       0.719727           351.032       0.1
14          0.692516    0.76627        0.811853       0.725195           377.591       0.1
15          0.65644     0.776256       0.692374       0.767578           404.23        0.1
16          0.650682    0.780738       0.796731       0.733008           430.797       0.1
17          0.610249    0.793435       0.632162       0.791406           457.467       0.1
18          0.591896    0.803667       0.791757       0.739453           484.04        0.1
19          0.578718    0.804465       1.04659        0.671484           510.602       0.1
20          0.554954    0.814431       0.8127         0.733594           537.247       0.1
21          0.549188    0.814236       0.654926       0.787109           567.567       0.1
22          0.535866    0.820446       0.640323       0.788672           594.214       0.1
23          0.527765    0.823028       0.958373       0.70957            620.761       0.1
24          0.512286    0.830056       0.793664       0.748633           647.363       0.1
25          0.497195    0.833141       0.717548       0.759961           673.951       0.1
26          0.495153    0.835759       1.7557         0.511328           700.494       0.1
27          0.486062    0.837069       0.732518       0.775391           727.13        0.1
28          0.47963     0.839387       0.669157       0.786328           753.692       0.1
29          0.475502    0.838312       0.915904       0.718555           780.308       0.1
30          0.460236    0.844841       0.877713       0.72793            806.806       0.1
31          0.298483    0.897239       0.381921       0.879687           837.16        0.01
32          0.211182    0.92784        0.364046       0.882617           863.898       0.01
33          0.180429    0.937478       0.374651       0.883984           890.554       0.01
34          0.164156    0.943692       0.361041       0.888867           917.199       0.01
35          0.144584    0.950387       0.375391       0.889258           943.774       0.01
36          0.132288    0.954235       0.377427       0.890625           970.4         0.01
37          0.12103     0.957376       0.390434       0.892578           996.96        0.01
38          0.111974    0.961204       0.400307       0.886133           1023.62       0.01
39          0.102573    0.964476       0.399275       0.892773           1050.2        0.01
40          0.0972647   0.965931       0.432854       0.887109           1076.83       0.01
41          0.0928545   0.966597       0.418165       0.887305           1107.19       0.01
42          0.08498     0.96964        0.432462       0.884961           1133.75       0.01
43          0.0845448   0.970792       0.4365         0.880273           1160.37       0.01
44          0.0770674   0.973914       0.441935       0.883789           1187.02       0.01
45          0.0732439   0.974565       0.469901       0.881836           1213.62       0.01
46          0.070491    0.975962       0.453856       0.8875             1240.16       0.01
47          0.068846    0.97583        0.461264       0.881055           1266.75       0.01
48          0.0694101   0.976941       0.435111       0.885742           1293.25       0.01
49          0.0653772   0.977117       0.461284       0.882617           1319.87       0.01
50          0.0633419   0.97836        0.464232       0.889648           1346.47       0.01
51          0.0584663   0.979834       0.464193       0.885547           1376.6        0.01
52          0.0607617   0.979714       0.466352       0.881445           1403.21       0.01
53          0.0615791   0.978632       0.451807       0.889453           1429.69       0.01
54          0.0588031   0.979914       0.489054       0.881836           1456.28       0.01
55          0.0582368   0.979367       0.4719         0.882617           1482.83       0.01
56          0.0558719   0.981379       0.495846       0.878125           1509.45       0.01
57          0.0579962   0.979936       0.472415       0.875781           1536.06       0.01
58          0.0592009   0.9793         0.454762       0.88418            1562.6        0.01
59          0.0568546   0.980735       0.487556       0.876172           1589.17       0.01
60          0.0579785   0.980079       0.472908       0.883398           1615.93       0.01
61          0.0318918   0.98968        0.416953       0.895703           1646.15       0.001
62          0.0220127   0.993612       0.416859       0.899609           1672.69       0.001
63          0.0186169   0.99434        0.417849       0.898242           1699.29       0.001
64          0.0159041   0.995526       0.41804        0.900781           1725.78       0.001
65          0.0147089   0.995916       0.429896       0.899609           1752.39       0.001
66          0.0129457   0.996404       0.433748       0.898828           1779.01       0.001
67          0.0131643   0.996283       0.433923       0.898828           1805.54       0.001
68          0.0112659   0.996893       0.437222       0.901758           1832.17       0.001
69          0.0106502   0.997151       0.443475       0.901758           1858.73       0.001
70          0.0107926   0.997203       0.445066       0.900391           1885.34       0.001
71          0.0105973   0.997062       0.44159        0.898633           1915.8        0.001
72          0.00934292  0.99767        0.450084       0.897266           1942.45       0.001
73          0.0104884   0.997092       0.451691       0.899023           1969.06       0.001
74          0.00849317  0.997707       0.450391       0.9                1995.61       0.001
75          0.00846932  0.997891       0.451362       0.902148           2022.21       0.001
76          0.00826699  0.997841       0.448779       0.900781           2048.69       0.001
77          0.00875069  0.997492       0.45095        0.900391           2075.31       0.001
78          0.00823296  0.998019       0.449194       0.898438           2102.1        0.001
79          0.00701245  0.998113       0.454196       0.899609           2128.74       0.001
80          0.00846517  0.997596       0.455877       0.901172           2155.29       0.001
81          0.00677115  0.99818        0.459518       0.899805           2185.54       0.001
82          0.00717393  0.998047       0.465337       0.899805           2212.13       0.001
83          0.00709802  0.997908       0.464472       0.898828           2238.69       0.001
84          0.00699702  0.998091       0.470595       0.899219           2265.27       0.001
85          0.00746627  0.997975       0.470063       0.901172           2291.78       0.001
86          0.00666763  0.998069       0.466201       0.899414           2318.38       0.001
87          0.00616266  0.998531       0.462948       0.9                2344.86       0.001
88          0.00719447  0.997847       0.463587       0.899609           2371.48       0.001
89          0.00638493  0.998335       0.465655       0.901367           2398.08       0.001
90          0.0061445   0.998286       0.464918       0.900586           2424.59       0.001
91          0.00582547  0.99838        0.46484        0.900977           2439.32       0.0001
92          0.00602776  0.998331       0.461773       0.901367           2452.62       0.0001
93          0.00597372  0.998491       0.464172       0.900195           2466.17       0.0001
94          0.00619668  0.99813        0.46446        0.900391           2479.48       0.0001
95          0.00545569  0.998557       0.466654       0.900977           2492.64       0.0001
96          0.00613322  0.998308       0.465335       0.900781           2505.78       0.0001
97          0.0054181   0.998624       0.465642       0.900586           2518.9        0.0001
98          0.00512285  0.998801       0.467116       0.900781           2532.13       0.0001
99          0.00562234  0.998576       0.464966       0.901367           2545.35       0.0001
100         0.00551726  0.998624       0.462725       0.901367           2558.63       0.0001
Test accuracy: 0.8966574

ゼロから学習する場合:

reset_seed(0)

model = train(DeepCNN(10), max_epoch=100, base_lr=0.1, lr_decay=(30, 'epoch'))

学習が終了しました.学習曲線と精度のグラフを見てみましょう.

[40]:
Image(filename='results/DeepCNN_cifar10_result/loss.png')
[40]:
../_images/notebooks_04_Introduction_to_Chainer_115_0.png
[41]:
Image(filename='results/DeepCNN_cifar10_result/accuracy.png')
[41]:
../_images/notebooks_04_Introduction_to_Chainer_116_0.png

先程の浅い(層数の少ない)CNNを用いた際には60%前後だったValidationデータでの精度が,90%程度まで上がりました.また,テストデータを用いた精度も,およそ90%程度となっています.しかし最新の研究成果では97%以上まで達成されています.さらに精度を上げるには,今回行ったようなネットワークの構造自体の改良ももちろんのこと,学習データを擬似的に増やす操作(Data augmentation)や,複数のモデルの出力を一つの出力に統合する操作(Ensemble)などなど,いろいろな工夫が考えられます.

4.5. データセットクラスの使用方法

ここでは,Chainerにすでに用意されているCIFAR10のデータを取得する機能を使って,データセットクラスを自分で書いてみます.Chainerでは,データセットを表すクラスは以下の機能を持っている必要があります.

  • データセット内のデータ数を返す__len__メソッド
  • 引数として渡されるiに対応したデータもしくはデータとラベルの組を返すget_exampleメソッド

その他のデータセットに必要な機能は,chainer.dataset.DatasetMixinクラスを継承することで用意できます.ここでは,DatasetMixinクラスを継承し,学習時に学習データに変換を施してモデルが受け取るデータのバリエーションを増やすData augmentation機能のついたデータセットクラスを作成してみましょう.

4.5.1. CIFAR10データセットクラス

[42]:
class CIFAR10Augmented(chainer.dataset.DatasetMixin):

    def __init__(self, split='train', train_ratio=0.9):
        train_val, test_data = cifar.get_cifar10()
        train_size = int(len(train_val) * train_ratio)
        train_data, valid_data = split_dataset_random(train_val, train_size, seed=0)
        if split == 'train':
            self.data = train_data
        elif split == 'valid':
            self.data = valid_data
        elif split == 'test':
            self.data = test_data
        else:
            raise ValueError("'split' argument should be either 'train', 'valid', or 'test'. But {} was given.".format(split))

        self.split = split
        self.random_crop = 4

    def __len__(self):
        return len(self.data)

    def get_example(self, i):
        x, t = self.data[i]
        if self.split == 'train':
            x = x.transpose(1, 2, 0)
            h, w, _ = x.shape
            x_offset = np.random.randint(self.random_crop)
            y_offset = np.random.randint(self.random_crop)
            x = x[y_offset:y_offset + h - self.random_crop,
                  x_offset:x_offset + w - self.random_crop]
            if np.random.rand() > 0.5:
                x = np.fliplr(x)
            x = x.transpose(2, 0, 1)

        return x, t

このクラスは,CIFAR10のデータのそれぞれに対し,

  • 32x32の大きさの中からランダムに28x28の領域をクロップ
  • 1/2の確率で左右を反転させる

という加工を行っています.このような操作を加えて擬似的に学習データのバリエーションを増やすことで,オーバーフィッティングの抑制などに寄与することが知られています.これらの操作以外にも,画像の色味を変化させるような変換やランダムな回転,アフィン変換など,さまざまな加工によって学習データ数を擬似的に増やす方法が提案されています.

4.5.2. 作成したデータセットクラスを用いた学習

それではさっそくこのCIFAR10クラスを使って学習を行ってみましょう.先程と同じネットワークを用い,Data augmentationの効果がどの程度あるのかを調べてみましょう.train関数も含め,データセットクラス以外は先程とすべて同様です.

ここでも,40分ほどの時間がかかりますので,上と同様に90エポックまで学習したあとのsnapshotをダウンロードして読み込ませ,最後の10エポックだけ実際に学習させてみましょう.

[43]:
!wget https://github.com/japan-medical-ai/medical-ai-course-materials/releases/download/v0.1/DeepCNN_cifar10_augmented_snapshot_epoch_90.npz

reset_seed(0)

model = train(DeepCNN(10), max_epoch=100, train_dataset=CIFAR10Augmented(), valid_dataset=CIFAR10Augmented('valid'), test_dataset=CIFAR10Augmented('test'), postfix='augmented_', base_lr=0.1, lr_decay=(30, 'epoch'), snapshot='DeepCNN_cifar10_augmented_snapshot_epoch_90.npz')
--2019-12-09 11:18:31--  https://github.com/japan-medical-ai/medical-ai-course-materials/releases/download/v0.1/DeepCNN_cifar10_augmented_snapshot_epoch_90.npz
Resolving github.com (github.com)... 52.74.223.119
Connecting to github.com (github.com)|52.74.223.119|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://github-production-release-asset-2e65be.s3.amazonaws.com/153412006/5064a880-eeb7-11e8-95bf-80b5d9533256?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20191209%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20191209T111832Z&X-Amz-Expires=300&X-Amz-Signature=1dc193a84a383234eb92fd9c7cb0db88c9ad7dad53e4e9a32d06f33a3825b832&X-Amz-SignedHeaders=host&actor_id=0&response-content-disposition=attachment%3B%20filename%3DDeepCNN_cifar10_augmented_snapshot_epoch_90.npz&response-content-type=application%2Foctet-stream [following]
--2019-12-09 11:18:32--  https://github-production-release-asset-2e65be.s3.amazonaws.com/153412006/5064a880-eeb7-11e8-95bf-80b5d9533256?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20191209%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20191209T111832Z&X-Amz-Expires=300&X-Amz-Signature=1dc193a84a383234eb92fd9c7cb0db88c9ad7dad53e4e9a32d06f33a3825b832&X-Amz-SignedHeaders=host&actor_id=0&response-content-disposition=attachment%3B%20filename%3DDeepCNN_cifar10_augmented_snapshot_epoch_90.npz&response-content-type=application%2Foctet-stream
Resolving github-production-release-asset-2e65be.s3.amazonaws.com (github-production-release-asset-2e65be.s3.amazonaws.com)... 52.216.241.116
Connecting to github-production-release-asset-2e65be.s3.amazonaws.com (github-production-release-asset-2e65be.s3.amazonaws.com)|52.216.241.116|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 56730280 (54M) [application/octet-stream]
Saving to: ‘DeepCNN_cifar10_augmented_snapshot_epoch_90.npz’

DeepCNN_cifar10_aug 100%[===================>]  54.10M  12.6MB/s    in 5.1s

2019-12-09 11:18:38 (10.5 MB/s) - ‘DeepCNN_cifar10_augmented_snapshot_epoch_90.npz’ saved [56730280/56730280]

/usr/local/lib/python3.6/dist-packages/chainer/training/triggers/interval_trigger.py:92: UserWarning: The previous value of iteration is not saved. IntervalTrigger guesses it using current iteration. If this trigger is not called at every iteration, it may not work correctly.
  'The previous value of iteration is not saved. '
/usr/local/lib/python3.6/dist-packages/chainer/training/triggers/interval_trigger.py:104: UserWarning: The previous value of epoch_detail is not saved. IntervalTrigger uses the value of trainer.updater.previous_epoch_detail. If this trigger is not called at every iteration, it may not work correctly.
  'The previous value of epoch_detail is not saved. '
epoch       main/loss   main/accuracy  val/main/loss  val/main/accuracy  elapsed_time  lr
1           2.5875      0.156405       2.11656        0.203125           24.1767       0.1
2           1.99359     0.233842       1.84577        0.304492           47.9466       0.1
3           1.76968     0.325365       1.98983        0.26543            71.6864       0.1
4           1.61662     0.389537       2.07369        0.26875            95.5072       0.1
5           1.41259     0.478989       1.52089        0.446484           119.236       0.1
6           1.23382     0.555487       1.48775        0.480664           143.06        0.1
7           1.09323     0.613404       1.11949        0.590234           166.813       0.1
8           0.99303     0.650857       1.33017        0.566992           190.624       0.1
9           0.926848    0.678755       1.0075         0.665234           214.446       0.1
10          0.863165    0.702769       0.858035       0.717383           238.218       0.1
11          0.807948    0.727162       0.9556         0.679297           265.757       0.1
12          0.769224    0.739116       0.857765       0.711328           289.507       0.1
13          0.739977    0.752952       0.91583        0.711133           313.339       0.1
14          0.72064     0.759393       1.20587        0.61875            337.097       0.1
15          0.690136    0.77093        0.837919       0.726562           360.905       0.1
16          0.673935    0.774706       1.03539        0.678711           384.67        0.1
17          0.662879    0.778742       0.730712       0.758789           408.47        0.1
18          0.639202    0.78742        0.758566       0.765625           432.298       0.1
19          0.625988    0.792713       1.24791        0.664062           456.061       0.1
20          0.616269    0.795277       0.963706       0.70625            479.882       0.1
21          0.611734    0.795962       0.887129       0.723437           507.312       0.1
22          0.600444    0.800582       0.889526       0.710352           531.173       0.1
23          0.605317    0.800414       0.715702       0.756445           554.926       0.1
24          0.584194    0.805731       0.984225       0.694336           578.687       0.1
25          0.584041    0.805464       0.956576       0.685156           602.49        0.1
26          0.57384     0.80954        0.977559       0.712695           627.031       0.1
27          0.560405    0.814298       0.894127       0.718945           650.845       0.1
28          0.559933    0.816195       0.729981       0.752734           674.584       0.1
29          0.555933    0.814387       0.841304       0.727344           698.4         0.1
30          0.558057    0.814971       0.753542       0.757227           722.119       0.1
31          0.397613    0.866455       0.337977       0.888867           749.619       0.01
32          0.320024    0.890202       0.322082       0.894336           773.368       0.01
33          0.293655    0.900479       0.323365       0.888867           797.169       0.01
34          0.279553    0.904874       0.308263       0.897656           820.951       0.01
35          0.26519     0.909945       0.301763       0.897852           844.682       0.01
36          0.259166    0.909846       0.286909       0.904688           868.482       0.01
37          0.246168    0.915064       0.289997       0.904688           892.324       0.01
38          0.24097     0.91697        0.280986       0.903906           916.139       0.01
39          0.233951    0.918892       0.291962       0.904102           939.884       0.01
40          0.220939    0.923628       0.299502       0.902539           963.666       0.01
41          0.216055    0.924272       0.2946         0.905664           991.125       0.01
42          0.215143    0.926972       0.308637       0.897266           1014.88       0.01
43          0.213903    0.926625       0.291742       0.907812           1038.64       0.01
44          0.203283    0.929198       0.296043       0.905469           1062.33       0.01
45          0.19772     0.931041       0.327708       0.89375            1086.11       0.01
46          0.194907    0.932893       0.312555       0.901563           1109.82       0.01
47          0.190354    0.934326       0.336271       0.895703           1133.58       0.01
48          0.190431    0.932915       0.326305       0.902539           1157.3        0.01
49          0.18926     0.934326       0.312767       0.901172           1181.05       0.01
50          0.184469    0.936035       0.296937       0.907812           1204.93       0.01
51          0.181691    0.936521       0.324149       0.901172           1232.22       0.01
52          0.17546     0.939675       0.347524       0.893945           1256          0.01
53          0.175786    0.937723       0.335627       0.893164           1279.74       0.01
54          0.173612    0.940274       0.317897       0.902344           1303.52       0.01
55          0.171849    0.939548       0.306998       0.90625            1327.23       0.01
56          0.168304    0.94165        0.31145        0.902148           1351.04       0.01
57          0.170139    0.941495       0.301311       0.910156           1374.82       0.01
58          0.165011    0.942708       0.359516       0.892773           1398.52       0.01
59          0.163968    0.94256        0.365818       0.886133           1422.3        0.01
60          0.16541     0.942575       0.357          0.890234           1446.01       0.01
61          0.129435    0.955988       0.277052       0.915234           1473.42       0.001
62          0.101981    0.965434       0.284798       0.916406           1497.13       0.001
63          0.0953637   0.967285       0.279956       0.919727           1520.93       0.001
64          0.0911066   0.968171       0.28204        0.918359           1544.63       0.001
65          0.0853851   0.97037        0.28504        0.918945           1568.41       0.001
66          0.0800331   0.972101       0.287688       0.917969           1592.21       0.001
67          0.0761374   0.973202       0.29148        0.920117           1616.19       0.001
68          0.0756613   0.973699       0.299635       0.918945           1639.98       0.001
69          0.075577    0.97407        0.293845       0.918359           1663.68       0.001
70          0.0730666   0.974676       0.29563        0.920508           1687.47       0.001
71          0.070825    0.975516       0.295581       0.920313           1714.73       0.001
72          0.0710753   0.975697       0.29838        0.919336           1738.52       0.001
73          0.0705982   0.975142       0.298369       0.920508           1762.3        0.001
74          0.0667571   0.976562       0.299809       0.920508           1786.01       0.001
75          0.0642319   0.978427       0.300881       0.920898           1809.77       0.001
76          0.0640179   0.977742       0.304647       0.918359           1833.51       0.001
77          0.0629752   0.977761       0.299763       0.919336           1857.31       0.001
78          0.0586612   0.979523       0.306034       0.922461           1881.03       0.001
79          0.059752    0.979869       0.311227       0.921289           1904.82       0.001
80          0.0571715   0.980213       0.304607       0.920703           1928.51       0.001
81          0.0573339   0.980136       0.315108       0.92168            1956.06       0.001
82          0.0560348   0.979847       0.321934       0.916992           1979.87       0.001
83          0.0553193   0.980613       0.315378       0.914648           2003.58       0.001
84          0.0531816   0.98129        0.318977       0.919531           2027.34       0.001
85          0.0560367   0.98097        0.310993       0.919141           2051.02       0.001
86          0.0535048   0.981534       0.317829       0.920117           2075.04       0.001
87          0.0522188   0.981571       0.31144        0.920313           2098.73       0.001
88          0.0526632   0.982156       0.318594       0.920703           2122.46       0.001
89          0.0528096   0.981445       0.309017       0.92207            2146.21       0.001
90          0.0499371   0.982928       0.313269       0.920508           2169.92       0.001
91          0.046129    0.984375       0.312747       0.919141           2182.89       0.0001
92          0.0442634   0.984642       0.308682       0.921484           2195.39       0.0001
93          0.0456881   0.98422        0.308501       0.920898           2208.13       0.0001
94          0.0450576   0.984553       0.311052       0.921484           2220.46       0.0001
95          0.0450287   0.985152       0.310078       0.920117           2232.59       0.0001
96          0.0444907   0.984486       0.312837       0.921289           2244.59       0.0001
97          0.0439379   0.985418       0.310588       0.92168            2256.65       0.0001
98          0.0430186   0.985352       0.310459       0.920508           2268.82       0.0001
99          0.0419429   0.985288       0.310179       0.92168            2281.02       0.0001
100         0.0421573   0.985574       0.31397        0.920898           2293.27       0.0001
Test accuracy: 0.917227

先程のData augmentationなしの場合は90%程度だったテスト精度が,学習データにaugmentationを施すことでおよそ1.8%程度向上していることが分かりました.

損失と精度のグラフを見てみましょう.

[44]:
Image(filename='results/DeepCNN_cifar10_augmented_result/loss.png')
[44]:
../_images/notebooks_04_Introduction_to_Chainer_125_0.png
[45]:
Image(filename='results/DeepCNN_cifar10_augmented_result/accuracy.png')
[45]:
../_images/notebooks_04_Introduction_to_Chainer_126_0.png

4.6. Data Augmentationの簡単な使い方

前述のようにデータセット内の各画像についていろいろな変換を行って擬似的にデータを増やすような操作をData Augmentationといいます.上では,オリジナルのデータセットクラスを作る方法を示すために変換の操作もget_example()内に書くという実装を行いましたが,実はもっと簡単にいろいろな変換をデータに対して行う方法があります.

それは,TransformDatasetクラスを使う方法です.TransformDatasetは,元になるデータセットオブジェクトと,そこからサンプルしてきた各データ点に対して行いたい変換を関数の形で与えると,変換済みのデータを返してくれるようなデータセットオブジェクトに加工してくれる便利なクラスです.簡単な使い方は以下のようになります.

[46]:
from chainer.datasets import TransformDataset

train_val, test_dataset = cifar.get_cifar10()
train_size = int(len(train_val) * 0.9)
train_dataset, valid_dataset = split_dataset_random(train_val, train_size, seed=0)


# 行いたい変換を関数の形で書く
def transform(inputs):
    x, t = inputs
    x = x.transpose(1, 2, 0)
    h, w, _ = x.shape
    x_offset = np.random.randint(4)
    y_offset = np.random.randint(4)
    x = x[y_offset:y_offset + h - 4,
          x_offset:x_offset + w - 4]
    if np.random.rand() > 0.5:
        x = np.fliplr(x)
    x = x.transpose(2, 0, 1)

    return x, t


# 各データをtransform関数で処理して返すデータセットオブジェクト
train_dataset = TransformDataset(train_dataset, transform)

このようにして得られた新しいtrain_datasetは,自作のデータセットクラスと同じような変換処理を行った上でデータを返してくれるデータセットオブジェクトとなります.

4.6.1. ChainerCVを活用した変換処理

さて,先ほどご紹介したコードでは,画像に対するランダムクロップ,及びランダムな左右反転の処理を自ら実装していました.もし,より多様な変換を行いたい場合,上記のtransform関数に処理を追加していくことになりますが,一般的に用いられる変換処理をその度に自ら実装するのは手間です.そこで本項では最後に,ChainerCV[Niitani 2017]をご紹介します.ChainerCVは,Computer Visionに特化した機能が豊富に追加された,Chainerの補助パッケージとしての役割を担うオープンソース・ソフトウェアです.

[47]:
!pip install chainercv
Collecting chainercv
  Downloading https://files.pythonhosted.org/packages/e8/1c/1f267ccf5ebdf1f63f1812fa0d2d0e6e35f0d08f63d2dcdb1351b0e77d85/chainercv-0.13.1.tar.gz (260kB)
     |████████████████████████████████| 266kB 35.5MB/s
Requirement already satisfied: chainer>=6.0 in /usr/local/lib/python3.6/dist-packages (from chainercv) (6.5.0)
Requirement already satisfied: Pillow in /usr/local/lib/python3.6/dist-packages (from chainercv) (4.3.0)
Requirement already satisfied: protobuf>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from chainer>=6.0->chainercv) (3.10.0)
Requirement already satisfied: six>=1.9.0 in /usr/local/lib/python3.6/dist-packages (from chainer>=6.0->chainercv) (1.12.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.6/dist-packages (from chainer>=6.0->chainercv) (3.0.12)
Requirement already satisfied: typing-extensions<=3.6.6 in /usr/local/lib/python3.6/dist-packages (from chainer>=6.0->chainercv) (3.6.6)
Requirement already satisfied: typing<=3.6.6 in /usr/local/lib/python3.6/dist-packages (from chainer>=6.0->chainercv) (3.6.6)
Requirement already satisfied: numpy>=1.9.0 in /usr/local/lib/python3.6/dist-packages (from chainer>=6.0->chainercv) (1.17.4)
Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from chainer>=6.0->chainercv) (42.0.1)
Requirement already satisfied: olefile in /usr/local/lib/python3.6/dist-packages (from Pillow->chainercv) (0.46)
Building wheels for collected packages: chainercv
  Building wheel for chainercv (setup.py) ... done
  Created wheel for chainercv: filename=chainercv-0.13.1-cp36-cp36m-linux_x86_64.whl size=537355 sha256=fc0aaac281e6d6effbb0699ea77604eacc16f6683691d9463803c54c192a746b
  Stored in directory: /root/.cache/pip/wheels/ea/10/01/e221beaa4b3d8341aa819a39ab8d4677457c79c81f521f3a94
Successfully built chainercv
Installing collected packages: chainercv
Successfully installed chainercv-0.13.1

ChainerCVには,画像に対する様々な変換があらかじめ用意されています.

例えば,上でNumPyを使って書いていたランダムクロップやランダム左右反転は,chainercv.transformsモジュールを使うと,それぞれ以下のように1行で書くことができます:

x = chainercv.transforms.random_crop(x, (28, 28))  # ランダムクロップ
x = chainercv.transforms.random_flip(x)  # ランダム左右反転

chainercv.transformsモジュールを使って,transform関数をアップデートしてみましょう.ちなみに,get_cifar10()で得られるデータセットでは,デフォルトで画像の画素値の範囲が[0, 1]にスケールされています.しかし,get_cifar10()scale=255.を渡しておくと,値の範囲をもともとの[0, 255]のままにできます.今回行われる処理は,以下の5つです:

  1. PCA lighting: 先行研究(AlexNet)の学習で使われていた方法で,色を変化させる変換処理を行います.
  2. Standardization: 訓練用データセット全体からチャンネルごとの画素値の平均・標準偏差を求めて標準化をします
  3. Random flip: ランダムに画像の左右を反転します
  4. Random expand: [1, 1.5]からランダムに決めた大きさの黒いキャンバスを作り,その中のランダムな位置へ画像を配置します
  5. Random crop: (28, 28)の大きさの領域をランダムにクロップします
[48]:
from functools import partial
from chainercv import transforms

train_val, test_dataset = cifar.get_cifar10(scale=255.)
train_size = int(len(train_val) * 0.9)
train_dataset, valid_dataset = split_dataset_random(train_val, train_size, seed=0)

mean = np.mean([x for x, _ in train_dataset], axis=(0, 2, 3))
std = np.std([x for x, _ in train_dataset], axis=(0, 2, 3))


def transform(inputs, mean, std, train=True):
    img, label = inputs
    img = img.copy()

    # Color augmentation
    if train:
        img = transforms.pca_lighting(img, 76.5)

    # Standardization
    img -= mean[:, None, None]
    img /= std[:, None, None]

    # Random flip & crop
    if train:
        img = transforms.random_flip(img, x_random=True)
        img = transforms.random_expand(img, max_ratio=1.5)
        img = transforms.random_crop(img, (28, 28))

    return img, label

train_dataset = TransformDataset(train_dataset, partial(transform, mean=mean, std=std, train=True))
valid_dataset = TransformDataset(valid_dataset, partial(transform, mean=mean, std=std, train=False))
test_dataset = TransformDataset(test_dataset, partial(transform, mean=mean, std=std, train=False))

では,standardizationとChainerCVによるPCA Lightingを追加したTransformDatasetを使って学習をしてみましょう.

これまでと同様,90エポックまで学習させておいたsnapshotを用いて,最後の10エポックだけ学習を行います.

[49]:
!wget https://github.com/japan-medical-ai/medical-ai-course-materials/releases/download/v0.1/DeepCNN_cifar10_augmented2_snapshot_epoch_90.npz

reset_seed(0)

model = train(DeepCNN(10), max_epoch=100, train_dataset=train_dataset, valid_dataset=valid_dataset, test_dataset=test_dataset, postfix='augmented2_', base_lr=0.1, lr_decay=(30, 'epoch'), snapshot='DeepCNN_cifar10_augmented2_snapshot_epoch_90.npz')
--2019-12-09 11:21:16--  https://github.com/japan-medical-ai/medical-ai-course-materials/releases/download/v0.1/DeepCNN_cifar10_augmented2_snapshot_epoch_90.npz
Resolving github.com (github.com)... 13.229.188.59
Connecting to github.com (github.com)|13.229.188.59|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://github-production-release-asset-2e65be.s3.amazonaws.com/153412006/5064a880-eeb7-11e8-8e8b-fddbe76ecd56?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20191209%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20191209T112116Z&X-Amz-Expires=300&X-Amz-Signature=bf6cac01b9855110a9abeb2df76a65e697c6226c2353f102c24da044d69f80ef&X-Amz-SignedHeaders=host&actor_id=0&response-content-disposition=attachment%3B%20filename%3DDeepCNN_cifar10_augmented2_snapshot_epoch_90.npz&response-content-type=application%2Foctet-stream [following]
--2019-12-09 11:21:16--  https://github-production-release-asset-2e65be.s3.amazonaws.com/153412006/5064a880-eeb7-11e8-8e8b-fddbe76ecd56?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20191209%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20191209T112116Z&X-Amz-Expires=300&X-Amz-Signature=bf6cac01b9855110a9abeb2df76a65e697c6226c2353f102c24da044d69f80ef&X-Amz-SignedHeaders=host&actor_id=0&response-content-disposition=attachment%3B%20filename%3DDeepCNN_cifar10_augmented2_snapshot_epoch_90.npz&response-content-type=application%2Foctet-stream
Resolving github-production-release-asset-2e65be.s3.amazonaws.com (github-production-release-asset-2e65be.s3.amazonaws.com)... 52.216.20.243
Connecting to github-production-release-asset-2e65be.s3.amazonaws.com (github-production-release-asset-2e65be.s3.amazonaws.com)|52.216.20.243|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 56734002 (54M) [application/octet-stream]
Saving to: ‘DeepCNN_cifar10_augmented2_snapshot_epoch_90.npz’

DeepCNN_cifar10_aug 100%[===================>]  54.11M  11.5MB/s    in 6.0s

2019-12-09 11:21:23 (8.96 MB/s) - ‘DeepCNN_cifar10_augmented2_snapshot_epoch_90.npz’ saved [56734002/56734002]

/usr/local/lib/python3.6/dist-packages/chainer/training/triggers/interval_trigger.py:92: UserWarning: The previous value of iteration is not saved. IntervalTrigger guesses it using current iteration. If this trigger is not called at every iteration, it may not work correctly.
  'The previous value of iteration is not saved. '
/usr/local/lib/python3.6/dist-packages/chainer/training/triggers/interval_trigger.py:104: UserWarning: The previous value of epoch_detail is not saved. IntervalTrigger uses the value of trainer.updater.previous_epoch_detail. If this trigger is not called at every iteration, it may not work correctly.
  'The previous value of epoch_detail is not saved. '
epoch       main/loss   main/accuracy  val/main/loss  val/main/accuracy  elapsed_time  lr
1           2.64248     0.137296       2.15132        0.175              23.5716       0.1
2           2.09025     0.202814       1.90949        0.248633           47.2547       0.1
3           1.92989     0.251647       1.80744        0.313867           70.8962       0.1
4           1.82033     0.296742       1.75358        0.319727           94.6262       0.1
5           1.70069     0.350717       1.59837        0.397266           118.252       0.1
6           1.56064     0.418435       1.6835         0.418359           141.967       0.1
7           1.45132     0.473691       1.30451        0.528711           165.613       0.1
8           1.33843     0.523815       2.81471        0.372656           189.382       0.1
9           1.27017     0.556019       1.61386        0.499023           213.846       0.1
10          1.20204     0.579972       1.46257        0.519141           237.532       0.1
11          1.14896     0.601252       1.1028         0.623633           265.125       0.1
12          1.10957     0.614428       1.17823        0.594141           288.775       0.1
13          1.07055     0.635387       1.00013        0.670898           312.476       0.1
14          1.04187     0.647058       1.07628        0.642383           336.226       0.1
15          1.00359     0.661066       1.00439        0.655859           360.038       0.1
16          0.971513    0.675503       0.8598         0.723828           383.802       0.1
17          0.941225    0.686301       1.09454        0.661328           407.665       0.1
18          0.917967    0.694869       0.998599       0.681055           431.468       0.1
19          0.905397    0.699586       0.811614       0.738281           455.224       0.1
20          0.874732    0.706188       0.714926       0.762695           479.061       0.1
21          0.86753     0.712607       0.850176       0.738477           506.629       0.1
22          0.858491    0.714844       0.95919        0.692187           530.459       0.1
23          0.845307    0.717637       1.12914        0.668164           554.185       0.1
24          0.827677    0.72785        0.827821       0.746484           577.974       0.1
25          0.821103    0.728715       0.7663         0.741602           601.76        0.1
26          0.816583    0.728321       0.763894       0.757812           625.472       0.1
27          0.805109    0.735596       0.74157        0.754492           649.375       0.1
28          0.809963    0.734152       0.742951       0.7625             673.099       0.1
29          0.792127    0.737149       0.692472       0.773633           696.895       0.1
30          0.783124    0.741965       1.02244        0.702734           720.621       0.1
31          0.604456    0.795898       0.394226       0.868359           748.179       0.01
32          0.523254    0.822227       0.379255       0.873242           771.92        0.01
33          0.502693    0.830078       0.360393       0.877734           795.747       0.01
34          0.484705    0.835627       0.351212       0.884766           819.578       0.01
35          0.465001    0.842192       0.348862       0.882031           843.291       0.01
36          0.45372     0.846613       0.339081       0.885742           867.053       0.01
37          0.450468    0.846043       0.335244       0.887695           890.763       0.01
38          0.439256    0.848411       0.331884       0.889453           914.524       0.01
39          0.430965    0.852475       0.324836       0.895312           938.351       0.01
40          0.424651    0.855447       0.370822       0.881641           962.135       0.01
41          0.418223    0.857          0.328078       0.89082            989.723       0.01
42          0.409156    0.860332       0.334291       0.888672           1013.41       0.01
43          0.410969    0.860574       0.338266       0.888281           1037.14       0.01
44          0.397083    0.862469       0.337938       0.889258           1060.81       0.01
45          0.390805    0.8661         0.31488        0.900586           1084.56       0.01
46          0.39124     0.866008       0.324671       0.893945           1108.24       0.01
47          0.390007    0.865101       0.320954       0.895117           1131.96       0.01
48          0.383439    0.868367       0.355809       0.890625           1155.62       0.01
49          0.380758    0.87065        0.314927       0.899805           1179.36       0.01
50          0.379198    0.870716       0.302073       0.903125           1203.09       0.01
51          0.371325    0.873019       0.317537       0.899805           1230.61       0.01
52          0.368561    0.874578       0.338393       0.891016           1254.34       0.01
53          0.366614    0.872975       0.318233       0.89707            1278.2        0.01
54          0.366082    0.874667       0.328654       0.892187           1301.94       0.01
55          0.365454    0.873331       0.311355       0.902148           1325.61       0.01
56          0.357191    0.876509       0.330824       0.895898           1349.36       0.01
57          0.361274    0.874667       0.320739       0.897461           1373.1        0.01
58          0.354885    0.878339       0.308104       0.901172           1396.78       0.01
59          0.359033    0.876931       0.316534       0.900586           1420.55       0.01
60          0.356422    0.876291       0.366406       0.888281           1444.25       0.01
61          0.314513    0.891513       0.261128       0.914062           1471.67       0.001
62          0.275813    0.905159       0.257487       0.916406           1495.37       0.001
63          0.268649    0.907804       0.253204       0.918164           1519.12       0.001
64          0.264412    0.908298       0.256485       0.918555           1542.84       0.001
65          0.261826    0.910511       0.253851       0.917773           1566.66       0.001
66          0.255704    0.911998       0.257955       0.916602           1590.48       0.001
67          0.25836     0.909566       0.256463       0.919141           1614.22       0.001
68          0.251792    0.912753       0.254393       0.920117           1638.04       0.001
69          0.251154    0.914508       0.25251        0.920508           1661.81       0.001
70          0.24735     0.91464        0.255754       0.91875            1685.83       0.001
71          0.241314    0.917268       0.253207       0.921094           1713.35       0.001
72          0.246059    0.914617       0.257668       0.920508           1737.18       0.001
73          0.238097    0.918213       0.257092       0.919727           1761.02       0.001
74          0.235832    0.918981       0.251519       0.919922           1784.77       0.001
75          0.236254    0.918857       0.253711       0.919531           1808.58       0.001
76          0.235273    0.917646       0.249922       0.920117           1832.36       0.001
77          0.233553    0.918635       0.251188       0.921094           1856.13       0.001
78          0.229216    0.920829       0.256883       0.921484           1879.92       0.001
79          0.231176    0.919877       0.254759       0.92207            1903.75       0.001
80          0.227571    0.921007       0.251693       0.920508           1927.51       0.001
81          0.229313    0.92041        0.257421       0.920508           1955.13       0.001
82          0.225018    0.922896       0.255896       0.919922           1978.98       0.001
83          0.223251    0.923478       0.25848        0.918555           2002.75       0.001
84          0.221238    0.924294       0.259179       0.920703           2026.56       0.001
85          0.222084    0.923411       0.251648       0.920313           2050.3        0.001
86          0.221973    0.924339       0.252463       0.920898           2074.11       0.001
87          0.217851    0.925147       0.252802       0.920898           2097.85       0.001
88          0.215969    0.925138       0.255438       0.921094           2121.65       0.001
89          0.217307    0.92294        0.253423       0.920313           2145.48       0.001
90          0.217241    0.925058       0.254931       0.919922           2169.54       0.001
91          0.210398    0.926336       0.248383       0.924023           2182.44       0.0001
92          0.202761    0.929131       0.248566       0.923828           2195.01       0.0001
93          0.202395    0.930686       0.248684       0.923633           2207.6        0.0001
94          0.205619    0.926794       0.25063        0.921875           2219.81       0.0001
95          0.207099    0.929088       0.251314       0.922852           2231.86       0.0001
96          0.198571    0.932692       0.250731       0.92168            2243.82       0.0001
97          0.200812    0.930331       0.250747       0.921875           2255.85       0.0001
98          0.204992    0.928689       0.250508       0.92207            2267.96       0.0001
99          0.203764    0.92922        0.248585       0.921289           2280.19       0.0001
100         0.202649    0.930331       0.25095        0.923047           2292.4        0.0001
Test accuracy: 0.92523736

わずかに精度が向上しました.他にも,ResNet [He 2016]と呼ばれる有名なネットワーク構造を採用して学習を行うなど,簡単に試せる改善方法がいくつかあります.ぜひ色々と試してみてください.


4.7. 参考文献

[Tokui 2015] Tokui, S., Oono, K., Hido, S. and Clayton, J., Chainer: a Next-Generation Open Source Framework for Deep Learning, Proceedings of Workshop on Machine Learning Systems(LearningSys) in The Twenty-ninth Annual Conference on Neural Information Processing Systems (NIPS), (2015)

[Niitani 2017] Yusuke Niitani, Toru Ogawa, Shunta Saito, Masaki Saito, "ChainerCV: a Library for Deep Learning in Computer Vision", ACM Multimedia (ACMMM), Open Source Software Competition, 2017

[Hidaka 2017] Masatoshi Hidaka, Yuichiro Kikura, Yoshitaka Ushiku, Tatsuya Harada. WebDNN: Fastest DNN Execution Framework on Web Browser. ACM International Conference on Multimedia (ACMMM), Open Source Software Competition, pp.1213-1216, 2017.

[He 2016] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun, "Deep Residual Learning for Image Recognition", CVPR 2016