フォントを生成するGANを作った話(後編)
2021 ISer Advent Calendar 23日目の記事です.
まだ前回の記事を読んでいない,あるいはもう忘れたという方は前回の記事を読まれることをお勧めします.
訓練1
さて,ようやく訓練フェーズに入ったわけですが,前述のネットワークをそのまま訓練しようとすると,非常に時間がかかるので,段階的に訓練していきます.まずはモデルの低層部分である,文字の画像を次元削減し,再び文字の画像を生成する部分(いわゆるオートエンコーダ)を学習します.下図のようにGeneratorの低層部分の出力から画像を生成し,元の文字の画像と一致するように訓練させることを順に行います.

ここで難しいのが,どの程度までこのモデルを訓練させるかです.変換できる文字の種類を今回学習する文字[1]に限ってしまえば,母集団=標本となるため,過適合がなく,訓練すればするほど精度が上昇するのですが,それ以外の文字もある程度変換できるようにしたい場合は早期に切り上げる必要があります.
今回は,頻出する構造がはっきりと再現できる程度に訓練をしました.特に再現に時間がかかった構造としては,かなの半濁点,諫などの内部の点,馬へんなどのれんがの部分があります.


また,この訓練の損失関数は初めはL1損失を用いました.フォント画像を出力するので,ぼやけた画像よりも二値化されたような画像のほうが望ましいためこの選択をしたのですが,意外と学習が遅く,数千エポック回しても下図のような細かい構造が再現できませんでした.これをもってL2損失に切り替えたところ,100エポックほどで以下のような画像を出力しました.この時は出力にシグモイド関数を挟んでいたため,L2損失を使うと勾配消失が起きるのではないかと思っていましたが…難しいものですね.

訓練2
ようやくGANとしての訓練を始められますが,その前に.Generatorが生成した画像の分類を行うDiscriminatorを説明します.
まず,Generatorが生成した画像か変換先のフォントの画像かを判別する通常のDiscriminatorを用意します.これには前述の画像のほかに,Genarotrに入力した画像と同じ変換先のフォントの組も同時に入れます.これにより,DiscriminatorはGeneratorが作った画像かどうかというよりは,入力された画像の文字が同じフォントに属するかという判定を行います.
これに加えて,入力された画像が何の文字かを判定するCharaDiscriminatorも用意しました.これは文字の代わりにそれに対応する特徴量ベクトルを出力します.訓練はこの特徴量がGenerator内のCharaEncoderの出力と一致するように行います.
以上の分類器とともに以下の図のように2つの損失を算出し,最小化させていくことでGeneratorを訓練していきます.同時にDiscriminatorにも損失関数を用意し,訓練させています.

Generatorが出力した画像(とそのほか必要な入力)をそれぞれ分類器に入れ,Generator, 分類器ごとにそれぞれの損失関数を最小化しさせました.
(途中)結果
今のところの結果をあげておきます.
まずはうまくいっている(ように見える)例から.以下で挙げる画像は,左から,変換したい字(入力),変換先の教師画像,Generatorの出力結果,変換したいフォントで書いた字の画像2組(入力)となっております.
もちろんこんなにうまくいくものばかりではありません.下のようになんか背景が真っ白になっている時があったり,

ほかにも,
- 明らかに他の画像の特徴で表現されている
- →他の画像の学習結果が大きくでてしまった?
- 周りに変なごみが出る(うまくいった例の画像などにも出てる)
- ^ , .などの小さい字が異様に膨らむ
などの問題があります.
また,フォントの大きさの変化,回転などといった特徴は表現できないことも分かっております,これはおそらくフォントのエンコードをする部分で主にConvolutionレイヤを使っているモデルを使用していることが原因だと思いますが,特に支障をきたさなそうなのでそのまま訓練を続けようと思っております[2],
感想
とにかくGenerator, Discriminatorの訓練バランス調整が難しく,崩れた途端意味のない画像しか生成しなくなるので,苦戦させられました.その調整ができても,訓練に時間がかかるという問題があり,それを改善しようとしてバランスが崩れる… というようなイタチゴッコが続いており,つらいです.
ただ,今のところはうまくいってそうでその点は少し気が楽なので,もう少し続けてみようと思います.
備忘録
初めてのGANの訓練でつまずいたところ,気づいたところを書いておきます.
- Dのちょうどいい正解率がよくわからない
- 50%がちょうどいい正解率らしいけど,Dが全然学習してないときとの区別ができない
- とりあえず65~55%あたりにしてみる
- G, Dのバランス調整はdropout率を動かすのが早い
- Dの正解率等に応じて動的に調整できればなお楽
- 人の実装はしっかりコードを読んで意味を確認してから使うこと
- 自分の直感とは違う書き方がたまにされている
- これで損失関数の符号が逆になったまま訓練してました()
- メモリが足りなくなったらdel; torch.cuda.empty_cache(); gc.collect()をする
- バッチサイズを大きくして訓練できる
- でもこれをすると逆効果になることもある
- 地味に時間も食うので毎iterationやるときなどは必要がなければ使わないほうがいいかも
- 勾配を伝播させる必要のないテンソルはdetach()で計算グラフから切り離す
- しなくてもバグらないけどメモリをかなり無駄遣いする
- GANで必須のテクニックだと思うけど意外と解説が少ない気がする
- line_profilerやtorchinfo.summaryで時間,メモリの無駄遣いをしているところを探す
- 損失関数をAdamにこだわらない
- RMSPropやAdamWが意外といけるときもある
余談
最後に,訓練中の出力画像を表示してくれたTensorboard君の渾身の煽りを御覧ください.
