背景介绍
有个业务需求,要提供一套API接口给第三方调用。
在处理具体业务接口之前,设计上要先做个简单的鉴权,协商拟定了身份传参后,考虑到项目上已经用到了Spring Cloud Gateway ,就统一在网关模块做身份校验。
所以在服务端获取到请求的时候,要先拦截获取到请求传参,才能做后续的鉴权逻辑。
这里就需要解决一个问题:Spring Cloud Gateway 怎么读取请求传参?
搜索关键词:spring cloud gateway get request body
问题描述
问题:Spring Cloud Gateway 读取请求传参
这里只简单处理两种情况,get请求和post请求。
如果发现是get请求,就取url上的参数;
如果发现是post请求,就读取body的内容。
解决方案
参考 https://github.com/spring-cloud/spring-cloud-gateway/issues/747
定义了两个过滤器 filter,第一个过滤器ApiRequestFilter
获取参数,放到上下文 GatewayContext
。
注意如果是POST请求,请求体读取完后,要重新构造,填回请求体中。
第二个过滤器ApiVerifyFilter
, 从上下文可以直接获取到参数。
后面如果其他业务也有读取参数的需求,就直接从上下文获取,不用再重复写获取参数的逻辑。
实现代码
GatewayContext
@Data
public class GatewayContext {
public static final String CACHE_GATEWAY_CONTEXT = "cacheGatewayContext";
/**
* cache json body
*/
private String cacheBody;
/**
* cache form data
*/
private MultiValueMap<String, Part> formData;
/**
* cache request path
*/
private String path;
}
ApiRequestFilter
@Component
@Slf4j
public class ApiRequestFilter implements GlobalFilter, Ordered {
private static AntPathMatcher antPathMatcher;
static {
antPathMatcher = new AntPathMatcher();
}
/**
* default HttpMessageReader
*/
private static final List<HttpMessageReader<?>> messageReaders = HandlerStrategies.withDefaults().messageReaders();
private static final ResolvableType MULTIPART_DATA_TYPE = ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, Part.class);
private static final Mono<MultiValueMap<String, Part>> EMPTY_MULTIPART_DATA = Mono.just(CollectionUtils.unmodifiableMultiValueMap(new LinkedMultiValueMap<String, Part>(0))).cache();
@Override
public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
ServerHttpRequest request = exchange.getRequest();
String url = request.getURI().getPath();
if(request.getMethod() == HttpMethod.GET){
// get请求 处理参数
return handleGetMethod(exchange, chain, request);
}
if(request.getMethod() == HttpMethod.POST){
// post请求 处理参数
return handlePostMethod(exchange, chain, request);
}
return chain.filter(exchange);
}
/**
* get请求 处理参数
* @param exchange
* @param chain
* @param request
* @return
*/
private Mono<Void> handleGetMethod(ServerWebExchange exchange, GatewayFilterChain chain, ServerHttpRequest request) {
// TODO 暂时不做处理
return chain.filter(exchange);
}
/**
* post请求 校验参数
* @param exchange
* @param chain
* @param request
* @return
*/
private Mono<Void> handlePostMethod(ServerWebExchange exchange, GatewayFilterChain chain, ServerHttpRequest request){
GatewayContext gatewayContext = new GatewayContext();
gatewayContext.setPath(request.getPath().pathWithinApplication().value());
/**
* save gateway context into exchange
*/
exchange.getAttributes().put(GatewayContext.CACHE_GATEWAY_CONTEXT, gatewayContext);
MediaType contentType = request.getHeaders().getContentType();
if(MediaType.APPLICATION_JSON.equals(contentType)
|| MediaType.APPLICATION_JSON_UTF8.equals(contentType)){
// 请求内容为 application json
// 重新构造 请求体
return readJsonBody(exchange, chain, gatewayContext);
}
if (MediaType.MULTIPART_FORM_DATA.isCompatibleWith(contentType)) {
// 请求内容为 form data
return readFormData(exchange, chain, gatewayContext);
}
return chain.filter(exchange);
}
/**
* post 请求
* 重新构造 请求体
* @param exchange
* @param chain
* @param gatewayContext
* @return
*/
private Mono<Void> readJsonBody(ServerWebExchange exchange, GatewayFilterChain chain, GatewayContext gatewayContext) {
return DataBufferUtils.join(exchange.getRequest().getBody())
.flatMap(dataBuffer -> {
/*
* read the body Flux<DataBuffer>, and release the buffer
* //TODO when SpringCloudGateway Version Release To G.SR2,this can be update with the new version's feature
* see PR https://github.com/spring-cloud/spring-cloud-gateway/pull/1095
*/
byte[] bytes = new byte[dataBuffer.readableByteCount()];
dataBuffer.read(bytes);
DataBufferUtils.release(dataBuffer);
Flux<DataBuffer> cachedFlux = Flux.defer(() -> {
DataBuffer buffer = exchange.getResponse().bufferFactory().wrap(bytes);
DataBufferUtils.retain(buffer);
return Mono.just(buffer);
});
/**
* repackage ServerHttpRequest
*/
ServerHttpRequest mutatedRequest =
new ServerHttpRequestDecorator(exchange.getRequest()) {
@Override
public Flux<DataBuffer> getBody() {
return cachedFlux;
}
};
/**
* mutate exchage with new ServerHttpRequest
*/
ServerWebExchange mutatedExchange = exchange.mutate().request(mutatedRequest).build();
/**
* read body string with default messageReaders
*/
return ServerRequest.create(mutatedExchange, messageReaders)
.bodyToMono(String.class)
.doOnNext(objectValue -> {
// save body into gatewayContext
gatewayContext.setCacheBody(objectValue);
})
.then(chain.filter(mutatedExchange));
});
}
private Mono<Void> readFormData(ServerWebExchange exchange, GatewayFilterChain chain, GatewayContext gatewayContext) {
return exchange.getRequest().getBody().collectList().flatMap(dataBuffers -> {
final byte[] totalBytes = dataBuffers.stream().map(dataBuffer -> {
try {
final byte[] bytes = IOUtils.toByteArray(dataBuffer.asInputStream());
// System.out.println(new String(bytes));
return bytes;
} catch (IOException e) {
throw new RuntimeException(e);
}
}).reduce(this::addBytes).get();
final ServerHttpRequestDecorator decorator = new ServerHttpRequestDecorator(exchange.getRequest()) {
@Override
public Flux<DataBuffer> getBody() {
return Flux.just(buffer(totalBytes));
}
};
final ServerCodecConfigurer configurer = ServerCodecConfigurer.create();
final Mono<MultiValueMap<String, Part>> multiValueMapMono = repackageMultipartData(decorator, configurer);
return multiValueMapMono.flatMap(part -> {
for (String key : part.keySet()) {
// 如果为文件时 则进入下一次循环
if (key.equals("file")) {
continue;
}
part.getFirst(key).content().subscribe(buffer -> {
final byte[] bytes = new byte[buffer.readableByteCount()];
buffer.read(bytes);
DataBufferUtils.release(buffer);
try {
final String bodyString = new String(bytes, "utf-8");
gatewayContext.setCacheBody(bodyString);
} catch (UnsupportedEncodingException e) {
e.printStackTrace();
}
});
}
return chain.filter(exchange.mutate().request(decorator).build());
});
});
}
@SuppressWarnings("unchecked")
private static Mono<MultiValueMap<String, Part>> repackageMultipartData(ServerHttpRequest request, ServerCodecConfigurer configurer) {
try {
final MediaType contentType = request.getHeaders().getContentType();
if (MediaType.MULTIPART_FORM_DATA.isCompatibleWith(contentType)) {
return ((HttpMessageReader<MultiValueMap<String, Part>>) configurer.getReaders().stream().filter(reader -> reader.canRead(MULTIPART_DATA_TYPE, MediaType.MULTIPART_FORM_DATA))
.findFirst().orElseThrow(() -> new IllegalStateException("No multipart HttpMessageReader."))).readMono(MULTIPART_DATA_TYPE, request, Collections.emptyMap())
.switchIfEmpty(EMPTY_MULTIPART_DATA).cache();
}
} catch (InvalidMediaTypeException ex) {
// Ignore
}
return EMPTY_MULTIPART_DATA;
}
/**
* addBytes.
* @param first first
* @param second second
* @return byte
*/
public byte[] addBytes(byte[] first, byte[] second) {
final byte[] result = Arrays.copyOf(first, first.length + second.length);
System.arraycopy(second, 0, result, first.length, second.length);
return result;
}
private DataBuffer buffer(byte[] bytes) {
final NettyDataBufferFactory nettyDataBufferFactory = new NettyDataBufferFactory(ByteBufAllocator.DEFAULT);
final DataBuffer buffer = nettyDataBufferFactory.allocateBuffer(bytes.length);
buffer.write(bytes);
return buffer;
}
@Override
public int getOrder() {
return FilterOrderConstant.getOrder(this.getClass().getName());
}
}
ApiVerifyFilter
@Component
@Slf4j
public class ApiVerifyFilter implements GlobalFilter, Ordered {
@Override
public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
ServerHttpRequest request = exchange.getRequest();
String url = request.getURI().getPath();
if(request.getMethod() == HttpMethod.GET){
// get请求 校验参数
return verifyGetMethod(exchange, chain, request);
}
if(request.getMethod() == HttpMethod.POST){
// post请求 校验参数
return verifyPostMethod(exchange, chain, request);
}
return chain.filter(exchange);
}
/**
* get请求 校验参数
* @param exchange
* @param chain
* @param request
* @return
*/
private Mono<Void> verifyGetMethod(ServerWebExchange exchange, GatewayFilterChain chain, ServerHttpRequest request) {
// get请求获取参数
Map<String, String> queryParamMap = request.getQueryParams().toSingleValueMap();
// 具体业务参数
String secretId = queryParamMap.get("secretId");
String secretKey = queryParamMap.get("secretKey");
// 校验参数逻辑
return verifyParams(exchange, chain, secretId, secretKey);
}
/**
* post请求 校验参数
* @param exchange
* @param chain
* @param request
* @return
*/
private Mono<Void> verifyPostMethod(ServerWebExchange exchange, GatewayFilterChain chain, ServerHttpRequest request) {
try {
GatewayContext gatewayContext = (GatewayContext)exchange.getAttributes().get(GatewayContext.CACHE_GATEWAY_CONTEXT);
// get body from gatewayContext
String cacheBody = gatewayContext.getCacheBody();
Map map = new ObjectMapper().readValue(cacheBody, Map.class);
// 具体业务参数
String secretId = String.valueOf(map.get("secretId"));
String secretKey = String.valueOf(map.get("secretKey"));
// 校验参数逻辑
return verifyParams(exchange, chain, secretId, secretKey);
} catch (Exception e){
log.error("解析body内容失败:{}", e);
// 403
return response(exchange, R.fail().enumCode(HttpCode.FORBIDDEN));
}
}
/**
* 校验参数
* @param exchange
* @param chain
* @param secretId
* @param secretKey
* @return
*/
private Mono<Void> verifyParams(ServerWebExchange exchange, GatewayFilterChain chain, String secretId, String secretKey) {
// 校验失败,则返回相应提示
// return response(exchange, R.fail().enumCode(HttpCode.UNAUTHORIZED));
// todo
// 校验成功,则当前过滤器执行完毕
return chain.filter(exchange);
}
/**
* response 返回code
* @param exchange
* @param r
* @return
*/
private Mono<Void> response(ServerWebExchange exchange, R r) {
ServerHttpResponse originalResponse = exchange.getResponse();
originalResponse.setStatusCode(HttpStatus.OK);
originalResponse.getHeaders().add(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_UTF8_VALUE);
try {
byte[] bytes = new ObjectMapper().writeValueAsBytes(r);
DataBuffer buffer = originalResponse.bufferFactory().wrap(bytes);
return originalResponse.writeWith(Flux.just(buffer));
} catch (JsonProcessingException e) {
e.printStackTrace();
return null;
}
}
@Override
public int getOrder() {
return FilterOrderConstant.getOrder(this.getClass().getName());
}
}
相关文章
暂无评论...