使用数据集工具

一.数据集工具介绍
HuggingFace通过API提供了统一的数据集处理工具,它提供的数据集如下所示:

该界面左侧可以根据不同的任务类型、类库、语言、License等来筛选数据集,右侧为具体的数据集列表,其中有经典的glue、super_glue数据集,问答数据集squad,情感分类数据集imdb,纯文本数据集wikitext等。进入sgugger/glue-mrpc数据集页面,可看到对该数据集的相关介绍,如下所示:

二.使用数据集工具
1.数据集加载和保存
以加载seamew/ChnSentiCorp数据集为例,在线加载如下所示:

#第3章/加载数据集
from datasets import load_dataset
dataset = load_dataset(path='seamew/ChnSentiCorp')
print(dataset)

load_dataset()函数的定义为:

def load_dataset(
    path: str,
    name: Optional[str] = None,
    data_dir: Optional[str] = None,
    data_files: Optional[Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]] = None,
    split: Optional[Union[str, Split]] = None,
    cache_dir: Optional[str] = None,
    features: Optional[Features] = None,
    download_config: Optional[DownloadConfig] = None,
    download_mode: Optional[Union[DownloadMode, str]] = None,
    verification_mode: Optional[Union[VerificationMode, str]] = None,
    ignore_verifications="deprecated",
    keep_in_memory: Optional[bool] = None,
    save_infos: bool = False,
    revision: Optional[Union[str, Version]] = None,
    use_auth_token: Optional[Union[bool, str]] = None,
    task: Optional[Union[str, TaskTemplate]] = None,
    streaming: bool = False,
    num_proc: Optional[int] = None,
    storage_options: Optional[Dict] = None,
    **config_kwargs,
) -> Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset]:

重点介绍几个参数,比如使用path指定数据集,name指定数据子集,split指定要加载的数据部分:

#第3章/加载glue数据集
dataset = load_dataset(path='glue', name='sst2', split='train')
print(dataset)

2.将数据集保存到本地磁盘

#第3章/将数据集保存到磁盘
dataset.save_to_disk(dataset_dict_path='./data/ChnSentiCorp')

3.从本地磁盘加载数据集

#第3章/从磁盘加载数据集
from datasets import load_from_disk
dataset = load_from_disk('./data/ChnSentiCorp')

4.取出数据部分

#使用train数据子集做后续的实验
dataset = dataset['train']

5.查看数据内容

#第3章/查看数据样例
for i in [12, 17, 20, 26, 56]:
    print(dataset[i])

6.数据排序
使用sort()函数让数据按照某个字段排序:

#第3章/排序数据
#数据中的label是无序的
print(dataset['label'][:10])
#让数据按照label排序
sorted_dataset = dataset.sort('label')
print(sorted_dataset['label'][:10])
print(sorted_dataset['label'][-10:])

7.打乱数据
使用shuffle()函数打乱数据:

#第3章/打乱数据顺序
shuffled_dataset=sorted_dataset.shuffle(seed=42)
shuffled_dataset['label'][:10]

8.数据抽样
使用select()函数从数据集中选择某些数据,然后组装成一个数据子集:

#第3章/从数据集中选择某些数据
dataset.select([0, 10, 20, 30, 40, 50])

9.数据过滤
使用filter()函数可以按照自定义的规则过滤数据:

#第3章/过滤数据
def f(data):
    return data['text'].startswith('非常不错')
dataset.filter(f)

10.训练测试集拆分
可以使用train_test_split()函数将数据集切分为训练集和测试集:

#第3章/切分训练集和测试集
dataset.train_test_split(test_size=0.1)

11.数据分桶
使用shared()函数把数据均匀地分为n部分:

#第3章/数据分桶
dataset.shard(num_shards=4, index=0)

其中,num_shards表示要把数据均匀地分为几部分,index表示要取出第几份数据。
12.重命名字段
使用rename_column()函数可以重命名字段:

#第3章/字段重命名
dataset.rename_column('text', 'text_rename')

13.删除字段
使用remove_columns()函数可以删除字段:

#第3章/删除字段
dataset.remove_columns(['text'])

14.映射函数
使用map()函数遍历数据,并且对每条数据都进行修改:

#第3章/应用函数
def f(data):
    data['text'] = 'My sentence: ' + data['text']
    return data
maped_datatset = dataset.map(f)
print(dataset['text'][20])
print(maped_datatset['text'][20])

15.使用批处理加速

#第3章/使用批处理加速
def f(data):
    text=data['text']
    text=['My sentence: ' + i for i in text]
    data['text']=text
    return data
maped_datatset=dataset.map(function=f, batched=True, batch_size=1000, num_proc=4)
print(dataset['text'][20])
print(maped_datatset['text'][20])

16.设置数据格式
使用set_format()函数修改数据格式:

#第3章/设置数据格式
dataset.set_format(type='torch', columns=['label'], output_all_columns=True)
print(dataset[20])

其中,type表示要修改的数据类型(numpy|torch|tensorflow|pandas等),columns表示要修改格式的字段,output_all_columns表示是否要保留其它字段,设置为True表示保留。
17.将数据保存为CSV格式

#第3章/导出为CSV格式
dataset = load_dataset(path='seamew/ChnSentiCorp', split='train')
dataset.to_csv(path_or_buf='./data/ChnSentiCorp.csv')
#加载CSV格式数据
csv_dataset = load_dataset(path='csv', data_files='./data/ChnSentiCorp.csv', split='train')
print(csv_dataset[20])

18.保存数据为JSON格式

#第3章/导出为JSON格式
dataset=load_dataset(path='seamew/ChnSentiCorp', split='train')
dataset.to_json(path_or_buf='./data/ChnSentiCorp.json')
#加载JSON格式数据
json_dataset=load_dataset(path='json', data_files='./data/ChnSentiCorp.json', split='train')
print(json_dataset[20])

参考文献:
[1]《HuggingFace自然语言处理详解:基于BERT中文模型的任务实战》
[2]https://huggingface.co/datasets/seamew/ChnSentiCorp

猜你喜欢

转载自blog.csdn.net/shengshengwang/article/details/131423213
今日推荐