人工智能训练过程中的进度条(python)
1.基于print的进度条
1.1 基础介绍
这里使用了转义字符\r
,其目的是下次打印时从本行开头进行打印。
如果print里设置end=''
,则不会换行,这时下次打印会对上次结果进行覆盖。
如果print里设置end='\n'
或者不设置end
的值,就会进行换行,那么就不会覆盖上次结果。
转义字符1 | 意义 |
---|---|
\n | 换行(LF) ,将当前位置移到下一行开头 |
\r | 回车(CR) ,将当前位置移到本行开头 |
1.2 代码举例
import time
epochs = 3
n_batch = 6
for i in range(epochs):
for j in range(n_batch):
time.sleep(0.5)
loss =1/((i+1)*(j+1))
print("\rEpoch: {:d} batch: {:d} loss: {:.4f} ".format(i+1, j+1, loss), end='')
print("\rEpoch: {:d}/{:d} epoch_loss: {:.4f} ".format(i+1, epochs, loss, end='\n'))
最终结果如下:
显然,print("\rEpoch: {:d} batch: {:d} loss: {:.4f} ".format(i+1, j+1, loss), end='')
的结果被print("\rEpoch: {:d}/{:d} epoch_loss: {:.4f} ".format(i+1, epochs, loss, end='\n'))
进行了覆盖。
2.使用sys.stdout.write进行打印
2.1 一般情况
sys.stdout.write
和print
类似,不同的是:
print
不设置,则默认换行,而sys.stdout.write
则默认不换行。
import time
import sys
epochs = 3
n_batch = 6
for i in range(epochs):
for j in range(n_batch):
time.sleep(0.5)
loss =1/((i+1)*(j+1))
sys.stdout.write("\rEpoch: {:d} batch: {:d} loss: {:.4f}".format(i+1, j+1, loss))
sys.stdout.write("\rEpoch: {:d}/{:d} train_loss: {:.4f} \n".format(i+1, epochs, loss))
#sys.stdout.write("\n")
输出结果与1完全一致。
2.2 存在多个sys.stdout.write的情况
2.2.1 不符合预期的打印结果
如果有多个sys.stdout.write
,同时还要让输出结果在同一行刷新,有时会存在一些不正常的情形。
如下:
import time
import sys
epochs = 3
n_batch = 6
def do_something():
time.sleep(0.3)
for i in range(epochs):
for j in range(n_batch):
sys.stdout.write("\rEpoch: {:d} ".format(i+1))
loss =1/((i+1)*(j+1))
do_something()
sys.stdout.write("batch: {:d} ". format(j+1))
sys.stdout.write("loss: {:.4f} ".format(loss))
sys.stdout.flush()
最终打印结果为:
可以看出,与我们的期望不太一致。
这有可能是因为程序打印过程中某些步骤存在占用时间较长的执行过程,造成各个sys.stdout.write不同步造成的。
程序如果是下面这个样子,打印就是正常的:
import time
import sys
epochs = 3
n_batch = 6
def do_something():
time.sleep(0.1)
for i in range(epochs):
for j in range(n_batch):
sys.stdout.write("\rEpoch: {:d} ".format(i+1))
loss =1/((i+1)*(j+1))
sys.stdout.write("batch: {:d} ". format(j+1))
sys.stdout.write("loss: {:.4f} ".format(loss))
do_something()
2.2.2 一般性的解决方法
针对这种情况,给出一种更一般性的方法:
即将各个需要打印的字符串先存起来,最后再统一使用一个sys.stdout.write
进行打印。
import time
import sys
epochs = 3
n_batch = 6
for i in range(epochs):
prt1="\rEpoch: {:d} ".format(i+1)
#sys.stdout.write(prt1)
for j in range(n_batch):
prt2="batch: {:d} ". format(j+1)
#sys.stdout.write("batch: {:d}". format(j+1))
time.sleep(0.5)
loss =1/((i+1)*(j+1))
prt3="loss: {:.4f} ".format(loss)
sys.stdout.write(prt1+prt2+prt3)
结果如下:
3.利用tqdm
import time
from tqdm import tqdm
pbar=tqdm(range(100))
for i in pbar:
time.sleep(.01)
pbar.set_description("Processing %s" % i)
看个人喜欢哪种风格了~
参考文献
[1] 如何用 Python 给程序加个进度条?
[2] python输出结果刷新及进度条的实现操作