けやみぃアーカイブ

CV勉強中の大学生のアウトプットです

StyleGAN2で顔ランドマーク座標を指定して画像生成【前編】

PFNから発表された Surrogate Gradient Field for Latent Space Manipulation という論文の再現実装をしてみました.どんな論文かは以下の図を見てもらえるとわかるかと思います.上段は属性(年齢や性別など)を指定してStyleGAN2で生成した画像です.中段では,今度は文章を使って花の外見を指定して生成しています.そして下段,個人的にはこれが一番面白いと思うのですが,顔のランドマーク座標を指定してアニメ顔を生成しています! f:id:kym384:20210730203730p:plain

簡単な解説

用語

  • G: 学習済みGenerator

論文ではStyleGAN2とProGANの2つで実験していますが,Generatorの形式によらないため他のものでも問題ありません.

  • C: 性質推定器

画像からその性質を推定するネットワークです.顔のランドマーク座標を指定して生成するタスクの場合は,dlibのランドマーク推定器なんかもこれにあたります.

  • z: 潜在変数

StyleGANの場合はwを使いますが,一般的にzと書くことにします.

問題定義

論文では簡単のため,C(G(z)),つまり潜在変数zから画像を生成して性質推定器に入力する操作を\Phi(z)と書いてるのでそれに従います.ここでの目標は,ランダムな潜在変数z_0とその性質c_0=\Phi(z_0)からスタートして欲しい性質c_1を持つ画像を生成する潜在変数z_1を求めることです.その過程をz(t)と表すこととし,z(0)=z_0z(1)=z_1とします.

手法

Auxiliary Mapping F(z, c)を定義します.Fz=F(z, c)となるように学習させます.中身はAdaINと(Spectral Normalization付きの)FC層のくり返しなので,zの情報をAdaINの正規化の過程で削減して,それをcで補完させることになっています.(論文にも構成が書かれていますが,本当にシンプルなネットワークです.)

ここで3つの仮定を設けます.

  1. 目的の性質c_1を持つ画像を生成するz_1が必ず存在する.

  2. z(t)を変化させた時,その生成画像の性質\Phi(z(t))は一様にc_1-c_0で変化する.つまり\frac{\rm{d}\Phi(z(t))}{\rm{d}t}=c_1-c_0

  3. 性質を変化させると,Fの出力も変化する.つまり\frac{\partial F(z, \Phi(z))}{\partial c} \neq 0 (c=\Phi(z))

そしてz(t)=F(z(t), \Phi(z(t)))の両辺をt微分します.それから仮定2を使って式変形することで,以下が得られます.(詳しくは論文の3.3をご覧ください)

\frac{\rm{d}z(t)}{\rm{d}t}=\left(\mathbf{I}-\frac{\partial F(z(t), \Phi(z(t)))}{\partial z}\right)^{-1}\frac{\partial F(z(t), \Phi(z(t)))}{\partial c}(c_1-c_0)

この右辺をH(z)と置き,Surrogate Gradient Fieldと呼んでいます.この微分方程式\frac{\rm{d}z(t)}{\rm{d}t}=H(z)を初期条件z(0)=z_0で数値的に解くことで,z(1)=z_1の近似を求めることができます!

また重要な点として性質判定器C逆伝播を計算する必要がありません! なので論文ではこの性質判定器にChainerで実装したCNNやdlibを使っています.

Auxiliary Mapping

Fについてもう少し解説しておきます.以下が掲載されているモデル図です.AdaINとLinear×2から成るConditional Linear BlockをN層重ねています.活性化関数にはLeaky ReLUを使っていて,全てのLinearにはSpectral Normalizationをかけています. f:id:kym384:20210730214340p:plain

ちなみに顔属性の判定器を作るためにAzureのAPIを使ってラベル付けして学習データを作ったらしいです.

実装

今回はFFHQで学習させたStyleGAN2を使って顔のランドマーク座標を指定して画像を生成させたいと思います.StyleGAN2のモデルには,以下で提供されている256の解像度でFFHQを学習させたモデルを使います.(軽量化のため)

github.com

ランドマーク座標推定にはdlibを使います.

Auxiliary Mappingの学習

学習データには論文同様20万枚生成した画像を使い,バッチサイズ32で50k iteration学習させました.cにはdlibの座標の出力を[-1, 1]に正規化し,一次元に並べました.なので68×2次元になっているわけです.

結果は以下です.左がランダムに生成した画像,右がAuxiliary Mappingの出力をもとに生成した画像です.いい感じに潜在変数を復元できていますね.(実際には論文よりもiterationを低めに設定しているので改善の余地はあります)

f:id:kym384:20210730220133p:plain

ランドマーク座標変換

論文に記載されている疑似コードをベースにSGFを実装していきます.ちなみに torch.autograd.functional.jacobian を知らなかったのでヤコビアンを求める方法に苦戦しました…

def translate(G, C, F, z0, c1, m=1, max_iteration=100, step=0.2):
    with torch.no_grad():
        image, _ = G([z0], input_is_latent=True, randomize_noise=False)
        c0 = C(image).unsqueeze(0)

        d_c = step * (c1 - c0)    # dΦ/dc = c1-c0

    images = [image]

    _, style_dim = z0.shape
    _, c_dim = c1.shape

    z = z0
    c = c0
    c_diff = -1

    for i in range(max_iteration):
        z.requires_grad = True
        c.requires_grad = True

        z_grad = torch.autograd.functional.jacobian(lambda z_:F(z_, c), z)[0,:,0,:]     # dF/dz
        c_grad = torch.autograd.functional.jacobian(lambda c_:F(z, c_), c)[0,:,0,:]     # dF/dc
        
        with torch.no_grad():
            # dz/dt = ( I - dF/dz)^{-1} * dF/dc * dΦ/dc
            d_z = c_grad @ d_c[0]    # dF/dc * dΦ/dc
    
           # ( I - dF/dz)^{-1} = I + dF/dz + (dF/dz)^2 + ...
            d_z_j = d_z
            for j in range(m):
                d_z_j = z_grad @ d_z_j
                d_z += d_z_j

            z += d_z
            image, _ = G([z], input_is_latent=True, randomize_noise=False)
            
            c = C(image).unsqueeze(0)

            # || Φ(z(t)) - c1 ||_2
            c_diff_new = (c-c1).pow(2).mean()
            
            print(c_diff_new)
            images.append(image)

            if c_diff > 0 and c_diff < c_diff_new:
                break

            c_diff = c_diff_new

    return images, z

結果は以下です.左端がランダムに生成したソース画像,右端がターゲット画像でこのランドマーク座標を持ったソース画像を生成するのが目的です.中間はその変換の過程の出力です.段々とターゲットのランドマーク座標に近づいてるのがわかります. f:id:kym384:20210730223337p:plain

ただ見てわかる通り,座標移動の過程で髪色などの属性も変化してしまっています.これは論文中でも言及されており,以下の図では上が座標だけを入力した場合,下が座標と共に属性情報を入力した場合です.なので本来であれば属性情報と共に入力しなければいけないのですが,今回はそのような分類器がなかったので座標だけの入力にしています… f:id:kym384:20210730221012p:plain

コード

github.com

後半は何かしらの属性識別器を作ってから書くつもりです.