【PyTorch】入力データが複数ある場合の実装方法

PyTorchでDeep Learningを実装する際に,データを入力する箇所がネットワーク内に複数ある場合の実装方法についてです。

モデル/②DataLoaderの作り方/③DataLoaderの使い方を順に確認していった後に,最後にまとめて全体のソースを記載しておきます。

なお,PyTorchでの基本的な実装方法は【PyTorch】MNISTのサンプルを動かしてみたを参考にしてみて下さい。

モデルの実装方法

今回はサンプルとして,下図のような入力が2か所あるネットワークを実装してみます。

まず最初に畳み込み層にINPUT(1):”高さ15×幅20×チャネル1″の2次元行列を入力します。その後,畳み込み層からの出力とINPUT(2):”3成分のベクトル”を一緒に全結合層に入力するような例です。

まずネットワークの書き方のソースは以下の通りです。

各層を定義するinit関数では,27行目がポイントになります。in_featuresの値を畳み込み層からの入力数(64×13×18つ)INPUT(2)データからの入力数(3つ)を足した値を設定します。

次にforward関数の中では44行目の部分がポイントです。ここでは,畳み込み層からの入力をviewメソッドで1次元化したあとに,torch.catメソッドを使ってINPUT(2)のデータと結合して2つの入力をひとつにまとめた形式に変換します。

DataLoaderの作り方

次にDataLoaderの作り方です。今回は複数入力がある場合の実装方法に焦点を当てますので,データの値はランダムに適当に作成します。

ここでは,32~33行目でデータセットをまとめて作成しているところがポイントになります。一度覚えてしまえば単純な話ではありますが,すべての入力データと出力データをまとめて引数に与えてデータセットを作成します。

最終的にDataLoaderを作成するときは,ここで作成したデータセットをそのまま引数として与えればいいだけですので特に変わったポイントはありません。

DataLoaderの使い方

DataLoaderはよくあるパターンとしてはtrain関数を用意して,その中で使用するという形が多いかと思います。そこでの実装方法は以下の通りです。

ここでポイントになるのは7行目と14行目です。

まず7行目についてですが,DataLoaderからはデータセットを作成したときに与えた引数の形でデータがfeedされます。そのため,”for in_oneD, in_twoD, target in data_loader”という形で3つのデータを受け取れます。また,その受け取ったデータは通常と同じように,それぞれ個別にto関数でGPUに送ってやれば問題ありません。

次に14行目ですが,ここでは2つの入力のデータを引数としてモデルに与えてやります。引数の順番は,”モデルの実装方法”で作成したforward関数で指定した引数の順番と同じにすれば良いです。

まとめ

PyTorchで複数の入力がある場合の実装方法のポイントは以上の通りです。一度覚えてしまえばそれ程難しくないというか,そのまま実装しているだけだなというのが分かるかと思います。

最後に,そのまま学習を実行できる状態のソース全体を記載しておきます。PyTorchがインストール出来ている環境でこのソースを適当なPythonファイルとして保存して実行すれば動くはずです。

今回は複数の入力がある場合にポイントとなる箇所に焦点を当てた記事でしたので,損失関数やオプティマイザなどの細かい箇所の基本的なことが気になる場合は,【PyTorch】MNISTのサンプルを動かしてみたもご覧下さい。

スポンサーリンク