内容:
1、-pytorchの予測画像シーケンスcrnnコードに基づいて、 -独自のデータセットロード
、2 予測-pytorch crnn画像シーケンスコードをベース-で説明したモデル
3、-pytorchの予測画像シーケンスcrnnコードに基づい-訓練プロセスそして、一般的なエラー
ミッションブリーフィング:ここではリカレントニューラルネットワークに基づくコンボリューションは、予測画像シーケンスを作ります。スリーステート・ラベルの画像に対応する画像のシーケンスの各。各特徴抽出の9つの連続ネットワークイメージの特徴付け畳み込み配列は、抽出した後、ニューラルネットワークサイクル(LSTM)に入力され、次の画像は、第十の計算画像と元の画像シーケンスの状態を予測し損失。
導入データセットの2つの工業(訓練期間、テスト期間)からビデオデータ、(時間的に連続)画像シーケンスに従ってビデオフレーム抽出。画像タグ、ラベル三つの状態(0,1,2)、一つの状態に対応する各画像。
データ前処理:00001.jpg-15000.jpgから名付けられたトレーニングセットを含む15,000画像、3×256×256のサイズ、。画像のパスを保存し、txtファイルにタグ情報を、スペースで区切って。図:
ロード:連続する各反復は9つの第十イメージとイメージのラベルを返したように、ここでは、__ DataSetクラス()関数を__getitem書き換えます。次のように詳細なコードは次のとおりです。
def default_loader(path):
return Image.open(path).convert('RGB')
class MyDataset(Dataset):
def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
fh = open(txt, 'r')
imgs = []
for line in fh:
line = line.strip('\n')
line = line.rstrip()
words = line.split()
imgs.append((words[0], int(words[1])))
imgs.sort(key=lambda x: x[0], reverse=False)
self.num_samples = len(imgs)
self.num_samples_per_iteration = 9
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
self.loader = loader
def __getitem__(self, index):
current_index = np.random.choice(range(self.num_samples_per_iteration, self.num_samples))
current_imgs = []
current_label = self.imgs[current_index][1]
for i in range(current_index - self.num_samples_per_iteration, current_index):
fn, label = self.imgs[i]
img = self.loader(fn)
if self.transform is not None:
img = self.transform(img)
current_imgs.append(img)
batch_cur_imgs = np.stack(current_imgs, axis=0) # [9, 3, 256, 256]
return batch_cur_imgs, current_label
def __len__(self):
return len(self.imgs)
train_data = MyDataset(txt='trainset256.txt', transform=transforms.ToTensor())
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=False)
test_data = MyDataset(txt='testset256.txt', transform=transforms.ToTensor())
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)
print('num_of_trainData:', len(train_data))
print('num_of_testData:', len(test_data))