【PyTorch】DataParallelを使った並列GPU化の躓きどころ - 加賀百万石ですが何か?

【PyTorch】DataParallelを使った並列GPU化の躓きどころ

今回はPyTorchを使って並列GPU化するときに実際に躓いたところの説明と解決方法を整理します。

なお、ここに書く内容はちゃんとドキュメントなどを調べたものではなくて経験則的に

「こうやってみたらうまく動いた」
「試行錯誤したらバグがとりあえず解消した」

というものですので、正確な理解が必要な方は別途公式ドキュメントなどで調査することをお勧めします。

実際に実装してみて分からなかったところ、躓いたところは以下の3点です。

  • torch.deviceの設定方法
  • Multi-GPUで学習したモデルを、Single-GPUで推論するときにLoadする方法
  • 並列化されたGPUからのデータが統合されるときの挙動

torch.deviceの設定方法

まずはシンプルに「torch.device」ってどう設定したらいいの?という疑問です。単一のGPUのときはシンプルにこのような感じで設定すると思います。

gpu_id = 0
if torch.cuda.is_available():
    device = torch.device(f'cuda:{gpu_id}')
else:
    device = torch.device('cpu')

並列化するときに気になることは「gpu_idはどうするの?」と言うことだと思います。

この設定については、 マルチGPUにするときも実はほとんど変わりません。

ここではGPU-ID=0,1の2つのGPUを用いて設定する例を記載します。

gpu_ids = [0, 1]
if torch.cuda.is_available():
    device = torch.device(f'cuda:{gpu_ids[0]}')
else:
    device = torch.device('cpu')

このようにtorch.deviceに設定する値はあくまでもひとつです。

詳細は不明ですが、おそらく(マスター、スレーブの意味での)マスターのGPUとして扱うものをdeviceに設定するような形なんだろうなと思います。

ここではGPI-ID=0のGPUをマスターとして設定しています。

並列GPUでの学習済みモデルを単一GPUにLoad

次によく引っかかる部分はこれだと思います。

マルチGPUで学習したモデルを推論時に使用するときに何も考えずにLoadすると、以下のようなエラーが出ることがあると思います。

RuntimeError: Error(s) in loading state_dict for YOUR_MODEL:
    Missing key(s) in state_dict: "base.0.weight", "base.0.bias", "base.2.weight", ...
    Unexpected key(s) in state_dict: "module.base.0.weight", "module.base.0.bias", "module.base.2.weight", ...

これはエラーそのものの通り、DictionaryのKeyが合っていないのが原因です。

まず、このようなエラーが出てしまう理由についてですが、torch.nn.DataParallelはmodelをラッピングしているので、DataParallelを適用した後ではモデルの持っているattributeやmethodの階層構造が変わってしまうことが原因です。

具体例を上げると、モデルが持っているはずのattributeやmethodにアクセスしようとした場合に、以下のようなAttributeErrorがよく出ると思います。

In [1]: model.XXXXX
AttributeError: 'DataParallel' object has no attribute 'XXXXX'

これを解決するには以下のようにアクセス方法を変えてやればOKです。

In [1]: model.moudle.XXXXX

さて、ではモデルをLoadするときにどうするかということですが、エラーの詳細を見てみると、

「base.0.weightが存在しないです」

でも

「module.base.0.weightというよく分からないものがあります」

という状態になってます。

ということなので「module.base.0.weight」となっているKeyを「base.0.weight」に強制的に書き換えてしまえば万事解決です。

書き換え方は以下のような方法でOKです。

▼修正前(よくあるLoad方法)

trained_weights = 'PATH_TO_TRAINED_WEIGHT.pth'
state_dict = torch.load(trained_weights)
self.load_state_dict(state_dict)

▼修正後

trained_weights = 'PATH_TO_TRAINED_WEIGHT.pth'
state_dict = torch.load(trained_weights, map_location=lambda storage, loc: storage)
from collections import OrderedDict
new_state_dict = OrderedDict()

for k, v in state_dict.items():
    if 'module' in k:
        k = k.replace('module.', '')
    new_state_dict[k] = v

self.load_state_dict(new_state_dict)

並列GPUからデータが統合されるときの挙動

これはエラーが発生した背景から少し説明した方が良いと思うので、まずはどういうときにエラーに遭遇したかを簡単に書きます。

まず発生したタイミングは、SSDの学習を並列化しようとしていたときです。

実装方法にも当然依存しますが、SSDを学習させるときにモデルからの出力が以下のようなサイズで出てきます。

--- DETECT MODE ----------------------------
  size of locs : torch.Size([8, 8732, 4]) # [n_batch, n_priors, n_loc]
 size of confs : torch.Size([8, 8732, 10]) # [n_batch, n_priors, n_class]
size of priors : torch.Size([8732, 4]) # [n_priors, n_loc]
--------------------------------------------

ただ、2GPUで並列化するとこの出力のサイズが以下のように変わります。

--- DETECT MODE ----------------------------
  size of locs : torch.Size([8, 8732, 4])
 size of confs : torch.Size([8, 8732, 10])
size of priors : torch.Size([17464, 4])
--------------------------------------------

ここでpriorsのサイズが倍になってしまっていて、locsやconfsのdim=1のサイズ:8732と合わなくてエラーになってしまうという現象が発生します。

ここからは正確にドキュメントなどを調べてみた訳ではないですが、試行錯誤した際の結果から考えられる結論は、

「各GPUからの結果をdim=0でconcatenateしている」

ということです。

SSDでは事前に設定するボックス群(priors)との位置関係を回帰問題として位置を検出します。

(SSDの詳細についてはこちらを参照ください)

そのため、並列化する際は当然各GPU上にそのアンカーとなるpriorsを配置する必要がある訳ですが、各GPUからマスターのGPUにpriorsも送られてしまい、それがconcatenateされてしまっているんだろうと思います。

したがって、SSDのこの問題に限っては、データの統合後に以下のような処理を挟み、priorsを1セット分だけ取り出してやると解決しました。

priors = priors[:n_priors]

という訳で、並列化されたGPU上での計算結果を統合する際には、

どのようなデータかは関係なく、dim=0ですべてconcatenateされる

という挙動になるのではないかと思います。

そのため、PyTorchでDataParallelを使用して並列化した際にTensorのサイズがGPU数分だけ大きくなってエラーを吐いてしまう場合は、この挙動が原因であることを疑ってみると解決するかもしれません。

まとめ

以上、実際にPyTorchを使ってマルチGPU化してみた際に躓いたところをまとめてみました。

冒頭にも言いましたが、ここに書いてある内容は試行錯誤してみて「こうやると上手くいっているっぽい」という経験則をベースにした情報ですので、正確な情報が必要な方は公式ドキュメントなどでちゃんと調査することをお勧めします。

スポンサーリンク