比较复杂的一个mock测试

其中测试对象:
特别注意:
api_res = self.cache_api_wrapper.get_api(self.network_api_svc.network_whole_features, company_name)
与以前的直接通过调用api获取网络图特征不同,此处外面还加了多进程get_api函数;
直接调用它会报错!

def get_api(self, function, *params):
        """
        多web进程之间的api(调度、缓存)管理
        """
        key = self.__cache_api.func_and_params_2_key(function, params)
        if self.__redis_helper.exists(key):  # 有缓存读取缓存
            return self.__get_redis_value(key)
        else:  # redis无缓存处理
            if self.__token_cache_db.try_add(key):  # 成功获取到调用api令牌
                api_data = self.__cache_api.get_api(function, *params)
                self.__token_cache_db.delete_by_token_string(key)
                return api_data
            else:  # 未获取到api令牌等阻塞读取到redis 缓存为止
                second_count = 0
                while True:
                    second_count += 1
                    if self.__redis_helper.exists(key):  # 缓存已经成功
                        self.__token_cache_db.delete_by_token_string(key)
                        return self.__get_redis_value(key)
                    if second_count >= 100:  # 超过等待时间重新api获取
                        self.__token_cache_db.delete_by_token_string(key)
                        return self.__cache_api.get_api(function, *params)
                    sleep(1)

测试对象,mock对象:


class NetworkFeaturesSingle(object):
    """
    网络图特征数据收集封装 -- 单一节点
    """
    network_api_svc = NetworkApiSvc()
    cache_api_wrapper = CacheApiWrapper()

    def get_features(self, company_name):
        """
        获取网络图特征
        """

        api_res = self.cache_api_wrapper.get_api(self.network_api_svc.network_whole_features, company_name)

        network_dto = NetworkFeatureDto()

        network_dto.network_share_cancel_cnt = api_res.get('shareOrPosRevokedCnt', 0)
        network_dto.cancel_cnt = api_res.get('frRevokedCnt', 0)
        network_dto.fr_zhi_xing_cnt = api_res.get('frZhixingCnt', 0)
        network_dto.network_share_zhixing_cnt = api_res.get('shareOrPosZhixingCnt', 0)
        network_dto.network_share_judge_doc_cnt = api_res.get('shareOrPosJudgeDocCnt', 0)
        network_dto.net_judgedoc_defendant_cnt = api_res.get('allLinkJudgedocCnt', 0)
        network_dto.judge_doc_cnt = api_res.get('frJudgedocCnt', 0)

        return network_dto

测试脚本,写法1:
报错:AttributeError: _name_

from unittest import TestCase

import mock

from api.network_api_svc import NetworkApiSvc
from biz.biz_utils.cache_api import CacheApi
from biz.integration_api.common.network_features_single import NetworkFeaturesSingle
from common.helper.test_helper import TestHelper
from test.test_clean.test_common.test_utils import TestUtils

class TestNetworkFeaturesSingle(TestCase):
    t_h = TestHelper()
    t_u = TestUtils()

    def tearDown(self):
        CacheApi.REFRESH_CACHE = False

    @mock.patch.object(NetworkApiSvc, 'network_whole_features')
    def test_get_features(self, network_whole_features):
        # given
        network_whole_features.return_value = self.__get_api_value()
        n_f_w = NetworkFeaturesSingle()
        company_name = u'测试公司'
        # when
        network_dto = n_f_w.get_features(company_name)
        # then
        self.assertEqual(1, network_dto.network_share_judge_doc_cnt)
        self.assertEqual(2, network_dto.judge_doc_cnt)
        self.assertEqual(3, network_dto.network_share_cancel_cnt)
        self.assertEqual(4, network_dto.cancel_cnt)
        self.assertEqual(5, network_dto.network_share_zhixing_cnt)
        self.assertEqual(6, network_dto.net_judgedoc_defendant_cnt)
        self.assertEqual(7, network_dto.fr_zhi_xing_cnt)

    def test_get_features_no_mock(self):
        # given
        n_f_w = NetworkFeaturesSingle()
        company_name = u'小米科技有限责任公司'
        # when
        network_dto = n_f_w.get_features(company_name)
        # then
        self.t_u.print_domain(network_dto)
        assert network_dto

    @staticmethod
    def __get_api_value():
        return {
            u'shareOrPosJudgeDocCnt': 1,
            u'frJudgedocCnt': 2,
            u'shareOrPosRevokedCnt': 3,
            u'frRevokedCnt': 4,
            u'shareOrPosZhixingCnt': 5,
            u'allLinkJudgedocCnt': 6,
            u'frZhixingCnt': 7
        }

测试脚本,写法2:

from unittest import TestCase

import mock

from api.network_api_svc import NetworkApiSvc
from biz.biz_utils.cache_api import CacheApi
from biz.integration_api.common.network_features_single import NetworkFeaturesSingle
from common.helper.test_helper import TestHelper
from test.test_clean.test_common.test_utils import TestUtils


class TestNetworkFeaturesSingle(TestCase):
    t_h = TestHelper()
    t_u = TestUtils()

    def tearDown(self):
        CacheApi.REFRESH_CACHE = False

    @mock.patch.object(NetworkApiSvc, 'network_whole_features')
    def test_get_features(self, network_whole_features):
        # given
        company_name = u'测试公司'
        self.t_h.set_api_return_value(locals(),network_whole_features, company_name, self.__get_api_value())
        n_f_w = NetworkFeaturesSingle()
#注释:locals()返回在它之前的所有局部变量,mock的network_whole_features以及company_name
        # when
        network_dto = n_f_w.get_features(company_name)
        # then
        self.assertEqual(1, network_dto.network_share_judge_doc_cnt)
        self.assertEqual(2, network_dto.judge_doc_cnt)
        self.assertEqual(3, network_dto.network_share_cancel_cnt)
        self.assertEqual(4, network_dto.cancel_cnt)
        self.assertEqual(5, network_dto.network_share_zhixing_cnt)
        self.assertEqual(6, network_dto.net_judgedoc_defendant_cnt)
        self.assertEqual(7, network_dto.fr_zhi_xing_cnt)

**local()用法:**
**locals() 函数会以字典类型返回当前位置的全部局部变量。**
对于函数, 方法, lambda 函式, 类, 以及实现了 __call__ 方法的类实例, 它都返回 True
>>>def runoob(arg):    # 两个局部变量:arg、z
...     z = 1
...     print (locals())
... 
>>> runoob(4)
{'z': 1, 'arg': 4}      # 返回一个名字/值对的字典
>>>

 当前locals()返回的是:
 {'network_whole_features': <MagicMock name='network_whole_features' id='140258690673808'>, 'company_name': u'\u6d4b\u8bd5\u516c\u53f8', 'self': <test_networkFeaturesSingle.TestNetworkFeaturesSingle testMethod=test_get_features>}
 //////////////////////////////////////////////
    def test_get_features_no_mock(self):
        # given
        n_f_w = NetworkFeaturesSingle()
        company_name = u'小米科技有限责任公司'
        # when
        network_dto = n_f_w.get_features(company_name)
        # then
        self.t_u.print_domain(network_dto)
        assert network_dto

    @staticmethod
    def __get_api_value():
        return {
            u'shareOrPosJudgeDocCnt': 1,
            u'frJudgedocCnt': 2,
            u'shareOrPosRevokedCnt': 3,
            u'frRevokedCnt': 4,
            u'shareOrPosZhixingCnt': 5,
            u'allLinkJudgedocCnt': 6,
            u'frZhixingCnt': 7
        }
增加的类函数:
    def set_api_return_value(self, local_vars, api, company_name, return_value):
        """
        用于 mock 测试的 api 调用 -- 支持  Redis缓存
        :param local_vars: 通过 locals() 取得
        :param api: 要调用的api
        :param company_name: 公司名
        :param return_value: 需要作为存根的数据
        """
        api.return_value = return_value
        redis_helper = RedisHelper()
        api_name = self.__get_variable_name(local_vars, api)
        redis_helper.delete(api_name + u'__' + company_name)
        api.__name__ = api_name

    @staticmethod
    def __get_variable_name(local_vars, x):
        """
        获取变量的字串名
        :param local_vars:  通过 locals() 取得
        :param x: 需要获取字串名的变量
        """
        for k, v in local_vars.items():
            if v is x:
                return k

猜你喜欢

转载自blog.csdn.net/sinat_26566137/article/details/81065744