AnimeFace で pix2pix をやってみた.
pix2pix とは
かなり今更という感じですがやります.
pix2pix は Image-to-Image Translation with Conditional Adversarial Networks で提案されたConditional GAN (CGAN) の一種.
pix2pixでは, デフォルメされた地図を航空写真のように変換, 線画やグレースケールの写真に着色, 昼間から夜間の画像に変換など様々なタスクで上手くいってるようです.
pix2pixは通常のGANとは異なり, あるノイズのベクトルではなく画像を Generator に入力します. 学習過程は上図のように, Discriminatorに入力に用いた画像(条件として)と生成された画像を結合してから入力します. Generator での学習では, Discriminator に "Real" と予測させるように, Discriminator の学習では, 生成された画像を "Fake" , 本物の画像を "Real" と予測するように交互に学習させます. (通常のGANの学習と同じ)
Generator
Generator に使われてるモデルは, UNet でセマンティックセグメンテーションの分野で活躍しているモデルが採用されている.
Encoder-decoder モデルでは, ダウンサンプリングしていくことで位置情報などの情報が失われてしまうが, U-Netでは Encoderで各層でダウンサンプリングした特徴量マップを Decoder 側に skip connection しており, 失われるはずであった位置情報も考慮しながらアップサンプリング可能となる.
確かに, グレースケール画像や線画に色をつけるという意味ではセグメンテーション的な意味合いもなくもないと思う.
Discriminator
Discriminatorでは, PatchGANという手法が使われています.
通常は入力された画像に対して Fake か Real かを評価するが, PatchGANでは入力された画像を N × N のタイル状にみて, 各タイルが Fake か Real なのかを判定することでより, ロスを計算するそうです. (いわゆる N が Patch_size にあたる)
具体的には, Dsicriminator の出力( 1, N , N) と (1, N, N) の次元を持つすべて 1 で埋めたもの, すべて0 で埋めたものの行列を用いて, その差がDiscriminatorのロスになります.
損失関数
一般的なGANの損失関数に,
Real の画像とFake の画像の L1距離
以上を足して, これを最適化します.
GANの損失関数はPatchGANによるものなので, 画像の局所的な部分の誤差と解釈でき, Real の画像とFake の画像の L1距離は画像の大域的な誤差と解釈でき, バランスをうまく補っているようです.
ちなみに, L1距離の重みとしてある λ は大きくすると, 大域的な誤差を重視するようになるため, 生成される画像と元画像が似るようになるということですかね.
また論文中では, L1距離を足すことでブラー(生成画像のボケ)が減少すると書かれています.
AnimeFace で pix2pix
概要
線画と元画像でペアにしてデータセットを作成し, Generator に線画を入力することで着色された画像が生成できるように学習させることを目的とします. 学習の仕方としては, 序盤で記述したのと同様です.
GeneratorのネットワークはSkip-Connectionありの Conv, BatchNorm, ReLU,AvgPool, Upsampling,を用いた U-Net の構造です.
DiscriminatorのネットワークはConv, BatchNorm, ReLU ,AvgPool,を用いた普通のダウンサンプリングです.
実験環境
Google Colab で実験しました. また, pix2pixのネットワークでは, Discriminator の入力でスタイル変換させたい画像(今回は線画)を結合して入力しますが, 結合しないバージョンも実験してみました
- PyTorch (実装は[2][3]あたりを参考)
- batch_size = 4
- Epoch = 65
- データセットは, Kaggleに置いてあるanother-anime-face-dataset から 5000枚をサンプル (サイズは 256×256 )
今回使ったコードです
実験結果
Epoch 65 までの学習した, Generatorにテストデータを入力した着色生成結果を載せます.
所感
- Epoch数が早い段階で着色は始まり, 難しいことせずともしっかり最後は着色できてる.
- 線画を入力するかしないかの違いは着色結果ではあまりみられない.
- 全体の結果からは, 彩度が高い着色や, 青と緑の色の着色が失敗しやすく, 赤系の色の着色は成功する兆候がある
この発展としては, Generator に色のヒントをなんらかの形で与えて学習させることで, ユーザが色を指定して生成することができ, 実際成功例もweb上で見られます. 一番手っ取り早そうなのは, 線画と元の画像のペアを入力のチャンネル数を 1 -> 4 としてGeneratorに入力したり..., Decoder部分からヒントとして元の画像をEmbeddingして入力するなどが考えられます.
また, 解像度が高い画像でも上手くいくのかは気になるところです.
参考
[1] https://arxiv.org/pdf/1611.07004.pdf
[2] pix2pixを1から実装して白黒画像をカラー化してみた(PyTorch) | Shikoan's ML Blog
[3] GitHub - mrzhu-cool/pix2pix-pytorch: PyTorch implementation of "Image-to-Image Translation Using Conditional Adversarial Networks".