当前位置:主页 > 查看内容

基于pytorch搭建Resnet18网络结构

发布时间:2021-05-11 00:00| 位朋友查看

简介:我们都知道神经网络深度不断增加会出现两个问题 1.梯度弥散、梯度爆炸 2.退化问题训练集的准确率下降 基于以上问题何凯明于2015年提出了残差神经网络ResNet此网络在深度够深的前提下训练集的准确率也不会下降太多而且能一直保持最优状态那它具体是怎么构建的……

我们都知道神经网络深度不断增加,会出现两个问题:

1.梯度弥散、梯度爆炸
2.退化问题,训练集的准确率下降

基于以上问题何凯明于2015年提出了残差神经网络(ResNet),此网络在深度够深的前提下训练集的准确率也不会下降太多,而且能一直保持最优状态,那它具体是怎么构建的呢?

以Resnet18为例,它是由残差块堆叠而成的网络:
1个卷积层+8个残差块(每个残差块有2个卷积层)+1个全连接层
如下图:
在这里插入图片描述
我个人认为ResNet是最适合深度学习小白研究的网络之一,知晓了原理,看着图就能搭建好网络,代码风格清晰明朗,网上有博主形容其“简单而实用”,我看确实如此。在一开始学习ResNet时,我试着成功搭建10层的Resnet用于训练我的垃圾分类数据集,今晚将其改成18层。

代码如下:

import torch
from torch import nn
from torch.nn import functional as F


class ResBlk(nn.Module):
    """
    resnet block
    """

    def __init__(self, ch_in, ch_out, stride=1):
        """
        :param ch_in:
        :param ch_out:
        """
        super(ResBlk, self).__init__()

        self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(ch_out)
        self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(ch_out)

        self.extra = nn.Sequential()
        if ch_out != ch_in:
            # [b, ch_in, h, w] => [b, ch_out, h, w]
            self.extra = nn.Sequential(
                nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
                nn.BatchNorm2d(ch_out)
            )

    def forward(self, x):
        """
        :param x: [b, ch, h, w]
        :return:
        """
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = self.extra(x) + out
        out = F.relu(out)

        return out


class ResNet18(nn.Module):

    def __init__(self, num_class):
        super(ResNet18, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        # conv1:[2,3,224,224]=>[2,64,56,56]

        self.blk1 = ResBlk(64, 64)
        # [2, 64, 56, 56] => [2, 64, 56, 56]

        self.blk2 = ResBlk(64, 128, stride=2)
        # [2, 64, 56, 56]=>[2, 128, 28, 28]

        self.blk2_1 = ResBlk(128, 128)
        # [2, 128, 28, 28] =>[2, 128, 28, 28]

        self.blk3 = ResBlk(128, 256, stride=2)
        # [2, 128, 28, 28]=>[2, 256, 14, 14]

        self.blk3_1 = ResBlk(256, 256)
        # [2, 256, 14, 14]=>[2, 256, 14, 14]

        self.blk4 = ResBlk(256, 512, stride=2)
        # [2, 256, 14, 14]=>[2, 512, 7, 7]

        self.blk4_1 = ResBlk(512, 512)
        # [2, 512, 7, 7]=>[2, 512, 7, 7]

        self.pool2 = nn.AvgPool2d(kernel_size=7, stride=1, padding=0)

        self.outlayer = nn.Linear(512, num_class)

    def forward(self, x):
        """
        :param x:
        :return:
        """
        x = F.relu(self.conv1(x))  # conv1:[b,3,224,224]=>[b,64,56,56]
        # print(x.shape)

        x = self.blk1(x)  # [b, 64, 56, 56]=>[2, 64, 56, 56]

        x = self.blk1(x)  # [b, 64, 56, 56]=>[2, 64, 56, 56]

        x = self.blk2(x)  # [2, 64, 56, 56]=>[2, 128, 28, 28]

        x = self.blk2_1(x)  # [2, 128, 28, 28] =>[2, 128, 28, 28]

        x = self.blk3(x)  # [2, 128, 28, 28]=>[2, 256, 14, 14]

        x = self.blk3_1(x)  # [2, 256, 14, 14]=>[2, 256, 14, 14]

        x = self.blk4(x)  # [2, 256, 14, 14]=>[2, 512, 7, 7]

        x = self.blk4_1(x)  # [2, 512, 7, 7]=>[2, 512, 7, 7]

        x = self.pool2(x)  # [2, 512, 7, 7]=>[2,512,1,1]

        x = x.view(x.size(0), -1)  # flatten

        x = self.outlayer(x)

        return x


def main():
    model = ResNet18(7)
    tmp = torch.randn(1, 3, 224, 224)
    out = model(tmp)
    print('resnet:', out.shape)


if __name__ == '__main__':
    main()

参考:
(图出自链接1,0基础ResNet建议看第2篇)

https://blog.csdn.net/weixin_44331304/article/details/106127552?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522161779576016780271556440%2522%252C%2522scm%2522%253A%252220140713.130102334…%2522%257D&request_id=161779576016780271556440&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2alltop_positive~default-1-106127552.first_rank_v2_pc_rank_v29&utm_term=resnet18

https://blog.csdn.net/weixin_44331304/article/details/106127552?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522161779576016780271556440%2522%252C%2522scm%2522%253A%252220140713.130102334…%2522%257D&request_id=161779576016780271556440&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2alltop_positive~default-1-106127552.first_rank_v2_pc_rank_v29&utm_term=resnet18

https://www.bilibili.com/video/BV1j64y1D748?p=102&spm_id_from=pageDriver

;原文链接:https://blog.csdn.net/weixin_48638088/article/details/115495988
本站部分内容转载于网络,版权归原作者所有,转载之目的在于传播更多优秀技术内容,如有侵权请联系QQ/微信:153890879删除,谢谢!

推荐图文


随机推荐