1.说明
使用requests库可以实现网络请求,但如果用于下载大文件,单线程下载确实不能很好地利用宽带,改为多线程会更好一点。
2.实现思路
1.当我们请求下载文件的时候,可以使用head请求看一下该文件有多大,响应头里的“Content-Length”字段表示文件的字节数。
2.拿到了文件大小之后,根据线程数划分为多个数据块,即每个线程都请求一部分,在请求头的“Range”字段指定下载范围。
3.因为是同一个进程同时写入一个文件,所以一定要注意上锁。
4.使用requests库下载的时候参数要指定stream=True,不然全加载到内存中就不好了。
5.如果发现其中一个块下载失败了,那就相当于整个文件都失败了,不过我还是希望再尝试下载2次才确定失败。
3.参考代码
import logging
import os.path
import threading
import time
from contextlib import closing
import requests
class MultiDownloader:
def __init__(self, url, save_path=None, file_name=None, thread_count=10, headers=None):
self.url = url
self.headers = headers if isinstance(headers, dict) else dict()
current_file_path = os.path.dirname(os.path.abspath(__file__))
self.save_path = save_path if save_path else os.path.join(current_file_path, "multi_download")
self.total_range = None
self.logger = self.get_logger()
self.get_resp_header_info()
if file_name:
self.file_name = file_name
if not self.file_name:
self.file_name = os.path.split(url)[1]
self.file_lock = threading.Lock()
self.thread_count = thread_count
self.failed_thread_list = list()
self.finished_thread_count = 0
self.chunk_size = 1024 * 100
self.logger.info(f"init multi task, url:{
self.url}")
self.logger.info(f"init multi task, sava_path:{
self.save_path}")
self.logger.info(f"init multi task, file_name:{
self.file_name}")
self.logger.info(f"init multi task, thread_count:{
self.thread_count}")
self.logger.info(f"init multi task, headers:{
self.headers}")
def get_logger(self):
logger = logging.getLogger("MultiDownloader")
logger.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s-%(filename)s-line:%(lineno)d-%(levelname)s-%(process)s: %(message)s")
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(formatter)
if not os.path.exists(self.save_path):
os.mkdir(self.save_path)
file_handler = logging.FileHandler(os.path.join(self.save_path, "download.log"), encoding="utf-8")
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
logger.addHandler(console_handler)
logger.addHandler(file_handler)
return logger
def get_resp_header_info(self):
res = requests.head(self.url, headers=self.headers, allow_redirects=True)
res_header = res.headers
self.logger.info(f"get_resp_header_info() res_header: {
res_header}")
content_range = res_header.get("Content-Length", "0")
self.total_range = int(content_range)
self.file_name = res_header.get("Content-Disposition", "").replace("attachment;filename=", "").replace('"', '')
self.url = res.url
def page_dispatcher(self, content_size):
page_size = content_size // self.thread_count
start_pos = 0
while start_pos + page_size < content_size:
yield {
'start_pos': start_pos,
'end_pos': start_pos + page_size
}
start_pos += page_size + 1
yield {
'start_pos': start_pos,
'end_pos': content_size - 1
}
def download_range(self, thread_name, page, file_handler):
self.logger.info(f"thread {
thread_name} start to download")
range_headers = {
"Range": f'bytes={
page["start_pos"]}-{
page["end_pos"]}'}
range_headers.update(self.headers)
try:
start_time = time.time()
try_times = 3
is_success = False
for i in range(try_times):
try:
with closing(requests.get(url=self.url, headers=range_headers, stream=True, timeout=30)) as res:
self.logger.info(f"thread {
thread_name} download length: {
len(res.content)}")
if res.status_code == 206:
for data in res.iter_content(chunk_size=self.chunk_size):
with self.file_lock:
file_handler.seek(page["start_pos"])
file_handler.write(data)
page["start_pos"] += len(data)
is_success = True
break
except Exception as e:
self.logger.error(f"download_range() request error: {
e}")
self.finished_thread_count += 1
spent_time = time.time() - start_time
if is_success:
self.logger.info("thread {} download success, spent_time: {}, progress: {}/{}".format(
thread_name, spent_time, self.finished_thread_count, self.thread_count
))
else:
self.logger.error(f"thread {
thread_name} download {
try_times} times but failed")
self.failed_thread_list.append(thread_name)
except Exception as e:
self.logger.error(f"thread {
thread_name} download failed: {
e}")
self.failed_thread_list.append(thread_name)
def run(self, ):
self.logger.info(f"run() get file total range: {
self.total_range}")
if not self.total_range or self.total_range < 1024:
raise Exception("get file total size failed")
thread_list = list()
full_path = os.path.join(self.save_path, self.file_name)
self.logger.info(f"ready to download, full_path: {
full_path}")
start_time = time.time()
with open(full_path, "wb+") as f:
for i, page in enumerate(self.page_dispatcher(self.total_range)):
self.logger.info("page: {}, page difference: {}".format(page, page["end_pos"] - page["start_pos"]))
thread_list.append(threading.Thread(target=self.download_range, args=(i, page, f)))
for thread in thread_list:
thread.start()
for thread in thread_list:
thread.join()
try:
actual_size = os.path.getsize(full_path)
except Exception as e:
actual_size = 0
self.logger.warning(f"get actual file size failed:, full_path: {
full_path}, error: {
e}")
if os.path.exists(full_path) and os.path.getsize(full_path) == 0:
self.logger.warning(f"file size is 0, remove, full_path:{
full_path}")
os.remove(full_path)
total_time = time.time() - start_time
self.logger.info("download finishing..........")
self.logger.info("total size %d Bytes (%.2f MB), actual file size %d Bytes, are they equal? %s" % (
self.total_range, self.total_range / (1024 * 1024), actual_size, self.total_range == actual_size,
))
self.logger.info("total spent time: %.2f second, average download speed: %.2f MB/s" % (
total_time, actual_size / (1024 * 1024) / total_time
))
if self.failed_thread_list:
self.logger.info(f"failed_thread_list: {
self.failed_thread_list}")
final_result = "download success!" if self.total_range == actual_size else "download failed"
self.logger.info(final_result)
if __name__ == '__main__':
params = {
"url": "https://dldir1.qq.com/qqfile/qq/PCQQ9.6.9/QQ9.6.9.28878.exe",
"save_path": "", # 保存文件的路径
# "file_name": "QQ9.6.9.28878.exe", # 保存的文件名,若不写会自动尝试获取
"headers": {
# "user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/106.0.0.0 Safari/537.36", # 用户代理
},
"thread_count": 10 # 线程数,即同时下载的任务数
}
downloader = MultiDownloader(**params)
downloader.run()
该项目已开源
GitHub链接:https://github.com/panmeibing/python_downloader