Pytorch

3 分钟阅读

概述

帮助文档

pip install torch, pip install torchvision, pip install onnx, pip install onnxruntime

构建模型与转ONNX

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import torch
from torch import nn
from torch.nn import functional as F
import onnxruntime

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

    def forward(self, x1, x2):
        y = torch.add(x1, x2)
        z = torch.softmax(y, -1)
        return z


input0 = torch.randn(4, 10)
input1 = torch.randn(4, 10)
net = Net()
output = net(input0, input1)
print(output)

# convert to onnx model, and test
torch.onnx.export(
    net,                  # pytorch model
    (input0, input1),      # input (tuple for multiple inputes)
    "test.onnx",          # onnx model to save
    export_params=True,        # store the trained parameter weights inside the model file
    opset_version=13,          # the ONNX version to export the model to
    do_constant_folding=True,  # whether to execute constant folding for optimization
    input_names=['input0', 'input1'],   # the model's input names
    output_names=['output'],  # the model's output names
    dynamic_axes=None,         # variable lenght axes
)

session = onnxruntime.InferenceSession("test.onnx")
input = {
    'input0':input0.numpy(),
    'input1':input1.numpy(),
}
out = session.run(None, input)

print(out)

获取Params和Flops

安装pip install torchstat

方法如下:

import torch
from torchstat import stat
net.load_state_dict(params)
stat(net,(3,224,224))

显示结果如:

Total params: 1,366,798
-----------------------------------------------------------------------------------------------------------------------------------------------------------
Total memory: 13.11MB
Total MAdd: 86.1MMAdd
Total Flops: 44.0MFlops
Total MemR+W: 29.92MB