diff --git a/spring-cloud-gateway-server-webflux/src/main/java/org/springframework/cloud/gateway/filter/headers/XForwardedHeadersFilter.java b/spring-cloud-gateway-server-webflux/src/main/java/org/springframework/cloud/gateway/filter/headers/XForwardedHeadersFilter.java index 584cb6b1ca..82ad39347d 100644 --- a/spring-cloud-gateway-server-webflux/src/main/java/org/springframework/cloud/gateway/filter/headers/XForwardedHeadersFilter.java +++ b/spring-cloud-gateway-server-webflux/src/main/java/org/springframework/cloud/gateway/filter/headers/XForwardedHeadersFilter.java @@ -16,6 +16,7 @@ package org.springframework.cloud.gateway.filter.headers; +import java.net.InetSocketAddress; import java.net.URI; import java.util.LinkedHashSet; import java.util.List; @@ -32,6 +33,7 @@ import org.springframework.core.log.LogMessage; import org.springframework.http.HttpHeaders; import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.http.server.reactive.ServerHttpRequestDecorator; import org.springframework.util.ObjectUtils; import org.springframework.util.StringUtils; import org.springframework.web.server.ServerWebExchange; @@ -220,11 +222,11 @@ public void setPrefixAppend(boolean prefixAppend) { @Override public HttpHeaders filter(HttpHeaders input, ServerWebExchange exchange) { ServerHttpRequest request = exchange.getRequest(); + InetSocketAddress remoteAddress = getRemoteAddress(request); - if (request.getRemoteAddress() != null - && !trustedProxies.isTrusted(request.getRemoteAddress().getHostString())) { + if (remoteAddress != null && !trustedProxies.isTrusted(remoteAddress.getHostString())) { log.trace(LogMessage.format("Remote address not trusted. pattern %s remote address %s", trustedProxies, - request.getRemoteAddress())); + remoteAddress)); return input; } @@ -237,8 +239,8 @@ public HttpHeaders filter(HttpHeaders input, ServerWebExchange exchange) { if (isForEnabled()) { String remoteAddr = null; - if (request.getRemoteAddress() != null && request.getRemoteAddress().getAddress() != null) { - remoteAddr = request.getRemoteAddress().getHostString(); + if (remoteAddress != null && remoteAddress.getAddress() != null) { + remoteAddr = remoteAddress.getHostString(); } // match xforwarded for against trusted proxies write(updated, X_FORWARDED_FOR_HEADER, remoteAddr, isForAppend(), trustedProxies::isTrusted); @@ -339,6 +341,11 @@ else if (value != null && shouldWrite.test(value)) { } } + private InetSocketAddress getRemoteAddress(ServerHttpRequest request) { + ServerHttpRequest nativeRequest = ServerHttpRequestDecorator.getNativeRequest(request); + return nativeRequest.getRemoteAddress(); + } + private int getDefaultPort(String scheme) { return HTTPS_SCHEME.equals(scheme) ? HTTPS_PORT : HTTP_PORT; } diff --git a/spring-cloud-gateway-server-webflux/src/test/java/org/springframework/cloud/gateway/filter/headers/XForwardedHeadersFilterTests.java b/spring-cloud-gateway-server-webflux/src/test/java/org/springframework/cloud/gateway/filter/headers/XForwardedHeadersFilterTests.java index 5a32b67125..c3e47b9675 100644 --- a/spring-cloud-gateway-server-webflux/src/test/java/org/springframework/cloud/gateway/filter/headers/XForwardedHeadersFilterTests.java +++ b/spring-cloud-gateway-server-webflux/src/test/java/org/springframework/cloud/gateway/filter/headers/XForwardedHeadersFilterTests.java @@ -32,6 +32,8 @@ import org.springframework.cloud.gateway.config.GatewayAutoConfiguration; import org.springframework.cloud.gateway.config.GatewayProperties; import org.springframework.http.HttpHeaders; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.http.server.reactive.ServerHttpRequestDecorator; import org.springframework.mock.http.server.reactive.MockServerHttpRequest; import org.springframework.mock.web.server.MockServerWebExchange; import org.springframework.web.server.ServerWebExchange; @@ -379,6 +381,28 @@ public void xForwardedHeadersNotTrusted() throws Exception { X_FORWARDED_PORT_HEADER, X_FORWARDED_PROTO_HEADER); } + @Test + public void trustedProxiesUsesNativeRequestRemoteAddress() throws Exception { + MockServerHttpRequest nativeRequest = MockServerHttpRequest.get("http://localhost:8080/get") + .remoteAddress(new InetSocketAddress("10.0.0.1", 80)) + .header(HttpHeaders.HOST, "myhost") + .build(); + ServerHttpRequest request = new ServerHttpRequestDecorator(nativeRequest) { + @Override + public InetSocketAddress getRemoteAddress() { + return new InetSocketAddress("192.168.0.1", 80); + } + }; + + XForwardedHeadersFilter filter = new XForwardedHeadersFilter("10\\.0\\.0\\..*"); + + HttpHeaders headers = filter.filter(request.getHeaders(), MockServerWebExchange.from(request)); + + assertThat(headers.headerNames()).contains(X_FORWARDED_FOR_HEADER, X_FORWARDED_HOST_HEADER, + X_FORWARDED_PORT_HEADER, X_FORWARDED_PROTO_HEADER); + assertThat(headers.getFirst(X_FORWARDED_FOR_HEADER)).isEqualTo("10.0.0.1"); + } + // : verify that existing x-forwarded-* headers are not forwarded // if x-forwarded-for is not trusted @Test