Convert Bert model from pytorch to onnx and run inference

Hemanth Sharma
2 min readJul 20, 2020

This short tutorial assumes that you have your pytorch BERT model trained.

Let us convert the pytorch BERT sequence classification model into onnx.

First load the pretrained model into model
https://pytorch.org/tutorials/beginner/saving_loading_models.html

Run the below code to convert BERT model into onnx

'''
Define arguments to pass to onnx exporter
'''
model_onnx_path = "model.onnx"# The inputs "input_ids", "token_type_ids" and "attention_mask" are torch tensors of shape batch*seq_lendummy_input = (input_ids, token_type_ids, attention_mask)input_names = ["input_ids", "token_type_ids", "attention_mask"]output_names = ["output"]'''
convert model to onnx
'''
torch.onnx.export(model, dummy_input, model_onnx_path, \
input_names = input_names, output_names = output_names, \verbose=False)

just to clarify that output_names have only “output” because its a classification model (shape: batch X num_of_classes )

Run inference with onnx file

  1. Install onnxruntime to run the onnx inference
pip install onnxruntime
or
pip install onnxruntime-gpu

2. Load onnx model using onnxruntime and run inference

import onnxruntime as ort
import torch
import torch.nn as nn
ort_session = ort.InferenceSession('model.onnx')def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
ort_inputs =
{ort_session.get_inputs()[0].name: to_numpy(input_ids),
ort_session.get_inputs()[1].name: to_numpy(token_type_ids),
ort_session.get_inputs()[2].name: to_numpy(attention_mask)}
pred = ort_session.run(['output'], ort_inputs)pred_output_softmax = nn.Softmax()(pred)_, predicted = torch.max(pred_output_softmax, 1)

The inference from onnx model takes longer than running inference from pytorch model

onnx model could be converted to TensorRT for faster inference.

You can checkout my TensorRT tutorial

Hope this helps :)

I apologize if I have left out any references from which I could have taken the code snippets from.

References:

https://pytorch.org/docs/stable/onnx.html
https://developer.nvidia.com/blog/how-to-deploy-real-time-text-to-speech-applications-on-gpus-using-tensorrt/https://github.com/microsoft/onnxruntime
https://github.com/microsoft/onnxruntime/issues/2796

--

--