Rasa课程、Rasa培训、Rasa面试、Rasa实战系列之FormAction(三)

Rasa课程、Rasa培训、Rasa面试、Rasa实战系列之FormAction(三)

aioresponses

模拟ClientSession发出的aiohttp请求

class aioresponses(object):
    """Mock aiohttp requests made by ClientSession."""
    _matches = None  # type: Dict[str, RequestMatch]
    _responses = None  # type: List[ClientResponse]
    requests = None  # type: Dict

    def __init__(self, **kwargs):
        self._param = kwargs.pop('param', None)
        self._passthrough = kwargs.pop('passthrough', [])
        self.patcher = patch('aiohttp.client.ClientSession._request',
                             side_effect=self._request_mock,
                             autospec=True)
        self.requests = {
    
    }

    def __enter__(self) -> 'aioresponses':
        self.start()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.stop()

    def __call__(self, f):
        def _pack_arguments(ctx, *args, **kwargs) -> Tuple[Tuple, Dict]:
            if self._param:
                kwargs[self._param] = ctx
            else:
                args += (ctx,)
            return args, kwargs

        if asyncio.iscoroutinefunction(f):
            @wraps(f)
            async def wrapped(*args, **kwargs):
                with self as ctx:
                    args, kwargs = _pack_arguments(ctx, *args, **kwargs)
                    return await f(*args, **kwargs)
        else:
            @wraps(f)
            def wrapped(*args, **kwargs):
                with self as ctx:
                    args, kwargs = _pack_arguments(ctx, *args, **kwargs)
                    return f(*args, **kwargs)
        return wrapped

    def clear(self):
        self._responses.clear()
        self._matches.clear()

    def start(self):
        self._responses = []
        self._matches = {
    
    }
        self.patcher.start()
        self.patcher.return_value = self._request_mock

    def stop(self) -> None:
        for response in self._responses:
            response.close()
        self.patcher.stop()
        self.clear()

    def head(self, url: 'Union[URL, str, Pattern]', **kwargs):
        self.add(url, method=hdrs.METH_HEAD, **kwargs)

    def get(self, url: 'Union[URL, str, Pattern]', **kwargs):
        self.add(url, method=hdrs.METH_GET, **kwargs)

    def post(self, url: 'Union[URL, str, Pattern]', **kwargs):
        self.add(url, method=hdrs.METH_POST, **kwargs)

    def put(self, url: 'Union[URL, str, Pattern]', **kwargs):
        self.add(url, method=hdrs.METH_PUT, **kwargs)

    def patch(self, url: 'Union[URL, str, Pattern]', **kwargs):
        self.add(url, method=hdrs.METH_PATCH, **kwargs)

    def delete(self, url: 'Union[URL, str, Pattern]', **kwargs):
        self.add(url, method=hdrs.METH_DELETE, **kwargs)

    def options(self, url: 'Union[URL, str, Pattern]', **kwargs):
        self.add(url, method=hdrs.METH_OPTIONS, **kwargs)

    def add(self, url: 'Union[URL, str, Pattern]', method: str = hdrs.METH_GET,
            status: int = 200,
            body: Union[str, bytes] = '',
            exception: 'Exception' = None,
            content_type: str = 'application/json',
            payload: Dict = None,
            headers: Dict = None,
            response_class: 'ClientResponse' = None,
            repeat: bool = False,
            timeout: bool = False,
            reason: Optional[str] = None,
            callback: Optional[Callable] = None) -> None:

        self._matches[str(uuid4())] = (RequestMatch(
            url,
            method=method,
            status=status,
            content_type=content_type,
            body=body,
            exception=exception,
            payload=payload,
            headers=headers,
            response_class=response_class,
            repeat=repeat,
            timeout=timeout,
            reason=reason,
            callback=callback,
        ))

    @staticmethod
    def is_exception(resp_or_exc: Union[ClientResponse, Exception]) -> bool:
        if inspect.isclass(resp_or_exc):
            parent_classes = set(inspect.getmro(resp_or_exc))
            if {
    
    Exception, BaseException} & parent_classes:
                return True
        else:
            if isinstance(resp_or_exc, (Exception, BaseException)):
                return True
        return False

    async def match(
        self, method: str, url: URL,
        allow_redirects: bool = True, **kwargs: Dict
    ) -> Optional['ClientResponse']:
        history = []
        while True:
            for key, matcher in self._matches.items():
                if matcher.match(method, url):
                    response_or_exc = await matcher.build_response(
                        url, allow_redirects=allow_redirects, **kwargs
                    )
                    break
            else:
                return None

            if matcher.repeat is False:
                del self._matches[key]

            if self.is_exception(response_or_exc):
                raise response_or_exc
            is_redirect = response_or_exc.status in (301, 302, 303, 307, 308)
            if is_redirect and allow_redirects:
                if hdrs.LOCATION not in response_or_exc.headers:
                    break
                history.append(response_or_exc)
                url = URL(response_or_exc.headers[hdrs.LOCATION])
                method = 'get'
                continue
            else:
                break

        response_or_exc._history = tuple(history)

        return response_or_exc

    async def _request_mock(self, orig_self: ClientSession,
                            method: str, url: 'Union[URL, str]',
                            *args: Tuple,
                            **kwargs: Dict) -> 'ClientResponse':
        """Return mocked response object or raise connection error."""
        if orig_self.closed:
            raise RuntimeError('Session is closed')

        url_origin = url
        url = normalize_url(merge_params(url, kwargs.get('params')))
        url_str = str(url)
        for prefix in self._passthrough:
            if url_str.startswith(prefix):
                return (await self.patcher.temp_original(
                    orig_self, method, url_origin, *args, **kwargs
                ))

        key = (method, url)
        self.requests.setdefault(key, [])
        try:
            kwargs_copy = copy.deepcopy(kwargs)
        except (TypeError, ValueError):
            # Handle the fact that some values cannot be deep copied
            kwargs_copy = kwargs
        self.requests[key].append(RequestCall(args, kwargs_copy))

        response = await self.match(method, url, **kwargs)

        if response is None:
            raise ClientConnectionError(
                'Connection refused: {} {}'.format(method, url)
            )
        self._responses.append(response)

        # Automatically call response.raise_for_status() on a request if the
        # request was initialized with raise_for_status=True. Also call
        # response.raise_for_status() if the client session was initialized
        # with raise_for_status=True, unless the request was called with
        # raise_for_status=False.
        raise_for_status = kwargs.get('raise_for_status')
        if raise_for_status is None:
            raise_for_status = getattr(
                orig_self, '_raise_for_status', False
            )
        if raise_for_status:
            response.raise_for_status()

        return response

EndpointConfig

外部HTTP端点的配置

class EndpointConfig:
    """Configuration for an external HTTP endpoint."""

    def __init__(
        self,
        url: Optional[Text] = None,
        params: Optional[Dict[Text, Any]] = None,
        headers: Optional[Dict[Text, Any]] = None,
        basic_auth: Optional[Dict[Text, Text]] = None,
        token: Optional[Text] = None,
        token_name: Text = "token",
        cafile: Optional[Text] = None,
        **kwargs: Any,
    ) -> None:
        """Creates an `EndpointConfig` instance."""
        self.url = url
        self.params = params or {
    
    }
        self.headers = headers or {
    
    }
        self.basic_auth = basic_auth or {
    
    }
        self.token = token
        self.token_name = token_name
        self.type = kwargs.pop("store_type", kwargs.pop("type", None))
        self.cafile = cafile
        self.kwargs = kwargs

    def session(self) -> aiohttp.ClientSession:
        """Creates and returns a configured aiohttp client session."""
        # create authentication parameters
        if self.basic_auth:
            auth = aiohttp.BasicAuth(
                self.basic_auth["username"], self.basic_auth["password"]
            )
        else:
            auth = None

        return aiohttp.ClientSession(
            headers=self.headers,
            auth=auth,
            timeout=aiohttp.ClientTimeout(total=DEFAULT_REQUEST_TIMEOUT),
        )

    def combine_parameters(
        self, kwargs: Optional[Dict[Text, Any]] = None
    ) -> Dict[Text, Any]:
        # construct GET parameters
        params = self.params.copy()

        # set the authentication token if present
        if self.token:
            params[self.token_name] = self.token

        if kwargs and "params" in kwargs:
            params.update(kwargs["params"])
            del kwargs["params"]
        return params

    async def request(
        self,
        method: Text = "post",
        subpath: Optional[Text] = None,
        content_type: Optional[Text] = "application/json",
        **kwargs: Any,
    ) -> Optional[Any]:
        """Send a HTTP request to the endpoint. Return json response, if available.

        All additional arguments will get passed through
        to aiohttp's `session.request`."""

        # create the appropriate headers
        headers = {
    
    }
        if content_type:
            headers["Content-Type"] = content_type

        if "headers" in kwargs:
            headers.update(kwargs["headers"])
            del kwargs["headers"]

        url = concat_url(self.url, subpath)

        sslcontext = None
        if self.cafile:
            try:
                sslcontext = ssl.create_default_context(cafile=self.cafile)
            except FileNotFoundError as e:
                raise FileNotFoundException(
                    f"Failed to find certificate file, "
                    f"'{
      
      os.path.abspath(self.cafile)}' does not exist."
                ) from e

        async with self.session() as session:
            async with session.request(
                method,
                url,
                headers=headers,
                params=self.combine_parameters(kwargs),
                ssl=sslcontext,
                **kwargs,
            ) as response:
                if response.status >= 400:
                    raise ClientResponseError(
                        response.status, response.reason, await response.content.read()
                    )
                try:
                    return await response.json()
                except ContentTypeError:
                    return None

    @classmethod
    def from_dict(cls, data: Dict[Text, Any]) -> "EndpointConfig":
        return EndpointConfig(**data)

    def copy(self) -> "EndpointConfig":
        return EndpointConfig(
            self.url,
            self.params,
            self.headers,
            self.basic_auth,
            self.token,
            self.token_name,
            **self.kwargs,
        )

    def __eq__(self, other: Any) -> bool:
        if isinstance(self, type(other)):
            return (
                other.url == self.url
                and other.params == self.params
                and other.headers == self.headers
                and other.basic_auth == self.basic_auth
                and other.token == self.token
                and other.token_name == self.token_name
            )
        else:
            return False

    def __ne__(self, other: Any) -> bool:
        return not self.__eq__(other)

FormAction

实现和执行表单逻辑的Action动作


class FormAction(LoopAction):
    """Action which implements and executes the form logic."""

    def __init__(
        self, form_name: Text, action_endpoint: Optional[EndpointConfig]
    ) -> None:
        """Creates a `FormAction`.

        Args:
            form_name: Name of the form.
            action_endpoint: Endpoint to execute custom actions.
        """
        self._form_name = form_name
        self.action_endpoint = action_endpoint
        # creating it requires domain, which we don't have in init
        # we'll create it on the first call
        self._unique_entity_mappings: Set[Text] = set()
        self._have_unique_entity_mappings_been_initialized = False

    def name(self) -> Text:
        """Return the form name."""
        return self._form_name

RemoteAction

class RemoteAction(Action):
    def __init__(self, name: Text, action_endpoint: Optional[EndpointConfig]) -> None:

        self._name = name
        self.action_endpoint = action_endpoint

    def _action_call_format(
        self, tracker: "DialogueStateTracker", domain: "Domain"
    ) -> Dict[Text, Any]:
        """Create the request json send to the action server."""
        from rasa.shared.core.trackers import EventVerbosity

        tracker_state = tracker.current_state(EventVerbosity.ALL)

        return {
    
    
            "next_action": self._name,
            "sender_id": tracker.sender_id,
            "tracker": tracker_state,
            "domain": domain.as_dict(),
            "version": rasa.__version__,
        }

    @staticmethod
    def action_response_format_spec() -> Dict[Text, Any]:
        """Expected response schema for an Action endpoint.

        Used for validation of the response returned from the
        Action endpoint."""
        schema = {
    
    
            "type": "object",
            "properties": {
    
    
                "events": EVENTS_SCHEMA,
                "responses": {
    
    "type": "array", "items": {
    
    "type": "object"}},
            },
        }
        return schema

    def _validate_action_result(self, result: Dict[Text, Any]) -> bool:
        from jsonschema import validate
        from jsonschema import ValidationError

        try:
            validate(result, self.action_response_format_spec())
            return True
        except ValidationError as e:
            e.message += (
                f". Failed to validate Action server response from API, "
                f"make sure your response from the Action endpoint is valid. "
                f"For more information about the format visit "
                f"{
      
      DOCS_BASE_URL}/custom-actions"
            )
            raise e

    @staticmethod
    def _utter_responses(
        responses: List[Dict[Text, Any]],
        output_channel: "OutputChannel",
        nlg: "NaturalLanguageGenerator",
        tracker: "DialogueStateTracker",
    ) -> List[BotUttered]:
        """Use the responses generated by the action endpoint and utter them."""
        bot_messages = []
        for response in responses:
            generated_response = response.pop("response", None)
            if generated_response:
                draft =  nlg.generate(
                    generated_response, tracker, output_channel.name(), **response
                )
                if not draft:
                    continue
                draft["utter_action"] = generated_response
            else:
                draft = {
    
    }

            buttons = response.pop("buttons", []) or []
            if buttons:
                draft.setdefault("buttons", [])
                draft["buttons"].extend(buttons)

            # Avoid overwriting `draft` values with empty values
            response = {
    
    k: v for k, v in response.items() if v}
            draft.update(response)
            bot_messages.append(create_bot_utterance(draft))

        return bot_messages

    def run(
        self,
        output_channel: "OutputChannel",
        nlg: "NaturalLanguageGenerator",
        tracker: "DialogueStateTracker",
        domain: "Domain",
    ) -> List[Event]:
        """Runs action. Please see parent class for the full docstring."""
        json_body = self._action_call_format(tracker, domain)
        if not self.action_endpoint:
            raise RasaException(
                f"Failed to execute custom action '{
      
      self.name()}' "
                f"because no endpoint is configured to run this "
                f"custom action. Please take a look at "
                f"the docs and set an endpoint configuration via the "
                f"--endpoints flag. "
                f"{
      
      DOCS_BASE_URL}/custom-actions"
            )

        try:
            logger.debug(
                "Calling action endpoint to run action '{}'.".format(self.name())
            )
            response: Any =  self.action_endpoint.request(
                json=json_body, method="post", timeout=DEFAULT_REQUEST_TIMEOUT
            )

            self._validate_action_result(response)

            events_json = response.get("events", [])
            responses = response.get("responses", [])
            bot_messages: List[Event] =  self._utter_responses(
                responses, output_channel, nlg, tracker
            )

            evts = events.deserialise_events(events_json)
            return bot_messages + evts

        except ClientResponseError as e:
            if e.status == 400:
                response_data = json.loads(e.text)
                exception = ActionExecutionRejection(
                    response_data["action_name"], response_data.get("error")
                )
                logger.error(exception.message)
                raise exception
            else:
                raise RasaException("Failed to execute custom action.") from e

        except aiohttp.ClientConnectionError as e:
            logger.error(
                "Failed to run custom action '{}'. Couldn't connect "
                "to the server at '{}'. Is the server running? "
                "Error: {}".format(self.name(), self.action_endpoint.url, e)
            )
            raise RasaException("Failed to execute custom action.")

        except aiohttp.ClientError as e:
            # not all errors have a status attribute, but
            # helpful to log if they got it

            # noinspection PyUnresolvedReferences
            status = getattr(e, "status", None)
            raise RasaException(
                "Failed to run custom action '{}'. Action server "
                "responded with a non 200 status code of {}. "
                "Make sure your action server properly runs actions "
                "and returns a 200 once the action is executed. "
                "Error: {}".format(self.name(), status, e)
            )

    def name(self) -> Text:
        return self._name


events.py

ENTITIES_SCHEMA = {
    
    
    "type": "array",
    "items": {
    
    
        "type": "object",
        "properties": {
    
    
            "start": {
    
    "type": "integer"},
            "end": {
    
    "type": "integer"},
            "entity": {
    
    "type": "string"},
            "confidence": {
    
    "type": "number"},
            "extractor": {
    
    "type": ["string", "null"]},
            "value": {
    
    },
            "role": {
    
    "type": ["string", "null"]},
            "group": {
    
    "type": ["string", "null"]},
        },
        "required": ["entity", "value"],
    },
}

INTENT = {
    
    
    "type": "object",
    "properties": {
    
    "name": {
    
    "type": "string"}, "confidence": {
    
    "type": "number"}},
}

RESPONSE_SCHEMA = {
    
    
    "type": "object",
    "properties": {
    
    
        "responses": {
    
    
            "type": "array",
            "items": {
    
    "type": "object", "properties": {
    
    "text": {
    
    "type": "string"}}},
        },
        "response_templates": {
    
    
            "type": "array",
            "items": {
    
    "type": "object", "properties": {
    
    "text": {
    
    "type": "string"}}},
        },
        "confidence": {
    
    "type": "number"},
        "intent_response_key": {
    
    "type": "string"},
        "utter_action": {
    
    "type": "string"},
        "template_name": {
    
    "type": "string"},
    },
}

RANKING_SCHEMA = {
    
    
    "type": "array",
    "items": {
    
    
        "type": "object",
        "properties": {
    
    
            "id": {
    
    "type": "number"},
            "confidence": {
    
    "type": "number"},
            "intent_response_key": {
    
    "type": "string"},
        },
    },
}

USER_UTTERED = {
    
    
    "properties": {
    
    
        "event": {
    
    "const": "user"},
        "text": {
    
    "type": ["string", "null"]},
        "input_channel": {
    
    "type": ["string", "null"]},
        "message_id": {
    
    "type": ["string", "null"]},
        "parse_data": {
    
    
            "type": "object",
            "properties": {
    
    
                "text": {
    
    "type": ["string", "null"]},
                "intent_ranking": {
    
    "type": "array", "items": INTENT},
                "intent": INTENT,
                "entities": ENTITIES_SCHEMA,
                "response_selector": {
    
    
                    "type": "object",
                    "oneOf": [
                        {
    
    "properties": {
    
    "all_retrieval_intents": {
    
    "type": "array"}}},
                        {
    
    
                            "patternProperties": {
    
    
                                "[\\w/]": {
    
    
                                    "type": "object",
                                    "properties": {
    
    
                                        "response": RESPONSE_SCHEMA,
                                        "ranking": RANKING_SCHEMA,
                                    },
                                }
                            }
                        },
                    ],
                },
            },
        },
    }
}

ACTION_EXECUTED = {
    
    
    "properties": {
    
    
        "event": {
    
    "const": "action"},
        "policy": {
    
    "type": ["string", "null"]},
        "confidence": {
    
    "type": ["number", "null"]},
        "name": {
    
    "type": ["string", "null"]},
        "hide_rule_turn": {
    
    "type": "boolean"},
        "action_text": {
    
    "type": ["string", "null"]},
    }
}

SLOT_SET = {
    
    
    "properties": {
    
    "event": {
    
    "const": "slot"}, "name": {
    
    "type": "string"}, "value": {
    
    }},
    "required": ["name", "value"],
}

ENTITIES_ADDED = {
    
    
    "properties": {
    
    "event": {
    
    "const": "entities"}, "entities": ENTITIES_SCHEMA},
    "required": ["entities"],
}

USER_UTTERED_FEATURIZATION = {
    
    "properties": {
    
    "event": {
    
    "const": "user_featurization"}}}
REMINDER_CANCELLED = {
    
    "properties": {
    
    "event": {
    
    "const": "cancel_reminder"}}}
REMINDER_SCHEDULED = {
    
    "properties": {
    
    "event": {
    
    "const": "reminder"}}}
ACTION_EXECUTION_REJECTED = {
    
    
    "properties": {
    
    "event": {
    
    "const": "action_execution_rejected"}}
}
FORM_VALIDATION = {
    
    "properties": {
    
    "event": {
    
    "const": "form_validation"}}}
LOOP_INTERRUPTED = {
    
    "properties": {
    
    "event": {
    
    "const": "loop_interrupted"}}}
FORM = {
    
    "properties": {
    
    "event": {
    
    "const": "form"}}}
ACTIVE_LOOP = {
    
    "properties": {
    
    "event": {
    
    "const": "active_loop"}}}
ALL_SLOTS_RESET = {
    
    "properties": {
    
    "event": {
    
    "const": "reset_slots"}}}
CONVERSATION_RESUMED = {
    
    "properties": {
    
    "event": {
    
    "const": "resume"}}}
CONVERSATION_PAUSED = {
    
    "properties": {
    
    "event": {
    
    "const": "pause"}}}
FOLLOWUP_ACTION = {
    
    "properties": {
    
    "event": {
    
    "const": "followup"}}}
STORY_EXPORTED = {
    
    "properties": {
    
    "event": {
    
    "const": "export"}}}
RESTARTED = {
    
    "properties": {
    
    "event": {
    
    "const": "restart"}}}
ACTION_REVERTED = {
    
    "properties": {
    
    "event": {
    
    "const": "undo"}}}
USER_UTTERANCE_REVERTED = {
    
    "properties": {
    
    "event": {
    
    "const": "rewind"}}}
BOT_UTTERED = {
    
    "properties": {
    
    "event": {
    
    "const": "bot"}}}
SESSION_STARTED = {
    
    "properties": {
    
    "event": {
    
    "const": "session_started"}}}
AGENT_UTTERED = {
    
    "properties": {
    
    "event": {
    
    "const": "agent"}}}

EVENT_SCHEMA = {
    
    
    "type": "object",
    "properties": {
    
    
        "event": {
    
    "type": "string"},
        "timestamp": {
    
    "type": ["number", "null"]},
        "metadata": {
    
    "type": ["object", "null"]},
    },
    "required": ["event"],
    "oneOf": [
        USER_UTTERED,
        ACTION_EXECUTED,
        SLOT_SET,
        ENTITIES_ADDED,
        USER_UTTERED_FEATURIZATION,
        REMINDER_CANCELLED,
        REMINDER_SCHEDULED,
        ACTION_EXECUTION_REJECTED,
        FORM_VALIDATION,
        LOOP_INTERRUPTED,
        FORM,
        ACTIVE_LOOP,
        ALL_SLOTS_RESET,
        CONVERSATION_RESUMED,
        CONVERSATION_PAUSED,
        FOLLOWUP_ACTION,
        STORY_EXPORTED,
        RESTARTED,
        ACTION_REVERTED,
        USER_UTTERANCE_REVERTED,
        BOT_UTTERED,
        SESSION_STARTED,
        AGENT_UTTERED,
    ],
}

EVENTS_SCHEMA = {
    
    "type": "array", "items": EVENT_SCHEMA}

Rasa 3.x系列博客分享

猜你喜欢

转载自blog.csdn.net/duan_zhihua/article/details/124218958