Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.springframework.cloud.gateway.filter.headers;

import java.net.InetSocketAddress;
import java.net.URI;
import java.util.LinkedHashSet;
import java.util.List;
Expand All @@ -32,6 +33,7 @@
import org.springframework.core.log.LogMessage;
import org.springframework.http.HttpHeaders;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.util.ObjectUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.server.ServerWebExchange;
Expand Down Expand Up @@ -220,11 +222,11 @@ public void setPrefixAppend(boolean prefixAppend) {
@Override
public HttpHeaders filter(HttpHeaders input, ServerWebExchange exchange) {
ServerHttpRequest request = exchange.getRequest();
InetSocketAddress remoteAddress = getRemoteAddress(request);

if (request.getRemoteAddress() != null
&& !trustedProxies.isTrusted(request.getRemoteAddress().getHostString())) {
if (remoteAddress != null && !trustedProxies.isTrusted(remoteAddress.getHostString())) {
log.trace(LogMessage.format("Remote address not trusted. pattern %s remote address %s", trustedProxies,
request.getRemoteAddress()));
remoteAddress));
return input;
}

Expand All @@ -237,8 +239,8 @@ public HttpHeaders filter(HttpHeaders input, ServerWebExchange exchange) {

if (isForEnabled()) {
String remoteAddr = null;
if (request.getRemoteAddress() != null && request.getRemoteAddress().getAddress() != null) {
remoteAddr = request.getRemoteAddress().getHostString();
if (remoteAddress != null && remoteAddress.getAddress() != null) {
remoteAddr = remoteAddress.getHostString();
}
// match xforwarded for against trusted proxies
write(updated, X_FORWARDED_FOR_HEADER, remoteAddr, isForAppend(), trustedProxies::isTrusted);
Expand Down Expand Up @@ -339,6 +341,11 @@ else if (value != null && shouldWrite.test(value)) {
}
}

private InetSocketAddress getRemoteAddress(ServerHttpRequest request) {
ServerHttpRequest nativeRequest = ServerHttpRequestDecorator.getNativeRequest(request);
return nativeRequest.getRemoteAddress();
}

private int getDefaultPort(String scheme) {
return HTTPS_SCHEME.equals(scheme) ? HTTPS_PORT : HTTP_PORT;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
import org.springframework.cloud.gateway.config.GatewayAutoConfiguration;
import org.springframework.cloud.gateway.config.GatewayProperties;
import org.springframework.http.HttpHeaders;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.web.server.ServerWebExchange;
Expand Down Expand Up @@ -379,6 +381,28 @@ public void xForwardedHeadersNotTrusted() throws Exception {
X_FORWARDED_PORT_HEADER, X_FORWARDED_PROTO_HEADER);
}

@Test
public void trustedProxiesUsesNativeRequestRemoteAddress() throws Exception {
MockServerHttpRequest nativeRequest = MockServerHttpRequest.get("http://localhost:8080/get")
.remoteAddress(new InetSocketAddress("10.0.0.1", 80))
.header(HttpHeaders.HOST, "myhost")
.build();
ServerHttpRequest request = new ServerHttpRequestDecorator(nativeRequest) {
@Override
public InetSocketAddress getRemoteAddress() {
return new InetSocketAddress("192.168.0.1", 80);
}
};

XForwardedHeadersFilter filter = new XForwardedHeadersFilter("10\\.0\\.0\\..*");

HttpHeaders headers = filter.filter(request.getHeaders(), MockServerWebExchange.from(request));

assertThat(headers.headerNames()).contains(X_FORWARDED_FOR_HEADER, X_FORWARDED_HOST_HEADER,
X_FORWARDED_PORT_HEADER, X_FORWARDED_PROTO_HEADER);
assertThat(headers.getFirst(X_FORWARDED_FOR_HEADER)).isEqualTo("10.0.0.1");
}

// : verify that existing x-forwarded-* headers are not forwarded
// if x-forwarded-for is not trusted
@Test
Expand Down