Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 35 additions & 15 deletions src/mcp/server/transport_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import logging

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
from starlette.requests import Request
from starlette.responses import Response
from typing_extensions import Self

logger = logging.getLogger(__name__)

Expand All @@ -31,6 +32,17 @@ class TransportSecuritySettings(BaseModel):
Only applies when `enable_dns_rebinding_protection` is `True`.
"""

@model_validator(mode="after")
def _warn_if_protection_enabled_with_empty_allowlist(self) -> Self:
if self.enable_dns_rebinding_protection and not self.allowed_hosts:
logger.warning(
"TransportSecuritySettings has DNS rebinding protection enabled but "
"allowed_hosts is empty — all requests will be rejected with HTTP 421. "
"Set allowed_hosts to your server's hostname(s), e.g. "
'TransportSecuritySettings(allowed_hosts=["your-host.example.com:*"])'
)
return self


# TODO(Marcelo): This should be a proper ASGI middleware. I'm sad to see this.
class TransportSecurityMiddleware:
Expand All @@ -40,7 +52,7 @@ def __init__(self, settings: TransportSecuritySettings | None = None):
# If not specified, disable DNS rebinding protection by default for backwards compatibility
self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False)

def _validate_host(self, host: str | None) -> bool: # pragma: no cover
def _validate_host(self, host: str | None) -> bool:
"""Validate the Host header against allowed values."""
if not host:
logger.warning("Missing Host header in request")
Expand All @@ -62,7 +74,7 @@ def _validate_host(self, host: str | None) -> bool: # pragma: no cover
logger.warning(f"Invalid Host header: {host}")
return False

def _validate_origin(self, origin: str | None) -> bool: # pragma: no cover
def _validate_origin(self, origin: str | None) -> bool:
"""Validate the Origin header against allowed values."""
# Origin can be absent for same-origin requests
if not origin:
Expand Down Expand Up @@ -94,7 +106,7 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res
Returns None if validation passes, or an error Response if validation fails.
"""
# Always validate Content-Type for POST requests
if is_post: # pragma: no branch
if is_post:
content_type = request.headers.get("content-type")
if not self._validate_content_type(content_type):
return Response("Invalid Content-Type header", status_code=400)
Expand All @@ -103,14 +115,22 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res
if not self.settings.enable_dns_rebinding_protection:
return None

# Validate Host header # pragma: no cover
host = request.headers.get("host") # pragma: no cover
if not self._validate_host(host): # pragma: no cover
return Response("Invalid Host header", status_code=421) # pragma: no cover

# Validate Origin header # pragma: no cover
origin = request.headers.get("origin") # pragma: no cover
if not self._validate_origin(origin): # pragma: no cover
return Response("Invalid Origin header", status_code=403) # pragma: no cover

return None # pragma: no cover
# Validate Host header
host = request.headers.get("host")
if not self._validate_host(host):
return Response(
f"Invalid Host header: {host!r}. "
"Configure TransportSecuritySettings(allowed_hosts=[...]) with your server's hostname.",
status_code=421,
)

# Validate Origin header
origin = request.headers.get("origin")
if not self._validate_origin(origin):
return Response(
f"Invalid Origin header: {origin!r}. "
"Configure TransportSecuritySettings(allowed_origins=[...]) with your server's origin.",
status_code=403,
)

return None
6 changes: 3 additions & 3 deletions tests/server/test_sse_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ async def test_sse_security_invalid_host_header(server_port: int):
async with httpx.AsyncClient() as client:
response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers)
assert response.status_code == 421
assert response.text == "Invalid Host header"
assert "Invalid Host header" in response.text

finally:
process.terminate()
Expand All @@ -128,7 +128,7 @@ async def test_sse_security_invalid_origin_header(server_port: int):
async with httpx.AsyncClient() as client:
response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers)
assert response.status_code == 403
assert response.text == "Invalid Origin header"
assert "Invalid Origin header" in response.text

finally:
process.terminate()
Expand Down Expand Up @@ -215,7 +215,7 @@ async def test_sse_security_custom_allowed_hosts(server_port: int):
async with httpx.AsyncClient() as client:
response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers)
assert response.status_code == 421
assert response.text == "Invalid Host header"
assert "Invalid Host header" in response.text

finally:
process.terminate()
Expand Down
6 changes: 3 additions & 3 deletions tests/server/test_streamable_http_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ async def test_streamable_http_security_invalid_host_header(server_port: int):
headers=headers,
)
assert response.status_code == 421
assert response.text == "Invalid Host header"
assert "Invalid Host header" in response.text

finally:
process.terminate()
Expand Down Expand Up @@ -154,7 +154,7 @@ async def test_streamable_http_security_invalid_origin_header(server_port: int):
headers=headers,
)
assert response.status_code == 403
assert response.text == "Invalid Origin header"
assert "Invalid Origin header" in response.text

finally:
process.terminate()
Expand Down Expand Up @@ -269,7 +269,7 @@ async def test_streamable_http_security_get_request(server_port: int):
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.get(f"http://127.0.0.1:{server_port}/", headers=headers)
assert response.status_code == 421
assert response.text == "Invalid Host header"
assert "Invalid Host header" in response.text

# Test GET request with valid host header
headers = {
Expand Down
182 changes: 182 additions & 0 deletions tests/server/test_transport_security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
"""Unit tests for TransportSecuritySettings and TransportSecurityMiddleware."""

import logging

import pytest
from starlette.requests import Request

from mcp.server.transport_security import TransportSecurityMiddleware, TransportSecuritySettings


def make_request(headers: dict[str, str], method: str = "GET") -> Request:
scope = {
"type": "http",
"method": method,
"headers": [(k.lower().encode(), v.encode()) for k, v in headers.items()],
"path": "/",
"query_string": b"",
}
return Request(scope)


# ---------------------------------------------------------------------------
# TransportSecuritySettings — construction-time warning
# ---------------------------------------------------------------------------


def test_no_warning_when_protection_disabled(caplog: pytest.LogCaptureFixture) -> None:
with caplog.at_level(logging.WARNING, logger="mcp.server.transport_security"):
TransportSecuritySettings(enable_dns_rebinding_protection=False)
assert not caplog.records


def test_no_warning_when_allowed_hosts_populated(caplog: pytest.LogCaptureFixture) -> None:
with caplog.at_level(logging.WARNING, logger="mcp.server.transport_security"):
TransportSecuritySettings(
enable_dns_rebinding_protection=True,
allowed_hosts=["example.com"],
)
assert not caplog.records


def test_warning_when_protection_enabled_with_empty_allowed_hosts(caplog: pytest.LogCaptureFixture) -> None:
with caplog.at_level(logging.WARNING, logger="mcp.server.transport_security"):
TransportSecuritySettings(enable_dns_rebinding_protection=True)
assert len(caplog.records) == 1
assert "allowed_hosts is empty" in caplog.records[0].message
assert "HTTP 421" in caplog.records[0].message
assert "allowed_hosts=" in caplog.records[0].message


# ---------------------------------------------------------------------------
# TransportSecurityMiddleware._validate_host
# ---------------------------------------------------------------------------


def test_validate_host_missing_host() -> None:
m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_hosts=["example.com"]))
assert m._validate_host(None) is False


def test_validate_host_exact_match() -> None:
m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_hosts=["example.com"]))
assert m._validate_host("example.com") is True


def test_validate_host_exact_no_match() -> None:
m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_hosts=["example.com"]))
assert m._validate_host("other.com") is False


def test_validate_host_port_wildcard_match() -> None:
m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_hosts=["localhost:*"]))
assert m._validate_host("localhost:8080") is True


def test_validate_host_port_wildcard_different_base() -> None:
m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_hosts=["localhost:*"]))
assert m._validate_host("other:8080") is False


def test_validate_host_port_wildcard_no_port() -> None:
m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_hosts=["localhost:*"]))
assert m._validate_host("localhost") is False


# ---------------------------------------------------------------------------
# TransportSecurityMiddleware._validate_origin
# ---------------------------------------------------------------------------


def test_validate_origin_absent_is_allowed() -> None:
m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_origins=["http://example.com"]))
assert m._validate_origin(None) is True


def test_validate_origin_exact_match() -> None:
m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_origins=["http://example.com"]))
assert m._validate_origin("http://example.com") is True


def test_validate_origin_exact_no_match() -> None:
m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_origins=["http://example.com"]))
assert m._validate_origin("http://other.com") is False


def test_validate_origin_port_wildcard_match() -> None:
m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_origins=["http://localhost:*"]))
assert m._validate_origin("http://localhost:3000") is True


def test_validate_origin_port_wildcard_different_base() -> None:
m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_origins=["http://localhost:*"]))
assert m._validate_origin("http://other:3000") is False


# ---------------------------------------------------------------------------
# TransportSecurityMiddleware.validate_request
# ---------------------------------------------------------------------------


@pytest.mark.anyio
async def test_validate_request_post_valid_content_type() -> None:
m = TransportSecurityMiddleware(TransportSecuritySettings(enable_dns_rebinding_protection=False))
request = make_request({"content-type": "application/json"}, method="POST")
assert await m.validate_request(request, is_post=True) is None


@pytest.mark.anyio
async def test_validate_request_post_invalid_content_type() -> None:
m = TransportSecurityMiddleware(TransportSecuritySettings(enable_dns_rebinding_protection=False))
request = make_request({"content-type": "text/plain"}, method="POST")
response = await m.validate_request(request, is_post=True)
assert response is not None
assert response.status_code == 400


@pytest.mark.anyio
async def test_validate_request_get_skips_content_type() -> None:
m = TransportSecurityMiddleware(TransportSecuritySettings(enable_dns_rebinding_protection=False))
request = make_request({})
assert await m.validate_request(request, is_post=False) is None


@pytest.mark.anyio
async def test_validate_request_protection_disabled_allows_any_host() -> None:
m = TransportSecurityMiddleware(TransportSecuritySettings(enable_dns_rebinding_protection=False))
request = make_request({"host": "attacker.example.com"})
assert await m.validate_request(request) is None


@pytest.mark.anyio
async def test_validate_request_valid_host_and_no_origin() -> None:
m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_hosts=["example.com"]))
request = make_request({"host": "example.com"})
assert await m.validate_request(request) is None


@pytest.mark.anyio
async def test_validate_request_invalid_host_returns_421_with_detail() -> None:
m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_hosts=["example.com"]))
request = make_request({"host": "attacker.com"})
response = await m.validate_request(request)
assert response is not None
assert response.status_code == 421
assert b"attacker.com" in response.body
assert b"allowed_hosts" in response.body


@pytest.mark.anyio
async def test_validate_request_invalid_origin_returns_403_with_detail() -> None:
m = TransportSecurityMiddleware(
TransportSecuritySettings(
allowed_hosts=["example.com"],
allowed_origins=["http://example.com"],
)
)
request = make_request({"host": "example.com", "origin": "http://attacker.com"})
response = await m.validate_request(request)
assert response is not None
assert response.status_code == 403
assert b"attacker.com" in response.body
assert b"allowed_origins" in response.body
Loading