如何在pytorch中使用自定义的激活函数?

如果自定义的激活函数是可导的,那么可以直接写一个python function来定义并调用,因为pytorch的autograd会自动对其求导。

如果自定义的激活函数不是可导的,比如类似于ReLU的分段可导的函数,需要写一个继承torch.autograd.Function的类,并自行定义forward和backward的过程

在pytorch中提供了定义新的autograd function的tutorial: https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html, tutorial以ReLU为例介绍了在forward, backward中需要自行定义的内容。

 1 import torch
 2 
 3 
 4 class MyReLU(torch.autograd.Function):
 5     """
 6     We can implement our own custom autograd Functions by subclassing
 7     torch.autograd.Function and implementing the forward and backward passes
 8     which operate on Tensors.
 9     """
10 
11     @staticmethod
12     def forward(ctx, input):
13         """
14         In the forward pass we receive a Tensor containing the input and return
15         a Tensor containing the output. ctx is a context object that can be used
16         to stash information for backward computation. You can cache arbitrary
17         objects for use in the backward pass using the ctx.save_for_backward method.
18         """
19         ctx.save_for_backward(input)
20         return input.clamp(min=0)
21 
22     @staticmethod
23     def backward(ctx, grad_output):
24         """
25         In the backward pass we receive a Tensor containing the gradient of the loss
26         with respect to the output, and we need to compute the gradient of the loss
27         with respect to the input.
28         """
29         input, = ctx.saved_tensors
30         grad_input = grad_output.clone()
31         grad_input[input < 0] = 0
32         return grad_input
33 
34 
35 dtype = torch.float
36 device = torch.device("cpu")
37 # device = torch.device("cuda:0") # Uncomment this to run on GPU
38 
39 # N is batch size; D_in is input dimension;
40 # H is hidden dimension; D_out is output dimension.
41 N, D_in, H, D_out = 64, 1000, 100, 10
42 
43 # Create random Tensors to hold input and outputs.
44 x = torch.randn(N, D_in, device=device, dtype=dtype)
45 y = torch.randn(N, D_out, device=device, dtype=dtype)
46 
47 # Create random Tensors for weights.
48 w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
49 w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)
50 
51 learning_rate = 1e-6
52 for t in range(500):
53     # To apply our Function, we use Function.apply method. We alias this as 'relu'.
54     relu = MyReLU.apply
55 
56     # Forward pass: compute predicted y using operations; we compute
57     # ReLU using our custom autograd operation.
58     y_pred = relu(x.mm(w1)).mm(w2)
59 
60     # Compute and print loss
61     loss = (y_pred - y).pow(2).sum()
62     print(t, loss.item())
63 
64     # Use autograd to compute the backward pass.
65     loss.backward()
66 
67     # Update weights using gradient descent
68     with torch.no_grad():
69         w1 -= learning_rate * w1.grad
70         w2 -= learning_rate * w2.grad
71 
72         # Manually zero the gradients after updating weights
73         w1.grad.zero_()
74         w2.grad.zero_()

但是如果定义ReLU函数时,没有使用以上正确的方法,而是直接自定义的函数,会出现什么问题呢?

这里对比了使用以上MyReLU和自定义函数:no_back的实验结果。

1 def no_back(x):
2     return x * (x > 0).float()

代码:

N, D_in, H, D_out = 2, 3, 4, 5

# Create random Tensors to hold input and outputs.
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Create random Tensors for weights.
origin_w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
origin_w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

learning_rate = 1e-3

def myReLU(func, x, y, origin_w1, origin_w2, learning_rate,N = 2, D_in = 3, H = 4, D_out = 5):
    w1 = deepcopy(origin_w1)
    w2 = deepcopy(origin_w2)
    for t in range(5):
        # Forward pass: compute predicted y using operations; we compute
        # ReLU using our custom autograd operation.
        y_pred = func(x.mm(w1)).mm(w2)

        # Compute and print loss
        loss = (y_pred - y).pow(2).sum()
        print("------", t, loss.item(), "------------")

        # Use autograd to compute the backward pass.
        loss.backward()

        # Update weights using gradient descent
        with torch.no_grad():
            print('w1 = ')
            print(w1)
            print('---------------------')
            print("x.mm(w1) = ")
            print(x.mm(w1))
            print('---------------------')
            print('func(x.mm(w1))')
            print(func(x.mm(w1)))
            print('---------------------')
            print("w1.grad:", w1.grad)
            # print("w2.grad:",w2.grad)
            print('---------------------')

            w1 -= learning_rate * w1.grad
            w2 -= learning_rate * w2.grad

            # Manually zero the gradients after updating weights
            w1.grad.zero_()
            w2.grad.zero_()
            print('========================')
            print()


myReLU(func = MyReLU.apply, x = x, y = y, origin_w1 = origin_w1, origin_w2 = origin_w2, learning_rate = learning_rate, N = 2, D_in = 3, H = 4, D_out = 5)
print('============')
print('============')
print('============')
myReLU(func = no_back, x = x, y = y, origin_w1 = origin_w1, origin_w2 = origin_w2, learning_rate = learning_rate, N = 2, D_in = 3, H = 4, D_out = 5)

对于使用了MyReLU.apply的实验结果为:

 1 ------ 0 20.18220329284668 ------------
 2 w1 = 
 3 tensor([[ 0.7070,  2.5772,  0.7987,  2.2287],
 4         [ 0.7425, -0.6309,  0.3268, -1.5072],
 5         [ 0.6930, -2.6128,  0.1949,  0.8819]], requires_grad=True)
 6 ---------------------
 7 x.mm(w1) = 
 8 tensor([[-0.9788,  1.0135, -0.4164,  1.8834],
 9         [-0.7692, -1.8556, -0.7085, -0.9849]])
10 ---------------------
11 func(x.mm(w1))
12 tensor([[0.0000, 1.0135, 0.0000, 1.8834],
13         [0.0000, 0.0000, 0.0000, 0.0000]])
14 ---------------------
15 w1.grad: tensor([[  0.0000,   0.0499,   0.0000,   0.1881],
16         [  0.0000,  -4.4962,   0.0000, -16.9378],
17         [  0.0000,  -0.2401,   0.0000,  -0.9043]])
18 ---------------------
19 ========================
20 
21 ------ 1 19.546737670898438 ------------
22 w1 = 
23 tensor([[ 0.7070,  2.5772,  0.7987,  2.2285],
24         [ 0.7425, -0.6265,  0.3268, -1.4903],
25         [ 0.6930, -2.6126,  0.1949,  0.8828]], requires_grad=True)
26 ---------------------
27 x.mm(w1) = 
28 tensor([[-0.9788,  1.0078, -0.4164,  1.8618],
29         [-0.7692, -1.8574, -0.7085, -0.9915]])
30 ---------------------
31 func(x.mm(w1))
32 tensor([[0.0000, 1.0078, 0.0000, 1.8618],
33         [0.0000, 0.0000, 0.0000, 0.0000]])
34 ---------------------
35 w1.grad: tensor([[  0.0000,   0.0483,   0.0000,   0.1827],
36         [  0.0000,  -4.3446,   0.0000, -16.4493],
37         [  0.0000,  -0.2320,   0.0000,  -0.8782]])
38 ---------------------
39 ========================
40 
41 ------ 2 18.94647789001465 ------------
42 w1 = 
43 tensor([[ 0.7070,  2.5771,  0.7987,  2.2283],
44         [ 0.7425, -0.6221,  0.3268, -1.4738],
45         [ 0.6930, -2.6123,  0.1949,  0.8837]], requires_grad=True)
46 ---------------------
47 x.mm(w1) = 
48 tensor([[-0.9788,  1.0023, -0.4164,  1.8409],
49         [-0.7692, -1.8591, -0.7085, -0.9978]])
50 ---------------------
51 func(x.mm(w1))
52 tensor([[0.0000, 1.0023, 0.0000, 1.8409],
53         [0.0000, 0.0000, 0.0000, 0.0000]])
54 ---------------------
55 w1.grad: tensor([[  0.0000,   0.0467,   0.0000,   0.1775],
56         [  0.0000,  -4.2009,   0.0000, -15.9835],
57         [  0.0000,  -0.2243,   0.0000,  -0.8534]])
58 ---------------------
59 ========================
60 
61 ------ 3 18.378826141357422 ------------
62 w1 = 
63 tensor([[ 0.7070,  2.5771,  0.7987,  2.2281],
64         [ 0.7425, -0.6179,  0.3268, -1.4578],
65         [ 0.6930, -2.6121,  0.1949,  0.8846]], requires_grad=True)
66 ---------------------
67 x.mm(w1) = 
68 tensor([[-0.9788,  0.9969, -0.4164,  1.8206],
69         [-0.7692, -1.8607, -0.7085, -1.0040]])
70 ---------------------
71 func(x.mm(w1))
72 tensor([[0.0000, 0.9969, 0.0000, 1.8206],
73         [0.0000, 0.0000, 0.0000, 0.0000]])
74 ---------------------
75 w1.grad: tensor([[  0.0000,   0.0451,   0.0000,   0.1726],
76         [  0.0000,  -4.0644,   0.0000, -15.5391],
77         [  0.0000,  -0.2170,   0.0000,  -0.8296]])
78 ---------------------
79 ========================
80 
81 ------ 4 17.841421127319336 ------------
82 w1 = 
83 tensor([[ 0.7070,  2.5770,  0.7987,  2.2280],
84         [ 0.7425, -0.6138,  0.3268, -1.4423],
85         [ 0.6930, -2.6119,  0.1949,  0.8854]], requires_grad=True)
86 ---------------------
87 x.mm(w1) = 
88 tensor([[-0.9788,  0.9918, -0.4164,  1.8008],
89         [-0.7692, -1.8623, -0.7085, -1.0100]])
90 ---------------------
91 func(x.mm(w1))
92 tensor([[0.0000, 0.9918, 0.0000, 1.8008],
93         [0.0000, 0.0000, 0.0000, 0.0000]])
94 ---------------------
95 w1.grad: tensor([[  0.0000,   0.0437,   0.0000,   0.1679],
96         [  0.0000,  -3.9346,   0.0000, -15.1145],
97         [  0.0000,  -0.2101,   0.0000,  -0.8070]])
98 ---------------------
99 ========================

View Code