PyTorch:二、构建卷积神经网络

一、制作自己的数据集
  1. 源代码
import torch
from torch.utils.data import Dataset

import pandas as pd
import numpy as np
# txt文件内容 路径 \t 类别 \t 长度 \n
txt_path = 'G:/stock/path.txt'

class SocktData(Dataset):
    dataset = []
    # 【data,label】形式初始化
    def __init__(self,txt_path):
        fh = open(txt_path)
        for line in fh:
            # 移除空格
            line = line.rstrip('\n')
            # 移除制表符
            line = line.split('\t')
            # 因为label=line[1]是字符串,训练时用的是数字,进行转换
            if line[1] == 'Rising':
                label = 0
            elif line[1] == 'Falling':
                label = 1
            else:
                label = 2
            self.dataset.append([line[0],label])
     # 返回data长度
    def __len__(self):
        return len(self.data)
    # 获取单个元素
    def __getitem__(self, index):
        data_path,label = self.dataset[0]
        # pd 读取数据 pd.DataFrame格式
        data = pd.read_csv(data_path)
        # data_array np.array格式
        data_array = np.array(data)
        # data_tensor Torch.tensor格式
        data_tensor = torch.Tensor(data_array)
        print('----------------------getitem--------------------')
        print('------------------打印data_tensor类型-------------')
        print(data_tensor.size())
        return data_tensor,label
# 制作数据集
train_set = SocktData(txt_path)
# 加载数据集
data,label = train_set[0]
print('-----------打印第一个数据内容和标签类型-------------')
print('type(data) = ',type(data))
print('type(label) = ',type(label))
  1. 执行结果:
    执行结果
  2. 评价:目前数据集构建好了~ 准备构建卷积神经网络。
二、构建卷积神经网络
  1. 源代码

  1. 执行结果
发布了19 篇原创文章 · 获赞 6 · 访问量 1549

猜你喜欢

转载自blog.csdn.net/qq_42759370/article/details/104365858