Tensor的镜像翻转

Tensor的镜像翻转

在使用numpy时我们可以对数组进行镜像翻转操作,如以下例子

1
2
3
4
import numpy as np
array = np.array(range(10))
print(array)
print(array[::-1])
[0 1 2 3 4 5 6 7 8 9]
[9 8 7 6 5 4 3 2 1 0]

但是在pytorch中并不能通过tensor[::-1]进行镜像的翻转,此处给出了tensor的镜像翻转方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# https://github.com/pytorch/pytorch/issues/229
import torch
from torch.autograd import Variable
def flip(x, dim):
xsize = x.size()
dim = x.dim() + dim if dim < 0 else dim
x = x.view(-1, *xsize[dim:])
x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1)-1,
-1, -1), ('cpu','cuda')[x.is_cuda])().long(), :]
return x.view(xsize)

# Code to test it with cpu Variable
a = Variable(torch.Tensor([range(1, 25)]).view(1, 2, 3, 4))
print(a)
print(flip(a, 0)) # Or -4
print(flip(a, 1)) # Or -3
print(flip(a, 2)) # Or -2
print(flip(a, 3)) # Or -1
tensor([[[[ 1.,  2.,  3.,  4.],
          [ 5.,  6.,  7.,  8.],
          [ 9., 10., 11., 12.]],

         [[13., 14., 15., 16.],
          [17., 18., 19., 20.],
          [21., 22., 23., 24.]]]])
tensor([[[[ 1.,  2.,  3.,  4.],
          [ 5.,  6.,  7.,  8.],
          [ 9., 10., 11., 12.]],

         [[13., 14., 15., 16.],
          [17., 18., 19., 20.],
          [21., 22., 23., 24.]]]])
tensor([[[[13., 14., 15., 16.],
          [17., 18., 19., 20.],
          [21., 22., 23., 24.]],

         [[ 1.,  2.,  3.,  4.],
          [ 5.,  6.,  7.,  8.],
          [ 9., 10., 11., 12.]]]])
tensor([[[[ 9., 10., 11., 12.],
          [ 5.,  6.,  7.,  8.],
          [ 1.,  2.,  3.,  4.]],

         [[21., 22., 23., 24.],
          [17., 18., 19., 20.],
          [13., 14., 15., 16.]]]])
tensor([[[[ 4.,  3.,  2.,  1.],
          [ 8.,  7.,  6.,  5.],
          [12., 11., 10.,  9.]],

         [[16., 15., 14., 13.],
          [20., 19., 18., 17.],
          [24., 23., 22., 21.]]]])

以下是pytorch>=0.4.0的代码

1
2
3
4
5
6
7
# https://github.com/pytorch/pytorch/issues/229
import torch
def flip(x, dim):
indices = [slice(None)] * x.dim()
indices[dim] = torch.arange(x.size(dim) - 1, -1, -1,
dtype=torch.long, device=x.device)
return x[tuple(indices)]