forked from microsoft/onnxruntime
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathOnnxRuntime.java
More file actions
534 lines (473 loc) · 19.4 KB
/
OnnxRuntime.java
File metadata and controls
534 lines (473 loc) · 19.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
/*
* Copyright (c) 2019, 2025, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.EnumSet;
import java.util.HashSet;
import java.util.Locale;
import java.util.Optional;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;
/**
* Static loader for the JNI binding. No public API, but called from various classes in this package
* to ensure shared libraries are properly loaded.
*/
final class OnnxRuntime {
private static final Logger logger = Logger.getLogger(OnnxRuntime.class.getName());
// The initial release of the ORT API.
private static final int ORT_API_VERSION_1 = 1;
// Post 1.0 builds of the ORT API.
private static final int ORT_API_VERSION_2 = 2;
// Post 1.3 builds of the ORT API
private static final int ORT_API_VERSION_3 = 3;
// Post 1.6 builds of the ORT API
private static final int ORT_API_VERSION_7 = 7;
// Post 1.7 builds of the ORT API
private static final int ORT_API_VERSION_8 = 8;
// Post 1.10 builds of the ORT API
private static final int ORT_API_VERSION_11 = 11;
// Post 1.12 builds of the ORT API
private static final int ORT_API_VERSION_13 = 13;
// Post 1.13 builds of the ORT API
private static final int ORT_API_VERSION_14 = 14;
// Post 1.22 builds of the ORT API
private static final int ORT_API_VERSION_23 = 23;
// The initial release of the ORT training API.
private static final int ORT_TRAINING_API_VERSION_1 = 1;
/**
* The name of the system property which when set gives the path on disk where the ONNX Runtime
* native libraries are stored.
*/
static final String ONNXRUNTIME_NATIVE_PATH = "onnxruntime.native.path";
/** The short name of the ONNX runtime shared library */
static final String ONNXRUNTIME_LIBRARY_NAME = "onnxruntime";
/** The short name of the ONNX runtime JNI shared library */
static final String ONNXRUNTIME_JNI_LIBRARY_NAME = "onnxruntime4j_jni";
/** The short name of the ONNX runtime shared provider library */
static final String ONNXRUNTIME_LIBRARY_SHARED_NAME = "onnxruntime_providers_shared";
/** The short name of the ONNX runtime CUDA provider library */
static final String ONNXRUNTIME_LIBRARY_CUDA_NAME = "onnxruntime_providers_cuda";
/** The short name of the ONNX runtime ROCM provider library */
static final String ONNXRUNTIME_LIBRARY_ROCM_NAME = "onnxruntime_providers_rocm";
/** The short name of the ONNX runtime DNNL provider library */
static final String ONNXRUNTIME_LIBRARY_DNNL_NAME = "onnxruntime_providers_dnnl";
/** The short name of the ONNX runtime OpenVINO provider library */
static final String ONNXRUNTIME_LIBRARY_OPENVINO_NAME = "onnxruntime_providers_openvino";
/** The short name of the ONNX runtime TensorRT provider library */
static final String ONNXRUNTIME_LIBRARY_TENSORRT_NAME = "onnxruntime_providers_tensorrt";
/** The short name of the ONNX runtime QNN provider library */
static final String ONNXRUNTIME_LIBRARY_QNN_NAME = "onnxruntime_providers_qnn";
/** The short name of the WebGPU DAWN library */
static final String ONNXRUNTIME_LIBRARY_WEBGPU_DAWN_NAME = "webgpu_dawn";
/** The short name of the WebGPU DXC library "dxil.dll" */
static final String ONNXRUNTIME_LIBRARY_WEBGPU_DXC_DXIL_NAME = "dxil";
/** The short name of the WebGPU DXC library "dxcompiler.dll" */
static final String ONNXRUNTIME_LIBRARY_WEBGPU_DXC_DXCOMPILER_NAME = "dxcompiler";
/** The OS & CPU architecture string */
private static final String OS_ARCH_STR = initOsArch();
/** Have the core ONNX Runtime native libraries been loaded */
private static boolean loaded = false;
/** The temp directory where native libraries are extracted */
private static Path tempDirectory;
/** The value of the {@link #ONNXRUNTIME_NATIVE_PATH} system property */
private static String libraryDirPathProperty;
/** Tracks if the shared providers have been extracted */
private static final Set<String> extractedSharedProviders = new HashSet<>();
/** The API handle. */
static long ortApiHandle;
/** The Training API handle. */
static long ortTrainingApiHandle;
/** The Compile API handle. */
static long ortCompileApiHandle;
/** Is training enabled in the native library */
static boolean trainingEnabled;
/** The available runtime providers */
static EnumSet<OrtProvider> providers;
/** The version string. */
private static String version;
private OnnxRuntime() {}
/* Computes and initializes OS_ARCH_STR (such as linux-x64) */
private static String initOsArch() {
String detectedOS = null;
String os = System.getProperty("os.name", "generic").toLowerCase(Locale.ENGLISH);
if (os.contains("mac") || os.contains("darwin")) {
detectedOS = "osx";
} else if (os.contains("win")) {
detectedOS = "win";
} else if (os.contains("nux")) {
detectedOS = "linux";
} else if (isAndroid()) {
detectedOS = "android";
} else {
throw new IllegalStateException("Unsupported os:" + os);
}
String detectedArch = null;
String arch = System.getProperty("os.arch", "generic").toLowerCase(Locale.ENGLISH);
if (arch.startsWith("amd64") || arch.startsWith("x86_64")) {
detectedArch = "x64";
} else if (arch.startsWith("x86")) {
// 32-bit x86 is not supported by the Java API
detectedArch = "x86";
} else if (arch.startsWith("aarch64")) {
detectedArch = "aarch64";
} else if (arch.startsWith("ppc64")) {
detectedArch = "ppc64";
} else if (arch.startsWith("loongarch64")) {
detectedArch = "loongarch64";
} else if (isAndroid()) {
detectedArch = arch;
} else {
throw new IllegalStateException("Unsupported arch:" + arch);
}
return detectedOS + '-' + detectedArch;
}
/**
* Loads the native C library.
*
* @throws IOException If it can't write to disk to copy out the library from the jar file.
* @throws IllegalStateException If the native library failed to load.
*/
static synchronized void init() throws IOException {
if (loaded) {
return;
}
tempDirectory = isAndroid() ? null : Files.createTempDirectory("onnxruntime-java");
try {
libraryDirPathProperty = System.getProperty(ONNXRUNTIME_NATIVE_PATH);
// Extract and prepare the shared provider library but don't try to load it,
// the ONNX Runtime native library will load it
extractProviderLibrary(ONNXRUNTIME_LIBRARY_SHARED_NAME);
// Extract and prepare the Dawn shared libraries (if present) but don't try to load them,
// the ONNX Runtime native library will load them
extractProviderLibrary(ONNXRUNTIME_LIBRARY_WEBGPU_DAWN_NAME);
extractProviderLibrary(ONNXRUNTIME_LIBRARY_WEBGPU_DXC_DXIL_NAME);
extractProviderLibrary(ONNXRUNTIME_LIBRARY_WEBGPU_DXC_DXCOMPILER_NAME);
if (!isAndroid()) {
load(ONNXRUNTIME_LIBRARY_NAME);
}
load(ONNXRUNTIME_JNI_LIBRARY_NAME);
ortApiHandle = initialiseAPIBase(ORT_API_VERSION_23);
if (ortApiHandle == 0L) {
throw new IllegalStateException(
"There is a mismatch between the ORT class files and the ORT native library, and the native library could not be loaded");
}
ortTrainingApiHandle = initialiseTrainingAPIBase(ortApiHandle, ORT_API_VERSION_23);
ortCompileApiHandle = initialiseCompileAPIBase(ortApiHandle);
trainingEnabled = ortTrainingApiHandle != 0L;
providers = initialiseProviders(ortApiHandle);
version = initialiseVersion();
loaded = true;
} finally {
if (tempDirectory != null) {
cleanUp(tempDirectory.toFile());
}
}
}
/**
* Marks the file for delete on exit.
*
* @param file The file to remove.
*/
private static void cleanUp(File file) {
if (!file.exists()) {
return;
}
logger.log(Level.FINE, "Deleting " + file + " on exit");
file.deleteOnExit();
}
/**
* Gets the native library version string.
*
* @return The version string.
*/
static String version() {
return version;
}
/**
* Extracts the CUDA provider library from the classpath resources if present, or checks to see if
* the CUDA provider library is in the directory specified by {@link #ONNXRUNTIME_NATIVE_PATH}.
*
* @return True if the CUDA provider library is ready for loading, false otherwise.
*/
static boolean extractCUDA() {
return extractProviderLibrary(ONNXRUNTIME_LIBRARY_CUDA_NAME);
}
/**
* Extracts the ROCM provider library from the classpath resources if present, or checks to see if
* the ROCM provider library is in the directory specified by {@link #ONNXRUNTIME_NATIVE_PATH}.
*
* @return True if the ROCM provider library is ready for loading, false otherwise.
*/
static boolean extractROCM() {
return extractProviderLibrary(ONNXRUNTIME_LIBRARY_ROCM_NAME);
}
/**
* Extracts the DNNL provider library from the classpath resources if present, or checks to see if
* the DNNL provider library is in the directory specified by {@link #ONNXRUNTIME_NATIVE_PATH}.
*
* @return True if the DNNL provider library is ready for loading, false otherwise.
*/
static boolean extractDNNL() {
return extractProviderLibrary(ONNXRUNTIME_LIBRARY_DNNL_NAME);
}
/**
* Extracts the OpenVINO provider library from the classpath resources if present, or checks to
* see if the OpenVINO provider library is in the directory specified by {@link
* #ONNXRUNTIME_NATIVE_PATH}.
*
* @return True if the OpenVINO provider library is ready for loading, false otherwise.
*/
static boolean extractOpenVINO() {
return extractProviderLibrary(ONNXRUNTIME_LIBRARY_OPENVINO_NAME);
}
/**
* Extracts the TensorRT provider library from the classpath resources if present, or checks to
* see if the TensorRT provider library is in the directory specified by {@link
* #ONNXRUNTIME_NATIVE_PATH}.
*
* @return True if the TensorRT provider library is ready for loading, false otherwise.
*/
static boolean extractTensorRT() {
return extractProviderLibrary(ONNXRUNTIME_LIBRARY_TENSORRT_NAME);
}
/**
* Extracts the QNN provider library from the classpath resources if present, or checks to see if
* the QNN provider library is in the directory specified by {@link #ONNXRUNTIME_NATIVE_PATH}.
*
* @return True if the QNN provider library is ready for loading, false otherwise.
*/
static boolean extractQNN() {
return extractProviderLibrary(ONNXRUNTIME_LIBRARY_QNN_NAME);
}
/**
* Extracts a shared provider library from the classpath resources if present, or checks to see if
* that library is in the directory specified by {@link #ONNXRUNTIME_NATIVE_PATH}.
*
* @param libraryName The shared provider library to load.
* @return True if the library is ready for loading by ORT's native code, false otherwise.
*/
static synchronized boolean extractProviderLibrary(String libraryName) {
// Android does not need to extract provider libraries.
if (isAndroid()) {
return false;
}
// Check if we've already extracted or check this provider, and it's ready
if (extractedSharedProviders.contains(libraryName)) {
return true;
}
// If a native library directory is configured, prefer it and skip extraction when present.
if (libraryDirPathProperty != null) {
String libraryFileName = mapLibraryName(libraryName);
File libraryFile = Paths.get(libraryDirPathProperty, libraryFileName).toFile();
if (libraryFile.exists()) {
extractedSharedProviders.add(libraryName);
return true;
}
}
// Otherwise extract the file from the classpath resources.
Optional<File> file = extractFromResources(libraryName);
if (file.isPresent()) {
extractedSharedProviders.add(libraryName);
return true;
} else {
return false;
}
}
/**
* Check if we're running on Android.
*
* @return True if the property java.vendor equals The Android Project, false otherwise.
*/
static boolean isAndroid() {
return System.getProperty("java.vendor", "generic").equals("The Android Project");
}
/**
* Load a shared library by name.
*
* <p>If the library path is not specified via a system property then it attempts to extract the
* library from the classpath before loading it.
*
* @param library The bare name of the library.
* @throws IOException If the file failed to read or write.
*/
private static void load(String library) throws IOException {
// On Android, we simply use System.loadLibrary
if (isAndroid()) {
System.loadLibrary(library);
return;
}
// 1) The user may skip loading of this library:
String skip = System.getProperty("onnxruntime.native." + library + ".skip");
if (Boolean.TRUE.toString().equalsIgnoreCase(skip)) {
logger.log(Level.FINE, "Skipping load of native library '" + library + "'");
return;
}
// Resolve the platform dependent library name.
String libraryFileName = mapLibraryName(library);
// 2) The user may explicitly specify the path to a directory containing all shared libraries:
if (libraryDirPathProperty != null) {
logger.log(
Level.FINE,
"Attempting to load native library '"
+ library
+ "' from specified path: "
+ libraryDirPathProperty);
// TODO: Switch this to Path.of when the minimum Java version is 11.
File libraryFile = Paths.get(libraryDirPathProperty, libraryFileName).toFile();
String libraryFilePath = libraryFile.getAbsolutePath();
if (!libraryFile.exists()) {
throw new IOException("Native library '" + library + "' not found at " + libraryFilePath);
}
System.load(libraryFilePath);
logger.log(Level.FINE, "Loaded native library '" + library + "' from specified path");
return;
}
// 3) The user may explicitly specify the path to their shared library:
String libraryPathProperty = System.getProperty("onnxruntime.native." + library + ".path");
if (libraryPathProperty != null) {
logger.log(
Level.FINE,
"Attempting to load native library '"
+ library
+ "' from specified path: "
+ libraryPathProperty);
File libraryFile = new File(libraryPathProperty);
String libraryFilePath = libraryFile.getAbsolutePath();
if (!libraryFile.exists()) {
throw new IOException("Native library '" + library + "' not found at " + libraryFilePath);
}
System.load(libraryFilePath);
logger.log(Level.FINE, "Loaded native library '" + library + "' from specified path");
return;
}
// 4) try loading from resources or library path:
Optional<File> extractedPath = extractFromResources(library);
if (extractedPath.isPresent()) {
// extracted library from resources
System.load(extractedPath.get().getAbsolutePath());
logger.log(Level.FINE, "Loaded native library '" + library + "' from resource path");
} else {
// failed to load library from resources, try to load it from the library path
logger.log(
Level.FINE, "Attempting to load native library '" + library + "' from library path");
System.loadLibrary(library);
logger.log(Level.FINE, "Loaded native library '" + library + "' from library path");
}
}
/**
* Extracts the library from the classpath resources. returns optional.empty if it failed to
* extract or couldn't be found.
*
* @param library The library name
* @return An optional containing the file if it is successfully extracted, or an empty optional
* if it failed to extract or couldn't be found.
*/
private static Optional<File> extractFromResources(String library) {
String libraryFileName = mapLibraryName(library);
String resourcePath = "/ai/onnxruntime/native/" + OS_ARCH_STR + '/' + libraryFileName;
File tempFile = tempDirectory.resolve(libraryFileName).toFile();
try (InputStream is = OnnxRuntime.class.getResourceAsStream(resourcePath)) {
if (is == null) {
// Not found in classpath resources
return Optional.empty();
} else {
// Found in classpath resources, load via temporary file
logger.log(
Level.FINE,
"Attempting to load native library '"
+ library
+ "' from resource path "
+ resourcePath
+ " copying to "
+ tempFile);
byte[] buffer = new byte[4096];
int readBytes;
try (FileOutputStream os = new FileOutputStream(tempFile)) {
while ((readBytes = is.read(buffer)) != -1) {
os.write(buffer, 0, readBytes);
}
}
logger.log(Level.FINE, "Extracted native library '" + library + "' from resource path");
return Optional.of(tempFile);
}
} catch (IOException e) {
logger.log(
Level.WARNING, "Failed to extract library '" + library + "' from the resources", e);
return Optional.empty();
} finally {
cleanUp(tempFile);
}
}
/**
* Maps the library name into a platform dependent library filename. Converts macOS's "jnilib" to
* "dylib" but otherwise is the same as {@link System#mapLibraryName(String)}.
*
* @param library The library name
* @return The library filename.
*/
private static String mapLibraryName(String library) {
return System.mapLibraryName(library).replace("jnilib", "dylib");
}
/**
* Extracts the providers array from the C API, converts it into an EnumSet.
*
* <p>Throws IllegalArgumentException if a provider isn't recognized (note this exception should
* only happen during development of ONNX Runtime, if it happens at any other point, file an issue
* on <a href="https://github.com/microsoft/onnxruntime">GitHub</a>).
*
* @param ortApiHandle The API Handle.
* @return The enum set.
*/
private static EnumSet<OrtProvider> initialiseProviders(long ortApiHandle) {
String[] providersArray = getAvailableProviders(ortApiHandle);
EnumSet<OrtProvider> providers = EnumSet.noneOf(OrtProvider.class);
for (String provider : providersArray) {
providers.add(OrtProvider.mapFromName(provider));
}
return providers;
}
/**
* Get a reference to the API struct.
*
* @param apiVersionNumber The API version to use.
* @return A pointer to the API struct.
*/
private static native long initialiseAPIBase(int apiVersionNumber);
/**
* Get a reference to the training API struct.
*
* @param apiHandle The ORT API struct pointer.
* @param apiVersionNumber The API version to use.
* @return A pointer to the training API struct.
*/
private static native long initialiseTrainingAPIBase(long apiHandle, int apiVersionNumber);
/**
* Get a reference to the compile API struct.
*
* @param apiHandle The ORT API struct pointer.
* @return A pointer to the compile API struct.
*/
private static native long initialiseCompileAPIBase(long apiHandle);
/**
* Gets the array of available providers.
*
* @param ortApiHandle The API handle
* @return The array of providers
*/
private static native String[] getAvailableProviders(long ortApiHandle);
/**
* Gets the version string from the native library.
*
* @return The version string.
*/
private static native String initialiseVersion();
}