Convert Bert model from pytorch to onnx and run inference
This short tutorial assumes that you have your pytorch BERT model trained.
Let us convert the pytorch BERT sequence classification model into onnx.
First load the pretrained model into model
https://pytorch.org/tutorials/beginner/saving_loading_models.html
Run the below code to convert BERT model into onnx
'''
Define arguments to pass to onnx exporter
'''model_onnx_path = "model.onnx"# The inputs "input_ids", "token_type_ids" and "attention_mask" are torch tensors of shape batch*seq_lendummy_input = (input_ids, token_type_ids, attention_mask)input_names = ["input_ids", "token_type_ids", "attention_mask"]output_names = ["output"]'''
convert model to onnx
'''torch.onnx.export(model, dummy_input, model_onnx_path, \
input_names = input_names, output_names = output_names, \verbose=False)
just to clarify that output_names have only “output” because its a classification model (shape: batch X num_of_classes )
Run inference with onnx file
- Install onnxruntime to run the onnx inference
pip install onnxruntime
or
pip install onnxruntime-gpu
2. Load onnx model using onnxruntime and run inference
import onnxruntime as ort
import torch
import torch.nn as nnort_session = ort.InferenceSession('model.onnx')def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()ort_inputs =
{ort_session.get_inputs()[0].name: to_numpy(input_ids),
ort_session.get_inputs()[1].name: to_numpy(token_type_ids),
ort_session.get_inputs()[2].name: to_numpy(attention_mask)}pred = ort_session.run(['output'], ort_inputs)pred_output_softmax = nn.Softmax()(pred)_, predicted = torch.max(pred_output_softmax, 1)
The inference from onnx model takes longer than running inference from pytorch model
onnx model could be converted to TensorRT for faster inference.
You can checkout my TensorRT tutorial
Hope this helps :)
I apologize if I have left out any references from which I could have taken the code snippets from.
References:
https://pytorch.org/docs/stable/onnx.html
https://developer.nvidia.com/blog/how-to-deploy-real-time-text-to-speech-applications-on-gpus-using-tensorrt/https://github.com/microsoft/onnxruntime
https://github.com/microsoft/onnxruntime/issues/2796