Week 5:装饰器、functools模块和部分习题

一、无参装饰器

给原函数增加新的附加功能,增强其前置、后置功能,而不改变原函数代码

#给函数增加一个日志记录的附加功能、测试运行时间,即装饰器logger

import datetime

def logger(fn):
    def wrapper(*args,**kwargs):
        start=datetime.datetime.now()
        ret=fn(*args,**kwargs)
        delta=(datetime.datetime.now()-start).total_seconds()
        print("Function {} took {}s".format(fn.__name__,delta))
        return ret
    return wrapper

@logger  #add=logger(add) => add=wrapper
def add(x,y):  
    return x+y

print(add(4,5))  
#wrapper(4,5) 实际调用的是wrapper,也就是logger(add)(4,5)
#在做完“start=datetime.datetime.now()”功能增强后,再执行ret=fn(*args,**kwargs),不影响原函数
#wrappper包装者,add是wappered是被包装者

二、带参装饰器

1、修改文档字符串和名称等特殊属性,即将wrapper函数的函数名和文档字符串修改为add的

def copy_properties(src):
    def _copy(dest):
        dest.__name__=src.__name__
        dest.__doc__=src.__doc__
        return dest
    return _copy

def logger(fn):
    @copy_properties(fn) 
     #wrapper=copy_properties(fn)(wrappper)=_copy(wrapper)
    def wrapper(*args,**kwargs):
        print("Before")
        ret=fn(*args,**kwargs)
        print("After")
        return ret
    return wrapper

@logger  
#add=logger(add) => add=wrapper
def add(x,y): 
    """This is add function document"""
    return x+y

print(add(4,5))
print(add.__name__, add.__doc__)

也可从模块中导入from functools import update_wrapper功能,在定义好wrapper函数后调用update_wrapper函数,传两个参数wrapper和fn,改动代码如下:

#定义好wrapper函数后,函数调用方式

from functools import update_wrapper

def logger(fn):
    def wrapper(*args,**kwargs):
        print("Before")
        ret=fn(*args,**kwargs)
        print("After")
        return ret
    update_wrapper(wrapper,fn)
    return wrapper

@logger  
#add=logger(add) => add=wrapper
def add(x,y): 
    """This is add function document"""
    return x+y

print(add(4,5))
print(add.__name__, add.__doc__)
#由于update_wrapper函数必填形参是wrapper和wrapped,返回值是wrapper,也可直接在定义好wrapper函数后,把logger函数的返回值return wrapper改写成return update_wrapper(wrapper,fn)

from functools import update_wrapper

def logger(fn):
    def wrapper(*args,**kwargs):
        print("Before")
        ret=fn(*args,**kwargs)
        print("After")
        return ret
    return update_wrapper(wrapper,fn)

@logger  
#add=logger(add) => add=wrapper
def add(x,y): 
    """This is add function document"""
    return x+y

print(add(4,5))
print(add.__name__, add.__doc__)

函数update_wrapper的带参装饰器版本为导入from functools import wraps,带参装饰器函数wraps只需传一个参数fn,代码如下:

from functools import wraps

def logger(fn):
    @wraps(fn)
    #wrapper=wraps(fn)(wrapper) #wrapper=wrappper
    def wrapper(*args,**kwargs):
        print("Before")
        ret=fn(*args,**kwargs)
        print("After")
        return ret
    return wrapper

@logger  
#add=logger(add) => add=wrapper
def add(x,y): 
    """This is add function document"""
    return x+y

print(add(4,5))
print(add.__name__, add.__doc__)

#通过引用wraps装饰器把wrapper这个内层函数的名称和文档字符串等属性做了替换,内层函数wrapper的功能并没有被改变,让wrapper显现原函数add的一些属性,所以@wraps(fn)时,wrapper=wraps(fn)(wrapper)推出wrapper=wrappper,即返回值需为wrapper才能实现

相比于update_wrapper函数调用,wraps装饰器更常用

2、给add函数加一个能调节duration的装饰器,并控制输出

需求1:获取函数的执行时长,并记录,其中时长的标准可调节

import datetime
import time

def logger(duration):
    def _logger(fn):
        def wrapper(*args,**kwargs):
            start=datetime.datetime.now()
            ret=fn(*args,**kwargs)
            delta=(datetime.datetime.now()-start).total_seconds()
            print("Function {} took {}s".format(fn.__name__,delta))
            if delta>duration:
                print('slow')
            else:
                print('fast')
            return ret
        return wrapper
    return _logger

@logger(3) #add=logger(3)(add)=_logger(add)=wrapper
def add(x,y):  
    time.sleep(3)
    return x+y

print(add(4,5))

带参装饰器的层数不一定,可能是两层、三层

使用@functionname(参数列表)方式调用,把参数都写在最外层的参数列表中,中间层形参只留一个fn,可以看做在装饰器外通过柯里化又加了一层函数;在最外层添加新参数即可,也无需再柯里化,而若在第二层添加新参数会影响装饰器的使用,不合适;

也可把超时打印这个耦合度太高的非核心功能提到最外面的参数列表,用lambda表达式输出到控制台

from functools import wraps
import datetime
import time

def logger(duration,func=lambda name,duration:print("Function {} took {}s".format(name,duration))):
    def _logger(fn):
        @wraps(fn)
        def wrapper(*args,**kwargs):
            start=datetime.datetime.now()
            ret=fn(*args,**kwargs)
            delta=(datetime.datetime.now()-start).total_seconds()
            if delta>duration:
                func(fn.__name__,delta)
            return ret
        return wrapper
    return _logger

@logger(2) #add=logger(3)(add)=_logger(add)=wrapper
def add(x,y):  
    time.sleep(2)
    return x+y

print(add(4,5))

若lambda表达式太长,也可在外面另写一个输出到日志文件、或控制台的新函数,在装饰器参数列表里把函数名传过来

from functools import wraps
import datetime
import time

def toprint(name,duration):
    print("Function {} took {}s".format(name,duration))

def logger(duration,func=toprint):
    def _logger(fn):
        @wraps(fn)
        def wrapper(*args,**kwargs):
            start=datetime.datetime.now()
            ret=fn(*args,**kwargs)
            delta=(datetime.datetime.now()-start).total_seconds()
            if delta>duration:
                func(fn.__name__,delta)
            return ret
        return wrapper
    return _logger

@logger(2) #add=logger(3)(add)=_logger(add)=wrapper
def add(x,y):  
    time.sleep(2)
    return x+y

print(add(4,5))

这样一来具有记录功能的 toprint 函数就作为缺省值添加到了logger装饰器,也可再写一些 tolog、toanywhere的函数,在给add函数加装饰器 @logger(2) 时一并传入,写成 @logger(2,toanywhere),灵活控制输出;

把一些功能函数起个明显的名字作为高阶函数传参调用,使得函数功能清晰,复用性也高

三、functools.update_wrapper

functools.update_wrapper(wrapper,wrapped,assigned =WRAPPER_ASSIGNMENTS,updated = WRAPPER_UPDATES)

其中wrapper是包装函数、被更新者,wrapped是被包装函数、数据源,元祖WRAPPER_ASSIGNMENTS中是要被覆盖的属性’module‘, ‘name‘, ‘qualname‘, ‘doc‘, ‘annotations‘,模块名、函数名称、限定名、文档、参数注解,元祖WRAPPER_UPDATES中是要被更新的属性;

四、inspect 模块

inspect模块:提供获取对象信息的函数、做类型检查
1、定义函数时做了函数参数注解后,可在传参时加一个检查参数类型的装饰器,代码如下

import inspect

def check(fn):  
    def wrapper(*args,**kwargs):
        print(inspect.isfunction(fn))  #是否是函数
        sig=inspect.signature(fn)  #获取函数签名
        print("signature:",sig)
        params=sig.parameters  #获取函数参数列表
        print("parameters:",params)
        ret=fn(*args,**kwargs)
        return ret
    return wrapper

@check
def add(x:int,y:int) -> int: 
    return x+y

print(add(4,5))

2、示例:

def add(x,y:int=7) -> int:
    return x+y
#检查用户输入是否符合参数注解的要求,不符合的话提醒用户

代码如下

import inspect

def check(fn):
    def wrapper(*args,**kwargs):
        sig=inspect.signature(fn) #拿到签名
        params=sig.parameters  #拿到参数列表,是一个有序字典
        values=list(params.values()) #取到有参数类型注解的部分,在字典的值里
        for i,p in enumerate(args):
            param=values[i]
            if param.annotation is not param.empty and not isinstance(p,param.annotation):
                print(p,'!==',values[i].annotation)
        for k,v in kwargs.items():
            if params[k].annotation is not inspect._empty and not isinstance(v,params[k].annotation):
                print(k,v,'!==',params[k].annotation)
        return fn(*args,**kwargs)
    return wrapper

@check
def add(x,y:int=7) -> int: 
    return x+y

print(add('b',y='a'))

五、functools.partial 偏函数

把函数部分参数固定下来,形成一个新函数返回

例:

import functools
import inspect

def add(x,y,*args)->int:
    print(args)
    return x+y

newadd=functools.partial(add,1,3,6,5)

print(newadd(7))
print(inspect.signature(newadd))
print(newadd(7,10))
print(inspect.signature(newadd))
print(newadd())
print(inspect.signature(newadd))

#输出如下:
(6, 5, 7)
4
(*args) -> int
(6, 5, 7, 10)
4
(*args) -> int
(6, 5)
4
(*args) -> int

六、习题-base64编码

import base64
import string

# base64 字符集
alphabet=bytes((string.ascii_uppercase+string.ascii_lowercase+string.digits+'+/').encode())

def base64encode(src:str):
    src=bytes(src.encode())
    target=bytearray()
    length=len(src)
    for offset in range(0,length,3):  #把原字符串分成3字节一组,不足补0
        tripe=src[offset:offset+3]
        r = 3-len(tripe)
        if offset+3>length:
            tripe=tripe+b'\x00'*r
        val=int.from_bytes(tripe,'big')  #把每3个字节变成int,进行右移取到每6位的索引
        for i in range(18,-1,-6):
            index=val>>i if i==18 else (val>>i)&0x3f
            target.append(alphabet[index])
    if r:
        target[-r:]=b'='*r
    return(bytes(target))

def base64decode(src:bytes):
    src = bytearray(src)
    target = []
    r = 0
    for j in range(1, 3):
        if src[-j] == 61:  #对bytearray按索引做值比较时,要赋int,ascii里'='十六进制是61
            r += 1
    if r:
        src[-r:]=b'A'*r   #注意src[0:]会全覆盖,此处应避开0
    length=len(src)
    for offset in range(0,length,4):
        #此处val的十进制数值应按6位一个单元进行计算
        val=alphabet.index(src[offset])*64**3+alphabet.index(src[offset+1])*64**2+alphabet.index(src[offset+2])*64+alphabet.index(src[offset+3])
        for i in range(16,-1,-8):
            index=val>>i if i==16 else (val>>i)&0xff
            target.append(chr(index))
    if r:
        for i in range(r):
            target.pop()
    return(''.join(target))


for i in ('abc','ab','a','abcd'):
    print("本次base64版:",i,base64encode(i),base64decode(base64encode(i)))
    print("默认base64版:",i,base64.b64encode(i.encode()),base64.b64decode(base64.b64encode(i.encode())))

#运行结果为:
本次base64版: abc b'YWJj' abc
默认base64版: abc b'YWJj' b'abc'
本次base64版: ab b'YWI=' ab
默认base64版: ab b'YWI=' b'ab'
本次base64版: a b'YQ==' a
默认base64版: a b'YQ==' b'a'
本次base64版: abcd b'YWJjZA==' abcd
默认base64版: abcd b'YWJjZA==' b'abcd'

关于解码方法的新思路:7=(1<<2)+(1<<1)+(1<<0)

import base64
import string

alphabet=bytes((string.ascii_uppercase+string.ascii_lowercase+string.digits+'+/').encode())
alphabet_dict=dict(zip(alphabet,range(64)))  #方便解码时查找索引

def base64encode(src:str):
    src=bytes(src.encode())
    target=bytearray()
    length=len(src)
    for offset in range(0,length,3):
        tripe=src[offset:offset+3]
        r = 3-len(tripe)
        if offset+3>length:
            tripe=tripe+b'\x00'*r
        val=int.from_bytes(tripe,'big')
        for i in range(18,-1,-6):
            index=val>>i if i==18 else (val>>i)&0x3f
            target.append(alphabet[index])
    if r:
        target[-r:]=b'='*r
    return(bytes(target))

def base64decode(src:bytes):   
    target=bytearray()   #目标可直接设成bytearray序列,src因为取切片的关系也不用转成bytearray
    length=len(src)
    for offset in range(0,length,4):
        block=src[offset:offset+4]
        temp=0x00   #把temp定义成十六进制的数字0
        for i in range(4):
            index=alphabet_dict.get(block[-1-i])  #不用设置缺省值,默认none,若找不到索引值说明是“=”
            if index:  #不管是“=”还是“A”,都进不来,因为它们值为0,做乘法和加法运算都对结果无影响,可直接过滤掉
                temp+=index<<(6*i)
        target.extend(temp.to_bytes(3,'big'))   #target是bytearray序列,temp为24字节,只能取长度3转一次,得到的3个bytes逐个迭代到target中,需用extend而不是append
    return bytes(target.rstrip(b'\x00'))  #最后的ascii码0需去掉,否则b'YQ=='会显示b'a\x00\x00',而不是b'a'

for i in ('abc','ab','a','abcd'):
    print("mylocal_base64:",i,base64encode(i),base64decode(base64encode(i)))
    print("default_base64:",i,base64.b64encode(i.encode()),base64.b64decode(base64.b64encode(i.encode())))

运行结果:

mylocal_base64: abc b'YWJj' b'abc'
default_base64: abc b'YWJj' b'abc'
mylocal_base64: ab b'YWI=' b'ab'
default_base64: ab b'YWI=' b'ab'
mylocal_base64: a b'YQ==' b'a'
default_base64: a b'YQ==' b'a'
mylocal_base64: abcd b'YWJjZA==' b'abcd'
default_base64: abcd b'YWJjZA==' b'abcd'

七、习题-求最长公共子串

#矩阵法:只需遍历一次

def find_substr(str1,str2):
    m=len(str1)
    n=len(str2)
    maxlen=0
    if m>n:
        str1,str2=str2,str1
    l=[[0]*m for _ in range(n)]
    for i in range(n):
        for j in range(m):
            if str2[i]==str1[j]:
                if i==0 or j==0:
                    l[i][j]=1
                else:
                    l[i][j]=l[i-1][j-1]+1 #进到此分支说明不在边缘
                if l[i][j]>maxlen:
                    maxlen=l[i][j] 
                    maxindex=j
    return(str1[maxindex-maxlen+1:maxindex+1])

print(find_substr('mnabcf','sdabcfed'))
#字符串法:从较短子串的全长开始,逐步减少长度,与较长子串进行比较

def find_substr(str1,str2):
    m=len(str1)
    n=len(str2)
    maxlen=0
    if m>n:
        str1,str2=str2,str1
    for step in range(m,0,-1):
        for start in range(m-step+1):
            s=str1[start:start+step]
            if str2.find(s)!=-1:
                return s

print(find_substr('mnabcf','sdabcfed'))

八、习题-functools.lru_cache和缓存器的变形

lru_cache 缓存,属于带参装饰器,缓存具有时效性
在此基础上做一个跟传参顺序无关的缓存器,实现过期被清理

from functools import wraps
import inspect
import datetime,time

def logger(fn):
    def wrapper(*args,**kwargs):
        start=datetime.datetime.now()
        ret=fn(*args,**kwargs)
        delta=(datetime.datetime.now()-start).total_seconds()
        print("Function {} took {}s".format(fn.__name__,delta))
        return ret
    return wrapper

def local_cache(duration=5):
    def _local_cache(fn):
        local_cache={}
        @wraps(fn)
        def wrapper(*args,**kwargs):
            def del_outkey():
                del_keys = []  # 过期清理
                for k, (_, t) in local_cache.items():
                    if datetime.datetime.now().timestamp() - t > duration:
                        del_keys.append(k)
                for key in del_keys:  # 注意字典不能在遍历的时候清除数据
                    local_cache.pop(key)

            def make_key(args, kwargs):
                sig = inspect.signature(fn)
                params = sig.parameters
                param_dict = {}
                param_dict.update(zip(params.keys(), args))  # 加入位置参数
                param_dict.update(kwargs)  # 加入关键字参数
                for k, v in params.items():  # 加入缺省值
                    if k not in param_dict.keys():
                        param_dict[k] = v.default
                target = tuple(sorted(param_dict.items()))  # 生成排序好的目标tuple
                return target

            del_outkey()
            key = make_key(args, kwargs)

            if key not in local_cache.keys():
                local_cache[key] = fn(*args, **kwargs), datetime.datetime.now().timestamp()
            return local_cache[key][0]

        return wrapper

    return _local_cache

@logger
@local_cache()
def add(x,y,z=6):
    time.sleep(2)
    return x+y+z

print(add(y=5,x=4,z=5))
print(add(4,5,z=5))
print(add(4,y=5,z=5))
print(add(y=5,x=4,z=6))
print(add(4,y=5))
print(add(4,5))

#结果如下:

Function add took 2.003774s
14
Function add took 0.000187s
14
Function add took 0.000218s
14
Function add took 2.000844s
15
Function add took 0.00017s
15
Function add took 0.000232s
15

九、习题-任务的分发调度

#命令的分发调度,注册命令后,输入关键词即可调用

def cmds_dispatcher(defaultfn=lambda :print("Unknown Command")):
    commands={}

    def register(cmd):
        def _register(fn):
            commands[cmd]=fn
        return _register

    def dispatcher():
        while True:
            cmd = input('input:')
            if cmd == '':
                print('Bye')
                return
            commands.get(cmd,defaultfn)()

    return register,dispatcher

myreg,myrun=cmds_dispatcher()

@myreg('mg')   #register('mg')(fn)  
def foo1():
    print('magedu')
@myreg('py')
def foo2():
    print('python')

myrun()

采用偏函数,可将参数固定,代码如下:

from functools import partial

def cmds_dispatcher():
    commands={}

    def default_fn():
        print("unknown command")

    def register(cmd,*args,**kwargs):
        def _register(fn):
            func=partial(fn,*args,**kwargs)
            commands[cmd]=func
            return func
        return _register

    def dispatcher():  #调度器
        while True:
            cmd = input('input:')
            if cmd.strip() == '':
                print('Bye')
                return
            commands.get(cmd,default_fn)()

    return register,dispatcher

myreg,myrun=cmds_dispatcher()

@myreg('mg1',x=1,y=2,z=4)
@myreg('mg',x=1,y=2,z=3)   #register('mg')(fn)  
def foo1(x,y,z):
    print('magedu',x,y,z)

@myreg('py',3,b=4)
def foo2(a,b=1):
    print('python',a,b)

myrun()

猜你喜欢

转载自blog.csdn.net/weixin_42196568/article/details/82118523