diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 08bf61ee..ccb649dd 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -49,6 +49,7 @@ ktor-client-logging = { group = "io.ktor", name = "ktor-client-logging", version ktor-server-content-negotiation = { group = "io.ktor", name = "ktor-server-content-negotiation", version.ref = "ktor" } ktor-client-content-negotiation = { group = "io.ktor", name = "ktor-client-content-negotiation", version.ref = "ktor" } ktor-serialization = { group = "io.ktor", name = "ktor-serialization-kotlinx-json", version.ref = "ktor" } +ktor-server-auth = { group = "io.ktor", name = "ktor-server-auth", version.ref = "ktor" } ktor-server-core = { group = "io.ktor", name = "ktor-server-core", version.ref = "ktor" } ktor-server-sse = { group = "io.ktor", name = "ktor-server-sse", version.ref = "ktor" } ktor-server-websockets = { group = "io.ktor", name = "ktor-server-websockets", version.ref = "ktor" } diff --git a/integration-test/build.gradle.kts b/integration-test/build.gradle.kts index 82078e7c..30afe9e8 100644 --- a/integration-test/build.gradle.kts +++ b/integration-test/build.gradle.kts @@ -22,6 +22,7 @@ kotlin { implementation(libs.ktor.server.content.negotiation) implementation(libs.ktor.serialization) implementation(libs.ktor.server.websockets) + implementation(libs.ktor.server.auth) implementation(libs.ktor.server.test.host) implementation(libs.ktor.server.content.negotiation) implementation(libs.ktor.serialization) diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/AbstractAuthenticationTest.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/AbstractAuthenticationTest.kt new file mode 100644 index 00000000..1e88558f --- /dev/null +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/AbstractAuthenticationTest.kt @@ -0,0 +1,188 @@ +package io.modelcontextprotocol.kotlin.sdk.integration + +import io.kotest.matchers.shouldBe +import io.ktor.client.HttpClient +import io.ktor.client.request.basicAuth +import io.ktor.client.request.get +import io.ktor.http.HttpStatusCode +import io.ktor.serialization.kotlinx.json.json +import io.ktor.server.application.Application +import io.ktor.server.application.ApplicationCall +import io.ktor.server.application.install +import io.ktor.server.auth.Authentication +import io.ktor.server.auth.UserIdPrincipal +import io.ktor.server.auth.authenticate +import io.ktor.server.auth.basic +import io.ktor.server.auth.principal +import io.ktor.server.engine.embeddedServer +import io.ktor.server.plugins.contentnegotiation.ContentNegotiation +import io.ktor.server.routing.Route +import io.ktor.server.routing.routing +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.server.Server +import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions +import io.modelcontextprotocol.kotlin.sdk.shared.Transport +import io.modelcontextprotocol.kotlin.sdk.types.Implementation +import io.modelcontextprotocol.kotlin.sdk.types.McpJson +import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceRequest +import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceRequestParams +import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceResult +import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.types.TextResourceContents +import io.modelcontextprotocol.kotlin.test.utils.actualPort +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withTimeout +import kotlin.test.Test +import kotlin.time.Duration.Companion.seconds +import io.ktor.client.engine.cio.CIO as ClientCIO +import io.ktor.server.cio.CIO as ServerCIO +import io.ktor.server.sse.SSE as ServerSSE + +/** + * Base class for MCP authentication integration tests. + */ +@Suppress("InjectDispatcher") +abstract class AbstractAuthenticationTest { + + protected companion object { + const val HOST = "127.0.0.1" + const val AUTH_REALM = "mcp-auth" + const val WHOAMI_URI = "whoami://me" + const val VALID_USER = "test-user" + const val VALID_PASSWORD = "valid-password-123" + const val INVALID_USER = "invalid-user" + const val INVALID_PASSWORD = "invalid-password" + } + + /** + * Installs Ktor plugins required by the transport under test. + */ + protected open fun Application.configurePlugins() { + install(ServerSSE) + // ContentNegotiation is required by the StreamableHttp transport for JSON body handling. + // Installing it for SSE tests as well is harmless. + install(ContentNegotiation) { json(McpJson) } + } + + /** + * Registers the MCP server on the given route. + */ + abstract fun Route.registerMcpServer(serverFactory: ApplicationCall.() -> Server) + + /** + * Creates a client transport configured with the given credentials. + */ + abstract fun createClientTransport(baseUrl: String, user: String, pass: String): Transport + + @Test + fun `mcp behind basic auth rejects unauthenticated requests with 401`(): Unit = runBlocking(Dispatchers.IO) { + val server = startAuthenticatedServer() + + val httpClient = HttpClient(ClientCIO) { expectSuccess = false } + try { + httpClient.get("http://$HOST:${server.actualPort()}").status shouldBe HttpStatusCode.Unauthorized + } finally { + httpClient.close() + server.stopSuspend(1000, 2000) + } + } + + @Test + fun `mcp rejects requests with invalid credentials`(): Unit = runBlocking(Dispatchers.IO) { + val server = startAuthenticatedServer() + + val httpClient = HttpClient(ClientCIO) { + expectSuccess = false + } + try { + httpClient.get("http://$HOST:${server.actualPort()}") { + basicAuth(INVALID_USER, INVALID_PASSWORD) + }.status shouldBe HttpStatusCode.Unauthorized + } finally { + httpClient.close() + server.stopSuspend(1000, 2000) + } + } + + @Test + fun `authenticated mcp client can read resource scoped to principal`(): Unit = runBlocking(Dispatchers.IO) { + val server = startAuthenticatedServer() + + val baseUrl = "http://$HOST:${server.actualPort()}" + var mcpClient: Client? = null + try { + mcpClient = Client(Implementation(name = "test-client", version = "1.0.0")) + withTimeout(5.seconds) { + mcpClient.connect(createClientTransport(baseUrl, VALID_USER, VALID_PASSWORD)) + } + + val result = mcpClient.readResource( + ReadResourceRequest(ReadResourceRequestParams(uri = WHOAMI_URI)), + ) + + result.contents shouldBe listOf( + TextResourceContents( + text = VALID_USER, + uri = WHOAMI_URI, + mimeType = "text/plain", + ), + ) + } finally { + mcpClient?.close() + server.stopSuspend(1000, 2000) + } + } + + private suspend fun startAuthenticatedServer() = embeddedServer(ServerCIO, host = HOST, port = 0) { + configurePlugins() + installBasicAuth() + routing { + authenticate(AUTH_REALM) { + registerMcpServer { + createMcpServer { principal()?.name } + } + } + } + }.startSuspend(wait = false) + + private fun Application.installBasicAuth() { + install(Authentication) { + basic(AUTH_REALM) { + validate { credentials -> + if (credentials.name == VALID_USER && credentials.password == VALID_PASSWORD) { + UserIdPrincipal(credentials.name) + } else { + null + } + } + } + } + } + + protected fun createMcpServer(principalProvider: () -> String?): Server = Server( + serverInfo = Implementation(name = "test-server", version = "1.0.0"), + options = ServerOptions( + capabilities = ServerCapabilities( + resources = ServerCapabilities.Resources(), + ), + ), + ).apply { + addResource( + uri = WHOAMI_URI, + name = "Current User", + description = "Returns the name of the authenticated user", + mimeType = "text/plain", + ) { + ReadResourceResult( + contents = listOf( + TextResourceContents( + text = principalProvider() ?: "anonymous", + uri = WHOAMI_URI, + mimeType = "text/plain", + ), + ), + ) + } + } +} diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/sse/SseAuthenticationTest.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/sse/SseAuthenticationTest.kt new file mode 100644 index 00000000..da763195 --- /dev/null +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/sse/SseAuthenticationTest.kt @@ -0,0 +1,41 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.sse + +import io.ktor.client.HttpClient +import io.ktor.client.plugins.sse.SSE +import io.ktor.client.request.basicAuth +import io.ktor.server.application.ApplicationCall +import io.ktor.server.routing.Route +import io.modelcontextprotocol.kotlin.sdk.client.SseClientTransport +import io.modelcontextprotocol.kotlin.sdk.integration.AbstractAuthenticationTest +import io.modelcontextprotocol.kotlin.sdk.server.Server +import io.modelcontextprotocol.kotlin.sdk.server.mcp +import io.modelcontextprotocol.kotlin.sdk.shared.Transport +import kotlin.test.AfterTest +import io.ktor.client.engine.cio.CIO as ClientCIO + +class SseAuthenticationTest : AbstractAuthenticationTest() { + + private var httpClient: HttpClient? = null + + @AfterTest + fun closeHttpClient() { + httpClient?.close() + httpClient = null + } + + override fun Route.registerMcpServer(serverFactory: ApplicationCall.() -> Server) { + mcp { + serverFactory(call) + } + } + + override fun createClientTransport(baseUrl: String, user: String, pass: String): Transport { + val client = HttpClient(ClientCIO) { install(SSE) } + httpClient = client + return SseClientTransport( + client = client, + urlString = baseUrl, + requestBuilder = { basicAuth(user, pass) }, + ) + } +}