Pytorch导出ONNX一些不支持操作的解决
在使用Pytorch导出ONNX格式模型时经常报一些操作不支持的问题,有些问题在Pytorch版本升级后得到解决,但是考虑到Pytorch版本升级后导出到NCNN又存在问题,所以这里还是考虑使用Pytorch0.3.1的版本进行解决
AdaptiveAvgPool2d
自适应池化层用起来很方便,但是导出到onnx时候就可恼了,onnx并不支持。
使用AvgPool2d替换
导出时候可以对存在AdaptiveAvgPool2d系列的那部分代码进行替换,如
1 | import torch |
此处我们知道tensor为1x8x5x5大小,我们可以直接使用如下代码替换
1 | avgpool = torch.nn.AvgPool2d(5) |
自定义AdaptiveAvgPool2d类
有些时候一个模块多次调用AdaptiveAvgPool2d,但是每次输入的tensor大小又不一样,这时候使用上一种替换方式就不那么合适了,这时候重新定义一个AdaptiveAvgPool2d类是一个更合理的解决方式
1 | class MyAdaptiveAvgPool2d(nn.Module): |
使用这个自定义类MyAdaptiveAvgPool2d()
代替torch.nn.AdaptiveAvgPool2d(1)
即可解决导出到onnx的问题
Expand
在进行tensor的数乘时由于两个tensor的维数不一致,比如维数为[3,8,8,8]的tensor和维数为[3,8,1,1]的tensor进行数乘时,Pytorch默认先将[3,8,1,1]的tensor进行Expand操作,将该tensor变换为[3,8,8,8]的tensor,两个tensor的维数保持一致后再进行数乘操作。看以下代码
1 | import torch |
在Pytorch0.3.1版本上会报错为RuntimeError: ONNX export failed: Couldn't export operator expand; this usually means you used a form of broadcasting that ONNX does not currently support
这里可以使用最粗暴的方式来替换Expand操作,使用torch.cat进行替换,对上边MyModel类的forward函数进行如下修改
1 | def forward(self): |
ReLU6
ReLU6的公式是y=min(max(0, x), 6),相当于有六个bernoulli分布,即六个硬币,同时抛出正面,这样鼓励网络学习到稀疏特征。
而ReLU的公式为y=max(0,x)
,由于onnx对ReLU6的导出不支持,所以可以通过以下改写实现一样的功能。
1 | import torch |