陈云pytorch学习笔记_用50行代码搭建ResNet

 

 

 

import torch as t
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
# 残差快   残差网络公式 a^[L+2] = g(a^[L]+z^[L+2])
class ResidualBlock(nn.Module):
    def __init__(self, inchannel, outchannel, stride=1, shortcut=None):   #shortcut=None对应图中跨层连接的实线,对应残差网络公式 a^[L+2] = g(a^[L]+z^[L+2]),否则对应当
                                                                          # 通道数变化后第一个残差块的虚线,此时对应的残差公式为a^[L+2] = g(z^[L+1]+z^[L+2])
        nn.Module.__init__(self)
        self.left = nn.Sequential(#得到z^[L+2]
            nn.Conv2d(inchannel, outchannel, 3, stride, 1, bias=False),
            nn.BatchNorm2d(outchannel),
            nn.ReLU(inplace= True),
            nn.Conv2d(outchannel, outchannel, 3, 1, 1, bias=False),
            nn.BatchNorm2d(outchannel))
        self.right = shortcut#决定是跨层连接的是实线还是虚线
    def forward(self, x):
        out = self.left(x)
        residual = x if self.right is None else self.right(x)
        out += residual
        return F.relu(out)   #a^[L+2] = g(a^[L]+z^[L+2])
    # ResNet34
class ResNet(nn.Module):
    def __init__(self, num_classes=1000):
        nn.Module.__init__(self)
        # 前几层图像转换(网络输入部分)
        self.pre = nn.Sequential(#对应图中开始残差处理之前的部分
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, 2, 1)
        )
        # 中间卷积部分
        self.layer1 = self._make_layer(64, 64, 3)
        self.layer2 = self._make_layer(64, 128, 4, stride=2)#stride=2代表每一个残差快的第一个层的2/
        self.layer3 = self._make_layer(128, 256, 6, stride=2)
        self.layer4 = self._make_layer(256, 512, 3, stride=2)
        # 平均池化
        self.avgpool = nn.AvgPool2d(7, stride=1)
        # 分类用的全连接
        self.fc = nn.Linear(512, 1000)
    def _make_layer(self, inchannel, outchannel, block_num, stride=1):
        # 使得输入输出通道数调整为一致。比如第二个layer时,第一个残差快输入为64,输出为128
        shortcut = nn.Sequential(#对应着每类相同通道数的残差快的第一个跨层直线是虚线
            nn.Conv2d(inchannel, outchannel, 1, stride, bias=False),
            nn.BatchNorm2d(outchannel))
        layers = []
        layers.append(ResidualBlock(inchannel, outchannel, stride, shortcut))
        for i in range(1, block_num):
            layers.append(ResidualBlock(outchannel, outchannel))
        return nn.Sequential(*layers)
    def forward(self, x):
        x = self.pre(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)#torch.Size([1, 512])
        return self.fc(x)

model = ResNet()
input = t.autograd.Variable(t.randn(1, 3, 224, 224))
o = model(input)
print(o)
model = models.resnet34()#调用工具包实线残差网络
o1 = model(input)
print(o1)

output

D:\anaconda\anaconda\pythonw.exe D:/Code/Python/pytorch入门与实践/第四章_神经网络工具箱nn/搭建ResNet.py
tensor([[-3.7385e-01, -2.2010e-01,  8.9218e-01, -6.6067e-01, -5.8422e-01,
          7.6649e-01, -8.0401e-01, -3.0225e-01,  2.4314e+00,  9.1019e-01,
         -1.1270e+00, -2.9847e-01, -6.0022e-01, -3.5480e-01,  3.5396e-01,
          3.6958e-01, -5.9464e-01, -4.5049e-01,  7.3531e-01, -1.1082e+00,
         -2.9160e-01,  4.7690e-01, -5.0259e-01,  4.1628e-01, -9.3588e-01,
          2.4529e-01,  1.2500e+00, -1.6038e-01, -3.3023e-01,  3.6957e-01,
          4.5195e-01, -6.0984e-01,  8.9558e-02,  4.2407e-01,  1.2888e+00,
         -5.3017e-01, -5.0509e-01,  1.3775e+00,  2.0299e-01,  2.7299e-01,
         -8.6381e-02,  8.3670e-01,  1.9371e-01,  8.4775e-01,  4.0002e-01,
          8.1064e-01, -1.0556e+00,  7.7371e-01,  6.8891e-01, -1.0209e+00,
         -3.1227e-01,  3.0977e-02, -2.9849e-01,  1.0534e+00, -1.8499e-01,
         -8.6205e-01, -3.4020e-01,  4.2196e-01, -1.0207e-01, -1.2846e+00,
          3.1905e-01,  9.9212e-01,  6.0208e-01,  6.2960e-01,  4.1285e-01,
         -1.1597e+00, -2.5292e-01,  7.0145e-01,  4.8221e-01,  6.6790e-01,
          2.8644e-01, -2.6554e-01,  3.4255e-01,  3.8341e-01, -4.2387e-01,
         -7.5981e-01,  2.3203e-01, -3.8434e-01, -6.9346e-01, -5.4996e-01,
          2.7865e-01, -3.9543e-01, -7.8868e-01, -6.5345e-01,  2.9946e-01,
          5.2693e-01, -4.5380e-01, -2.2429e-01, -1.4129e+00, -3.7963e-01,
         -2.2408e-01, -6.9917e-02,  1.7414e-01, -2.7821e-01, -6.4848e-01,
         -4.3716e-01, -3.5371e-01, -3.6100e-01, -5.7401e-01,  3.7754e-01,
         -4.1583e-02, -3.4307e-01,  5.9179e-01,  2.7279e-02, -2.6988e-01,
         -4.7790e-01, -3.5140e-01,  5.1556e-01,  8.0434e-01, -4.4143e-01,
          6.0849e-01,  3.3159e-01, -7.6929e-03,  1.0759e-01,  2.6402e-01,
          2.3914e-01, -6.5949e-01, -9.0380e-01, -5.7449e-01,  8.1698e-01,
          9.0535e-01,  3.7668e-01,  3.2937e-01,  4.5524e-01,  1.6086e-01,
         -2.8713e-01,  1.7160e-01,  3.5057e-01,  7.0938e-01, -6.3579e-02,
         -3.9463e-01,  2.6736e-01, -4.4593e-01,  1.0601e+00, -3.6988e-01,
         -6.2878e-01,  3.7628e-01,  5.3490e-01, -3.2025e-01, -6.2648e-01,
         -5.2117e-02, -4.0097e-01, -1.1775e+00,  1.2687e+00,  1.1808e+00,
          3.4300e-01, -2.3935e-01, -7.8519e-01,  3.3952e-01, -2.1779e-01,
         -4.9251e-01, -4.1354e-01, -7.1647e-01,  1.1502e+00, -9.0239e-01,
         -1.8571e-01,  8.7283e-01,  5.6701e-01,  9.5695e-02, -2.6622e-01,
          3.3122e-02,  5.8339e-01,  6.4253e-01, -1.2866e-01,  1.9386e+00,
          3.0843e-01,  1.7281e-01,  1.3516e-01, -7.3507e-02, -3.3128e-01,
         -6.3045e-02, -1.6130e-01, -6.2078e-01,  1.0369e+00,  7.4816e-01,
          4.3222e-01, -1.2471e+00, -2.8628e-01, -2.2325e-01, -1.3061e+00,
          6.0621e-01,  1.2517e+00,  7.9576e-01,  1.3829e-01,  3.0933e-01,
         -5.7864e-01, -1.3680e-01,  4.2718e-01,  6.0374e-01, -6.5616e-01,
          8.8827e-01,  6.1121e-01,  9.2531e-01, -9.0994e-01, -1.6550e+00,
         -9.5535e-01, -1.3156e+00,  2.3245e-01, -4.6053e-01,  1.9782e-01,
         -7.3612e-01,  7.3810e-01, -6.6007e-02, -5.0354e-01,  5.7257e-01,
         -8.2178e-02, -1.0175e+00, -7.8140e-01, -8.3596e-02,  4.3341e-01,
          6.1036e-01, -6.0388e-01,  3.3036e-01,  3.2923e-01,  1.2033e+00,
         -6.1371e-01, -8.7145e-01, -7.0251e-02, -1.9632e-01, -4.0972e-01,
         -6.5015e-01, -1.1036e+00,  4.5884e-01, -7.8906e-01,  9.7192e-01,
          7.7442e-01,  3.4869e-01,  9.7635e-02, -9.9016e-01,  7.6778e-01,
          3.5343e-01,  1.1142e-01, -1.4715e-01, -3.1201e-01, -4.6759e-01,
          4.2290e-01,  2.9731e-01, -5.6528e-01, -4.6112e-01, -9.2171e-03,
          2.7790e-01,  2.2434e-01,  8.3167e-01,  6.0836e-01,  7.9597e-01,
          8.8949e-01, -5.5800e-01, -5.8002e-01,  2.3448e-02, -3.7334e-01,
         -2.2329e-01, -1.1076e+00, -3.0460e-01,  1.4154e-02,  3.7740e-04,
          9.9988e-02, -7.1763e-01, -2.9103e-01,  3.7885e-01,  4.5475e-01,
         -8.8300e-01, -5.9084e-01,  1.0630e-01, -7.6122e-01,  5.6615e-01,
          1.5967e-02, -1.3541e+00, -2.3975e-02,  3.4815e-01,  6.6317e-01,
          3.3460e-01,  7.1318e-01,  1.1366e+00,  1.1671e+00,  5.2543e-02,
         -8.5805e-01,  5.0632e-01, -1.0799e+00, -3.6625e-01,  6.3304e-01,
          1.7650e-01,  1.2427e+00, -4.2824e-01, -9.7264e-01, -6.5294e-01,
         -4.1995e-01,  1.6176e-01,  6.5306e-01,  6.5527e-01,  1.5359e-01,
         -1.4403e-02,  5.5345e-01, -1.2129e+00, -1.6561e-01, -1.8614e-01,
         -6.3296e-01,  9.8403e-01,  1.9044e-01, -7.3609e-01,  1.3295e-01,
         -2.9614e-01, -2.6278e-01,  6.1773e-01, -2.6080e-02, -2.6567e-01,
          2.1076e-01, -6.1336e-01, -3.0605e-01,  4.8003e-01, -3.7147e-01,
         -3.3662e-01,  6.8647e-01,  1.2991e+00,  1.0152e+00, -4.5890e-01,
         -4.8116e-01,  6.6182e-01,  2.0629e-01,  1.1687e+00, -1.4938e-01,
          5.4687e-01,  2.8266e-01,  1.2739e+00,  2.1758e-01,  3.4379e-01,
         -1.7554e-01, -2.9683e-01, -3.6898e-01, -3.4443e-01,  4.4539e-01,
          6.2362e-01, -7.0732e-01,  6.8179e-01, -8.1357e-01, -2.9273e-02,
          1.0692e-01,  3.0787e-01, -1.5126e-01, -6.9601e-01, -6.7925e-03,
         -1.2032e-01,  4.3556e-01, -2.5765e-01,  3.5000e-01,  2.6138e-01,
         -7.1060e-01, -7.8778e-01, -5.7865e-01,  1.9608e-02, -2.6077e-01,
          2.3804e-01,  5.9406e-01,  6.5240e-01,  5.8997e-01, -2.6604e-01,
         -8.2560e-01, -4.9733e-01, -1.1837e+00, -6.0205e-01,  4.3423e-01,
          8.6452e-01, -6.5541e-01, -4.8626e-01,  6.9140e-01,  1.4461e-01,
         -2.2926e-01, -2.0209e-01, -2.8848e-01, -9.6731e-01, -1.7899e-01,
         -9.3281e-01, -1.1425e-01,  8.9797e-01, -1.6610e-01, -1.0455e+00,
         -5.4111e-01, -6.5867e-01,  4.8375e-01,  8.1165e-01, -4.3111e-01,
          1.2101e+00, -7.1869e-01, -9.3713e-01,  8.0531e-01, -1.0737e+00,
         -8.2951e-02,  9.9555e-01,  5.7983e-01,  6.4511e-02, -1.5688e-01,
         -7.8642e-01,  2.1101e-01,  3.5995e-01, -1.1792e+00,  2.7492e-01,
         -1.1338e+00,  1.4390e+00,  7.4970e-01,  4.1525e-01, -2.2618e-01,
         -5.6868e-01, -1.0587e+00, -8.4871e-01,  4.2294e-01, -7.5937e-01,
          1.2478e+00,  3.1876e-01,  5.4925e-02, -4.2920e-01,  2.9908e-02,
         -7.2298e-01, -5.0745e-01, -1.0847e-01, -3.3263e-01,  3.8415e-01,
         -3.0520e-01, -7.3637e-01,  3.7017e-01,  2.0959e-01, -3.9341e-01,
          9.2597e-02, -7.0634e-01, -5.4246e-01, -5.9055e-01,  1.0688e-01,
         -5.5952e-01,  1.4558e+00, -7.3014e-01, -4.1277e-01,  1.1603e+00,
         -2.9320e-01, -1.1457e+00,  7.0188e-01, -7.2187e-01,  4.9911e-01,
         -1.7366e-01,  7.8894e-01, -6.5754e-01, -6.5171e-01,  2.3485e-01,
          2.2758e-01,  1.1867e+00,  4.1541e-01,  8.7716e-02,  5.4310e-01,
         -3.5418e-01, -2.2289e-01,  5.3745e-01,  4.9035e-01, -1.5985e-01,
          1.4329e-02,  1.1308e+00, -9.6263e-01,  2.2490e-01,  1.6313e-01,
          5.1836e-01,  2.7269e-01,  1.3424e-01, -3.2040e-01,  1.7400e-01,
         -6.9929e-01, -7.7752e-01, -6.4446e-01,  2.9672e-01, -1.3011e-02,
         -4.6553e-01,  5.0777e-01, -2.2849e-01, -3.6042e-01,  1.3338e+00,
          1.6765e-01, -9.3721e-01,  5.1379e-01,  2.6106e-01,  7.9202e-01,
          7.5039e-01, -7.1235e-02, -5.7395e-03, -5.5282e-01,  7.3725e-01,
          5.6211e-01, -6.3226e-01, -4.4917e-02,  4.0115e-01,  3.1123e-01,
         -8.5666e-01, -1.1569e+00,  3.8246e-01, -4.3587e-01, -1.1493e+00,
         -3.8287e-01,  3.4385e-01, -7.5745e-01, -2.5882e-01, -2.8164e-01,
         -1.1965e-02,  2.1589e-01, -4.3658e-01, -1.3746e-01,  9.7378e-01,
         -1.0517e+00, -4.0558e-01, -1.0544e-02, -6.4660e-02, -5.1592e-01,
          4.5318e-01,  1.3184e+00,  3.7338e-01,  3.4490e-02,  1.4038e+00,
         -1.6802e-01,  2.3007e-02,  4.4980e-03,  1.2705e-01,  2.7906e-01,
         -8.4902e-01, -3.8745e-01,  3.1278e-01, -8.8074e-01, -4.7914e-01,
          7.2190e-02, -6.4725e-01, -2.0902e-01,  1.2280e-01, -1.8186e-01,
          1.0589e-01, -6.0947e-01, -2.8543e-01, -1.0723e+00, -1.7837e-01,
          1.4746e+00,  1.1301e+00,  8.8037e-01, -3.8367e-01,  6.1571e-01,
          5.8543e-02,  5.3181e-01, -2.4058e-01,  6.9641e-01,  5.4891e-01,
          4.8759e-03, -1.2818e+00,  7.0707e-01,  4.6681e-01, -2.3600e-01,
         -2.7093e-01,  2.1033e-01, -3.7307e-01,  2.2353e-01, -1.5244e-01,
          2.1925e-02, -2.3214e-01,  2.5308e-02,  7.7142e-01,  6.8966e-01,
          5.5418e-01, -3.1878e-01, -8.5453e-01,  5.2859e-01,  7.1266e-01,
          3.4018e-01,  2.4858e-01, -7.1972e-01,  1.2186e+00, -8.9309e-01,
          3.7593e-01,  9.3331e-01,  1.6154e+00,  3.6179e-01,  9.8585e-01,
          5.0944e-01,  3.2588e-01, -2.8218e-01, -7.9708e-02,  5.0813e-01,
          7.1221e-01,  3.9624e-01,  6.6906e-01, -1.6557e-01, -4.3672e-01,
         -4.2653e-01,  7.3462e-01, -8.6661e-02,  6.1583e-01,  1.1201e+00,
         -1.3712e+00,  7.1885e-01, -9.1739e-02,  4.3945e-01,  1.1710e+00,
         -2.2104e-01,  2.4807e-01,  2.9516e-01, -1.0306e+00,  1.2226e+00,
         -2.5720e-01, -2.0021e-01, -6.3561e-01,  7.7526e-01, -1.3281e-01,
          4.2105e-01, -9.8721e-01,  6.0226e-02, -7.4250e-01, -1.1144e-01,
          2.1858e-01,  1.0423e+00, -2.9606e-01,  1.8390e-01,  3.6015e-01,
         -3.9900e-03,  8.7758e-02,  1.1577e-01,  1.1973e+00,  3.1103e-01,
         -4.2398e-01,  3.0271e-01, -2.8444e-01,  5.1662e-01, -1.3356e+00,
          5.8137e-01, -5.7901e-01,  6.6809e-01,  5.4561e-01,  7.4738e-01,
         -9.4391e-03, -2.7110e-01,  6.9678e-02, -3.0574e-01,  4.4471e-01,
         -6.3125e-02,  4.8383e-01, -4.3376e-01,  9.3516e-01, -1.0835e+00,
          3.9941e-01,  7.5706e-01, -3.8424e-02, -1.5844e-01,  6.9874e-01,
         -2.2643e-01, -8.1118e-01,  3.4155e-01,  6.2197e-01, -4.5426e-01,
         -1.1235e-01, -2.8171e-01, -4.2020e-01,  5.8910e-03, -2.1005e-01,
          9.5020e-01, -5.1491e-01, -5.2156e-01, -3.6768e-01, -3.3160e-01,
         -8.8831e-01,  1.6263e-01, -1.2790e+00, -2.2097e-01,  6.3435e-01,
          3.3925e-01,  5.1287e-01,  3.0567e-02,  1.4348e-01, -3.5183e-01,
         -4.7715e-01, -5.9157e-01, -5.3332e-01, -1.2875e-01, -6.4399e-01,
          1.2068e+00, -4.3453e-01,  3.6951e-01, -3.7221e-01,  1.0549e+00,
         -7.8913e-01, -7.2606e-01,  1.7644e+00, -7.1351e-01, -4.7304e-01,
          9.7223e-01, -7.1468e-01,  1.0018e-01,  6.0553e-01,  1.4333e-01,
          3.6390e-01, -5.6337e-02,  3.2207e-01,  5.4261e-01, -4.5484e-01,
         -9.2550e-02, -1.0209e-01, -5.5761e-01, -8.3987e-02,  6.7479e-01,
          7.3383e-01,  3.2637e-01,  2.2839e-01,  7.3619e-01,  4.5373e-02,
          1.4767e+00,  1.1286e+00,  1.0320e-01,  6.0987e-03,  1.8241e-02,
         -4.1522e-01,  1.4877e+00, -9.9928e-02,  1.1028e+00,  1.8680e-01,
          9.3361e-01,  1.1641e-03, -6.7221e-01, -1.1105e+00, -4.5087e-01,
         -1.9451e-01,  5.3225e-01, -7.0291e-02,  2.6069e-01,  3.2638e-01,
         -9.7803e-01,  2.5177e-01, -5.2165e-02,  3.2999e-01,  7.0848e-01,
         -4.7834e-01,  1.2501e+00, -3.3023e-01, -9.7759e-01,  6.8180e-01,
         -6.7149e-01,  1.3792e+00, -3.1857e-01, -9.0531e-01,  4.8713e-01,
         -1.1678e-01, -5.3198e-01, -8.2755e-01, -6.2357e-01, -7.8093e-01,
         -1.9248e-01,  2.3543e-01, -3.6204e-01, -3.1342e-01, -1.2192e+00,
          5.2952e-01,  2.6511e-01,  1.4131e+00, -9.5214e-02, -7.0332e-01,
         -1.9167e-01,  8.7300e-01, -3.5202e-01,  8.1680e-01, -1.1795e+00,
         -4.6051e-01, -7.1857e-01, -1.1671e-01, -3.5521e-01,  3.9610e-01,
          6.4604e-01,  2.4028e-01,  2.4551e-01,  3.8291e-01,  3.6420e-01,
         -6.4876e-01, -1.0353e-01, -3.0709e-01,  5.1353e-01,  5.7663e-01,
          2.8187e-01,  3.8347e-01, -3.2358e-01,  2.4605e-01, -5.6711e-01,
          7.2537e-01,  9.1548e-01,  8.8802e-01,  6.2908e-01,  3.6824e-01,
         -1.1807e-01, -6.2132e-01, -2.2127e-01,  7.0661e-01, -3.4173e-01,
          3.7389e-02,  6.3347e-02, -1.5815e-01, -5.4016e-02,  6.1224e-01,
         -1.1764e-01, -6.6400e-01, -6.9328e-01,  1.1396e-01,  2.9845e-01,
          1.1607e+00, -3.9230e-01,  6.9676e-01,  3.7753e-01, -4.5834e-01,
          3.5393e-01, -1.9701e-01, -4.0007e-02,  1.6460e+00,  1.0020e+00,
          2.7125e-01,  5.4337e-01,  9.6963e-01,  3.7202e-01,  1.3688e-01,
          5.5945e-01, -6.7389e-01,  4.1863e-01,  9.1551e-01, -1.4212e-01,
         -8.4382e-01,  6.4143e-01, -2.8692e-01, -8.3050e-01,  5.6636e-01,
          3.8771e-01,  6.7161e-01,  7.1014e-01,  8.8337e-01, -6.5802e-01,
         -1.9257e-01, -8.7510e-01, -8.5440e-01, -1.0751e+00, -3.9010e-01,
          9.7424e-01, -1.0402e-01,  5.7751e-01,  2.9744e-01,  7.4402e-01,
          1.5031e-01,  2.2013e-01, -2.3053e-01, -4.6259e-01,  2.5106e-01,
         -1.0140e+00,  8.0855e-01, -1.1636e-01, -7.8362e-02, -9.2715e-02,
         -1.0271e+00,  2.9693e-01, -2.0904e-01, -5.0984e-01,  1.3045e+00,
          4.8532e-01,  1.4346e-01,  7.6788e-01,  9.7047e-01,  4.0762e-01,
          2.3484e-01,  3.1950e-01, -2.3318e-01, -6.8306e-03,  5.6380e-01,
         -9.6460e-02, -2.8250e-01,  1.1092e+00, -3.1063e-02, -2.1305e-01,
          6.3479e-01,  2.1984e-01, -1.1693e+00,  4.7175e-01, -1.3506e-01,
          1.1924e-01,  4.1394e-01, -1.2817e+00, -2.7704e-01,  1.0168e+00,
          2.1124e-01,  6.1006e-02, -2.0014e-01,  1.4460e+00,  3.5466e-01,
          3.7454e-01, -1.2640e-02,  6.0403e-03,  3.2332e-01,  8.9131e-01,
          4.5607e-02, -6.6399e-02, -2.0708e+00,  3.2648e-01,  7.6369e-01,
          4.1520e-01, -2.7174e-01, -5.1358e-01,  9.6802e-01,  3.8855e-01,
          6.7598e-01,  3.1721e-01,  2.0969e-01, -1.3217e-01,  7.5170e-01,
         -1.0165e+00,  4.6450e-01,  3.1623e-01, -1.2664e-01, -5.8193e-01,
          7.5702e-01,  2.1583e-01,  8.0843e-01,  8.0445e-01, -6.3687e-01,
         -2.1509e-01,  1.3130e-01, -4.3707e-01, -3.1932e-01,  2.4451e-01,
          1.8980e-01,  1.7880e-01,  5.7971e-01, -9.6651e-01,  4.5083e-01,
          5.5928e-01,  3.5459e-01,  1.1491e-01,  1.0462e+00, -1.2330e-01,
         -2.5296e-01,  2.2241e-02,  1.1558e+00, -5.2790e-01, -5.4470e-01,
          6.7174e-01, -6.3254e-01,  1.0079e-01, -5.7307e-01, -5.3185e-01,
          1.3242e+00, -4.9839e-01, -5.1236e-03, -1.1210e+00,  1.0798e+00,
          7.2417e-02, -4.2987e-01,  1.2662e+00,  6.0314e-01, -9.2942e-02,
          9.2841e-01, -1.3543e-01, -3.3278e-01,  6.6304e-01, -4.1726e-01,
         -3.8710e-02,  6.5126e-01, -7.1032e-02,  3.6333e-01,  8.8658e-02,
          5.5386e-01, -1.1523e+00, -4.8741e-01, -3.0395e-01, -7.6689e-01,
          1.0167e+00,  2.8526e-01,  4.7810e-01,  1.0492e-01,  3.2575e-01,
         -8.6603e-01,  8.0494e-01,  6.8050e-01,  1.5820e-01, -1.0125e-01,
         -9.4131e-02, -8.0925e-01,  9.6557e-01, -3.8323e-04, -5.9713e-01,
         -4.7461e-01,  2.8437e-01, -7.0949e-02, -6.6371e-01,  5.8345e-01,
         -1.9877e-01, -1.0992e+00, -6.4899e-01,  7.9953e-01,  4.7137e-01,
          1.0099e+00,  6.3704e-01,  2.3527e-01,  2.2146e-01, -1.3238e-01,
         -4.7322e-01,  1.1008e+00,  4.1789e-01, -2.6206e-01, -2.6280e-01,
         -1.3215e-01, -7.2749e-01,  8.8819e-02,  8.2486e-01,  9.9206e-01]],
       grad_fn=<AddmmBackward>)
tensor([[ 0.8583,  0.2219,  0.0908, -0.3688,  1.1560,  0.7270, -0.1212, -0.3032,
         -0.1714, -0.0670, -1.0323,  0.0047, -1.0780, -0.8921, -1.3603, -0.1051,
         -0.7071,  0.1529,  0.4977,  0.6891, -0.3599,  0.1205, -0.0934,  0.2150,
          1.1170, -0.3915,  0.2290, -0.2707,  0.2720, -0.8762, -0.7861,  0.0707,
         -0.3628, -0.3093, -0.5939, -0.2183, -0.0052,  0.8033, -1.1063, -0.6420,
          0.1120,  0.3753, -0.4286, -0.6054, -0.1547, -0.4218,  0.3286, -0.2107,
         -0.4165, -0.0471,  0.0936, -0.8109, -0.2143, -1.0776,  0.7402, -0.2014,
          0.5503, -1.1897,  0.1982,  0.3422, -0.2176, -0.3140, -1.1125,  0.5685,
         -0.2621, -0.0292, -0.0085,  0.8044, -0.2474,  0.0391, -0.3945, -0.3764,
          1.2721,  0.0749,  0.2646,  0.0626,  0.2451, -1.3575,  0.6655,  0.0903,
         -0.1422,  0.3071,  0.9701,  0.4616,  0.8723, -1.8646,  0.0220,  1.1814,
         -0.9941,  1.6733, -0.0821, -0.4536, -0.0352, -0.6677, -0.1204,  1.2320,
          1.1207,  0.4285, -0.4761,  0.3149,  0.2277, -0.1647, -0.4141,  0.3424,
         -1.0660,  0.0641,  0.0542, -0.3486,  0.3432,  0.0687, -0.1604, -0.5909,
         -1.1305,  0.3894, -0.1907, -0.7874, -0.5636, -0.4908,  0.3239,  0.3048,
          0.6679,  0.7988, -0.8538, -0.2247,  0.3454, -0.9895,  0.9839, -0.4592,
         -0.3357, -1.3587,  0.7747,  0.2871, -0.3907,  0.2034,  0.1411, -0.5172,
          1.5404,  0.5140, -1.2516, -0.1875, -0.1076, -1.9001,  0.2219, -0.1132,
          0.2477, -0.1683,  0.9590, -0.3854, -0.2032, -0.3209, -0.6780,  0.0069,
          0.8684, -0.6738, -0.3045,  0.7861,  0.1697,  0.3226, -0.7716, -0.0477,
          0.0432, -0.5059,  0.2730, -0.6356,  1.0412, -0.2633, -0.9977, -0.3650,
          0.3670, -0.6272, -0.1175,  1.1047,  0.7568, -0.0852, -0.3786,  1.2422,
         -0.5950, -0.2451,  0.5207, -0.0848,  0.0894, -0.0162, -0.4142, -0.0570,
          0.1677,  0.4498, -0.8081, -0.1912,  0.4066, -0.1794, -0.1089,  0.7038,
         -1.1204,  1.4453, -0.5261, -0.2537,  0.2954,  0.0537,  1.1032, -1.6995,
         -0.8658,  0.0069, -0.7864, -0.5945,  0.3836, -0.8819, -0.7932,  0.7809,
         -0.5035,  0.1249, -0.7372, -0.0023,  0.1401,  1.6874, -0.5839,  0.0617,
         -0.3735,  0.1255,  0.8500,  0.1772, -0.2503, -0.9388,  0.4660,  0.0778,
         -0.2575,  0.9906, -0.6868, -0.1666,  1.5679, -0.1536, -0.6431,  0.2470,
         -0.6598, -0.3674,  1.2074, -0.5786, -1.2000,  0.5436, -0.9324, -0.1678,
          0.2622,  0.2365, -1.1233, -0.0316,  0.4280, -0.6036, -0.1521,  0.8521,
         -1.2506,  0.0447, -0.2429, -0.5794,  0.2477, -0.2386, -0.4713, -0.8464,
         -0.6100,  0.0416, -0.9101,  0.4154, -1.1316, -0.3032,  0.2720,  0.0818,
         -0.1726,  0.6396,  0.2227, -0.6746,  1.4707,  0.2891,  0.4319,  1.3665,
         -1.0922,  0.2068,  0.2742,  0.5250,  0.6502, -0.8084, -0.5297, -0.3780,
         -0.3048,  0.3210, -0.4358,  0.7772, -0.4798, -0.2714, -0.4301, -0.1023,
         -0.8924, -0.4756, -0.9159, -0.6420,  0.2500,  0.6301, -0.3656, -0.3115,
         -0.3092, -0.6765,  0.2568, -0.1190, -0.3246, -0.7433,  1.3411,  0.2621,
          0.6059, -1.6666,  0.5171,  0.9830,  0.4238,  0.9399, -0.2219,  0.1042,
          0.6885, -0.1398,  1.0048,  0.8237,  0.5311,  0.2481, -1.1185, -0.3169,
         -0.4606,  0.2594,  0.5915, -0.0420,  0.0353,  0.5132, -0.1115,  0.1641,
          0.6328,  0.0220, -0.1134,  0.3487,  0.7037,  0.1949,  0.9965, -0.3493,
         -0.1531,  0.3266,  0.3365, -0.9098,  0.6647, -0.4036,  0.6309,  0.2539,
         -0.6655,  0.2219,  1.3017, -0.2910,  0.0077, -0.1615, -0.0499, -1.0351,
          0.0631,  0.0526, -0.2282,  0.7903, -0.2692, -0.8018,  0.0425, -0.6413,
         -0.7523, -1.5698, -0.5256, -0.3538,  1.2565,  1.2090,  0.0132, -0.8184,
          1.1792, -0.4623, -0.1368, -0.1340, -0.4158, -0.3891, -0.6368, -0.6716,
          0.1764, -0.6001,  0.3692,  0.1826,  0.3553, -0.4659, -0.0166, -0.2227,
          0.0605, -0.1283,  0.4476, -0.2427,  0.7576,  0.8014, -0.1844,  0.4134,
         -0.3707, -0.5320, -0.2180, -0.7385,  0.5511, -1.1440,  0.7495, -0.1902,
          0.1369,  0.2095, -0.4616, -0.2702, -0.6023, -0.1063, -0.1010,  0.4664,
         -0.4199,  0.6815,  0.1581,  0.1553, -0.3236, -0.3660, -0.0891, -0.0942,
         -0.8452, -1.1930, -0.7743, -0.8862, -0.5736, -0.9316, -0.1222, -0.4710,
         -0.4420, -0.5289, -0.7370, -0.4200,  0.8102,  0.1068,  0.6879,  0.9414,
         -1.0126, -0.6519,  0.4527,  0.3266,  0.4081,  0.1996, -0.0257,  0.1841,
         -0.4881, -0.8573,  0.5010, -0.4788, -0.6908,  0.3824, -0.2642, -0.6462,
          0.2921,  0.8192,  0.6443,  0.5318,  0.8571,  0.5193, -0.0748, -0.0666,
          0.2659,  0.2960, -0.2691, -1.3030, -0.7433,  0.2877, -0.6012, -0.6165,
         -0.1664, -0.4276,  0.7057, -0.4753,  0.9193, -0.0858,  0.4529,  0.0187,
          0.5288, -0.0120, -0.2770,  0.4051,  0.0486, -0.2863,  0.7978, -0.1046,
          0.5071, -0.2378,  0.0393, -0.4039,  1.1442, -0.6032,  0.6462,  0.2437,
          0.6592, -0.1853, -0.3932,  0.1069, -0.3172,  0.0439, -0.0894, -0.7581,
          0.8185, -0.6686, -0.7607, -0.0244,  0.0612, -0.3434,  0.4846,  0.3707,
         -0.1968,  0.7238,  0.0380, -0.1852, -0.0509,  0.0693, -0.2527, -0.7352,
         -0.6229,  0.4219, -0.8397,  0.0265,  0.6799,  0.2732,  0.8133,  0.5658,
          0.4521,  0.2094, -0.1233,  0.2853, -0.4095, -0.0043,  0.0443, -0.1329,
          0.1748, -0.3173,  0.4919, -0.2481, -0.4752, -0.3148, -0.4381, -0.8508,
          0.4462,  0.6670,  0.3655,  2.1904,  0.3760, -0.1575, -0.4121, -0.8432,
          0.2034,  0.7576,  0.4390,  0.1646,  0.1873, -0.3555,  0.8141, -1.1400,
         -0.9080,  0.8897,  1.0261, -0.0397,  0.0434, -0.5660,  0.4741,  0.0463,
          0.6119, -0.7441, -0.1215, -0.3012,  0.5099, -1.2163, -0.3103, -0.4813,
         -0.3444,  0.2921, -0.9768,  0.4538,  0.6191,  0.8799, -0.3835, -0.0057,
          0.2141,  0.0442, -0.3738,  0.4000, -0.9016,  0.0222,  1.1992,  0.7037,
         -0.2133,  0.1393,  1.4536, -0.3438,  0.6417,  0.1219,  0.4277, -0.1041,
          0.1900,  0.0222, -0.2329, -0.0655,  0.3298, -0.8072,  0.1152, -0.0886,
         -0.0550, -0.1536, -0.0492, -0.2500, -0.2076, -0.3855,  0.8968,  0.2879,
          0.5730,  0.1542, -0.6952,  0.6044, -0.0396,  0.6409,  1.0697, -0.5936,
         -1.0671, -0.1631, -0.0559, -0.6267, -0.1045,  0.4992, -1.1814, -0.3745,
          0.1148,  1.2093,  0.6348, -0.0950,  0.6317,  0.3497,  0.9094,  0.2639,
          0.2698, -0.5232,  0.2271, -0.1841, -1.1478,  0.1940,  0.1754,  0.5913,
         -0.1162, -0.2418,  0.2757,  0.5607, -0.3401, -0.2242, -0.5553,  0.7191,
          1.1865,  0.4946, -0.0032, -0.3131,  0.4494,  0.2746, -0.0319, -0.8218,
         -0.1342,  0.2442, -0.5747, -0.1053,  0.5896, -0.8873, -0.6665,  0.5551,
          0.0782,  0.3987,  0.3041,  0.6591, -0.1150, -1.2871,  0.3905,  0.1369,
         -0.7377,  0.9123, -0.2117,  0.4595,  0.6514,  0.4681, -0.2784,  0.0099,
         -0.4514, -0.3678, -0.2100,  0.5424,  0.7370, -0.5189,  0.2916,  0.0367,
          0.7997,  0.2482, -1.2903,  0.3780,  0.6159, -0.1243, -0.6554,  0.9334,
          0.2720, -0.3412,  0.5984, -0.8122, -0.6410,  0.4535,  0.1734, -0.3975,
         -0.8048,  0.4323,  0.4416, -0.0587,  0.3416,  0.5949,  0.9841,  0.6708,
         -0.4209, -0.2902,  0.4471, -0.6019, -0.3284,  0.7052, -0.3894,  0.2325,
         -0.1371, -0.0458,  0.2366,  0.6565, -0.6877, -0.4468, -0.0416,  0.1399,
          0.3912, -0.5745, -0.5798,  0.3441,  0.4783,  0.6710,  1.5530, -0.2175,
          0.2798,  0.7343,  0.2631, -0.1522, -0.0929, -0.7242, -0.1866,  0.4094,
          0.9072,  0.7748, -0.9727,  0.2121,  0.6975,  0.5502, -1.5739, -0.1935,
          0.2408, -0.9197,  0.8733,  0.1751, -1.6064, -0.8624,  0.3407,  0.1941,
          0.2186,  1.0303,  0.9977,  0.9978,  0.5819,  0.3241, -0.0397, -0.0729,
         -0.2124,  0.6568,  0.3392,  0.5155, -0.0025, -1.6329,  0.0523,  0.0961,
         -1.1520,  0.9825, -0.1009,  0.3857, -0.6765,  0.2406,  0.7285, -0.1881,
          0.3678,  0.4719, -0.5791,  0.2218,  0.6020,  0.3131,  0.2334, -0.5597,
          0.7021,  0.0916, -0.0537, -1.1107, -1.3456,  0.1169,  0.2511,  0.0659,
          0.3046, -0.7241, -0.0933,  0.8756, -1.0751,  0.0476,  0.7796,  0.3287,
          0.5448, -0.9035,  0.5777, -0.7859,  0.4341,  1.1255,  0.4447,  0.2046,
          0.6836,  0.7843,  0.9051,  0.4812,  0.2248, -0.3002, -0.2056,  0.0062,
          0.6319, -0.6192,  0.1192,  0.6877, -0.1505, -0.5901,  1.1053, -0.4244,
          1.2949, -0.0225, -0.7752,  0.4802, -0.8232,  0.6852, -0.3668,  0.5570,
          0.6362,  0.1484, -0.1976,  0.1145,  0.4329, -0.3728, -0.0098, -0.1347,
         -0.1296, -0.8788, -0.9397, -0.9967,  0.3410, -0.5902,  0.2272,  0.4113,
          0.5386, -0.0544, -0.4600, -0.4375, -0.4247, -0.0638, -0.3998, -0.8774,
         -0.0317,  0.4021, -0.8222, -0.4809, -1.1616,  0.8198, -0.0503,  0.5451,
         -0.4983,  0.1550, -0.8350, -0.2284,  0.8163,  0.3057, -0.1393, -0.8876,
         -0.1237,  0.0322,  0.0652,  0.0870, -1.2977, -0.2600,  0.3764,  0.2252,
         -0.4700,  0.8265, -0.3017, -0.7971, -0.6706,  0.1718,  0.0871, -1.4964,
         -0.5555,  0.2503, -0.5699, -0.6307, -0.1777,  0.0038, -1.1909, -0.9936,
          0.8798,  0.3346, -0.5728, -1.2314, -0.4099,  0.3378,  0.7328, -0.1436,
         -0.1719,  0.0340, -1.2943, -0.0104,  1.0345,  0.2331,  0.0599,  0.6879,
         -0.7434, -0.5223,  0.2352, -0.9593,  0.0825, -0.5867,  0.5346,  0.7509,
          0.8674, -0.9182,  0.3365,  0.1850, -1.1668, -1.2389, -0.3753, -0.1050,
         -0.0569,  0.1062,  0.7333,  0.0465,  1.0837,  0.6633, -1.7137, -0.5974,
          0.4209, -0.3124,  0.3935,  0.5566, -0.0534, -0.1313, -0.4815,  0.7935,
          0.3970, -1.2496, -0.3091, -1.0784,  0.4957,  0.4002,  0.3540, -1.6088,
         -0.6770,  0.7959, -0.7764,  0.7647,  0.2016,  0.3960,  0.1153,  0.3679,
          0.4652, -0.1590, -0.3743, -0.0253, -0.1515, -0.7036, -0.8022, -0.2377,
         -0.4536,  0.2447,  0.2165,  0.0511, -1.0900, -0.3818, -0.9283,  0.4730,
          0.4143,  0.0216,  0.1163, -0.1247,  0.2278, -0.6479, -0.5509,  0.5441,
          0.2503, -0.0678,  0.8512,  0.3365, -0.5701,  0.1218,  0.2744, -0.7122]],
       grad_fn=<AddmmBackward>)

Process finished with exit code 0

View Code