KotlinDL is a high-level Deep Learning API written in Kotlin and inspired by Keras. Under the hood, it uses SKaiNET as its tensor computation backend, enabling Kotlin Multiplatform support across all SKaiNET-supported platforms.
KotlinDL offers simple APIs for training deep learning models from scratch and leveraging transfer learning for tailoring existing pre-trained models to your tasks.
By leveraging SKaiNET, KotlinDL runs on all platforms supported by SKaiNET:
- JVM (Linux, macOS, Windows)
- Android
- Native (Linux x64/ARM64, macOS ARM64, iOS ARM64/Simulator)
- JavaScript (Browser, Node.js)
- WebAssembly (Browser, Node.js)
Write your deep learning code once and deploy it anywhere.
Here's an example of training a simple neural network:
// KotlinDL wrapper imports
import org.jetbrains.kotlinx.dl.api.core.dsl.deepLearning
import org.jetbrains.kotlinx.dl.api.core.loss.MeanSquaredError
import org.jetbrains.kotlinx.dl.api.core.metric.accuracy
import org.jetbrains.kotlinx.dl.api.core.tensor.FloatType
// SKaiNET tensor DSL and operations
import sk.ainet.lang.tensor.dsl.tensor
import sk.ainet.lang.tensor.matmul
import sk.ainet.lang.nn.optim.sgd
import sk.ainet.lang.nn.topology.ModuleParameter
fun main() {
deepLearning {
// Create training data using tensor DSL
val x = tensor<FloatType, Float> {
shape(4, 2) { from(0f, 0f, 0f, 1f, 1f, 0f, 1f, 1f) }
}
val y = tensor<FloatType, Float> {
shape(4, 1) { from(0f, 1f, 1f, 0f) }
}
// Define model parameters with gradient tracking
val w = tensor<FloatType, Float> {
shape(2, 1) { from(0.1f, 0.2f) }
}.withRequiresGrad()
val wParam = ModuleParameter.WeightParameter("w", w)
val optimizer = sgd(lr = 0.1f)
val lossFunction = MeanSquaredError()
// Training loop
repeat(100) { epoch ->
trainStep(optimizer, wParam) {
val predictions = x.matmul(wParam.value)
lossFunction.forward(predictions, y, ctx)
}
if (epoch % 10 == 0) {
println("Epoch $epoch")
}
}
}
}- Library Structure
- How to configure KotlinDL in your project
- Requirements
- Documentation
- Examples and tutorials
- Logging
- Contributing
- Reporting issues/Support
- Code of Conduct
- License
KotlinDL consists of several modules:
kotlin-deeplearning-api- API interfaces and classeskotlin-deeplearning-impl- Implementation classes, utilities, and SKaiNET integrationkotlin-deeplearning-dataset- Dataset classes and data loading utilities
All modules are built with Kotlin Multiplatform and support all SKaiNET-compatible platforms (JVM, Android, Native, JS/Wasm).
To use KotlinDL in your project, ensure that mavenCentral() and mavenLocal() (for SKaiNET snapshots) are added to the repositories list:
// settings.gradle.kts
dependencyResolutionManagement {
repositories {
mavenCentral()
google()
mavenLocal() // For SKaiNET snapshots
}
}Then add the necessary dependencies to your build.gradle.kts file:
// build.gradle.kts
dependencies {
implementation("org.jetbrains.kotlinx:kotlin-deeplearning-impl:[KOTLIN-DL-VERSION]")
implementation("org.jetbrains.kotlinx:kotlin-deeplearning-dataset:[KOTLIN-DL-VERSION]")
}Or using Gradle version catalog (gradle/libs.versions.toml):
[versions]
kotlindl = "[KOTLIN-DL-VERSION]"
[libraries]
kotlindl-impl = { module = "org.jetbrains.kotlinx:kotlin-deeplearning-impl", version.ref = "kotlindl" }
kotlindl-dataset = { module = "org.jetbrains.kotlinx:kotlin-deeplearning-dataset", version.ref = "kotlindl" }// build.gradle.kts
dependencies {
implementation(libs.kotlindl.impl)
implementation(libs.kotlindl.dataset)
}The latest KotlinDL version is 0.6.0.
KotlinDL is built with Kotlin Multiplatform, allowing you to share deep learning code across all supported platforms. In your build.gradle.kts:
kotlin {
// JVM & Android
jvm()
androidTarget()
// Native - Desktop
linuxX64()
linuxArm64()
macosArm64()
// Native - iOS
iosArm64()
iosSimulatorArm64()
// JavaScript & WebAssembly
js(IR) { browser(); nodejs() }
wasmJs { browser(); nodejs() }
sourceSets {
commonMain.dependencies {
implementation("org.jetbrains.kotlinx:kotlin-deeplearning-impl:[KOTLIN-DL-VERSION]")
implementation("org.jetbrains.kotlinx:kotlin-deeplearning-dataset:[KOTLIN-DL-VERSION]")
}
}
}SKaiNET provides the platform-specific tensor computation implementations automatically, so your deep learning code in commonMain works across all targets without modification.
KotlinDL supports Android development with SKaiNET as the tensor computation backend.
To use KotlinDL in your Android project, add the following to your build.gradle.kts:
// build.gradle.kts
android {
compileSdk = 34
defaultConfig {
minSdk = 26
targetSdk = 34
}
compileOptions {
sourceCompatibility = JavaVersion.VERSION_21
targetCompatibility = JavaVersion.VERSION_21
}
}
dependencies {
implementation("org.jetbrains.kotlinx:kotlin-deeplearning-impl:[KOTLIN-DL-VERSION]")
}| KotlinDL Version | Kotlin Version | Minimum Java Version | Android: Compile SDK Version |
|---|---|---|---|
| 0.6.0+ | 2.1.0 | 21 | 34 |
| 0.5.0-0.5.2 | 1.8.x | 11 | 31 |
You do not need prior experience with Deep Learning to use KotlinDL.
We are working on including extensive documentation to help you get started. At this point, please feel free to check out the following tutorials we have prepared:
For more inspiration, take a look at the code examples in this repository.
By default, the API module uses the kotlin-logging library to organize the logging process separately from the specific logger implementation.
You could use any widely known JVM logging library with a Simple Logging Facade for Java (SLF4J) implementation such as Logback or Log4j/Log4j2.
You will also need to add the following dependencies and configuration file log4j2.xml to the src/resource folder in your project if you wish to use log4j2:
// build.gradle.kts
dependencies {
implementation("org.apache.logging.log4j:log4j-api:2.17.2")
implementation("org.apache.logging.log4j:log4j-core:2.17.2")
implementation("org.apache.logging.log4j:log4j-slf4j-impl:2.17.2")
}<Configuration status="WARN">
<Appenders>
<Console name="STDOUT" target="SYSTEM_OUT">
<PatternLayout pattern="%d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n"/>
</Console>
</Appenders>
<Loggers>
<Root level="debug">
<AppenderRef ref="STDOUT" level="DEBUG"/>
</Root>
<Logger name="io.jhdf" level="off" additivity="true">
<appender-ref ref="STDOUT" />
</Logger>
</Loggers>
</Configuration>If you wish to use Logback, include the following dependency and configuration file logback.xml to src/resource folder in your project:
// build.gradle.kts
dependencies {
implementation("ch.qos.logback:logback-classic:1.4.5")
}<configuration>
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
<encoder>
<pattern>%d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n</pattern>
</encoder>
</appender>
<root level="info">
<appender-ref ref="STDOUT"/>
</root>
</configuration>Read the Contributing Guidelines.
Please use GitHub issues for filing feature requests and bug reports. You are also welcome to join the #kotlindl channel in Kotlin Slack.
This project and the corresponding community are governed by the JetBrains Open Source and Community Code of Conduct. Please make sure you read it.
KotlinDL is licensed under the Apache 2.0 License.