tqdm() 函数与 enumerate()函数应用于训练神经网络

目录

1、enumerate() 函数:

2、tqdm()函数:

3、tqdm 和 enumerate() 结合:

4、tqdm() 和 enumerate()用于神经网络训练:


1、enumerate() 函数:

        这个函数的基本应用就是用来遍历一个数据对象(如列表、元组或字符串)组合,它在遍历的同时还可以得到当前元素的索引位置。一般用于在 for 循环中得到访问下标。

        enumerate 的语法是: enumerate(sequence, start=0) 其中: sequence 为可迭代对象 start 指定索引起始值,默认为 0。

参数:
        sequence:是一个可迭代对象
        start:是一个可选参数,表示索引从几开始计数

例1:

lt=['a','b','c','d','e','f','g'] # 创建lt数组
for i, item in enumerate(lt):
    print(i, item) # 打印索引和值
print(type(item)) # 查看item类型


输出结果:
0 a
1 b
2 c
3 d
4 e
5 f
6 g
<class 'str'>

在这个例子中,每个元素的下标和值都被提取到了变量i和item中,可以借此方便地进行各种操作。

例2:

扫描二维码关注公众号,回复: 15866353 查看本文章
lt=['a','b','c','d','e','f','g'] # 创建lt数组
for item in enumerate(lt):
    print(item) # 打印索引和值
print(type(item))


输出结果:
(0, 'a')
(1, 'b')
(2, 'c')
(3, 'd')
(4, 'e')
(5, 'f')
(6, 'g')
<class 'tuple'>

当返回值设置一个时,返回各个元组。

2、tqdm()函数:

        tqdm是一个Python包,可以用来实现进度条的显示。它可以在控制台中显示一个进度条,用于表示代码执行的进度,帮助开发者更好地直观地看到代码运行的进展情况。

#示例
for i in tqdm(range(20), desc='It\'s a test'):
    time.sleep(0.1)

参数:
        iterable=None,可迭代对象。如上一节中的range(20)
        desc=None,传入str类型,作为进度条标题。如上一节中的desc='It\'s a test'
        total=None,预期的迭代次数。一般不填,默认为iterable的长度。
        leave=True,迭代结束时,是否保留最终的进度条。默认保留。
        file=None,输出指向位置,默认是终端,一般不需要设置。
        ncols=None,可以自定义进度条的总长度
        unit,描述处理项目的文字,默认’it’,即100it/s;处理照片设置为’img’,则为100img/s
        postfix,以字典形式传入详细信息,将显示在进度条中。例如postfix={'value': 520}
        unit_scale,自动根据国际标准进行项目处理速度单位的换算,例如100000it/s换算为100kit/s

        tqdm()的返回值是一个可迭代对象,迭代的每一个元素就是iterable的每一个参数。该返回值可以修改进度条信息。

例1:

from tqdm import tqdm

lt=['a','b','c','d','e','f','g'] # 创建lt数组

for item in tqdm(lt):
    print(item)
print(type(item))


输出

a
b
c
d
e
f
g
<class 'str'>
100%|██████████| 7/7 [00:00<?, ?it/s]

例2:

from tqdm import tqdm

lt=['a','b','c','d','e','f','g'] # 创建lt数组

for i, item in tqdm(lt):
    print(i, item)
print(type(item))


输出结果

0%|          | 0/7 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "D:/pyProjects/Build-Your-Own-Face-Model-master/recognition/forTest.py", line 5, in <module>
    for i, item in tqdm(lt):
ValueError: not enough values to unpack (expected 2, got 1)

出错了,提示没有那么多返回值

3、tqdm 和 enumerate() 结合:

from tqdm import tqdm
lt=['a','b','c']
for i,item in enumerate(tqdm(lt)):
    print(i, item)

4、tqdm() 和 enumerate()用于神经网络训练:

tqdm()函数:

for data, labels in tqdm(dataloader, desc=f"Epoch {e}/{conf.epoch}",
                             ascii=True, total=len(dataloader)):
    data = data.to(device)
    labels = labels.to(device)



enumerate()函数:

for ii, data in enumerate(trainloader):  
    data_input, labels = data
    data_input = data_input.to(device)
    labels = label.to(device).long()


这是两个函数在神经网络训练时的代码片段,可以发现:
    tqdm()函数可以直接返回data, labels值;
    enumerate()返回的是数据列表的id与对应的(data, labels)元组,因此enumerate()还
需一步: data_input, labels = data,提取对应的tata与labels。

猜你喜欢

转载自blog.csdn.net/aizsa111/article/details/131747501
今日推荐