本記事ではPyTorchのモデルをONNX変換する手順について解説します。
今回例としてResNet18をONNXに変換してみたいと思います。
そもそもONNXに変換するメリットって何?
まず初めにPyTocrhのモデルのままでなくONNXに変換するメリットについて簡単にお話します。
ズバリ本記事のタイトルにあるように「ディープラーニングモデルの共有を簡単に!」です。
もう少し具体的なお話についてはChatGPTさんに聞いてみましょう。
ONNX(Open Neural Network Exchange)は、ディープラーニングモデルを異なるフレームワーク間で互換性を持たせるための標準形式を提供することから、モデルの共有が簡単になります。PyTorchからONNXへの変換が特に便利である理由をいくつか挙げます:
ChatGPT
- フレームワークの互換性:
- ONNXはPyTorchだけでなく、TensorFlow、Caffe2、MXNetなど多くのディープラーニングフレームワークに対応しています。これにより、異なるフレームワーク間でモデルを容易に共有・移行できます。
- 推論エンジンの多様性:
- ONNXモデルは、多くの推論エンジンでサポートされています。例えば、ONNX Runtime、TensorRT、OpenVINOなどがあり、様々なハードウェア上での高速な推論が可能になります。
- デプロイの柔軟性:
- ONNXモデルはクラウド、エッジデバイス、モバイルデバイスなど、異なるプラットフォームに簡単にデプロイできるため、実運用環境への導入がスムーズです。
- エコシステムとツールの豊富さ:
- ONNXには豊富なツールとエコシステムがあり、モデルの変換、最適化、デバッグを支援するためのライブラリやツールが揃っています。これにより、開発者は効率的に作業を進めることができます。
- 標準化のメリット:
- モデルをONNX形式に変換することで、ディープラーニングの標準形式として広く認識され、他の研究者や開発者とモデルを共有する際のトラブルが減少します。
ChatGPTさんありがとうございます。
といった感じでONNXは様々なディープラーニングフレームワークと互換性がある標準フォーマットで、異なるフレームワーク間でモデルを簡単に移行・共有できます。
またONNXはNetronというツールで簡単にモデル構造を可視化できるところも魅力の一つですね。
PyTorchのResNet18をONNXに変換
今回はモデルの学習についてはスコープ外なので、torchvisionを使って学習済みのResNet18を用意してONNXに変換していきます。
まずはモデルの用意。
import torchvision
model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
ちなみにweightはImageNet-1Kの学習結果を使っています。
続いてPyTocrhモデルをONNXに変換する際にモデルへの入力サンプルが必要になるので用意していきます。
これはエクスポート時にモデルの動作をシミュレートするために使用されます。
import PIL
image = PIL.Image.open(image_path)
といった上記例のようにPILで画像を読み込んでおきます。
そして前処理。
preprocess = torchvision.transforms.Compose(
[
torchvision.transforms.Resize(256),
torchvision.transforms.CenterCrop(224),
torchvision.transforms.ToTensor(),
]
)
image = preprocess(image).unsqueeze(0)
ここまで準備できたらONNXに変換は以下。
torch.onnx.export(model, image, "resnet18.onnx")
第3引数はONNXを出力したいパスを指定しています。
ONNX Runtimeで推論してみる
おまけですが変換したONNXモデルで推論実行をしてみます。
import numpy as np
import onnxruntime as ort
ort_session = ort.InferenceSession("resnet18.onnx")
ort_inputs = {ort_session.get_inputs()[0].name: image.numpy()}
ort_outs = ort_session.run(None, ort_inputs)
predicted_class = np.argmax(ort_outs[0], axis=1).item()
最後に
以上の内容をもう少し整理して、さらにPyTorchモデルとONNXモデルの出力が一致していることを確認するテストが以下です。
良ければこちらも参考にしていただければと思います。