django rest framework之节流的源码流程剖析

视图类:

 1 class UserViewset(BaseView):
 2     '''
 3         create:
 4         创建用户
 5         retrieve:
 6 
 7     '''
 8     queryset = User.objects.all()
 9     throttle_classes = [UserRateThrottle] #添加节流类
10     authentication_classes = (JSONWebTokenAuthentication, authentication.SessionAuthentication)
11     def get_serializer_class(self):
12         self.dispatch
13         if self.action == "retrieve":
14             return UserDetailSerializer
15         elif self.action == "create":
16             return UserRegSerializer
17 
18         return UserDetailSerializer
19 
20     def get_permissions(self):
21         if self.action == "retrieve":
22             return [permissions.IsAuthenticated()]
23         elif self.action == "create":
24             return []
25 
26         return []
27 
28     def create(self, request, *args, **kwargs):
29         serializer = self.get_serializer(data=request.data)
30         serializer.is_valid(raise_exception=True)
31         user = self.perform_create(serializer)
32         re_dict = serializer.data
33         payload = jwt_payload_handler(user)
34         re_dict["token"] = jwt_encode_handler(payload)
35         re_dict["name"] = user.name if user.name else user.username
36 
37         headers = self.get_success_headers(serializer.data)
38         return Response(re_dict, status=status.HTTP_201_CREATED, headers=headers)
39 
40     def get_object(self):
41         return self.request.user
42 
43     def perform_create(self, serializer):
44         return serializer.save()

通权限类一样在中调用:

1     def check_throttles(self, request):
2         """
3         Check if request should be throttled.
4         Raises an appropriate exception if the request is throttled.
5         """
6         for throttle in self.get_throttles():
7             if not throttle.allow_request(request, self): #验证是不是要被节流
8                 self.throttled(request, throttle.wait())  #验证不通过就返回响应

内置节流类:

  1 class BaseThrottle(object):
  2     """
  3     Rate throttling of requests.
  4     """
  5 
  6     def allow_request(self, request, view):
  7         """
  8         Return `True` if the request should be allowed, `False` otherwise.
  9         """
 10         raise NotImplementedError('.allow_request() must be overridden')
 11 
 12     def get_ident(self, request): 获取访问IP
 13         """
 14         Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR
 15         if present and number of proxies is > 0. If not use all of
 16         HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR.
 17         """
 18         xff = request.META.get('HTTP_X_FORWARDED_FOR')
 19         remote_addr = request.META.get('REMOTE_ADDR')
 20         num_proxies = api_settings.NUM_PROXIES
 21 
 22         if num_proxies is not None:
 23             if num_proxies == 0 or xff is None:
 24                 return remote_addr
 25             addrs = xff.split(',')
 26             client_addr = addrs[-min(num_proxies, len(addrs))]
 27             return client_addr.strip()
 28 
 29         return ''.join(xff.split()) if xff else remote_addr
 30 
 31     def wait(self):
 32         """
 33         Optionally, return a recommended number of seconds to wait before
 34         the next request.
 35         """
 36         return None
 37 
 38 
 39 class SimpleRateThrottle(BaseThrottle):
 40     """
 41     A simple cache implementation, that only requires `.get_cache_key()`
 42     to be overridden.
 43 
 44     The rate (requests / seconds) is set by a `rate` attribute on the View
 45     class.  The attribute is a string of the form 'number_of_requests/period'.
 46 
 47     Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day')
 48 
 49     Previous request information used for throttling is stored in the cache.
 50     """
 51     cache = default_cache
 52     timer = time.time
 53     cache_format = 'throttle_%(scope)s_%(ident)s'
 54     scope = None
 55     THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES  #获取配置
 56 
 57     def __init__(self):
 58         if not getattr(self, 'rate', None):
 59             self.rate = self.get_rate()
 60         self.num_requests, self.duration = self.parse_rate(self.rate)
 61 
 62     def get_cache_key(self, request, view):
 63         """
 64         Should return a unique cache-key which can be used for throttling.
 65         Must be overridden.
 66 
 67         May return `None` if the request should not be throttled.
 68         """
 69         raise NotImplementedError('.get_cache_key() must be overridden')
 70 
 71     def get_rate(self): #获取配置的参数
 72         """
 73         Determine the string representation of the allowed request rate.
 74         """
 75         if not getattr(self, 'scope', None):
 76             msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
 77                    self.__class__.__name__)
 78             raise ImproperlyConfigured(msg)
 79 
 80         try:
 81             return self.THROTTLE_RATES[self.scope]
 82         except KeyError:
 83             msg = "No default throttle rate set for '%s' scope" % self.scope
 84             raise ImproperlyConfigured(msg)
 85 
 86     def parse_rate(self, rate): #获取定义里的节流策略如3/m,每分钟访问3次
 87         """
 88         Given the request rate string, return a two tuple of:
 89         <allowed number of requests>, <period of time in seconds>
 90         """
 91         if rate is None:
 92             return (None, None)
 93         num, period = rate.split('/')
 94         num_requests = int(num)
 95         duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
 96         return (num_requests, duration)
 97 
 98     def allow_request(self, request, view):
 99         """
100         Implement the check to see if the request should be throttled.
101 
102         On success calls `throttle_success`.
103         On failure calls `throttle_failure`.
104         """
105         if self.rate is None:
106             return True
107 
108         self.key = self.get_cache_key(request, view) #获取存储的key
109         if self.key is None:
110             return True
111 
112         self.history = self.cache.get(self.key, []) #获取访问历史
113         self.now = self.timer()
114 
115         # Drop any requests from the history which have now passed the
116         # throttle duration
117         while self.history and self.history[-1] <= self.now - self.duration:
118             self.history.pop()
119         if len(self.history) >= self.num_requests: #判断
120             return self.throttle_failure()
121         return self.throttle_success()
122 
123     def throttle_success(self):
124         """
125         Inserts the current request's timestamp along with the key
126         into the cache.
127         """
128         self.history.insert(0, self.now)
129         self.cache.set(self.key, self.history, self.duration)
130         return True
131 
132     def throttle_failure(self):
133         """
134         Called when a request to the API has failed due to throttling.
135         """
136         return False
137 
138     def wait(self):  #返回响应
139         """
140         Returns the recommended next request time in seconds.
141         """
142         if self.history:
143             remaining_duration = self.duration - (self.now - self.history[-1])
144         else:
145             remaining_duration = self.duration
146 
147         available_requests = self.num_requests - len(self.history) + 1
148         if available_requests <= 0:
149             return None
150 
151         return remaining_duration / float(available_requests)
152 class UserRateThrottle(SimpleRateThrottle): 153 """ 154 Limits the rate of API calls that may be made by a given user. 155 156 The user id will be used as a unique cache key if the user is 157 authenticated. For anonymous requests, the IP address of the request will 158 be used. 159 """ 160 scope = 'user' 161 162 def get_cache_key(self, request, view): 163 if request.user.is_authenticated: #如果用户是登录后的就返回用户的id 164 ident = request.user.pk 165 else: 166 ident = self.get_ident(request) #返回请求的ip 167 168 return self.cache_format % { 169 'scope': self.scope, 170 'ident': ident 171 }

猜你喜欢

转载自www.cnblogs.com/arrow-kejin/p/9988271.html