データの独自のセットをロード - 予測-pytorchコードcrnn達成するために、画像のシーケンスに基づいています

内容
1、-pytorchの予測画像シーケンスcrnnコードに基づいて、 -独自のデータセットロード
、2 予測-pytorch crnn画像シーケンスコードをベース-で説明したモデル
3、-pytorchの予測画像シーケンスcrnnコードに基づい-訓練プロセスそして、一般的なエラー
ミッションブリーフィング:ここではリカレントニューラルネットワークに基づくコンボリューションは、予測画像シーケンスを作ります。スリーステート・ラベルの画像に対応する画像のシーケンスの各。各特徴抽出の9つの連続ネットワークイメージの特徴付け畳み込み配列は、抽出した後、ニューラルネットワークサイクル(LSTM)に入力され、次の画像は、第十の計算画像と元の画像シーケンスの状態を予測し損失。

導入データセットの2つの工業(訓練期間、テスト期間)からビデオデータ、(時間的に連続)画像シーケンスに従ってビデオフレーム抽出。画像タグ、ラベル三つの状態(0,1,2)、一つの状態に対応する各画像。

データ前処理:00001.jpg-15000.jpgから名付けられたトレーニングセットを含む15,000画像、3×256×256のサイズ、。画像のパスを保存し、txtファイルにタグ情報を、スペースで区切って。図:
ここに画像を挿入説明
ロード:連続する各反復は9つの第十イメージとイメージのラベルを返したように、ここでは、__ DataSetクラス()関数を__getitem書き換えます。次のように詳細なコードは次のとおりです。

def default_loader(path):
    return Image.open(path).convert('RGB')


class MyDataset(Dataset):
    def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
        fh = open(txt, 'r')
        imgs = []
        for line in fh:
            line = line.strip('\n')
            line = line.rstrip()
            words = line.split()
            imgs.append((words[0], int(words[1])))
        imgs.sort(key=lambda x: x[0], reverse=False)
        self.num_samples = len(imgs)
        self.num_samples_per_iteration = 9
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        current_index = np.random.choice(range(self.num_samples_per_iteration, self.num_samples))
        current_imgs = []
        current_label = self.imgs[current_index][1]
        for i in range(current_index - self.num_samples_per_iteration, current_index):
            fn, label = self.imgs[i]
            img = self.loader(fn)
            if self.transform is not None:
                img = self.transform(img)
            current_imgs.append(img)
        batch_cur_imgs = np.stack(current_imgs, axis=0)  # [9, 3, 256, 256]
        return batch_cur_imgs, current_label

    def __len__(self):
        return len(self.imgs)


train_data = MyDataset(txt='trainset256.txt', transform=transforms.ToTensor())
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=False)

test_data = MyDataset(txt='testset256.txt', transform=transforms.ToTensor())
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)
print('num_of_trainData:', len(train_data))
print('num_of_testData:', len(test_data))

おすすめ

転載: blog.csdn.net/hnu_zzt/article/details/86494331