けやみぃアーカイブ

CV勉強中の大学生のアウトプットです

PyTorchの異なるバージョン間での重み変換

なんやかんや複雑な事情で、学習時と推論時のPyTorchのバージョンが違った時の対処法です。特に学習時のバージョンの方が新しいと、推論する時に古いバージョンでは学習済みモデルを読み込んでくれないことがあるかと思います。まあ推論環境のPyTorchをアップデートすればいいだけなんですが、anaconda自体アップデートしなきゃいけなかったり、そうすると他にもいろいろ一緒にアップデートされてしまったりと面倒くさそうだったのでcheckpointの方を変換することにしました。

方法というほどの方法でもないんですが…一度numpyに変換して保存してあげます。

# translate_weights.py

from collections import OrderedDict
import numpy as np
import shutil
import torch
import sys
import os

# PyTorchの異なるバージョン間での重みの変換
# 例 PyTorch 1.8 で学習 -> PyTorch 1.1 で推論

name = sys.argv[1]

if not os.path.isdir("tmp"):
  # 変換元のPyTorchで実行 (例 PyTorch 1.8)

  os.mkdir("tmp")

  ckpt = torch.load(name)

  for k, v in ckpt.items():
    v = v.to("cpu").detach().numpy()
    k = k.replace(".", "+")

    np.save(f"tmp/{k}.npy", v)

else:
  # 変換後のPyTorchで実行 (例 PyTorch 1.1)

  ckpt = OrderedDict()

  for k in os.listdir("tmp/"):
    v = np.load(f"tmp/{k}")
    k = k[:-4].replace("+", ".")

    ckpt[k] = torch.from_numpy(v)

  torch.save(ckpt, name)

  shutil.rmtree("tmp")

最初に新しい方のPyTorchで読み込んでnumpy形式で保存した後に、古い方のPyTorchでnumpy読み込み→ptファイルで保存という流れになってます。なので

conda activate new_pytorch                       # 新しい方のPyTorchが入ってる環境
python trasnlate_weights.py hoge.pt              # 読み込みたいcheckpoint
conda activate old_pytorch                       # 古い方のPyTorchが入っている環境
python translate_weights.py new_hoge.pt          # 変換後のcheckpointの保存先

とかで上手く行くと思います。

一応注意点。(コード読んでもらえればわかるんですが)

  • checkpointにGeneratorとかDiscriminatorとかoptimizerが一緒に保存されてる時はコピペじゃ動きません。for k, v in ckpt.items()の部分をもう1回繰り返したら行けると思います。

  • tmpフォルダを作る&消すという操作をしてるので、もともとtmpフォルダがある場合は適当なフォルダ名に書き換えてください。

パワープレイですが、まあ誰かの役に立てば…。