4. Deep Learningフレームワークの基礎¶
ChainerはDeep Learningフレームワークの一つで,現在様々なDeep Learningフレームワーク(TensorFlow, PyTorch, etc.)でも採用され主要なニューラルネットワークの記法となっているDefine-by-Runというアイデアを汎用的なDeep Learningフレームワークとしては初めて採用し,2015年からPreferred Networks社によって開発が続けられています.Define-by-Runとは,ニューラルネットワーク中の計算を行うコードを記述することでニューラルネットワークの構造を定義する考え方です.学習を行う前にネットワーク構造を定義しておき,そのネットワークに学習に用いるデータを入力するためのコードを別途書く必要がある方法はDefine-and-Runと呼ばれます.Define-by-Runは実行時にネットワーク構造が決定されるため,動的な構造を記述しやすいという特徴があります.
ここでは,その柔軟性と直感的であることを特徴とするこのChainerというフレームワークの基本的な使い方を解説します.
4.1. 環境構築¶
まずはColab上で以下のセルを実行し,必要なライブラリをインストールしましょう.ここではgraphviz
というソフトウェアをインストールしています.これは,後にニューラルネットワークのアーキテクチャをグラフ構造として可視化するために使用します.Google Colab上には,ChainerやCuPyは予めインストールされています.
[1]:
!apt-get install -y graphviz
Reading package lists... Done
Building dependency tree
Reading state information... Done
graphviz is already the newest version (2.40.1-2).
The following packages were automatically installed and are no longer required:
cuda-cufft-10-1 cuda-cufft-dev-10-1 cuda-curand-10-1 cuda-curand-dev-10-1
cuda-cusolver-10-1 cuda-cusolver-dev-10-1 cuda-cusparse-10-1
cuda-cusparse-dev-10-1 cuda-license-10-2 cuda-npp-10-1 cuda-npp-dev-10-1
cuda-nsight-10-1 cuda-nsight-compute-10-1 cuda-nsight-systems-10-1
cuda-nvgraph-10-1 cuda-nvgraph-dev-10-1 cuda-nvjpeg-10-1
cuda-nvjpeg-dev-10-1 cuda-nvrtc-10-1 cuda-nvrtc-dev-10-1 cuda-nvvp-10-1
libcublas10 libnvidia-common-430 nsight-compute-2019.5.0
nsight-systems-2019.5.2
Use 'apt autoremove' to remove them.
0 upgraded, 0 newly installed, 0 to remove and 5 not upgraded.
それでは,以下のコマンドをターミナルで実行し,Chainerや,ChainerでGPUを活用するために必要となるCuPyというパッケージが正しくインストールされているかどうかを確認してみましょう.
[2]:
!python -c 'import chainer; chainer.print_runtime_info()'
Platform: Linux-4.14.137+-x86_64-with-Ubuntu-18.04-bionic
Chainer: 6.5.0
ChainerX: Not Available
NumPy: 1.17.4
CuPy:
CuPy Version : 6.5.0
CUDA Root : /usr/local/cuda
CUDA Build Version : 10010
CUDA Driver Version : 10010
CUDA Runtime Version : 10010
cuDNN Build Version : 7603
cuDNN Version : 7603
NCCL Build Version : 2402
NCCL Runtime Version : 2402
iDeep: 2.0.0.post3
Chainer, NumPy, そしてCuPy, さらにCuPyの下にCUDAやcuDNN, NCCLといった項目があり,それぞれバージョン番号が表示されていれば成功です.
4.2. Chainerの基本的な使い方¶
はじめに,シンプルなタスクに実際に取り組むことによって,Chainerの基本的な使い方を説明していきます.さっそく,有名な手書き数字のデータセットMNISTを使って,画像を10クラス(数字の0 - 9)のいずれかに分類するネットワークを書き,学習させてみましょう.
4.2.1. データセットの準備¶
まずは学習対象となるデータセットの準備をします.教師あり学習の場合,データセットは「入力データ」と「それと対になるラベルデータ」のペアを返すオブジェクトである必要があります.
Chainerには,MNISTやCIFAR10/100のような良く用いられるデータセットに対して,データのダウンロードからオブジェクト作成までを自動的に行ってくれる便利なメソッドがあります.ここではひとまずこれを用いましょう.
[3]:
from chainer.datasets import mnist
# データセットがダウンロード済みでなければ,ダウンロードも行う
train_val, test = mnist.get_mnist(withlabel=True, ndim=1)
Downloading from http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz...
Downloading from http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz...
Downloading from http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz...
Downloading from http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz...
データセットオブジェクトの準備ができました.このオブジェクトは, train_val[i]
のように指定すると,i番目の (data, label) というタプルを返すリスト と同様のものと考えてください.(実際ただのPythonリストもChainerのデータセットオブジェクトとして利用可能です).それでは,0番目のデータとラベルを取り出して,表示してみましょう.
[4]:
# matplotlibを使ったグラフ描画結果がnotebook内に表示されるようにします.
%matplotlib inline
import matplotlib.pyplot as plt
# データの例示
x, t = train_val[0] # 0番目の (data, label) を取り出す
plt.imshow(x.reshape(28, 28), cmap='gray')
plt.axis('off')
plt.show()
print('label:', t)
label: 5
4.2.2. Validation用データセットを作る¶
次に,さきほど作成したtrain_val
データセットを,Training用のデータセットとValidation用のデータセットに分割します.Validationデータセットとは,学習には用いずにモデルの汎化性能をチェックしたり,学習率などのハイパーパラメータを調整するために用いる検証用のデータセットスプリットのことです.分割処理も,Chainerが提供しているデータセット分割用の関数を用いて行うことができます.元々60000個のデータが入っているtrain
データセットを,ランダムに選択された50000個のデータと残りの10000個のデータの2つに分割しましょう.これには,split_dataset_random
という関数を使用します.
[5]:
from chainer.datasets import split_dataset_random
train, valid = split_dataset_random(train_val, 50000, seed=0)
関数の第1引数が分割したい対象のデータセットオブジェクト,第2引数が1つ目のデータセットの要素数,第3引数がランダムな抽出を行う際に用いられる乱数シード(これは省略可)となります.第3引数のseed
として同じ値を指定すると,再実行した際にデータセットを同じように分割するようになります.それでは,それぞれのデータセットの中に入っているデータの数を確認してみましょう.
[6]:
print('Training dataset size:', len(train))
print('Validation dataset size:', len(valid))
Training dataset size: 50000
Validation dataset size: 10000
4.2.3. Iteratorの作成¶
次に,さきほど準備したデータセットオブジェクトから,幾つかのデータ(入力とラベルのペア)を束ねて学習モデルに次々に渡す,Iteratorという機能を紹介します.なぜIteratorの機能が必要かというと,ニューラルネットワークのパラメータを更新する際に利用される,確率的勾配降下法(Stochastic Gradient Descent, SGD)をはじめとする最適化手法では,一つのデータだけを元に更新する処理を繰り返すのではなく,幾つかのデータを束ねた ミニバッチ を元に計算していくのが一般的となっているためです(ミニバッチ計算が一般的である理由としては,勾配のミニバッチ平均を計算することでパラメータ更新が安定することや,GPUなどを用いた並列化がしやすいこと等が挙げられます).
Iterator
は,さきほど作成したデータセットオブジェクトを引数として指定し,next()
メソッドを呼ぶことで新しいミニバッチを返してくれます.データセット内のデータすべてを1度ずつ学習に利用し終えた時点のことを 1エポック(epoch) と呼びます.Iteratorの内部では,学習中に何エポックまで学習を行ったか,などの情報が逐次記録されており,データセット内のデータを何度も使って学習のループを回すようなコードを簡単に書くことができるようになります.
データセットオブジェクトからイテレータを作るには,以下のようにします.
[7]:
from chainer import iterators
batchsize = 128
train_iter = iterators.SerialIterator(train, batchsize)
valid_iter = iterators.SerialIterator(
valid, batchsize, repeat=False, shuffle=False)
test_iter = iterators.SerialIterator(
test, batchsize, repeat=False, shuffle=False)
今,学習データセット用のイテレータ(train_iter
)と,検証データセット用のイテレータ(valid_iter
),および学習したネットワークの評価に用いるテストデータセット用のイテレータ(test_iter
)の計3つを作成しました.ここではbatchsize = 128
としているため,作成した3つのイテレータはnext()
メソッドが(train_iter.next()
のように)呼ばれると,128枚の数字画像データを一括りにして返します.実際にnext()
の返り値を調べてみましょう.
[8]:
minibatch = train_iter.next()
このminibatch
という変数は,(img, label)
というタプルが128個(ミニバッチサイズだけ)並んだリストになっています.実際に,このリストの長さが128であることを確認してみましょう.
[9]:
print('batchsize:', len(minibatch))
batchsize: 128
次に,このminibatch
というリストの一つ目の要素(画像とラベルを持つタプルになっているはずです)をminibatch[0]
として取り出してみます.
[10]:
x, t = minibatch[0]
print('x:', x.shape)
print('t:', t.shape)
x: (784,)
t: ()
そのときの返り値である2つの配列 x
と t
のshapeを調べてみると,データはそれぞれ長さ784のベクトルとして格納されており,正解ラベルはスカラー値となっています.784は,\(28 \times 28\)で,28ピクセル四方の画像データの画素値を1列に並べたものになっています.
4.2.3.1. SerialIteratorについて¶
Chainerにいくつか用意されているイテレータの一種であるSerialIterator
は,データセットの中のデータを順番に取り出してくる最もシンプルなイテレータです.SerialIterator
のコンストラクタ(クラスをインスタンス化するタイミングで呼ばれるメソッド)の引数にデータセットオブジェクトと,バッチサイズを取ります.このとき,渡したデータセットオブジェクトから,データを繰り返し読み出す必要がある場合はrepeat
引数をTrue
とし,1周が終わったらそれ以上データを取り出したくない場合はこれをFalse
とします.これは,主にvalidation用のデータセットに対して使うフラグです.デフォルトでは,True
になっています.また,shuffle
引数にTrue
を渡すと,データセットから取り出されてくるデータの順番をエポックごとにランダムに変更します.SerialIterator
の他にも,マルチプロセスで高速にデータを処理できるようにしたMultiprocessIterator
やMultithreadIterator
など,複数のイテレータが用意されています.詳しくは以下を見てください.
4.2.4. ネットワークの定義¶
それでは,学習させるネットワークを定義してみましょう.今回は,全結合層のみからなるニューラルネットワーク(多層パーセプトロン)を作ることにして,中間層のユニット数は100とします.今回用いるMNISTデータセットは0〜9までの数字のいずれかを意味する10種のラベルを持つことから,出力ユニット数は10とします.
ここで,ネットワークを定義するために必要なLink
, Function
, Chain
について簡単に説明します.
4.2.4.1. LinkとFunction¶
Chainerでは,ニューラルネットワークの各層を,Link
とFunction
に区別します.
- Linkは,パラメータを持つ関数です.
- Functionは,パラメータを持たない関数です.
これらを組み合わせてネットワークを記述します.パラメータを持つ層は,chainer.links
モジュール以下に用意されています.例えば chainer.links.Linear
は,前章で説明した全結合層に対応しており,内部に W
と b
という学習できるパラメータが保持されています.パラメータを持たない層は,chainer.functions
モジュール以下に用意されています.これらに簡単にアクセスするために,
import chainer.links as L
import chainer.functions as F
と別名を与えて,L.Convolution2D(...)
やF.relu(...)
のように用いる慣習がありますが,特にこれが決まった書き方というわけではありません.
4.2.4.2. Chain¶
Chain
は,パラメータを持つ層(Link)をまとめておくためのクラスです.パラメータを持つということは,基本的にネットワークの学習の際にそれらを更新していく必要があるということです(更新されないパラメータを持たせることもできます).Chainerでは,モデルのパラメータの更新は,Optimizer
という機能が担います.その際,更新すべき全てのパラメータを簡単に発見できるように,Chain
で一箇所にまとめておきます.
4.2.4.3. 同じ結果を保証する¶
ネットワークを書き始める際に乱数シードを固定すると,本記事とほぼ同様の結果が再現できるようになります.(cuDNNが有効になっている環境下でより厳密に計算結果の再現性を保証したい場合は,chainer.config.cudnn_deterministic
というConfiguringオプションについて知る必要があります.こちらのドキュメントを参照してください:chainer.config.cudnn_deterministic.
[11]:
import random
import numpy
import chainer
def reset_seed(seed=0):
random.seed(seed)
numpy.random.seed(seed)
if chainer.cuda.available:
chainer.cuda.cupy.random.seed(seed)
reset_seed(0)
4.2.4.4. Chainを継承したネットワークの定義¶
Chainerでは,ネットワークは Chain
クラスを継承したクラスとして定義されることが一般的です. Chain
を継承することで,中間層のユニット数=100,出力ユニット数=10とした3層の多層パーセプトロンは以下のように書くことができます.
[12]:
import chainer
import chainer.links as L
import chainer.functions as F
class MLP(chainer.Chain):
def __init__(self, n_mid_units=100, n_out=10):
super(MLP, self).__init__()
# パラメータを持つ層の登録
with self.init_scope():
self.l1 = L.Linear(None, n_mid_units)
self.l2 = L.Linear(n_mid_units, n_mid_units)
self.l3 = L.Linear(n_mid_units, n_out)
def forward(self, x):
# データを受け取った際のforward計算を書く
h1 = F.relu(self.l1(x))
h2 = F.relu(self.l2(h1))
return self.l3(h2)
gpu_id = 0 # CPUを用いる場合は,この値を-1にしてください
net = MLP()
if gpu_id >= 0:
net.to_gpu(gpu_id)
継承した MLP
クラスのコンストラクタ内で with self.init_scope()
が呼ばれており,その中でネットワークに登場するLink
(具体的には,全結合層の L.Linear
)が定義されています.このような形で記述することで,Optimizer
はこれらが最適化対象となるパラメータを持つ層であると自動的に解釈してくれるようになります.
また, forward
というメソッドには,関数の名前の通り,ネットワークの順伝播を記述します.forward
の引数としてデータ x
を受け取り,出力として順伝播の計算結果を返すようにすることで, MLP
クラスをインスタンス化して作成されたオブジェクトを,関数のように使えるようになります.(例:output = net(data)
)
Chainerには数多くの Function
や Link
が用意されています.ぜひ一度以下の一覧のページを見てみてください.
Link
には,ニューラルネットワークによく用いられる全結合層や畳み込み層,LSTMなどに加えて,ResNetや,VGGなどの有名なネットワーク構造も登録されています.また,Function
には,ReLUなどの活性化関数や,画像の大きさをresizeする関数,サイン・コサインのような関数を始め,ネットワークの要素として使える関数が登録されています.Define-by-Runでは,データをネットワークに入力して順伝播計算を行ったあとに,データに適用された関数(パラメータあり・なし両方)の履歴をたどり直すことで,バックプロパゲーションによる勾配計算を行うパスを取得するため,パラメータを持たない関数であっても
chainer.functions
に含まれているものを繋げて用いる必要があります.
4.2.4.5. GPUで実行するには¶
深層学習で用いられるような多くのパラメータを持ったネットワークの学習には,GPUを用いることが一般的となっています.GPUを使うと,行列演算などの一部の処理をCPUに比べとても高速に行うことができます.Chainerで計算をGPUで行う方法は簡単です.Chain
クラスはto_gpu
メソッドを持ち,この引数にGPU IDを指定すると,指定したGPU IDのメモリ上にネットワークの全パラメータを転送します.こうしておくと,順伝播も学習の際のパラメータ更新なども全てGPU上で行われるようになります.GPU
IDとして-1を指定すると,CPUを使用します.
4.2.4.6. 入力側ユニット数の自動計算¶
上のネットワーク定義で,最初のLinear層は第一引数にNone
が渡されています.このように引数を指定すると,データが最初にその層に入力されたタイミングで,自動的に必要な数の入力側のユニット数を判断し, n_input
\(\times\) n_mid_units
の大きさの行列を作成し,学習対象パラメータとして保持します.これは後々,畳み込み層を全結合層の前に配置する際などに便利な機能となるため,覚えておいてください.
4.2.5. 最適化手法の選択¶
それでは,上で定義したネットワークをMNISTデータセットを使って訓練してみましょう.学習時に用いる最適化の手法は数多く提案されていますが,Chainerは多くの手法を同一のインターフェースで利用できるよう,Optimizer
という機能でそれらを提供しています.chainer.optimizers
モジュール以下に定義されています.一覧はこちらにあります:
ここでは最もシンプルな勾配降下法の手法であるoptimizers.SGD
を用います.Optimizer
のオブジェクトには,setup
メソッドを使ってモデル(Chain
オブジェクト)を渡します.こうすることでOptimizer
に,何を最適化すればいいか把握させることができます.
他にもいろいろな最適化手法が手軽に試せるので,色々と試してみて結果の変化を見てみてください.例えば,下のchainer.optimizers.SGD
のうちSGD
の部分をMomentumSGD
, RMSprop
, Adam
などに変えるだけで,最適化手法の違いがどのような学習曲線(ロスカーブとも言う.目的関数の値のプロットのこと)の違いを生むかなどを簡単に調べることができます.最適化の手法によっては,人が与える必要があった学習率を適切に自動決定するものもあります.
[13]:
from chainer import optimizers
optimizer = optimizers.SGD(lr=0.01).setup(net)
4.2.5.1. 学習率(learning rate)¶
今回はSGDのlr
という引数に \(0.01\) を与えました.この値は学習率として知られ,モデルをうまく訓練して良いパフォーマンスを発揮させるために調整する必要がある重要なハイパーパラメータとして知られています.ハイパーパラメータは学習されるパラメータとは異なり人が手で与える学習の設定に関するものやネットワークの構造に関するもののことを指します.
4.2.6. 学習の開始¶
今回は0〜9の数字を区別する分類問題なので,softmax_cross_entropy
という損失関数を使って最小化すべき損失を計算します.Softmax関数は,\(d\)次元のベクトル\({\bf y} \in \mathbb{R}^d\)が与えられたとき,その各次元の値の合計が1になるように正規化することができます.すなわち,確率分布のような出力を任意の実数ベクトルから作ることができます.\({\bf y}\)の\(i\)番目の次元を\(y_i\)と書くと,Softmax関数は
と表せます.これによって正規化された出力ベクトルを入力が各クラスに所属する確率を表しているものと考え,正解の1-hotベクトルとの間で前章で説明した交差エントロピーを計算するのが softmax_cross_entropy
関数です.
まずネットワークにデータを渡し,順伝播により予測値を計算します.そして,この予測値と入力データに対応する正解ラベルを損失関数に渡して損失(最小化したい値)を計算をします.損失は,chainer.Variable
のオブジェクトとして得られます.このVariable
は,過去の計算の履歴を覚えていて,辿れるようになっています.この仕組みが,Define-by-Run [Tokui 2015]とよばれる発明の中心的な役割を果たしています.
計算した損失に対する勾配をネットワークに逆向きに計算していく処理は,Chainerではネットワークが出力したVariable
から,backward
メソッドを呼ぶだけで実現できます.これを呼ぶことで,誤差逆伝播用の計算グラフを構築し,途中のパラメータの勾配を連鎖率を使って計算してくれます.(詳しくは日本ソフトウェア科学会におけるチュートリアルの資料をご覧ください.)
最後に,計算された各パラメータに対する勾配を用いて,Optimizer
によってネットワークパラメータの更新(=学習)が行われます.
まとめると,一連の更新処理の中で行われるのは,以下の4項目となります.
- ネットワークにデータを渡して順伝播を計算し,出力
y
を得る - 出力
y
と正解ラベルt
を使って,最小化すべき損失をsoftmax_cross_entropy
関数で計算する softmax_cross_entropy
関数の出力(Variable
)のbackward
メソッドを呼んで,ネットワークの全てのパラメータの勾配を誤差逆伝播法で計算する- Optimizerの
update
メソッドを呼び,3.で計算した勾配を使って全パラメータを更新する
パラメータの更新は,上記ステップを繰り返すことで行われます.一度のパラメータ更新に用いられるデータは,ネットワークに入力された,ミニバッチとして束ねられたデータのみです.次々と新しいミニバッチを入力し,上記のステップを繰り返すことで,データセット全体を用いて学習を行います.この過程を学習ループと呼んでいます.
4.2.6.1. 目的関数¶
目的関数として,例えば分類問題ではなく回帰問題を解きたいような場合,F.softmax_cross_entropy
の代わりにF.mean_squared_error
などを用いることもできます.他にも,いろいろな問題設定に対応するために様々な損失関数がChainerには用意されています.こちらからその一覧を見ることができます:
4.2.6.2. 学習ループのコード¶
[14]:
import numpy as np
from chainer.dataset import concat_examples
from chainer.cuda import to_cpu
max_epoch = 10
while train_iter.epoch < max_epoch:
# ---------- 学習の1イテレーション ----------
train_batch = train_iter.next()
x, t = concat_examples(train_batch, gpu_id)
# 予測値の計算
y = net(x)
# 損失の計算
loss = F.softmax_cross_entropy(y, t)
# 勾配の計算
net.cleargrads()
loss.backward()
# パラメータの更新
optimizer.update()
# --------------- ここまで ----------------
# 1エポック終了ごとにValidationデータに対する予測精度を測って,
# モデルの汎化性能が向上していることをチェックしよう
if train_iter.is_new_epoch: # 1 epochが終わったら
# 損失の表示
print('epoch:{:02d} train_loss:{:.4f} '.format(
train_iter.epoch, float(to_cpu(loss.data))), end='')
valid_losses = []
valid_accuracies = []
while True:
valid_batch = valid_iter.next()
x_valid, t_valid = concat_examples(valid_batch, gpu_id)
# Validationデータをforward
with chainer.using_config('train', False), \
chainer.using_config('enable_backprop', False):
y_valid = net(x_valid)
# 損失を計算
loss_valid = F.softmax_cross_entropy(y_valid, t_valid)
valid_losses.append(to_cpu(loss_valid.array))
# 精度を計算
accuracy = F.accuracy(y_valid, t_valid)
accuracy.to_cpu()
valid_accuracies.append(accuracy.array)
if valid_iter.is_new_epoch:
valid_iter.reset()
break
print('val_loss:{:.4f} val_accuracy:{:.4f}'.format(
np.mean(valid_losses), np.mean(valid_accuracies)))
# テストデータでの評価
test_accuracies = []
while True:
test_batch = test_iter.next()
x_test, t_test = concat_examples(test_batch, gpu_id)
# テストデータをforward
with chainer.using_config('train', False), \
chainer.using_config('enable_backprop', False):
y_test = net(x_test)
# 精度を計算
accuracy = F.accuracy(y_test, t_test)
accuracy.to_cpu()
test_accuracies.append(accuracy.array)
if test_iter.is_new_epoch:
test_iter.reset()
break
print('test_accuracy:{:.4f}'.format(np.mean(test_accuracies)))
epoch:01 train_loss:0.9100 val_loss:0.9743 val_accuracy:0.8018
epoch:02 train_loss:0.5396 val_loss:0.5336 val_accuracy:0.8645
epoch:03 train_loss:0.4012 val_loss:0.4230 val_accuracy:0.8847
epoch:04 train_loss:0.3329 val_loss:0.3741 val_accuracy:0.8941
epoch:05 train_loss:0.4588 val_loss:0.3455 val_accuracy:0.9002
epoch:06 train_loss:0.2481 val_loss:0.3274 val_accuracy:0.9074
epoch:07 train_loss:0.3306 val_loss:0.3109 val_accuracy:0.9118
epoch:08 train_loss:0.3801 val_loss:0.2990 val_accuracy:0.9145
epoch:09 train_loss:0.2974 val_loss:0.2886 val_accuracy:0.9180
epoch:10 train_loss:0.3216 val_loss:0.2803 val_accuracy:0.9204
test_accuracy:0.9234
val_accuracy
に着目してみると,最終的におよそ92%程度の精度で手書きの数字が分類できるようになりました.ここで言う精度とは,Validationデータセット中に \(N\) 個のデータがあり分類結果が正しかったものが \(M\) 個あるとすると \(M/N\)
を指します.学習中は,各ループの終わりに始めに取り分けておいたValidationデータセットを使って精度をはかることで,モデルの汎化性能をチェックしています.汎化性能とは,主に未知のデータに対する性能の高さのことを意味します.学習終了後には,テスト用のデータセットを用いて,学習が完了したネットワークの評価を行います.テストデータでの評価結果は,およそ92.37%の正解率となりました.
4.2.6.3. ValidationやTestを行う際の注意点¶
学習終了後の最終的な評価には,ハイパーパラメータ調整などにも用いられるValidationデータセットとはさらに別のTestデータセットを用います.TestデータセットはTrainingデータセットともValidationデータセットともデータの重複がないように用意しておきます.
さて,これまでは主に,「学習」のやり方について説明してきましたが,「評価」を行う際には注意すべき点があります.なぜなら,一部の関数や,計算過程において,学習時と評価時でその挙動が異なるためです.以下では,それらの挙動の違いを制御するための方法について説明します.
4.2.6.3.1. chainer.using_config('train', False)
¶
先程の例では,学習時と推論時で動作が異なる関数は含まれていませんでしたが,Validationやテストのために推論を行うときは以下のように,chainer.using_config('train', False)
をwith構文と共に使うことで,その中では対応する関数が推論モードとして動作することになります.これによって,学習時と推論時で挙動が異なる関数などが正しく推論のための動作をするようになります(例えば,Dropoutなど).詳しくはこちらの train
の項をお読みください:Configuration Keys.
with chainer.using_config('train', False):
--- 何か推論処理 ---
4.2.6.3.2. chainer.using_config('enable_backprop', False)
¶
評価のみ行うことを考えた場合,出力の計算後に損失関数の各パラメータについての勾配の情報は不要なため,chainer.using_config('enable_backprop', False)
とすることで,無駄な計算グラフの構築が行われず,メモリ消費量を節約することができます.詳しくはこちらの enable_backprop
の項をお読みください:Configuration Keys.
4.2.6.3.3. ChainerのConfig¶
Chainerにはこの他にも,いくつかのグローバルなConfigが用意されています.また,chainer.config
以下にユーザが自由な設定値を置くこともできます.詳しくはこちらを一読してください:Configuring Chainer
4.2.7. 学習済みモデルの保存¶
学習が終了後,その結果を保存します.Chainerには,2種類のフォーマットで学習済みネットワークをファイルに保存する機能が用意されています.一つはHDF5形式,もう一つはNumPyのNPZ形式で,ネットワークを保存します.今回は,追加ライブラリのインストールが必要なHDF5ではなく,NumPy標準機能で提供されているシリアライズ機能(numpy.savez()
)を利用したNPZ形式でのモデルの保存を行います.
[15]:
from chainer import serializers
serializers.save_npz('my_mnist.model', net)
[16]:
# 保存されていることを確認
%ls -la my_mnist.model
-rw-r--r-- 1 root root 334084 Dec 9 11:13 my_mnist.model
4.2.8. 保存したモデルを読み込んで推論¶
学習が終了して保存したモデルを読み込み,推論を行う方法について説明します.はじめに,学習に利用したネットワークを再度インスタンス化して,そこにさきほど保存したNPZファイルを読み込ませます.
[17]:
# まず同じネットワークのオブジェクトを作る
infer_net = MLP()
# そのオブジェクトに保存済みパラメータをロードする
serializers.load_npz('my_mnist.model', infer_net)
以上で準備が整いました.それでは,試しにテストデータの中から一つ目の画像を取ってきて,それに対する分類を行ってみましょう.
[18]:
gpu_id = 0 # CPUで計算をしたい場合は,-1を指定してください
if gpu_id >= 0:
infer_net.to_gpu(gpu_id)
# 1つ目のテストデータを取り出します
x, t = test[0] # tは使わない
# どんな画像か表示してみます
plt.imshow(x.reshape(28, 28), cmap='gray')
plt.show()
# ミニバッチの形にする(複数の画像をまとめて推論に使いたい場合は,サイズnのミニバッチにしてまとめればよい)
print('元の形:', x.shape, end=' -> ')
x = x[None, ...]
print('ミニバッチの形にしたあと:', x.shape)
# ネットワークと同じデバイス上にデータを送る
x = infer_net.xp.asarray(x)
# モデルのforward関数に渡す
with chainer.using_config('train', False), chainer.using_config('enable_backprop', False):
y = infer_net(x)
# Variable形式で出てくるので中身を取り出す
y = y.array
# 結果をCPUに送る
y = to_cpu(y)
# 予測確率の最大値のインデックスを見る
pred_label = y.argmax(axis=1)
print('ネットワークの予測:', pred_label[0])
元の形: (784,) -> ミニバッチの形にしたあと: (1, 784)
ネットワークの予測: 7
ネットワークの予測は7でした.画像を見る限り,正しく予測できていることが確認できます.
4.3. Trainerの使用方法¶
Chainerは,これまで書いてきたような学習ループを隠蔽するTrainer
という機能を提供しています.これを使うと,学習ループを自ら書く必要がなくなり,また便利な拡張機能(Extention
)を使うことで,学習過程での学習曲線の可視化や,ログの保存なども簡単に行うことができます.
4.3.1. データセット・Iterator・ネットワークの準備¶
データセット,Iterator,ネットワークは,Trainerを使用する場合にも同様に準備します.
[19]:
reset_seed(0)
train_val, test = mnist.get_mnist()
train, valid = split_dataset_random(train_val, 50000, seed=0)
batchsize = 128
train_iter = iterators.SerialIterator(train, batchsize)
valid_iter = iterators.SerialIterator(valid, batchsize, False, False)
test_iter = iterators.SerialIterator(test, batchsize, False, False)
gpu_id = 0 # CPUを用いたい場合は,-1を指定してください
net = MLP()
if gpu_id >= 0:
net.to_gpu(gpu_id)
4.3.2. Updaterの準備¶
学習ループを自分で書く場合の学習ステップについて再度確認すると,「データセットからミニバッチを作成」「ネットワークに入力して予測を出力」「正解と比較し誤差を計算」「バックワード(誤差逆伝播)を実行」「Optimizer
によってパラメータを更新」という一連のステップを,以下のように書いていました.
# ---------- 学習の1イテレーション ----------
train_batch = train_iter.next()
x, t = concat_examples(train_batch, gpu_id)
# 予測値の計算
y = net(x)
# 損失の計算
loss = F.softmax_cross_entropy(y, t)
# 勾配の計算
net.cleargrads()
loss.backward()
# パラメータの更新
optimizer.update()
Chainerの機能として提供されているUpdater
を用いることで,これらの一連の処理を簡単に書けるようになります.Updater
にはIterator
とOptimizer
を渡します.
Iterator
はデータセットオブジェクトを持っているため,そこからミニバッチを作成します.Optimizer
は最適化対象のネットワークを持っているため,それを使って順伝播と誤差計算・パラメータのアップデートをすることができます.従って,この2つを渡すことで,Updater
内で全ての処理が完結します.さっそく,Updater
オブジェクトを作成してみましょう.
[20]:
from chainer import training
gpu_id = 0 # CPUを使いたい場合は-1を指定してください
# ネットワークをClassifierで包んで,損失の計算などをモデルに含める
net = L.Classifier(net)
# 最適化手法の選択
optimizer = optimizers.SGD(lr=0.01).setup(net)
# UpdaterにIteratorとOptimizerを渡す
updater = training.StandardUpdater(train_iter, optimizer, device=gpu_id)
4.3.2.1. 損失計算のためのChain¶
ここでは,ネットワークをL.Classifier
で包んでいます.L.Classifier
は,渡されたネットワーク自体をpredictor
というattributeに持ち,損失計算を行う機能を追加してくれます.こうすることで,net()
はデータx
だけでなくラベルt
も取るようになり,受け取ったデータをpredictor
に通して予測値を計算し,正解ラベルt
と比較して損失のVariableを返します.損失関数として何を用いるかはデフォルトではF.softmax_cross_entropy
となっていますが,L.Classifier
の引数lossfun
に損失計算を行う関数を渡してやれば変更することができ,(Classifierという名前ながら)回帰問題などの損失計算機能の追加にも使うことができます.(L.Classifier(net, lossfun=L.mean_squared_error, compute_accuracy=False)
のようにする)
StandardUpdater
は前述のようなUpdater
の担当する処理を遂行するための最もシンプルなクラスです.この他にも複数のGPUを用いるためのParallelUpdater
などが用意されています.
4.3.3. Trainerの準備¶
実際に学習ループ部分を隠蔽しているのはUpdater
ですが,Trainer
はさらにUpdater
を受け取って学習全体の管理を行う機能を提供しています.例えば,データセットを何周したら学習を終了するか(stop_trigger) や,途中の損失の値をどのファイルに保存したいか,学習曲線を可視化した画像ファイルを保存するかどうかなど,学習全体の設定として必須・もしくはあると便利な色々な機能を提供しています.
必須なものとしては学習終了のタイミングを指定するstop_trigger
がありますが,これはTrainer
オブジェクトを作成するときのコンストラクタで指定します.指定の方法は単純で,(長さ, 単位)
という形のタプルを与えればよいだけです.「長さ」には数字を,「単位」には'iteration'
もしくは'epoch'
のいずれかの文字列を指定します.こうすると,たとえば100 epoch(データセット100周)で学習を終了してください,とか,1000
iteration(1000回更新)で学習を終了してください,といったことが指定できます.Trainer
を作るときに,stop_trigger
を指定しないと,学習は自動的には止まりません.
では,実際にTrainer
オブジェクトを作ってみましょう.
[21]:
max_epoch = 10
# TrainerにUpdaterを渡す
trainer = training.Trainer(
updater, (max_epoch, 'epoch'), out='results/mnist_result')
out
引数では,この次に説明するExtension
を使って,ログファイルや損失の変化の過程を描画したグラフの画像ファイルなどを保存するディレクトリを指定しています.
Trainerと,その内側にあるいろいろなオブジェクトの関係は,図にまとめると以下のようになっています.このイメージを持っておくと自分で部分的に改造したりする際に便利だと思います.
4.3.4. TrainerにExtensionを追加¶
Trainer
を使う利点として,
- ログを自動的にファイルに保存(
LogReport
) - ターミナルに定期的に損失などの情報を表示(
PrintReport
) - 損失を定期的にグラフで可視化して画像として保存(
PlotReport
) - 定期的にモデルやOptimizerの状態を自動シリアライズ(
snapshot
) - 学習の進捗を示すプログレスバーを表示(
ProgressBar
) - ネットワークの構造をGraphvizのdot形式で保存(
dump_graph
) - ネットワークのパラメータの平均や分散などの統計情報を出力(
ParameterStatistics
)
などの様々な便利な機能を簡単に利用することができる点があります.これらの機能を利用するには,Trainer
オブジェクトに対してextend
メソッドを使って追加したいExtension
のオブジェクトを渡すだけです.では実際に幾つかのExtension
を追加してみましょう.
[22]:
from chainer.training import extensions
trainer.extend(extensions.LogReport())
trainer.extend(extensions.snapshot(filename='snapshot_epoch-{.updater.epoch}'))
trainer.extend(extensions.Evaluator(valid_iter, net, device=gpu_id), name='val')
trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy', 'val/main/loss', 'val/main/accuracy', 'l1/W/data/std', 'elapsed_time']))
trainer.extend(extensions.ParameterStatistics(net.predictor.l1, {'std': np.std}))
trainer.extend(extensions.PlotReport(['l1/W/data/std'], x_key='epoch', file_name='std.png'))
trainer.extend(extensions.PlotReport(['main/loss', 'val/main/loss'], x_key='epoch', file_name='loss.png'))
trainer.extend(extensions.PlotReport(['main/accuracy', 'val/main/accuracy'], x_key='epoch', file_name='accuracy.png'))
trainer.extend(extensions.dump_graph('main/loss'))
4.3.4.1. LogReport
¶
epoch
やiteration
ごとのloss
, accuracy
などを自動的に集計し,log
というファイル名で保存します.
4.3.4.2. snapshot
¶
Trainer
オブジェクトを指定されたタイミング(デフォルトでは1エポックごと)で保存します.Trainer
オブジェクトは上述のようにUpdater
を持っており,この中にOptimizer
とモデルが保持されているため,このExtension
でスナップショットをとっておけば,その時点から学習を再開させたり,学習済みモデルを使った推論などが可能になります.
4.3.4.3. dump_graph
¶
指定されたVariable
オブジェクトから辿れる計算グラフをGraphvizのdot形式で保存します.
4.3.4.4. Evaluator
¶
評価用のデータセットのIterator
と,学習に使うモデルのオブジェクトを渡しておくことで,学習中のモデルを指定されたタイミングで評価用データセットを用いて評価します.内部では,chainer.config.using_config('train', False)
が自動的に行われます.backprop_enable
をFalse
にすることは行われないため,メモリ使用効率はデフォルトでは最適ではありませんが,基本的にはEvaluator
を使えば評価を行えるという点において問題はありません.
4.3.4.5. PrintReport
¶
LogReport
と同様に集計された値を標準出力に出力します.この際,どの値を出力するかをリストの形で与えます.
4.3.4.6. PlotReport
¶
引数のリストで指定された値の変遷をmatplotlib
ライブラリを使ってグラフに描画し,出力ディレクトリにfile_name
引数で指定されたファイル名で画像として保存します.
4.3.4.7. ParameterStatistics
¶
指定したレイヤ(Link)が持つパラメータの平均・分散・最小値・最大値などなどの統計情報を計算して,ログに保存します.パラメータが発散していないかなどをチェックするのに便利です.
これらのExtension
は,ここで紹介した以外にも,例えばtrigger
によって個別に作動するタイミングを指定できるなどのいくつかのオプションを持っており,より柔軟に組み合わせることができます.詳しくは公式のドキュメントを見てください.
4.3.5. 学習の開始 (Trainer利用)¶
学習を開始するために,Trainer
オブジェクトのメソッドrun
を実行してください.
[23]:
trainer.run()
epoch main/loss main/accuracy val/main/loss val/main/accuracy l1/W/data/std elapsed_time
1 1.66917 0.599904 0.938911 0.806764 0.0359232 4.00874
2 0.673337 0.843211 0.519283 0.86699 0.0366054 6.59789
3 0.459913 0.878686 0.414855 0.887757 0.0370351 9.17739
4 0.38953 0.893262 0.370488 0.896855 0.037301 11.7382
5 0.353165 0.901215 0.342328 0.90447 0.03749 14.4128
6 0.33014 0.90609 0.32212 0.90981 0.037639 17.1353
7 0.312328 0.910906 0.30679 0.913172 0.0377671 19.834
8 0.298127 0.914704 0.295095 0.915744 0.0378811 22.4303
9 0.28583 0.917659 0.284156 0.918513 0.0379864 25.0918
10 0.275227 0.921096 0.274761 0.921677 0.0380852 27.7848
学習ループを自分で書いた場合よりも遥かに簡単に,同様の結果を得ることができました.さらに,Extension
の機能を利用することで,様々なスコアや,学習曲線の可視化も自動で出力されます.
では,保存されている損失のグラフを確認してみましょう.
[24]:
from IPython.display import Image
Image(filename='results/mnist_result/loss.png')
[24]:
精度のグラフも見てみましょう.
[25]:
Image(filename='results/mnist_result/accuracy.png')
[25]:
もう少し学習を続ければ,さらに精度の向上が期待できそうです.
最後に,dump_graph
というExtension
によって出力された計算グラフのファイルを,Graphviz
で画像化してみましょう.
[26]:
!dot -Tpng results/mnist_result/cg.dot -o results/mnist_result/cg.png
[27]:
Image(filename='results/mnist_result/cg.png')
[27]:
データやパラメータが関数に次々と渡され,損失が出力されるまでの一連の計算過程が確認できます.
4.3.6. テストデータ評価¶
Validationデータに対する評価を学習中に行うために使用されるEvaluatorは,Trainerと関係なく独立して使うこともできます.以下のようにしてIterator
とネットワークのオブジェクト(net
),使用するデバイスIDを渡してEvaluatorオブジェクトを作成し,これを関数として実行するだけです.
[28]:
test_evaluator = extensions.Evaluator(test_iter, net, device=gpu_id)
results = test_evaluator()
print('Test accuracy:', results['main/accuracy'])
Test accuracy: 0.9250395
4.3.7. 学習済みモデルで推論する¶
それでは,Trainer Extensionのsnapshotが保存した学習済みパラメータを読み込んで,以前と同様に1番目のテストデータで推論を行ってみましょう.
ここで一点注意が必要ですが,snapshotが保存するnpzファイルはTrainer全体のスナップショットとなっており,学習の再開に必要となるextensionの内部のパラメータなども一緒に保存されています.しかし,今回はネットワークのパラメータだけを読み込めば良いので, serializers.load_npz()
のpath
引数にネットワーク部分までのパスを指定します.こうすることで,ネットワークのオブジェクトにパラメータだけを読み込ませることができます.
[29]:
reset_seed(0)
infer_net = MLP()
serializers.load_npz(
'results/mnist_result/snapshot_epoch-10',
infer_net, path='updater/model:main/predictor/')
if gpu_id >= 0:
infer_net.to_gpu(gpu_id)
x, t = test[0]
plt.imshow(x.reshape(28, 28), cmap='gray')
plt.show()
x = infer_net.xp.asarray(x[None, ...])
with chainer.using_config('train', False), chainer.using_config('enable_backprop', False):
y = infer_net(x)
y = to_cpu(y.array)
print('予測ラベル:', y.argmax(axis=1)[0])
予測ラベル: 7
無事正解できていることが確認できました.
4.4. 新しいネットワークの利用¶
ここでは,MNISTデータセットではなくCIFAR10という32x32サイズの小さなカラー画像に10クラスのいずれかのラベルがついたデータセットを用いて,いろいろなモデルを自分で書いて試行錯誤する流れを体験してみます.
airplane | automobile | bird | cat | deer | dog | frog | horse | ship | truck |
---|---|---|---|---|---|---|---|---|---|
4.4.1. 新しいネットワークの定義¶
ここでは,さきほど試した全結合層だけからなるネットワークではなく,前章で紹介した,畳込み層を持つネットワークを定義してみます.3つの畳み込み層を持ち,2つの全結合層がそのあとに続いています.
[30]:
class MyNet(chainer.Chain):
def __init__(self, n_out):
super(MyNet, self).__init__()
with self.init_scope():
self.conv1 = L.Convolution2D(None, 32, 3, 3, 1)
self.conv2 = L.Convolution2D(32, 64, 3, 3, 1)
self.conv3 = L.Convolution2D(64, 128, 3, 3, 1)
self.fc4 = L.Linear(None, 1000)
self.fc5 = L.Linear(1000, n_out)
def forward(self, x):
h = F.relu(self.conv1(x))
h = F.relu(self.conv2(h))
h = F.relu(self.conv3(h))
h = F.relu(self.fc4(h))
h = self.fc5(h)
return h
4.4.2. 学習¶
ここで,あとから別のネットワークも簡単に同じ設定で訓練できるよう,train
関数を作っておきます.これは,
- ネットワークのオブジェクト
- バッチサイズ
- 使用するGPU ID
- 学習を終了するエポック数
- データセットオブジェクト
- 学習率の初期値
- 学習率減衰のタイミング
などを渡すと,内部でTrainer
を用いて渡されたデータセットを使ってネットワークを訓練し,学習が終了した状態のネットワークを返してくれる関数です.Trainer.run()
が終了した後に,テストデータセットを使って評価まで行ってくれます.先程のMNISTでの例と違い,最適化手法にはMomentumSGDを用い,ExponentialShift
というExtention
を使って,指定したタイミングごとに学習率を減衰させるようにしてみます.
また,ここではcifar.get_cifar10()
が返す学習用データセットのうち9割のデータをtrain
,残りの1割をvalid
として使うようにしています.
このtrain
関数を用いて,上で定義したMyNet
モデルを訓練してみます.
[31]:
from chainer.datasets import cifar
def train(network_object, batchsize=128, gpu_id=0, max_epoch=20, train_dataset=None, valid_dataset=None, test_dataset=None, postfix='', base_lr=0.01, lr_decay=None, snapshot=None):
# 1. Dataset
if train_dataset is None and valid_dataset is None and test_dataset is None:
train_val, test = cifar.get_cifar10()
train_size = int(len(train_val) * 0.9)
train, valid = split_dataset_random(train_val, train_size, seed=0)
else:
train, valid, test = train_dataset, valid_dataset, test_dataset
# 2. Iterator
train_iter = iterators.MultiprocessIterator(train, batchsize)
valid_iter = iterators.MultiprocessIterator(valid, batchsize, False, False)
# 3. Model
net = L.Classifier(network_object)
# 4. Optimizer
optimizer = optimizers.MomentumSGD(lr=base_lr).setup(net)
optimizer.add_hook(chainer.optimizer.WeightDecay(0.0005))
# 5. Updater
updater = training.StandardUpdater(train_iter, optimizer, device=gpu_id)
# 6. Trainer
trainer = training.Trainer(updater, (max_epoch, 'epoch'), out='results/{}_cifar10_{}result'.format(network_object.__class__.__name__, postfix))
# 7. Trainer extensions
trainer.extend(extensions.LogReport())
trainer.extend(extensions.observe_lr())
trainer.extend(extensions.snapshot(filename='snapshot_epoch_{.updater.epoch}'), trigger=(10, 'epoch'))
trainer.extend(extensions.Evaluator(valid_iter, net, device=gpu_id), name='val')
trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy', 'val/main/loss', 'val/main/accuracy', 'elapsed_time', 'lr']))
trainer.extend(extensions.PlotReport(['main/loss', 'val/main/loss'], x_key='epoch', file_name='loss.png'))
trainer.extend(extensions.PlotReport(['main/accuracy', 'val/main/accuracy'], x_key='epoch', file_name='accuracy.png'))
if lr_decay is not None:
trainer.extend(extensions.ExponentialShift('lr', 0.1), trigger=lr_decay)
if snapshot is not None:
chainer.serializers.load_npz(snapshot, trainer)
trainer.run()
del trainer
# 8. Evaluation
test_iter = iterators.MultiprocessIterator(test, batchsize, False, False)
test_evaluator = extensions.Evaluator(test_iter, net, device=gpu_id)
results = test_evaluator()
print('Test accuracy:', results['main/accuracy'])
return net
[32]:
net = train(MyNet(10), gpu_id=0)
Downloading from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz...
epoch main/loss main/accuracy val/main/loss val/main/accuracy elapsed_time lr
1 1.92583 0.305065 1.72466 0.39668 4.88937 0.01
2 1.60857 0.423007 1.53026 0.463281 9.25879 0.01
3 1.47209 0.46964 1.48127 0.478125 13.5662 0.01
4 1.39223 0.499911 1.39299 0.499609 18.0876 0.01
5 1.32882 0.526197 1.3789 0.511719 22.5673 0.01
6 1.26765 0.547852 1.35271 0.516406 27.1432 0.01
7 1.21327 0.568999 1.25582 0.560547 31.8979 0.01
8 1.16433 0.583984 1.22899 0.570508 36.4486 0.01
9 1.12036 0.602384 1.23554 0.565039 40.9875 0.01
10 1.07057 0.61899 1.21995 0.56543 45.5839 0.01
11 1.02992 0.636808 1.1724 0.585938 50.4524 0.01
12 0.98116 0.653112 1.19605 0.579883 55.0429 0.01
13 0.938254 0.667392 1.159 0.59375 59.4494 0.01
14 0.901819 0.681067 1.20838 0.579492 64.0684 0.01
15 0.855333 0.698287 1.19485 0.585938 68.5982 0.01
16 0.810262 0.714321 1.19381 0.583984 73.0674 0.01
17 0.764117 0.731423 1.21938 0.587109 77.5318 0.01
18 0.72205 0.743697 1.20823 0.585742 81.9437 0.01
19 0.666414 0.764712 1.23899 0.593164 86.2922 0.01
20 0.620457 0.782715 1.24922 0.597461 90.6681 0.01
Test accuracy: 0.6065071
学習が20エポックまで終わりました.損失と精度のプロットを見てみましょう.
[33]:
Image(filename='results/MyNet_cifar10_result/loss.png')
[33]:
[34]:
Image(filename='results/MyNet_cifar10_result/accuracy.png')
[34]:
学習データでの精度(main/accuracy
)は77%程度まで到達していますが,テストデータでの損失(val/main/loss
)は途中から下げ止まり,精度(val/main/accuracy
)も60%前後で頭打ちになってしまっています.表示されたログの最後の行を確認すると,テストデータでの精度も同様に60%程度となっています.学習データでは精度が良いが, テストデータでは精度が良くない場合,モデルが学習データにオーバーフィッティングしていると考えられます.
4.4.3. 学習済みネットワークを使った予測¶
テスト精度は60%程度でしたが,試しにこの学習済みネットワークを使っていくつかのテスト画像を分類させてみましょう.あとで使いまわせるようにpredict
関数を作っておきます.
[35]:
cls_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
def predict(net, image_id):
_, test = cifar.get_cifar10()
x, t = test[image_id]
net.to_cpu()
with chainer.using_config('train', False), chainer.using_config('enable_backprop', False):
y = net.predictor(x[None, ...]).data.argmax(axis=1)[0]
plt.imshow(x.transpose(1, 2, 0))
plt.show()
print('predicted_label:', cls_names[y])
print('answer:', cls_names[t])
for i in range(10, 15):
predict(net, i)
predicted_label: airplane
answer: airplane
predicted_label: truck
answer: truck
predicted_label: dog
answer: dog
predicted_label: horse
answer: horse
predicted_label: truck
answer: truck
うまく分類できているものもあれば,そうでないものもありました.ネットワークの学習に使用したデータセット上ではほぼ百発百中で正解できても,未知のデータ,すなわちテストデータセットの画像に対して高精度な予測ができなければ意味がありません.テストデータでの精度は,モデルの汎化性能に関係していると言われています.
どうすれば高い汎化性能を持つネットワークを設計し,学習することができるでしょうか?これは非常に難しい問いですが,機械学習を使った応用を考えるとき,最も重要な問いの一つです.
4.4.4. 深いネットワークの定義¶
では,さきほどのネットワークよりも多層のネットワークを定義してみましょう.ここでは,1層の畳み込みネットワークをConvBlock
,1層の全結合ネットワークをLinearBlock
として定義し,これを数多く積み重ねることで大きなネットワークを定義してみます.
4.4.4.1. 構成要素を定義する¶
まず,ネットワークの構成要素となるConvBlock
とLinearBlock
を定義してみましょう.
[36]:
class ConvBlock(chainer.Chain):
def __init__(self, n_ch, pool_drop=False):
w = chainer.initializers.HeNormal()
super(ConvBlock, self).__init__()
with self.init_scope():
self.conv = L.Convolution2D(None, n_ch, 3, 1, 1, nobias=True, initialW=w)
self.bn = L.BatchNormalization(n_ch)
self.pool_drop = pool_drop
def forward(self, x):
h = F.relu(self.bn(self.conv(x)))
if self.pool_drop:
h = F.max_pooling_2d(h, 2, 2)
h = F.dropout(h, ratio=0.25)
return h
class LinearBlock(chainer.Chain):
def __init__(self, drop=False):
w = chainer.initializers.HeNormal()
super(LinearBlock, self).__init__()
with self.init_scope():
self.fc = L.Linear(None, 1024, initialW=w)
self.drop = drop
def forward(self, x):
h = F.relu(self.fc(x))
if self.drop:
h = F.dropout(h)
return h
ConvBlock
はChain
を継承した小さなネットワークとして定義されており,一つの畳み込み層とBatch Normalization層で構成されます.Batch Normalization層は,ネットワークの学習プロセスを安定させるために広く利用されている手法の一つで,例えば今回のように,畳み込み層の直後に挿入する形で利用されます.forward
メソッドでは,これらにデータを渡しつつ,活性化関数ReLUを適用して,さらにpool_drop
引数がTrue
であれば,Max
PoolingとDropoutを適用するような順伝播の計算が行われます.Dropoutは,ネットワークの過学習を避けて汎化性能を上げる目的で利用される手法の一つで,層の中のノードのうち,一定割合(dropout ratioと呼ばれる)をランダムに無効にしながら学習を行います(無効にする割合はratio
という引数で指定でき,何も指定しなければ50%が無効化されます).推論時は,dropout
ratioを\(p\)とすると,Dropout層への入力をただ\(p\)倍して出力するだけの層として働きます.これによって,擬似的に複数のネットワークの学習結果をアンサンブル(参考:Ensemble
averaging)するような効果があると言われ,汎化性能が向上する場合があります.最適化の際にモデルのパラメータに何らかの制約を与えて汎化性能を向上させるための工夫は正則化(regularization)と呼ばれ,このDropoutやパラメータの絶対値が大きくなりすぎないようにするWeight decayなどの方法が知られています.
Chainerでは,Pythonを使って書いたforward計算のコード自体がネットワークの構造を表します.すなわち,実行時にデータがどの層を通過していったか,によってネットワークそのものが定義されます.この性質によって,上記のような分岐などを含むネットワークも簡単に記述でき,柔軟かつシンプルで可読性の高いネットワーク定義が可能になります.これがDefine-by-Runの大きな特徴となっています.
4.4.4.2. 大きなネットワークの定義¶
次に,これらの小さなネットワークを構成要素として積み重ねて,大きなネットワークを定義してみましょう.
[37]:
class DeepCNN(chainer.ChainList):
def __init__(self, n_output):
super(DeepCNN, self).__init__(
ConvBlock(64),
ConvBlock(64, True),
ConvBlock(128),
ConvBlock(128, True),
ConvBlock(256),
ConvBlock(256),
ConvBlock(256),
ConvBlock(256, True),
LinearBlock(),
LinearBlock(),
L.Linear(None, n_output)
)
def forward(self, x):
for f in self.children():
x = f(x)
return x
ここで,ChainList
というクラスが利用されています.このクラスはChain
を継承したクラスで,いくつものLink
やChain
を順次呼び出していくようなネットワークを定義するときに便利です.ChainList
を継承して定義されるモデルは,親クラスのコンストラクタを呼び出す際に,キーワード引数ではなく通常の引数としてLink
もしくはChain
オブジェクトを渡すことができ,self.children()
メソッドによって登録した順番に取り出すことができます.この特徴を使うと,forward計算が上記のように簡単に記述可能となります.
4.4.4.3. 高速化のTIPS¶
今回は多くの畳込み層を使う大きなネットワークを使うので,Chainerが用意してくれているcuDNNのautotune機能を有効にしてみます.やり方は簡単で,以下の二行を事前に実行しておくだけです.これを有効にすると,cuDNNが自動的に高速な畳み込みのアルゴリズムを選択するなどの実行時の調整を行ってくれるようになります.
[38]:
chainer.cuda.set_max_workspace_size(1024 * 1024 * 1024)
chainer.config.autotune = True
それでは,学習を回してみます.今回はパラメータ数も多いので,学習を停止するエポック数を100に設定します.また,学習率を0.1から始めて,30エポックごとに10分の1にするように設定します.
本来は,以下の2行を実行することで乱数シードを固定し,100エポック分上で定義した DeepCNN
というクラスが表すモデルの学習ができるのですが,これは40分以上の時間を要するので,今回は事前に90エポックまで学習を進めておいた重みを読み込んで,90エポック終了時点から学習を再開し,最後の10エポックだけ実際にここで学習を回すことにします.
[39]:
!wget https://github.com/japan-medical-ai/medical-ai-course-materials/releases/download/v0.1/DeepCNN_cifar10_snapshot_epoch_90.npz
reset_seed(0)
model = train(DeepCNN(10), max_epoch=100, base_lr=0.1, lr_decay=(30, 'epoch'), snapshot='DeepCNN_cifar10_snapshot_epoch_90.npz')
--2019-12-09 11:16:00-- https://github.com/japan-medical-ai/medical-ai-course-materials/releases/download/v0.1/DeepCNN_cifar10_snapshot_epoch_90.npz
Resolving github.com (github.com)... 13.250.177.223
Connecting to github.com (github.com)|13.250.177.223|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://github-production-release-asset-2e65be.s3.amazonaws.com/153412006/4fcc1200-eeb7-11e8-8ca0-9095e5bca078?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20191209%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20191209T111601Z&X-Amz-Expires=300&X-Amz-Signature=f2d9063e8a6ec0edd8567b4ae8caeebbd09f2d1e6eabc59db80bcd2c519f2787&X-Amz-SignedHeaders=host&actor_id=0&response-content-disposition=attachment%3B%20filename%3DDeepCNN_cifar10_snapshot_epoch_90.npz&response-content-type=application%2Foctet-stream [following]
--2019-12-09 11:16:01-- https://github-production-release-asset-2e65be.s3.amazonaws.com/153412006/4fcc1200-eeb7-11e8-8ca0-9095e5bca078?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20191209%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20191209T111601Z&X-Amz-Expires=300&X-Amz-Signature=f2d9063e8a6ec0edd8567b4ae8caeebbd09f2d1e6eabc59db80bcd2c519f2787&X-Amz-SignedHeaders=host&actor_id=0&response-content-disposition=attachment%3B%20filename%3DDeepCNN_cifar10_snapshot_epoch_90.npz&response-content-type=application%2Foctet-stream
Resolving github-production-release-asset-2e65be.s3.amazonaws.com (github-production-release-asset-2e65be.s3.amazonaws.com)... 52.216.164.75
Connecting to github-production-release-asset-2e65be.s3.amazonaws.com (github-production-release-asset-2e65be.s3.amazonaws.com)|52.216.164.75|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 56889445 (54M) [application/octet-stream]
Saving to: ‘DeepCNN_cifar10_snapshot_epoch_90.npz’
DeepCNN_cifar10_sna 100%[===================>] 54.25M 12.5MB/s in 5.4s
2019-12-09 11:16:07 (10.1 MB/s) - ‘DeepCNN_cifar10_snapshot_epoch_90.npz’ saved [56889445/56889445]
/usr/local/lib/python3.6/dist-packages/chainer/training/triggers/interval_trigger.py:92: UserWarning: The previous value of iteration is not saved. IntervalTrigger guesses it using current iteration. If this trigger is not called at every iteration, it may not work correctly.
'The previous value of iteration is not saved. '
/usr/local/lib/python3.6/dist-packages/chainer/training/triggers/interval_trigger.py:104: UserWarning: The previous value of epoch_detail is not saved. IntervalTrigger uses the value of trainer.updater.previous_epoch_detail. If this trigger is not called at every iteration, it may not work correctly.
'The previous value of epoch_detail is not saved. '
epoch main/loss main/accuracy val/main/loss val/main/accuracy elapsed_time lr
1 2.62849 0.144931 2.22525 0.15625 27.9748 0.1
2 2.11316 0.210804 1.97533 0.266406 54.5483 0.1
3 1.87483 0.289396 1.8026 0.320508 81.0827 0.1
4 1.74066 0.340443 1.74728 0.358203 107.7 0.1
5 1.58789 0.409411 1.63916 0.40332 134.235 0.1
6 1.38757 0.492831 1.22399 0.561719 161.093 0.1
7 1.2036 0.566128 1.32939 0.553906 187.55 0.1
8 1.07527 0.617077 1.17079 0.589648 214.184 0.1
9 0.964984 0.660711 0.950063 0.67207 240.817 0.1
10 0.895905 0.688056 0.972697 0.657031 267.372 0.1
11 0.828796 0.715043 0.944733 0.686914 297.779 0.1
12 0.784123 0.731793 0.960205 0.687305 324.312 0.1
13 0.742033 0.748291 0.858308 0.719727 351.032 0.1
14 0.692516 0.76627 0.811853 0.725195 377.591 0.1
15 0.65644 0.776256 0.692374 0.767578 404.23 0.1
16 0.650682 0.780738 0.796731 0.733008 430.797 0.1
17 0.610249 0.793435 0.632162 0.791406 457.467 0.1
18 0.591896 0.803667 0.791757 0.739453 484.04 0.1
19 0.578718 0.804465 1.04659 0.671484 510.602 0.1
20 0.554954 0.814431 0.8127 0.733594 537.247 0.1
21 0.549188 0.814236 0.654926 0.787109 567.567 0.1
22 0.535866 0.820446 0.640323 0.788672 594.214 0.1
23 0.527765 0.823028 0.958373 0.70957 620.761 0.1
24 0.512286 0.830056 0.793664 0.748633 647.363 0.1
25 0.497195 0.833141 0.717548 0.759961 673.951 0.1
26 0.495153 0.835759 1.7557 0.511328 700.494 0.1
27 0.486062 0.837069 0.732518 0.775391 727.13 0.1
28 0.47963 0.839387 0.669157 0.786328 753.692 0.1
29 0.475502 0.838312 0.915904 0.718555 780.308 0.1
30 0.460236 0.844841 0.877713 0.72793 806.806 0.1
31 0.298483 0.897239 0.381921 0.879687 837.16 0.01
32 0.211182 0.92784 0.364046 0.882617 863.898 0.01
33 0.180429 0.937478 0.374651 0.883984 890.554 0.01
34 0.164156 0.943692 0.361041 0.888867 917.199 0.01
35 0.144584 0.950387 0.375391 0.889258 943.774 0.01
36 0.132288 0.954235 0.377427 0.890625 970.4 0.01
37 0.12103 0.957376 0.390434 0.892578 996.96 0.01
38 0.111974 0.961204 0.400307 0.886133 1023.62 0.01
39 0.102573 0.964476 0.399275 0.892773 1050.2 0.01
40 0.0972647 0.965931 0.432854 0.887109 1076.83 0.01
41 0.0928545 0.966597 0.418165 0.887305 1107.19 0.01
42 0.08498 0.96964 0.432462 0.884961 1133.75 0.01
43 0.0845448 0.970792 0.4365 0.880273 1160.37 0.01
44 0.0770674 0.973914 0.441935 0.883789 1187.02 0.01
45 0.0732439 0.974565 0.469901 0.881836 1213.62 0.01
46 0.070491 0.975962 0.453856 0.8875 1240.16 0.01
47 0.068846 0.97583 0.461264 0.881055 1266.75 0.01
48 0.0694101 0.976941 0.435111 0.885742 1293.25 0.01
49 0.0653772 0.977117 0.461284 0.882617 1319.87 0.01
50 0.0633419 0.97836 0.464232 0.889648 1346.47 0.01
51 0.0584663 0.979834 0.464193 0.885547 1376.6 0.01
52 0.0607617 0.979714 0.466352 0.881445 1403.21 0.01
53 0.0615791 0.978632 0.451807 0.889453 1429.69 0.01
54 0.0588031 0.979914 0.489054 0.881836 1456.28 0.01
55 0.0582368 0.979367 0.4719 0.882617 1482.83 0.01
56 0.0558719 0.981379 0.495846 0.878125 1509.45 0.01
57 0.0579962 0.979936 0.472415 0.875781 1536.06 0.01
58 0.0592009 0.9793 0.454762 0.88418 1562.6 0.01
59 0.0568546 0.980735 0.487556 0.876172 1589.17 0.01
60 0.0579785 0.980079 0.472908 0.883398 1615.93 0.01
61 0.0318918 0.98968 0.416953 0.895703 1646.15 0.001
62 0.0220127 0.993612 0.416859 0.899609 1672.69 0.001
63 0.0186169 0.99434 0.417849 0.898242 1699.29 0.001
64 0.0159041 0.995526 0.41804 0.900781 1725.78 0.001
65 0.0147089 0.995916 0.429896 0.899609 1752.39 0.001
66 0.0129457 0.996404 0.433748 0.898828 1779.01 0.001
67 0.0131643 0.996283 0.433923 0.898828 1805.54 0.001
68 0.0112659 0.996893 0.437222 0.901758 1832.17 0.001
69 0.0106502 0.997151 0.443475 0.901758 1858.73 0.001
70 0.0107926 0.997203 0.445066 0.900391 1885.34 0.001
71 0.0105973 0.997062 0.44159 0.898633 1915.8 0.001
72 0.00934292 0.99767 0.450084 0.897266 1942.45 0.001
73 0.0104884 0.997092 0.451691 0.899023 1969.06 0.001
74 0.00849317 0.997707 0.450391 0.9 1995.61 0.001
75 0.00846932 0.997891 0.451362 0.902148 2022.21 0.001
76 0.00826699 0.997841 0.448779 0.900781 2048.69 0.001
77 0.00875069 0.997492 0.45095 0.900391 2075.31 0.001
78 0.00823296 0.998019 0.449194 0.898438 2102.1 0.001
79 0.00701245 0.998113 0.454196 0.899609 2128.74 0.001
80 0.00846517 0.997596 0.455877 0.901172 2155.29 0.001
81 0.00677115 0.99818 0.459518 0.899805 2185.54 0.001
82 0.00717393 0.998047 0.465337 0.899805 2212.13 0.001
83 0.00709802 0.997908 0.464472 0.898828 2238.69 0.001
84 0.00699702 0.998091 0.470595 0.899219 2265.27 0.001
85 0.00746627 0.997975 0.470063 0.901172 2291.78 0.001
86 0.00666763 0.998069 0.466201 0.899414 2318.38 0.001
87 0.00616266 0.998531 0.462948 0.9 2344.86 0.001
88 0.00719447 0.997847 0.463587 0.899609 2371.48 0.001
89 0.00638493 0.998335 0.465655 0.901367 2398.08 0.001
90 0.0061445 0.998286 0.464918 0.900586 2424.59 0.001
91 0.00582547 0.99838 0.46484 0.900977 2439.32 0.0001
92 0.00602776 0.998331 0.461773 0.901367 2452.62 0.0001
93 0.00597372 0.998491 0.464172 0.900195 2466.17 0.0001
94 0.00619668 0.99813 0.46446 0.900391 2479.48 0.0001
95 0.00545569 0.998557 0.466654 0.900977 2492.64 0.0001
96 0.00613322 0.998308 0.465335 0.900781 2505.78 0.0001
97 0.0054181 0.998624 0.465642 0.900586 2518.9 0.0001
98 0.00512285 0.998801 0.467116 0.900781 2532.13 0.0001
99 0.00562234 0.998576 0.464966 0.901367 2545.35 0.0001
100 0.00551726 0.998624 0.462725 0.901367 2558.63 0.0001
Test accuracy: 0.8966574
ゼロから学習する場合:
reset_seed(0)
model = train(DeepCNN(10), max_epoch=100, base_lr=0.1, lr_decay=(30, 'epoch'))
学習が終了しました.学習曲線と精度のグラフを見てみましょう.
[40]:
Image(filename='results/DeepCNN_cifar10_result/loss.png')
[40]:
[41]:
Image(filename='results/DeepCNN_cifar10_result/accuracy.png')
[41]:
先程の浅い(層数の少ない)CNNを用いた際には60%前後だったValidationデータでの精度が,90%程度まで上がりました.また,テストデータを用いた精度も,およそ90%程度となっています.しかし最新の研究成果では97%以上まで達成されています.さらに精度を上げるには,今回行ったようなネットワークの構造自体の改良ももちろんのこと,学習データを擬似的に増やす操作(Data augmentation)や,複数のモデルの出力を一つの出力に統合する操作(Ensemble)などなど,いろいろな工夫が考えられます.
4.5. データセットクラスの使用方法¶
ここでは,Chainerにすでに用意されているCIFAR10のデータを取得する機能を使って,データセットクラスを自分で書いてみます.Chainerでは,データセットを表すクラスは以下の機能を持っている必要があります.
- データセット内のデータ数を返す
__len__
メソッド - 引数として渡される
i
に対応したデータもしくはデータとラベルの組を返すget_example
メソッド
その他のデータセットに必要な機能は,chainer.dataset.DatasetMixin
クラスを継承することで用意できます.ここでは,DatasetMixin
クラスを継承し,学習時に学習データに変換を施してモデルが受け取るデータのバリエーションを増やすData augmentation機能のついたデータセットクラスを作成してみましょう.
4.5.1. CIFAR10データセットクラス¶
[42]:
class CIFAR10Augmented(chainer.dataset.DatasetMixin):
def __init__(self, split='train', train_ratio=0.9):
train_val, test_data = cifar.get_cifar10()
train_size = int(len(train_val) * train_ratio)
train_data, valid_data = split_dataset_random(train_val, train_size, seed=0)
if split == 'train':
self.data = train_data
elif split == 'valid':
self.data = valid_data
elif split == 'test':
self.data = test_data
else:
raise ValueError("'split' argument should be either 'train', 'valid', or 'test'. But {} was given.".format(split))
self.split = split
self.random_crop = 4
def __len__(self):
return len(self.data)
def get_example(self, i):
x, t = self.data[i]
if self.split == 'train':
x = x.transpose(1, 2, 0)
h, w, _ = x.shape
x_offset = np.random.randint(self.random_crop)
y_offset = np.random.randint(self.random_crop)
x = x[y_offset:y_offset + h - self.random_crop,
x_offset:x_offset + w - self.random_crop]
if np.random.rand() > 0.5:
x = np.fliplr(x)
x = x.transpose(2, 0, 1)
return x, t
このクラスは,CIFAR10のデータのそれぞれに対し,
- 32x32の大きさの中からランダムに28x28の領域をクロップ
- 1/2の確率で左右を反転させる
という加工を行っています.このような操作を加えて擬似的に学習データのバリエーションを増やすことで,オーバーフィッティングの抑制などに寄与することが知られています.これらの操作以外にも,画像の色味を変化させるような変換やランダムな回転,アフィン変換など,さまざまな加工によって学習データ数を擬似的に増やす方法が提案されています.
4.5.2. 作成したデータセットクラスを用いた学習¶
それではさっそくこのCIFAR10
クラスを使って学習を行ってみましょう.先程と同じネットワークを用い,Data augmentationの効果がどの程度あるのかを調べてみましょう.train
関数も含め,データセットクラス以外は先程とすべて同様です.
ここでも,40分ほどの時間がかかりますので,上と同様に90エポックまで学習したあとのsnapshotをダウンロードして読み込ませ,最後の10エポックだけ実際に学習させてみましょう.
[43]:
!wget https://github.com/japan-medical-ai/medical-ai-course-materials/releases/download/v0.1/DeepCNN_cifar10_augmented_snapshot_epoch_90.npz
reset_seed(0)
model = train(DeepCNN(10), max_epoch=100, train_dataset=CIFAR10Augmented(), valid_dataset=CIFAR10Augmented('valid'), test_dataset=CIFAR10Augmented('test'), postfix='augmented_', base_lr=0.1, lr_decay=(30, 'epoch'), snapshot='DeepCNN_cifar10_augmented_snapshot_epoch_90.npz')
--2019-12-09 11:18:31-- https://github.com/japan-medical-ai/medical-ai-course-materials/releases/download/v0.1/DeepCNN_cifar10_augmented_snapshot_epoch_90.npz
Resolving github.com (github.com)... 52.74.223.119
Connecting to github.com (github.com)|52.74.223.119|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://github-production-release-asset-2e65be.s3.amazonaws.com/153412006/5064a880-eeb7-11e8-95bf-80b5d9533256?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20191209%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20191209T111832Z&X-Amz-Expires=300&X-Amz-Signature=1dc193a84a383234eb92fd9c7cb0db88c9ad7dad53e4e9a32d06f33a3825b832&X-Amz-SignedHeaders=host&actor_id=0&response-content-disposition=attachment%3B%20filename%3DDeepCNN_cifar10_augmented_snapshot_epoch_90.npz&response-content-type=application%2Foctet-stream [following]
--2019-12-09 11:18:32-- https://github-production-release-asset-2e65be.s3.amazonaws.com/153412006/5064a880-eeb7-11e8-95bf-80b5d9533256?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20191209%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20191209T111832Z&X-Amz-Expires=300&X-Amz-Signature=1dc193a84a383234eb92fd9c7cb0db88c9ad7dad53e4e9a32d06f33a3825b832&X-Amz-SignedHeaders=host&actor_id=0&response-content-disposition=attachment%3B%20filename%3DDeepCNN_cifar10_augmented_snapshot_epoch_90.npz&response-content-type=application%2Foctet-stream
Resolving github-production-release-asset-2e65be.s3.amazonaws.com (github-production-release-asset-2e65be.s3.amazonaws.com)... 52.216.241.116
Connecting to github-production-release-asset-2e65be.s3.amazonaws.com (github-production-release-asset-2e65be.s3.amazonaws.com)|52.216.241.116|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 56730280 (54M) [application/octet-stream]
Saving to: ‘DeepCNN_cifar10_augmented_snapshot_epoch_90.npz’
DeepCNN_cifar10_aug 100%[===================>] 54.10M 12.6MB/s in 5.1s
2019-12-09 11:18:38 (10.5 MB/s) - ‘DeepCNN_cifar10_augmented_snapshot_epoch_90.npz’ saved [56730280/56730280]
/usr/local/lib/python3.6/dist-packages/chainer/training/triggers/interval_trigger.py:92: UserWarning: The previous value of iteration is not saved. IntervalTrigger guesses it using current iteration. If this trigger is not called at every iteration, it may not work correctly.
'The previous value of iteration is not saved. '
/usr/local/lib/python3.6/dist-packages/chainer/training/triggers/interval_trigger.py:104: UserWarning: The previous value of epoch_detail is not saved. IntervalTrigger uses the value of trainer.updater.previous_epoch_detail. If this trigger is not called at every iteration, it may not work correctly.
'The previous value of epoch_detail is not saved. '
epoch main/loss main/accuracy val/main/loss val/main/accuracy elapsed_time lr
1 2.5875 0.156405 2.11656 0.203125 24.1767 0.1
2 1.99359 0.233842 1.84577 0.304492 47.9466 0.1
3 1.76968 0.325365 1.98983 0.26543 71.6864 0.1
4 1.61662 0.389537 2.07369 0.26875 95.5072 0.1
5 1.41259 0.478989 1.52089 0.446484 119.236 0.1
6 1.23382 0.555487 1.48775 0.480664 143.06 0.1
7 1.09323 0.613404 1.11949 0.590234 166.813 0.1
8 0.99303 0.650857 1.33017 0.566992 190.624 0.1
9 0.926848 0.678755 1.0075 0.665234 214.446 0.1
10 0.863165 0.702769 0.858035 0.717383 238.218 0.1
11 0.807948 0.727162 0.9556 0.679297 265.757 0.1
12 0.769224 0.739116 0.857765 0.711328 289.507 0.1
13 0.739977 0.752952 0.91583 0.711133 313.339 0.1
14 0.72064 0.759393 1.20587 0.61875 337.097 0.1
15 0.690136 0.77093 0.837919 0.726562 360.905 0.1
16 0.673935 0.774706 1.03539 0.678711 384.67 0.1
17 0.662879 0.778742 0.730712 0.758789 408.47 0.1
18 0.639202 0.78742 0.758566 0.765625 432.298 0.1
19 0.625988 0.792713 1.24791 0.664062 456.061 0.1
20 0.616269 0.795277 0.963706 0.70625 479.882 0.1
21 0.611734 0.795962 0.887129 0.723437 507.312 0.1
22 0.600444 0.800582 0.889526 0.710352 531.173 0.1
23 0.605317 0.800414 0.715702 0.756445 554.926 0.1
24 0.584194 0.805731 0.984225 0.694336 578.687 0.1
25 0.584041 0.805464 0.956576 0.685156 602.49 0.1
26 0.57384 0.80954 0.977559 0.712695 627.031 0.1
27 0.560405 0.814298 0.894127 0.718945 650.845 0.1
28 0.559933 0.816195 0.729981 0.752734 674.584 0.1
29 0.555933 0.814387 0.841304 0.727344 698.4 0.1
30 0.558057 0.814971 0.753542 0.757227 722.119 0.1
31 0.397613 0.866455 0.337977 0.888867 749.619 0.01
32 0.320024 0.890202 0.322082 0.894336 773.368 0.01
33 0.293655 0.900479 0.323365 0.888867 797.169 0.01
34 0.279553 0.904874 0.308263 0.897656 820.951 0.01
35 0.26519 0.909945 0.301763 0.897852 844.682 0.01
36 0.259166 0.909846 0.286909 0.904688 868.482 0.01
37 0.246168 0.915064 0.289997 0.904688 892.324 0.01
38 0.24097 0.91697 0.280986 0.903906 916.139 0.01
39 0.233951 0.918892 0.291962 0.904102 939.884 0.01
40 0.220939 0.923628 0.299502 0.902539 963.666 0.01
41 0.216055 0.924272 0.2946 0.905664 991.125 0.01
42 0.215143 0.926972 0.308637 0.897266 1014.88 0.01
43 0.213903 0.926625 0.291742 0.907812 1038.64 0.01
44 0.203283 0.929198 0.296043 0.905469 1062.33 0.01
45 0.19772 0.931041 0.327708 0.89375 1086.11 0.01
46 0.194907 0.932893 0.312555 0.901563 1109.82 0.01
47 0.190354 0.934326 0.336271 0.895703 1133.58 0.01
48 0.190431 0.932915 0.326305 0.902539 1157.3 0.01
49 0.18926 0.934326 0.312767 0.901172 1181.05 0.01
50 0.184469 0.936035 0.296937 0.907812 1204.93 0.01
51 0.181691 0.936521 0.324149 0.901172 1232.22 0.01
52 0.17546 0.939675 0.347524 0.893945 1256 0.01
53 0.175786 0.937723 0.335627 0.893164 1279.74 0.01
54 0.173612 0.940274 0.317897 0.902344 1303.52 0.01
55 0.171849 0.939548 0.306998 0.90625 1327.23 0.01
56 0.168304 0.94165 0.31145 0.902148 1351.04 0.01
57 0.170139 0.941495 0.301311 0.910156 1374.82 0.01
58 0.165011 0.942708 0.359516 0.892773 1398.52 0.01
59 0.163968 0.94256 0.365818 0.886133 1422.3 0.01
60 0.16541 0.942575 0.357 0.890234 1446.01 0.01
61 0.129435 0.955988 0.277052 0.915234 1473.42 0.001
62 0.101981 0.965434 0.284798 0.916406 1497.13 0.001
63 0.0953637 0.967285 0.279956 0.919727 1520.93 0.001
64 0.0911066 0.968171 0.28204 0.918359 1544.63 0.001
65 0.0853851 0.97037 0.28504 0.918945 1568.41 0.001
66 0.0800331 0.972101 0.287688 0.917969 1592.21 0.001
67 0.0761374 0.973202 0.29148 0.920117 1616.19 0.001
68 0.0756613 0.973699 0.299635 0.918945 1639.98 0.001
69 0.075577 0.97407 0.293845 0.918359 1663.68 0.001
70 0.0730666 0.974676 0.29563 0.920508 1687.47 0.001
71 0.070825 0.975516 0.295581 0.920313 1714.73 0.001
72 0.0710753 0.975697 0.29838 0.919336 1738.52 0.001
73 0.0705982 0.975142 0.298369 0.920508 1762.3 0.001
74 0.0667571 0.976562 0.299809 0.920508 1786.01 0.001
75 0.0642319 0.978427 0.300881 0.920898 1809.77 0.001
76 0.0640179 0.977742 0.304647 0.918359 1833.51 0.001
77 0.0629752 0.977761 0.299763 0.919336 1857.31 0.001
78 0.0586612 0.979523 0.306034 0.922461 1881.03 0.001
79 0.059752 0.979869 0.311227 0.921289 1904.82 0.001
80 0.0571715 0.980213 0.304607 0.920703 1928.51 0.001
81 0.0573339 0.980136 0.315108 0.92168 1956.06 0.001
82 0.0560348 0.979847 0.321934 0.916992 1979.87 0.001
83 0.0553193 0.980613 0.315378 0.914648 2003.58 0.001
84 0.0531816 0.98129 0.318977 0.919531 2027.34 0.001
85 0.0560367 0.98097 0.310993 0.919141 2051.02 0.001
86 0.0535048 0.981534 0.317829 0.920117 2075.04 0.001
87 0.0522188 0.981571 0.31144 0.920313 2098.73 0.001
88 0.0526632 0.982156 0.318594 0.920703 2122.46 0.001
89 0.0528096 0.981445 0.309017 0.92207 2146.21 0.001
90 0.0499371 0.982928 0.313269 0.920508 2169.92 0.001
91 0.046129 0.984375 0.312747 0.919141 2182.89 0.0001
92 0.0442634 0.984642 0.308682 0.921484 2195.39 0.0001
93 0.0456881 0.98422 0.308501 0.920898 2208.13 0.0001
94 0.0450576 0.984553 0.311052 0.921484 2220.46 0.0001
95 0.0450287 0.985152 0.310078 0.920117 2232.59 0.0001
96 0.0444907 0.984486 0.312837 0.921289 2244.59 0.0001
97 0.0439379 0.985418 0.310588 0.92168 2256.65 0.0001
98 0.0430186 0.985352 0.310459 0.920508 2268.82 0.0001
99 0.0419429 0.985288 0.310179 0.92168 2281.02 0.0001
100 0.0421573 0.985574 0.31397 0.920898 2293.27 0.0001
Test accuracy: 0.917227
先程のData augmentationなしの場合は90%程度だったテスト精度が,学習データにaugmentationを施すことでおよそ1.8%程度向上していることが分かりました.
損失と精度のグラフを見てみましょう.
[44]:
Image(filename='results/DeepCNN_cifar10_augmented_result/loss.png')
[44]:
[45]:
Image(filename='results/DeepCNN_cifar10_augmented_result/accuracy.png')
[45]:
4.6. Data Augmentationの簡単な使い方¶
前述のようにデータセット内の各画像についていろいろな変換を行って擬似的にデータを増やすような操作をData Augmentationといいます.上では,オリジナルのデータセットクラスを作る方法を示すために変換の操作もget_example()
内に書くという実装を行いましたが,実はもっと簡単にいろいろな変換をデータに対して行う方法があります.
それは,TransformDataset
クラスを使う方法です.TransformDataset
は,元になるデータセットオブジェクトと,そこからサンプルしてきた各データ点に対して行いたい変換を関数の形で与えると,変換済みのデータを返してくれるようなデータセットオブジェクトに加工してくれる便利なクラスです.簡単な使い方は以下のようになります.
[46]:
from chainer.datasets import TransformDataset
train_val, test_dataset = cifar.get_cifar10()
train_size = int(len(train_val) * 0.9)
train_dataset, valid_dataset = split_dataset_random(train_val, train_size, seed=0)
# 行いたい変換を関数の形で書く
def transform(inputs):
x, t = inputs
x = x.transpose(1, 2, 0)
h, w, _ = x.shape
x_offset = np.random.randint(4)
y_offset = np.random.randint(4)
x = x[y_offset:y_offset + h - 4,
x_offset:x_offset + w - 4]
if np.random.rand() > 0.5:
x = np.fliplr(x)
x = x.transpose(2, 0, 1)
return x, t
# 各データをtransform関数で処理して返すデータセットオブジェクト
train_dataset = TransformDataset(train_dataset, transform)
このようにして得られた新しいtrain_dataset
は,自作のデータセットクラスと同じような変換処理を行った上でデータを返してくれるデータセットオブジェクトとなります.
4.6.1. ChainerCVを活用した変換処理¶
さて,先ほどご紹介したコードでは,画像に対するランダムクロップ,及びランダムな左右反転の処理を自ら実装していました.もし,より多様な変換を行いたい場合,上記のtransform
関数に処理を追加していくことになりますが,一般的に用いられる変換処理をその度に自ら実装するのは手間です.そこで本項では最後に,ChainerCV[Niitani 2017]をご紹介します.ChainerCVは,Computer
Visionに特化した機能が豊富に追加された,Chainerの補助パッケージとしての役割を担うオープンソース・ソフトウェアです.
[47]:
!pip install chainercv
Collecting chainercv
Downloading https://files.pythonhosted.org/packages/e8/1c/1f267ccf5ebdf1f63f1812fa0d2d0e6e35f0d08f63d2dcdb1351b0e77d85/chainercv-0.13.1.tar.gz (260kB)
|████████████████████████████████| 266kB 35.5MB/s
Requirement already satisfied: chainer>=6.0 in /usr/local/lib/python3.6/dist-packages (from chainercv) (6.5.0)
Requirement already satisfied: Pillow in /usr/local/lib/python3.6/dist-packages (from chainercv) (4.3.0)
Requirement already satisfied: protobuf>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from chainer>=6.0->chainercv) (3.10.0)
Requirement already satisfied: six>=1.9.0 in /usr/local/lib/python3.6/dist-packages (from chainer>=6.0->chainercv) (1.12.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.6/dist-packages (from chainer>=6.0->chainercv) (3.0.12)
Requirement already satisfied: typing-extensions<=3.6.6 in /usr/local/lib/python3.6/dist-packages (from chainer>=6.0->chainercv) (3.6.6)
Requirement already satisfied: typing<=3.6.6 in /usr/local/lib/python3.6/dist-packages (from chainer>=6.0->chainercv) (3.6.6)
Requirement already satisfied: numpy>=1.9.0 in /usr/local/lib/python3.6/dist-packages (from chainer>=6.0->chainercv) (1.17.4)
Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from chainer>=6.0->chainercv) (42.0.1)
Requirement already satisfied: olefile in /usr/local/lib/python3.6/dist-packages (from Pillow->chainercv) (0.46)
Building wheels for collected packages: chainercv
Building wheel for chainercv (setup.py) ... done
Created wheel for chainercv: filename=chainercv-0.13.1-cp36-cp36m-linux_x86_64.whl size=537355 sha256=fc0aaac281e6d6effbb0699ea77604eacc16f6683691d9463803c54c192a746b
Stored in directory: /root/.cache/pip/wheels/ea/10/01/e221beaa4b3d8341aa819a39ab8d4677457c79c81f521f3a94
Successfully built chainercv
Installing collected packages: chainercv
Successfully installed chainercv-0.13.1
ChainerCVには,画像に対する様々な変換があらかじめ用意されています.
例えば,上でNumPyを使って書いていたランダムクロップやランダム左右反転は,chainercv.transforms
モジュールを使うと,それぞれ以下のように1行で書くことができます:
x = chainercv.transforms.random_crop(x, (28, 28)) # ランダムクロップ
x = chainercv.transforms.random_flip(x) # ランダム左右反転
chainercv.transforms
モジュールを使って,transform
関数をアップデートしてみましょう.ちなみに,get_cifar10()
で得られるデータセットでは,デフォルトで画像の画素値の範囲が[0, 1]
にスケールされています.しかし,get_cifar10()
にscale=255.
を渡しておくと,値の範囲をもともとの[0, 255]
のままにできます.今回行われる処理は,以下の5つです:
- PCA lighting: 先行研究(AlexNet)の学習で使われていた方法で,色を変化させる変換処理を行います.
- Standardization: 訓練用データセット全体からチャンネルごとの画素値の平均・標準偏差を求めて標準化をします
- Random flip: ランダムに画像の左右を反転します
- Random expand:
[1, 1.5]
からランダムに決めた大きさの黒いキャンバスを作り,その中のランダムな位置へ画像を配置します - Random crop:
(28, 28)
の大きさの領域をランダムにクロップします
[48]:
from functools import partial
from chainercv import transforms
train_val, test_dataset = cifar.get_cifar10(scale=255.)
train_size = int(len(train_val) * 0.9)
train_dataset, valid_dataset = split_dataset_random(train_val, train_size, seed=0)
mean = np.mean([x for x, _ in train_dataset], axis=(0, 2, 3))
std = np.std([x for x, _ in train_dataset], axis=(0, 2, 3))
def transform(inputs, mean, std, train=True):
img, label = inputs
img = img.copy()
# Color augmentation
if train:
img = transforms.pca_lighting(img, 76.5)
# Standardization
img -= mean[:, None, None]
img /= std[:, None, None]
# Random flip & crop
if train:
img = transforms.random_flip(img, x_random=True)
img = transforms.random_expand(img, max_ratio=1.5)
img = transforms.random_crop(img, (28, 28))
return img, label
train_dataset = TransformDataset(train_dataset, partial(transform, mean=mean, std=std, train=True))
valid_dataset = TransformDataset(valid_dataset, partial(transform, mean=mean, std=std, train=False))
test_dataset = TransformDataset(test_dataset, partial(transform, mean=mean, std=std, train=False))
では,standardizationとChainerCVによるPCA Lightingを追加したTransformDataset
を使って学習をしてみましょう.
これまでと同様,90エポックまで学習させておいたsnapshotを用いて,最後の10エポックだけ学習を行います.
[49]:
!wget https://github.com/japan-medical-ai/medical-ai-course-materials/releases/download/v0.1/DeepCNN_cifar10_augmented2_snapshot_epoch_90.npz
reset_seed(0)
model = train(DeepCNN(10), max_epoch=100, train_dataset=train_dataset, valid_dataset=valid_dataset, test_dataset=test_dataset, postfix='augmented2_', base_lr=0.1, lr_decay=(30, 'epoch'), snapshot='DeepCNN_cifar10_augmented2_snapshot_epoch_90.npz')
--2019-12-09 11:21:16-- https://github.com/japan-medical-ai/medical-ai-course-materials/releases/download/v0.1/DeepCNN_cifar10_augmented2_snapshot_epoch_90.npz
Resolving github.com (github.com)... 13.229.188.59
Connecting to github.com (github.com)|13.229.188.59|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://github-production-release-asset-2e65be.s3.amazonaws.com/153412006/5064a880-eeb7-11e8-8e8b-fddbe76ecd56?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20191209%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20191209T112116Z&X-Amz-Expires=300&X-Amz-Signature=bf6cac01b9855110a9abeb2df76a65e697c6226c2353f102c24da044d69f80ef&X-Amz-SignedHeaders=host&actor_id=0&response-content-disposition=attachment%3B%20filename%3DDeepCNN_cifar10_augmented2_snapshot_epoch_90.npz&response-content-type=application%2Foctet-stream [following]
--2019-12-09 11:21:16-- https://github-production-release-asset-2e65be.s3.amazonaws.com/153412006/5064a880-eeb7-11e8-8e8b-fddbe76ecd56?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20191209%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20191209T112116Z&X-Amz-Expires=300&X-Amz-Signature=bf6cac01b9855110a9abeb2df76a65e697c6226c2353f102c24da044d69f80ef&X-Amz-SignedHeaders=host&actor_id=0&response-content-disposition=attachment%3B%20filename%3DDeepCNN_cifar10_augmented2_snapshot_epoch_90.npz&response-content-type=application%2Foctet-stream
Resolving github-production-release-asset-2e65be.s3.amazonaws.com (github-production-release-asset-2e65be.s3.amazonaws.com)... 52.216.20.243
Connecting to github-production-release-asset-2e65be.s3.amazonaws.com (github-production-release-asset-2e65be.s3.amazonaws.com)|52.216.20.243|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 56734002 (54M) [application/octet-stream]
Saving to: ‘DeepCNN_cifar10_augmented2_snapshot_epoch_90.npz’
DeepCNN_cifar10_aug 100%[===================>] 54.11M 11.5MB/s in 6.0s
2019-12-09 11:21:23 (8.96 MB/s) - ‘DeepCNN_cifar10_augmented2_snapshot_epoch_90.npz’ saved [56734002/56734002]
/usr/local/lib/python3.6/dist-packages/chainer/training/triggers/interval_trigger.py:92: UserWarning: The previous value of iteration is not saved. IntervalTrigger guesses it using current iteration. If this trigger is not called at every iteration, it may not work correctly.
'The previous value of iteration is not saved. '
/usr/local/lib/python3.6/dist-packages/chainer/training/triggers/interval_trigger.py:104: UserWarning: The previous value of epoch_detail is not saved. IntervalTrigger uses the value of trainer.updater.previous_epoch_detail. If this trigger is not called at every iteration, it may not work correctly.
'The previous value of epoch_detail is not saved. '
epoch main/loss main/accuracy val/main/loss val/main/accuracy elapsed_time lr
1 2.64248 0.137296 2.15132 0.175 23.5716 0.1
2 2.09025 0.202814 1.90949 0.248633 47.2547 0.1
3 1.92989 0.251647 1.80744 0.313867 70.8962 0.1
4 1.82033 0.296742 1.75358 0.319727 94.6262 0.1
5 1.70069 0.350717 1.59837 0.397266 118.252 0.1
6 1.56064 0.418435 1.6835 0.418359 141.967 0.1
7 1.45132 0.473691 1.30451 0.528711 165.613 0.1
8 1.33843 0.523815 2.81471 0.372656 189.382 0.1
9 1.27017 0.556019 1.61386 0.499023 213.846 0.1
10 1.20204 0.579972 1.46257 0.519141 237.532 0.1
11 1.14896 0.601252 1.1028 0.623633 265.125 0.1
12 1.10957 0.614428 1.17823 0.594141 288.775 0.1
13 1.07055 0.635387 1.00013 0.670898 312.476 0.1
14 1.04187 0.647058 1.07628 0.642383 336.226 0.1
15 1.00359 0.661066 1.00439 0.655859 360.038 0.1
16 0.971513 0.675503 0.8598 0.723828 383.802 0.1
17 0.941225 0.686301 1.09454 0.661328 407.665 0.1
18 0.917967 0.694869 0.998599 0.681055 431.468 0.1
19 0.905397 0.699586 0.811614 0.738281 455.224 0.1
20 0.874732 0.706188 0.714926 0.762695 479.061 0.1
21 0.86753 0.712607 0.850176 0.738477 506.629 0.1
22 0.858491 0.714844 0.95919 0.692187 530.459 0.1
23 0.845307 0.717637 1.12914 0.668164 554.185 0.1
24 0.827677 0.72785 0.827821 0.746484 577.974 0.1
25 0.821103 0.728715 0.7663 0.741602 601.76 0.1
26 0.816583 0.728321 0.763894 0.757812 625.472 0.1
27 0.805109 0.735596 0.74157 0.754492 649.375 0.1
28 0.809963 0.734152 0.742951 0.7625 673.099 0.1
29 0.792127 0.737149 0.692472 0.773633 696.895 0.1
30 0.783124 0.741965 1.02244 0.702734 720.621 0.1
31 0.604456 0.795898 0.394226 0.868359 748.179 0.01
32 0.523254 0.822227 0.379255 0.873242 771.92 0.01
33 0.502693 0.830078 0.360393 0.877734 795.747 0.01
34 0.484705 0.835627 0.351212 0.884766 819.578 0.01
35 0.465001 0.842192 0.348862 0.882031 843.291 0.01
36 0.45372 0.846613 0.339081 0.885742 867.053 0.01
37 0.450468 0.846043 0.335244 0.887695 890.763 0.01
38 0.439256 0.848411 0.331884 0.889453 914.524 0.01
39 0.430965 0.852475 0.324836 0.895312 938.351 0.01
40 0.424651 0.855447 0.370822 0.881641 962.135 0.01
41 0.418223 0.857 0.328078 0.89082 989.723 0.01
42 0.409156 0.860332 0.334291 0.888672 1013.41 0.01
43 0.410969 0.860574 0.338266 0.888281 1037.14 0.01
44 0.397083 0.862469 0.337938 0.889258 1060.81 0.01
45 0.390805 0.8661 0.31488 0.900586 1084.56 0.01
46 0.39124 0.866008 0.324671 0.893945 1108.24 0.01
47 0.390007 0.865101 0.320954 0.895117 1131.96 0.01
48 0.383439 0.868367 0.355809 0.890625 1155.62 0.01
49 0.380758 0.87065 0.314927 0.899805 1179.36 0.01
50 0.379198 0.870716 0.302073 0.903125 1203.09 0.01
51 0.371325 0.873019 0.317537 0.899805 1230.61 0.01
52 0.368561 0.874578 0.338393 0.891016 1254.34 0.01
53 0.366614 0.872975 0.318233 0.89707 1278.2 0.01
54 0.366082 0.874667 0.328654 0.892187 1301.94 0.01
55 0.365454 0.873331 0.311355 0.902148 1325.61 0.01
56 0.357191 0.876509 0.330824 0.895898 1349.36 0.01
57 0.361274 0.874667 0.320739 0.897461 1373.1 0.01
58 0.354885 0.878339 0.308104 0.901172 1396.78 0.01
59 0.359033 0.876931 0.316534 0.900586 1420.55 0.01
60 0.356422 0.876291 0.366406 0.888281 1444.25 0.01
61 0.314513 0.891513 0.261128 0.914062 1471.67 0.001
62 0.275813 0.905159 0.257487 0.916406 1495.37 0.001
63 0.268649 0.907804 0.253204 0.918164 1519.12 0.001
64 0.264412 0.908298 0.256485 0.918555 1542.84 0.001
65 0.261826 0.910511 0.253851 0.917773 1566.66 0.001
66 0.255704 0.911998 0.257955 0.916602 1590.48 0.001
67 0.25836 0.909566 0.256463 0.919141 1614.22 0.001
68 0.251792 0.912753 0.254393 0.920117 1638.04 0.001
69 0.251154 0.914508 0.25251 0.920508 1661.81 0.001
70 0.24735 0.91464 0.255754 0.91875 1685.83 0.001
71 0.241314 0.917268 0.253207 0.921094 1713.35 0.001
72 0.246059 0.914617 0.257668 0.920508 1737.18 0.001
73 0.238097 0.918213 0.257092 0.919727 1761.02 0.001
74 0.235832 0.918981 0.251519 0.919922 1784.77 0.001
75 0.236254 0.918857 0.253711 0.919531 1808.58 0.001
76 0.235273 0.917646 0.249922 0.920117 1832.36 0.001
77 0.233553 0.918635 0.251188 0.921094 1856.13 0.001
78 0.229216 0.920829 0.256883 0.921484 1879.92 0.001
79 0.231176 0.919877 0.254759 0.92207 1903.75 0.001
80 0.227571 0.921007 0.251693 0.920508 1927.51 0.001
81 0.229313 0.92041 0.257421 0.920508 1955.13 0.001
82 0.225018 0.922896 0.255896 0.919922 1978.98 0.001
83 0.223251 0.923478 0.25848 0.918555 2002.75 0.001
84 0.221238 0.924294 0.259179 0.920703 2026.56 0.001
85 0.222084 0.923411 0.251648 0.920313 2050.3 0.001
86 0.221973 0.924339 0.252463 0.920898 2074.11 0.001
87 0.217851 0.925147 0.252802 0.920898 2097.85 0.001
88 0.215969 0.925138 0.255438 0.921094 2121.65 0.001
89 0.217307 0.92294 0.253423 0.920313 2145.48 0.001
90 0.217241 0.925058 0.254931 0.919922 2169.54 0.001
91 0.210398 0.926336 0.248383 0.924023 2182.44 0.0001
92 0.202761 0.929131 0.248566 0.923828 2195.01 0.0001
93 0.202395 0.930686 0.248684 0.923633 2207.6 0.0001
94 0.205619 0.926794 0.25063 0.921875 2219.81 0.0001
95 0.207099 0.929088 0.251314 0.922852 2231.86 0.0001
96 0.198571 0.932692 0.250731 0.92168 2243.82 0.0001
97 0.200812 0.930331 0.250747 0.921875 2255.85 0.0001
98 0.204992 0.928689 0.250508 0.92207 2267.96 0.0001
99 0.203764 0.92922 0.248585 0.921289 2280.19 0.0001
100 0.202649 0.930331 0.25095 0.923047 2292.4 0.0001
Test accuracy: 0.92523736
わずかに精度が向上しました.他にも,ResNet [He 2016]と呼ばれる有名なネットワーク構造を採用して学習を行うなど,簡単に試せる改善方法がいくつかあります.ぜひ色々と試してみてください.
4.7. 参考文献¶
[Tokui 2015] Tokui, S., Oono, K., Hido, S. and Clayton, J., Chainer: a Next-Generation Open Source Framework for Deep Learning, Proceedings of Workshop on Machine Learning Systems(LearningSys) in The Twenty-ninth Annual Conference on Neural Information Processing Systems (NIPS), (2015)
[Niitani 2017] Yusuke Niitani, Toru Ogawa, Shunta Saito, Masaki Saito, "ChainerCV: a Library for Deep Learning in Computer Vision", ACM Multimedia (ACMMM), Open Source Software Competition, 2017
[Hidaka 2017] Masatoshi Hidaka, Yuichiro Kikura, Yoshitaka Ushiku, Tatsuya Harada. WebDNN: Fastest DNN Execution Framework on Web Browser. ACM International Conference on Multimedia (ACMMM), Open Source Software Competition, pp.1213-1216, 2017.
[He 2016] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun, "Deep Residual Learning for Image Recognition", CVPR 2016