PyTorch-Studiennotizen (17) – Einführung in die Nutzung von Torchvision.transforms
Bei diesem Blog-Beitrag handelt es sich um die Studiennotizen von PyTorch, den 17. Inhaltsdatensatz, der hauptsächlich die Verwendung von Torchvision.transforms aufzeichnet.
Inhaltsverzeichnis
- PyTorch-Studiennotizen (17) – Einführung in die Verwendung von Torchvision.transforms
- 1. Die Ursache des Problems
- 2. Die spezifische Verwendung von Torchvision.transforms
- 3. Andere Verwendungen von Torchvision.transforms
- 4. Ergänzen Sie weitere Funktionen des Torchvision-Moduls
- 5. Fehlerbehebung beim Ausführen
1. Die Ursache des Problems
Beim Lesen des Anwendungscodes von ResNet bin ich auf den folgenden kleinen Code gestoßen. Dieser Code erscheint vor dem Lesen der Bildinformationen. Was ist die spezifische Funktion dieses Codes? Anfänger müssen die spezifische Bedeutung dieses Codes verstehen
data_transform = transforms.Compose(
[transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
2. Die spezifische Verwendung von Torchvision.transforms
Es gibt ein sehr wichtiges und nützliches Paket im PyTorch-Framework: Torchvision, das hauptsächlich aus drei Unterpaketen besteht, nämlich: Torchvision.datasets, Torchvision.models, Torchvision.transforms. Der obige Code verwendet torchvision.transforms
dieses Paket.
Die hier verwendete Torchvision-Toolbibliothek ist ein häufig verwendetes Bildverarbeitungspaket unter dem Pytorch-Framework, das zum Generieren von Bild- und Videodatensätzen (torchvision.datasets), zur Bildvorverarbeitung (torchvision.transforms) und zum Importieren vorab trainierter Modelle verwendet werden kann (torchvision. models) und das Erzeugen von Diagrammen und das Speichern von Bildern (torchvision.utils). Unter diesen
kann die Transformationsfunktion zur Vorverarbeitung des Bildes sein: 归一化(normalize)
, usw. Bei den oben genannten Schritten handelt es sich häufig um eine Reihe tatsächlicher Vorgänge. Zu diesem Zeitpunkt kann Compose verwendet werden, um diese Bildvorverarbeitungsvorgänge zu verbinden. Wie im obigen Code lautet die Operation hier: transforms.ToTensor() , die ein PIL-Bild in einen Tensor umwandelt. Das heißt, ( H ∗ W ∗ C ) (H\ast W\ast C)尺寸剪裁(resize)
翻转(flip)
( H∗W∗C ) Das PIL-Bild im Bereich [0,255] wird konvertiert in(C ∗ H ∗ W) (C\ast H\ast W)( C∗H∗W ) Torch.tensor im Bereich [0,1].
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) , normalisieren Sie das Bild mit Mittelwert [0.485, 0.456, 0.406] und Standardabweichung [0.229, 0.224, 0.225].
3. Andere Verwendungen von Torchvision.transforms
Zu den weiteren Funktionen der Transformationsfunktion gehören:
Größe ändern: Ändern Sie die Größe des angegebenen Bildes auf die angegebene Größe.
ToPILImage: Torch.tensor in PIL-Bild konvertieren.
CenterCrop: Verwenden Sie den Mittelpunkt des Eingabebilds als Mittelpunkt, um den Zuschneidevorgang der angegebenen Größe durchzuführen.
RandomCrop: Der Zuschneidevorgang der angegebenen Größe wird um die zufällige Position des Eingabebildes herum ausgeführt.
RandomHorizontalFlip: Dreht das gegebene PIL-Bild horizontal mit einer Wahrscheinlichkeit von 0,5.
RandomVerticalFlip: Spiegelt das angegebene PIL-Bild vertikal mit einer Wahrscheinlichkeit von 0,5.
RandomResizedCrop: Schneiden Sie das gegebene Bild nach dem Zufallsprinzip auf verschiedene Größen und Seitenverhältnisse zu und skalieren Sie das zugeschnittene Bild dann auf die angegebene Größe (mit einem Parameter n).
Graustufen: Konvertiert ein bestimmtes Bild in ein Graustufenbild.
RandomGrayscale: Konvertiert ein Bild mit einer angegebenen Wahrscheinlichkeit in ein Graustufenbild.
FiveCrop: 5 Bilder einer bestimmten Größe aus einem Eingabebild zuschneiden, einschließlich 4 Eckbildern und einem Bild in der Mitte.
TenCrop: 10 Bilder der angegebenen Größe ausschneiden. Die Methode besteht darin, das Eingabebild auf der Grundlage von FiveCrop horizontal oder vertikal zu spiegeln und dann die FiveCrop-Operation auszuführen, sodass ein Bild 10 Zuschneidebilder erhalten kann.
Pad: Füllt die „Auffüll“-Pixel auf allen Seiten des angegebenen Bildes mit dem „Füll“-Wert auf.
ColorJitter: Ändern Sie Helligkeit, Kontrast, Sättigung und Farbton eines Bildes.
Lambda: Führen Sie die durch seine Parameter angegebene Transformation durch.
Eine detaillierte Einführung in die oben genannten vier Pakete und ihre spezifischen Funktionen finden Sie in der chinesischen Dokumentation von Pytorch .
Die Code-Implementierung kann sich auf die Code-Implementierung von Github beziehen .
4. Ergänzen Sie weitere Funktionen des Torchvision-Moduls
Torchvision ist eine von PyTorch unabhängige Werkzeugbibliothek zur Bildbearbeitung. Sie umfasst derzeit sechs Module:
1) Torchvision.datasets: Mehrere häufig verwendete visuelle Datensätze, die heruntergeladen und geladen werden können, und wie Sie Ihren eigenen Datensatz schreiben.
2) Torchvision.models: klassische Modelle wie AlexNet, VGG, ResNet usw. und trainierte Parameter.
3) Torchvision.transforms: häufig verwendete Bildoperationen, wie z. B. zufälliges Schneiden, Drehen, Datentypkonvertierung, Tensor und Numpy sowie PIL-Bildaustausch usw.
4) Torchvision.ops: Bietet einige häufig verwendete Vorgänge in CV, wie z. B. NMS, ROI_Align, ROI_Pool usw.
5) Torchvision.io: Bietet einige Vorgänge für die Eingabe und Ausgabe, derzeit für das Schreiben und Schreiben von Videos.
6) Torchvision.utils: Andere Tools, wie zum Beispiel das Generieren eines Bildrasters usw.
5. Fehlerbehebung beim Ausführen
Frage 1: Der Datensatz ist ein Farbbild und die Anzahl der Kanäle beträgt 3, aber die Anzahl der Eingangskanäle im Modell beträgt 1, dh es empfängt ein graues Bild. Zu diesem Zeitpunkt wird ein Fehler gemeldet Trainieren des Modells. Der spezifische Fehler ist:
RuntimeError: Given groups=1, weight of size 32 3 3 3, expected input[1, 4, 416, 416] to have 3 channels
Um das Problem der Anzahl der Eingangskanäle zu lösen, dh das 3-Kanal-Farbbild in ein 1-Kanal-Graubild zu ändern, lautet die Änderungsmethode derzeit:
修改前:
train_data = torchvision.datasets.CIFAR10(root="CIFAR10", train=True,
transform=torchvision.transforms.torchvision.transforms.ToTensor(),
download=True)
test_data = torchvision.datasets.CIFAR10(root="CIFAR10", train=False,
transform=torchvision.transforms.ToTensor(),
download=True)
修改后:
train_data = torchvision.datasets.CIFAR10(root="CIFAR10", train=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.Grayscale(),
torchvision.transforms.ToTensor()]),
download=True)
test_data = torchvision.datasets.CIFAR10(root="CIFAR10", train=False,
transform=torchvision.transforms.Compose([
torchvision.transforms.Grayscale(),
torchvision.transforms.ToTensor()]),
download=True)
Das heißt, torchvision.transforms.Grayscale()
der Vorgang des Hinzufügens eines.
Frage 1: