Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions src/Microsoft.ML.LightGbm/LightGbmBinaryTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.IO;
using Microsoft.ML;
using Microsoft.ML.Calibrators;
using Microsoft.ML.CommandLine;
Expand Down Expand Up @@ -228,6 +230,26 @@ internal LightGbmBinaryTrainer(IHostEnvironment env,
{
}

/// <summary>
/// Initializes a new instance of <see cref="LightGbmBinaryTrainer"/>
/// </summary>
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
/// <param name="lightGbmModel"> A pre-trained <see cref="System.IO.Stream"/> of a LightGBM model file inferencing</param>
/// <param name="featureColumnName">The name of the feature column.</param>
internal LightGbmBinaryTrainer(IHostEnvironment env,
Stream lightGbmModel,
string featureColumnName = DefaultColumnNames.Features)
: base(env,
LoadNameValue,
new Options()
{
FeatureColumnName = featureColumnName,
LightGbmModel = lightGbmModel
},
new SchemaShape.Column())
{
}

private protected override CalibratedModelParametersBase<LightGbmBinaryModelParameters, PlattCalibrator> CreatePredictor()
{
Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete");
Expand All @@ -241,11 +263,16 @@ private protected override void CheckDataValid(IChannel ch, RoleMappedData data)
{
Host.AssertValue(ch);
base.CheckDataValid(ch, data);
var labelType = data.Schema.Label.Value.Type;
if (!(labelType is BooleanDataViewType || labelType is KeyDataViewType || labelType == NumberDataViewType.Single))

// If using a pre-trained model file we don't need a label column
if (LightGbmTrainerOptions.LightGbmModel == null)
{
throw ch.ExceptParam(nameof(data),
$"Label column '{data.Schema.Label.Value.Name}' is of type '{labelType.RawType}', but must be unsigned int, boolean or float.");
var labelType = data.Schema.Label.Value.Type;
if (!(labelType is BooleanDataViewType || labelType is KeyDataViewType || labelType == NumberDataViewType.Single))
{
throw ch.ExceptParam(nameof(data),
$"Label column '{data.Schema.Label.Value.Name}' is of type '{labelType.RawType}', but must be unsigned int, boolean or float.");
}
}
}

Expand Down
65 changes: 65 additions & 0 deletions src/Microsoft.ML.LightGbm/LightGbmCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.IO;
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers.LightGbm;
Expand Down Expand Up @@ -67,6 +68,22 @@ public static LightGbmRegressionTrainer LightGbm(this RegressionCatalog.Regressi
return new LightGbmRegressionTrainer(env, options);
}

/// <summary>
/// Create <see cref="LightGbmRegressionTrainer"/> from a pre-trained LightGBM model, which predicts a target using a gradient boosting decision tree regression.
/// </summary>
/// <param name="catalog">The <see cref="RegressionCatalog"/>.</param>
/// <param name="lightGbmModel"> A pre-trained <see cref="System.IO.Stream"/> of a LightGBM model file inferencing</param>
/// <param name="featureColumnName">The name of the feature column. The column data must be a known-sized vector of <see cref="System.Single"/>.</param>
public static LightGbmRegressionTrainer LightGbm(this RegressionCatalog.RegressionTrainers catalog,
Copy link
Contributor

@LittleLittleCloud LittleLittleCloud May 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happen if loading an LightGBM model newer than 2.3.1, maybe we can add test for that scenario?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So at least "officially" we aren't going to support it. But @torronen said he would test it and let us know since the model format itself should be the same. Since we are going to be updating to the latest version though this year this is only a temporary situation.

Stream lightGbmModel,
string featureColumnName = DefaultColumnNames.Features
)
{
Contracts.CheckValue(catalog, nameof(catalog));
var env = CatalogUtils.GetEnvironment(catalog);
return new LightGbmRegressionTrainer(env, lightGbmModel, featureColumnName);
}

/// <summary>
/// Create <see cref="LightGbmBinaryTrainer"/>, which predicts a target using a gradient boosting decision tree binary classification.
/// </summary>
Expand Down Expand Up @@ -119,6 +136,22 @@ public static LightGbmBinaryTrainer LightGbm(this BinaryClassificationCatalog.Bi
return new LightGbmBinaryTrainer(env, options);
}

/// <summary>
/// Create <see cref="LightGbmBinaryTrainer"/> from a pre-trained LightGBM model, which predicts a target using a gradient boosting decision tree binary classification.
/// </summary>
/// <param name="catalog">The <see cref="BinaryClassificationCatalog"/>.</param>
/// <param name="lightGbmModel"> A pre-trained <see cref="System.IO.Stream"/> of a LightGBM model file inferencing</param>
/// <param name="featureColumnName">The name of the feature column. The column data must be a known-sized vector of <see cref="System.Single"/>.</param>
public static LightGbmBinaryTrainer LightGbm(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
Stream lightGbmModel,
string featureColumnName = DefaultColumnNames.Features
)
{
Contracts.CheckValue(catalog, nameof(catalog));
var env = CatalogUtils.GetEnvironment(catalog);
return new LightGbmBinaryTrainer(env, lightGbmModel, featureColumnName);
}

/// <summary>
/// Create <see cref="LightGbmRankingTrainer"/>, which predicts a target using a gradient boosting decision tree ranking model.
/// </summary>
Expand Down Expand Up @@ -174,6 +207,22 @@ public static LightGbmRankingTrainer LightGbm(this RankingCatalog.RankingTrainer
return new LightGbmRankingTrainer(env, options);
}

/// <summary>
/// Create <see cref="LightGbmRankingTrainer"/> from a pre-trained LightGBM model, which predicts a target using a gradient boosting decision tree ranking model.
/// </summary>
/// <param name="catalog">The <see cref="RankingCatalog"/>.</param>
/// <param name="lightGbmModel"> A pre-trained <see cref="System.IO.Stream"/> of a LightGBM model file inferencing</param>
/// <param name="featureColumnName">The name of the feature column. The column data must be a known-sized vector of <see cref="System.Single"/>.</param>
public static LightGbmRankingTrainer LightGbm(this RankingCatalog.RankingTrainers catalog,
Stream lightGbmModel,
string featureColumnName = DefaultColumnNames.Features
)
{
Contracts.CheckValue(catalog, nameof(catalog));
var env = CatalogUtils.GetEnvironment(catalog);
return new LightGbmRankingTrainer(env, lightGbmModel, featureColumnName);
}

/// <summary>
/// Create <see cref="LightGbmMulticlassTrainer"/>, which predicts a target using a gradient boosting decision tree multiclass classification model.
/// </summary>
Expand Down Expand Up @@ -225,5 +274,21 @@ public static LightGbmMulticlassTrainer LightGbm(this MulticlassClassificationCa
var env = CatalogUtils.GetEnvironment(catalog);
return new LightGbmMulticlassTrainer(env, options);
}

/// <summary>
/// Create <see cref="LightGbmMulticlassTrainer"/> from a pre-trained LightGBM model, which predicts a target using a gradient boosting decision tree multiclass classification model.
/// </summary>
/// <param name="catalog">The <see cref="MulticlassClassificationCatalog"/>.</param>
/// <param name="lightGbmModel"> A pre-trained <see cref="System.IO.Stream"/> of a LightGBM model file inferencing</param>
/// <param name="featureColumnName">The name of the feature column. The column data must be a known-sized vector of <see cref="System.Single"/>.</param>
public static LightGbmMulticlassTrainer LightGbm(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
Stream lightGbmModel,
string featureColumnName = DefaultColumnNames.Features
)
{
Contracts.CheckValue(catalog, nameof(catalog));
var env = CatalogUtils.GetEnvironment(catalog);
return new LightGbmMulticlassTrainer(env, lightGbmModel, featureColumnName);
}
}
}
60 changes: 52 additions & 8 deletions src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Calibrators;
Expand Down Expand Up @@ -170,6 +171,26 @@ internal LightGbmMulticlassTrainer(IHostEnvironment env,
{
}

/// <summary>
/// Initializes a new instance of <see cref="LightGbmRankingTrainer"/>
/// </summary>
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
/// <param name="lightGbmModel"> A pre-trained <see cref="System.IO.Stream"/> of a LightGBM model file inferencing</param>
/// <param name="featureColumnName">The name of the feature column.</param>
internal LightGbmMulticlassTrainer(IHostEnvironment env,
Stream lightGbmModel,
string featureColumnName = DefaultColumnNames.Features)
: base(env,
LoadNameValue,
new Options()
{
FeatureColumnName = featureColumnName,
LightGbmModel = lightGbmModel
},
new SchemaShape.Column())
{
}

private InternalTreeEnsemble GetBinaryEnsemble(int classID)
{
var res = new InternalTreeEnsemble();
Expand Down Expand Up @@ -213,11 +234,15 @@ private protected override void CheckDataValid(IChannel ch, RoleMappedData data)
{
Host.AssertValue(ch);
base.CheckDataValid(ch, data);
var labelType = data.Schema.Label.Value.Type;
if (!(labelType is BooleanDataViewType || labelType is KeyDataViewType || labelType == NumberDataViewType.Single))
// If using a pre-trained model file we don't need a label or group column
if (LightGbmTrainerOptions.LightGbmModel == null)
{
throw ch.ExceptParam(nameof(data),
$"Label column '{data.Schema.Label.Value.Name}' is of type '{labelType.RawType}', but must be of unsigned int, boolean or float.");
var labelType = data.Schema.Label.Value.Type;
if (!(labelType is BooleanDataViewType || labelType is KeyDataViewType || labelType == NumberDataViewType.Single))
{
throw ch.ExceptParam(nameof(data),
$"Label column '{data.Schema.Label.Value.Name}' is of type '{labelType.RawType}', but must be of unsigned int, boolean or float.");
}
}
}

Expand All @@ -227,6 +252,22 @@ private protected override void InitializeBeforeTraining()
_numberOfClasses = 0;
}

private protected override void AdditionalLoadPreTrainedModel(string modelText)
{
string[] lines = modelText.Split(new char[] { '\r', '\n' }, StringSplitOptions.RemoveEmptyEntries);
// Jump to the "objective" value in the file. It's at the beginning.
int i = 0;
while (!lines[i].StartsWith("objective"))
i++;

// Format in the file is objective=multiclass num_class:4
var split = lines[i].Split(' ');
GbmOptions["objective"] = split[0].Split('=')[1];
_numberOfClassesIncludingNan = int.Parse(split[1].Split(':')[1]);
_numberOfClasses = _numberOfClassesIncludingNan;
}


private protected override void ConvertNaNLabels(IChannel ch, RoleMappedData data, float[] labels)
{
// Only initialize one time.
Expand Down Expand Up @@ -317,11 +358,14 @@ private protected override void CheckAndUpdateParametersBeforeTraining(IChannel

private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
{
bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol);
Contracts.Assert(success);
SchemaShape.Column labelCol = default;
if (LightGbmTrainerOptions.LightGbmModel == null)
{
bool success = inputSchema.TryFindColumn(LabelColumn.Name, out labelCol);
Contracts.Assert(success);
}

var metadata = new SchemaShape(labelCol.Annotations.Where(x => x.Name == AnnotationUtils.Kinds.KeyValues)
.Concat(AnnotationUtils.GetTrainerOutputAnnotation()));
var metadata = LightGbmTrainerOptions.LightGbmModel == null ? new SchemaShape(labelCol.Annotations.Where(x => x.Name == AnnotationUtils.Kinds.KeyValues).Concat(AnnotationUtils.GetTrainerOutputAnnotation())) : new SchemaShape(AnnotationUtils.GetTrainerOutputAnnotation());
return new[]
{
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false, new SchemaShape(AnnotationUtils.AnnotationsForMulticlassScoreColumn(labelCol))),
Expand Down
58 changes: 42 additions & 16 deletions src/Microsoft.ML.LightGbm/LightGbmRankingTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System;
using System.Collections.Generic;
using System.IO;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
Expand Down Expand Up @@ -215,27 +216,52 @@ internal LightGbmRankingTrainer(IHostEnvironment env,
Host.CheckNonEmpty(rowGroupIdColumnName, nameof(rowGroupIdColumnName));
}

/// <summary>
/// Initializes a new instance of <see cref="LightGbmRankingTrainer"/>
/// </summary>
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
/// <param name="lightGbmModel"> A pre-trained <see cref="System.IO.Stream"/> of a LightGBM model file inferencing</param>
/// <param name="featureColumnName">The name of the feature column.</param>
internal LightGbmRankingTrainer(IHostEnvironment env,
Stream lightGbmModel,
string featureColumnName = DefaultColumnNames.Features)
: base(env,
LoadNameValue,
new Options()
{
FeatureColumnName = featureColumnName,
LightGbmModel = lightGbmModel
},
new SchemaShape.Column())
{
}

private protected override void CheckDataValid(IChannel ch, RoleMappedData data)
{
Host.AssertValue(ch);
base.CheckDataValid(ch, data);
// Check label types.
var labelCol = data.Schema.Label.Value;
var labelType = labelCol.Type;
if (!(labelType is KeyDataViewType || labelType == NumberDataViewType.Single))
{
throw ch.ExceptParam(nameof(data),
$"Label column '{labelCol.Name}' is of type '{labelType.RawType}', but must be Key or Single.");
}
// Check group types.
if (!data.Schema.Group.HasValue)
throw ch.ExceptValue(nameof(data.Schema.Group), "Group column is missing.");
var groupCol = data.Schema.Group.Value;
var groupType = groupCol.Type;
if (!(groupType == NumberDataViewType.UInt32 || groupType is KeyDataViewType))

// If using a pre-trained model file we don't need a label or group column
if (LightGbmTrainerOptions.LightGbmModel == null)
{
throw ch.ExceptParam(nameof(data),
$"Group column '{groupCol.Name}' is of type '{groupType.RawType}', but must be UInt32 or Key.");
// Check label types.
var labelCol = data.Schema.Label.Value;
var labelType = labelCol.Type;
if (!(labelType is KeyDataViewType || labelType == NumberDataViewType.Single))
{
throw ch.ExceptParam(nameof(data),
$"Label column '{labelCol.Name}' is of type '{labelType.RawType}', but must be Key or Single.");
}
// Check group types.
if (!data.Schema.Group.HasValue)
throw ch.ExceptValue(nameof(data.Schema.Group), "Group column is missing.");
var groupCol = data.Schema.Group.Value;
var groupType = groupCol.Type;
if (!(groupType == NumberDataViewType.UInt32 || groupType is KeyDataViewType))
{
throw ch.ExceptParam(nameof(data),
$"Group column '{groupCol.Name}' is of type '{groupType.RawType}', but must be UInt32 or Key.");
}
}
}

Expand Down
34 changes: 30 additions & 4 deletions src/Microsoft.ML.LightGbm/LightGbmRegressionTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using System.IO;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
Expand Down Expand Up @@ -192,6 +193,26 @@ internal LightGbmRegressionTrainer(IHostEnvironment env, Options options)
{
}

/// <summary>
/// Initializes a new instance of <see cref="LightGbmRegressionTrainer"/>
/// </summary>
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
/// <param name="lightGbmModel"> A pre-trained <see cref="System.IO.Stream"/> of a LightGBM model file inferencing</param>
/// <param name="featureColumnName">The name of the feature column.</param>
internal LightGbmRegressionTrainer(IHostEnvironment env,
Stream lightGbmModel,
string featureColumnName = DefaultColumnNames.Features)
: base(env,
LoadNameValue,
new Options()
{
FeatureColumnName = featureColumnName,
LightGbmModel = lightGbmModel
},
new SchemaShape.Column())
{
}

private protected override LightGbmRegressionModelParameters CreatePredictor()
{
Host.Check(TrainedEnsemble != null,
Expand All @@ -204,11 +225,16 @@ private protected override void CheckDataValid(IChannel ch, RoleMappedData data)
{
Host.AssertValue(ch);
base.CheckDataValid(ch, data);
var labelType = data.Schema.Label.Value.Type;
if (!(labelType is BooleanDataViewType || labelType is KeyDataViewType || labelType == NumberDataViewType.Single))

// If using a pre-trained model file we don't need a label column
if (LightGbmTrainerOptions.LightGbmModel == null)
{
throw ch.ExceptParam(nameof(data),
$"Label column '{data.Schema.Label.Value.Name}' is of type '{labelType.RawType}', but must be an unsigned int, boolean or float.");
var labelType = data.Schema.Label.Value.Type;
if (!(labelType is BooleanDataViewType || labelType is KeyDataViewType || labelType == NumberDataViewType.Single))
{
throw ch.ExceptParam(nameof(data),
$"Label column '{data.Schema.Label.Value.Name}' is of type '{labelType.RawType}', but must be an unsigned int, boolean or float.");
}
}
}

Expand Down
Loading