Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import kotlinx.coroutines.test.runTest
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows
import java.util.concurrent.atomic.AtomicBoolean
import kotlin.test.Ignore
import kotlin.test.assertEquals
import kotlin.test.assertNotNull
import kotlin.test.assertTrue
Expand Down Expand Up @@ -165,7 +164,6 @@ abstract class AbstractResourceIntegrationTest : KotlinTestBase() {
assertEquals(testResourceContent, content.text, "Resource content should match")
}

@Ignore("Blocked by https://github.com/modelcontextprotocol/kotlin-sdk/issues/249")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is passing now

@Test
fun testSubscribeAndUnsubscribe() {
runBlocking(Dispatchers.IO) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,9 @@ abstract class KotlinTestBase {
// Create StreamableHTTP server transport
// Using JSON response mode for simpler testing (no SSE session required)
val transport = StreamableHttpServerTransport(
enableJsonResponse = true, // Use JSON response mode for testing
StreamableHttpServerTransport.Configuration(
enableJsonResponse = true, // Use JSON response mode for testing
),
)
// Use stateless mode to skip session validation for simpler testing
transport.setSessionIdGenerator(null)
Expand Down
12 changes: 12 additions & 0 deletions kotlin-sdk-server/api/kotlin-sdk-server.api
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ public final class io/modelcontextprotocol/kotlin/sdk/server/StdioServerTranspor
public final class io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport : io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport {
public static final field STANDALONE_SSE_STREAM_ID Ljava/lang/String;
public fun <init> ()V
public fun <init> (Lio/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport$Configuration;)V
public fun <init> (ZZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Ljava/lang/Long;)V
public synthetic fun <init> (ZZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Ljava/lang/Long;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public fun close (Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
Expand All @@ -211,6 +212,17 @@ public final class io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServe
public fun start (Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
}

public final class io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport$Configuration {
public synthetic fun <init> (ZZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/time/Duration;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public synthetic fun <init> (ZZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/time/Duration;Lkotlin/jvm/internal/DefaultConstructorMarker;)V
public final fun getAllowedHosts ()Ljava/util/List;
public final fun getAllowedOrigins ()Ljava/util/List;
public final fun getEnableDnsRebindingProtection ()Z
public final fun getEnableJsonResponse ()Z
public final fun getEventStore ()Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;
public final fun getRetryInterval-FghU774 ()Lkotlin/time/Duration;
}

public final class io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensionsKt {
public static final fun mcpWebSocket (Lio/ktor/server/application/Application;Ljava/lang/String;Lkotlin/jvm/functions/Function0;)V
public static final fun mcpWebSocket (Lio/ktor/server/application/Application;Lkotlin/jvm/functions/Function0;)V
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,11 +225,13 @@ private suspend fun RoutingContext.mcpStatelessStreamableHttpEndpoint(
block: RoutingContext.() -> Server,
) {
val transport = StreamableHttpServerTransport(
enableDnsRebindingProtection = enableDnsRebindingProtection,
allowedHosts = allowedHosts,
allowedOrigins = allowedOrigins,
eventStore = eventStore,
enableJsonResponse = true,
StreamableHttpServerTransport.Configuration(
enableDnsRebindingProtection = enableDnsRebindingProtection,
allowedHosts = allowedHosts,
allowedOrigins = allowedOrigins,
eventStore = eventStore,
enableJsonResponse = true,
),
).also { it.setSessionIdGenerator(null) }

logger.info { "New stateless StreamableHttp connection established without sessionId" }
Expand Down Expand Up @@ -305,11 +307,13 @@ private suspend fun RoutingContext.streamableTransport(
}

val transport = StreamableHttpServerTransport(
enableDnsRebindingProtection = enableDnsRebindingProtection,
allowedHosts = allowedHosts,
allowedOrigins = allowedOrigins,
eventStore = eventStore,
enableJsonResponse = true,
StreamableHttpServerTransport.Configuration(
enableDnsRebindingProtection = enableDnsRebindingProtection,
allowedHosts = allowedHosts,
allowedOrigins = allowedOrigins,
eventStore = eventStore,
enableJsonResponse = true,
),
)

transport.setOnSessionInitialized { initializedSessionId ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.decodeFromJsonElement
import kotlin.concurrent.atomics.AtomicBoolean
import kotlin.concurrent.atomics.ExperimentalAtomicApi
import kotlin.time.Duration
import kotlin.time.Duration.Companion.milliseconds
import kotlin.uuid.ExperimentalUuidApi
import kotlin.uuid.Uuid

Expand All @@ -46,8 +48,8 @@ private const val MAXIMUM_MESSAGE_SIZE = 4 * 1024 * 1024 // 4 MB

/**
* A holder for an active request call.
* If enableJsonResponse is true, session is null.
* Otherwise, session is not null.
* If [StreamableHttpServerTransport.Configuration.enableJsonResponse] is true, the session is null.
* Otherwise, the session is not null.
*/
private data class SessionContext(val session: ServerSSESession?, val call: ApplicationCall)

Expand All @@ -66,32 +68,87 @@ private data class SessionContext(val session: ServerSSESession?, val call: Appl
* - No Session ID is included in any responses
* - No session validation is performed
*
* @param enableJsonResponse If true, the server will return JSON responses instead of starting an SSE stream.
* This can be useful for simple request/response scenarios without streaming.
* Default is false (SSE streams are preferred).
* @param enableDnsRebindingProtection Enable DNS rebinding protection
* (requires allowedHosts and/or allowedOrigins to be configured).
* Default is false for backwards compatibility.
* @param allowedHosts List of allowed host header values for DNS rebinding protection.
* If not specified, host validation is disabled.
* @param allowedOrigins List of allowed origin header values for DNS rebinding protection.
* If not specified, origin validation is disabled.
* @param eventStore Event store for resumability support
* If provided, resumability will be enabled, allowing clients to reconnect and resume messages
* @param retryIntervalMillis Retry interval (in milliseconds) advertised via SSE priming events
* to hint the client when to reconnect. Applies only when an [eventStore] is configured.
* Defaults to `null` (no retry hint).
* @param configuration Transport configuration. See [Configuration] for available options.
*/
@OptIn(ExperimentalUuidApi::class, ExperimentalAtomicApi::class)
@Suppress("TooManyFunctions")
public class StreamableHttpServerTransport(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a method for backward source compatibility?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

private val enableJsonResponse: Boolean = false,
private val enableDnsRebindingProtection: Boolean = false,
private val allowedHosts: List<String>? = null,
private val allowedOrigins: List<String>? = null,
private val eventStore: EventStore? = null,
private val retryIntervalMillis: Long? = null,
) : AbstractTransport() {
public class StreamableHttpServerTransport(private val configuration: Configuration) : AbstractTransport() {

@Deprecated("Use default constructor with explicit Configuration()")
public constructor() : this(configuration = Configuration())

/**
* Secondary constructor for `StreamableHttpServerTransport` that simplifies initialization by directly taking the
* configurable parameters without requiring a `Configuration` instance.
*
* @param enableJsonResponse Determines whether the server should return JSON responses.
* Defaults to `false`.
* @param enableDnsRebindingProtection Enables DNS rebinding protection.
* Defaults to `false`.
* @param allowedHosts A list of hosts allowed for server communication.
* Defaults to `null`, allowing all hosts.
* @param allowedOrigins A list of allowed origins for CORS (Cross-Origin Resource Sharing).
* Defaults to `null`, allowing all origins.
* @param eventStore The `EventStore` instance for handling resumable events.
* Defaults to `null`, disabling resumability.
* @param retryIntervalMillis Retry interval in milliseconds for event handling or reconnection attempts.
* Defaults to `null`.
*/
@Suppress("MaxLineLength")
@Deprecated(
"Use constructor with Configuration: StreamableHttpServerTransport(Configuration(enableJsonResponse = ...))",
replaceWith = ReplaceWith(
"StreamableHttpServerTransport(Configuration(enableJsonResponse = enableJsonResponse, enableDnsRebindingProtection = enableDnsRebindingProtection, allowedHosts = allowedHosts, allowedOrigins = allowedOrigins, eventStore = eventStore, retryIntervalMillis = retryIntervalMillis))",
),
)
public constructor(
enableJsonResponse: Boolean = false,
enableDnsRebindingProtection: Boolean = false,
allowedHosts: List<String>? = null,
allowedOrigins: List<String>? = null,
eventStore: EventStore? = null,
retryIntervalMillis: Long? = null,
) : this(
Configuration(
enableJsonResponse = enableJsonResponse,
enableDnsRebindingProtection = enableDnsRebindingProtection,
allowedHosts = allowedHosts,
allowedOrigins = allowedOrigins,
eventStore = eventStore,
retryInterval = retryIntervalMillis?.milliseconds,
),
)

/**
* Configuration for managing various aspects of the StreamableHttpServerTransport.
*
* @property enableJsonResponse Determines whether the server should return JSON responses.
* Defaults to `false`.
*
* @property enableDnsRebindingProtection Enables DNS rebinding protection.
* Defaults to `false`.
*
* @property allowedHosts A list of hosts allowed for server communication.
* Defaults to `null`, allowing all hosts.
*
* @property allowedOrigins A list of allowed origins for CORS (Cross-Origin Resource Sharing).
* Defaults to `null`, allowing all origins.
*
* @property eventStore The `EventStore` instance for handling resumable events.
* Defaults to `null`, disabling resumability.
*
* @property retryInterval Retry interval for event handling or reconnection attempts.
* Defaults to `null`.
*/
public class Configuration(
public val enableJsonResponse: Boolean = false,
public val enableDnsRebindingProtection: Boolean = false,
public val allowedHosts: List<String>? = null,
public val allowedOrigins: List<String>? = null,
public val eventStore: EventStore? = null,
public val retryInterval: Duration? = null,
)

public var sessionId: String? = null
private set

Expand Down Expand Up @@ -177,7 +234,7 @@ public class StreamableHttpServerTransport(
?: error("No connection established for request id $routingRequestId")
val activeStream = streamsMapping[streamId]

if (!enableJsonResponse) {
if (!configuration.enableJsonResponse) {
activeStream?.let { stream ->
emitOnStream(streamId, stream.session, message)
}
Expand All @@ -194,7 +251,7 @@ public class StreamableHttpServerTransport(
streamMutex.withLock {
if (activeStream == null) error("No connection established for request ID: $routingRequestId")

if (enableJsonResponse) {
if (configuration.enableJsonResponse) {
activeStream.call.response.header(HttpHeaders.ContentType, ContentType.Application.Json.toString())
sessionId?.let { activeStream.call.response.header(MCP_SESSION_ID_HEADER, it) }
val responses = relatedIds.mapNotNull { requestToResponseMapping[it] }
Expand Down Expand Up @@ -261,7 +318,7 @@ public class StreamableHttpServerTransport(
@Suppress("CyclomaticComplexMethod", "LongMethod", "ReturnCount", "TooGenericExceptionCaught")
public suspend fun handlePostRequest(session: ServerSSESession?, call: ApplicationCall) {
try {
if (!enableJsonResponse && session == null) {
if (!configuration.enableJsonResponse && session == null) {
error("Server session can't be null for SSE responses")
}

Expand Down Expand Up @@ -328,7 +385,7 @@ public class StreamableHttpServerTransport(
}

val streamId = Uuid.random().toString()
if (!enableJsonResponse) {
if (!configuration.enableJsonResponse) {
call.appendSseHeaders()
flushSse(session) // flush headers immediately
maybeSendPrimingEvent(streamId, session)
Expand All @@ -353,7 +410,7 @@ public class StreamableHttpServerTransport(

@Suppress("ReturnCount")
public suspend fun handleGetRequest(session: ServerSSESession?, call: ApplicationCall) {
if (enableJsonResponse) {
if (configuration.enableJsonResponse) {
call.reject(
HttpStatusCode.MethodNotAllowed,
RPCError.ErrorCode.CONNECTION_CLOSED,
Expand All @@ -375,7 +432,7 @@ public class StreamableHttpServerTransport(

if (!validateSession(call) || !validateProtocolVersion(call)) return

eventStore?.let { store ->
configuration.eventStore?.let { store ->
call.request.header(MCP_RESUMPTION_TOKEN_HEADER)?.let { lastEventId ->
replayEvents(store, lastEventId, sseSession)
return
Expand Down Expand Up @@ -413,7 +470,7 @@ public class StreamableHttpServerTransport(
*/
@Suppress("ReturnCount", "TooGenericExceptionCaught")
public suspend fun closeSseStream(requestId: RequestId) {
if (enableJsonResponse) return
if (configuration.enableJsonResponse) return
val streamId = requestToStreamMapping[requestId] ?: return
val sessionContext = streamsMapping[streamId] ?: return

Expand Down Expand Up @@ -562,9 +619,9 @@ public class StreamableHttpServerTransport(

@Suppress("ReturnCount")
private fun validateHeaders(call: ApplicationCall): String? {
if (!enableDnsRebindingProtection) return null
if (!configuration.enableDnsRebindingProtection) return null

allowedHosts?.let { hosts ->
configuration.allowedHosts?.let { hosts ->
val hostHeader = call.request.headers[HttpHeaders.Host]?.lowercase()
val allowedHostsLowercase = hosts.map { it.lowercase() }

Expand All @@ -573,7 +630,7 @@ public class StreamableHttpServerTransport(
}
}

allowedOrigins?.let { origins ->
configuration.allowedOrigins?.let { origins ->
val originHeader = call.request.headers[HttpHeaders.Origin]?.lowercase()
val allowedOriginsLowercase = origins.map { it.lowercase() }

Expand Down Expand Up @@ -636,7 +693,7 @@ public class StreamableHttpServerTransport(
this?.lowercase()?.contains(mime.toString().lowercase()) == true

private suspend fun emitOnStream(streamId: String, session: ServerSSESession?, message: JSONRPCMessage) {
val eventId = eventStore?.storeEvent(streamId, message)
val eventId = configuration.eventStore?.storeEvent(streamId, message)
try {
session?.send(event = "message", id = eventId, data = McpJson.encodeToString(message))
} catch (_: Exception) {
Expand All @@ -646,11 +703,15 @@ public class StreamableHttpServerTransport(

@Suppress("TooGenericExceptionCaught")
private suspend fun maybeSendPrimingEvent(streamId: String, session: ServerSSESession?) {
val store = eventStore ?: return
val store = configuration.eventStore ?: return
val sseSession = session ?: return
try {
val primingEventId = store.storeEvent(streamId, JSONRPCEmptyMessage)
sseSession.send(id = primingEventId, retry = retryIntervalMillis, data = "")
sseSession.send(
id = primingEventId,
retry = configuration.retryInterval?.inWholeMilliseconds,
data = "",
)
} catch (e: Exception) {
_onError(e)
}
Expand Down
Loading