Skip to content

Commit 4ee9a07

Browse files
committed
Implements BodyFilterFunctions.modifyResponseBody
Fixes gh-3189
1 parent 276b1cc commit 4ee9a07

File tree

5 files changed

+195
-12
lines changed

5 files changed

+195
-12
lines changed

docs/modules/ROOT/nav.adoc

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@
8585
//*** xref:spring-cloud-gateway-server-mvc/filters/local-cache-response-filter.adoc[]
8686
*** xref:spring-cloud-gateway-server-mvc/filters/maprequestheader.adoc[]
8787
*** xref:spring-cloud-gateway-server-mvc/filters/modifyrequestbody.adoc[]
88-
//*** xref:spring-cloud-gateway-server-mvc/filters/modifyresponsebody.adoc[]
88+
*** xref:spring-cloud-gateway-server-mvc/filters/modifyresponsebody.adoc[]
8989
*** xref:spring-cloud-gateway-server-mvc/filters/prefixpath.adoc[]
9090
*** xref:spring-cloud-gateway-server-mvc/filters/preservehostheader.adoc[]
9191
*** xref:spring-cloud-gateway-server-mvc/filters/redirectto.adoc[]

docs/modules/ROOT/pages/spring-cloud-gateway-server-mvc/filters/modifyresponsebody.adoc

+8-11
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@ The following listing shows how to modify a response body filter:
1010
[source,java]
1111
----
1212
@Bean
13-
public RouteLocator routes(RouteLocatorBuilder builder) {
14-
return builder.routes()
15-
.route("rewrite_response_upper", r -> r.host("*.rewriteresponseupper.org")
16-
.filters(f -> f.prefixPath("/httpbin")
17-
.modifyResponseBody(String.class, String.class,
18-
(exchange, s) -> Mono.just(s.toUpperCase()))).uri(uri))
19-
.build();
13+
public RouterFunction<ServerResponse> gatewayRouterFunctionsModifyResponseBodySimple() {
14+
return route("modify_response_body")
15+
.GET("/anything/modifyresponsebody", http())
16+
.before(new HttpbinUriResolver())
17+
.after(modifyResponseBody(String.class, String.class, null,
18+
(request, response, s) -> s.replace("fooval", "FOOVAL")))
19+
.build();
2020
}
2121
----
2222
.GatewaySampleApplication.java
@@ -35,12 +35,9 @@ class RouteConfiguration {
3535
public RouterFunction<ServerResponse> gatewayRouterFunctionsAddReqHeader() {
3636
return route("rewrite_request_obj")
3737
.route(host("*.rewriterequestobj.org"), http("https://example.org"))
38-
.before(modifyResponseBody(String.class, String.class, MediaType.APPLICATION_JSON_VALUE, (request, s) -> s.toUpperCase()))
38+
.before(modifyResponseBody(String.class, String.class, MediaType.APPLICATION_JSON_VALUE, (request, response, s) -> s.toUpperCase()))
3939
.build();
4040
}
4141
4242
}
4343
----
44-
45-
NOTE: If the response has no body, the `RewriteFilter` is passed `null`. `Mono.empty()` should be returned to assign a missing body in the response.
46-

spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/filter/AfterFilterFunctions.java

+6
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,12 @@ private static void dedupeHeader(HttpHeaders headers, String name, DedupeStrateg
9292
}
9393
}
9494

95+
public static <T, R> BiFunction<ServerRequest, ServerResponse, ServerResponse> modifyResponseBody(Class<T> inClass,
96+
Class<R> outClass, String newContentType,
97+
BodyFilterFunctions.RewriteResponseFunction<T, R> rewriteFunction) {
98+
return BodyFilterFunctions.modifyResponseBody(inClass, outClass, newContentType, rewriteFunction);
99+
}
100+
95101
public static BiFunction<ServerRequest, ServerResponse, ServerResponse> removeResponseHeader(String name) {
96102
return (request, response) -> {
97103
response.headers().remove(name);

spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/filter/BodyFilterFunctions.java

+65
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import java.io.ByteArrayInputStream;
2020
import java.io.ByteArrayOutputStream;
2121
import java.io.IOException;
22+
import java.io.InputStream;
2223
import java.io.OutputStream;
2324
import java.io.UncheckedIOException;
2425
import java.net.InetSocketAddress;
@@ -44,6 +45,7 @@
4445
import org.springframework.cloud.gateway.server.mvc.common.MvcUtils;
4546
import org.springframework.core.ParameterizedTypeReference;
4647
import org.springframework.http.HttpHeaders;
48+
import org.springframework.http.HttpInputMessage;
4749
import org.springframework.http.HttpMethod;
4850
import org.springframework.http.HttpOutputMessage;
4951
import org.springframework.http.MediaType;
@@ -142,6 +144,65 @@ public static <T, R> Function<ServerRequest, ServerRequest> modifyRequestBody(Cl
142144
}).orElse(request);
143145
}
144146

147+
@SuppressWarnings({"unchecked", "rawtypes"})
148+
public static <T, R> BiFunction<ServerRequest, ServerResponse, ServerResponse> modifyResponseBody(Class<T> inClass, Class<R> outClass,
149+
String newContentType, RewriteResponseFunction<T, R> rewriteFunction) {
150+
return (request, response) -> {
151+
Object o = request.attributes().get(MvcUtils.CLIENT_RESPONSE_INPUT_STREAM_ATTR);
152+
if (o instanceof InputStream inputStream) {
153+
try {
154+
List<HttpMessageConverter<?>> converters = request.messageConverters();
155+
Optional<HttpMessageConverter<?>> inConverter = converters.stream().filter(c -> c.canRead(inClass, response.headers().getContentType())).findFirst();
156+
if (inConverter.isEmpty()) {
157+
//TODO: throw exception?
158+
return response;
159+
}
160+
HttpMessageConverter<?> inputConverter = inConverter.get();
161+
T input = (T) inputConverter.read((Class)inClass, new SimpleInputMessage(inputStream, response.headers()));
162+
R output = rewriteFunction.apply(request, response, input);
163+
164+
Optional<HttpMessageConverter<?>> outConverter = converters.stream().filter(c -> c.canWrite(outClass, null)).findFirst();
165+
if (outConverter.isEmpty()) {
166+
//TODO: throw exception?
167+
return response;
168+
}
169+
HttpMessageConverter<R> byteConverter = (HttpMessageConverter<R>) outConverter.get();
170+
ByteArrayHttpOutputMessage outputMessage = new ByteArrayHttpOutputMessage(response.headers());
171+
byteConverter.write(output, null, outputMessage);
172+
request.attributes().put(MvcUtils.CLIENT_RESPONSE_INPUT_STREAM_ATTR, new ByteArrayInputStream(outputMessage.body.toByteArray()));
173+
if (StringUtils.hasText(newContentType)) {
174+
response.headers().setContentType(MediaType.parseMediaType(newContentType));
175+
}
176+
response.headers().remove(HttpHeaders.CONTENT_LENGTH);
177+
}
178+
catch (IOException e) {
179+
throw new UncheckedIOException(e);
180+
}
181+
}
182+
return response;
183+
};
184+
}
185+
186+
private final static class SimpleInputMessage implements HttpInputMessage {
187+
private final InputStream inputStream;
188+
private final HttpHeaders headers;
189+
190+
private SimpleInputMessage(InputStream inputStream, HttpHeaders headers) {
191+
this.inputStream = inputStream;
192+
this.headers = headers;
193+
}
194+
195+
@Override
196+
public InputStream getBody() throws IOException {
197+
return this.inputStream;
198+
}
199+
200+
@Override
201+
public HttpHeaders getHeaders() {
202+
return this.headers;
203+
}
204+
}
205+
145206
private final static class ByteArrayHttpOutputMessage implements HttpOutputMessage {
146207

147208
private final HttpHeaders headers;
@@ -173,6 +234,10 @@ public interface RewriteFunction<T, R> extends BiFunction<ServerRequest, T, R> {
173234

174235
}
175236

237+
public interface RewriteResponseFunction<T, R> {
238+
R apply(ServerRequest request, ServerResponse response, T t);
239+
}
240+
176241
private static class ByteArrayServletInputStream extends ServletInputStream {
177242

178243
private final ByteArrayInputStream body;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
/*
2+
* Copyright 2013-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.cloud.gateway.server.mvc.filter;
18+
19+
import java.util.Map;
20+
21+
import org.junit.jupiter.api.Test;
22+
23+
import org.springframework.beans.factory.annotation.Autowired;
24+
import org.springframework.boot.SpringBootConfiguration;
25+
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
26+
import org.springframework.boot.test.context.SpringBootTest;
27+
import org.springframework.cloud.gateway.server.mvc.test.HttpbinTestcontainers;
28+
import org.springframework.cloud.gateway.server.mvc.test.HttpbinUriResolver;
29+
import org.springframework.cloud.gateway.server.mvc.test.TestLoadBalancerConfig;
30+
import org.springframework.cloud.gateway.server.mvc.test.client.TestRestClient;
31+
import org.springframework.cloud.loadbalancer.annotation.LoadBalancerClient;
32+
import org.springframework.context.annotation.Bean;
33+
import org.springframework.http.MediaType;
34+
import org.springframework.test.context.ContextConfiguration;
35+
import org.springframework.web.servlet.function.RouterFunction;
36+
import org.springframework.web.servlet.function.ServerResponse;
37+
38+
import static org.assertj.core.api.Assertions.assertThat;
39+
import static org.springframework.boot.test.context.SpringBootTest.WebEnvironment.*;
40+
import static org.springframework.cloud.gateway.server.mvc.filter.AfterFilterFunctions.modifyResponseBody;
41+
import static org.springframework.cloud.gateway.server.mvc.handler.GatewayRouterFunctions.route;
42+
import static org.springframework.cloud.gateway.server.mvc.handler.HandlerFunctions.http;
43+
import static org.springframework.cloud.gateway.server.mvc.test.TestUtils.getMap;
44+
45+
@SpringBootTest(webEnvironment = RANDOM_PORT)
46+
@ContextConfiguration(initializers = HttpbinTestcontainers.class)
47+
public class BodyFilterFunctionsTests {
48+
49+
@Autowired
50+
TestRestClient restClient;
51+
52+
@Test
53+
public void modifyResponseBodySimple() {
54+
restClient.get()
55+
.uri("/anything/modifyresponsebodysimple")
56+
.header("X-Foo", "fooval")
57+
.exchange()
58+
.expectStatus()
59+
.isOk()
60+
.expectBody(Map.class)
61+
.consumeWith(res -> {
62+
Map<String, Object> headers = getMap(res.getResponseBody(), "headers");
63+
assertThat(headers).containsEntry("X-Foo", "FOOVAL");
64+
});
65+
}
66+
67+
@Test
68+
public void modifyResponseBodyComplex() {
69+
restClient.get()
70+
.uri("/deny")
71+
.header("X-Foo", "fooval")
72+
.exchange()
73+
.expectStatus()
74+
.isOk()
75+
// deny returns text/plain
76+
.expectHeader().contentType(MediaType.APPLICATION_JSON)
77+
.expectBody(Message.class)
78+
.consumeWith(res -> {
79+
assertThat(res.getResponseBody().message()).isNotEmpty();
80+
});
81+
}
82+
83+
@SpringBootConfiguration
84+
@EnableAutoConfiguration
85+
@LoadBalancerClient(name = "httpbin", configuration = TestLoadBalancerConfig.Httpbin.class)
86+
protected static class TestConfiguration {
87+
@Bean
88+
public RouterFunction<ServerResponse> gatewayRouterFunctionsModifyResponseBodySimple() {
89+
// @formatter:off
90+
return route("modify_response_body_simple")
91+
.GET("/anything/modifyresponsebodysimple", http())
92+
.before(new HttpbinUriResolver())
93+
.after(modifyResponseBody(String.class, String.class, null,
94+
(request, response, s) -> s.replace("fooval", "FOOVAL")))
95+
.build();
96+
// @formatter:on
97+
}
98+
99+
@Bean
100+
public RouterFunction<ServerResponse> gatewayRouterFunctionsModifyResponseBodyComplex() {
101+
// @formatter:off
102+
return route("modify_response_body_simple")
103+
.GET("/deny", http())
104+
.before(new HttpbinUriResolver())
105+
.after(modifyResponseBody(String.class, Message.class, MediaType.APPLICATION_JSON_VALUE,
106+
(request, response, s) -> new Message(s)))
107+
.build();
108+
// @formatter:on
109+
}
110+
}
111+
112+
record Message(String message) {
113+
114+
}
115+
}

0 commit comments

Comments
 (0)