Django-drf架构 认证、权限、节流的详解

Django-drf架构 认证、权限、节流的详解

一、Token认证:

Token是服务端产生,如果前端使用用户名或密码向服务器请求认证,服务端认证成功,那么在服务端会返回Token给前端。前端可以在每次请求的时候带上Token证明自己的合法请求。如果Token在服务端持久化(比如存储Mysql或Redis中),那么它就是一个永久的身份令牌。

使用Token可以解决哪些问题?

1.Token完全由应用管理,所以它可以避开同源策略

2.Token可以避免CSRF攻击

3.Token可以是无状态的,可以在多个服务之间共享

图解:

  • 在登录时,在服务器生成token(可以存入Redis中,设置token的过期时间),登录成功后,将token返回给前端,前端再次请求服务器接口时,将会校验token是否正确,若不正确,不可以进行访问。
    在这里插入图片描述

二、DRF 认证:

1.源码流程详解:

  • 所有客户端请求到来,都要执行dispatch()方法,dispatch()方法会根据请求方式的不同触发GET/POST/PUT/DELETE等方法

  • dispatch方法的源码:

    第一步:initialize_request(request, *args, **kwargs):对request进行加工

    第二步:initial(request, *args, **kwargs):

    ​ a.处理版权信息

    ​ b.认证

    ​ c.权限

    ​ d.请求用户进行访问频率的限制

    第三步:handler(request, *args, **kwargs):执行GET/PUT/POST/DELETE方法

    第四步:finalize_response(request, response, *args, **kwargs):对第三步的返回结果进行加工

    def dispatch(self, request, *args, **kwargs):
            """
            `.dispatch()` is pretty much the same as Django's regular dispatch,
            but with extra hooks for startup, finalize, and exception handling.
            """
            self.args = args
            self.kwargs = kwargs
    
            # 第一步:对request进行加工(添加数据)
            request = self.initialize_request(request, *args, **kwargs)
            self.request = request
            self.headers = self.default_response_headers  # deprecate?
    
            try:
                #第二步:
                    # 处理版权信息
                    # 认证
                    # 权限
                    # 请求用户进行访问频率的限制
                self.initial(request, *args, **kwargs)
    
                # Get the appropriate handler method
                if request.method.lower() in self.http_method_names:
                    handler = getattr(self, request.method.lower(),
                                      self.http_method_not_allowed)
                else:
                    handler = self.http_method_not_allowed
    
                # 第三步、执行:get/post/put/delete函数
                response = handler(request, *args, **kwargs)
    
            except Exception as exc:
                response = self.handle_exception(exc)
    
            #第四步、 对返回结果再次进行加工
            self.response = self.finalize_response(request, response, *args, **kwargs)
            return self.response
    

    接下来对每个步骤进行详解:

    第一步:对request进行加工(添加数据):

    def initialize_request(self, request, *args, **kwargs):
            """
            Returns the initial request object.
            """
            # 请求弄成一个字典返回了
            parser_context = self.get_parser_context(request)
    
            return Request(
                request,
                parsers=self.get_parsers(),  # 解析数据,默认的有三种方式,可点进去看
                #self.get_authenticator优先找自己的,没有就找父类的
                authenticators=self.get_authenticators(), # 获取认证相关的所有类并实例化,传入request对象供Request使用
                negotiator=self.get_content_negotiator(),
                parser_context=parser_context
            )
    

    获取认证相关的类具体: authenticators=self.get_authenticators()

    def get_authenticators(self):
            """
            Instantiates and returns the list of authenticators that this view can use.
            """
            #返回的是对象列表
            return [auth() for auth in self.authentication_classes]  # authentication_classes:就是在View视图 authentication_classes = [xxxxx, xxxx],将每个类实例化
    

    查看认证的类:self.authentication_classes

    authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES  # 默认的,如果自己有会优先执行自己的
    

    接着点击:api_settings

    api_settings = APISettings(None, DEFAULTS, IMPORT_STRINGS)  #点击继承的DEFAULTSDEFAULTS = {
        # Base API policies
        'DEFAULT_AUTHENTICATION_CLASSES': (
            'rest_framework.authentication.SessionAuthentication',   #这时候就找到了他默认认证的类了,可以导入看看
            'rest_framework.authentication.BasicAuthentication'
        ),
    

    导入SessionAuthentication、BasicAuthentication类:

    from rest_framework.authentication import SessionAuthentication
    from rest_framework.authentication import BaseAuthentication
    

    看看authenticate方法和authenticate_header方法

    class BaseAuthentication(object):
        """
        All authentication classes should extend BaseAuthentication.
        """
    
        def authenticate(self, request):
            """
            Authenticate the request and return a two-tuple of (user, token).
            """
            raise NotImplementedError(".authenticate() must be overridden.")
    
        def authenticate_header(self, request):
            """
            Return a string to be used as the value of the `WWW-Authenticate`
            header in a `401 Unauthenticated` response, or `None` if the
            authentication scheme should return `403 Permission Denied` responses.
            """
            pass
    

    第二步:查看self.initial(request, *args, * *kwargs):

    def initial(self, request, *args, **kwargs):
            """
            Runs anything that needs to occur prior to calling the method handler.
            """
            self.format_kwarg = self.get_format_suffix(**kwargs)
    
            # Perform content negotiation and store the accepted info on the request
            neg = self.perform_content_negotiation(request)
            request.accepted_renderer, request.accepted_media_type = neg
    
            # Determine the API version, if versioning is in use.
            #2.1 处理版本信息
            version, scheme = self.determine_version(request, *args, **kwargs)
            request.version, request.versioning_scheme = version, scheme
    
            # Ensure that the incoming request is permitted
            #2.2 认证
            self.perform_authentication(request)
            # 2.3 权限
            self.check_permissions(request)
            # 2.4 请求用户进行访问频率的限制
            self.check_throttles(request)
    

    在查看一下认证的self.perform_authentication(request)

    def perform_authentication(self, request):
            """
            Perform authentication on the incoming request.
    
            Note that if you override this and simply 'pass', then authentication
            will instead be performed lazily, the first time either
            `request.user` or `request.auth` is accessed.
            """
            request.user   #执行request的user,这是的request已经是加工后的request了
    

    然后找到request.user这个属性:

     @property
        def user(self):
            """
            Returns the user associated with the current request, as authenticated
            by the authentication classes provided to the request.
            """
            if not hasattr(self, '_user'):
                with wrap_attributeerrors():
                    self._authenticate()  # 开始用户认证咯
            return self._user  #返回user
    

    执行self._authenticate() 开始用户认证,如果验证成功后返回元组: (用户,用户Token)

    def _authenticate(self):
            """
            Attempt to authenticate the request using each authentication instance
            in turn.
            """
            #循环对象列表
            for authenticator in self.authenticators:
                try:
                    #执行每一个对象的authenticate 方法
                    user_auth_tuple = authenticator.authenticate(self)   
                except exceptions.APIException:
                    self._not_authenticated()
                    raise
    
                if user_auth_tuple is not None:
                    self._authenticator = authenticator
                    self.user, self.auth = user_auth_tuple  #返回一个元组,user,和auth,赋给了self,
                    # 只要实例化Request,就会有一个request对象,就可以request.user,request.auth了
                    return
    
            self._not_authenticated()
    

    在user_auth_tuple = authenticator.authenticate(self) 进行验证,如果验证成功,执行类里的authenticatie方法

    如果用户没有认证成功:self._not_authenticated()

    def _not_authenticated(self):
            """
            Set authenticator, user & authtoken representing an unauthenticated request.
    
            Defaults are None, AnonymousUser & None.
            """
            #如果跳过了所有认证,默认用户和Token和使用配置文件进行设置
            self._authenticator = None  #
    
            if api_settings.UNAUTHENTICATED_USER:
                self.user = api_settings.UNAUTHENTICATED_USER() # 默认值为:匿名用户AnonymousUser
            else:
                self.user = None  # None 表示跳过该认证
    
            if api_settings.UNAUTHENTICATED_TOKEN:
                self.auth = api_settings.UNAUTHENTICATED_TOKEN()  # 默认值为:None
            else:
                self.auth = None
    
        # (user, token)
        # 表示验证通过并设置用户名和Token;
        # AuthenticationFailed异常
    

    第三步:执行GET/POST/PUT/DELETE等方法

    第四步:对第三步返回结果进行加工

2.局部认证Demo:

写一个demo,大体思路流程如下:

1.登录时,将token(随机字符串+时间戳) 存储到redis数据库中,然后将token进行返回给客户端
2.客户端将token放到请求头中,访问服务器接口
3.服务器每个接口 authentication_classes = (EasyAuthentication,) 进行认证
4.自定义写EasyAuthentication类,从请求头中获取token
5.将取出的token去Redis数据库中查看,若存在把redis中对应的值取出来
6.进行返回
7.token认证成功,可以获取到API接口所有的数据

在类视图中(继承APIView),加上认证:

class MyCourseListViews(APIView):
    authentication_classes = [EasyAuthentication,]

    def get(self, request, *args, **kwargs):
		...................逻辑处理..............	

写一个认证类:

from rest_framework.authentication import BaseAuthentication

class EasyAuthentication(BaseAuthentication):
    redis = SingleRedis(db=2).conn

    def authenticate(self, request):
        """从请求中获取token和user_id,与redis中取结果进行对比;验证通过重置有效期
        """
        token = request.META.get('HTTP_AUTHORIZATION')

        if not token:
            raise exceptions.NotAuthenticated(error_constants.ERR_TOKEN_ERROR)

        val = self.redis.get(token)
        if val is None:
            raise exceptions.AuthenticationFailed(error_constants.ERR_TOKEN_ERROR)

        # user_info: {'user_id': xxx, 'role': xxx, 'expire': xxx, 'mute': xxx}
        user_info = json.loads(val)

        self.redis.expire(token, user_info['expire'])
        return user_info, token

3.配置全局的认证:

在settings.py中,设置全局认证:

#设置全局认证
REST_FRAMEWORK = {
    "DEFAULT_AUTHENTICATION_CLASSES":['api.utils.auth.EasyAuthentication',]   # 路径 + 认证类
}

注意:在settings.py里面设置的全局认证,所有业务都需要经过认证,如果不想让某个API需要认证,将设置为空,如下所示

# 在登录的类视图,要获取token,所以不需要经过认证,可以将authentication_classes = [] 即可
class LoginViews(APIView):
    authentication_classes = []

    def get(self, request, *args, **kwargs):
		...................逻辑处理..............	

4.总结:

1.创建认证类:继承BaseAuthentication、重写authenticate方法

2.authenticate()返回值(三种)

  • None:当前认证不管,等下一个认证来执行
  • raise exceptions.AuthenticationFailed('用户认证失败')导入:from rest_framework import exceptions
  • 有返回值元祖形式:(元素1,元素2)元素1复制给request.user、元素2复制给request.auth

3.局部使用:authentication_classes = [EasyAuthentication,]

4.全局使用:在settings.py中配置

#设置全局认证
REST_FRAMEWORK = {
    "DEFAULT_AUTHENTICATION_CLASSES":['api.utils.auth.EasyAuthentication',]   # 路径 + 认证类
}

三、DRF 权限:

1.源码流程详解:

  • 大致流程同认证。

2.局部权限Demo:

写一个demo,大体流程如下:

1.在类视图OrderView中,区分访问权限,只有VIP才可以进行请求。
2.写一个权限类,继承BasePermission。
3.当每个角色进行请求OrderView时,校验是否有权限可以访问。

写一个OrderView类视图:

class OrderView(APIView):
    '''订单相关业务(只有VIP用户才可以访问)'''
    permission_classes = [MyPremission,]    # 不用全局的权限配置的话,这里就要写自己的局部权限
    def get(self,request,*args,**kwargs):
        print(request.user)  # request.user 会携带一个字段,
		# 比如user_type=1就是VIP user_type=0就是普通用户
        return HttpResponse('用户信息')

在一个权限类:

from rest_framework.permissions import BasePermission

class MyPremission(BasePermission):
 	message = "必须是VIP才可以进行访问"  # 重写返回的提示信息

    def has_permission(self,request,view):
		# 如果user_type = 0 说明是普通用户, 返回False 页面也显示我们重写的message属性
        if request.user.user_type == 0:
            return False
		# 返回True 就是权限认证成功
        return True

3.配置全局的权限:

在settings.py中,配置如下:

# 全局
REST_FRAMEWORK = {
   "DEFAULT_PERMISSION_CLASSES":['api.utils.permission.MyPremission'], # 路径 + 权限类
}

注意:在settings.py里面设置的全局认证,所有业务都需要经过校验权限,如果不想让某个API需要权限都可以进行访问,将设置为空,如下所示

class OrderView(APIView):
    '''订单相关业务(所有用户都可以进行访问)'''
    permission_classes = []  # 设置为[] 即可
    def get(self,request,*args,**kwargs):
        print(request.user)  # request.user 会携带一个字段,
		# 比如user_type=1就是VIP user_type=0就是普通用户
        return HttpResponse('用户信息')

4.总结:

1.使用:必须继承BasePermission类、必须实现has_permission方法

2.返回值:True 有权访问、False 无权访问

3.局部:permission_classes = [MyPremission,]

4.全局:在settings.py 配置

#全局
REST_FRAMEWORK = {
   "DEFAULT_PERMISSION_CLASSES":['api.utils.permission.MyPremission'],
}

四、DRF 节流:

1.源码流程详解:

  • 大致流程同认证。

2.自定义节流方法:

需求60s内允许访问3次

from rest_framework.throttling import BaseThrottle
import time
VISIT_RECORD = {}   #保存访问记录

class VisitThrottle(BaseThrottle):
    '''60s内只能访问3次'''

    def __init__(self):
        self.history = None   #初始化访问记录

    def allow_request(self,request,view):
        #获取用户ip (get_ident)
        remote_addr = self.get_ident(request)
        ctime = time.time()
        #如果当前IP不在访问记录里面,就添加到记录
        if remote_addr not in VISIT_RECORD:
            VISIT_RECORD[remote_addr] = [ctime,]     #键值对的形式保存
            return True    #True表示可以访问
        #获取当前ip的历史访问记录
        history = VISIT_RECORD.get(remote_addr)
        #初始化访问记录
        self.history = history

        #如果有历史访问记录,并且最早一次的访问记录离当前时间超过60s,就删除最早的那个访问记录,
        #只要为True,就一直循环删除最早的一次访问记录
        while history and history[-1] < ctime - 60:
            history.pop()
        #如果访问记录不超过三次,就把当前的访问记录插到第一个位置(pop删除最后一个)
        if len(history) < 3:
            history.insert(0,ctime)
            return True

    def wait(self):
        '''还需要等多久才能访问'''
        ctime = time.time()
        return 60 - (ctime - self.history[-1])

在settings.py中配置节流:

#全局
REST_FRAMEWORK = {
    #节流
    "DEFAULT_THROTTLE_CLASSES":['api.utils.throttle.VisitThrottle'],  # 路径 + 节流类
}

在60s内连续访问超过3次接口时,就会出现:

# 告诉还用6s就可以进行访问
{
    "detail": "Request was throttled. Expected avaiable in 6 seconds"
}

3.介绍内置节流类:

BaseThrottle:需要自己写allow_request和wait方法 (get_ident就是获取ip),上面示例用的就是BaseThrottle

class BaseThrottle(object):
    """
    Rate throttling of requests.
    """

    def allow_request(self, request, view):
        """
        Return `True` if the request should be allowed, `False` otherwise.
        """
        raise NotImplementedError('.allow_request() must be overridden')

    def get_ident(self, request):
        """
        Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR
        if present and number of proxies is > 0. If not use all of
        HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR.
        """
        xff = request.META.get('HTTP_X_FORWARDED_FOR')
        remote_addr = request.META.get('REMOTE_ADDR')
        num_proxies = api_settings.NUM_PROXIES

        if num_proxies is not None:
            if num_proxies == 0 or xff is None:
                return remote_addr
            addrs = xff.split(',')
            client_addr = addrs[-min(num_proxies, len(addrs))]
            return client_addr.strip()

        return ''.join(xff.split()) if xff else remote_addr

    def wait(self):
        """
        Optionally, return a recommended number of seconds to wait before
        the next request.
        """
        return None

SimpleRateThrottle

class SimpleRateThrottle(BaseThrottle):
    """
    A simple cache implementation, that only requires `.get_cache_key()`
    to be overridden.

    The rate (requests / seconds) is set by a `rate` attribute on the View
    class.  The attribute is a string of the form 'number_of_requests/period'.

    Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day')

    Previous request information used for throttling is stored in the cache.
    """
    cache = default_cache
    timer = time.time
    cache_format = 'throttle_%(scope)s_%(ident)s'
    scope = None   #这个值自定义,写什么都可以
    THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES

    def __init__(self):
        if not getattr(self, 'rate', None):
            self.rate = self.get_rate()
        self.num_requests, self.duration = self.parse_rate(self.rate)

    def get_cache_key(self, request, view):
        """
        Should return a unique cache-key which can be used for throttling.
        Must be overridden.

        May return `None` if the request should not be throttled.
        """
        raise NotImplementedError('.get_cache_key() must be overridden')

    def get_rate(self):
        """
        Determine the string representation of the allowed request rate.
        """
        if not getattr(self, 'scope', None):
            msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
                   self.__class__.__name__)
            raise ImproperlyConfigured(msg)

        try:
            return self.THROTTLE_RATES[self.scope]
        except KeyError:
            msg = "No default throttle rate set for '%s' scope" % self.scope
            raise ImproperlyConfigured(msg)

    def parse_rate(self, rate):
        """
        Given the request rate string, return a two tuple of:
        <allowed number of requests>, <period of time in seconds>
        """
        if rate is None:
            return (None, None)
        num, period = rate.split('/')
        num_requests = int(num)
        duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
        return (num_requests, duration)

    def allow_request(self, request, view):
        """
        Implement the check to see if the request should be throttled.

        On success calls `throttle_success`.
        On failure calls `throttle_failure`.
        """
        if self.rate is None:
            return True

        self.key = self.get_cache_key(request, view)
        if self.key is None:
            return True

        self.history = self.cache.get(self.key, [])
        self.now = self.timer()

        # Drop any requests from the history which have now passed the
        # throttle duration
        while self.history and self.history[-1] <= self.now - self.duration:
            self.history.pop()
        if len(self.history) >= self.num_requests:
            return self.throttle_failure()
        return self.throttle_success()

    def throttle_success(self):
        """
        Inserts the current request's timestamp along with the key
        into the cache.
        """
        self.history.insert(0, self.now)
        self.cache.set(self.key, self.history, self.duration)
        return True

    def throttle_failure(self):
        """
        Called when a request to the API has failed due to throttling.
        """
        return False

    def wait(self):
        """
        Returns the recommended next request time in seconds.
        """
        if self.history:
            remaining_duration = self.duration - (self.now - self.history[-1])
        else:
            remaining_duration = self.duration

        available_requests = self.num_requests - len(self.history) + 1
        if available_requests <= 0:
            return None

        return remaining_duration / float(available_requests)

我们可以通过继承SimpleRateThrottle类,来实现节流,会更加的简单,因为SimpleRateThrottle里面都帮我们写好了

4.使用SimpleRateThrottle类实现节流:

from rest_framework.throttling import SimpleRateThrottle

class VisitThrottle(SimpleRateThrottle):
    '''匿名用户60s只能访问三次(根据ip)'''
    scope = 'throttle'   #这里面的值,自己随便定义,settings里面根据这个值配置throttle

    def get_cache_key(self, request, view):
        #通过ip限制节流
        return self.get_ident(request)

class UserThrottle(SimpleRateThrottle):
    '''登录用户60s可以访问10次'''
    scope = 'userThrottle'    #这里面的值,自己随便定义,settings里面根据这个值配置userThrottle

    def get_cache_key(self, request, view):
        return request.user.user_id

在settings.py配置节流:

#全局
REST_FRAMEWORK = {
    # 设置全局节流
    "DEFAULT_THROTTLE_CLASSES":['api.utils.throttle.UserThrottle'],   #全局配置,登录用户节流限制(10/m)
    # 设置访问频率
    "DEFAULT_THROTTLE_RATES":{
        'throttle':'3/m',         #没登录用户3/m,throttle就是scope定义的值,通过IP地址
        'userThrottle':'10/m',    #登录用户10/m,userThrottle就是scope定义的值, 通过user_id
    }
}

在views.py,设置局部配置方法:

class AuthView(APIView):
    # 默认的节流是登录用户(10/m),AuthView不需要登录,这里用匿名用户的节流(3/m)
    throttle_classes = [VisitThrottle,]

简单说明:

1.在settings.py 中的'api.utils.throttle.UserThrottle',是全局配置(根据登录的user_id,进行限制 10/m)

2.在setiings.py中的"DEFAULT_THROTTLE_RATES",设置访问频率

3.throttle_classes = [VisitThrottle,] ,局部使用,不需要再settings.py配置全局

4.总结:

1.创建节流类,继承BaseThrottle类, 实现:allow_request方法 ,wait 方法

2.创建类,继承SimpleRateThrottle类, 实现: get_cache_key方法, scope='throttle' (配置文件中的throttle)

3.全局与局部:

#全局
REST_FRAMEWORK = {
    # 设置全局节流
    "DEFAULT_THROTTLE_CLASSES":['api.utils.throttle.UserThrottle'],   #全局配置,登录用户节流限制(10/m)
    # 设置访问频率
    "DEFAULT_THROTTLE_RATES":{
        'throttle':'3/m',         #没登录用户3/m,throttle就是scope定义的值,通过IP地址
        'userThrottle':'10/m',    #登录用户10/m,userThrottle就是scope定义的值, 通过user_id
    }
}

# 局部:再类视图中添加
throttle_classes = [VisitThrottle,]

猜你喜欢

转载自blog.csdn.net/Fe_cow/article/details/91489476