Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit e446564

Browse files
committed
Protect the determination of whether to dispatch to accelerated algos (at the callsite)
1 parent e24dc80 commit e446564

File tree

4 files changed

+37
-21
lines changed

4 files changed

+37
-21
lines changed

src/Microsoft.ML.FastTree/RandomForestClassification.cs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,15 @@ public static extern unsafe int DecisionForestClassificationCompute(
265265
[BestFriend]
266266
private bool IsDispatchingToOneDalEnabled()
267267
{
268-
return OneDalUtils.IsDispatchingEnabled();
268+
try
269+
{
270+
return OneDalUtils.IsDispatchingEnabled();
271+
}
272+
catch (Exception)
273+
{
274+
// Bail to default implementation upon encountering any situation where dispatch failed
275+
return false;
276+
}
269277
}
270278

271279
[BestFriend]

src/Microsoft.ML.FastTree/RandomForestRegression.cs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,15 @@ public static extern unsafe int DecisionForestRegressionCompute(
398398
[BestFriend]
399399
private bool IsDispatchingToOneDalEnabled()
400400
{
401-
return OneDalUtils.IsDispatchingEnabled();
401+
try
402+
{
403+
return OneDalUtils.IsDispatchingEnabled();
404+
}
405+
catch (Exception)
406+
{
407+
// fall back to original implementation for any circumstance that prevents dispatching
408+
return false;
409+
}
402410
}
403411

404412
[BestFriend]

src/Microsoft.ML.Mkl.Components/OlsLinearRegression.cs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
using Microsoft.ML.Internal.Internallearn;
1515
using Microsoft.ML.Internal.Utilities;
1616
using Microsoft.ML.Model;
17+
using Microsoft.ML.OneDal;
1718
using Microsoft.ML.Runtime;
1819
using Microsoft.ML.Trainers;
19-
using Microsoft.ML.OneDal;
2020

2121
[assembly: LoadableClass(OlsTrainer.Summary, typeof(OlsTrainer), typeof(OlsTrainer.Options),
2222
new[] { typeof(SignatureRegressorTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) },
@@ -409,7 +409,15 @@ private void ComputeMklRegression(IChannel ch, FloatLabelCursor.Factory cursorFa
409409
[BestFriend]
410410
private bool IsDispatchingToOneDalEnabled()
411411
{
412-
return OneDalUtils.IsDispatchingEnabled();
412+
try
413+
{
414+
return OneDalUtils.IsDispatchingEnabled();
415+
}
416+
catch (Exception)
417+
{
418+
// Bail to default implementation upon any situation that prevents dispatching
419+
return false;
420+
}
413421
}
414422

415423
private OlsModelParameters TrainCore(IChannel ch, FloatLabelCursor.Factory cursorFactory, int featureCount)

src/Microsoft.ML.OneDal/OneDalUtils.cs

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,32 +20,24 @@ internal static class OneDalUtils
2020
[BestFriend]
2121
internal static bool IsDispatchingEnabled()
2222
{
23-
try
23+
if (Environment.GetEnvironmentVariable("MLNET_BACKEND") == "ONEDAL" &&
24+
System.Runtime.InteropServices.RuntimeInformation.ProcessArchitecture == System.Runtime.InteropServices.Architecture.X64)
2425
{
25-
if (Environment.GetEnvironmentVariable("MLNET_BACKEND") == "ONEDAL" &&
26-
System.Runtime.InteropServices.RuntimeInformation.ProcessArchitecture == System.Runtime.InteropServices.Architecture.X64)
26+
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
2727
{
28-
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
29-
{
3028
#if NETFRAMEWORK
3129
// AppContext not available in the framework, user needs to set PATH manually
3230
// this will probably result in a runtime error where the user needs to set the PATH
3331
#else
34-
var currentDir = AppContext.BaseDirectory;
35-
var nativeLibs = Path.Combine(currentDir, "runtimes", "win-x64", "native");
36-
var originalPath = Environment.GetEnvironmentVariable("PATH");
37-
Environment.SetEnvironmentVariable("PATH", nativeLibs + ";" + originalPath);
32+
var currentDir = AppContext.BaseDirectory;
33+
var nativeLibs = Path.Combine(currentDir, "runtimes", "win-x64", "native");
34+
var originalPath = Environment.GetEnvironmentVariable("PATH");
35+
Environment.SetEnvironmentVariable("PATH", nativeLibs + ";" + originalPath);
3836
#endif
39-
}
40-
return true;
4137
}
42-
return false;
43-
}
44-
catch (Exception)
45-
{
46-
// fallback to default algorithm implementation if dispatch fails
47-
return false;
38+
return true;
4839
}
40+
return false;
4941
}
5042

5143
[BestFriend]

0 commit comments

Comments
 (0)