Spring Cloud Gateway 如何动态添加请求参数

背景介绍

项目使用的技术栈是Spring Cloud,有个功能需求是:

业务上,在Spring Cloud Gateway模块的服务已经可以获取到token,并且已实现鉴权通过后从token获取到身份信息;

现在希望把身份信息,填充到request参数里面(这里把多个数据封装成一个BaseDTO对象,用于扩展)。

后续处理具体业务的微服务模块,在controller层的方法传参,只要继承了BaseDTO对象,就可以直接获取到身份信息,用于业务逻辑处理。

问题描述

简单来说,问题就是 Spring Cloud Gateway 如何动态添加请求参数。

Spring Cloud Gateway Add Request Parameter

  1. 查看官方文档,提供了下面的示例:

docs.spring.io/spring-clou…

image.png

但是是在配置文件写明的,看起来好像只能是固定值。

  1. 在github上看到也有人提了类似问题,

github.com/spring-clou…

image.png

但是实现的效果也跟配置文件差不多。

  1. 在stackoverflow上也查了类似回答:

stackoverflow.com/questions/6…

image.png

大概思路有了方向。

解决方案

在 Spring Cloud Gateway 源码上,发现了这两个类 AddRequestParameterGatewayFilterFactoryModifyRequestBodyGatewayFilterFactory

image.png

代码内容如下:

AddRequestParameterGatewayFilterFactory

//
// Source code recreated from a .class file by IntelliJ IDEA
// (powered by FernFlower decompiler)
//

package org.springframework.cloud.gateway.filter.factory;

import java.net.URI;
import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.factory.AbstractNameValueGatewayFilterFactory.NameValueConfig;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponentsBuilder;

public class AddRequestParameterGatewayFilterFactory extends AbstractNameValueGatewayFilterFactory {
    public AddRequestParameterGatewayFilterFactory() {
    }

    public GatewayFilter apply(NameValueConfig config) {
        return (exchange, chain) -> {
            URI uri = exchange.getRequest().getURI();
            StringBuilder query = new StringBuilder();
            String originalQuery = uri.getRawQuery();
            if (StringUtils.hasText(originalQuery)) {
                query.append(originalQuery);
                if (originalQuery.charAt(originalQuery.length() - 1) != '&') {
                    query.append('&');
                }
            }

            query.append(config.getName());
            query.append('=');
            query.append(config.getValue());

            try {
                URI newUri = UriComponentsBuilder.fromUri(uri).replaceQuery(query.toString()).build(true).toUri();
                ServerHttpRequest request = exchange.getRequest().mutate().uri(newUri).build();
                return chain.filter(exchange.mutate().request(request).build());
            } catch (RuntimeException var8) {
                throw new IllegalStateException("Invalid URI query: \"" + query.toString() + "\"");
            }
        };
    }
}

复制代码

ModifyRequestBodyGatewayFilterFactory

//
// Source code recreated from a .class file by IntelliJ IDEA
// (powered by FernFlower decompiler)
//

package org.springframework.cloud.gateway.filter.factory.rewrite;

import java.util.Map;
import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.factory.AbstractGatewayFilterFactory;
import org.springframework.cloud.gateway.support.BodyInserterContext;
import org.springframework.cloud.gateway.support.CachedBodyOutputMessage;
import org.springframework.cloud.gateway.support.DefaultServerRequest;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.HttpHeaders;
import org.springframework.http.codec.ServerCodecConfigurer;
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.web.reactive.function.BodyInserter;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.server.ServerRequest;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

public class ModifyRequestBodyGatewayFilterFactory extends AbstractGatewayFilterFactory<ModifyRequestBodyGatewayFilterFactory.Config> {
    public ModifyRequestBodyGatewayFilterFactory() {
        super(ModifyRequestBodyGatewayFilterFactory.Config.class);
    }

    /** @deprecated */
    @Deprecated
    public ModifyRequestBodyGatewayFilterFactory(ServerCodecConfigurer codecConfigurer) {
        this();
    }

    public GatewayFilter apply(ModifyRequestBodyGatewayFilterFactory.Config config) {
        return (exchange, chain) -> {
            Class inClass = config.getInClass();
            ServerRequest serverRequest = new DefaultServerRequest(exchange);
            Mono<?> modifiedBody = serverRequest.bodyToMono(inClass).flatMap((o) -> {
                return config.rewriteFunction.apply(exchange, o);
            });
            BodyInserter bodyInserter = BodyInserters.fromPublisher(modifiedBody, config.getOutClass());
            CachedBodyOutputMessage outputMessage = new CachedBodyOutputMessage(exchange, exchange.getRequest().getHeaders());
            return bodyInserter.insert(outputMessage, new BodyInserterContext()).then(Mono.defer(() -> {
                ServerHttpRequestDecorator decorator = new ServerHttpRequestDecorator(exchange.getRequest()) {
                    public HttpHeaders getHeaders() {
                        HttpHeaders httpHeaders = new HttpHeaders();
                        httpHeaders.putAll(super.getHeaders());
                        httpHeaders.set("Transfer-Encoding", "chunked");
                        return httpHeaders;
                    }

                    public Flux<DataBuffer> getBody() {
                        return outputMessage.getBody();
                    }
                };
                return chain.filter(exchange.mutate().request(decorator).build());
            }));
        };
    }

    public static class Config {
        private Class inClass;
        private Class outClass;
        private Map<String, Object> inHints;
        private Map<String, Object> outHints;
        private RewriteFunction rewriteFunction;

        public Config() {
        }

        public Class getInClass() {
            return this.inClass;
        }

        public ModifyRequestBodyGatewayFilterFactory.Config setInClass(Class inClass) {
            this.inClass = inClass;
            return this;
        }

        public Class getOutClass() {
            return this.outClass;
        }

        public ModifyRequestBodyGatewayFilterFactory.Config setOutClass(Class outClass) {
            this.outClass = outClass;
            return this;
        }

        public Map<String, Object> getInHints() {
            return this.inHints;
        }

        public ModifyRequestBodyGatewayFilterFactory.Config setInHints(Map<String, Object> inHints) {
            this.inHints = inHints;
            return this;
        }

        public Map<String, Object> getOutHints() {
            return this.outHints;
        }

        public ModifyRequestBodyGatewayFilterFactory.Config setOutHints(Map<String, Object> outHints) {
            this.outHints = outHints;
            return this;
        }

        public RewriteFunction getRewriteFunction() {
            return this.rewriteFunction;
        }

        public <T, R> ModifyRequestBodyGatewayFilterFactory.Config setRewriteFunction(Class<T> inClass, Class<R> outClass, RewriteFunction<T, R> rewriteFunction) {
            this.setInClass(inClass);
            this.setOutClass(outClass);
            this.setRewriteFunction(rewriteFunction);
            return this;
        }

        public ModifyRequestBodyGatewayFilterFactory.Config setRewriteFunction(RewriteFunction rewriteFunction) {
            this.rewriteFunction = rewriteFunction;
            return this;
        }
    }
}

复制代码

实际上,可以当作官方提供的参考示例。

照着类似内容,我们可以依样画葫芦,在自己的网关过滤器上实现添加参数的功能。

实现代码

鉴权过滤器主要处理流程

@Component
public class AuthFilter implements GlobalFilter, Ordered {

    private static final Logger LOGGER = LoggerFactory.getLogger(AuthFilter.class);

    private static AntPathMatcher antPathMatcher;

    static {
        antPathMatcher = new AntPathMatcher();
    }

    @Override
    public int getOrder() {
        return FilterOrderConstant.getOrder(this.getClass().getName());
    }

    @Override
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        ServerHttpRequest request = exchange.getRequest();
        URI uri = request.getURI();
        String url = uri.getPath();
        String host = uri.getHost();

        // 跳过不需要验证的路径
        Stream<String> skipAuthUrls = UrlConstant.skipAuthUrls.stream();
        if(skipAuthUrls.anyMatch(path -> antPathMatcher.match(path, url))){
            // 直接返回
            ServerHttpRequest.Builder builder = request.mutate();
			return chain.filter(exchange.mutate().request(builder.build()).build());
        }
        // 从请求头中取出token
        String token = request.getHeaders().getFirst("Authorization");

        // 取出token包含的身份信息
        // 校验token逻辑不再阐述
        BaseDTO baseDTO = getClaim(token);
        if(null == baseDTO){
	    // 鉴权不通过,拿不到身份信息
            return illegalResponse(exchange, "{\"code\": \"401\",\"msg\": \"unauthorized.\"}");
        }


        // 将现在的request,添加当前身份信息
        ServerHttpRequest.Builder builder = request.mutate();

        Stream<String> addRequestParameterUrls = UrlConstant.addRequestParameterUrls.stream();
        if (addRequestParameterUrls.anyMatch(path -> antPathMatcher.match(path, url))){
            // 需要添加请求参数
            if(request.getMethod() == HttpMethod.GET){
                // get请求 处理参数
                return addParameterForGetMethod(exchange, chain, uri, baseDTO, builder);
            }

            if(request.getMethod() == HttpMethod.POST){
                // post请求 处理参数
                MediaType contentType = request.getHeaders().getContentType();
                if(MediaType.APPLICATION_JSON.equals(contentType)
                        || MediaType.APPLICATION_JSON_UTF8.equals(contentType)){
                    // 请求内容为 application json
                    return addParameterForPostMethod(exchange, chain, baseDTO);
                }

                if (MediaType.MULTIPART_FORM_DATA.isCompatibleWith(contentType)) {
                    // 请求内容为 form data
                    return addParameterForFormData(exchange, chain, baseDTO, builder);
                }

            }

            if(request.getMethod() == HttpMethod.PUT){
                // put请求 处理参数
                // 走 post 请求流程
                return addParameterForPostMethod(exchange, chain, baseDTO);
            }

            if(request.getMethod() == HttpMethod.DELETE){
                // delete请求 处理参数
                // 走 get 请求流程
                return addParameterForGetMethod(exchange, chain, uri, baseDTO, builder);
            }

        }


        // 当前过滤器filter执行结束
        return chain.filter(exchange.mutate().request(builder.build()).build());
    }

}
复制代码

Get请求 添加参数

/**
     * get请求,添加参数
     * @param exchange
     * @param chain
     * @param uri
     * @param baseDTO
     * @param builder
     * @return
     */
    private Mono<Void> addParameterForGetMethod(ServerWebExchange exchange, GatewayFilterChain chain, URI uri, BaseDTO baseDTO, ServerHttpRequest.Builder builder) {
        StringBuilder query = new StringBuilder();

        String originalQuery = uri.getQuery();
        if (StringUtils.hasText(originalQuery)) {
            query.append(originalQuery);
            if (originalQuery.charAt(originalQuery.length() - 1) != '&') {
                query.append('&');
            }
        }

        query.append("userId").append("=").append(baseDTO.getUserId())
                .append("&").append("userName").append("=").append(baseDTO.getUserName())
        ;

        try {
            URI newUri = UriComponentsBuilder.fromUri(uri).replaceQuery(query.toString()).build().encode().toUri();
            ServerHttpRequest request = exchange.getRequest().mutate().uri(newUri).build();
            return chain.filter(exchange.mutate().request(request).build());
        } catch (Exception e) {
            LOGGER.error("Invalid URI query: " + query.toString(), e);
            // 当前过滤器filter执行结束
            return chain.filter(exchange.mutate().request(builder.build()).build());
        }
    }
复制代码

Post请求 添加参数

请求内容为 application json

/**
     * post请求,添加参数
     * @param exchange
     * @param chain
     * @param baseDTO
     * @return
     */
    private Mono<Void> addParameterForPostMethod(ServerWebExchange exchange, GatewayFilterChain chain, BaseDTO baseDTO) {
        ServerRequest serverRequest = new DefaultServerRequest(exchange);
        AtomicBoolean flag = new AtomicBoolean(false);
        Mono<String> modifiedBody = serverRequest.bodyToMono(String.class).flatMap((o) -> {
            if(o.startsWith("[")){
                // body内容为数组,直接返回
                return Mono.just(o);
            }

            ObjectMapper objectMapper = new ObjectMapper();
            try {
                Map map = objectMapper.readValue(o, Map.class);

                map.put("userId", baseDTO.getUserId());
                map.put("userName", baseDTO.getUserName());

                String json = objectMapper.writeValueAsString(map);
                LOGGER.info("addParameterForPostMethod -> json = {}", json);
                return Mono.just(json);
            }catch (Exception e){
                e.printStackTrace();
                return Mono.just(o);
            }
        });

        BodyInserter bodyInserter = BodyInserters.fromPublisher(modifiedBody, String.class);
        CachedBodyOutputMessage outputMessage = new CachedBodyOutputMessage(exchange, exchange.getRequest().getHeaders());
        return bodyInserter.insert(outputMessage, new BodyInserterContext()).then(Mono.defer(() -> {
            ServerHttpRequestDecorator decorator = new ServerHttpRequestDecorator(exchange.getRequest()) {
                public HttpHeaders getHeaders() {
                    HttpHeaders httpHeaders = new HttpHeaders();
                    httpHeaders.putAll(super.getHeaders());
                    httpHeaders.set("Transfer-Encoding", "chunked");
                    return httpHeaders;
                }

                public Flux<DataBuffer> getBody() {
                    return outputMessage.getBody();
                }
            };
            return chain.filter(exchange.mutate().request(decorator).build());
        }));
    }
复制代码

请求内容为 form data

/**
     * post请求,form data 添加参数
     * @param exchange
     * @param chain
     * @param baseDTO
     * @param builder
     * @return
     */
    private Mono<Void> addParameterForFormData(ServerWebExchange exchange, GatewayFilterChain chain, BaseDTO baseDTO, ServerHttpRequest.Builder builder) {
        builder.header("userId", String.valueOf(baseDTO.getUserId()));
        try {
            builder.header("userName", URLEncoder.encode(String.valueOf(baseDTO.getUserName()), "UTF-8"));
        } catch (UnsupportedEncodingException e) {
            builder.header("userName", String.valueOf(baseDTO.getUserName()));
        }
        ServerHttpRequest serverHttpRequest = builder.build();
        HttpHeaders headers = serverHttpRequest.getHeaders();

        return chain.filter(exchange.mutate().request(serverHttpRequest).build());
    }
复制代码

返回数据处理

/**
     * 返回消息
     * @param exchange
     * @param data
     * @return
     */
    private Mono<Void> illegalResponse(ServerWebExchange exchange, String data) {
        ServerHttpResponse originalResponse = exchange.getResponse();
        originalResponse.setStatusCode(HttpStatus.OK);
        originalResponse.getHeaders().add(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_UTF8_VALUE);
        byte[] response = data.getBytes(StandardCharsets.UTF_8);
        DataBuffer buffer = originalResponse.bufferFactory().wrap(response);
        return originalResponse.writeWith(Flux.just(buffer));
    }
复制代码

最终效果

上面描述,已实现将userIduserName两个属性,写入到request参数中。

在具体业务处理的服务模块,controller层的传参,只要继承包含userId和userName两个属性的BaseDTO类,就可以拿到该信息,用于实际的业务流程。

猜你喜欢

转载自juejin.im/post/7037745496105943070