diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java index 047aeebe8..876bb5e87 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java @@ -142,6 +142,12 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) return; } + String contentType = request.getContentType(); + if (contentType == null || !contentType.startsWith(APPLICATION_JSON)) { + response.sendError(HttpServletResponse.SC_UNSUPPORTED_MEDIA_TYPE, "Content-Type must be application/json"); + return; + } + McpTransportContext transportContext = this.contextExtractor.extract(request); String accept = request.getHeader(ACCEPT); diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java index 9a785e150..83872ae76 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java @@ -415,6 +415,12 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) return; } + String contentType = request.getContentType(); + if (contentType == null || !contentType.startsWith(APPLICATION_JSON)) { + response.sendError(HttpServletResponse.SC_UNSUPPORTED_MEDIA_TYPE, "Content-Type must be application/json"); + return; + } + List badRequestErrors = new ArrayList<>(); String accept = request.getHeader(ACCEPT); @@ -450,6 +456,17 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) McpSchema.InitializeRequest initializeRequest = jsonMapper.convertValue(jsonrpcRequest.params(), new TypeRef() { }); + + String headerVersion = request.getHeader(HttpHeaders.PROTOCOL_VERSION); + if (headerVersion != null && !headerVersion.equals(initializeRequest.protocolVersion())) { + this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, McpError + .builder(McpSchema.ErrorCodes.INVALID_REQUEST) + .message("MCP-Protocol-Version header '" + headerVersion + + "' does not match body protocolVersion '" + initializeRequest.protocolVersion() + "'") + .build()); + return; + } + McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory .startSession(initializeRequest); this.sessions.put(init.session().getId(), init.session()); diff --git a/mcp-test/src/test/java/io/modelcontextprotocol/server/transport/HttpTransportValidationTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/server/transport/HttpTransportValidationTests.java new file mode 100644 index 000000000..be476e724 --- /dev/null +++ b/mcp-test/src/test/java/io/modelcontextprotocol/server/transport/HttpTransportValidationTests.java @@ -0,0 +1,178 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.time.Duration; +import java.util.stream.Stream; + +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.ProtocolVersions; +import jakarta.servlet.http.HttpServlet; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.BeforeParameterizedClassInvocation; +import org.junit.jupiter.params.Parameter; +import org.junit.jupiter.params.ParameterizedClass; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Named.named; +import static org.junit.jupiter.params.provider.Arguments.arguments; + +/** + * Validates Content-Type and protocol version enforcement in HTTP servlet transports. + * + * @author Gorre Surya + */ +@ParameterizedClass +@MethodSource("transports") +class HttpTransportValidationTests { + + private static final String ACCEPT_HEADER = "application/json, text/event-stream"; + + private static final String INITIALIZE_BODY = """ + {"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"%s","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}} + """ + .formatted(ProtocolVersions.MCP_2025_11_25) + .strip(); + + @Parameter + private static TransportServer transportServer; + + private static Tomcat tomcat; + + private static String baseUrl; + + private static HttpClient httpClient; + + @BeforeParameterizedClassInvocation + static void setUp(TransportServer transport) { + transportServer = transport; + var port = TomcatTestUtil.findAvailablePort(); + baseUrl = "http://localhost:" + port; + tomcat = TomcatTestUtil.createTomcatServer("", port, transportServer.servlet()); + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + httpClient = HttpClient.newBuilder().connectTimeout(Duration.ofSeconds(5)).build(); + } + + @AfterAll + static void tearDown() { + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + @Test + void postWithNonJsonContentTypeReturns415() throws Exception { + var request = HttpRequest.newBuilder() + .uri(URI.create(baseUrl + "/mcp")) + .header("Content-Type", "text/plain") + .header("Accept", ACCEPT_HEADER) + .POST(HttpRequest.BodyPublishers.ofString(INITIALIZE_BODY)) + .build(); + + var response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); + + assertThat(response.statusCode()).isEqualTo(415); + } + + @Test + void postWithMissingContentTypeReturns415() throws Exception { + var request = HttpRequest.newBuilder() + .uri(URI.create(baseUrl + "/mcp")) + .header("Accept", ACCEPT_HEADER) + .POST(HttpRequest.BodyPublishers.ofString(INITIALIZE_BODY)) + .build(); + + var response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); + + assertThat(response.statusCode()).isEqualTo(415); + } + + @Test + void postWithJsonContentTypeIncludingCharsetSucceeds() throws Exception { + var request = HttpRequest.newBuilder() + .uri(URI.create(baseUrl + "/mcp")) + .header("Content-Type", "application/json; charset=utf-8") + .header("Accept", ACCEPT_HEADER) + .POST(HttpRequest.BodyPublishers.ofString(INITIALIZE_BODY)) + .build(); + + var response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); + + assertThat(response.statusCode()).isIn(200, 202); + } + + static Stream transports() { + return Stream.of(arguments(named("Streamable HTTP", new StreamableHttpTransportServer())), + arguments(named("Stateless", new StatelessTransportServer()))); + } + + interface TransportServer { + + HttpServlet servlet(); + + } + + static class StreamableHttpTransportServer implements TransportServer { + + private final HttpServletStreamableServerTransportProvider transport; + + StreamableHttpTransportServer() { + transport = HttpServletStreamableServerTransportProvider.builder().build(); + McpServer.sync(transport) + .serverInfo("test-server", "1.0.0") + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .build(); + } + + @Override + public HttpServlet servlet() { + return transport; + } + + } + + static class StatelessTransportServer implements TransportServer { + + private final HttpServletStatelessServerTransport transport; + + StatelessTransportServer() { + transport = HttpServletStatelessServerTransport.builder().build(); + McpServer.sync(transport) + .serverInfo("test-server", "1.0.0") + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .build(); + } + + @Override + public HttpServlet servlet() { + return transport; + } + + } + +} diff --git a/mcp-test/src/test/java/io/modelcontextprotocol/server/transport/StreamableTransportProtocolVersionTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/server/transport/StreamableTransportProtocolVersionTests.java new file mode 100644 index 000000000..30b2d7f1b --- /dev/null +++ b/mcp-test/src/test/java/io/modelcontextprotocol/server/transport/StreamableTransportProtocolVersionTests.java @@ -0,0 +1,130 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.time.Duration; + +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.spec.HttpHeaders; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.ProtocolVersions; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Validates MCP-Protocol-Version header consistency enforcement in + * {@link HttpServletStreamableServerTransportProvider}. + * + * @author Gorre Surya + */ +class StreamableTransportProtocolVersionTests { + + private static final String ACCEPT_HEADER = "application/json, text/event-stream"; + + private static final String CONTENT_TYPE = "application/json"; + + private static Tomcat tomcat; + + private static String baseUrl; + + private static HttpClient httpClient; + + @BeforeAll + static void setUp() throws Exception { + var port = TomcatTestUtil.findAvailablePort(); + baseUrl = "http://localhost:" + port; + + var transport = HttpServletStreamableServerTransportProvider.builder().build(); + McpServer.sync(transport) + .serverInfo("test-server", "1.0.0") + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .build(); + + tomcat = TomcatTestUtil.createTomcatServer("", port, transport); + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + + httpClient = HttpClient.newBuilder().connectTimeout(Duration.ofSeconds(5)).build(); + } + + @AfterAll + static void tearDown() { + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + @Test + void initializeWithMatchingProtocolVersionHeaderSucceeds() throws Exception { + var body = initializeBody(ProtocolVersions.MCP_2025_11_25); + var request = HttpRequest.newBuilder() + .uri(URI.create(baseUrl + "/mcp")) + .header("Content-Type", CONTENT_TYPE) + .header("Accept", ACCEPT_HEADER) + .header(HttpHeaders.PROTOCOL_VERSION, ProtocolVersions.MCP_2025_11_25) + .POST(HttpRequest.BodyPublishers.ofString(body)) + .build(); + + var response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); + + assertThat(response.statusCode()).isIn(200, 202); + } + + @Test + void initializeWithAbsentProtocolVersionHeaderSucceeds() throws Exception { + var body = initializeBody(ProtocolVersions.MCP_2025_11_25); + var request = HttpRequest.newBuilder() + .uri(URI.create(baseUrl + "/mcp")) + .header("Content-Type", CONTENT_TYPE) + .header("Accept", ACCEPT_HEADER) + .POST(HttpRequest.BodyPublishers.ofString(body)) + .build(); + + var response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); + + assertThat(response.statusCode()).isIn(200, 202); + } + + @Test + void initializeWithMismatchedProtocolVersionHeaderReturns400() throws Exception { + var body = initializeBody(ProtocolVersions.MCP_2025_11_25); + var request = HttpRequest.newBuilder() + .uri(URI.create(baseUrl + "/mcp")) + .header("Content-Type", CONTENT_TYPE) + .header("Accept", ACCEPT_HEADER) + .header(HttpHeaders.PROTOCOL_VERSION, ProtocolVersions.MCP_2024_11_05) + .POST(HttpRequest.BodyPublishers.ofString(body)) + .build(); + + var response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); + + assertThat(response.statusCode()).isEqualTo(400); + } + + private static String initializeBody(String protocolVersion) { + return """ + {"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"%s","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}} + """ + .formatted(protocolVersion) + .strip(); + } + +}