Idea básica: use CNN para la clasificación de caracteres de longitud fija;
Requisitos del sistema operativo: Python2 / 3, memoria 4G, con o sin GPU
Este problema vuelve a ocurrir en% pylab en línea, y los símbolos anteriores parecen ser inapropiados para el sistema actual.
nombre | Talla | Enlace |
---|---|---|
OCNLI_train1128.csv | 5,78 MB | http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531841/OCNLI_train1128.csv |
TNEWS_train1128.csv | 4,38 MB | http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531841/TNEWS_train1128.csv |
OCEMOTION_train1128.csv | 4,96 MB | http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531841/OCEMOTION_train1128.csv |
Este no ha tenido tiempo de descargar
Código completo:
import os, sys, glob, shutil, json
import cv2
desde PIL import Image
import numpy as np
importar antorcha
de torch.utils.data.dataset importar conjunto de datos
importar torchvision.transforms as transforms
class SVHNDataset (Conjunto de datos):
def init (self, img_path, img_label, transform = None):
self.img_path = img_path
self.img_label = img_label
si transform no es None:
self.transform = transform
else:
self.transform = None
def __getitem__(self, index):
img = Image.open(self.img_path[index]).convert('RGB')
if self.transform is not None:
img = self.transform(img)
# 原始SVHN中类别10为数字0
lbl = np.array(self.img_label[index], dtype=np.int)
lbl = list(lbl) + (5 - len(lbl)) * [10]
return img, torch.from_numpy(np.array(lbl[:5]))
def __len__(self):
return len(self.img_path)
train_path = glob.glob ('… / input / train / *. png')
train_path.sort ()
train_json = json.load (open ('… / input / train.json'))
train_label = [train_json [x] [ 'etiqueta'] para x en train_json]
datos = SVHNDataset (train_path, train_label,
transforms.Compose ([
# Escala a un tamaño fijo
transforms.Resize ((64, 128)),
# 随机颜色变换
transforms.ColorJitter(0.2, 0.2, 0.2),
# 加入随机旋转
transforms.RandomRotation(5),
# 将图片转换为pytorch 的tesntor
# transforms.ToTensor(),
# 对图像像素进行归一化
# transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
]))
通过上述代码,可以将赛题的图像数据和对应标签进行读取,在读取过程中的进行数据扩增,效果如下所示:
|1|2|3|
|----|-----|------|
|![IMG](IMG/Task02/23.png) | ![IMG](IMG/Task02/23_1.png)| ![IMG](IMG/Task02/23_2.png)|
|![IMG](IMG/Task02/144_1.png) | ![IMG](IMG/Task02/144_2.png)| ![IMG](IMG/Task02/144_3.png)|
接下来我们将在定义好的Dataset基础上构建DataLoder,你可以会问有了Dataset为什么还要有DataLoder?其实这两个是两个不同的概念,是为了实现不同的功能。
- Dataset:对数据集的封装,提供索引方式的对数据样本进行读取
- DataLoder:对Dataset进行封装,提供批量读取的迭代读取
加入DataLoder后,数据读取代码改为如下:
```python
import os, sys, glob, shutil, json
import cv2d
from PIL import Image
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms
class SVHNDataset(Dataset):
def __init__(self, img_path, img_label, transform=None):
self.img_path = img_path
self.img_label = img_label
if transform is not None:
self.transform = transform
else:
self.transform = None
def __getitem__(self, index):
img = Image.open(self.img_path[index]).convert('RGB')
if self.transform is not None:
img = self.transform(img)
# 原始SVHN中类别10为数字0
lbl = np.array(self.img_label[index], dtype=np.int)
lbl = list(lbl) + (5 - len(lbl)) * [10]
return img, torch.from_numpy(np.array(lbl[:5]))
def __len__(self):
return len(self.img_path)
train_path = glob.glob('../input/train/*.png')
train_path.sort()
train_json = json.load(open('../input/train.json'))
train_label = [train_json[x]['label'] for x in train_json]
train_loader = torch.utils.data.DataLoader(
SVHNDataset(train_path, train_label,
transforms.Compose([
transforms.Resize((64, 128)),
transforms.ColorJitter(0.3, 0.3, 0.2),
transforms.RandomRotation(5),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])),
batch_size=10, # 每批样本个数
shuffle=False, # 是否打乱顺序
num_workers=10, # 读取的线程个数
)
for data in train_loader:
break
Después de agregar DataLoder, los datos se obtienen en lotes y cada lote se denomina Conjunto de datos para leer una sola muestra para el empalme. En este momento, el formato de los datos es: el
torch.Size([10, 3, 64, 128]), torch.Size([10, 6])
primero es un archivo de imagen en el orden de tamaño de lote * canal * alto * ancho; el segundo es una etiqueta de carácter.
Necesita usar cv2