diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractResourceIntegrationTest.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractResourceIntegrationTest.kt index fc088103..40b95520 100644 --- a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractResourceIntegrationTest.kt +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractResourceIntegrationTest.kt @@ -19,7 +19,6 @@ import kotlinx.coroutines.test.runTest import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows import java.util.concurrent.atomic.AtomicBoolean -import kotlin.test.Ignore import kotlin.test.assertEquals import kotlin.test.assertNotNull import kotlin.test.assertTrue @@ -165,7 +164,6 @@ abstract class AbstractResourceIntegrationTest : KotlinTestBase() { assertEquals(testResourceContent, content.text, "Resource content should match") } - @Ignore("Blocked by https://github.com/modelcontextprotocol/kotlin-sdk/issues/249") @Test fun testSubscribeAndUnsubscribe() { runBlocking(Dispatchers.IO) { diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/KotlinTestBase.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/KotlinTestBase.kt index f8e8740e..2f8b3872 100644 --- a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/KotlinTestBase.kt +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/KotlinTestBase.kt @@ -148,7 +148,9 @@ abstract class KotlinTestBase { // Create StreamableHTTP server transport // Using JSON response mode for simpler testing (no SSE session required) val transport = StreamableHttpServerTransport( - enableJsonResponse = true, // Use JSON response mode for testing + StreamableHttpServerTransport.Configuration( + enableJsonResponse = true, // Use JSON response mode for testing + ), ) // Use stateless mode to skip session validation for simpler testing transport.setSessionIdGenerator(null) diff --git a/kotlin-sdk-server/api/kotlin-sdk-server.api b/kotlin-sdk-server/api/kotlin-sdk-server.api index 15bbe5ee..1b7198e0 100644 --- a/kotlin-sdk-server/api/kotlin-sdk-server.api +++ b/kotlin-sdk-server/api/kotlin-sdk-server.api @@ -38,10 +38,10 @@ public final class io/modelcontextprotocol/kotlin/sdk/server/KtorServerKt { public static final fun mcp (Lio/ktor/server/application/Application;Lkotlin/jvm/functions/Function1;)V public static final fun mcp (Lio/ktor/server/routing/Route;Ljava/lang/String;Lkotlin/jvm/functions/Function1;)V public static final fun mcp (Lio/ktor/server/routing/Route;Lkotlin/jvm/functions/Function1;)V - public static final fun mcpStatelessStreamableHttp (Lio/ktor/server/application/Application;Ljava/lang/String;ZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/jvm/functions/Function1;)V - public static synthetic fun mcpStatelessStreamableHttp$default (Lio/ktor/server/application/Application;Ljava/lang/String;ZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V - public static final fun mcpStreamableHttp (Lio/ktor/server/application/Application;Ljava/lang/String;ZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/jvm/functions/Function1;)V - public static synthetic fun mcpStreamableHttp$default (Lio/ktor/server/application/Application;Ljava/lang/String;ZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V + public static final fun mcpStatelessStreamableHttp (Lio/ktor/server/application/Application;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport$Configuration;Lkotlin/jvm/functions/Function1;)V + public static synthetic fun mcpStatelessStreamableHttp$default (Lio/ktor/server/application/Application;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport$Configuration;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V + public static final fun mcpStreamableHttp (Lio/ktor/server/application/Application;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport$Configuration;Lkotlin/jvm/functions/Function1;)V + public static synthetic fun mcpStreamableHttp$default (Lio/ktor/server/application/Application;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport$Configuration;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V } public final class io/modelcontextprotocol/kotlin/sdk/server/RegisteredPrompt : io/modelcontextprotocol/kotlin/sdk/server/Feature { @@ -194,7 +194,7 @@ public final class io/modelcontextprotocol/kotlin/sdk/server/StdioServerTranspor public final class io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport : io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport { public static final field STANDALONE_SSE_STREAM_ID Ljava/lang/String; - public fun ()V + public fun (Lio/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport$Configuration;)V public fun (ZZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Ljava/lang/Long;)V public synthetic fun (ZZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Ljava/lang/Long;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public fun close (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; @@ -211,6 +211,17 @@ public final class io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServe public fun start (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; } +public final class io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport$Configuration { + public synthetic fun (ZZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/time/Duration;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public synthetic fun (ZZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/time/Duration;Lkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun getAllowedHosts ()Ljava/util/List; + public final fun getAllowedOrigins ()Ljava/util/List; + public final fun getEnableDnsRebindingProtection ()Z + public final fun getEnableJsonResponse ()Z + public final fun getEventStore ()Lio/modelcontextprotocol/kotlin/sdk/server/EventStore; + public final fun getRetryInterval-FghU774 ()Lkotlin/time/Duration; +} + public final class io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensionsKt { public static final fun mcpWebSocket (Lio/ktor/server/application/Application;Ljava/lang/String;Lkotlin/jvm/functions/Function0;)V public static final fun mcpWebSocket (Lio/ktor/server/application/Application;Lkotlin/jvm/functions/Function0;)V diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt index 5a750a5d..a6533170 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt @@ -92,6 +92,11 @@ public fun Route.mcp(block: ServerSSESession.() -> Server) { } } +/** + * Configures the application to use Server-Sent Events (SSE) and sets up routing for the provided server logic. + * + * @param block A lambda function that defines the server logic within the context of a [ServerSSESession]. + */ @KtorDsl public fun Application.mcp(block: ServerSSESession.() -> Server) { install(SSE) @@ -101,14 +106,20 @@ public fun Application.mcp(block: ServerSSESession.() -> Server) { } } +/** + * Sets up HTTP endpoints for an application to support MCP streamable interactions + * using the Server-Sent Events (SSE) protocol and other HTTP methods. + * + * @param path The base URL path for the MCP streamable HTTP routes. Defaults to "/mcp". + * @param configuration An instance of `StreamableHttpServerTransport.Configuration` used to configure + * the behavior of the transport layer. + * @param block A lambda with a `RoutingContext` receiver, allowing the user to define server logic + * for handling streamable transport. + */ @KtorDsl -@Suppress("LongParameterList") public fun Application.mcpStreamableHttp( path: String = "/mcp", - enableDnsRebindingProtection: Boolean = false, - allowedHosts: List? = null, - allowedOrigins: List? = null, - eventStore: EventStore? = null, + configuration: StreamableHttpServerTransport.Configuration = StreamableHttpServerTransport.Configuration(), block: RoutingContext.() -> Server, ) { install(SSE) @@ -125,10 +136,7 @@ public fun Application.mcpStreamableHttp( post { val transport = streamableTransport( transportManager = transportManager, - enableDnsRebindingProtection = enableDnsRebindingProtection, - allowedHosts = allowedHosts, - allowedOrigins = allowedOrigins, - eventStore = eventStore, + configuration = configuration, block = block, ) ?: return@post @@ -144,14 +152,19 @@ public fun Application.mcpStreamableHttp( } } +/** + * Sets up a stateless and streamable HTTP endpoint within the application using the specified path and configuration. + * This method installs the SSE feature and defines specific routing behavior for HTTP methods. + * + * @param path The URL path where the endpoint will be accessible. Defaults to "/mcp". + * @param configuration The configuration object used to customize the behavior of the streamable HTTP server transport. + * @param block A lambda function that provides the routing context to define the server behavior. + */ @KtorDsl @Suppress("LongParameterList") public fun Application.mcpStatelessStreamableHttp( path: String = "/mcp", - enableDnsRebindingProtection: Boolean = false, - allowedHosts: List? = null, - allowedOrigins: List? = null, - eventStore: EventStore? = null, + configuration: StreamableHttpServerTransport.Configuration = StreamableHttpServerTransport.Configuration(), block: RoutingContext.() -> Server, ) { install(SSE) @@ -160,10 +173,7 @@ public fun Application.mcpStatelessStreamableHttp( route(path) { post { mcpStatelessStreamableHttpEndpoint( - enableDnsRebindingProtection = enableDnsRebindingProtection, - allowedHosts = allowedHosts, - allowedOrigins = allowedOrigins, - eventStore = eventStore, + configuration = configuration, block = block, ) } @@ -218,18 +228,11 @@ private fun ServerSSESession.mcpSseTransport( } private suspend fun RoutingContext.mcpStatelessStreamableHttpEndpoint( - enableDnsRebindingProtection: Boolean = false, - allowedHosts: List? = null, - allowedOrigins: List? = null, - eventStore: EventStore? = null, + configuration: StreamableHttpServerTransport.Configuration = StreamableHttpServerTransport.Configuration(), block: RoutingContext.() -> Server, ) { val transport = StreamableHttpServerTransport( - enableDnsRebindingProtection = enableDnsRebindingProtection, - allowedHosts = allowedHosts, - allowedOrigins = allowedOrigins, - eventStore = eventStore, - enableJsonResponse = true, + configuration, ).also { it.setSessionIdGenerator(null) } logger.info { "New stateless StreamableHttp connection established without sessionId" } @@ -292,10 +295,7 @@ private suspend fun existingStreamableTransport( private suspend fun RoutingContext.streamableTransport( transportManager: TransportManager, - enableDnsRebindingProtection: Boolean, - allowedHosts: List?, - allowedOrigins: List?, - eventStore: EventStore?, + configuration: StreamableHttpServerTransport.Configuration, block: RoutingContext.() -> Server, ): StreamableHttpServerTransport? { val sessionId = call.request.sessionId() @@ -304,13 +304,7 @@ private suspend fun RoutingContext.streamableTransport( return transport ?: existingStreamableTransport(call, transportManager) } - val transport = StreamableHttpServerTransport( - enableDnsRebindingProtection = enableDnsRebindingProtection, - allowedHosts = allowedHosts, - allowedOrigins = allowedOrigins, - eventStore = eventStore, - enableJsonResponse = true, - ) + val transport = StreamableHttpServerTransport(configuration) transport.setOnSessionInitialized { initializedSessionId -> transportManager.addTransport(initializedSessionId, transport) diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt index 1894cdc4..e30ece05 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt @@ -36,6 +36,8 @@ import kotlinx.serialization.json.JsonObject import kotlinx.serialization.json.decodeFromJsonElement import kotlin.concurrent.atomics.AtomicBoolean import kotlin.concurrent.atomics.ExperimentalAtomicApi +import kotlin.time.Duration +import kotlin.time.Duration.Companion.milliseconds import kotlin.uuid.ExperimentalUuidApi import kotlin.uuid.Uuid @@ -46,8 +48,8 @@ private const val MAXIMUM_MESSAGE_SIZE = 4 * 1024 * 1024 // 4 MB /** * A holder for an active request call. - * If enableJsonResponse is true, session is null. - * Otherwise, session is not null. + * If [StreamableHttpServerTransport.Configuration.enableJsonResponse] is true, the session is null. + * Otherwise, the session is not null. */ private data class SessionContext(val session: ServerSSESession?, val call: ApplicationCall) @@ -66,32 +68,81 @@ private data class SessionContext(val session: ServerSSESession?, val call: Appl * - No Session ID is included in any responses * - No session validation is performed * - * @param enableJsonResponse If true, the server will return JSON responses instead of starting an SSE stream. - * This can be useful for simple request/response scenarios without streaming. - * Default is false (SSE streams are preferred). - * @param enableDnsRebindingProtection Enable DNS rebinding protection - * (requires allowedHosts and/or allowedOrigins to be configured). - * Default is false for backwards compatibility. - * @param allowedHosts List of allowed host header values for DNS rebinding protection. - * If not specified, host validation is disabled. - * @param allowedOrigins List of allowed origin header values for DNS rebinding protection. - * If not specified, origin validation is disabled. - * @param eventStore Event store for resumability support - * If provided, resumability will be enabled, allowing clients to reconnect and resume messages - * @param retryIntervalMillis Retry interval (in milliseconds) advertised via SSE priming events - * to hint the client when to reconnect. Applies only when an [eventStore] is configured. - * Defaults to `null` (no retry hint). + * @param configuration Transport configuration. See [Configuration] for available options. */ @OptIn(ExperimentalUuidApi::class, ExperimentalAtomicApi::class) @Suppress("TooManyFunctions") -public class StreamableHttpServerTransport( - private val enableJsonResponse: Boolean = false, - private val enableDnsRebindingProtection: Boolean = false, - private val allowedHosts: List? = null, - private val allowedOrigins: List? = null, - private val eventStore: EventStore? = null, - private val retryIntervalMillis: Long? = null, -) : AbstractTransport() { +public class StreamableHttpServerTransport(private val configuration: Configuration) : AbstractTransport() { + + /** + * Secondary constructor for `StreamableHttpServerTransport` that simplifies initialization by directly taking the + * configurable parameters without requiring a `Configuration` instance. + * + * @param enableJsonResponse Determines whether the server should return JSON responses. + * Defaults to `false`. + * @param enableDnsRebindingProtection Enables DNS rebinding protection. + * Defaults to `false`. + * @param allowedHosts A list of hosts allowed for server communication. + * Defaults to `null`, allowing all hosts. + * @param allowedOrigins A list of allowed origins for CORS (Cross-Origin Resource Sharing). + * Defaults to `null`, allowing all origins. + * @param eventStore The `EventStore` instance for handling resumable events. + * Defaults to `null`, disabling resumability. + * @param retryIntervalMillis Retry interval in milliseconds for event handling or reconnection attempts. + * Defaults to `null`. + */ + @Deprecated( + "Use constructor with Configuration: StreamableHttpServerTransport(Configuration(enableJsonResponse = ...))", + level = DeprecationLevel.WARNING, + ) + public constructor( + enableJsonResponse: Boolean = false, + enableDnsRebindingProtection: Boolean = false, + allowedHosts: List? = null, + allowedOrigins: List? = null, + eventStore: EventStore? = null, + retryIntervalMillis: Long? = null, + ) : this( + Configuration( + enableJsonResponse = enableJsonResponse, + enableDnsRebindingProtection = enableDnsRebindingProtection, + allowedHosts = allowedHosts, + allowedOrigins = allowedOrigins, + eventStore = eventStore, + retryInterval = retryIntervalMillis?.milliseconds, + ), + ) + + /** + * Configuration for managing various aspects of the StreamableHttpServerTransport. + * + * @property enableJsonResponse Determines whether the server should return JSON responses. + * Defaults to `false`. + * + * @property enableDnsRebindingProtection Enables DNS rebinding protection. + * Defaults to `false`. + * + * @property allowedHosts A list of hosts allowed for server communication. + * Defaults to `null`, allowing all hosts. + * + * @property allowedOrigins A list of allowed origins for CORS (Cross-Origin Resource Sharing). + * Defaults to `null`, allowing all origins. + * + * @property eventStore The `EventStore` instance for handling resumable events. + * Defaults to `null`, disabling resumability. + * + * @property retryInterval Retry interval for event handling or reconnection attempts. + * Defaults to `null`. + */ + public class Configuration( + public val enableJsonResponse: Boolean = false, + public val enableDnsRebindingProtection: Boolean = false, + public val allowedHosts: List? = null, + public val allowedOrigins: List? = null, + public val eventStore: EventStore? = null, + public val retryInterval: Duration? = null, + ) + public var sessionId: String? = null private set @@ -177,7 +228,7 @@ public class StreamableHttpServerTransport( ?: error("No connection established for request id $routingRequestId") val activeStream = streamsMapping[streamId] - if (!enableJsonResponse) { + if (!configuration.enableJsonResponse) { activeStream?.let { stream -> emitOnStream(streamId, stream.session, message) } @@ -194,7 +245,7 @@ public class StreamableHttpServerTransport( streamMutex.withLock { if (activeStream == null) error("No connection established for request ID: $routingRequestId") - if (enableJsonResponse) { + if (configuration.enableJsonResponse) { activeStream.call.response.header(HttpHeaders.ContentType, ContentType.Application.Json.toString()) sessionId?.let { activeStream.call.response.header(MCP_SESSION_ID_HEADER, it) } val responses = relatedIds.mapNotNull { requestToResponseMapping[it] } @@ -261,7 +312,7 @@ public class StreamableHttpServerTransport( @Suppress("CyclomaticComplexMethod", "LongMethod", "ReturnCount", "TooGenericExceptionCaught") public suspend fun handlePostRequest(session: ServerSSESession?, call: ApplicationCall) { try { - if (!enableJsonResponse && session == null) { + if (!configuration.enableJsonResponse && session == null) { error("Server session can't be null for SSE responses") } @@ -328,7 +379,7 @@ public class StreamableHttpServerTransport( } val streamId = Uuid.random().toString() - if (!enableJsonResponse) { + if (!configuration.enableJsonResponse) { call.appendSseHeaders() flushSse(session) // flush headers immediately maybeSendPrimingEvent(streamId, session) @@ -353,7 +404,7 @@ public class StreamableHttpServerTransport( @Suppress("ReturnCount") public suspend fun handleGetRequest(session: ServerSSESession?, call: ApplicationCall) { - if (enableJsonResponse) { + if (configuration.enableJsonResponse) { call.reject( HttpStatusCode.MethodNotAllowed, RPCError.ErrorCode.CONNECTION_CLOSED, @@ -375,7 +426,7 @@ public class StreamableHttpServerTransport( if (!validateSession(call) || !validateProtocolVersion(call)) return - eventStore?.let { store -> + configuration.eventStore?.let { store -> call.request.header(MCP_RESUMPTION_TOKEN_HEADER)?.let { lastEventId -> replayEvents(store, lastEventId, sseSession) return @@ -413,7 +464,7 @@ public class StreamableHttpServerTransport( */ @Suppress("ReturnCount", "TooGenericExceptionCaught") public suspend fun closeSseStream(requestId: RequestId) { - if (enableJsonResponse) return + if (configuration.enableJsonResponse) return val streamId = requestToStreamMapping[requestId] ?: return val sessionContext = streamsMapping[streamId] ?: return @@ -562,9 +613,9 @@ public class StreamableHttpServerTransport( @Suppress("ReturnCount") private fun validateHeaders(call: ApplicationCall): String? { - if (!enableDnsRebindingProtection) return null + if (!configuration.enableDnsRebindingProtection) return null - allowedHosts?.let { hosts -> + configuration.allowedHosts?.let { hosts -> val hostHeader = call.request.headers[HttpHeaders.Host]?.lowercase() val allowedHostsLowercase = hosts.map { it.lowercase() } @@ -573,7 +624,7 @@ public class StreamableHttpServerTransport( } } - allowedOrigins?.let { origins -> + configuration.allowedOrigins?.let { origins -> val originHeader = call.request.headers[HttpHeaders.Origin]?.lowercase() val allowedOriginsLowercase = origins.map { it.lowercase() } @@ -636,7 +687,7 @@ public class StreamableHttpServerTransport( this?.lowercase()?.contains(mime.toString().lowercase()) == true private suspend fun emitOnStream(streamId: String, session: ServerSSESession?, message: JSONRPCMessage) { - val eventId = eventStore?.storeEvent(streamId, message) + val eventId = configuration.eventStore?.storeEvent(streamId, message) try { session?.send(event = "message", id = eventId, data = McpJson.encodeToString(message)) } catch (_: Exception) { @@ -646,11 +697,15 @@ public class StreamableHttpServerTransport( @Suppress("TooGenericExceptionCaught") private suspend fun maybeSendPrimingEvent(streamId: String, session: ServerSSESession?) { - val store = eventStore ?: return + val store = configuration.eventStore ?: return val sseSession = session ?: return try { val primingEventId = store.storeEvent(streamId, JSONRPCEmptyMessage) - sseSession.send(id = primingEventId, retry = retryIntervalMillis, data = "") + sseSession.send( + id = primingEventId, + retry = configuration.retryInterval?.inWholeMilliseconds, + data = "", + ) } catch (e: Exception) { _onError(e) }