Pytorch 模型版本切换
0.3.1转到0.4.1或更高版本
直接使用代码导入时常碰到 ‘BatchNorm2d’ object has no attribute ‘track_running_stats’的报错信息,这是由于0.3.1中的BN操作中没有配置track_running_stats参数,0.3.1中BatchNorm的定义如下class torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True)
$$ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta$$
Parameters:
num_features – num_features from an expected input of size batch_size x num_features x height x width
eps – a value added to the denominator for numerical stability. Default: 1e-5
momentum – the value used for the running_mean and running_var computation. Default: 0.1
affine – a boolean value that when set to True, gives the layer learnable affine parameters. Default: True
Shape:
Input: (N,C,H,W)
Output: (N,C,H,W) (same shape as input)
而在0.4.1中定义发生了变化class torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
If track_running_stats is set to False, this layer then does not keep running estimates,
and batch statistics are instead used during evaluation time as well.
所以使用0.4.1或以上版本导入0.3.1模型时需要对模型中的BN层添加track_running_stats参数,代码如下
1 | def recursion_change_bn(module): |
另外,也可以在导入模型处直接修改模型,模型的statedict本身可以理解为一个Orderdict,在模型中添加参数num_batches_tracked对应的值即可. 具体做法是在键值为running_var后添加一个键值为num_batches_tracked,值为0的Tensor. 具体代码如下
1 | checkpoint = torch.load(checkpoint_path, map_location=device) |
0.3.1版本导入0.4.1以上版本模型
0.4中使用设备:.to(device)
0.4中删除了Variable,直接tensor就可以
with torch.no_grad():的使用代替volatile;弃用volatile,测试中不需要计算梯度的话,用with torch.no_grad():
data改用.detach;x.detach()返回一个requires_grad=False的共享数据的Tensor,并且,如果反向传播中需要x,那么x.detach返回的Tensor的变动会被autograd追踪。相反,x.data()返回的Tensor,其变动不会被autograd追踪,如果反向传播需要用到x的话,值就不对了。
torchvision
pytorch0.4有一些接口已经改变,且模型向下版本兼容,不向上兼容。
In PyTorch 0.4, is it recommended to use
reshape
thanview
when it is possible?Question about ‘rebuild_tensor_v2’?
使用pytorch0.3导入pytorch0.4保存的模型时候在导入前添加如下代码段,解决的报错内容为(AttributeError: Can’t get attribute ‘_rebuild_tensor_v2’ on
\\lib\\site-packages\\torch\\_utils.py'>),详情可对比查看_utils.py文件:
1 | # This can be removed once PyTorch 0.4.x is out. |
在导出为ONNX模型时还可能会报错存在多余的num_batches_tracked值, 错误代码为KeyError: 'unexpected key "module.bn1.num_batches_tracked" in state_dict'
, 此处的处理方式和上边的添加num_batches_tracked键值对应,删除该键值即可,具体代码如下
1 | checkpoint = torch.load(checkpoint_path, map_location=device) |
由0.4.1导出为0.3.1的ONNX模型时,上述两段代码都需要加入
导出为1.0.0模型
pytorch1.0.0添加了torch.jit, 可以直接将模型和网络打包到模型文件中,而不需要在使用模型文件时导入网络定义,在模型的使用时变得更加方便了
模型的jit导出
1 | def pth_to_jit(model, save_path, device="cuda:0"): |
jit模型导入使用
1 | def load_jit(jit_model_path): |