Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,48 @@ Postgres MCP Pro supports multiple *access modes* to give you control over the o
To use restricted mode, replace `--access-mode=unrestricted` with `--access-mode=restricted` in the configuration examples above.


##### Transport Security Configuration

Postgres MCP Pro includes DNS rebinding protection to secure the server against certain types of attacks.
By default, the server allows connections from common local and Docker hostnames.
Transport security applies only to network transports (`sse` and `streamable-http`), not `stdio`.

You can customize this behavior using CLI flags or environment variables (env vars take precedence over CLI flags):

| CLI Flag | Environment Variable | Description | Default |
|---|---|---|---|
| `--disable-dns-rebinding-protection` | `MCP_ENABLE_DNS_REBINDING_PROTECTION` | Enable/disable DNS rebinding protection | Enabled |
| `--allowed-hosts` | `MCP_ALLOWED_HOSTS` | Comma-separated allowed host patterns | `localhost:*,127.0.0.1:*,0.0.0.0:*,postgres-mcp-server:*,host.docker.internal:*` |
| `--allowed-origins` | `MCP_ALLOWED_ORIGINS` | Comma-separated allowed origins | Empty (allows any origin) |

For example, to restrict allowed hosts in your configuration:

```json
{
"mcpServers": {
"postgres": {
"command": "docker",
"args": [
"run",
"-i",
"--rm",
"-e",
"DATABASE_URI",
"-e",
"MCP_ALLOWED_HOSTS",
"crystaldba/postgres-mcp",
"--access-mode=unrestricted"
],
"env": {
"DATABASE_URI": "postgresql://username:password@localhost:5432/dbname",
"MCP_ALLOWED_HOSTS": "localhost:*,myapp.example.com:*"
}
}
}
}
```


#### Other MCP Clients

Many MCP clients have similar configuration files to Claude Desktop, and you can adapt the examples above to work with the client of your choice.
Expand Down
33 changes: 33 additions & 0 deletions src/postgres_mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import mcp.types as types
from mcp.server.fastmcp import FastMCP
from mcp.server.transport_security import TransportSecuritySettings
from mcp.types import ToolAnnotations
from pydantic import Field
from pydantic import validate_call
Expand Down Expand Up @@ -596,6 +597,24 @@ async def main():
default=8000,
help="Port for streamable HTTP server (default: 8000)",
)
parser.add_argument(
"--disable-dns-rebinding-protection",
action="store_true",
default=False,
help="Disable DNS rebinding protection (not recommended for production)",
)
parser.add_argument(
"--allowed-hosts",
type=str,
default=None,
help="Comma-separated allowed Host header values for DNS rebinding protection (e.g. 'localhost:*,127.0.0.1:*')",
)
parser.add_argument(
"--allowed-origins",
type=str,
default=None,
help="Comma-separated allowed Origin header values for DNS rebinding protection (e.g. 'http://localhost:*')",
)

args = parser.parse_args()

Expand Down Expand Up @@ -656,6 +675,20 @@ async def main():
logger.warning("Signal handling not supported on Windows")
pass

# Apply transport security settings (SSE and streamable-http only)
if args.transport in ("sse", "streamable-http"):
dns_env = os.environ.get("MCP_ENABLE_DNS_REBINDING_PROTECTION")
protection_off = dns_env.lower() in ("false", "0", "no") if dns_env else args.disable_dns_rebinding_protection
hosts = os.environ.get("MCP_ALLOWED_HOSTS", args.allowed_hosts)
origins = os.environ.get("MCP_ALLOWED_ORIGINS", args.allowed_origins)

if protection_off or hosts or origins:
mcp.settings.transport_security = TransportSecuritySettings(
enable_dns_rebinding_protection=not protection_off,
**{"allowed_hosts": [h.strip() for h in hosts.split(",") if h.strip()]} if hosts else {},
**{"allowed_origins": [o.strip() for o in origins.split(",") if o.strip()]} if origins else {},
)

# Run the server with the selected transport (always async)
if args.transport == "stdio":
await mcp.run_stdio_async()
Expand Down
224 changes: 224 additions & 0 deletions tests/unit/test_transport_security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
import sys
from unittest.mock import AsyncMock
from unittest.mock import patch

import pytest

_TRANSPORT_MOCK_MAP = {
"sse": "postgres_mcp.server.mcp.run_sse_async",
"streamable-http": "postgres_mcp.server.mcp.run_streamable_http_async",
}

_MCP_ENV_KEYS = [
"MCP_ENABLE_DNS_REBINDING_PROTECTION",
"MCP_ALLOWED_HOSTS",
"MCP_ALLOWED_ORIGINS",
]


@pytest.mark.parametrize("transport", ["sse", "streamable-http"])
class TestTransportSecurityIntegration:
@pytest.fixture(autouse=True)
def _preserve_mcp_state(self, monkeypatch: pytest.MonkeyPatch):
from postgres_mcp.server import mcp

original_argv = sys.argv
original_security = mcp.settings.transport_security
for key in _MCP_ENV_KEYS:
monkeypatch.delenv(key, raising=False)
yield
sys.argv = original_argv
mcp.settings.transport_security = original_security

@pytest.mark.asyncio
async def test_disable_dns_rebinding_via_cli_flag(self, transport: str):
from postgres_mcp.server import main
from postgres_mcp.server import mcp

sys.argv = [
"postgres_mcp",
"postgresql://user:password@localhost/db",
f"--transport={transport}",
"--disable-dns-rebinding-protection",
]

with (
patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()),
patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()),
):
await main()
assert mcp.settings.transport_security is not None
assert mcp.settings.transport_security.enable_dns_rebinding_protection is False

@pytest.mark.asyncio
async def test_disable_dns_rebinding_via_env(self, transport: str):
from postgres_mcp.server import main
from postgres_mcp.server import mcp

sys.argv = [
"postgres_mcp",
"postgresql://user:password@localhost/db",
f"--transport={transport}",
]

with (
patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()),
patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()),
patch.dict("os.environ", {"MCP_ENABLE_DNS_REBINDING_PROTECTION": "false"}),
):
await main()
assert mcp.settings.transport_security is not None
assert mcp.settings.transport_security.enable_dns_rebinding_protection is False

@pytest.mark.asyncio
async def test_allowed_hosts_via_cli(self, transport: str):
from postgres_mcp.server import main
from postgres_mcp.server import mcp

sys.argv = [
"postgres_mcp",
"postgresql://user:password@localhost/db",
f"--transport={transport}",
"--allowed-hosts",
"localhost:*,127.0.0.1:*",
]

with (
patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()),
patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()),
):
await main()
assert mcp.settings.transport_security is not None
assert "localhost:*" in mcp.settings.transport_security.allowed_hosts
assert "127.0.0.1:*" in mcp.settings.transport_security.allowed_hosts

@pytest.mark.asyncio
async def test_allowed_hosts_env_overrides_cli(self, transport: str):
from postgres_mcp.server import main
from postgres_mcp.server import mcp

sys.argv = [
"postgres_mcp",
"postgresql://user:password@localhost/db",
f"--transport={transport}",
"--allowed-hosts",
"cli-host:*",
]

with (
patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()),
patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()),
patch.dict("os.environ", {"MCP_ALLOWED_HOSTS": "env-host:*"}),
):
await main()
assert mcp.settings.transport_security is not None
assert "env-host:*" in mcp.settings.transport_security.allowed_hosts
assert "cli-host:*" not in mcp.settings.transport_security.allowed_hosts

@pytest.mark.asyncio
async def test_allowed_origins_via_cli(self, transport: str):
from postgres_mcp.server import main
from postgres_mcp.server import mcp

sys.argv = [
"postgres_mcp",
"postgresql://user:password@localhost/db",
f"--transport={transport}",
"--allowed-origins",
"http://localhost:*",
]

with (
patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()),
patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()),
):
await main()
assert mcp.settings.transport_security is not None
assert "http://localhost:*" in mcp.settings.transport_security.allowed_origins

@pytest.mark.asyncio
async def test_allowed_origins_env_overrides_cli(self, transport: str):
from postgres_mcp.server import main
from postgres_mcp.server import mcp

sys.argv = [
"postgres_mcp",
"postgresql://user:password@localhost/db",
f"--transport={transport}",
"--allowed-origins",
"http://cli-origin:*",
]

with (
patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()),
patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()),
patch.dict("os.environ", {"MCP_ALLOWED_ORIGINS": "http://env-origin:*"}),
):
await main()
assert mcp.settings.transport_security is not None
assert "http://env-origin:*" in mcp.settings.transport_security.allowed_origins
assert "http://cli-origin:*" not in mcp.settings.transport_security.allowed_origins

@pytest.mark.asyncio
async def test_env_protection_true_overrides_cli_disable(self, transport: str):
from postgres_mcp.server import main
from postgres_mcp.server import mcp

sys.argv = [
"postgres_mcp",
"postgresql://user:password@localhost/db",
f"--transport={transport}",
"--disable-dns-rebinding-protection",
]

with (
patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()),
patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()),
patch.dict("os.environ", {"MCP_ENABLE_DNS_REBINDING_PROTECTION": "true"}),
):
await main()
assert mcp.settings.transport_security is not None
assert mcp.settings.transport_security.enable_dns_rebinding_protection is True

@pytest.mark.asyncio
async def test_default_defers_to_fastmcp(self, transport: str):
from postgres_mcp.server import main
from postgres_mcp.server import mcp

sys.argv = [
"postgres_mcp",
"postgresql://user:password@localhost/db",
f"--transport={transport}",
]

with (
patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()),
patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()),
):
await main()
assert mcp.settings.transport_security is not None
assert mcp.settings.transport_security.enable_dns_rebinding_protection is True

@pytest.mark.asyncio
async def test_database_url_after_flags_not_consumed(self, transport: str):
from postgres_mcp.server import main
from postgres_mcp.server import mcp

sys.argv = [
"postgres_mcp",
f"--transport={transport}",
"--allowed-hosts",
"localhost:*,my-gateway:8080",
"--allowed-origins",
"http://localhost:*,http://my-gateway:*",
"postgresql://user:password@localhost/db",
]

with (
patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()),
patch(_TRANSPORT_MOCK_MAP[transport], AsyncMock()),
):
await main()
assert mcp.settings.transport_security is not None
assert "localhost:*" in mcp.settings.transport_security.allowed_hosts
assert "my-gateway:8080" in mcp.settings.transport_security.allowed_hosts