Vertiefendes Verständnis des ResNet-Netzwerkmodells + der PyTorch-Implementierung

Wenn Sie die Vergangenheit Revue passieren lassen und Neues lernen, können Sie Lehrer werden!

1. Referenzmaterialien

These:Identitätszuordnungen in Deep Residual Networks
These:Deep Residual Learning für die Bilderkennung a>AutobahnnetzwerkeLesen Sie „Deep Residual Learning“ für Bilderkennung noch einmal, um das Geheimnis des Restnetzwerks besser zu verstehen (mit Pytorch-Code)Detaillierte Erläuterung und Implementierung von ResNetResNet-Notizen zum Restnetzwerk[pytorch] ResNet18, ResNet20, ResNet34, ResNet50 Netzwerkstruktur und ImplementierungOffizielle PyTorch-Implementierung von ResNet
Ausführliche Erklärung von ResNet + PyTorch-Implementierung





2. Verwandte Einführung

1. Tiefes Netzwerk

Mit zunehmender Anzahl der Netzwerkschichten wird die Darstellungsfähigkeit des Netzwerks stärker. Dies liegt daran, dass die Funktion des Faltungskerns darin besteht, die Merkmale des Bildes zu extrahieren. Ein Faltungskern reicht jedoch nicht aus. Ein Faltungskern kann nur spiegeln einen bestimmten Aspekt des Bildes wider. Merkmale, daher sind mehrere Faltungskerne erforderlich. Diese verschiedenen Faltungskerne können unterschiedliche Merkmale des Bildes extrahieren, wodurch das Modell besser in der Lage ist, Bildmerkmale zu lernen. Daher können die Eigenschaften des Originalbilds nur mit genügend Faltungskernen und ausreichenden Parametern besser ausgedrückt werden.

Daher haben tiefe Netzwerke zwei Vorteile:

  1. Je tiefer das Netzwerk,desto höher die Ebene der Funktionen;
  2. Je tiefer das Netzwerk, desto stärkerdie Darstellungsfähigkeit.

2. Benennung des Netzwerkmodells

Heutzutage bestehen viele Netzwerkstrukturen aus „Name + Nummer“, wobei die Nummer die Tiefe des Netzwerks darstellt.Netzwerktiefe bezieht sich auf dieGewichtungsschicht des Netzwerks, einschließlich Faltungsschichten und vollständig verbundener Schichten, ausgenommen Pooling-Schichten und BN-Schichten

3. BN-Chargennormalisierungsschicht

Die Stapelnormalisierungsschicht (Batch Normalization, als BN bezeichnet) sorgt dafür, dass die Merkmalskarte eines Datenstapels die Verteilungsregel mit einem Mittelwert von 0 und einer Varianz von 1 erfüllt.

Während des Bildvorverarbeitungsprozesses werden normalerweise BN-Operationen am Bild ausgeführt, was die Konvergenz des Netzwerks beschleunigen kann. Wie in der folgenden Abbildung dargestellt, ist die Eingabe für Conv1 eine Merkmalsmatrix, die eine bestimmte Verteilung erfüllt. Für Conv2 erfüllt die Eingabe-Feature-Map jedoch nicht unbedingt ein bestimmtes Verteilungsgesetz (beachten Sie, dass die Erfüllung eines bestimmten Verteilungsgesetzes hier nicht bedeutet). Dies bedeutet, dass die Daten einer bestimmten Feature-Karte dem Verteilungsgesetz entsprechen müssen. Theoretisch bedeutet dies, dass die Daten der Feature-Karte, die dem gesamten Trainingsmustersatz entsprechen, dem Verteilungsgesetz entsprechen müssen). Der Zweck von BN besteht darin, dafür zu sorgen, dass die Merkmalskarte das Verteilungsgesetz mit einem Mittelwert von 0 und einer Varianz von 1 erfüllt.
Fügen Sie hier eine Bildbeschreibung ein

3. Einführung in ResNet

Detaillierte Erklärung von ResNet

Das Deep Residual Network (ResNet) wurde 2015 von Microsoft Labs vorgeschlagen. Es gewann den ersten Platz bei der Klassifizierungsaufgabe, den ersten Platz bei der Zielerkennung im ImageNet-Wettbewerb in diesem Jahr und den ersten Platz bei der Zielerkennung und Bildsegmentierung im COCO-Datensatz. Erster Platz. Der Vorschlag von ResNet stellt einen Meilenstein in der Geschichte der CNN-Bilder dar. Aufgrund seiner Vorteile bei öffentlichen Daten gewann der Autor He Kaiming auch den CVPR2016 Best Paper Award.

1. Einleitung

Warum ist Netzwerktiefe wichtig?

Da CNN low/mid/high-level Merkmale extrahieren kann, gilt: Je mehr Schichten das Netzwerk hat, desto umfangreicher sind die Merkmale verschiedener Ebenen, die extrahiert werden können. Darüber hinaus sind die vom tieferen Netzwerk extrahierten Merkmale abstrakter und enthalten mehr semantische Informationen.

Warum können wir nicht einfach die Anzahl der Netzwerkschichten erhöhen?

Bevor das ResNet-Netzwerk vorgeschlagen wurde, wurden herkömmliche Faltungs-Neuronale Netze durch Stapeln einer Reihe von Faltungsschichten und Poolschichten erhalten. Im Allgemeinen glauben wir, dass die Funktionsinformationen umso umfangreicher und der Modelleffekt umso besser sein sollten, je tiefer das Netzwerk ist. Experimente haben jedoch gezeigt, dass herkömmliche Faltungsnetzwerke oder vollständig verbundene Netzwerke bei der Übertragung von Informationen mehr oder weniger Informationen verlieren,<. a i=3 >Informationsverlust und andere Probleme, eine einfache Erhöhung der Netzwerktiefe führt zu Netzwerkverschlechterungsproblemen und führt auch zu Gradient verschwindet oder Gradient explodiert, was dazu führt, dass kein sehr tiefes Training möglich ist Netzwerke.

1.1 Problem mit verschwindendem/explodierendem Gradienten

Mit zunehmender Anzahl der Netzwerkschichten treten bei der Backpropagation Probleme wie das Verschwinden des Gradienten oder die Explosion des Gradienten auf. Die Rückausbreitung wird verwendet, um das Gewicht des Netzwerks anzupassen, einschließlich des Werts des Faltungskerns, des Gewichts und der Vorspannung der verborgenen Schicht, die alle durch Rückausbreitung angepasst werden müssen. Die Rückausbreitung berechnet hauptsächlich den Änderungsfaktor, um das Gewicht anzupassen und Die Berechnung des Änderungsfaktors erfordert zunächst die Berechnung der partiellen Ableitung der Zielfunktion (die Summe der Quadrate der Differenz zwischen dem vorhergesagten Wert und dem wahren Wert) in Bezug auf das Gewicht jeder Schicht des Netzwerks. Wenn wir den Gradienten durch Backpropagation ermitteln, verwenden wir daherKettenregel, durchläuft der Gradientenwert eine Reihe aufeinanderfolgender Multiplikationen und schrumpft oder nimmt drastisch zu. Dieses Phänomen behindert die Konvergenz des Modells.

Verschwindender Gradient: 0,99^1000=0,00004317

Gradientenexplosion: 1,01^1000=20959,155

Wenn der Fehlergradient jeder Schicht kleiner als 1 ist, wird während des Rückwärtsausbreitungsprozesses jede Vorwärtsausbreitung mit einem Fehlergradienten kleiner als 1 multipliziert. Je tiefer das Netzwerk, desto mehr Koeffizienten kleiner als 1 werden multipliziert und der Gradient wird größer sein. Bei Annäherung an 0 tritt ein „Verschwinden des Gradienten“ auf. Wenn umgekehrt der Fehlergradient jeder Schicht größer als 1 ist, muss während des Rückwärtsausbreitungsprozesses jede Vorwärtsausbreitung mit einem Fehlergradienten größer als 1 multipliziert werden Netzwerk Je tiefer Sie gehen, desto größer ist der Gradient, und es kommt zu einer „Gradientenexplosion“.

Lösung: Um das Problem des Verschwindens oder der Explosion von Gradienten zu lösen, schlägt das ResNet-Papier vor, die Datenvorverarbeitung (Datenstandardisierungsverarbeitung) zu verwenden, die Standardgewichtsinitialisierung zu verwenden und zu verwenden BN-Schichten im Netzwerk lösen.

1.2 Netzwerkdegradationsproblem (Degradationsproblem)

Je tiefer das Netzwerk wird, desto schwieriger wird das Training und die Optimierung des Netzwerks wird immer schwieriger. Theoretisch ist der Effekt umso besser, je tiefer das Netzwerk ist. Aufgrund der Schwierigkeit des Trainings führt ein zu tiefes Netzwerk jedoch zu Degradationsproblemen, und der Effekt ist nicht so gut wie bei einem relativ flachen Netzwerk.Mit zunehmender Anzahl der Netzwerkschichten wird die Genauigkeit des Netzwerks gesättigt oder nimmt sogar ab. Dies wird alsDegradationsproblem bezeichnet
Fügen Sie hier eine Bildbeschreibung ein

Lösung: Um das Degradationsproblem in tiefen Netzwerken zu lösen,Lassen Sie einige Schichten des neuronalen Netzwerks die Verbindung der Neuronen der nächsten Schicht überspringen, verbinden Sie die Schichten und schwächen Sie die starke Verbindung zwischen den einzelnen Schichten. Diese Art von neuronalem Netzwerk wird als Restnetzwerk (ResNet) bezeichnet.. Das ResNet-Papier schlägt eine Reststruktur (Reststruktur) vor, um das Degradationsproblem zu lindern. Die folgende Abbildung zeigt ein Faltungsnetzwerk, das die Reststruktur verwendet. Es ist ersichtlich, dass der Effekt mit zunehmender Vertiefung des Netzwerks nicht schlechter, sondern immer schlimmer wird besser. (Die gepunktete Linie ist ein Zugfehler und die durchgezogene Linie ist ein Testfehler).
Fügen Sie hier eine Bildbeschreibung ein

2. Restkartierung

Fügen Sie hier eine Bildbeschreibung ein

Wie in der Abbildung oben gezeigt, wird die Abbildung links aufgerufenIdentitätszuordnung, heißt das Bild rechtsRestkartierung. Unter der Annahme, dass im Bild links die ursprüngliche Eingabe x und die ideale Zuordnung f(x) ist, muss der Teil im gepunkteten Feld links direkt passen die Zuordnung f(x), während der Teil im gepunkteten Feld rechts an die Restkarte angepasst werden mussf(x)-x und die Restkarte oft einfacher zu optimieren ist Wirklichkeit. f(x) im Bild rechts ist ideale Kartierung: Wenn die Gewichts- und Bias-Parameter der gewichteten Operation (z. B. Strahlung) über dem gepunkteten Kästchen in der rechten Abbildung auf 0 gesetzt sind, ist f(x) die Identitätszuordnung. In der PraxisWenn die Grenze der idealen Abbildung f(x) nahe an der Identitätsabbildung liegt, kann die Restkarte auch die subtilen Schwankungen der Identitätsabbildung leicht erfassen.

3. ResNet und VGG

Das ResNet-Netzwerk basiert auf dem VGG19-Netzwerk, wurde auf seiner Basis modifiziert und durchKurzschlussmechanismus< a i hinzugefügt =3>Resteinheit, wie in der Abbildung unten gezeigt. Im Vergleich zu gewöhnlichen Netzwerken fügt ResNet einen Kurzschlussmechanismus zwischen jeweils zwei Schichten hinzu, der Restlernen bildet.
Fügen Sie hier eine Bildbeschreibung ein

Im Vergleich zum VGG19-Netzwerk spiegeln sich die wichtigsten Änderungen von ResNet wider in: ResNet verwendet direkt die Stride=2-Faltung für das Downsampling (die Größe der Feature-Map wird halbiert und die Anzahl der Kanäle verdoppelt) und global average pool ersetzt die vollständig verbundene Ebene. Dies spiegelt ein wichtiges Designprinzip von ResNet wider:Wenn die Größe der Feature-Map um die Hälfte reduziert wird, verdoppelt sich die Anzahl der Kanäle der Feature-Map, wodurch die Komplexität der Netzwerkschicht erhalten bleibt

4. Restliche Reststruktur

1. einfaches und restliches Netzwerk

Ein neuronales Netzwerk, das aus mehreren Restblöcken besteht, ist ein Restnetzwerk. Seine Struktur ist in der folgenden Abbildung dargestellt:
Fügen Sie hier eine Bildbeschreibung ein

Experimente zeigen, dass diese Modellstruktur gut für das Training sehr tiefer neuronaler Netze geeignet ist. Darüber hinaus nennen wir zur leichteren Unterscheidung nicht-residuales Netzwerk einfaches Netzwerk.

2. Restliche Reststruktur

2.1 short cutStruktureinführung

Der größte Unterschied zwischen ResNet und VGGNet besteht darin, dass es viele Seitenkanäle gibt, die den Eingang direkt mit nachfolgenden Schichten verbinden. Diese Struktur wird auch short cut oder skip connections

  • Durch die Verwendungshort cut eines strukturierten Restblocks kann seine Eingabe schneller über schichtübergreifende Datenrouten weitergeleitet werden

  • short cutDie Zweige auf dem Pfad werden „Abkürzungszweige“ genannt und unterscheiden sich von den „Hauptzweigen“.

Wie in der folgenden Abbildung gezeigt, verwendet die Reststruktur die Verbindungsmethode short cut, um die Merkmalsmatrizen in Schichten hinzuzufügen. Die sogenannte Addition ist Fügen Sie die Werte an derselben Position in der Merkmalsmatrix hinzu. In praktischen Anwendungen ist die short cut der Reststruktur nicht unbedingt durch eine Schicht verbunden, sondern kann auch durch mehrere Schichten getrennt sein. Das von ResNet vorgeschlagene Restnetzwerk ist durch mehrere Schichten getrennt.
Fügen Sie hier eine Bildbeschreibung ein
wird im Allgemeinen identity Function genannt, was ein istVerbindung überspringen; rufe F(x) aufResNet Function. Beachten Sie, dass F(x) und x die gleiche Form haben müssen.

2.2 Einführung in die Reststruktur

Wie nachfolgend dargestellt,Es gibt zwei verschiedene Reststrukturen in ResNet. Die linke Reststruktur heißt BasicBlock und die rechte Reststruktur heißt Engpass. Die Reststruktur von ResNet18/34 ist BasicBlock und verwendet zwei 3x3-Faltungen. Die Reststruktur von ResNet50/101/152 ist Bottleneck und die Faltung von 1x1+3x3+1x1 wird verwendet.
Fügen Sie hier eine Bildbeschreibung ein
ResNet folgt dem vollständigen 3×3-Faltungsschichtdesign von VGG. Erstens verfügt die BasicBlockReststruktur über zwei 3×3-Faltungsschichten mit der gleichen Anzahl von Ausgangskanälen. Auf jede Faltungsschicht folgen eine BN-Batch-Normalisierungsschicht und eine ReLU-Aktivierungsfunktion. Anschließend werden diese beiden Faltungsoperationen über den schichtübergreifenden Datenpfad übersprungen und die Eingabe direkt vor der endgültigen ReLU-Aktivierungsfunktion hinzugefügt. Ein solches Design erfordert, dass die Ausgaben der beiden Faltungsschichten dieselbe Form wie die Eingaben haben, damit sie addiert werden können. Wenn Sie die Anzahl der Kanäle ändern möchten, müssen Sie eine zusätzliche 1×1-Faltungsschicht einführen, um die Eingabe in die erforderliche Form umzuwandeln, und dann die Additionsoperation ausführen.
Fügen Sie hier eine Bildbeschreibung ein

Ähnlich wie VggNet verfügt auch ResNet über mehrere Versionen mit unterschiedlichen Schichten, und es gibt zwei Arten von Reststrukturen, die flachen und tiefen Netzwerken entsprechen:

ResNet Reststruktur
Flaches Netzwerk ResNet18/34 BasicBlock
Deep Web ResNet50/101/152 Engpass

Das Folgende ist das spezifische Reststrukturdiagramm mit durchgezogener/gestrichelter Linie von ResNet 18/34 und ResNet 50/101/152:

ResNet 18/34

Fügen Sie hier eine Bildbeschreibung ein

ResNet 50/101/152

Fügen Sie hier eine Bildbeschreibung ein

3. BasicBlockReststruktur

Für ResNet, das weniger 18-schichtige und 34-schichtige Netzwerkschichten hat, besteht es aus BasicBlock, das Restlernen zwischen zwei Schichten durchführt. Die zweischichtigen Faltungskerne sind 3x3 bzw. 3x3.basic_block=identity_blockDiese Struktur stellt sicher, dass Eingang und Ausgang gleich sind, und realisiert die Reihenschaltung des Netzwerks
Fügen Sie hier eine Bildbeschreibung ein

import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo

# 这个文件内包括6中不同的网络架构
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
           'resnet152']

# 每一种架构下都有训练好的可以用的参数文件
model_urls = {
    
    
    'resnet18': 'https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth',
    'resnet34': 'https://s3.amazonaws.com/pytorch/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://s3.amazonaws.com/pytorch/models/resnet50-19c8e357.pth',
    'resnet101': 'https://s3.amazonaws.com/pytorch/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://s3.amazonaws.com/pytorch/models/resnet152-b121ed2d.pth',
}

# 常见的3x3卷积
def conv3x3(in_planes, out_planes, stride=1):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    # 残差结构中,主分支的卷积核个数是否发生变化,不变则为1
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, downsample=None):  
        # downsample对应虚线残差结构
    	# inplanes代表输入通道数,planes代表输出通道数。
        super(BasicBlock, self).__init__()
        # Conv1
        self.conv1 = conv3x3(inplanes, planes, stride)
        # stride=1为实线残差结构,不需要改变大小,stride=2为虚线残差结构
        # stride=1,output=(input-3+2*1)/ 1 + 1 = input   输入和输出的shape不变
        # stride=2,output=(input-3+2*1)/ 2 + 1 = input = input/2 + 0.5 = input/2(向下取整)
        self.bn1 = nn.BatchNorm2d(planes)  # 使用BN时不使用偏置
        self.relu = nn.ReLU(inplace=True)
        # Conv2
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        # 下采样
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:  # 虚线残差结构,需要下采样
            residual = self.downsample(x)  # 捷径分支 short cut
		# F(x)+x
        out += residual
        out = self.relu(out)

        return out

BasicBlockinit()Die Funktion in der Klasse forward() definiert die Vorwärtsausbreitung und die implementierte Funktion ist der Restblock. definiert die Netzwerkarchitektur, die Funktion
Fügen Sie hier eine Bildbeschreibung ein

4. BottleneckReststruktur

Für ResNet mit einer großen Anzahl von 50-schichtigen, 101-schichtigen und 152-schichtigen Netzwerkschichten besteht es aus Bottleneck, das Restlernen zwischen drei Schichten durchführt. Die dreischichtigen Faltungskerne sind 1x1, 3x3 und 1x1. Bei tiefem Engpass spielt der 1×1-Faltungskern die Rolle der Dimensionsreduzierung und Dimensionalität (Feature-Matrix-Tiefe) und kann Netzwerkparameter erheblich reduzieren. Konkret besteht die Funktion des 1×1-Faltungskerns in der ersten Schicht darin, eine Dimensionsreduktionsoperation an der Merkmalsmatrix durchzuführen und die Tiefe der Merkmalsmatrix um 256 auf 64 zu reduzieren ; Der 1×1-Faltungskern der dritten Schicht führt eine Dimensionserhöhungsoperation an der Merkmalsmatrix aus, und die Tiefe der Merkmalsmatrix beträgt von 64 aufsteigend auf 256 reduziert. Die Verringerung der Tiefe der Merkmalsmatrix dient hauptsächlich dazu, die Anzahl der Parameter zu verringern. Reduzieren Sie zuerst die Dimension und erhöhen Sie dann die Dimension, um die Merkmalsmatrix auf dem Hauptzweig auszugeben und Verknüpfung Die auf Zweigen ausgegebenen Feature-Matrizen haben für Additionsoperationen die gleiche Form
Fügen Sie hier eine Bildbeschreibung ein

Es ist zu beachten, dass die Anzahl der Kanäle der Feature-Map der verborgenen Ebene relativ gering ist und 1/4 der Anzahl der Kanäle der Ausgabe-Feature-Map beträgt. Wie in der folgenden Abbildung dargestellt, beträgt die Anzahl der Kanäle der verborgenen Schicht, die den ersten beiden Faltungskernen im dreischichtigen Faltungskern entsprechen, 64, und die Anzahl der Kanäle der Ausgabeschicht, die dem letzten Faltungskern entsprechen, beträgt 256. Die Anzahl der Kanäle In der verborgenen Ebene wird 1/4 der Anzahl der Ebenenkanäle ausgegeben.
Fügen Sie hier eine Bildbeschreibung ein

# ResNet50/101/152的残差结构,用的是1x1+3x3+1x1的卷积核
class Bottleneck(nn.Module):
    """
    注意:原论文中,在虚线残差结构的主分支上,第一个1x1卷积层的步距是2,第二个3x3卷积层步距是1。
    但在pytorch官方实现过程中是第一个1x1卷积层的步距是1,第二个3x3卷积层步距是2,
    这么做的好处是能够在top1上提升大概0.5%的准确率。
    可参考 Resnet v1.5 https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch
    """
    # 残差结构中第三层卷积核个数是第一/二层卷积核个数的4倍
    expansion = 4      # 输出通道数的倍乘

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        # conv1   1x1
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        # conv2   3x3
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        # stride=stride根据传入的进行调整,因为实线中的第二层是1,虚线中是2
        self.bn2 = nn.BatchNorm2d(planes)
        # conv3   1x1  
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:  # 捷径分支 short cut
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

BottleneckDie Klasse ist ein weiterer Blocktyp, die Funktion init() definiert die Netzwerkarchitektur und die Funktion forward() definiert die Vorwärtsausbreitung. In diesem Block gibt es drei Faltungen, nämlich 1x1, 3x3 und 1x1. Die von ihnen ausgeführten Funktionen bestehen darin, Dimensionen zu komprimieren, zu falten und Dimensionen wiederherzustellen.. Daher besteht die von der Klasse Bottleneck implementierte Funktion darin, die Anzahl der Kanäle zu komprimieren und dann zu verstärken. Hinweis: Die Ebene ist hier nicht mehr die Anzahl der Ausgabekanäle. Die Anzahl der Ausgabekanäle sollte p l a n e ∗ e x p a n s i o n plane*expansion seinplaneexpans ion,即 4 ∗ p l a n e 4 *Flugzeug 4plane
Fügen Sie hier eine Bildbeschreibung ein

5. Berechnung der Reststrukturmenge

Nehmen Sie an, dass die Anzahl der Kanäle der Eingabemerkmals- und Ausgabemerkmalsmatrizen der beiden Reststrukturen jeweils 256 Dimensionen beträgt, wie unten gezeigt:
Fügen Sie hier eine Bildbeschreibung ein

Wenn der Bias-Term nicht berücksichtigt wird, lautet die CNN-Parameterberechnungsformel: D K ∗ D K ∗ M ∗ N D_K*D_K*M*N DKDKMN

Wenn die BasicBlock-Reststruktur verwendet wird, beträgt die Anzahl der Parameter: 3×3x256×256+3×3x256×256=1179648.
Wenn die Bottleneck-Reststruktur verwendet wird, beträgt die Anzahl der Parameter: 1×1×256×64+3×3×64×64+1×1×64×256=69632.

Zusammenfassen: Offensichtlich hat die Verwendung der Bottleneck-Reststruktur weniger Parameter und ist besser für den Aufbau tiefer Netzwerke geeignet.

5. ResNet-Netzwerk

1. ResNet-Netzwerkstruktur

  • resnet18: ResNet(BasicBlock, [2, 2, 2, 2])
  • resnet34: ResNet(BasicBlock, [3, 4, 6, 3])
  • resnet50: ResNet (Engpass, [3, 4, 6, 3])
  • resnet101: ResNet (Engpass, [3, 4, 23, 3])
  • resnet152: ResNet (Engpass, [3, 8, 36, 3])

Wie in der folgenden Abbildung gezeigt, ist ResNet50 in 5 Schichten unterteilt: conv1, conv2_x, conv3_x, conv4_x und conv5_x. Die Anzahl der Netzwerkschichten beträgt:1+1+3x3+4x3+6x3+3x3=50, die vorherige Faltung Schicht + eine Schicht-Pooling-Schicht + 4 Faltungsgruppen, ohne Berücksichtigung der letzten vollständig verbundenen und Pooling-Schichten.
Fügen Sie hier eine Bildbeschreibung ein
ResNet hat im Allgemeinen 4 Stufen. Jeder Stapel ist ein Stapel von Blöcken. Beispielsweise ist [3, 4, 6, 3] die Anzahl der in jeder Stufe gestapelten Blöcke, also unterschiedliche Deep ResNet.

Fügen Sie hier eine Bildbeschreibung ein
BildquelleResNet50-Netzwerkstrukturdiagramm und detaillierte Strukturerklärung,
Bild-Download-Link: Link: Extraktionscode: 1ojdBaidu Cloud Disk

2. Innovationspunkte des ResNet-Netzwerks

  • Bauen Sie eine ultratiefe Netzwerkstruktur auf (kann 1.000 Schichten überschreiten).

  • Vorgeschlagene Reststruktur (Reststruktur) zur Linderung des Degradationsproblems.

  • Verwenden Sie BN-Schichten, um verschwindende oder explodierende Gradientenprobleme zu lösen. Verwenden Sie BN, um das Training zu beschleunigen (Abbrecher verwerfen).

    Im Bildvorverarbeitungsprozess werden Bilder normalerweise standardisiert, was die Konvergenz des Netzwerks beschleunigen kann. Der Zweck von BN besteht darin, dafür zu sorgen, dass die Merkmalskarte das Verteilungsgesetz mit einem Mittelwert von 0 und einer Varianz von 1 erfüllt.

3. Kerncode

# 整个网络的框架部分
class ResNet(nn.Module):
    # block = BasicBlock or Bottleneck
    # layers为残差结构中conv2_x~conv5_x中残差块个数,是一个列表,如34层中的是[3,4,6,3]
    def __init__(self, block, layers, num_classes=1000):  
        self.inplanes = 64 
        super(ResNet, self).__init__()
        # 1.conv1
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        # 2.conv2_x
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        # 3.conv3_x
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        # 4.conv4_x
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        # 5.conv5_x
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        
        self.avgpool = nn.AvgPool2d(7)
        self.fc = nn.Linear(512 * block.expansion, num_classes)
		
		# 初始化权重
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        # 每个blocks的第一个residual结构保存在layers列表中。
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            # 通过循环将剩下的一系列实线残差结构添加到layers
            layers.append(block(self.inplanes, planes))   
		
        # Sequential将一系列网络结构组合在一起
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)   # 将输出结果展成一行
        x = self.fc(x)

        return x

ResNet hat insgesamt 5 Stufen. Die erste Stufe ist eine 7x7-Faltung mit Schritt = 2 und durchläuft dann die Pooling-Schicht. Die resultierende Feature-Map-Größe beträgt 1/4 des Originalbilds. _make_layer() Die Funktion wird zum Generieren von 4 Ebenen verwendet, und ein Netzwerk kann basierend auf der Liste der Eingabeebenen erstellt werden.

# resnet18
def resnet18(pretrained=False):
    """Constructs a ResNet-18 model.
	# https://download.pytorch.org/models/resnet18-f37072fd.pth
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2])
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
    return model

# resnet34
def resnet34(pretrained=False):
    """Constructs a ResNet-34 model.
    # https://download.pytorch.org/models/resnet34-333f7ec4.pth
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [3, 4, 6, 3])
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
    return model

# resnet50
def resnet50(pretrained=False):
    """Constructs a ResNet-50 model.
	# https://download.pytorch.org/models/resnet50-19c8e357.pth
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 6, 3])
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
    return model

# resnet101
def resnet101(pretrained=False):
    """Constructs a ResNet-101 model.
	# https://download.pytorch.org/models/resnet101-5d3b4d8f.pth
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 23, 3])
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
    return model

# resnet152
def resnet152(pretrained=False):
    """Constructs a ResNet-152 model.
	# https://download.pytorch.org/models/resnet152-394f9c45.pth
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 8, 36, 3])
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
    return model

6. Das einfachste ResNet18

Die spezifische Zusammensetzung des ResNet18-Netzwerks
PyTorch implementiert ResNet18
ResNet18-Struktur und Ausgabedimensionen jeder Schicht
ResNet 18 Strukturinterpretation „empfohlene Sammlung“
Resnet 18-Netzwerkmodell [leicht zu verstehen]
Resnet 18-Netzwerkmodell
Resnet- 18 Netzwerk Grafisches Verständnis

ResNet18-Netzwerkstruktur

Das 18-schichtige ResNet heißt ResNet18 und seine Netzwerktiefe beträgt 18 Schichten, insbesondere 18 Schichten mit Gewichten, einschließlich: Faltungsschichten und vollständig verbundene Schichten, ausgenommen Pooling-Schichten und BN-Schichten. Wie in der folgenden Abbildung dargestellt, gibt es 17 Faltungsschichten und 1 FC-Schicht, also 18 Schichten.
Fügen Sie hier eine Bildbeschreibung ein

Gepunktete Linieshort cut Die Dimensionsverarbeitung erfolgt über einen 1×1-Faltungskern (die Merkmalsmatrix wird in Längen- und Breitenrichtung heruntergesampelt und die Tiefenrichtung wird angepasst). die Bedürfnisse der nächsten Schicht des Reststrukturkanals).

  • Kanalkanal verdoppelt. Die Anzahl der Kanäle wird durch 1x1-Faltung angepasst. Die durchgezogene Linie zeigt an, dass sich die Anzahl der Kanäle im Restblock nicht geändert hat, und die gepunktete Linie zeigt an, dass sich die Anzahl der Kanäle geändert hat, z. B. 64 -> 128.
  • Die Form der Merkmalsmatrix wird halbiert. Passen Sie die Schrittgröße auf 2 an, um Downsampling zu implementieren.
    Fügen Sie hier eine Bildbeschreibung ein

Hinweis:

BN 表示批量归一化
   
RELU 表示激活函数
   
lambda x:x 这个函数的意思是输出等于输入
   
identity 表示残差
   
1个resnet block 包含2个basic block
1个resnet block 需要添加2个残差
   
在resnet block之间残差形式是1*1conv,在resnet block内部残差形式是lambda x:x
resnet block之间的残差用实线箭头表示,resnet block内部的残差用虚线箭头表示
   
3*3conv s=2,p=1 特征图尺寸会缩小
3*3conv s=1,p=1 特征图尺寸不变

(1) Conv1-Faltungsschicht

Durchlaufen Sie zunächst eine Faltungsschicht. Die Größe des Faltungskerns dieser Faltungsschicht beträgt 7x7, der Schritt beträgt 2, die Polsterung beträgt 3 und der Ausgabekanal beträgt 64. Nach der Formel:
n o u t = ⌊ n i n + 2 p − k s ⌋ + 1 n_{out}=\left\lfloor\frac{n_{in}+2p-k}{s} \right \rfloor+1 Nout=SNin+2pk+1
Wir können berechnen, dass die Größe der endgültigen Ausgabedaten 64 x 112 x 112 beträgt.

(2) Maxpooling-Pooling-Schicht

Fügen Sie hier eine Bildbeschreibung ein

Maximale Pooling-Schicht, die Größe des Faltungskerns dieser Schicht beträgt 3x3, der Schritt beträgt 2 und die Polsterung beträgt 1. Die Größe der endgültigen Ausgabedaten beträgt 64 x 56 x 56. Mit anderen Worten: Diese Ebene ändert nicht die Anzahl der Datenkanäle, aber die Form der Merkmalsmatrix wird halbiert.

(3) conv2_x-Faltungsschicht (Resnet-Block1)

Die Faltungskerngröße dieser Faltungsschicht beträgt 3x3, der Schritt ist 1 und die Polsterung ist 1. Durch zwei Faltungsberechnungen beträgt die Ausgabedatengröße schließlich 64 x 56 x 56. Diese Ebene ändert nichts an der Datengröße und der Anzahl der Kanäle.
Fügen Sie hier eine Bildbeschreibung ein

(4) conv3_x-Faltungsschicht (Resnet-Block2)

Die Dimensionalität wird durch eine 1x1-Faltungsschicht und anschließendes Downsampling erhöht. Die endgültige Ausgabe beträgt 128 x 28 x 28. Der Ausgabekanal wird verdoppelt und die Form der Ausgabe-Feature-Matrix wird halbiert.
Fügen Sie hier eine Bildbeschreibung ein

(5) conv4_x-Faltungsschicht (Resnet-Block3)

Übergeben Sie eine 1x1-Faltungsschicht und führen Sie ein Downsampling durch. Die endgültige Ausgabe beträgt 256 x 14 x 14. Der Ausgabekanal wird verdoppelt und die Form der Ausgabe-Feature-Matrix wird halbiert.
Fügen Sie hier eine Bildbeschreibung ein

(6) conv5_x-Faltungsschicht (Resnet-Block4)

Auf die gleiche Weise wie oben beträgt die endgültige Ausgabe 512 x 7 x 7. Der Ausgabekanal wird verdoppelt und die Form der Ausgabe-Feature-Matrix wird halbiert.
Fügen Sie hier eine Bildbeschreibung ein

(7) Avgpooling-Schicht

Die endgültige Ausgabe ist 512x1x1.

(8) FC-Schicht

7. Quellcode-Analyse

Detaillierte Erläuterung der ResNet-Netzwerkstruktur, Netzwerkkonstruktion, Transferlernen
Pytorch-Bildklassifizierungsartikel: 6. Detaillierte Erläuterung der ResNet-Netzwerkstruktur und Einführung in Transferlernen

1.model.py

import torch.nn as nn
import torch


# ResNet18/34的残差结构,用的是2个3x3的卷积核
class BasicBlock(nn.Module):
    expansion = 1  # 残差结构中,主分支的卷积核个数是否发生变化,不变则为1

    def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):  # downsample对应虚线残差结构
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=(3, 3), stride=(stride, stride),
                               padding=1, bias=False)
        # stride=1为实线残差结构,不需要改变大小,stride=2为虚线残差结构
        # stride=1,output=(input-3+2*1)/ 1 + 1 = input   输入和输出的shape不变
        # stride=2,output=(input-3+2*1)/ 2 + 1 = input = input/2 + 0.5 = input/2(向下取整)
        self.bn1 = nn.BatchNorm2d(out_channel)   # 使用BN时不使用偏置
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:  # 虚线残差结构,需要下采样
            identity = self.downsample(x)  # 捷径分支 short cut

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out


# ResNet50/101/152的残差结构,用的是1x1+3x3+1x1的卷积核
class Bottleneck(nn.Module):
    """
    注意:原论文中,在虚线残差结构的主分支上,第一个1x1卷积层的步距是2,第二个3x3卷积层步距是1。
    但在pytorch官方实现过程中是第一个1x1卷积层的步距是1,第二个3x3卷积层步距是2,
    这么做的好处是能够在top1上提升大概0.5%的准确率。
    可参考 Resnet v1.5 https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch
    """
    expansion = 4  # 残差结构中第三层卷积核个数是第一/二层卷积核个数的4倍

    def __init__(self, in_channel, out_channel, stride=1, downsample=None,
                 groups=1, width_per_group=64):
        super(Bottleneck, self).__init__()

        width = int(out_channel * (width_per_group / 64.)) * groups

        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,
                               kernel_size=(1, 1), stride=(1, 1), bias=False)  # squeeze channels
        self.bn1 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,
                               kernel_size=(3, 3), stride=(stride, stride), bias=False, padding=1)
        # stride=stride根据传入的进行调整,因为实线中的第二层是1,虚线中是2
        self.bn2 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel * self.expansion,  # 卷积核个数变为4倍
                               kernel_size=(1, 1), stride=(1, 1), bias=False)  # unsqueeze channels
        self.bn3 = nn.BatchNorm2d(out_channel * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)  # 捷径分支 short cut

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += identity
        out = self.relu(out)

        return out


# 整个网络的框架部分
class ResNet(nn.Module):
    # block = BasicBlock or Bottleneck
    # block_num为残差结构中conv2_x~conv5_x中残差块个数,是一个列表,如34层中的是[3,4,6,3]
    def __init__(self,
                 block,
                 blocks_num,
                 num_classes=1000,
                 include_top=True,  # 方便在resnet网络的基础上搭建其他网络,这里用不到
                 groups=1,
                 width_per_group=64):
        super(ResNet, self).__init__()
        self.include_top = include_top
        self.in_channel = 64

        self.groups = groups
        self.width_per_group = width_per_group

        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=(7, 7), stride=(2, 2),
                               padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
        if self.include_top:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1),自适应平均池化下采样
            self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    # channel为残差结构中第一层卷积核个数,block_num表示该层一共包含多少个残差结构,如34层中的是3,4,6,3
    def _make_layer(self, block, channel, block_num, stride=1):
        downsample = None
        # ResNet50/101/152的残差结构,block.expansion=4
        if stride != 1 or self.in_channel != channel * block.expansion:  # layer2,3,4都会经过这个结构
            downsample = nn.Sequential(  # 生成下采样函数,这里只需要调整conv2的特征矩阵的深度
                nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=(1, 1), stride=(stride, stride), bias=False),
                nn.BatchNorm2d(channel * block.expansion))

        layers = []
        # 首先将第一层残差结构添加进去,block = BasicBlock or Bottleneck
        layers.append(block(self.in_channel,  # 输入特征矩阵的深度64
                            channel,  # 残差结构对应主分支上的第一个卷积层的卷积核个数
                            downsample=downsample,  # 50/101/152对应的是高宽不变,深度4倍,对应的虚线残差结构
                            stride=stride,
                            groups=self.groups,
                            width_per_group=self.width_per_group))
        self.in_channel = channel * block.expansion

        for _ in range(1, block_num):
            # 通过循环将剩下的一系列实线残差结构添加到layers
            layers.append(block(self.in_channel,
                                channel,
                                groups=self.groups,
                                width_per_group=self.width_per_group))

        # Sequential将一系列网络结构组合在一起
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        if self.include_top:
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)

        return x


def resnet18(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet18-f37072fd.pth
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, include_top=include_top)


def resnet34(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet34-333f7ec4.pth
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)


def resnet50(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet50-19c8e357.pth
    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)


def resnet101(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet101-5d3b4d8f.pth
    return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)

def resnet152(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet152-394f9c45.pth
    return ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, include_top=include_top)


def resnext50_32x4d(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth
    groups = 32
    width_per_group = 4
    return ResNet(Bottleneck, [3, 4, 6, 3],
                  num_classes=num_classes,
                  include_top=include_top,
                  groups=groups,
                  width_per_group=width_per_group)


def resnext101_32x8d(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth
    groups = 32
    width_per_group = 8
    return ResNet(Bottleneck, [3, 4, 23, 3],
                  num_classes=num_classes,
                  include_top=include_top,
                  groups=groups,
                  width_per_group=width_per_group)

2.train.py

import os
import sys
import json
 
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm
 
from model import resnet34


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))
 
    data_transform = {
    
    
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(256),      #原图的长宽比固定不动,把最小边长缩放到256
                                   transforms.CenterCrop(224),      #中心裁剪
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
 
    data_root = os.path.abspath(os.path.join(os.getcwd(), "../"))  # get data root path
    image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)
 
    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)
 
    batch_size = 4
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))
 
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)
 
    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=nw)
 
    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))
    
    net = resnet34()
    # load pretrain weights
    # download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
    model_weight_path = "./resnet34-pre.pth"
    assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
    net.load_state_dict(torch.load(model_weight_path, map_location='cpu'))
    # for param in net.parameters():
    #     param.requires_grad = False
 
    # change fc layer structure
    in_channel = net.fc.in_features
    net.fc = nn.Linear(in_channel, 5)
    net.to(device)
 
    # define loss function
    loss_function = nn.CrossEntropyLoss()
 
    # construct an optimizer
    params = [p for p in net.parameters() if p.requires_grad]
    optimizer = optim.Adam(params, lr=0.0001)
 
    epochs = 3
    best_acc = 0.0
    save_path = './resNet34.pth'
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            logits = net(images.to(device))
            loss = loss_function(logits, labels.to(device))
            loss.backward()
            optimizer.step()
 
            # print statistics
            running_loss += loss.item()
 
            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)
 
        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                # loss = loss_function(outputs, test_labels)
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
 
                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
                                                           epochs)
 
        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))
 
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)
 
    print('Finished Training')
 
 
if __name__ == '__main__':
    main()

3.predict.py

import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import resnet34


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
    data_transform = transforms.Compose(
        [transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
 
    # load image
    img_path = "../tulip.jpg"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)
    plt.imshow(img)
    # [N, C, H, W]
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)
 
    # read class_indict
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
 
    with open(json_path, "r") as f:
        class_indict = json.load(f)
 
    # create model
    model = resnet34(num_classes=5).to(device)
 
    # load model weights
    weights_path = "./resNet34.pth"
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    model.load_state_dict(torch.load(weights_path, map_location=device))
 
    # prediction
    model.eval()
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()
 
    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
    plt.title(print_res)
    for i in range(len(predict)):
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
                                                  predict[i].numpy()))
    plt.show()
 
 
if __name__ == '__main__':
    main()

Guess you like

Origin blog.csdn.net/m0_37605642/article/details/134299994