Pytorch导出ONNX一些不支持操作的解决

Pytorch导出ONNX一些不支持操作的解决

在使用Pytorch导出ONNX格式模型时经常报一些操作不支持的问题,有些问题在Pytorch版本升级后得到解决,但是考虑到Pytorch版本升级后导出到NCNN又存在问题,所以这里还是考虑使用Pytorch0.3.1的版本进行解决

AdaptiveAvgPool2d

自适应池化层用起来很方便,但是导出到onnx时候就可恼了,onnx并不支持。

使用AvgPool2d替换

导出时候可以对存在AdaptiveAvgPool2d系列的那部分代码进行替换,如

1
2
3
4
5
import torch 
x = torch.autograd.Variable(torch.rand(1, 8, 5, 5))
avgpool = torch.nn.AdaptiveAvgPool2d(1)
x = avgpool(x)
print(x)

此处我们知道tensor为1x8x5x5大小,我们可以直接使用如下代码替换

1
avgpool = torch.nn.AvgPool2d(5)

自定义AdaptiveAvgPool2d类

有些时候一个模块多次调用AdaptiveAvgPool2d,但是每次输入的tensor大小又不一样,这时候使用上一种替换方式就不那么合适了,这时候重新定义一个AdaptiveAvgPool2d类是一个更合理的解决方式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class MyAdaptiveAvgPool2d(nn.Module):
def __init__(self, sz=None):
super().__init__()
self.sz = sz

def forward(self, x):
inp_size = x.size()
kernel_width, kernel_height = inp_size[2], inp_size[3]
if self.sz is not None:
if isinstance(self.sz, int):
kernel_width = ceil(inp_size[2] / self.sz)
kernel_height = ceil(inp_size[3] / self.sz)
elif isinstance(self.sz, list) or isinstance(self.sz, tuple):
assert len(self.sz) == 2
kernel_width = ceil(inp_size[2] / self.sz[0])
kernel_height = ceil(inp_size[3] / self.sz[1])
return F.avg_pool2d(input=x,
ceil_mode=False,
kernel_size=(kernel_width, kernel_height))

使用这个自定义类MyAdaptiveAvgPool2d()代替torch.nn.AdaptiveAvgPool2d(1)即可解决导出到onnx的问题

Expand

在进行tensor的数乘时由于两个tensor的维数不一致,比如维数为[3,8,8,8]的tensor和维数为[3,8,1,1]的tensor进行数乘时,Pytorch默认先将[3,8,1,1]的tensor进行Expand操作,将该tensor变换为[3,8,8,8]的tensor,两个tensor的维数保持一致后再进行数乘操作。看以下代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self,n_channels):
super(MyModel, self).__init__()
self.n_channels = n_channels
self.weight = nn.Parameter(torch.Tensor(self.n_channels))

def forward(self, x):
res = x * self.weight.view(1, -1, 1, 1)
return res


net = MyModel(256)
input_x = torch.autograd.Variable(torch.randn(1, 256, 160, 160))
torch.onnx.export(net, input_x, "model.onnx", verbose=True, export_param=True)

在Pytorch0.3.1版本上会报错为RuntimeError: ONNX export failed: Couldn't export operator expand; this usually means you used a form of broadcasting that ONNX does not currently support

这里可以使用最粗暴的方式来替换Expand操作,使用torch.cat进行替换,对上边MyModel类的forward函数进行如下修改

1
2
3
4
5
6
7
8
9
10
11
12
13
def forward(self):
self.weight = self.weight.view(1, -1, 1, 1)
x_size = x.size()
weight_list = []
for i in range(x_size[2]):
weight_list.append(self.weight)
self.weight = torch.cat(weight_list, 2)
weight_list = []
for i in range(x_size[3]):
weight_list.append(self.weight)
self.weight = torch.cat(weight_list, 3)
res = x * self.weight
return res

ReLU6

ReLU6的公式是y=min(max(0, x), 6),相当于有六个bernoulli分布,即六个硬币,同时抛出正面,这样鼓励网络学习到稀疏特征。而ReLU的公式为y=max(0,x),由于onnx对ReLU6的导出不支持,所以可以通过以下改写实现一样的功能。

1
2
3
4
5
6
7
8
9
10
11
import torch
import torch.nn as nn

input_x = torch.autograd.Variable(torch.randn(1, 256, 160, 160)*6)
relu6 = nn.ReLU6(inplace=True)(input_x)

relu = nn.ReLU(inplace=True)(input_x)
# replace_relu6 = relu.clamp(max=6)
replace_relu6 = 6.0 - nn.ReLU(inplace=True)(6.0-relu)

print((relu6-replace_relu6).sum())