Pytorch
概述
pip install torch
, pip install torchvision
, pip install onnx
, pip install onnxruntime
构建模型与转ONNX
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
from torch import nn
from torch.nn import functional as F
import onnxruntime
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
def forward(self, x1, x2):
y = torch.add(x1, x2)
z = torch.softmax(y, -1)
return z
input0 = torch.randn(4, 10)
input1 = torch.randn(4, 10)
net = Net()
output = net(input0, input1)
print(output)
# convert to onnx model, and test
torch.onnx.export(
net, # pytorch model
(input0, input1), # input (tuple for multiple inputes)
"test.onnx", # onnx model to save
export_params=True, # store the trained parameter weights inside the model file
opset_version=13, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input0', 'input1'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes=None, # variable lenght axes
)
session = onnxruntime.InferenceSession("test.onnx")
input = {
'input0':input0.numpy(),
'input1':input1.numpy(),
}
out = session.run(None, input)
print(out)
获取Params和Flops
安装pip install torchstat
方法如下:
import torch
from torchstat import stat
net.load_state_dict(params)
stat(net,(3,224,224))
显示结果如:
Total params: 1,366,798
-----------------------------------------------------------------------------------------------------------------------------------------------------------
Total memory: 13.11MB
Total MAdd: 86.1MMAdd
Total Flops: 44.0MFlops
Total MemR+W: 29.92MB