diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 51f98e96..20b9a7fa 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -16,6 +16,8 @@ jobs: validate-pr: runs-on: macos-latest-xlarge name: Validate PR + env: + JAVA_OPTS: "-Xmx8g -Dfile.encoding=UTF-8 -Djava.awt.headless=true -Dkotlin.daemon.jvm.options=-Xmx6g" steps: - uses: actions/checkout@v5 @@ -32,18 +34,41 @@ jobs: cache-read-only: true - name: Build with Gradle - run: ./gradlew clean ktlintCheck build koverLog koverHtmlReport - env: - JAVA_OPTS: "-Xmx8g -Dfile.encoding=UTF-8 -Djava.awt.headless=true -Dkotlin.daemon.jvm.options=-Xmx6g" + run: |- + ./gradlew clean ktlintCheck build koverLog koverHtmlReport + ./gradlew :kotlin-sdk-core:publishToMavenLocal :kotlin-sdk-client:publishToMavenLocal :kotlin-sdk-server:publishToMavenLocal + + - name: Build Kotlin-MCP-Client Sample + working-directory: ./samples/kotlin-mcp-client + run: ./../../gradlew clean build + + - name: Build Kotlin-MCP-Server Sample + working-directory: ./samples/kotlin-mcp-server + run: ./../../gradlew clean build + + - name: Build Weather-Stdio-Server Sample + working-directory: ./samples/weather-stdio-server + run: ./../../gradlew clean build - name: Upload Reports - if: always() + if: ${{ !cancelled() }} uses: actions/upload-artifact@v4 with: name: reports path: | **/build/reports/ + - name: Publish Test Report + uses: mikepenz/action-junit-report@v5 + if: ${{ !cancelled() }} # always run even if the previous step fails + with: + report_paths: '**/test-results/**/TEST-*.xml' + detailed_summary: true + flaky_summary: true + include_empty_in_summary: false + include_time_in_summary: true + annotate_only: true + - name: Disable Auto-Merge on Fail if: failure() && github.event_name == 'pull_request' run: gh pr merge --disable-auto "$PR_URL" diff --git a/.github/workflows/gradle-publish.yml b/.github/workflows/gradle-publish.yml index 1d85ccc6..53dde064 100644 --- a/.github/workflows/gradle-publish.yml +++ b/.github/workflows/gradle-publish.yml @@ -33,36 +33,16 @@ jobs: - name: Setup Gradle uses: gradle/actions/setup-gradle@v4 - - name: Verify publication configuration - run: ./gradlew jreleaserConfig - env: - JRELEASER_MAVENCENTRAL_USERNAME: ${{ secrets.OSSRH_USERNAME }} - JRELEASER_MAVENCENTRAL_PASSWORD: ${{ secrets.OSSRH_TOKEN }} - JRELEASER_GPG_PUBLIC_KEY: ${{ secrets.GPG_PUBLIC_KEY }} - JRELEASER_GPG_SECRET_KEY: ${{ secrets.GPG_SECRET_KEY }} - JRELEASER_GPG_PASSPHRASE: ${{ secrets.SIGNING_PASSPHRASE }} - JRELEASER_GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - name: Clean Build with Gradle run: ./gradlew clean build - env: - JRELEASER_MAVENCENTRAL_USERNAME: ${{ secrets.OSSRH_USERNAME }} - JRELEASER_MAVENCENTRAL_PASSWORD: ${{ secrets.OSSRH_TOKEN }} - JRELEASER_GPG_PUBLIC_KEY: ${{ secrets.GPG_PUBLIC_KEY }} - JRELEASER_GPG_SECRET_KEY: ${{ secrets.GPG_SECRET_KEY }} - JRELEASER_GPG_PASSPHRASE: ${{ secrets.SIGNING_PASSPHRASE }} - GPG_SECRET_KEY: ${{ secrets.GPG_SECRET_KEY }} - SIGNING_PASSPHRASE: ${{ secrets.SIGNING_PASSPHRASE }} - name: Publish to Maven Central Portal id: publish - run: ./gradlew publish jreleaserFullRelease --info --stacktrace -Djreleaser.verbose=true + run: ./gradlew publishToMavenCentral --no-configuration-cache env: - JRELEASER_MAVENCENTRAL_USERNAME: ${{ secrets.OSSRH_USERNAME }} - JRELEASER_MAVENCENTRAL_PASSWORD: ${{ secrets.OSSRH_TOKEN }} - JRELEASER_GPG_PUBLIC_KEY: ${{ secrets.GPG_PUBLIC_KEY }} - JRELEASER_GPG_SECRET_KEY: ${{ secrets.GPG_SECRET_KEY }} - JRELEASER_GPG_PASSPHRASE: ${{ secrets.SIGNING_PASSPHRASE }} - JRELEASER_GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + ORG_GRADLE_PROJECT_mavenCentralUsername: ${{ secrets.OSSRH_USERNAME }} + ORG_GRADLE_PROJECT_mavenCentralPassword: ${{ secrets.OSSRH_TOKEN }} + ORG_GRADLE_PROJECT_signingInMemoryKey: ${{ secrets.GPG_SECRET_KEY }} + ORG_GRADLE_PROJECT_signingInMemoryKeyPassword: ${{ secrets.SIGNING_PASSPHRASE }} GPG_SECRET_KEY: ${{ secrets.GPG_SECRET_KEY }} SIGNING_PASSPHRASE: ${{ secrets.SIGNING_PASSPHRASE }} diff --git a/build.gradle.kts b/build.gradle.kts index 81117d2a..f3497c34 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -6,7 +6,7 @@ plugins { allprojects { group = "io.modelcontextprotocol" - version = "0.7.1" + version = "0.7.2" } dependencies { diff --git a/buildSrc/build.gradle.kts b/buildSrc/build.gradle.kts index 16c0eb9b..84676a4f 100644 --- a/buildSrc/build.gradle.kts +++ b/buildSrc/build.gradle.kts @@ -12,5 +12,5 @@ dependencies { implementation(libs.kotlin.serialization) implementation(libs.kotlinx.atomicfu.gradle) implementation(libs.dokka.gradle) - implementation(libs.jreleaser.gradle) -} \ No newline at end of file + implementation(libs.maven.publish) +} diff --git a/buildSrc/src/main/kotlin/mcp.jreleaser.gradle.kts b/buildSrc/src/main/kotlin/mcp.jreleaser.gradle.kts deleted file mode 100644 index 9761aa7b..00000000 --- a/buildSrc/src/main/kotlin/mcp.jreleaser.gradle.kts +++ /dev/null @@ -1,75 +0,0 @@ -import org.jreleaser.model.Active - -plugins { - id("org.jreleaser") - id("mcp.publishing") -} - -jreleaser { - gitRootSearch = true - strict = true - - signing { - active = Active.ALWAYS - armored = true - artifacts = true - files = true - } - - deploy { - active = Active.ALWAYS - maven { - active = Active.ALWAYS - mavenCentral.create("ossrh") { - active = Active.ALWAYS - sign = true - url = "https://central.sonatype.com/api/v1/publisher" - applyMavenCentralRules = false - maxRetries = 240 - stagingRepository(layout.buildDirectory.dir("staging-deploy").get().asFile.path) - - // workaround: https://github.com/jreleaser/jreleaser/issues/1784 - afterEvaluate { - publishing.publications.forEach { publication -> - if (publication is MavenPublication) { - val pubName = publication.name - - if (!pubName.contains("jvm", ignoreCase = true) - && !pubName.contains("metadata", ignoreCase = true) - && !pubName.contains("kotlinMultiplatform", ignoreCase = true) - ) { - artifactOverride { - artifactId = when { - pubName.contains("wasm", ignoreCase = true) -> - "${project.name}-wasm-${pubName.lowercase().substringAfter("wasm")}" - - else -> "${project.name}-${pubName.lowercase()}" - } - jar = false - verifyPom = false - sourceJar = false - javadocJar = false - } - } - } - } - } - } - } - - checksum { - individual = false - artifacts = false - files = false - } - } - - release { - github { - skipRelease = true - skipTag = true - overwrite = false - token = "none" - } - } -} diff --git a/buildSrc/src/main/kotlin/mcp.multiplatform.gradle.kts b/buildSrc/src/main/kotlin/mcp.multiplatform.gradle.kts index 86569842..e1157168 100644 --- a/buildSrc/src/main/kotlin/mcp.multiplatform.gradle.kts +++ b/buildSrc/src/main/kotlin/mcp.multiplatform.gradle.kts @@ -10,25 +10,6 @@ plugins { id("org.jetbrains.kotlinx.atomicfu") } -// Generation library versions -val generateLibVersion by tasks.registering { - val outputDir = layout.buildDirectory.dir("generated-sources/libVersion") - outputs.dir(outputDir) - - doLast { - val sourceFile = outputDir.get().file("io/modelcontextprotocol/kotlin/sdk/LibVersion.kt").asFile - sourceFile.parentFile.mkdirs() - sourceFile.writeText( - """ - package io.modelcontextprotocol.kotlin.sdk - - public const val LIB_VERSION: String = "${project.version}" - - """.trimIndent() - ) - } -} - kotlin { jvm { compilerOptions.jvmTarget = JvmTarget.JVM_1_8 @@ -41,10 +22,4 @@ kotlin { explicitApi = ExplicitApiMode.Strict jvmToolchain(21) - - sourceSets { - commonMain { - kotlin.srcDir(generateLibVersion) - } - } } diff --git a/buildSrc/src/main/kotlin/mcp.publishing.gradle.kts b/buildSrc/src/main/kotlin/mcp.publishing.gradle.kts index bef396de..13551379 100644 --- a/buildSrc/src/main/kotlin/mcp.publishing.gradle.kts +++ b/buildSrc/src/main/kotlin/mcp.publishing.gradle.kts @@ -1,63 +1,52 @@ +import com.vanniktech.maven.publish.MavenPublishBaseExtension + plugins { `maven-publish` + id("com.vanniktech.maven.publish") signing } -val javadocJar by tasks.registering(Jar::class) { - archiveClassifier.set("javadoc") -} - -publishing { - publications.withType().configureEach { - if (name.contains("jvm", ignoreCase = true)) { - artifact(javadocJar) - } +mavenPublishing { + publishToMavenCentral(automaticRelease = true) + configureSigning(this) - pom { - name = project.name - description = "Kotlin implementation of the Model Context Protocol (MCP)" - url = "https://github.com/modelcontextprotocol/kotlin-sdk" - - licenses { - license { - name = "MIT License" - url = "https://github.com/modelcontextprotocol/kotlin-sdk/blob/main/LICENSE" - distribution = "repo" - } - } + pom { + name = project.name + description = "Kotlin implementation of the Model Context Protocol (MCP)" + url = "https://github.com/modelcontextprotocol/kotlin-sdk" - organization { - name = "Anthropic" - url = "https://www.anthropic.com" + licenses { + license { + name = "MIT License" + url = "https://github.com/modelcontextprotocol/kotlin-sdk/blob/main/LICENSE" + distribution = "repo" } + } - developers { - developer { - id = "JetBrains" - name = "JetBrains Team" - organization = "JetBrains" - organizationUrl = "https://www.jetbrains.com" - } - } + organization { + name = "Anthropic" + url = "https://www.anthropic.com" + } - scm { - url = "https://github.com/modelcontextprotocol/kotlin-sdk" - connection = "scm:git:git://github.com/modelcontextprotocol/kotlin-sdk.git" - developerConnection = "scm:git:git@github.com:modelcontextprotocol/kotlin-sdk.git" + developers { + developer { + id = "JetBrains" + name = "JetBrains Team" + organization = "JetBrains" + organizationUrl = "https://www.jetbrains.com" } } - } - repositories { - maven { - name = "staging" - url = uri(layout.buildDirectory.dir("staging-deploy")) + scm { + url = "https://github.com/modelcontextprotocol/kotlin-sdk" + connection = "scm:git:git://github.com/modelcontextprotocol/kotlin-sdk.git" + developerConnection = "scm:git:git@github.com:modelcontextprotocol/kotlin-sdk.git" } } } -signing { - val gpgKeyName = "GPG_SIGNING_KEY" +private fun Project.configureSigning(mavenPublishing: MavenPublishBaseExtension) { + val gpgKeyName = "GPG_SECRET_KEY" val gpgPassphraseName = "SIGNING_PASSPHRASE" val signingKey = providers.environmentVariable(gpgKeyName) .orElse(providers.gradleProperty(gpgKeyName)) @@ -65,7 +54,7 @@ signing { .orElse(providers.gradleProperty(gpgPassphraseName)) if (signingKey.isPresent) { - useInMemoryPgpKeys(signingKey.get(), signingPassphrase.get()) - sign(publishing.publications) + mavenPublishing.signAllPublications() + signing.useInMemoryPgpKeys(signingKey.get(), signingPassphrase.get()) } -} \ No newline at end of file +} diff --git a/gradle.properties b/gradle.properties index 85b95662..7e636348 100644 --- a/gradle.properties +++ b/gradle.properties @@ -7,6 +7,8 @@ org.jetbrains.dokka.experimental.gradle.pluginMode.noWarn=true # Kotlin kotlin.code.style=official kotlin.daemon.jvmargs=-Xmx4G +# Build JS targets using npm package manager https://kotlinlang.org/docs/js-project-setup.html#npm-dependencies +kotlin.js.yarn=false # MPP kotlin.mpp.enableCInteropCommonization=true kotlin.native.ignoreDisabledTargets=true diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 6728830a..088dd9cd 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -5,6 +5,8 @@ dokka = "2.0.0" atomicfu = "0.29.0" ktlint = "13.1.0" kover = "0.9.1" +mavenPublish = "0.34.0" +binaryCompatibilityValidatorPlugin = "0.18.1" # libraries version serialization = "1.9.0" @@ -13,14 +15,12 @@ coroutines = "1.10.2" kotlinx-io = "0.8.0" ktor = "3.3.0" logging = "7.0.13" -jreleaser = "1.19.0" -binaryCompatibilityValidatorPlugin = "0.18.1" slf4j = "2.0.17" -kotest = "5.9.1" +kotest = "6.0.3" awaitility = "4.3.0" # Samples -mcp-kotlin = "0.7.1" +mcp-kotlin = "0.7.2" anthropic = "2.7.0" shadow = "8.1.1" @@ -30,7 +30,7 @@ kotlin-gradle = { module = "org.jetbrains.kotlin:kotlin-gradle-plugin", version. kotlin-serialization = { module = "org.jetbrains.kotlin:kotlin-serialization", version.ref = "kotlin" } kotlinx-atomicfu-gradle = { module = "org.jetbrains.kotlinx:atomicfu-gradle-plugin", version.ref = "atomicfu" } dokka-gradle = { module = "org.jetbrains.dokka:dokka-gradle-plugin", version.ref = "dokka" } -jreleaser-gradle = { module = "org.jreleaser:jreleaser-gradle-plugin", version.ref = "jreleaser" } +maven-publish = { module = "com.vanniktech:gradle-maven-publish-plugin", version.ref = "mavenPublish" } # Kotlinx libraries kotlinx-serialization-json = { group = "org.jetbrains.kotlinx", name = "kotlinx-serialization-json", version.ref = "serialization" } @@ -57,7 +57,8 @@ slf4j-simple = { group = "org.slf4j", name = "slf4j-simple", version.ref = "slf4 # Samples ktor-client-cio = { group = "io.ktor", name = "ktor-client-cio", version.ref = "ktor" } ktor-server-cio = { group = "io.ktor", name = "ktor-server-cio", version.ref = "ktor" } -mcp-kotlin = { group = "io.modelcontextprotocol", name = "kotlin-sdk", version.ref = "mcp-kotlin" } +mcp-kotlin-client = { group = "io.modelcontextprotocol", name = "kotlin-sdk-client", version.ref = "mcp-kotlin" } +mcp-kotlin-server = { group = "io.modelcontextprotocol", name = "kotlin-sdk-server", version.ref = "mcp-kotlin" } anthropic-java = { group = "com.anthropic", name = "anthropic-java", version.ref = "anthropic" } ktor-client-content-negotiation = { group = "io.ktor", name = "ktor-client-content-negotiation", version.ref = "ktor" } ktor-serialization-kotlinx-json = { group = "io.ktor", name = "ktor-serialization-kotlinx-json", version.ref = "ktor" } diff --git a/kotlin-sdk-client/api/kotlin-sdk-client.api b/kotlin-sdk-client/api/kotlin-sdk-client.api index 00d80eb4..f0782da5 100644 --- a/kotlin-sdk-client/api/kotlin-sdk-client.api +++ b/kotlin-sdk-client/api/kotlin-sdk-client.api @@ -1,7 +1,3 @@ -public final class io/modelcontextprotocol/kotlin/sdk/LibVersionKt { - public static final field LIB_VERSION Ljava/lang/String; -} - public class io/modelcontextprotocol/kotlin/sdk/client/Client : io/modelcontextprotocol/kotlin/sdk/shared/Protocol { public fun (Lio/modelcontextprotocol/kotlin/sdk/Implementation;Lio/modelcontextprotocol/kotlin/sdk/client/ClientOptions;)V public synthetic fun (Lio/modelcontextprotocol/kotlin/sdk/Implementation;Lio/modelcontextprotocol/kotlin/sdk/client/ClientOptions;ILkotlin/jvm/internal/DefaultConstructorMarker;)V diff --git a/kotlin-sdk-client/build.gradle.kts b/kotlin-sdk-client/build.gradle.kts index 5da48e6b..7ba3de25 100644 --- a/kotlin-sdk-client/build.gradle.kts +++ b/kotlin-sdk-client/build.gradle.kts @@ -6,7 +6,6 @@ plugins { id("mcp.multiplatform") id("mcp.publishing") id("mcp.dokka") - id("mcp.jreleaser") alias(libs.plugins.kotlinx.binary.compatibility.validator) } diff --git a/kotlin-sdk-core/build.gradle.kts b/kotlin-sdk-core/build.gradle.kts index bea4162c..849c841f 100644 --- a/kotlin-sdk-core/build.gradle.kts +++ b/kotlin-sdk-core/build.gradle.kts @@ -6,10 +6,28 @@ plugins { id("mcp.multiplatform") id("mcp.publishing") id("mcp.dokka") - id("mcp.jreleaser") alias(libs.plugins.kotlinx.binary.compatibility.validator) } +// Generation library versions +val generateLibVersion by tasks.registering { + val outputDir = layout.buildDirectory.dir("generated-sources/libVersion") + outputs.dir(outputDir) + + doLast { + val sourceFile = outputDir.get().file("io/modelcontextprotocol/kotlin/sdk/LibVersion.kt").asFile + sourceFile.parentFile.mkdirs() + sourceFile.writeText( + """ + package io.modelcontextprotocol.kotlin.sdk + + public const val LIB_VERSION: String = "${project.version}" + + """.trimIndent(), + ) + } +} + kotlin { iosArm64() iosX64() @@ -31,6 +49,7 @@ kotlin { sourceSets { commonMain { + kotlin.srcDir(generateLibVersion) dependencies { api(libs.kotlinx.serialization.json) api(libs.kotlinx.coroutines.core) diff --git a/kotlin-sdk-server/api/kotlin-sdk-server.api b/kotlin-sdk-server/api/kotlin-sdk-server.api index 7e2ed4e1..ec09fea2 100644 --- a/kotlin-sdk-server/api/kotlin-sdk-server.api +++ b/kotlin-sdk-server/api/kotlin-sdk-server.api @@ -1,7 +1,3 @@ -public final class io/modelcontextprotocol/kotlin/sdk/LibVersionKt { - public static final field LIB_VERSION Ljava/lang/String; -} - public final class io/modelcontextprotocol/kotlin/sdk/server/KtorServerKt { public static final fun MCP (Lio/ktor/server/application/Application;Lkotlin/jvm/functions/Function1;)V public static final fun mcp (Lio/ktor/server/application/Application;Lkotlin/jvm/functions/Function1;)V diff --git a/kotlin-sdk-server/build.gradle.kts b/kotlin-sdk-server/build.gradle.kts index 80adddcc..c5feae0a 100644 --- a/kotlin-sdk-server/build.gradle.kts +++ b/kotlin-sdk-server/build.gradle.kts @@ -2,7 +2,6 @@ plugins { id("mcp.multiplatform") id("mcp.publishing") id("mcp.dokka") - id("mcp.jreleaser") alias(libs.plugins.kotlinx.binary.compatibility.validator) } diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/PromptIntegrationTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractPromptIntegrationTest.kt similarity index 56% rename from kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/PromptIntegrationTest.kt rename to kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractPromptIntegrationTest.kt index 54fd5fc8..d5644bbc 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/PromptIntegrationTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractPromptIntegrationTest.kt @@ -2,25 +2,26 @@ package io.modelcontextprotocol.kotlin.sdk.integration.kotlin import io.modelcontextprotocol.kotlin.sdk.GetPromptRequest import io.modelcontextprotocol.kotlin.sdk.GetPromptResult -import io.modelcontextprotocol.kotlin.sdk.ImageContent import io.modelcontextprotocol.kotlin.sdk.PromptArgument import io.modelcontextprotocol.kotlin.sdk.PromptMessage -import io.modelcontextprotocol.kotlin.sdk.PromptMessageContent import io.modelcontextprotocol.kotlin.sdk.Role import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities import io.modelcontextprotocol.kotlin.sdk.TextContent import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.test.runTest import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows import kotlin.test.assertEquals import kotlin.test.assertNotNull import kotlin.test.assertTrue -class PromptIntegrationTest : KotlinTestBase() { +abstract class AbstractPromptIntegrationTest : KotlinTestBase() { + + private val basicPromptName = "basic-prompt" + private val basicPromptDescription = "A basic prompt for testing" - private val testPromptName = "greeting" - private val testPromptDescription = "A simple greeting prompt" private val complexPromptName = "multimodal-prompt" private val complexPromptDescription = "A prompt with multiple content types" private val conversationPromptName = "conversation" @@ -28,6 +29,14 @@ class PromptIntegrationTest : KotlinTestBase() { private val strictPromptName = "strict-prompt" private val strictPromptDescription = "A prompt with required arguments" + private val largePromptName = "large-prompt" + private val largePromptDescription = "A very large prompt for testing" + private val largePromptContent = "X".repeat(100_000) // 100KB of data + + private val specialCharsPromptName = "special-chars-prompt" + private val specialCharsPromptDescription = "A prompt with special characters" + private val specialCharsContent = "!@#$%^&*()_+{}|:\"<>?~`-=[]\\;',./\n\t" + override fun configureServerCapabilities(): ServerCapabilities = ServerCapabilities( prompts = ServerCapabilities.Prompts( listChanged = true, @@ -35,10 +44,10 @@ class PromptIntegrationTest : KotlinTestBase() { ) override fun configureServer() { - // simple prompt with a name parameter + // basic prompt with a name parameter server.addPrompt( - name = testPromptName, - description = testPromptDescription, + name = basicPromptName, + description = basicPromptDescription, arguments = listOf( PromptArgument( name = "name", @@ -50,7 +59,7 @@ class PromptIntegrationTest : KotlinTestBase() { val name = request.arguments?.get("name") ?: "World" GetPromptResult( - description = testPromptDescription, + description = basicPromptDescription, messages = listOf( PromptMessage( role = Role.user, @@ -64,57 +73,117 @@ class PromptIntegrationTest : KotlinTestBase() { ) } - // prompt with multiple content types + // special chars prompt server.addPrompt( - name = complexPromptName, - description = complexPromptDescription, + name = specialCharsPromptName, + description = specialCharsPromptDescription, arguments = listOf( PromptArgument( - name = "topic", - description = "The topic to discuss", + name = "special", + description = "Special characters to include", required = false, ), + ), + ) { request -> + val special = request.arguments?.get("special") ?: specialCharsContent + + GetPromptResult( + description = specialCharsPromptDescription, + messages = listOf( + PromptMessage( + role = Role.user, + content = TextContent(text = "Special characters: $special"), + ), + PromptMessage( + role = Role.assistant, + content = TextContent(text = "Received special characters: $special"), + ), + ), + ) + } + + // very large prompt + server.addPrompt( + name = largePromptName, + description = largePromptDescription, + arguments = listOf( PromptArgument( - name = "includeImage", - description = "Whether to include an image", + name = "size", + description = "Size multiplier", required = false, ), ), ) { request -> - val topic = request.arguments?.get("topic") ?: "general knowledge" - val includeImage = request.arguments?.get("includeImage")?.toBoolean() ?: true - - val messages = mutableListOf() + val size = request.arguments?.get("size")?.toIntOrNull() ?: 1 + val content = largePromptContent.repeat(size) - messages.add( - PromptMessage( - role = Role.user, - content = TextContent(text = "I'd like to discuss $topic."), + GetPromptResult( + description = largePromptDescription, + messages = listOf( + PromptMessage( + role = Role.user, + content = TextContent(text = "Generate a large response"), + ), + PromptMessage( + role = Role.assistant, + content = TextContent(text = content), + ), ), ) + } - val assistantContents = mutableListOf() - assistantContents.add(TextContent(text = "I'd be happy to discuss $topic with you.")) - - if (includeImage) { - assistantContents.add( - ImageContent( - data = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BmMIQAAAABJRU5ErkJggg==", - mimeType = "image/png", - ), - ) + // complext prompt + server.addPrompt( + name = complexPromptName, + description = complexPromptDescription, + arguments = listOf( + PromptArgument(name = "arg1", description = "Argument 1", required = true), + PromptArgument(name = "arg2", description = "Argument 2", required = true), + PromptArgument(name = "arg3", description = "Argument 3", required = true), + PromptArgument(name = "arg4", description = "Argument 4", required = false), + PromptArgument(name = "arg5", description = "Argument 5", required = false), + PromptArgument(name = "arg6", description = "Argument 6", required = false), + PromptArgument(name = "arg7", description = "Argument 7", required = false), + PromptArgument(name = "arg8", description = "Argument 8", required = false), + PromptArgument(name = "arg9", description = "Argument 9", required = false), + PromptArgument(name = "arg10", description = "Argument 10", required = false), + ), + ) { request -> + // validate required arguments + val requiredArgs = listOf("arg1", "arg2", "arg3") + for (argName in requiredArgs) { + if (request.arguments?.get(argName) == null) { + throw IllegalArgumentException("Missing required argument: $argName") + } } - messages.add( - PromptMessage( - role = Role.assistant, - content = assistantContents[0], - ), - ) + val args = mutableMapOf() + for (i in 1..10) { + val argName = "arg$i" + val argValue = request.arguments?.get(argName) + if (argValue != null) { + args[argName] = argValue + } + } GetPromptResult( description = complexPromptDescription, - messages = messages, + messages = listOf( + PromptMessage( + role = Role.user, + content = TextContent( + text = "Arguments: ${ + args.entries.joinToString { + "${it.key}=${it.value}" + } + }", + ), + ), + PromptMessage( + role = Role.assistant, + content = TextContent(text = "Received ${args.size} arguments"), + ), + ), ) } @@ -224,10 +293,10 @@ class PromptIntegrationTest : KotlinTestBase() { assertNotNull(result, "List prompts result should not be null") assertTrue(result.prompts.isNotEmpty(), "Prompts list should not be empty") - val testPrompt = result.prompts.find { it.name == testPromptName } + val testPrompt = result.prompts.find { it.name == basicPromptName } assertNotNull(testPrompt, "Test prompt should be in the list") assertEquals( - testPromptDescription, + basicPromptDescription, testPrompt.description, "Prompt description should match", ) @@ -250,14 +319,14 @@ class PromptIntegrationTest : KotlinTestBase() { val testName = "Alice" val result = client.getPrompt( GetPromptRequest( - name = testPromptName, + name = basicPromptName, arguments = mapOf("name" to testName), ), ) assertNotNull(result, "Get prompt result should not be null") assertEquals( - testPromptDescription, + basicPromptDescription, result.description, "Prompt description should match", ) @@ -363,107 +432,242 @@ class PromptIntegrationTest : KotlinTestBase() { } @Test - fun testComplexContentTypes() = runBlocking(Dispatchers.IO) { - val topic = "artificial intelligence" + fun testMultipleMessagesAndRoles() = runBlocking(Dispatchers.IO) { + val topic = "climate change" val result = client.getPrompt( GetPromptRequest( - name = complexPromptName, - arguments = mapOf( - "topic" to topic, - "includeImage" to "true", - ), + name = conversationPromptName, + arguments = mapOf("topic" to topic), ), ) assertNotNull(result, "Get prompt result should not be null") assertEquals( - complexPromptDescription, + conversationPromptDescription, result.description, "Prompt description should match", ) assertTrue(result.messages.isNotEmpty(), "Prompt messages should not be empty") + assertEquals(6, result.messages.size, "Prompt should have 6 messages") + + val userMessages = result.messages.filter { it.role == Role.user } + val assistantMessages = result.messages.filter { it.role == Role.assistant } + + assertEquals(3, userMessages.size, "Should have 3 user messages") + assertEquals(3, assistantMessages.size, "Should have 3 assistant messages") + + for (i in 0 until result.messages.size) { + val expectedRole = if (i % 2 == 0) Role.user else Role.assistant + assertEquals( + expectedRole, + result.messages[i].role, + "Message $i should have role $expectedRole", + ) + } + + for (message in result.messages) { + val content = message.content as? TextContent + assertNotNull(content, "Message content should be TextContent") + val text = requireNotNull(content.text) + + // Either the message contains the topic or it's a generic conversation message + val containsTopic = text.contains(topic) + val isGenericMessage = text.contains("thank you") || text.contains("welcome") + + assertTrue( + containsTopic || isGenericMessage, + "Message should either contain the topic or be a generic conversation message", + ) + } + } + + @Test + fun testBasicPrompt() = runBlocking(Dispatchers.IO) { + val testName = "Alice" + val result = client.getPrompt( + GetPromptRequest( + name = basicPromptName, + arguments = mapOf("name" to testName), + ), + ) + + assertNotNull(result, "Get prompt result should not be null") + assertEquals(basicPromptDescription, result.description, "Prompt description should match") + assertEquals(2, result.messages.size, "Prompt should have 2 messages") val userMessage = result.messages.find { it.role == Role.user } assertNotNull(userMessage, "User message should be in the list") val userContent = userMessage.content as? TextContent assertNotNull(userContent, "User message content should be TextContent") - val userText2 = requireNotNull(userContent.text) - assertTrue(userText2.contains(topic), "User message should contain the topic") + assertEquals("Hello, $testName!", userContent.text, "User message content should match") val assistantMessage = result.messages.find { it.role == Role.assistant } assertNotNull(assistantMessage, "Assistant message should be in the list") val assistantContent = assistantMessage.content as? TextContent assertNotNull(assistantContent, "Assistant message content should be TextContent") - val assistantText = requireNotNull(assistantContent.text) - assertTrue( - assistantText.contains(topic), - "Assistant message should contain the topic", + assertEquals( + "Greetings, $testName! How can I assist you today?", + assistantContent.text, + "Assistant message content should match", ) + } + + @Test + fun testComplexPromptWithManyArguments() = runBlocking(Dispatchers.IO) { + val arguments = (1..10).associate { i -> "arg$i" to "value$i" } - val resultNoImage = client.getPrompt( + val result = client.getPrompt( GetPromptRequest( name = complexPromptName, - arguments = mapOf( - "topic" to topic, - "includeImage" to "false", - ), + arguments = arguments, ), ) - assertNotNull(resultNoImage, "Get prompt result (no image) should not be null") - assertEquals(2, resultNoImage.messages.size, "Prompt should have 2 messages") + assertNotNull(result, "Get prompt result should not be null") + assertEquals(complexPromptDescription, result.description, "Prompt description should match") + + assertEquals(2, result.messages.size, "Prompt should have 2 messages") + + val userMessage = result.messages.find { it.role == Role.user } + assertNotNull(userMessage, "User message should be in the list") + val userContent = userMessage.content as? TextContent + assertNotNull(userContent, "User message content should be TextContent") + + // verify all arguments + val text = userContent.text ?: "" + for (i in 1..10) { + assertTrue(text.contains("arg$i=value$i"), "Message should contain arg$i=value$i") + } + + val assistantMessage = result.messages.find { it.role == Role.assistant } + assertNotNull(assistantMessage, "Assistant message should be in the list") + val assistantContent = assistantMessage.content as? TextContent + assertNotNull(assistantContent, "Assistant message content should be TextContent") + assertEquals( + "Received 10 arguments", + assistantContent.text, + "Assistant message should indicate 10 arguments", + ) } @Test - fun testMultipleMessagesAndRoles() = runBlocking(Dispatchers.IO) { - val topic = "climate change" + fun testLargePrompt() = runBlocking(Dispatchers.IO) { val result = client.getPrompt( GetPromptRequest( - name = conversationPromptName, - arguments = mapOf("topic" to topic), + name = largePromptName, + arguments = mapOf("size" to "1"), ), ) assertNotNull(result, "Get prompt result should not be null") - assertEquals( - conversationPromptDescription, - result.description, - "Prompt description should match", + assertEquals(largePromptDescription, result.description, "Prompt description should match") + + assertEquals(2, result.messages.size, "Prompt should have 2 messages") + + val assistantMessage = result.messages.find { it.role == Role.assistant } + assertNotNull(assistantMessage, "Assistant message should be in the list") + val assistantContent = assistantMessage.content as? TextContent + assertNotNull(assistantContent, "Assistant message content should be TextContent") + val text = assistantContent.text ?: "" + assertEquals(100_000, text.length, "Assistant message should be 100KB in size") + } + + @Test + fun testSpecialCharacters() = runBlocking(Dispatchers.IO) { + val result = client.getPrompt( + GetPromptRequest( + name = specialCharsPromptName, + arguments = mapOf("special" to specialCharsContent), + ), ) - assertTrue(result.messages.isNotEmpty(), "Prompt messages should not be empty") - assertEquals(6, result.messages.size, "Prompt should have 6 messages") + assertNotNull(result, "Get prompt result should not be null") + assertEquals(specialCharsPromptDescription, result.description, "Prompt description should match") - val userMessages = result.messages.filter { it.role == Role.user } - val assistantMessages = result.messages.filter { it.role == Role.assistant } + assertEquals(2, result.messages.size, "Prompt should have 2 messages") - assertEquals(3, userMessages.size, "Should have 3 user messages") - assertEquals(3, assistantMessages.size, "Should have 3 assistant messages") + val userMessage = result.messages.find { it.role == Role.user } + assertNotNull(userMessage, "User message should be in the list") + val userContent = userMessage.content as? TextContent + assertNotNull(userContent, "User message content should be TextContent") + val userText = userContent.text ?: "" + assertTrue(userText.contains(specialCharsContent), "User message should contain special characters") - for (i in 0 until result.messages.size) { - val expectedRole = if (i % 2 == 0) Role.user else Role.assistant - assertEquals( - expectedRole, - result.messages[i].role, - "Message $i should have role $expectedRole", - ) + val assistantMessage = result.messages.find { it.role == Role.assistant } + assertNotNull(assistantMessage, "Assistant message should be in the list") + val assistantContent = assistantMessage.content as? TextContent + assertNotNull(assistantContent, "Assistant message content should be TextContent") + val assistantText = assistantContent.text ?: "" + assertTrue( + assistantText.contains(specialCharsContent), + "Assistant message should contain special characters", + ) + } + + @Test + fun testConcurrentPromptRequests() = runTest { + val concurrentCount = 10 + val results = mutableListOf() + + runBlocking { + repeat(concurrentCount) { index -> + launch { + val promptName = when (index % 4) { + 0 -> basicPromptName + 1 -> complexPromptName + 2 -> largePromptName + else -> specialCharsPromptName + } + + val arguments = when (promptName) { + basicPromptName -> mapOf("name" to "User$index") + complexPromptName -> mapOf("arg1" to "v1", "arg2" to "v2", "arg3" to "v3") + largePromptName -> mapOf("size" to "1") + else -> mapOf("special" to "!@#$%^&*()") + } + + val result = client.getPrompt( + GetPromptRequest( + name = promptName, + arguments = arguments, + ), + ) + + synchronized(results) { + results.add(result) + } + } + } } - for (message in result.messages) { - val content = message.content as? TextContent - assertNotNull(content, "Message content should be TextContent") - val text = requireNotNull(content.text) + assertEquals(concurrentCount, results.size, "All concurrent operations should complete") - // Either the message contains the topic or it's a generic conversation message - val containsTopic = text.contains(topic) - val isGenericMessage = text.contains("thank you") || text.contains("welcome") + results.forEach { result -> + assertNotNull(result, "Result should not be null") + assertTrue(result.messages.isNotEmpty(), "Result messages should not be empty") + } + } - assertTrue( - containsTopic || isGenericMessage, - "Message should either contain the topic or be a generic conversation message", - ) + @Test + fun testNonExistentPrompt() = runTest { + val nonExistentPromptName = "non-existent-prompt" + + val exception = assertThrows { + runBlocking { + client.getPrompt( + GetPromptRequest( + name = nonExistentPromptName, + arguments = mapOf("name" to "Test"), + ), + ) + } } + + val msg = exception.message ?: "" + val expectedMessage = "JSONRPCError(code=InternalError, message=Prompt not found: non-existent-prompt, data={})" + + assertEquals(expectedMessage, msg, "Unexpected error message for non-existent prompt") } } diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ResourceEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractResourceIntegrationTest.kt similarity index 85% rename from kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ResourceEdgeCasesTest.kt rename to kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractResourceIntegrationTest.kt index 9e23a3c0..95c643a8 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ResourceEdgeCasesTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractResourceIntegrationTest.kt @@ -20,7 +20,7 @@ import kotlin.test.assertEquals import kotlin.test.assertNotNull import kotlin.test.assertTrue -class ResourceEdgeCasesTest : KotlinTestBase() { +abstract class AbstractResourceIntegrationTest : KotlinTestBase() { private val testResourceUri = "test://example.txt" private val testResourceName = "Test Resource" @@ -67,6 +67,31 @@ class ResourceEdgeCasesTest : KotlinTestBase() { ) } + server.setRequestHandler(Method.Defined.ResourcesSubscribe) { _, _ -> + EmptyRequestResult() + } + + server.setRequestHandler(Method.Defined.ResourcesUnsubscribe) { _, _ -> + EmptyRequestResult() + } + + server.addResource( + uri = testResourceUri, + name = testResourceName, + description = testResourceDescription, + mimeType = "text/plain", + ) { request -> + ReadResourceResult( + contents = listOf( + TextResourceContents( + text = testResourceContent, + uri = request.uri, + mimeType = "text/plain", + ), + ), + ) + } + server.addResource( uri = binaryResourceUri, name = binaryResourceName, @@ -117,13 +142,41 @@ class ResourceEdgeCasesTest : KotlinTestBase() { ), ) } + } - server.setRequestHandler(Method.Defined.ResourcesSubscribe) { _, _ -> - EmptyRequestResult() - } + @Test + fun testListResources() = runBlocking(Dispatchers.IO) { + val result = client.listResources() - server.setRequestHandler(Method.Defined.ResourcesUnsubscribe) { _, _ -> - EmptyRequestResult() + assertNotNull(result, "List resources result should not be null") + assertTrue(result.resources.isNotEmpty(), "Resources list should not be empty") + + val testResource = result.resources.find { it.uri == testResourceUri } + assertNotNull(testResource, "Test resource should be in the list") + assertEquals(testResourceName, testResource.name, "Resource name should match") + assertEquals(testResourceDescription, testResource.description, "Resource description should match") + } + + @Test + fun testReadResource() = runBlocking(Dispatchers.IO) { + val result = client.readResource(ReadResourceRequest(uri = testResourceUri)) + + assertNotNull(result, "Read resource result should not be null") + assertTrue(result.contents.isNotEmpty(), "Resource contents should not be empty") + + val content = result.contents.firstOrNull() as? TextResourceContents + assertNotNull(content, "Resource content should be TextResourceContents") + assertEquals(testResourceContent, content.text, "Resource content should match") + } + + @Test + fun testSubscribeAndUnsubscribe() { + runBlocking(Dispatchers.IO) { + val subscribeResult = client.subscribeResource(SubscribeRequest(uri = testResourceUri)) + assertNotNull(subscribeResult, "Subscribe result should not be null") + + val unsubscribeResult = client.unsubscribeResource(UnsubscribeRequest(uri = testResourceUri)) + assertNotNull(unsubscribeResult, "Unsubscribe result should not be null") } } @@ -257,15 +310,4 @@ class ResourceEdgeCasesTest : KotlinTestBase() { assertTrue(result.contents.isNotEmpty(), "Result contents should not be empty") } } - - @Test - fun testSubscribeAndUnsubscribe() { - runBlocking(Dispatchers.IO) { - val subscribeResult = client.subscribeResource(SubscribeRequest(uri = testResourceUri)) - assertNotNull(subscribeResult, "Subscribe result should not be null") - - val unsubscribeResult = client.unsubscribeResource(UnsubscribeRequest(uri = testResourceUri)) - assertNotNull(unsubscribeResult, "Unsubscribe result should not be null") - } - } } diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolIntegrationTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractToolIntegrationTest.kt similarity index 61% rename from kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolIntegrationTest.kt rename to kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractToolIntegrationTest.kt index 84ae233d..3b0de299 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolIntegrationTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractToolIntegrationTest.kt @@ -1,6 +1,7 @@ package io.modelcontextprotocol.kotlin.sdk.integration.kotlin import io.kotest.assertions.json.shouldEqualJson +import io.modelcontextprotocol.kotlin.sdk.CallToolRequest import io.modelcontextprotocol.kotlin.sdk.CallToolResult import io.modelcontextprotocol.kotlin.sdk.CallToolResultBase import io.modelcontextprotocol.kotlin.sdk.ImageContent @@ -9,7 +10,10 @@ import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities import io.modelcontextprotocol.kotlin.sdk.TextContent import io.modelcontextprotocol.kotlin.sdk.Tool import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.delay +import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.test.runTest import kotlinx.serialization.json.JsonArray import kotlinx.serialization.json.JsonPrimitive import kotlinx.serialization.json.add @@ -25,7 +29,7 @@ import kotlin.test.assertEquals import kotlin.test.assertNotNull import kotlin.test.assertTrue -class ToolIntegrationTest : KotlinTestBase() { +abstract class AbstractToolIntegrationTest : KotlinTestBase() { private val testToolName = "echo" private val testToolDescription = "A simple echo tool that returns the input text" private val complexToolName = "calculator" @@ -35,6 +39,20 @@ class ToolIntegrationTest : KotlinTestBase() { private val multiContentToolName = "multi-content" private val multiContentToolDescription = "A tool that returns multiple content types" + private val basicToolName = "basic-tool" + private val basicToolDescription = "A basic tool for testing" + + private val largeToolName = "large-tool" + private val largeToolDescription = "A tool that returns a large response" + private val largeToolContent = "X".repeat(100_000) // 100KB of data + + private val slowToolName = "slow-tool" + private val slowToolDescription = "A tool that takes time to respond" + + private val specialCharsToolName = "special-chars-tool" + private val specialCharsToolDescription = "A tool that handles special characters" + private val specialCharsContent = "!@#$%^&*()_+{}|:\"<>?~`-=[]\\;',./\n\t" + override fun configureServerCapabilities(): ServerCapabilities = ServerCapabilities( tools = ServerCapabilities.Tools( listChanged = true, @@ -77,6 +95,114 @@ class ToolIntegrationTest : KotlinTestBase() { } private fun setupCalculatorTool() { + server.addTool( + name = basicToolName, + description = basicToolDescription, + inputSchema = Tool.Input( + properties = buildJsonObject { + put( + "text", + buildJsonObject { + put("type", "string") + put("description", "The text to echo back") + }, + ) + }, + required = listOf("text"), + ), + ) { request -> + val text = (request.arguments["text"] as? JsonPrimitive)?.content ?: "No text provided" + + CallToolResult( + content = listOf(TextContent(text = "Echo: $text")), + structuredContent = buildJsonObject { + put("result", text) + }, + ) + } + + server.addTool( + name = specialCharsToolName, + description = specialCharsToolDescription, + inputSchema = Tool.Input( + properties = buildJsonObject { + put( + "special", + buildJsonObject { + put("type", "string") + put("description", "Special characters to process") + }, + ) + }, + ), + ) { request -> + val special = (request.arguments["special"] as? JsonPrimitive)?.content ?: specialCharsContent + + CallToolResult( + content = listOf(TextContent(text = "Received special characters: $special")), + structuredContent = buildJsonObject { + put("special", special) + put("length", special.length) + }, + ) + } + + server.addTool( + name = slowToolName, + description = slowToolDescription, + inputSchema = Tool.Input( + properties = buildJsonObject { + put( + "delay", + buildJsonObject { + put("type", "integer") + put("description", "Delay in milliseconds") + }, + ) + }, + ), + ) { request -> + val delay = (request.arguments["delay"] as? JsonPrimitive)?.content?.toIntOrNull() ?: 1000 + + // simulate slow operation + runBlocking { + delay(delay.toLong()) + } + + CallToolResult( + content = listOf(TextContent(text = "Completed after ${delay}ms delay")), + structuredContent = buildJsonObject { + put("delay", delay) + }, + ) + } + + server.addTool( + name = largeToolName, + description = largeToolDescription, + inputSchema = Tool.Input( + properties = buildJsonObject { + put( + "size", + buildJsonObject { + put("type", "integer") + put("description", "Size multiplier") + }, + ) + }, + ), + ) { request -> + val size = (request.arguments["size"] as? JsonPrimitive)?.content?.toIntOrNull() ?: 1 + val content = largeToolContent.take(largeToolContent.length.coerceAtMost(size * 1000)) + + CallToolResult( + content = listOf(TextContent(text = content)), + structuredContent = buildJsonObject { + put("size", content.length) + }, + ) + } + server.addTool( name = complexToolName, description = complexToolDescription, @@ -466,4 +592,193 @@ class ToolIntegrationTest : KotlinTestBase() { "Text-only result should have 1 content item", ) } + + @Test + fun testComplexNestedSchema(): Unit = runBlocking(Dispatchers.IO) { + val userJson = buildJsonObject { + put("name", JsonPrimitive("John Galt")) + put("age", JsonPrimitive(30)) + put( + "address", + buildJsonObject { + put("street", JsonPrimitive("123 Main St")) + put("city", JsonPrimitive("New York")) + put("country", JsonPrimitive("USA")) + }, + ) + } + + val optionsJson = buildJsonArray { + add(JsonPrimitive("option1")) + add(JsonPrimitive("option2")) + add(JsonPrimitive("option3")) + } + + val arguments = buildJsonObject { + put("user", userJson) + put("options", optionsJson) + } + + val result = client.callTool( + CallToolRequest( + name = complexToolName, + arguments = arguments, + ), + ) as CallToolResultBase + + val actualContent = result.structuredContent.toString() + val expectedContent = """ + { + "operation": "add", + "a": 0.0, + "b": 0.0, + "result": 0.0, + "formattedResult": "0,00", + "precision": 2, + "tags": [] + } + """.trimIndent() + + actualContent shouldEqualJson expectedContent + } + + @Test + fun testLargeResponse(): Unit = runBlocking(Dispatchers.IO) { + val size = 10 + val arguments = mapOf("size" to size) + + val result = client.callTool(largeToolName, arguments) as CallToolResultBase + + val content = result.content.firstOrNull() as TextContent + assertNotNull(content, "Tool result content should be TextContent") + + val actualContent = result.structuredContent.toString() + val expectedContent = """ + { + "size" : 10000 + } + """.trimIndent() + + actualContent shouldEqualJson expectedContent + } + + @Test + fun testSlowTool(): Unit = runBlocking(Dispatchers.IO) { + val delay = 500 + val arguments = mapOf("delay" to delay) + + val startTime = System.currentTimeMillis() + val result = client.callTool(slowToolName, arguments) as CallToolResultBase + val endTime = System.currentTimeMillis() + + val content = result.content.firstOrNull() as? TextContent + assertNotNull(content, "Tool result content should be TextContent") + + assertTrue(endTime - startTime >= delay, "Tool should take at least the specified delay") + + val actualContent = result.structuredContent.toString() + val expectedContent = """ + { + "delay" : 500 + } + """.trimIndent() + + actualContent shouldEqualJson expectedContent + } + + @Test + fun testSpecialCharacters() { + runBlocking(Dispatchers.IO) { + val arguments = mapOf("special" to specialCharsContent) + + val result = client.callTool(specialCharsToolName, arguments) as CallToolResultBase + + val content = result.content.firstOrNull() as? TextContent + assertNotNull(content, "Tool result content should be TextContent") + val text = content.text ?: "" + + assertTrue(text.contains(specialCharsContent), "Result should contain the special characters") + + val actualContent = result.structuredContent.toString() + val expectedContent = """ + { + "special" : "!@#$%^&*()_+{}|:\"<>?~`-=[]\\;',./\n\t", + "length" : 34 + } + """.trimIndent() + + actualContent shouldEqualJson expectedContent + } + } + + @Test + fun testConcurrentToolCalls() = runTest { + val concurrentCount = 10 + val results = mutableListOf() + + runBlocking { + repeat(concurrentCount) { index -> + launch { + val toolName = when (index % 5) { + 0 -> basicToolName + 1 -> complexToolName + 2 -> largeToolName + 3 -> slowToolName + else -> specialCharsToolName + } + + val arguments = when (toolName) { + basicToolName -> mapOf("text" to "Concurrent call $index") + + complexToolName -> mapOf( + "user" to mapOf( + "name" to "User $index", + "age" to 20 + index, + "address" to mapOf( + "street" to "Street $index", + "city" to "City $index", + "country" to "Country $index", + ), + ), + ) + + largeToolName -> mapOf("size" to 1) + + slowToolName -> mapOf("delay" to 100) + + else -> mapOf("special" to "!@#$%^&*()") + } + + val result = client.callTool(toolName, arguments) + + synchronized(results) { + results.add(result) + } + } + } + } + + assertEquals(concurrentCount, results.size, "All concurrent operations should complete") + results.forEach { result -> + assertNotNull(result, "Result should not be null") + assertTrue(result.content.isNotEmpty(), "Result content should not be empty") + } + } + + @Test + fun testNonExistentTool() = runTest { + val nonExistentToolName = "non-existent-tool" + val arguments = mapOf("text" to "Test") + + val exception = assertThrows { + runBlocking { + client.callTool(nonExistentToolName, arguments) + } + } + + val msg = exception.message ?: "" + val expectedMessage = "JSONRPCError(code=InternalError, message=Tool not found: non-existent-tool, data={})" + + assertEquals(expectedMessage, msg, "Unexpected error message for non-existent tool") + } } diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/KotlinTestBase.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/KotlinTestBase.kt index e0ccd39b..563fa853 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/KotlinTestBase.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/KotlinTestBase.kt @@ -11,15 +11,24 @@ import io.modelcontextprotocol.kotlin.sdk.Implementation import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities import io.modelcontextprotocol.kotlin.sdk.client.Client import io.modelcontextprotocol.kotlin.sdk.client.SseClientTransport +import io.modelcontextprotocol.kotlin.sdk.client.StdioClientTransport import io.modelcontextprotocol.kotlin.sdk.integration.utils.Retry import io.modelcontextprotocol.kotlin.sdk.server.Server import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions +import io.modelcontextprotocol.kotlin.sdk.server.StdioServerTransport import io.modelcontextprotocol.kotlin.sdk.server.mcp import kotlinx.coroutines.runBlocking import kotlinx.coroutines.withTimeout +import kotlinx.io.Sink +import kotlinx.io.Source +import kotlinx.io.asSink +import kotlinx.io.asSource +import kotlinx.io.buffered import org.awaitility.kotlin.await import org.junit.jupiter.api.AfterEach import org.junit.jupiter.api.BeforeEach +import java.io.PipedInputStream +import java.io.PipedOutputStream import kotlin.time.Duration.Companion.seconds import io.ktor.server.cio.CIO as ServerCIO import io.ktor.server.sse.SSE as ServerSSE @@ -34,34 +43,62 @@ abstract class KotlinTestBase { protected lateinit var client: Client protected lateinit var serverEngine: EmbeddedServer<*, *> + // Transport selection + protected enum class TransportKind { SSE, STDIO } + protected open val transportKind: TransportKind = TransportKind.STDIO + + // STDIO-specific fields + private var stdioServerTransport: StdioServerTransport? = null + private var stdioClientInput: Source? = null + private var stdioClientOutput: Sink? = null + protected abstract fun configureServerCapabilities(): ServerCapabilities protected abstract fun configureServer() @BeforeEach fun setUp() { setupServer() - await - .ignoreExceptions() - .until { - port = runBlocking { serverEngine.engine.resolvedConnectors().first().port } - port != 0 - } + if (transportKind == TransportKind.SSE) { + await + .ignoreExceptions() + .until { + port = runBlocking { serverEngine.engine.resolvedConnectors().first().port } + port != 0 + } + } runBlocking { setupClient() } } protected suspend fun setupClient() { - val transport = SseClientTransport( - HttpClient(CIO) { - install(SSE) - }, - "http://$host:$port", - ) - client = Client( - Implementation("test", "1.0"), - ) - client.connect(transport) + when (transportKind) { + TransportKind.SSE -> { + val transport = SseClientTransport( + HttpClient(CIO) { + install(SSE) + }, + "http://$host:$port", + ) + client = Client( + Implementation("test", "1.0"), + ) + client.connect(transport) + } + + TransportKind.STDIO -> { + val input = checkNotNull(stdioClientInput) { "STDIO client input not initialized" } + val output = checkNotNull(stdioClientOutput) { "STDIO client output not initialized" } + val transport = StdioClientTransport( + input = input, + output = output, + ) + client = Client( + Implementation("test", "1.0"), + ) + client.connect(transport) + } + } } protected fun setupServer() { @@ -74,12 +111,37 @@ abstract class KotlinTestBase { configureServer() - serverEngine = embeddedServer(ServerCIO, host = host, port = port) { - install(ServerSSE) - routing { - mcp { server } + if (transportKind == TransportKind.SSE) { + serverEngine = embeddedServer(ServerCIO, host = host, port = port) { + install(ServerSSE) + routing { + mcp { server } + } + }.start(wait = false) + } else { + // Create in-memory stdio pipes: client->server and server->client + val clientToServerOut = PipedOutputStream() + val clientToServerIn = PipedInputStream(clientToServerOut) + + val serverToClientOut = PipedOutputStream() + val serverToClientIn = PipedInputStream(serverToClientOut) + + // Server transport reads from client and writes to client + val serverTransport = StdioServerTransport( + inputStream = clientToServerIn.asSource().buffered(), + outputStream = serverToClientOut.asSink().buffered(), + ) + stdioServerTransport = serverTransport + + // Prepare client-side streams for later client initialization + stdioClientInput = serverToClientIn.asSource().buffered() + stdioClientOutput = clientToServerOut.asSink().buffered() + + // Start server transport by connecting the server + runBlocking { + server.connect(serverTransport) } - }.start(wait = false) + } } @AfterEach @@ -98,11 +160,25 @@ abstract class KotlinTestBase { } // stop server - if (::serverEngine.isInitialized) { - try { - serverEngine.stop(500, 1000) - } catch (e: Exception) { - println("Warning: Error during server stop: ${e.message}") + if (transportKind == TransportKind.SSE) { + if (::serverEngine.isInitialized) { + try { + serverEngine.stop(500, 1000) + } catch (e: Exception) { + println("Warning: Error during server stop: ${e.message}") + } + } + } else { + stdioServerTransport?.let { + try { + runBlocking { it.close() } + } catch (e: Exception) { + println("Warning: Error during stdio server stop: ${e.message}") + } finally { + stdioServerTransport = null + stdioClientInput = null + stdioClientOutput = null + } } } } diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/PromptEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/PromptEdgeCasesTest.kt deleted file mode 100644 index 31376332..00000000 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/PromptEdgeCasesTest.kt +++ /dev/null @@ -1,391 +0,0 @@ -package io.modelcontextprotocol.kotlin.sdk.integration.kotlin - -import io.modelcontextprotocol.kotlin.sdk.GetPromptRequest -import io.modelcontextprotocol.kotlin.sdk.GetPromptResult -import io.modelcontextprotocol.kotlin.sdk.PromptArgument -import io.modelcontextprotocol.kotlin.sdk.PromptMessage -import io.modelcontextprotocol.kotlin.sdk.Role -import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities -import io.modelcontextprotocol.kotlin.sdk.TextContent -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.launch -import kotlinx.coroutines.runBlocking -import kotlinx.coroutines.test.runTest -import org.junit.jupiter.api.Test -import org.junit.jupiter.api.assertThrows -import kotlin.test.assertEquals -import kotlin.test.assertNotNull -import kotlin.test.assertTrue - -class PromptEdgeCasesTest : KotlinTestBase() { - - private val basicPromptName = "basic-prompt" - private val basicPromptDescription = "A basic prompt for testing" - - private val complexPromptName = "complex-prompt" - private val complexPromptDescription = "A complex prompt with many arguments" - - private val largePromptName = "large-prompt" - private val largePromptDescription = "A very large prompt for testing" - private val largePromptContent = "X".repeat(100_000) // 100KB of data - - private val specialCharsPromptName = "special-chars-prompt" - private val specialCharsPromptDescription = "A prompt with special characters" - private val specialCharsContent = "!@#$%^&*()_+{}|:\"<>?~`-=[]\\;',./\n\t" - - override fun configureServerCapabilities(): ServerCapabilities = ServerCapabilities( - prompts = ServerCapabilities.Prompts( - listChanged = true, - ), - ) - - override fun configureServer() { - server.addPrompt( - name = basicPromptName, - description = basicPromptDescription, - arguments = listOf( - PromptArgument( - name = "name", - description = "The name to greet", - required = true, - ), - ), - ) { request -> - val name = request.arguments?.get("name") ?: "World" - - GetPromptResult( - description = basicPromptDescription, - messages = listOf( - PromptMessage( - role = Role.user, - content = TextContent(text = "Hello, $name!"), - ), - PromptMessage( - role = Role.assistant, - content = TextContent(text = "Greetings, $name! How can I assist you today?"), - ), - ), - ) - } - - server.addPrompt( - name = complexPromptName, - description = complexPromptDescription, - arguments = listOf( - PromptArgument(name = "arg1", description = "Argument 1", required = true), - PromptArgument(name = "arg2", description = "Argument 2", required = true), - PromptArgument(name = "arg3", description = "Argument 3", required = true), - PromptArgument(name = "arg4", description = "Argument 4", required = false), - PromptArgument(name = "arg5", description = "Argument 5", required = false), - PromptArgument(name = "arg6", description = "Argument 6", required = false), - PromptArgument(name = "arg7", description = "Argument 7", required = false), - PromptArgument(name = "arg8", description = "Argument 8", required = false), - PromptArgument(name = "arg9", description = "Argument 9", required = false), - PromptArgument(name = "arg10", description = "Argument 10", required = false), - ), - ) { request -> - // validate required arguments - val requiredArgs = listOf("arg1", "arg2", "arg3") - for (argName in requiredArgs) { - if (request.arguments?.get(argName) == null) { - throw IllegalArgumentException("Missing required argument: $argName") - } - } - - val args = mutableMapOf() - for (i in 1..10) { - val argName = "arg$i" - val argValue = request.arguments?.get(argName) - if (argValue != null) { - args[argName] = argValue - } - } - - GetPromptResult( - description = complexPromptDescription, - messages = listOf( - PromptMessage( - role = Role.user, - content = TextContent( - text = "Arguments: ${ - args.entries.joinToString { - "${it.key}=${it.value}" - } - }", - ), - ), - PromptMessage( - role = Role.assistant, - content = TextContent(text = "Received ${args.size} arguments"), - ), - ), - ) - } - - // very large prompt - server.addPrompt( - name = largePromptName, - description = largePromptDescription, - arguments = listOf( - PromptArgument( - name = "size", - description = "Size multiplier", - required = false, - ), - ), - ) { request -> - val size = request.arguments?.get("size")?.toIntOrNull() ?: 1 - val content = largePromptContent.repeat(size) - - GetPromptResult( - description = largePromptDescription, - messages = listOf( - PromptMessage( - role = Role.user, - content = TextContent(text = "Generate a large response"), - ), - PromptMessage( - role = Role.assistant, - content = TextContent(text = content), - ), - ), - ) - } - - server.addPrompt( - name = specialCharsPromptName, - description = specialCharsPromptDescription, - arguments = listOf( - PromptArgument( - name = "special", - description = "Special characters to include", - required = false, - ), - ), - ) { request -> - val special = request.arguments?.get("special") ?: specialCharsContent - - GetPromptResult( - description = specialCharsPromptDescription, - messages = listOf( - PromptMessage( - role = Role.user, - content = TextContent(text = "Special characters: $special"), - ), - PromptMessage( - role = Role.assistant, - content = TextContent(text = "Received special characters: $special"), - ), - ), - ) - } - } - - @Test - fun testBasicPrompt() = runBlocking(Dispatchers.IO) { - val testName = "Alice" - val result = client.getPrompt( - GetPromptRequest( - name = basicPromptName, - arguments = mapOf("name" to testName), - ), - ) - - assertNotNull(result, "Get prompt result should not be null") - assertEquals(basicPromptDescription, result.description, "Prompt description should match") - - assertEquals(2, result.messages.size, "Prompt should have 2 messages") - - val userMessage = result.messages.find { it.role == Role.user } - assertNotNull(userMessage, "User message should be in the list") - val userContent = userMessage.content as? TextContent - assertNotNull(userContent, "User message content should be TextContent") - assertEquals("Hello, $testName!", userContent.text, "User message content should match") - - val assistantMessage = result.messages.find { it.role == Role.assistant } - assertNotNull(assistantMessage, "Assistant message should be in the list") - val assistantContent = assistantMessage.content as? TextContent - assertNotNull(assistantContent, "Assistant message content should be TextContent") - assertEquals( - "Greetings, $testName! How can I assist you today?", - assistantContent.text, - "Assistant message content should match", - ) - } - - @Test - fun testComplexPromptWithManyArguments() = runBlocking(Dispatchers.IO) { - val arguments = (1..10).associate { i -> "arg$i" to "value$i" } - - val result = client.getPrompt( - GetPromptRequest( - name = complexPromptName, - arguments = arguments, - ), - ) - - assertNotNull(result, "Get prompt result should not be null") - assertEquals(complexPromptDescription, result.description, "Prompt description should match") - - assertEquals(2, result.messages.size, "Prompt should have 2 messages") - - val userMessage = result.messages.find { it.role == Role.user } - assertNotNull(userMessage, "User message should be in the list") - val userContent = userMessage.content as? TextContent - assertNotNull(userContent, "User message content should be TextContent") - - // verify all arguments - val text = userContent.text ?: "" - for (i in 1..10) { - assertTrue(text.contains("arg$i=value$i"), "Message should contain arg$i=value$i") - } - - val assistantMessage = result.messages.find { it.role == Role.assistant } - assertNotNull(assistantMessage, "Assistant message should be in the list") - val assistantContent = assistantMessage.content as? TextContent - assertNotNull(assistantContent, "Assistant message content should be TextContent") - assertEquals( - "Received 10 arguments", - assistantContent.text, - "Assistant message should indicate 10 arguments", - ) - } - - @Test - fun testLargePrompt() = runBlocking(Dispatchers.IO) { - val result = client.getPrompt( - GetPromptRequest( - name = largePromptName, - arguments = mapOf("size" to "1"), - ), - ) - - assertNotNull(result, "Get prompt result should not be null") - assertEquals(largePromptDescription, result.description, "Prompt description should match") - - assertEquals(2, result.messages.size, "Prompt should have 2 messages") - - val assistantMessage = result.messages.find { it.role == Role.assistant } - assertNotNull(assistantMessage, "Assistant message should be in the list") - val assistantContent = assistantMessage.content as? TextContent - assertNotNull(assistantContent, "Assistant message content should be TextContent") - val text = assistantContent.text ?: "" - assertEquals(100_000, text.length, "Assistant message should be 100KB in size") - } - - @Test - fun testSpecialCharacters() = runBlocking(Dispatchers.IO) { - val result = client.getPrompt( - GetPromptRequest( - name = specialCharsPromptName, - arguments = mapOf("special" to specialCharsContent), - ), - ) - - assertNotNull(result, "Get prompt result should not be null") - assertEquals(specialCharsPromptDescription, result.description, "Prompt description should match") - - assertEquals(2, result.messages.size, "Prompt should have 2 messages") - - val userMessage = result.messages.find { it.role == Role.user } - assertNotNull(userMessage, "User message should be in the list") - val userContent = userMessage.content as? TextContent - assertNotNull(userContent, "User message content should be TextContent") - val userText = userContent.text ?: "" - assertTrue(userText.contains(specialCharsContent), "User message should contain special characters") - - val assistantMessage = result.messages.find { it.role == Role.assistant } - assertNotNull(assistantMessage, "Assistant message should be in the list") - val assistantContent = assistantMessage.content as? TextContent - assertNotNull(assistantContent, "Assistant message content should be TextContent") - val assistantText = assistantContent.text ?: "" - assertTrue( - assistantText.contains(specialCharsContent), - "Assistant message should contain special characters", - ) - } - - @Test - fun testMissingRequiredArguments() = runTest { - val exception = assertThrows { - runBlocking { - client.getPrompt( - GetPromptRequest( - name = complexPromptName, - arguments = mapOf("arg4" to "value4", "arg5" to "value5"), - ), - ) - } - } - - val msg = exception.message ?: "" - val expectedMessage = "JSONRPCError(code=InternalError, message=Missing required argument: arg1, data={})" - - assertEquals(expectedMessage, msg, "Unexpected error message for missing required argument") - } - - @Test - fun testConcurrentPromptRequests() = runTest { - val concurrentCount = 10 - val results = mutableListOf() - - runBlocking { - repeat(concurrentCount) { index -> - launch { - val promptName = when (index % 4) { - 0 -> basicPromptName - 1 -> complexPromptName - 2 -> largePromptName - else -> specialCharsPromptName - } - - val arguments = when (promptName) { - basicPromptName -> mapOf("name" to "User$index") - complexPromptName -> mapOf("arg1" to "v1", "arg2" to "v2", "arg3" to "v3") - largePromptName -> mapOf("size" to "1") - else -> mapOf("special" to "!@#$%^&*()") - } - - val result = client.getPrompt( - GetPromptRequest( - name = promptName, - arguments = arguments, - ), - ) - - synchronized(results) { - results.add(result) - } - } - } - } - - assertEquals(concurrentCount, results.size, "All concurrent operations should complete") - - results.forEach { result -> - assertNotNull(result, "Result should not be null") - assertTrue(result.messages.isNotEmpty(), "Result messages should not be empty") - } - } - - @Test - fun testNonExistentPrompt() = runTest { - val nonExistentPromptName = "non-existent-prompt" - - val exception = assertThrows { - runBlocking { - client.getPrompt( - GetPromptRequest( - name = nonExistentPromptName, - arguments = mapOf("name" to "Test"), - ), - ) - } - } - - val msg = exception.message ?: "" - val expectedMessage = "JSONRPCError(code=InternalError, message=Prompt not found: non-existent-prompt, data={})" - - assertEquals(expectedMessage, msg, "Unexpected error message for non-existent prompt") - } -} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ResourceIntegrationTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ResourceIntegrationTest.kt deleted file mode 100644 index 5ea9bbd0..00000000 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ResourceIntegrationTest.kt +++ /dev/null @@ -1,94 +0,0 @@ -package io.modelcontextprotocol.kotlin.sdk.integration.kotlin - -import io.modelcontextprotocol.kotlin.sdk.EmptyRequestResult -import io.modelcontextprotocol.kotlin.sdk.Method -import io.modelcontextprotocol.kotlin.sdk.ReadResourceRequest -import io.modelcontextprotocol.kotlin.sdk.ReadResourceResult -import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities -import io.modelcontextprotocol.kotlin.sdk.SubscribeRequest -import io.modelcontextprotocol.kotlin.sdk.TextResourceContents -import io.modelcontextprotocol.kotlin.sdk.UnsubscribeRequest -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.runBlocking -import org.junit.jupiter.api.Test -import kotlin.test.assertEquals -import kotlin.test.assertNotNull -import kotlin.test.assertTrue - -class ResourceIntegrationTest : KotlinTestBase() { - - private val testResourceUri = "test://example.txt" - private val testResourceName = "Test Resource" - private val testResourceDescription = "A test resource for integration testing" - private val testResourceContent = "This is the content of the test resource." - - override fun configureServerCapabilities(): ServerCapabilities = ServerCapabilities( - resources = ServerCapabilities.Resources( - subscribe = true, - listChanged = true, - ), - ) - - override fun configureServer() { - server.addResource( - uri = testResourceUri, - name = testResourceName, - description = testResourceDescription, - mimeType = "text/plain", - ) { request -> - ReadResourceResult( - contents = listOf( - TextResourceContents( - text = testResourceContent, - uri = request.uri, - mimeType = "text/plain", - ), - ), - ) - } - - server.setRequestHandler(Method.Defined.ResourcesSubscribe) { _, _ -> - EmptyRequestResult() - } - - server.setRequestHandler(Method.Defined.ResourcesUnsubscribe) { _, _ -> - EmptyRequestResult() - } - } - - @Test - fun testListResources() = runBlocking(Dispatchers.IO) { - val result = client.listResources() - - assertNotNull(result, "List resources result should not be null") - assertTrue(result.resources.isNotEmpty(), "Resources list should not be empty") - - val testResource = result.resources.find { it.uri == testResourceUri } - assertNotNull(testResource, "Test resource should be in the list") - assertEquals(testResourceName, testResource.name, "Resource name should match") - assertEquals(testResourceDescription, testResource.description, "Resource description should match") - } - - @Test - fun testReadResource() = runBlocking(Dispatchers.IO) { - val result = client.readResource(ReadResourceRequest(uri = testResourceUri)) - - assertNotNull(result, "Read resource result should not be null") - assertTrue(result.contents.isNotEmpty(), "Resource contents should not be empty") - - val content = result.contents.firstOrNull() as? TextResourceContents - assertNotNull(content, "Resource content should be TextResourceContents") - assertEquals(testResourceContent, content.text, "Resource content should match") - } - - @Test - fun testSubscribeAndUnsubscribe() { - runBlocking(Dispatchers.IO) { - val subscribeResult = client.subscribeResource(SubscribeRequest(uri = testResourceUri)) - assertNotNull(subscribeResult, "Subscribe result should not be null") - - val unsubscribeResult = client.unsubscribeResource(UnsubscribeRequest(uri = testResourceUri)) - assertNotNull(unsubscribeResult, "Unsubscribe result should not be null") - } - } -} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolEdgeCasesTest.kt deleted file mode 100644 index c30cffb0..00000000 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolEdgeCasesTest.kt +++ /dev/null @@ -1,489 +0,0 @@ -package io.modelcontextprotocol.kotlin.sdk.integration.kotlin - -import io.kotest.assertions.json.shouldEqualJson -import io.modelcontextprotocol.kotlin.sdk.CallToolRequest -import io.modelcontextprotocol.kotlin.sdk.CallToolResult -import io.modelcontextprotocol.kotlin.sdk.CallToolResultBase -import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities -import io.modelcontextprotocol.kotlin.sdk.TextContent -import io.modelcontextprotocol.kotlin.sdk.Tool -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.delay -import kotlinx.coroutines.launch -import kotlinx.coroutines.runBlocking -import kotlinx.coroutines.test.runTest -import kotlinx.serialization.json.JsonArray -import kotlinx.serialization.json.JsonObject -import kotlinx.serialization.json.JsonPrimitive -import kotlinx.serialization.json.add -import kotlinx.serialization.json.buildJsonArray -import kotlinx.serialization.json.buildJsonObject -import kotlinx.serialization.json.put -import org.junit.jupiter.api.Test -import org.junit.jupiter.api.assertThrows -import kotlin.test.assertEquals -import kotlin.test.assertNotNull -import kotlin.test.assertTrue - -class ToolEdgeCasesTest : KotlinTestBase() { - - private val basicToolName = "basic-tool" - private val basicToolDescription = "A basic tool for testing" - - private val complexToolName = "complex-tool" - private val complexToolDescription = "A complex tool with nested schema" - - private val largeToolName = "large-tool" - private val largeToolDescription = "A tool that returns a large response" - private val largeToolContent = "X".repeat(100_000) // 100KB of data - - private val slowToolName = "slow-tool" - private val slowToolDescription = "A tool that takes time to respond" - - private val specialCharsToolName = "special-chars-tool" - private val specialCharsToolDescription = "A tool that handles special characters" - private val specialCharsContent = "!@#$%^&*()_+{}|:\"<>?~`-=[]\\;',./\n\t" - - override fun configureServerCapabilities(): ServerCapabilities = ServerCapabilities( - tools = ServerCapabilities.Tools( - listChanged = true, - ), - ) - - override fun configureServer() { - server.addTool( - name = basicToolName, - description = basicToolDescription, - inputSchema = Tool.Input( - properties = buildJsonObject { - put( - "text", - buildJsonObject { - put("type", "string") - put("description", "The text to echo back") - }, - ) - }, - required = listOf("text"), - ), - ) { request -> - val text = (request.arguments["text"] as? JsonPrimitive)?.content ?: "No text provided" - - CallToolResult( - content = listOf(TextContent(text = "Echo: $text")), - structuredContent = buildJsonObject { - put("result", text) - }, - ) - } - - server.addTool( - name = complexToolName, - description = complexToolDescription, - inputSchema = Tool.Input( - properties = buildJsonObject { - put( - "user", - buildJsonObject { - put("type", "object") - put("description", "User information") - put( - "properties", - buildJsonObject { - put( - "name", - buildJsonObject { - put("type", "string") - put("description", "User's name") - }, - ) - put( - "age", - buildJsonObject { - put("type", "integer") - put("description", "User's age") - }, - ) - put( - "address", - buildJsonObject { - put("type", "object") - put("description", "User's address") - put( - "properties", - buildJsonObject { - put( - "street", - buildJsonObject { - put("type", "string") - }, - ) - put( - "city", - buildJsonObject { - put("type", "string") - }, - ) - put( - "country", - buildJsonObject { - put("type", "string") - }, - ) - }, - ) - }, - ) - }, - ) - }, - ) - put( - "options", - buildJsonObject { - put("type", "array") - put("description", "Additional options") - put( - "items", - buildJsonObject { - put("type", "string") - }, - ) - }, - ) - }, - required = listOf("user"), - ), - ) { request -> - val user = request.arguments["user"] as? JsonObject - val name = (user?.get("name") as? JsonPrimitive)?.content ?: "Unknown" - val age = (user?.get("age") as? JsonPrimitive)?.content?.toIntOrNull() ?: 0 - - val address = user?.get("address") as? JsonObject - val street = (address?.get("street") as? JsonPrimitive)?.content ?: "Unknown" - val city = (address?.get("city") as? JsonPrimitive)?.content ?: "Unknown" - val country = (address?.get("country") as? JsonPrimitive)?.content ?: "Unknown" - - val options = (request.arguments["options"] as? JsonArray)?.mapNotNull { - (it as? JsonPrimitive)?.content - } ?: emptyList() - - val summary = - "User: $name, Age: $age, Address: $street, $city, $country, Options: ${options.joinToString(", ")}" - - CallToolResult( - content = listOf(TextContent(text = summary)), - structuredContent = buildJsonObject { - put("name", name) - put("age", age) - put( - "address", - buildJsonObject { - put("street", street) - put("city", city) - put("country", country) - }, - ) - put( - "options", - buildJsonArray { - options.forEach { add(it) } - }, - ) - }, - ) - } - - server.addTool( - name = largeToolName, - description = largeToolDescription, - inputSchema = Tool.Input( - properties = buildJsonObject { - put( - "size", - buildJsonObject { - put("type", "integer") - put("description", "Size multiplier") - }, - ) - }, - ), - ) { request -> - val size = (request.arguments["size"] as? JsonPrimitive)?.content?.toIntOrNull() ?: 1 - val content = largeToolContent.take(largeToolContent.length.coerceAtMost(size * 1000)) - - CallToolResult( - content = listOf(TextContent(text = content)), - structuredContent = buildJsonObject { - put("size", content.length) - }, - ) - } - - server.addTool( - name = slowToolName, - description = slowToolDescription, - inputSchema = Tool.Input( - properties = buildJsonObject { - put( - "delay", - buildJsonObject { - put("type", "integer") - put("description", "Delay in milliseconds") - }, - ) - }, - ), - ) { request -> - val delay = (request.arguments["delay"] as? JsonPrimitive)?.content?.toIntOrNull() ?: 1000 - - // simulate slow operation - runBlocking { - delay(delay.toLong()) - } - - CallToolResult( - content = listOf(TextContent(text = "Completed after ${delay}ms delay")), - structuredContent = buildJsonObject { - put("delay", delay) - }, - ) - } - - server.addTool( - name = specialCharsToolName, - description = specialCharsToolDescription, - inputSchema = Tool.Input( - properties = buildJsonObject { - put( - "special", - buildJsonObject { - put("type", "string") - put("description", "Special characters to process") - }, - ) - }, - ), - ) { request -> - val special = (request.arguments["special"] as? JsonPrimitive)?.content ?: specialCharsContent - - CallToolResult( - content = listOf(TextContent(text = "Received special characters: $special")), - structuredContent = buildJsonObject { - put("special", special) - put("length", special.length) - }, - ) - } - } - - @Test - fun testBasicTool(): Unit = runBlocking(Dispatchers.IO) { - val testText = "Hello, world!" - val arguments = mapOf("text" to testText) - - val result = client.callTool(basicToolName, arguments) as CallToolResultBase - - val expectedToolResult = "[TextContent(text=Echo: Hello, world!, annotations=null)]" - assertEquals(expectedToolResult, result.content.toString(), "Unexpected tool result") - - val actualContent = result.structuredContent.toString() - val expectedContent = """ - { - "result" : "Hello, world!" - } - """.trimIndent() - - actualContent shouldEqualJson expectedContent - } - - @Test - fun testComplexNestedSchema(): Unit = runBlocking(Dispatchers.IO) { - val userJson = buildJsonObject { - put("name", JsonPrimitive("John Galt")) - put("age", JsonPrimitive(30)) - put( - "address", - buildJsonObject { - put("street", JsonPrimitive("123 Main St")) - put("city", JsonPrimitive("New York")) - put("country", JsonPrimitive("USA")) - }, - ) - } - - val optionsJson = buildJsonArray { - add(JsonPrimitive("option1")) - add(JsonPrimitive("option2")) - add(JsonPrimitive("option3")) - } - - val arguments = buildJsonObject { - put("user", userJson) - put("options", optionsJson) - } - - val result = client.callTool( - CallToolRequest( - name = complexToolName, - arguments = arguments, - ), - ) as CallToolResultBase - - val actualContent = result.structuredContent.toString() - val expectedContent = """ - { - "name" : "John Galt", - "age" : 30, - "address" : { - "street" : "123 Main St", - "city" : "New York", - "country" : "USA" - }, - "options" : [ "option1", "option2", "option3" ] - } - """.trimIndent() - - actualContent shouldEqualJson expectedContent - } - - @Test - fun testLargeResponse(): Unit = runBlocking(Dispatchers.IO) { - val size = 10 - val arguments = mapOf("size" to size) - - val result = client.callTool(largeToolName, arguments) as CallToolResultBase - - val content = result.content.firstOrNull() as TextContent - assertNotNull(content, "Tool result content should be TextContent") - - val actualContent = result.structuredContent.toString() - val expectedContent = """ - { - "size" : 10000 - } - """.trimIndent() - - actualContent shouldEqualJson expectedContent - } - - @Test - fun testSlowTool(): Unit = runBlocking(Dispatchers.IO) { - val delay = 500 - val arguments = mapOf("delay" to delay) - - val startTime = System.currentTimeMillis() - val result = client.callTool(slowToolName, arguments) as CallToolResultBase - val endTime = System.currentTimeMillis() - - val content = result.content.firstOrNull() as? TextContent - assertNotNull(content, "Tool result content should be TextContent") - - assertTrue(endTime - startTime >= delay, "Tool should take at least the specified delay") - - val actualContent = result.structuredContent.toString() - val expectedContent = """ - { - "delay" : 500 - } - """.trimIndent() - - actualContent shouldEqualJson expectedContent - } - - @Test - fun testSpecialCharacters() { - runBlocking(Dispatchers.IO) { - val arguments = mapOf("special" to specialCharsContent) - - val result = client.callTool(specialCharsToolName, arguments) as CallToolResultBase - - val content = result.content.firstOrNull() as? TextContent - assertNotNull(content, "Tool result content should be TextContent") - val text = content.text ?: "" - - assertTrue(text.contains(specialCharsContent), "Result should contain the special characters") - - val actualContent = result.structuredContent.toString() - val expectedContent = """ - { - "special" : "!@#$%^&*()_+{}|:\"<>?~`-=[]\\;',./\n\t", - "length" : 34 - } - """.trimIndent() - - actualContent shouldEqualJson expectedContent - } - } - - @Test - fun testConcurrentToolCalls() = runTest { - val concurrentCount = 10 - val results = mutableListOf() - - runBlocking { - repeat(concurrentCount) { index -> - launch { - val toolName = when (index % 5) { - 0 -> basicToolName - 1 -> complexToolName - 2 -> largeToolName - 3 -> slowToolName - else -> specialCharsToolName - } - - val arguments = when (toolName) { - basicToolName -> mapOf("text" to "Concurrent call $index") - - complexToolName -> mapOf( - "user" to mapOf( - "name" to "User $index", - "age" to 20 + index, - "address" to mapOf( - "street" to "Street $index", - "city" to "City $index", - "country" to "Country $index", - ), - ), - ) - - largeToolName -> mapOf("size" to 1) - - slowToolName -> mapOf("delay" to 100) - - else -> mapOf("special" to "!@#$%^&*()") - } - - val result = client.callTool(toolName, arguments) - - synchronized(results) { - results.add(result) - } - } - } - } - - assertEquals(concurrentCount, results.size, "All concurrent operations should complete") - results.forEach { result -> - assertNotNull(result, "Result should not be null") - assertTrue(result.content.isNotEmpty(), "Result content should not be empty") - } - } - - @Test - fun testNonExistentTool() = runTest { - val nonExistentToolName = "non-existent-tool" - val arguments = mapOf("text" to "Test") - - val exception = assertThrows { - runBlocking { - client.callTool(nonExistentToolName, arguments) - } - } - - val msg = exception.message ?: "" - val expectedMessage = "JSONRPCError(code=InternalError, message=Tool not found: non-existent-tool, data={})" - - assertEquals(expectedMessage, msg, "Unexpected error message for non-existent tool") - } -} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/sse/PromptIntegrationTestSse.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/sse/PromptIntegrationTestSse.kt new file mode 100644 index 00000000..d8f218b9 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/sse/PromptIntegrationTestSse.kt @@ -0,0 +1,7 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin.sse + +import io.modelcontextprotocol.kotlin.sdk.integration.kotlin.AbstractPromptIntegrationTest + +class PromptIntegrationTestSse : AbstractPromptIntegrationTest() { + override val transportKind: TransportKind = TransportKind.SSE +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/sse/ResourceIntegrationTestSse.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/sse/ResourceIntegrationTestSse.kt new file mode 100644 index 00000000..bf1240df --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/sse/ResourceIntegrationTestSse.kt @@ -0,0 +1,7 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin.sse + +import io.modelcontextprotocol.kotlin.sdk.integration.kotlin.AbstractResourceIntegrationTest + +class ResourceIntegrationTestSse : AbstractResourceIntegrationTest() { + override val transportKind: TransportKind = TransportKind.SSE +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/sse/ToolIntegrationTestSse.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/sse/ToolIntegrationTestSse.kt new file mode 100644 index 00000000..dd007c6e --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/sse/ToolIntegrationTestSse.kt @@ -0,0 +1,7 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin.sse + +import io.modelcontextprotocol.kotlin.sdk.integration.kotlin.AbstractToolIntegrationTest + +class ToolIntegrationTestSse : AbstractToolIntegrationTest() { + override val transportKind: TransportKind = TransportKind.SSE +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/stdio/PromptIntegrationTestStdio.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/stdio/PromptIntegrationTestStdio.kt new file mode 100644 index 00000000..88be1e80 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/stdio/PromptIntegrationTestStdio.kt @@ -0,0 +1,7 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin.stdio + +import io.modelcontextprotocol.kotlin.sdk.integration.kotlin.AbstractPromptIntegrationTest + +class PromptIntegrationTestStdio : AbstractPromptIntegrationTest() { + override val transportKind: TransportKind = TransportKind.STDIO +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/stdio/ResourceIntegrationTestStdio.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/stdio/ResourceIntegrationTestStdio.kt new file mode 100644 index 00000000..88eca7b0 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/stdio/ResourceIntegrationTestStdio.kt @@ -0,0 +1,7 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin.stdio + +import io.modelcontextprotocol.kotlin.sdk.integration.kotlin.AbstractResourceIntegrationTest + +class ResourceIntegrationTestStdio : AbstractResourceIntegrationTest() { + override val transportKind: TransportKind = TransportKind.STDIO +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/stdio/ToolIntegrationTestStdio.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/stdio/ToolIntegrationTestStdio.kt new file mode 100644 index 00000000..673d44bb --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/stdio/ToolIntegrationTestStdio.kt @@ -0,0 +1,7 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin.stdio + +import io.modelcontextprotocol.kotlin.sdk.integration.kotlin.AbstractToolIntegrationTest + +class ToolIntegrationTestStdio : AbstractToolIntegrationTest() { + override val transportKind: TransportKind = TransportKind.STDIO +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/AbstractKotlinClientTsServerTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/AbstractKotlinClientTsServerTest.kt new file mode 100644 index 00000000..51179982 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/AbstractKotlinClientTsServerTest.kt @@ -0,0 +1,78 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.typescript + +import io.modelcontextprotocol.kotlin.sdk.CallToolResult +import io.modelcontextprotocol.kotlin.sdk.TextContent +import io.modelcontextprotocol.kotlin.sdk.client.Client +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.runBlocking +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Timeout +import java.util.concurrent.TimeUnit +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +abstract class AbstractKotlinClientTsServerTest : TsTestBase() { + protected abstract suspend fun useClient(block: suspend (Client) -> T): T + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun connectsAndPings() = runBlocking(Dispatchers.IO) { + useClient { client -> + assertNotNull(client, "Client should be initialized") + val ping = client.ping() + assertNotNull(ping, "Ping result should not be null") + val serverImpl = client.serverVersion + assertNotNull(serverImpl, "Server implementation should not be null") + println("Connected to TypeScript server: ${serverImpl.name} v${serverImpl.version}") + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun listsTools() = runBlocking(Dispatchers.IO) { + useClient { client -> + val result = client.listTools() + assertNotNull(result, "Tools list should not be null") + assertTrue(result.tools.isNotEmpty(), "Tools list should not be empty") + val toolNames = result.tools.map { it.name } + assertTrue("greet" in toolNames, "Greet tool should be available") + assertTrue("multi-greet" in toolNames, "Multi-greet tool should be available") + // Some tests also check collect-user-info; keep base minimal and non-breaking + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun callGreet() = runBlocking(Dispatchers.IO) { + useClient { client -> + val testName = "TestUser" + val arguments = mapOf("name" to testName) + val result = client.callTool("greet", arguments) + assertNotNull(result, "Tool call result should not be null") + val callResult = result as CallToolResult + val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present in the result") + assertEquals("Hello, $testName!", textContent.text) + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun multipleClients() = runBlocking(Dispatchers.IO) { + useClient { client1 -> + useClient { client2 -> + val tools1 = client1.listTools() + val tools2 = client2.listTools() + assertTrue(tools1.tools.isNotEmpty(), "Tools list for first client should not be empty") + assertTrue(tools2.tools.isNotEmpty(), "Tools list for second client should not be empty") + val toolNames1 = tools1.tools.map { it.name } + val toolNames2 = tools2.tools.map { it.name } + assertTrue("greet" in toolNames1, "Greet tool should be available to first client") + assertTrue("multi-greet" in toolNames1, "Multi-greet tool should be available to first client") + assertTrue("greet" in toolNames2, "Greet tool should be available to second client") + assertTrue("multi-greet" in toolNames2, "Multi-greet tool should be available to second client") + } + } + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/AbstractTsClientKotlinServerTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/AbstractTsClientKotlinServerTest.kt new file mode 100644 index 00000000..e7950e79 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/AbstractTsClientKotlinServerTest.kt @@ -0,0 +1,87 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.typescript + +import kotlinx.coroutines.test.runTest +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Timeout +import java.util.concurrent.TimeUnit +import kotlin.test.assertTrue + +abstract class AbstractTsClientKotlinServerTest : TsTestBase() { + + protected open fun beforeServer() {} + protected open fun afterServer() {} + + /** + * Run the TypeScript client against the prepared server and return its console output. + */ + protected abstract fun runClient(vararg args: String): String + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun toolCall() = runTest { + beforeServer() + try { + val testName = "TestUser" + val out = runClient("greet", testName) + assertTrue(out.contains("Text content:"), "Output should contain the text content section.\n$out") + assertTrue(out.contains("Hello, $testName!"), "Tool response should contain the greeting.\n$out") + assertTrue( + out.contains("Structured content:"), + "Output should contain the structured content section.\n$out", + ) + assertTrue( + out.contains( + "\"greeting\": \"Hello, $testName!\"", + ) || + out.contains("greeting") || + out.contains("greet"), + "Structured content should contain the greeting.\n$out", + ) + } finally { + afterServer() + } + } + + @Test + @Timeout(60, unit = TimeUnit.SECONDS) + fun notifications() = runTest { + beforeServer() + try { + val name = "NotifUser" + val out = runClient("multi-greet", name) + assertTrue( + out.contains("Multiple greetings") || out.contains("greeting"), + "Tool response should contain greeting message.\n$out", + ) + assertTrue( + out.contains("\"notificationCount\": 3") || out.contains("notificationCount: 3"), + "Structured content should indicate that 3 notifications were emitted by the server.\nOutput:\n$out", + ) + } finally { + afterServer() + } + } + + @Test + @Timeout(120, unit = TimeUnit.SECONDS) + fun multipleClientSequence() = runTest { + beforeServer() + try { + val out1 = runClient("greet", "FirstClient") + assertTrue(out1.contains("Hello, FirstClient!"), "Should greet first client.\n$out1") + + val out2 = runClient("multi-greet", "SecondClient") + assertTrue( + out2.contains("Multiple greetings") || out2.contains("greeting"), + "Should respond for second client.\n$out2", + ) + + val out3 = runClient() + assertTrue(out3.contains("Available utils:"), "Should list available utils.\n$out3") + assertTrue(out3.contains("greet"), "Greet tool should be available.\n$out3") + assertTrue(out3.contains("multi-greet"), "Multi-greet tool should be available.\n$out3") + } finally { + afterServer() + } + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerTest.kt deleted file mode 100644 index eca06be1..00000000 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerTest.kt +++ /dev/null @@ -1,140 +0,0 @@ -package io.modelcontextprotocol.kotlin.sdk.integration.typescript - -import io.modelcontextprotocol.kotlin.sdk.CallToolResult -import io.modelcontextprotocol.kotlin.sdk.TextContent -import io.modelcontextprotocol.kotlin.sdk.client.Client -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.runBlocking -import kotlinx.coroutines.withTimeout -import org.junit.jupiter.api.AfterEach -import org.junit.jupiter.api.BeforeEach -import org.junit.jupiter.api.Test -import org.junit.jupiter.api.Timeout -import java.util.concurrent.TimeUnit -import kotlin.test.assertEquals -import kotlin.test.assertNotNull -import kotlin.test.assertTrue -import kotlin.time.Duration.Companion.seconds - -class KotlinClientTypeScriptServerTest : TypeScriptTestBase() { - - private var port: Int = 0 - private val host = "localhost" - private lateinit var serverUrl: String - - private lateinit var client: Client - private lateinit var tsServerProcess: Process - - @BeforeEach - fun setUp() { - port = findFreePort() - serverUrl = "http://$host:$port/mcp" - tsServerProcess = startTypeScriptServer(port) - println("TypeScript server started on port $port") - } - - @AfterEach - fun tearDown() { - if (::client.isInitialized) { - try { - runBlocking { - withTimeout(3.seconds) { - client.close() - } - } - } catch (e: Exception) { - println("Warning: Error during client close: ${e.message}") - } - } - - if (::tsServerProcess.isInitialized) { - try { - println("Stopping TypeScript server") - stopProcess(tsServerProcess) - } catch (e: Exception) { - println("Warning: Error during TypeScript server stop: ${e.message}") - } - } - } - - @Test - @Timeout(30, unit = TimeUnit.SECONDS) - fun testKotlinClientConnectsToTypeScriptServer(): Unit = runBlocking(Dispatchers.IO) { - withClient(serverUrl) { client -> - assertNotNull(client, "Client should be initialized") - - val pingResult = client.ping() - assertNotNull(pingResult, "Ping result should not be null") - - val serverImpl = client.serverVersion - assertNotNull(serverImpl, "Server implementation should not be null") - println("Connected to TypeScript server: ${serverImpl.name} v${serverImpl.version}") - } - } - - @Test - @Timeout(30, unit = TimeUnit.SECONDS) - fun testListTools(): Unit = runBlocking(Dispatchers.IO) { - withClient(serverUrl) { client -> - val result = client.listTools() - assertNotNull(result, "Tools list should not be null") - assertTrue(result.tools.isNotEmpty(), "Tools list should not be empty") - - // Verify specific utils are available - val toolNames = result.tools.map { it.name } - assertTrue(toolNames.contains("greet"), "Greet tool should be available") - assertTrue(toolNames.contains("multi-greet"), "Multi-greet tool should be available") - assertTrue(toolNames.contains("collect-user-info"), "Collect-user-info tool should be available") - - println("Available utils: ${toolNames.joinToString()}") - } - } - - @Test - @Timeout(30, unit = TimeUnit.SECONDS) - fun testToolCall(): Unit = runBlocking(Dispatchers.IO) { - withClient(serverUrl) { client -> - val testName = "TestUser" - val arguments = mapOf("name" to testName) - - val result = client.callTool("greet", arguments) - assertNotNull(result, "Tool call result should not be null") - - val callResult = result as CallToolResult - val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent - assertNotNull(textContent, "Text content should be present in the result") - assertEquals( - "Hello, $testName!", - textContent.text, - "Tool response should contain the greeting with the provided name", - ) - } - } - - @Test - @Timeout(30, unit = TimeUnit.SECONDS) - fun testMultipleClients(): Unit = runBlocking(Dispatchers.IO) { - val client1 = newClient(serverUrl) - val client2 = newClient(serverUrl) - try { - val tools1 = client1.listTools() - assertNotNull(tools1, "Tools list for first client should not be null") - assertTrue(tools1.tools.isNotEmpty(), "Tools list for first client should not be empty") - - val tools2 = client2.listTools() - assertNotNull(tools2, "Tools list for second client should not be null") - assertTrue(tools2.tools.isNotEmpty(), "Tools list for second client should not be empty") - - val toolNames1 = tools1.tools.map { it.name } - val toolNames2 = tools2.tools.map { it.name } - - assertTrue(toolNames1.contains("greet"), "Greet tool should be available to first client") - assertTrue(toolNames1.contains("multi-greet"), "Multi-greet tool should be available to first client") - assertTrue(toolNames2.contains("greet"), "Greet tool should be available to second client") - assertTrue(toolNames2.contains("multi-greet"), "Multi-greet tool should be available to second client") - } finally { - client1.close() - client2.close() - } - } -} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TsTestBase.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TsTestBase.kt new file mode 100644 index 00000000..801f5820 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TsTestBase.kt @@ -0,0 +1,454 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.typescript + +import io.ktor.client.HttpClient +import io.ktor.client.engine.cio.CIO +import io.ktor.client.plugins.sse.SSE +import io.modelcontextprotocol.kotlin.sdk.Implementation +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.client.StdioClientTransport +import io.modelcontextprotocol.kotlin.sdk.client.mcpStreamableHttp +import io.modelcontextprotocol.kotlin.sdk.integration.typescript.sse.KotlinServerForTsClient +import io.modelcontextprotocol.kotlin.sdk.integration.utils.Retry +import io.modelcontextprotocol.kotlin.sdk.server.Server +import io.modelcontextprotocol.kotlin.sdk.server.StdioServerTransport +import kotlinx.coroutines.withTimeout +import kotlinx.io.Sink +import kotlinx.io.Source +import kotlinx.io.asSink +import kotlinx.io.asSource +import kotlinx.io.buffered +import org.awaitility.kotlin.await +import org.junit.jupiter.api.BeforeAll +import java.io.BufferedReader +import java.io.File +import java.io.InputStreamReader +import java.net.ServerSocket +import java.net.Socket +import java.util.concurrent.TimeUnit +import kotlin.io.path.createTempDirectory +import kotlin.time.Duration.Companion.seconds + +enum class TransportKind { SSE, STDIO, DEFAULT } + +@Retry(times = 3) +abstract class TsTestBase { + + protected open val transportKind: TransportKind = TransportKind.DEFAULT + + protected val projectRoot: File get() = File(System.getProperty("user.dir")) + protected val tsClientDir: File + get() { + val base = File( + projectRoot, + "src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript", + ) + + // Allow override via system property for CI: -Dts.transport=stdio|sse + val fromProp = System.getProperty("ts.transport")?.lowercase() + val overrideSubDir = when (fromProp) { + "stdio" -> "stdio" + "sse" -> "sse" + else -> null + } + + val subDirName = overrideSubDir ?: when (transportKind) { + TransportKind.STDIO -> "stdio" + TransportKind.SSE -> "sse" + TransportKind.DEFAULT -> null + } + if (subDirName != null) { + val sub = File(base, subDirName) + if (sub.exists()) return sub + } + return base + } + + companion object { + @JvmStatic + private val tempRootDir: File = createTempDirectory("typescript-sdk-").toFile().apply { deleteOnExit() } + + @JvmStatic + protected val sdkDir: File = File(tempRootDir, "typescript-sdk") + + @JvmStatic + @BeforeAll + fun setupTypeScriptSdk() { + println("Cloning TypeScript SDK repository") + + if (!sdkDir.exists()) { + val process = ProcessBuilder( + "git", + "clone", + "--depth", + "1", + "https://github.com/modelcontextprotocol/typescript-sdk.git", + sdkDir.absolutePath, + ) + .redirectErrorStream(true) + .start() + val exitCode = process.waitFor() + if (exitCode != 0) { + throw RuntimeException("Failed to clone TypeScript SDK repository: exit code $exitCode") + } + } + + println("Installing TypeScript SDK dependencies") + executeCommand("npm install", sdkDir, allowFailure = false, timeoutSeconds = null) + } + + @JvmStatic + protected fun killProcessOnPort(port: Int) { + val isWindows = System.getProperty("os.name").lowercase().contains("windows") + val killCommand = if (isWindows) { + "netstat -ano | findstr :$port | for /f \"tokens=5\" %a in ('more')" + + " do taskkill /F /PID %a 2>nul || echo No process found" + } else { + "lsof -ti:$port | xargs kill -9 2>/dev/null || true" + } + executeCommand(killCommand, File("."), allowFailure = true, timeoutSeconds = null) + } + + @JvmStatic + protected fun findFreePort(): Int { + ServerSocket(0).use { socket -> + return socket.localPort + } + } + + @JvmStatic + protected fun executeCommand( + command: String, + workingDir: File, + allowFailure: Boolean = false, + timeoutSeconds: Long? = null, + ): String { + if (!workingDir.exists()) { + if (!workingDir.mkdirs()) { + throw RuntimeException("Failed to create working directory: ${workingDir.absolutePath}") + } + } + + if (!workingDir.isDirectory || !workingDir.canRead()) { + throw RuntimeException("Working directory is not accessible: ${workingDir.absolutePath}") + } + + val isWindows = System.getProperty("os.name").lowercase().contains("windows") + val processBuilder = if (isWindows) { + ProcessBuilder() + .command("cmd.exe", "/c", "set TYPESCRIPT_SDK_DIR=${sdkDir.absolutePath} && $command") + } else { + ProcessBuilder() + .command("bash", "-c", "TYPESCRIPT_SDK_DIR='${sdkDir.absolutePath}' $command") + } + + val process = processBuilder + .directory(workingDir) + .redirectErrorStream(true) + .start() + + val output = StringBuilder() + BufferedReader(InputStreamReader(process.inputStream)).use { reader -> + var line: String? + while (reader.readLine().also { line = it } != null) { + println(line) + output.append(line).append("\n") + } + } + + if (timeoutSeconds == null) { + val exitCode = process.waitFor() + if (!allowFailure && exitCode != 0) { + throw RuntimeException( + "Command execution failed with exit code $exitCode: $command\n" + + "Working dir: ${workingDir.absolutePath}\nOutput:\n$output", + ) + } + } else { + process.waitFor(timeoutSeconds, TimeUnit.SECONDS) + } + + return output.toString() + } + } + + private fun waitForProcessTermination(process: Process, timeoutSeconds: Long): Boolean { + if (process.isAlive && !process.waitFor(timeoutSeconds, TimeUnit.SECONDS)) { + process.destroyForcibly() + process.waitFor(2, TimeUnit.SECONDS) + return false + } + return true + } + + private fun createProcessOutputReader(process: Process, prefix: String = "TS-SERVER"): Thread { + val outputReader = Thread { + try { + process.inputStream.bufferedReader().useLines { lines -> + for (line in lines) { + println("[$prefix] $line") + } + } + } catch (e: Exception) { + println("Warning: Error reading process output: ${e.message}") + } + } + outputReader.isDaemon = true + return outputReader + } + + private fun createProcessErrorReader(process: Process, prefix: String = "TS-SERVER"): Thread { + val errorReader = Thread { + try { + process.errorStream.bufferedReader().useLines { lines -> + for (line in lines) { + println("[$prefix][err] $line") + } + } + } catch (e: Exception) { + println("Warning: Error reading process error stream: ${e.message}") + } + } + errorReader.isDaemon = true + return errorReader + } + + protected fun waitForPort(host: String = "localhost", port: Int, timeoutSeconds: Long = 10): Boolean = try { + await.atMost(timeoutSeconds, TimeUnit.SECONDS) + .pollDelay(200, TimeUnit.MILLISECONDS) + .pollInterval(100, TimeUnit.MILLISECONDS) + .until { + try { + Socket(host, port).use { true } + } catch (_: Exception) { + false + } + } + true + } catch (_: Exception) { + false + } + + protected fun executeCommandAllowingFailure(command: String, workingDir: File, timeoutSeconds: Long = 20): String = + executeCommand(command, workingDir, allowFailure = true, timeoutSeconds = timeoutSeconds) + + protected fun startTypeScriptServer(port: Int): Process { + killProcessOnPort(port) + + if (!sdkDir.exists() || !sdkDir.isDirectory) { + throw IllegalStateException( + "TypeScript SDK directory does not exist or is not accessible: ${sdkDir.absolutePath}", + ) + } + + val isWindows = System.getProperty("os.name").lowercase().contains("windows") + val localServerPath = File(tsClientDir, "simpleStreamableHttp.ts").absolutePath + val processBuilder = if (isWindows) { + ProcessBuilder() + .command( + "cmd.exe", + "/c", + "set MCP_PORT=$port && set NODE_PATH=${sdkDir.absolutePath}\\node_modules && npx --prefix \"${sdkDir.absolutePath}\" tsx \"$localServerPath\"", + ) + } else { + ProcessBuilder() + .command( + "bash", + "-c", + "MCP_PORT=$port NODE_PATH='${sdkDir.absolutePath}/node_modules' npx --prefix '${sdkDir.absolutePath}' tsx \"$localServerPath\"", + ) + } + + processBuilder.environment()["TYPESCRIPT_SDK_DIR"] = sdkDir.absolutePath + + val process = processBuilder + .directory(tsClientDir) + .redirectErrorStream(true) + .start() + + createProcessOutputReader(process).start() + + if (!waitForPort(port = port, timeoutSeconds = 20)) { + throw IllegalStateException("TypeScript server did not become ready on localhost:$port within timeout") + } + return process + } + + protected fun stopProcess(process: Process, waitSeconds: Long = 3, name: String = "TypeScript server") { + process.destroy() + if (waitForProcessTermination(process, waitSeconds)) { + println("$name stopped gracefully") + } else { + println("$name did not stop gracefully, forced termination") + } + } + + // ===== SSE client helpers ===== + protected suspend fun newClient(serverUrl: String): Client = + HttpClient(CIO) { install(SSE) }.mcpStreamableHttp(serverUrl) + + protected suspend fun withClient(serverUrl: String, block: suspend (Client) -> T): T { + val client = newClient(serverUrl) + return try { + withTimeout(20.seconds) { block(client) } + } finally { + try { + withTimeout(3.seconds) { client.close() } + } catch (_: Exception) { + // ignore errors + } + } + } + + // ===== STDIO client + server helpers ===== + protected fun startTypeScriptServerStdio(): Process { + if (!sdkDir.exists() || !sdkDir.isDirectory) { + throw IllegalStateException( + "TypeScript SDK directory does not exist or is not accessible: ${sdkDir.absolutePath}", + ) + } + val isWindows = System.getProperty("os.name").lowercase().contains("windows") + val localServerPath = File(tsClientDir, "simpleStdio.ts").absolutePath + val processBuilder = if (isWindows) { + ProcessBuilder() + .command( + "cmd.exe", + "/c", + "set NODE_PATH=${sdkDir.absolutePath}\\node_modules && npx --prefix \"${sdkDir.absolutePath}\" tsx \"$localServerPath\"", + ) + } else { + ProcessBuilder() + .command( + "bash", + "-c", + "NODE_PATH='${sdkDir.absolutePath}/node_modules' npx --prefix '${sdkDir.absolutePath}' tsx \"$localServerPath\"", + ) + } + processBuilder.environment()["TYPESCRIPT_SDK_DIR"] = sdkDir.absolutePath + val process = processBuilder + .directory(tsClientDir) + .redirectErrorStream(false) + .start() + // For stdio transports, do NOT read from stdout (it's used for protocol). Read stderr for logs only. + createProcessErrorReader(process, prefix = "TS-SERVER-STDIO").start() + // Give the process a moment to start + await.atMost(2, TimeUnit.SECONDS) + .pollDelay(200, TimeUnit.MILLISECONDS) + .pollInterval(100, TimeUnit.MILLISECONDS) + .until { process.isAlive } + return process + } + + protected suspend fun newClientStdio(process: Process): Client { + val input: Source = process.inputStream.asSource().buffered() + val output: Sink = process.outputStream.asSink().buffered() + val transport = StdioClientTransport(input = input, output = output) + val client = Client(Implementation("test", "1.0")) + client.connect(transport) + return client + } + + protected suspend fun withClientStdio(block: suspend (Client, Process) -> T): T { + val proc = startTypeScriptServerStdio() + val client = newClientStdio(proc) + return try { + withTimeout(20.seconds) { block(client, proc) } + } finally { + try { + withTimeout(3.seconds) { client.close() } + } catch (_: Exception) { + } + try { + stopProcess(proc, name = "TypeScript stdio server") + } catch (_: Exception) { + } + } + } + + // ===== Helpers to run TypeScript client over STDIO against Kotlin server over STDIO ===== + protected fun runStdioClient(vararg args: String): String { + // Start Node stdio client (it will speak MCP over its stdout/stdin) + val isWindows = System.getProperty("os.name").lowercase().contains("windows") + val clientPath = File(tsClientDir, "myClient.ts").absolutePath + + val process = if (isWindows) { + ProcessBuilder() + .command( + "cmd.exe", + "/c", + ( + "set TYPESCRIPT_SDK_DIR=${sdkDir.absolutePath} && " + + "set NODE_PATH=${sdkDir.absolutePath}\\node_modules && " + + "npx --prefix \"${sdkDir.absolutePath}\" tsx \"$clientPath\" " + + args.joinToString(" ") + ), + ) + .directory(tsClientDir) + .redirectErrorStream(false) + .start() + } else { + ProcessBuilder() + .command( + "bash", + "-c", + ( + "TYPESCRIPT_SDK_DIR='${sdkDir.absolutePath}' " + + "NODE_PATH='${sdkDir.absolutePath}/node_modules' " + + "npx --prefix '${sdkDir.absolutePath}' tsx \"$clientPath\" " + + args.joinToString(" ") + ), + ) + .directory(tsClientDir) + .redirectErrorStream(false) + .start() + } + + // Create Kotlin server and attach stdio transport to the process streams + val server: Server = KotlinServerForTsClient().createMcpServer() + val transport = StdioServerTransport( + inputStream = process.inputStream.asSource().buffered(), + outputStream = process.outputStream.asSink().buffered(), + ) + + // Connect server in a background thread to avoid blocking + val serverThread = Thread { + try { + kotlinx.coroutines.runBlocking { server.connect(transport) } + } catch (e: Exception) { + println("[STDIO-SERVER] Error connecting: ${e.message}") + } + } + serverThread.isDaemon = true + serverThread.start() + + // Read ONLY stderr from client for human-readable output + val output = StringBuilder() + val errReader = Thread { + try { + process.errorStream.bufferedReader().useLines { lines -> + lines.forEach { line -> + println("[TS-CLIENT-STDIO][err] $line") + output.append(line).append('\n') + } + } + } catch (e: Exception) { + println("Warning: Error reading stdio client stderr: ${e.message}") + } + } + errReader.isDaemon = true + errReader.start() + + // Wait up to 25s for client to exit + val finished = process.waitFor(25, TimeUnit.SECONDS) + if (!finished) { + println("Stdio client did not finish in time; destroying") + process.destroyForcibly() + } + + try { + kotlinx.coroutines.runBlocking { transport.close() } + } catch (_: Exception) { + } + + return output.toString() + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptClientKotlinServerTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptClientKotlinServerTest.kt deleted file mode 100644 index d25dbebb..00000000 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptClientKotlinServerTest.kt +++ /dev/null @@ -1,191 +0,0 @@ -package io.modelcontextprotocol.kotlin.sdk.integration.typescript - -import kotlinx.coroutines.test.runTest -import org.junit.jupiter.api.AfterEach -import org.junit.jupiter.api.BeforeEach -import org.junit.jupiter.api.Test -import org.junit.jupiter.api.Timeout -import java.util.concurrent.TimeUnit -import kotlin.test.Ignore -import kotlin.test.assertTrue - -class TypeScriptClientKotlinServerTest : TypeScriptTestBase() { - - private var port: Int = 0 - private lateinit var serverUrl: String - private var httpServer: KotlinServerForTypeScriptClient? = null - - @BeforeEach - fun setUp() { - port = findFreePort() - serverUrl = "http://localhost:$port/mcp" - killProcessOnPort(port) - httpServer = KotlinServerForTypeScriptClient() - httpServer?.start(port) - if (!waitForPort(port = port)) { - throw IllegalStateException("Kotlin test server did not become ready on localhost:$port within timeout") - } - println("Kotlin server started on port $port") - } - - @AfterEach - fun tearDown() { - try { - httpServer?.stop() - println("HTTP server stopped") - } catch (e: Exception) { - println("Error during server shutdown: ${e.message}") - } - } - - @Test - @Timeout(30, unit = TimeUnit.SECONDS) - fun testToolCall() = runTest { - val testName = "TestUser" - val command = "npx tsx myClient.ts $serverUrl greet $testName" - val output = executeCommand(command, tsClientDir) - - assertTrue( - output.contains("Hello, $testName!"), - "Tool response should contain the greeting with the provided name", - ) - assertTrue(output.contains("Tool result:"), "Output should indicate a successful tool call") - assertTrue(output.contains("Text content:"), "Output should contain the text content section") - assertTrue(output.contains("Structured content:"), "Output should contain the structured content section") - assertTrue( - output.contains("\"greeting\": \"Hello, $testName!\""), - "Structured content should contain the greeting", - ) - } - - @Test - @Timeout(30, unit = TimeUnit.SECONDS) - fun testNotifications() = runTest { - val name = "NotifUser" - val command = "npx tsx myClient.ts $serverUrl multi-greet $name" - val output = executeCommand(command, tsClientDir) - - assertTrue( - output.contains("Multiple greetings") || output.contains("greeting"), - "Tool response should contain greeting message", - ) - // verify that the server sent 3 notifications - assertTrue( - output.contains("\"notificationCount\": 3") || output.contains("notificationCount: 3"), - "Structured content should indicate that 3 notifications were emitted by the server.\nOutput:\n$output", - ) - } - - @Test - @Timeout(120, unit = TimeUnit.SECONDS) - fun testMultipleClientSequence() = runTest { - val testName1 = "FirstClient" - val command1 = "npx tsx myClient.ts $serverUrl greet $testName1" - val output1 = executeCommand(command1, tsClientDir) - - assertTrue(output1.contains("Connected to server"), "First client should connect to server") - assertTrue(output1.contains("Hello, $testName1!"), "Tool response should contain the greeting for first client") - assertTrue(output1.contains("Disconnected from server"), "First client should disconnect cleanly") - - val testName2 = "SecondClient" - val command2 = "npx tsx myClient.ts $serverUrl multi-greet $testName2" - val output2 = executeCommand(command2, tsClientDir) - - assertTrue(output2.contains("Connected to server"), "Second client should connect to server") - assertTrue( - output2.contains("Multiple greetings") || output2.contains("greeting"), - "Tool response should contain greeting message", - ) - assertTrue(output2.contains("Disconnected from server"), "Second client should disconnect cleanly") - - val command3 = "npx tsx myClient.ts $serverUrl" - val output3 = executeCommand(command3, tsClientDir) - - assertTrue(output3.contains("Connected to server"), "Third client should connect to server") - assertTrue(output3.contains("Available utils:"), "Third client should list available utils") - assertTrue(output3.contains("greet"), "Greet tool should be available to third client") - assertTrue(output3.contains("multi-greet"), "Multi-greet tool should be available to third client") - assertTrue(output3.contains("Disconnected from server"), "Third client should disconnect cleanly") - } - - @Test - @Timeout(30, unit = TimeUnit.SECONDS) - @Ignore // Ignored due to flaky, see issue https://github.com/modelcontextprotocol/kotlin-sdk/issues/262 - fun testMultipleClientParallel() = runTest { - val clientCount = 3 - val clients = listOf( - "FirstClient" to "greet", - "SecondClient" to "multi-greet", - "ThirdClient" to "", - ) - - val threads = mutableListOf() - val outputs = mutableListOf>() - val exceptions = mutableListOf() - - for (i in 0 until clientCount) { - val (clientName, toolName) = clients[i] - val thread = Thread { - try { - val command = if (toolName.isEmpty()) { - "npx tsx myClient.ts $serverUrl" - } else { - "npx tsx myClient.ts $serverUrl $toolName $clientName" - } - - val output = executeCommand(command, tsClientDir) - synchronized(outputs) { - outputs.add(i to output) - } - } catch (e: Exception) { - synchronized(exceptions) { - exceptions.add(e) - } - } - } - threads.add(thread) - thread.start() - Thread.sleep(500) - } - - threads.forEach { it.join() } - - if (exceptions.isNotEmpty()) { - println( - "Exceptions occurred in parallel clients: ${ - exceptions.joinToString { - it.message ?: it.toString() - } - }", - ) - } - - val sortedOutputs = outputs.sortedBy { it.first }.map { it.second } - - sortedOutputs.forEachIndexed { index, output -> - val clientName = clients[index].first - val toolName = clients[index].second - - when (toolName) { - "greet" -> { - val containsGreeting = output.contains("Hello, $clientName!") || - output.contains("\"greeting\": \"Hello, $clientName!\"") - assertTrue( - containsGreeting, - "Tool response should contain the greeting for $clientName", - ) - } - - "multi-greet" -> { - val containsGreeting = output.contains("Multiple greetings") || - output.contains("greeting") || - output.contains("greet") - assertTrue( - containsGreeting, - "Tool response should contain greeting message for $clientName", - ) - } - } - } - } -} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptTestBase.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptTestBase.kt deleted file mode 100644 index a19f00ec..00000000 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptTestBase.kt +++ /dev/null @@ -1,242 +0,0 @@ -package io.modelcontextprotocol.kotlin.sdk.integration.typescript - -import io.ktor.client.HttpClient -import io.ktor.client.engine.cio.CIO -import io.ktor.client.plugins.sse.SSE -import io.modelcontextprotocol.kotlin.sdk.client.Client -import io.modelcontextprotocol.kotlin.sdk.client.mcpStreamableHttp -import io.modelcontextprotocol.kotlin.sdk.integration.utils.Retry -import kotlinx.coroutines.withTimeout -import org.junit.jupiter.api.BeforeAll -import java.io.BufferedReader -import java.io.File -import java.io.InputStreamReader -import java.net.ServerSocket -import java.net.Socket -import java.nio.file.Files -import java.util.concurrent.TimeUnit -import kotlin.time.Duration.Companion.seconds - -@Retry(times = 3) -abstract class TypeScriptTestBase { - - protected val projectRoot: File get() = File(System.getProperty("user.dir")) - protected val tsClientDir: File - get() = File( - projectRoot, - "src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript", - ) - - companion object { - @JvmStatic - private val tempRootDir: File = Files.createTempDirectory("typescript-sdk-").toFile().apply { deleteOnExit() } - - @JvmStatic - protected val sdkDir: File = File(tempRootDir, "typescript-sdk") - - @JvmStatic - @BeforeAll - fun setupTypeScriptSdk() { - println("Cloning TypeScript SDK repository") - - if (!sdkDir.exists()) { - val process = ProcessBuilder( - "git", - "clone", - "--depth", - "1", - "https://github.com/modelcontextprotocol/typescript-sdk.git", - sdkDir.absolutePath, - ) - .redirectErrorStream(true) - .start() - val exitCode = process.waitFor() - if (exitCode != 0) { - throw RuntimeException("Failed to clone TypeScript SDK repository: exit code $exitCode") - } - } - - println("Installing TypeScript SDK dependencies") - executeCommand("npm install", sdkDir, allowFailure = false, timeoutSeconds = null) - } - - @JvmStatic - protected fun killProcessOnPort(port: Int) { - val isWindows = System.getProperty("os.name").lowercase().contains("windows") - val killCommand = if (isWindows) { - "netstat -ano | findstr :$port | for /f \"tokens=5\" %a in ('more') do taskkill /F /PID %a 2>nul || echo No process found" - } else { - "lsof -ti:$port | xargs kill -9 2>/dev/null || true" - } - executeCommand(killCommand, File("."), allowFailure = true, timeoutSeconds = null) - } - - @JvmStatic - protected fun findFreePort(): Int { - ServerSocket(0).use { socket -> - return socket.localPort - } - } - - @JvmStatic - protected fun executeCommand( - command: String, - workingDir: File, - allowFailure: Boolean = false, - timeoutSeconds: Long? = null, - ): String { - if (!workingDir.exists()) { - if (!workingDir.mkdirs()) { - throw RuntimeException("Failed to create working directory: ${workingDir.absolutePath}") - } - } - - if (!workingDir.isDirectory || !workingDir.canRead()) { - throw RuntimeException("Working directory is not accessible: ${workingDir.absolutePath}") - } - - val isWindows = System.getProperty("os.name").lowercase().contains("windows") - val processBuilder = if (isWindows) { - ProcessBuilder() - .command("cmd.exe", "/c", "set TYPESCRIPT_SDK_DIR=${sdkDir.absolutePath} && $command") - } else { - ProcessBuilder() - .command("bash", "-c", "TYPESCRIPT_SDK_DIR='${sdkDir.absolutePath}' $command") - } - - val process = processBuilder - .directory(workingDir) - .redirectErrorStream(true) - .start() - - val output = StringBuilder() - BufferedReader(InputStreamReader(process.inputStream)).use { reader -> - var line: String? - while (reader.readLine().also { line = it } != null) { - println(line) - output.append(line).append("\n") - } - } - - if (timeoutSeconds == null) { - val exitCode = process.waitFor() - if (!allowFailure && exitCode != 0) { - throw RuntimeException( - "Command execution failed with exit code $exitCode: $command\nWorking dir: ${workingDir.absolutePath}\nOutput:\n$output", - ) - } - } else { - process.waitFor(timeoutSeconds, TimeUnit.SECONDS) - } - - return output.toString() - } - } - - private fun waitForProcessTermination(process: Process, timeoutSeconds: Long): Boolean { - if (process.isAlive && !process.waitFor(timeoutSeconds, TimeUnit.SECONDS)) { - process.destroyForcibly() - process.waitFor(2, TimeUnit.SECONDS) - return false - } - return true - } - - private fun createProcessOutputReader(process: Process, prefix: String = "TS-SERVER"): Thread { - val outputReader = Thread { - try { - process.inputStream.bufferedReader().useLines { lines -> - for (line in lines) { - println("[$prefix] $line") - } - } - } catch (e: Exception) { - println("Warning: Error reading process output: ${e.message}") - } - } - outputReader.isDaemon = true - return outputReader - } - - protected fun waitForPort(host: String = "localhost", port: Int, timeoutSeconds: Long = 10): Boolean { - val deadline = System.currentTimeMillis() + timeoutSeconds * 1000 - while (System.currentTimeMillis() < deadline) { - try { - Socket(host, port).use { return true } - } catch (_: Exception) { - Thread.sleep(100) - } - } - return false - } - - protected fun executeCommandAllowingFailure(command: String, workingDir: File, timeoutSeconds: Long = 20): String = - executeCommand(command, workingDir, allowFailure = true, timeoutSeconds = timeoutSeconds) - - protected fun startTypeScriptServer(port: Int): Process { - killProcessOnPort(port) - - if (!sdkDir.exists() || !sdkDir.isDirectory) { - throw IllegalStateException( - "TypeScript SDK directory does not exist or is not accessible: ${sdkDir.absolutePath}", - ) - } - - val isWindows = System.getProperty("os.name").lowercase().contains("windows") - val localServerPath = File(tsClientDir, "simpleStreamableHttp.ts").absolutePath - val processBuilder = if (isWindows) { - ProcessBuilder() - .command( - "cmd.exe", - "/c", - "set MCP_PORT=$port && set NODE_PATH=${sdkDir.absolutePath}\\node_modules && npx --prefix \"${sdkDir.absolutePath}\" tsx \"$localServerPath\"", - ) - } else { - ProcessBuilder() - .command( - "bash", - "-c", - "MCP_PORT=$port NODE_PATH='${sdkDir.absolutePath}/node_modules' npx --prefix '${sdkDir.absolutePath}' tsx \"$localServerPath\"", - ) - } - - processBuilder.environment()["TYPESCRIPT_SDK_DIR"] = sdkDir.absolutePath - - val process = processBuilder - .directory(tsClientDir) - .redirectErrorStream(true) - .start() - - createProcessOutputReader(process).start() - - if (!waitForPort(port = port, timeoutSeconds = 20)) { - throw IllegalStateException("TypeScript server did not become ready on localhost:$port within timeout") - } - return process - } - - protected fun stopProcess(process: Process, waitSeconds: Long = 3, name: String = "TypeScript server") { - process.destroy() - if (waitForProcessTermination(process, waitSeconds)) { - println("$name stopped gracefully") - } else { - println("$name did not stop gracefully, forced termination") - } - } - - protected suspend fun newClient(serverUrl: String): Client = - HttpClient(CIO) { install(SSE) }.mcpStreamableHttp(serverUrl) - - protected suspend fun withClient(serverUrl: String, block: suspend (Client) -> T): T { - val client = newClient(serverUrl) - return try { - withTimeout(20.seconds) { block(client) } - } finally { - try { - withTimeout(3.seconds) { client.close() } - } catch (_: Exception) { - // ignore errors - } - } - } -} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/KotlinClientTsServerEdgeCasesTestSse.kt similarity index 95% rename from kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerEdgeCasesTest.kt rename to kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/KotlinClientTsServerEdgeCasesTestSse.kt index 7b15fbc7..1585aa5e 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerEdgeCasesTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/KotlinClientTsServerEdgeCasesTestSse.kt @@ -1,11 +1,14 @@ -package io.modelcontextprotocol.kotlin.sdk.integration.typescript +package io.modelcontextprotocol.kotlin.sdk.integration.typescript.sse import io.modelcontextprotocol.kotlin.sdk.CallToolResult import io.modelcontextprotocol.kotlin.sdk.TextContent import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.integration.typescript.TransportKind +import io.modelcontextprotocol.kotlin.sdk.integration.typescript.TsTestBase import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.async import kotlinx.coroutines.awaitAll +import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.runBlocking import kotlinx.coroutines.withTimeout import kotlinx.serialization.json.JsonObject @@ -21,7 +24,9 @@ import kotlin.test.assertNotNull import kotlin.test.assertTrue import kotlin.time.Duration.Companion.seconds -class KotlinClientTypeScriptServerEdgeCasesTest : TypeScriptTestBase() { +class KotlinClientTsServerEdgeCasesTestSse : TsTestBase() { + + override val transportKind = TransportKind.SSE private var port: Int = 0 private val host = "localhost" @@ -132,7 +137,7 @@ class KotlinClientTypeScriptServerEdgeCasesTest : TypeScriptTestBase() { fun testConcurrentRequests(): Unit = runBlocking(Dispatchers.IO) { withClient(serverUrl) { client -> val concurrentCount = 5 - val responses = kotlinx.coroutines.coroutineScope { + val responses = coroutineScope { val results = (1..concurrentCount).map { i -> async { val name = "ConcurrentClient$i" diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/KotlinClientTsServerTestSse.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/KotlinClientTsServerTestSse.kt new file mode 100644 index 00000000..95a80f0b --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/KotlinClientTsServerTestSse.kt @@ -0,0 +1,49 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.typescript.sse + +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.integration.typescript.AbstractKotlinClientTsServerTest +import io.modelcontextprotocol.kotlin.sdk.integration.typescript.TransportKind +import kotlinx.coroutines.withTimeout +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.BeforeEach +import kotlin.time.Duration.Companion.seconds + +class KotlinClientTsServerTestSse : AbstractKotlinClientTsServerTest() { + + override val transportKind = TransportKind.SSE + + private var port: Int = 0 + private val host = "localhost" + private lateinit var serverUrl: String + private lateinit var tsServerProcess: Process + + @BeforeEach + fun setUpSse() { + port = findFreePort() + serverUrl = "http://$host:$port/mcp" + tsServerProcess = startTypeScriptServer(port) + println("TypeScript server started on port $port") + } + + @AfterEach + fun tearDownSse() { + if (::tsServerProcess.isInitialized) { + try { + println("Stopping TypeScript server") + stopProcess(tsServerProcess) + } catch (e: Exception) { + println("Warning: Error during TypeScript server stop: ${e.message}") + } + } + } + + override suspend fun useClient(block: suspend (Client) -> T): T = withClient(serverUrl) { client -> + try { + withTimeout(20.seconds) { block(client) } + } finally { + try { + withTimeout(3.seconds) { client.close() } + } catch (_: Exception) {} + } + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinServerForTypeScriptClient.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/KotlinServerForTsClientSse.kt similarity index 99% rename from kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinServerForTypeScriptClient.kt rename to kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/KotlinServerForTsClientSse.kt index 5757fcbc..56abf31d 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinServerForTypeScriptClient.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/KotlinServerForTsClientSse.kt @@ -1,4 +1,4 @@ -package io.modelcontextprotocol.kotlin.sdk.integration.typescript +package io.modelcontextprotocol.kotlin.sdk.integration.typescript.sse import io.github.oshai.kotlinlogging.KotlinLogging import io.ktor.http.ContentType @@ -55,12 +55,13 @@ import kotlinx.serialization.json.buildJsonObject import kotlinx.serialization.json.contentOrNull import kotlinx.serialization.json.decodeFromJsonElement import kotlinx.serialization.json.jsonPrimitive +import org.awaitility.Awaitility.await import java.util.UUID import java.util.concurrent.ConcurrentHashMap private val logger = KotlinLogging.logger {} -class KotlinServerForTypeScriptClient { +class KotlinServerForTsClient { private val serverTransports = ConcurrentHashMap() private val jsonFormat = Json { ignoreUnknownKeys = true } private var server: EmbeddedServer<*, *>? = null @@ -195,7 +196,7 @@ class KotlinServerForTypeScriptClient { server = null } - private fun createMcpServer(): Server { + fun createMcpServer(): Server { val server = Server( Implementation( name = "kotlin-http-server", @@ -487,6 +488,6 @@ class HttpServerTransport(private val sessionId: String) : AbstractTransport() { } fun main() { - val server = KotlinServerForTypeScriptClient() + val server = KotlinServerForTsClient() server.start() } diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/TsClientKotlinServerTestSse.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/TsClientKotlinServerTestSse.kt new file mode 100644 index 00000000..edee4279 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/TsClientKotlinServerTestSse.kt @@ -0,0 +1,50 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.typescript.sse + +import io.modelcontextprotocol.kotlin.sdk.integration.typescript.AbstractTsClientKotlinServerTest +import io.modelcontextprotocol.kotlin.sdk.integration.typescript.TransportKind +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.BeforeEach + +class TsClientKotlinServerTestSse : AbstractTsClientKotlinServerTest() { + + override val transportKind = TransportKind.SSE + + private var port: Int = 0 + private lateinit var serverUrl: String + private var httpServer: KotlinServerForTsClient? = null + + @BeforeEach + fun setUp() { + port = findFreePort() + serverUrl = "http://localhost:$port/mcp" + killProcessOnPort(port) + httpServer = KotlinServerForTsClient().also { it.start(port) } + check(waitForPort(port = port)) { "Kotlin test server did not become ready on localhost:$port within timeout" } + println("Kotlin server started on port $port") + } + + @AfterEach + fun tearDown() { + try { + httpServer?.stop() + println("HTTP server stopped") + } catch (e: Exception) { + println("Error during server shutdown: ${e.message}") + } + } + + override fun beforeServer() {} + override fun afterServer() {} + + override fun runClient(vararg args: String): String { + val cmd = buildString { + append("npx tsx myClient.ts ") + append(serverUrl) + if (args.isNotEmpty()) { + append(' ') + append(args.joinToString(" ")) + } + } + return executeCommand(cmd, tsClientDir) + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/TsEdgeCasesTestSse.kt similarity index 55% rename from kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptEdgeCasesTest.kt rename to kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/TsEdgeCasesTestSse.kt index 6504b49e..65b003df 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptEdgeCasesTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/TsEdgeCasesTestSse.kt @@ -1,5 +1,10 @@ -package io.modelcontextprotocol.kotlin.sdk.integration.typescript +package io.modelcontextprotocol.kotlin.sdk.integration.typescript.sse +import io.modelcontextprotocol.kotlin.sdk.integration.typescript.TransportKind +import io.modelcontextprotocol.kotlin.sdk.integration.typescript.TsTestBase +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll +import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.test.runTest import org.junit.jupiter.api.AfterEach import org.junit.jupiter.api.BeforeEach @@ -11,19 +16,22 @@ import java.io.File import java.util.concurrent.TimeUnit import kotlin.test.assertEquals import kotlin.test.assertTrue +import kotlin.test.fail -class TypeScriptEdgeCasesTest : TypeScriptTestBase() { +class TsEdgeCasesTestSse : TsTestBase() { + + override val transportKind = TransportKind.SSE private var port: Int = 0 private lateinit var serverUrl: String - private var httpServer: KotlinServerForTypeScriptClient? = null + private var httpServer: KotlinServerForTsClient? = null @BeforeEach fun setUp() { port = findFreePort() serverUrl = "http://localhost:$port/mcp" killProcessOnPort(port) - httpServer = KotlinServerForTypeScriptClient() + httpServer = KotlinServerForTsClient() httpServer?.start(port) if (!waitForPort(port = port)) { throw IllegalStateException("Kotlin test server did not become ready on localhost:$port within timeout") @@ -116,6 +124,35 @@ class TypeScriptEdgeCasesTest : TypeScriptTestBase() { @Test @Timeout(60, unit = TimeUnit.SECONDS) fun testComplexConcurrentRequests() = runTest { + fun prettyFail( + index: Int, + command: String, + expectation: String, + output: String, + ): Nothing { + val msg = buildString { + appendLine("Assertion failed for client #$index") + appendLine("Expectation: $expectation") + appendLine("Command: $command") + appendLine("----- OUTPUT BEGIN -----") + appendLine(output.trimEnd()) + appendLine("----- OUTPUT END -----") + } + fail(msg) + } + + fun assertContains( + index: Int, + command: String, + output: String, + needle: String, + description: String, + ) { + if (!output.contains(needle)) { + prettyFail(index, command, "$description — expected to contain: \"$needle\"", output) + } + } + val commands = listOf( "npx tsx myClient.ts $serverUrl greet \"Client1\"", "npx tsx myClient.ts $serverUrl multi-greet \"Client2\"", @@ -124,41 +161,78 @@ class TypeScriptEdgeCasesTest : TypeScriptTestBase() { "npx tsx myClient.ts $serverUrl multi-greet \"Client5\"", ) - val threads = commands.mapIndexed { index, command -> - Thread { - println("Starting client $index") - val output = executeCommand(command, tsClientDir) - println("Client $index completed") + coroutineScope { + val jobs = commands.mapIndexed { index, command -> + async(kotlinx.coroutines.Dispatchers.IO) { + println("Starting client $index") + val output = executeCommand(command, tsClientDir) + println("Client $index completed") - assertTrue( - output.contains("Connected to server"), - "Client $index should connect to server", - ) - assertTrue( - output.contains("Disconnected from server"), - "Client $index should disconnect cleanly", - ) + assertContains( + index, + command, + output, + "Connected to server", + "Client should connect to server", + ) + assertContains( + index, + command, + output, + "Disconnected from server", + "Client should disconnect cleanly", + ) - when { - command.contains("greet \"Client1\"") -> - assertTrue(output.contains("Hello, Client1!"), "Client 1 should receive correct greeting") + when { + command.contains("greet \"Client1\"") -> + assertContains( + index, + command, + output, + "Hello, Client1!", + "Client 1 should receive correct greeting", + ) - command.contains("multi-greet \"Client2\"") -> - assertTrue(output.contains("Multiple greetings"), "Client 2 should receive multiple greetings") + command.contains("multi-greet \"Client2\"") -> + assertContains( + index, + command, + output, + "Multiple greetings", + "Client 2 should receive multiple greetings", + ) - command.contains("greet \"Client3\"") -> - assertTrue(output.contains("Hello, Client3!"), "Client 3 should receive correct greeting") + command.contains("greet \"Client3\"") -> + assertContains( + index, + command, + output, + "Hello, Client3!", + "Client 3 should receive correct greeting", + ) - !command.contains("greet") && !command.contains("multi-greet") -> - assertTrue(output.contains("Available utils:"), "Client 4 should list available tools") + !command.contains("greet") && !command.contains("multi-greet") -> + assertContains( + index, + command, + output, + "Available utils:", + "Client 4 should list available tools", + ) - command.contains("multi-greet \"Client5\"") -> - assertTrue(output.contains("Multiple greetings"), "Client 5 should receive multiple greetings") + command.contains("multi-greet \"Client5\"") -> + assertContains( + index, + command, + output, + "Multiple greetings", + "Client 5 should receive multiple greetings", + ) + } } - }.apply { start() } + } + jobs.awaitAll() } - - threads.forEach { it.join() } } @Test diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/myClient.ts b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/myClient.ts similarity index 100% rename from kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/myClient.ts rename to kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/myClient.ts diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/simpleStreamableHttp.ts b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/simpleStreamableHttp.ts similarity index 100% rename from kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/simpleStreamableHttp.ts rename to kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/simpleStreamableHttp.ts diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/stdio/KotlinClientTsServerEdgeCasesTestStdio.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/stdio/KotlinClientTsServerEdgeCasesTestStdio.kt new file mode 100644 index 00000000..88a46856 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/stdio/KotlinClientTsServerEdgeCasesTestStdio.kt @@ -0,0 +1,182 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.typescript.stdio + +import io.modelcontextprotocol.kotlin.sdk.CallToolResult +import io.modelcontextprotocol.kotlin.sdk.TextContent +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.integration.typescript.TransportKind +import io.modelcontextprotocol.kotlin.sdk.integration.typescript.TsTestBase +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll +import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.runBlocking +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Timeout +import org.junit.jupiter.api.assertThrows +import java.util.concurrent.TimeUnit +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +class KotlinClientTsServerEdgeCasesTestStdio : TsTestBase() { + + override val transportKind = TransportKind.STDIO + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testNonExistentToolOverStdio(): Unit = runBlocking(Dispatchers.IO) { + withClientStdio { client: Client, _ -> + val nonExistentToolName = "non-existent-tool" + val arguments = mapOf("name" to "TestUser") + + val exception = assertThrows { + client.callTool(nonExistentToolName, arguments) + } + + val expectedMessage = + "JSONRPCError(code=InvalidParams, message=MCP error -32602: Tool non-existent-tool not found, data={})" + assertEquals( + expectedMessage, + exception.message, + "Unexpected error message for non-existent tool", + ) + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testSpecialCharactersInArgumentsOverStdio(): Unit = runBlocking(Dispatchers.IO) { + withClientStdio { client: Client, _ -> + val specialChars = "!@#$%^&*()_+{}[]|\\:;\"'<>.,?/" + val arguments = mapOf("name" to specialChars) + + val result = client.callTool("greet", arguments) + assertNotNull(result, "Tool call result should not be null") + + val callResult = result as CallToolResult + val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present in the result") + + val text = textContent.text ?: "" + assertTrue( + text.contains(specialChars), + "Tool response should contain the special characters", + ) + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testLargePayloadOverStdio(): Unit = runBlocking(Dispatchers.IO) { + withClientStdio { client: Client, _ -> + val largeName = "A".repeat(10 * 1024) + val arguments = mapOf("name" to largeName) + + val result = client.callTool("greet", arguments) + assertNotNull(result, "Tool call result should not be null") + + val callResult = result as CallToolResult + val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present in the result") + + val text = textContent.text ?: "" + assertTrue( + text.contains("Hello,") && text.contains("A"), + "Tool response should contain the greeting with the large name", + ) + } + } + + @Test + @Timeout(60, unit = TimeUnit.SECONDS) + fun testConcurrentRequestsOverStdio(): Unit = runBlocking(Dispatchers.IO) { + withClientStdio { client: Client, _ -> + val concurrentCount = 5 + val responses = coroutineScope { + val results = (1..concurrentCount).map { i -> + async { + val name = "ConcurrentClient$i" + val arguments = mapOf("name" to name) + + val result = client.callTool("greet", arguments) + assertNotNull(result, "Tool call result should not be null for client $i") + + val callResult = result as CallToolResult + val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present for client $i") + + textContent.text ?: "" + } + } + results.awaitAll() + } + + for (i in 1..concurrentCount) { + val expectedName = "ConcurrentClient$i" + val matchingResponses = responses.filter { it.contains("Hello, $expectedName!") } + assertEquals( + 1, + matchingResponses.size, + "Should have exactly one response for $expectedName", + ) + } + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testInvalidArgumentsOverStdio(): Unit = runBlocking(Dispatchers.IO) { + withClientStdio { client: Client, _ -> + val invalidArguments = mapOf( + "name" to JsonObject(mapOf("nested" to JsonPrimitive("value"))), + ) + + val exception = assertThrows { + client.callTool("greet", invalidArguments) + } + + val msg = exception.message ?: "" + val expectedMessage = """ + JSONRPCError(code=InvalidParams, message=MCP error -32602: Invalid arguments for tool greet: [ + { + "code": "invalid_type", + "expected": "string", + "received": "object", + "path": [ + "name" + ], + "message": "Expected string, received object" + } + ], data={}) + """.trimIndent() + + assertEquals(expectedMessage, msg, "Unexpected error message for invalid arguments") + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testMultipleToolCallsOverStdio(): Unit = runBlocking(Dispatchers.IO) { + withClientStdio { client: Client, _ -> + repeat(10) { i -> + val name = "SequentialClient$i" + val arguments = mapOf("name" to name) + + val result = client.callTool("greet", arguments) + assertNotNull(result, "Tool call result should not be null for call $i") + + val callResult = result as CallToolResult + val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present for call $i") + + assertEquals( + "Hello, $name!", + textContent.text, + "Tool response should contain the greeting with the provided name", + ) + } + } + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/stdio/KotlinClientTsServerTestStdio.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/stdio/KotlinClientTsServerTestStdio.kt new file mode 100644 index 00000000..82737249 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/stdio/KotlinClientTsServerTestStdio.kt @@ -0,0 +1,23 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.typescript.stdio + +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.integration.typescript.AbstractKotlinClientTsServerTest +import io.modelcontextprotocol.kotlin.sdk.integration.typescript.TransportKind + +class KotlinClientTsServerTestStdio : AbstractKotlinClientTsServerTest() { + + override val transportKind = TransportKind.STDIO + + override suspend fun useClient(block: suspend (Client) -> T): T = withClientStdio { client, proc -> + try { + block(client) + } finally { + try { + client.close() + } catch (_: Exception) {} + try { + stopProcess(proc, name = "TypeScript stdio server") + } catch (_: Exception) {} + } + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/stdio/TsClientKotlinServerTestStdio.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/stdio/TsClientKotlinServerTestStdio.kt new file mode 100644 index 00000000..abf821b5 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/stdio/TsClientKotlinServerTestStdio.kt @@ -0,0 +1,9 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.typescript.stdio + +import io.modelcontextprotocol.kotlin.sdk.integration.typescript.AbstractTsClientKotlinServerTest +import io.modelcontextprotocol.kotlin.sdk.integration.typescript.TransportKind + +class TsClientKotlinServerTestStdio : AbstractTsClientKotlinServerTest() { + override val transportKind = TransportKind.STDIO + override fun runClient(vararg args: String): String = runStdioClient(*args) +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/stdio/TsEdgeCasesTestStdio.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/stdio/TsEdgeCasesTestStdio.kt new file mode 100644 index 00000000..dafda6d5 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/stdio/TsEdgeCasesTestStdio.kt @@ -0,0 +1,124 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.typescript.stdio + +import io.modelcontextprotocol.kotlin.sdk.integration.typescript.TransportKind +import io.modelcontextprotocol.kotlin.sdk.integration.typescript.TsTestBase +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Timeout +import org.junit.jupiter.api.condition.EnabledOnOs +import org.junit.jupiter.api.condition.OS +import java.io.File +import java.util.concurrent.TimeUnit +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class TsEdgeCasesTestStdio : TsTestBase() { + + override val transportKind: TransportKind = TransportKind.STDIO + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testNonExistentToolOverStdio() { + val output = runStdioClient("non-existent-tool", "TestUser") + assertTrue(output.contains("Tool \"non-existent-tool\" not found"), "Should report non-existent tool.\n$output") + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testSpecialCharactersOverStdio() { + val specialChars = "!@#$+-[].,?" + val tempFile = File.createTempFile("special_chars", ".txt").apply { + writeText(specialChars) + deleteOnExit() + } + val content = tempFile.readText() + val output = runStdioClient("greet", content) + assertTrue(output.contains("Hello, $specialChars!"), "Tool should handle special characters.\n$output") + assertTrue(output.contains("Disconnected from server"), "Client should disconnect cleanly.\n$output") + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + @EnabledOnOs(OS.MAC, OS.LINUX) + fun testLargePayloadOverStdio() { + val largeName = "A".repeat(10 * 1024) + val tempFile = File.createTempFile("large_name", ".txt").apply { + writeText(largeName) + deleteOnExit() + } + val content = tempFile.readText() + val output = runStdioClient("greet", content) + tempFile.delete() + assertTrue( + output.contains("Hello,") && output.contains("A".repeat(20)), + "Should handle large payloads.\n$output", + ) + assertTrue(output.contains("Disconnected from server"), "Client should disconnect cleanly.\n$output") + } + + @Test + @Timeout(60, unit = TimeUnit.SECONDS) + fun testComplexConcurrentRequestsOverStdio() { + val commands: List> = listOf( + arrayOf("greet", "Client1"), + arrayOf("multi-greet", "Client2"), + arrayOf("greet", "Client3"), + emptyArray(), + arrayOf("multi-greet", "Client5"), + ) + + val threads = commands.mapIndexed { index, args -> + Thread { + val output = runStdioClient(*args) + assertTrue( + output.contains("Disconnected from server"), + "Client $index should disconnect cleanly.\n$output", + ) + when { + args.contentEquals(arrayOf("greet", "Client1")) -> + assertTrue( + output.contains("Hello, Client1!"), + "Client 1 should receive correct greeting.\n$output", + ) + + args.contentEquals(arrayOf("multi-greet", "Client2")) -> + assertTrue( + output.contains("Multiple greetings") || output.contains("greeting"), + "Client 2 should receive multiple greetings.\n$output", + ) + + args.contentEquals(arrayOf("greet", "Client3")) -> + assertTrue( + output.contains("Hello, Client3!"), + "Client 3 should receive correct greeting.\n$output", + ) + + args.isEmpty() -> + assertTrue( + output.contains("Available utils:"), + "Client 4 should list available tools.\n$output", + ) + + args.contentEquals(arrayOf("multi-greet", "Client5")) -> + assertTrue( + output.contains("Multiple greetings") || output.contains("greeting"), + "Client 5 should receive multiple greetings.\n$output", + ) + } + }.apply { start() } + } + + threads.forEach { it.join() } + } + + @Test + @Timeout(120, unit = TimeUnit.SECONDS) + fun testRapidSequentialRequestsOverStdio() { + val outputs = (1..10).map { i -> + val output = runStdioClient("greet", "RapidClient$i") + assertTrue(output.contains("Hello, RapidClient$i!"), "Client $i should receive correct greeting.\n$output") + assertTrue(output.contains("Disconnected from server"), "Client $i should disconnect cleanly.\n$output") + output + } + assertEquals(10, outputs.size, "All 10 rapid requests should complete successfully") + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/stdio/myClient.ts b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/stdio/myClient.ts new file mode 100644 index 00000000..e4aae91d --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/stdio/myClient.ts @@ -0,0 +1,159 @@ +// @ts-nocheck +const args = process.argv.slice(2); +const toolName = args[0]; +const toolArgs = args.slice(1); +const PROTOCOL_VERSION = "2024-11-05"; + +async function main() { + const sdkDirRaw = process.env.TYPESCRIPT_SDK_DIR; + const sdkDir = sdkDirRaw ? sdkDirRaw.trim() : undefined; + + let Client: any; + + if (!sdkDir) { + throw new Error('TYPESCRIPT_SDK_DIR environment variable is not set. It should point to the cloned typescript-sdk directory.'); + } + + const path = await import('path'); + const { pathToFileURL } = await import('url'); + const clientUrl = pathToFileURL(path.join(sdkDir, 'src', 'client', 'index.ts')).href; + ({ Client } = await import(clientUrl)); + + if (!toolName) { + console.error('Available utils will be listed after connection'); + } else { + console.error(`Will call tool: ${toolName} with args: ${toolArgs.join(', ')}`); + } + + class MinimalStdioClientTransport { + onmessage: ((msg: any) => void) | undefined; + private buffer: string = ''; + private closed = false; + + async start(): Promise { + process.stdin.setEncoding('utf8'); + process.stdin.resume(); + process.stdin.on('data', (chunk: string) => { + if (this.closed) return; + this.buffer += chunk; + this.processBuffer(); + }); + } + + private processBuffer() { + while (true) { + const idx = this.buffer.indexOf('\n'); + if (idx === -1) break; + const line = this.buffer.slice(0, idx); + this.buffer = this.buffer.slice(idx + 1); + const trimmed = line.trim(); + if (!trimmed) continue; + try { + const msg = JSON.parse(trimmed); + this.onmessage && this.onmessage(msg); + } catch (e) { + console.error('Parse error in client stdio (line):', trimmed, e); + } + } + } + + async send(msg: any): Promise { + if (this.closed) return; + const json = JSON.stringify(msg); + const payload = json + '\n'; + await new Promise((resolve, reject) => { + process.stdout.write(payload, 'utf8', (err?: Error | null) => err ? reject(err) : resolve()); + }); + } + + async close(): Promise { + this.closed = true; + try { process.stdin.pause(); } catch {} + } + } + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new MinimalStdioClientTransport(); + + try { + await client.connect(transport, { protocolVersion: PROTOCOL_VERSION }); + console.error('Connected to server over stdio'); + + try { + if (typeof (client as any).on === 'function') { + (client as any).on('notification', (n: any) => { + try { + const method = (n && (n.method || (n.params && n.params.method))) || 'unknown'; + console.error('Notification:', method, JSON.stringify(n)); + } catch { + console.error('Notification: '); + } + }); + } + } catch { + // ignore + } + + const toolsResult = await client.listTools(); + const tools = toolsResult.tools; + console.error('Available utils:', tools.map((t: { name: any; }) => t.name).join(', ')); + + if (!toolName) { + await client.close(); + return; + } + + const tool = tools.find((t: { name: string; }) => t.name === toolName); + if (!tool) { + console.error(`Tool "${toolName}" not found`); + process.exit(1); + } + + const toolArguments: any = {}; + + if (toolName === 'greet' && toolArgs.length > 0) { + toolArguments['name'] = toolArgs[0]; + } else if (tool.input && tool.input.properties) { + const propNames = Object.keys(tool.input.properties); + if (propNames.length > 0 && toolArgs.length > 0) { + toolArguments[propNames[0]] = toolArgs[0]; + } + } + + console.error(`Calling tool ${toolName} with arguments:`, toolArguments); + + const result = await client.callTool({ + name: toolName, + arguments: toolArguments + }); + console.error('Tool result:', JSON.stringify(result)); + + if (result.content) { + for (const content of result.content) { + if (content.type === 'text') { + console.error('Text content:', content.text); + } + } + } + + if (result.structuredContent) { + console.error('Structured content:', JSON.stringify(result.structuredContent, null, 2)); + } + + } catch (error) { + console.error('Error:', error); + process.exit(1); + } finally { + await client.close(); + console.error('Disconnected from server'); + } +} + +main().catch(error => { + console.error('Unhandled error:', error); + process.exit(1); +}); diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/stdio/simpleStdio.ts b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/stdio/simpleStdio.ts new file mode 100644 index 00000000..29863091 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/stdio/simpleStdio.ts @@ -0,0 +1,74 @@ +// @ts-nocheck +import { z } from 'zod'; +import path from 'node:path'; +import { pathToFileURL } from 'node:url'; + +const SDK_DIR = process.env.TYPESCRIPT_SDK_DIR; +if (!SDK_DIR) { + throw new Error('TYPESCRIPT_SDK_DIR environment variable is not set. It should point to the cloned typescript-sdk directory.'); +} + +async function importFromSdk(rel: string): Promise { + const full = path.resolve(SDK_DIR!, rel); + const url = pathToFileURL(full).href; + return await import(url); +} + +async function main() { + const { McpServer } = await importFromSdk('src/server/mcp.ts'); + const { StdioServerTransport } = await importFromSdk('src/server/stdio.ts'); + + const server = new McpServer({ + name: 'simple-stdio-server', + version: '1.0.0', + }, { capabilities: { logging: {} } }); + + // Simple tools mirroring ones from HTTP test server + server.registerTool('greet', { + title: 'Greeting Tool', + description: 'A simple greeting tool', + inputSchema: { name: z.string().describe('Name to greet') }, + }, async ({ name }): Promise => { + return { content: [{ type: 'text', text: `Hello, ${name}!` }] }; + }); + + server.tool('multi-greet', 'A tool that sends different greetings with delays between them', + { name: z.string().describe('Name to greet') }, + { title: 'Multiple Greeting Tool', readOnlyHint: true, openWorldHint: false }, + async ({ name }, extra): Promise => { + const sleep = (ms: number) => new Promise(r => setTimeout(r, ms)); + await server.sendLoggingMessage({ level: 'debug', data: `Starting multi-greet for ${name}` }, extra.sessionId); + await sleep(200); + await server.sendLoggingMessage({ level: 'info', data: `Sending first greeting to ${name}` }, extra.sessionId); + await sleep(200); + await server.sendLoggingMessage({ level: 'info', data: `Sending second greeting to ${name}` }, extra.sessionId); + return { content: [{ type: 'text', text: `Good morning, ${name}!` }] }; + } + ); + + server.registerPrompt('greeting-template', { + title: 'Greeting Template', + description: 'A simple greeting prompt template', + argsSchema: { name: z.string().describe('Name to include in greeting') }, + }, async ({ name }): Promise => { + return { + messages: [{ role: 'user', content: { type: 'text', text: `Please greet ${name} in a friendly manner.` } }], + }; + }); + + server.registerResource('greeting-resource', 'https://example.com/greetings/default', { + title: 'Default Greeting', + description: 'A simple greeting resource', + mimeType: 'text/plain', + }, async (): Promise => { + return { contents: [{ uri: 'https://example.com/greetings/default', text: 'Hello, world!' }] }; + }); + + const transport = new StdioServerTransport(); + await server.connect(transport); +} + +main().catch((err) => { + console.error('Failed to start stdio server:', err); + process.exit(1); +}); diff --git a/kotlin-sdk/build.gradle.kts b/kotlin-sdk/build.gradle.kts index 0ac37c91..e59df60e 100644 --- a/kotlin-sdk/build.gradle.kts +++ b/kotlin-sdk/build.gradle.kts @@ -1,7 +1,6 @@ plugins { id("mcp.multiplatform") id("mcp.publishing") - id("mcp.jreleaser") } kotlin { diff --git a/kotlin-sdk/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/package.kt b/kotlin-sdk/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/package.kt new file mode 100644 index 00000000..59faee4f --- /dev/null +++ b/kotlin-sdk/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/package.kt @@ -0,0 +1,25 @@ +@file:Suppress("ktlint:standard:no-empty-class-body", "ktlint:standard:kdoc") +/** + * # MCP Kotlin SDK + * + * A Kotlin Multiplatform implementation of the Model Context Protocol (MCP). + * + * This is the main SDK module that provides a convenient single dependency + * for all MCP functionality including: + * + * - Core protocol types and utilities ([kotlin-sdk-core]) + * - Client implementations ([kotlin-sdk-client]) + * - Server implementations ([kotlin-sdk-server]) + * + * ## Usage + * + * Add this dependency to your project to get access to all MCP Kotlin SDK functionality: + * + * ```kotlin + * implementation("io.modelcontextprotocol:kotlin-sdk:$version") + * ``` + * + * This will transitively include all core, client, and server components. + */ + +package io.modelcontextprotocol.kotlin.sdk diff --git a/samples/kotlin-mcp-client/build.gradle.kts b/samples/kotlin-mcp-client/build.gradle.kts index b34e8ed2..3e3ad859 100644 --- a/samples/kotlin-mcp-client/build.gradle.kts +++ b/samples/kotlin-mcp-client/build.gradle.kts @@ -12,7 +12,7 @@ group = "org.example" version = "0.1.0" dependencies { - implementation(libs.mcp.kotlin) + implementation(libs.mcp.kotlin.client) implementation(libs.ktor.client.cio) implementation(libs.anthropic.java) implementation(libs.slf4j.simple) diff --git a/samples/kotlin-mcp-client/settings.gradle.kts b/samples/kotlin-mcp-client/settings.gradle.kts index 6e2f160d..c5f225b8 100644 --- a/samples/kotlin-mcp-client/settings.gradle.kts +++ b/samples/kotlin-mcp-client/settings.gradle.kts @@ -9,6 +9,7 @@ pluginManagement { dependencyResolutionManagement { repositories { + mavenLocal() mavenCentral() } versionCatalogs { @@ -16,4 +17,4 @@ dependencyResolutionManagement { from(files("../../gradle/libs.versions.toml")) } } -} \ No newline at end of file +} diff --git a/samples/kotlin-mcp-server/build.gradle.kts b/samples/kotlin-mcp-server/build.gradle.kts index 6a649687..e85100d7 100644 --- a/samples/kotlin-mcp-server/build.gradle.kts +++ b/samples/kotlin-mcp-server/build.gradle.kts @@ -11,10 +11,6 @@ plugins { group = "org.example" version = "0.1.0" -repositories { - mavenCentral() -} - val jvmMainClass = "Main_jvmKt" kotlin { @@ -43,7 +39,7 @@ kotlin { sourceSets { commonMain.dependencies { - implementation(libs.mcp.kotlin) + implementation(libs.mcp.kotlin.server) implementation(libs.ktor.server.cio) } jvmMain.dependencies { diff --git a/samples/kotlin-mcp-server/settings.gradle.kts b/samples/kotlin-mcp-server/settings.gradle.kts index 08039844..cac44ab8 100644 --- a/samples/kotlin-mcp-server/settings.gradle.kts +++ b/samples/kotlin-mcp-server/settings.gradle.kts @@ -9,6 +9,7 @@ pluginManagement { dependencyResolutionManagement { repositories { + mavenLocal() mavenCentral() } versionCatalogs { @@ -16,4 +17,4 @@ dependencyResolutionManagement { from(files("../../gradle/libs.versions.toml")) } } -} \ No newline at end of file +} diff --git a/samples/weather-stdio-server/build.gradle.kts b/samples/weather-stdio-server/build.gradle.kts index f82306cc..f39057f9 100644 --- a/samples/weather-stdio-server/build.gradle.kts +++ b/samples/weather-stdio-server/build.gradle.kts @@ -9,22 +9,19 @@ application { mainClass.set("io.modelcontextprotocol.sample.server.MainKt") } -repositories { - mavenCentral() -} - group = "org.example" version = "0.1.0" dependencies { implementation(libs.ktor.client.content.negotiation) implementation(libs.ktor.serialization.kotlinx.json) - implementation(libs.mcp.kotlin) + implementation(libs.mcp.kotlin.server) implementation(libs.ktor.server.cio) implementation(libs.ktor.client.cio) implementation(libs.slf4j.simple) testImplementation(kotlin("test")) + testImplementation(libs.mcp.kotlin.client) testImplementation(libs.kotlinx.coroutines.test) } diff --git a/samples/weather-stdio-server/settings.gradle.kts b/samples/weather-stdio-server/settings.gradle.kts index 1121bf5f..67e786b1 100644 --- a/samples/weather-stdio-server/settings.gradle.kts +++ b/samples/weather-stdio-server/settings.gradle.kts @@ -9,6 +9,7 @@ pluginManagement { dependencyResolutionManagement { repositories { + mavenLocal() mavenCentral() } versionCatalogs { @@ -16,4 +17,4 @@ dependencyResolutionManagement { from(files("../../gradle/libs.versions.toml")) } } -} \ No newline at end of file +}