Convert Bert model from pytorch to onnx and run inference

Hemanth Sharma
2 min readJul 20, 2020

This short tutorial assumes that you have your pytorch BERT model trained.

Let us convert the pytorch BERT sequence classification model into onnx.

First load the pretrained model into model
https://pytorch.org/tutorials/beginner/saving_loading_models.html

Run the below code to convert BERT model into onnx

'''
Define arguments to pass to onnx exporter
'''
model_onnx_path = "model.onnx"# The inputs "input_ids", "token_type_ids" and "attention_mask" are torch tensors of shape batch*seq_lendummy_input = (input_ids, token_type_ids, attention_mask)input_names = ["input_ids", "token_type_ids", "attention_mask"]output_names = ["output"]'''
convert model to onnx
'''
torch.onnx.export(model, dummy_input, model_onnx_path, \
input_names = input_names, output_names = output_names, \verbose=False)

just to clarify that output_names have only “output” because its a classification model (shape: batch X num_of_classes )

Run inference with onnx file

  1. Install onnxruntime to run the onnx inference
pip install onnxruntime
or
pip install onnxruntime-gpu

2. Load onnx model using onnxruntime and run inference

import onnxruntime as ort
import torch
import torch.nn as nn
ort_session = ort.InferenceSession('model.onnx')def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
ort_inputs =
{ort_session.get_inputs()[0].name: to_numpy(input_ids),
ort_session.get_inputs()[1].name: to_numpy(token_type_ids),
ort_session.get_inputs()[2].name: to_numpy(attention_mask)}
pred = ort_session.run(['output'], ort_inputs)pred_output_softmax = nn.Softmax()(pred)_, predicted = torch.max(pred_output_softmax, 1)

The inference from onnx model takes longer than running inference from pytorch model

onnx model could be converted to TensorRT for faster inference.

You can checkout my TensorRT tutorial

Hope this helps :)

I apologize if I have left out any references from which I could have taken the code snippets from.

References:

https://pytorch.org/docs/stable/onnx.html
https://developer.nvidia.com/blog/how-to-deploy-real-time-text-to-speech-applications-on-gpus-using-tensorrt/https://github.com/microsoft/onnxruntime
https://github.com/microsoft/onnxruntime/issues/2796

Sign up to discover human stories that deepen your understanding of the world.

Membership

Read member-only stories

Support writers you read most

Earn money for your writing

Listen to audio narrations

Read offline with the Medium app

Hemanth Sharma
Hemanth Sharma

Written by Hemanth Sharma

Senior Data Science Engineer @ Qyrus

Responses (1)

Write a response