Pytorch在其tensor有着很多类似但又不同的操作,这里对这些操作做一个更加详细的解释。

tensor.clone(), tensor.detach(), tensor.data

tensor.clone(), tensor.detach(), tensor.data这三种操作,都有着对tensor进行copy的意思,但实际进行copy的时候却有着一定的不同!

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# clone()真正的copy,新开辟存储
x = torch.tensor(([1.0]),requires_grad=True)
y = x.clone()
print("Id_x:{} Id_y:{}".format(id(x),id(y)))
y += 1
print("x:{} y:{}".format(x,y))

print('-----------------------------------')
# detach()与原来的tensor共享存储,操作之后requires_grad变为false
x = torch.tensor(([1.0]),requires_grad=True)
y = x.detach()
print("Id_x:{} Id_y:{}".format(id(x),id(y)))
y += 1
print("x:{} y:{}".format(x,y))

print('-----------------------------------')
# .data与detach()一样,官方回答是没有足够的时间改代码,所以这个东西还在
x = torch.tensor(([1.0]),requires_grad=True)
y = x.data
print("Id_x:{} Id_y:{}".format(id(x),id(y)))
y += 1
print("x:{} y:{}".format(x,y))
1
2
3
4
5
6
7
8
Id_x:140684285215008 Id_y:140684285217384
x:tensor([1.], requires_grad=True) y:tensor([2.], grad_fn=<AddBackward0>)
-----------------------------------
Id_x:140684285216808 Id_y:140684285215008
x:tensor([2.], requires_grad=True) y:tensor([2.])
-----------------------------------
Id_x:140684285216088 Id_y:140684285216808
x:tensor([2.], requires_grad=True) y:tensor([2.])

在官方的文档中,对于detach的解释是:Returns a new Tensor, detached from the current graph.
这句话该怎么理解,看下面的代码。
可以看出,z detach 得到的 p 就与之前的计算图完全没有关系了,这样pq反传的梯度就不会到x上。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
x = torch.tensor(([1.0]),requires_grad=True)
y = x**2
z = 2*y
w= z**3

# detach it, so the gradient w.r.t `p` does not effect `z`!
p = z.detach()
print(p)
q = torch.tensor(([2.0]), requires_grad=True)
pq = p*q
pq.backward(retain_graph=True)

w.backward()
print(x.grad)

x = torch.tensor(([1.0]),requires_grad=True)
y = x**2
z = 2*y
w= z**3

# create a subpath for z
p = z.clone()
print(p)
q = torch.tensor(([2.0]), requires_grad=True)
pq = p*q
pq.backward(retain_graph=True)

w.backward()
print(x.grad)
1
2
3
4
tensor([2.])
tensor([48.])
tensor([2.], grad_fn=<CloneBackward>)
tensor([56.])

tensor.reshape(),tensor.view()

暂时认为一样吧,我还没搞懂到底怎么不一样