From 74cc552319647c425a6742fa4388c89cd121b0b3 Mon Sep 17 00:00:00 2001 From: Seth Hollyman Date: Thu, 23 May 2019 10:46:24 +0100 Subject: [PATCH 1/2] BigQuery: Augment BQ Model Metadata. This change introduces extended ML model metadata into the existing CRUD implementation for BigQuery ML models. This implementation exposes the underlying discovery types directly without an abstraction for the ML metadata. Methods that provide this metadata are all marked @BetaApi to indicate the getters may change. --- .../com/google/cloud/bigquery/BigQuery.java | 4 +- .../java/com/google/cloud/bigquery/Model.java | 21 ++++++ .../com/google/cloud/bigquery/ModelInfo.java | 75 +++++++++++++++++++ .../google/cloud/bigquery/ModelInfoTest.java | 18 +++++ .../cloud/bigquery/it/ITBigQueryTest.java | 5 ++ 5 files changed, 122 insertions(+), 1 deletion(-) diff --git a/google-cloud-clients/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/BigQuery.java b/google-cloud-clients/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/BigQuery.java index c9c2e8226673..a566fbeb8421 100644 --- a/google-cloud-clients/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/BigQuery.java +++ b/google-cloud-clients/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/BigQuery.java @@ -132,7 +132,9 @@ enum ModelField implements FieldSelector { LAST_MODIFIED_TIME("lastModifiedTime"), LOCATION("location"), MODEL_REFERENCE("modelReference"), - TIME_PARTITIONING("timePartitioning"), + TRAINING_RUNS("trainingRuns"), + LABEL_COLUMNS("labelColumns"), + FEATURE_COLUMNS("featureColumns"), TYPE("modelType"); static final List REQUIRED_FIELDS = ImmutableList.of(MODEL_REFERENCE); diff --git a/google-cloud-clients/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/Model.java b/google-cloud-clients/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/Model.java index 7606a5b25e78..76d3d08d645b 100644 --- a/google-cloud-clients/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/Model.java +++ b/google-cloud-clients/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/Model.java @@ -18,9 +18,12 @@ import static com.google.common.base.Preconditions.checkNotNull; +import com.google.api.services.bigquery.model.StandardSqlField; +import com.google.api.services.bigquery.model.TrainingRun; import com.google.cloud.bigquery.BigQuery.ModelOption; import java.io.IOException; import java.io.ObjectInputStream; +import java.util.List; import java.util.Map; import java.util.Objects; @@ -106,6 +109,24 @@ public Builder setLabels(Map labels) { return this; } + @Override + public Builder setTrainingRuns(List trainingRunList) { + infoBuilder.setTrainingRuns(trainingRunList); + return this; + } + + @Override + public Builder setLabelColumns(List labelColumnList) { + infoBuilder.setLabelColumns(labelColumnList); + return this; + } + + @Override + public Builder setFeatureColumns(List featureColumnList) { + infoBuilder.setFeatureColumns(featureColumnList); + return this; + } + public Model build() { return new Model(bigquery, infoBuilder); } diff --git a/google-cloud-clients/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/ModelInfo.java b/google-cloud-clients/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/ModelInfo.java index 15df4f492444..23f263708bda 100644 --- a/google-cloud-clients/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/ModelInfo.java +++ b/google-cloud-clients/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/ModelInfo.java @@ -18,11 +18,17 @@ import static com.google.common.base.Preconditions.checkNotNull; +import com.google.api.core.BetaApi; import com.google.api.services.bigquery.model.Model; +import com.google.api.services.bigquery.model.StandardSqlField; +import com.google.api.services.bigquery.model.TrainingRun; import com.google.common.base.Function; import com.google.common.base.MoreObjects; import com.google.common.base.Strings; +import com.google.common.collect.ImmutableList; import java.io.Serializable; +import java.util.Collections; +import java.util.List; import java.util.Map; import java.util.Objects; @@ -62,6 +68,9 @@ public Model apply(ModelInfo ModelInfo) { private final Long lastModifiedTime; private final Long expirationTime; private final Labels labels; + private final ImmutableList trainingRunList; + private final ImmutableList featureColumnList; + private final ImmutableList labelColumnList; /** A builder for {@code ModelInfo} objects. */ public abstract static class Builder { @@ -97,6 +106,12 @@ public abstract static class Builder { abstract Builder setLastModifiedTime(Long lastModifiedTime); + abstract Builder setTrainingRuns(List trainingRunList); + + abstract Builder setLabelColumns(List labelColumnList); + + abstract Builder setFeatureColumns(List featureColumnList); + /** Creates a {@code ModelInfo} object. */ public abstract ModelInfo build(); } @@ -112,6 +127,9 @@ static class BuilderImpl extends Builder { private Long lastModifiedTime; private Long expirationTime; private Labels labels = Labels.ZERO; + private List trainingRunList = Collections.emptyList(); + private List labelColumnList = Collections.emptyList(); + private List featureColumnList = Collections.emptyList(); BuilderImpl() {} @@ -124,6 +142,9 @@ static class BuilderImpl extends Builder { this.creationTime = modelInfo.creationTime; this.lastModifiedTime = modelInfo.lastModifiedTime; this.expirationTime = modelInfo.expirationTime; + this.trainingRunList = modelInfo.trainingRunList; + this.labelColumnList = modelInfo.labelColumnList; + this.featureColumnList = modelInfo.featureColumnList; } BuilderImpl(Model modelPb) { @@ -139,6 +160,15 @@ static class BuilderImpl extends Builder { this.lastModifiedTime = modelPb.getLastModifiedTime(); this.expirationTime = modelPb.getExpirationTime(); this.labels = Labels.fromPb(modelPb.getLabels()); + if (modelPb.getTrainingRuns() != null) { + this.trainingRunList = modelPb.getTrainingRuns(); + } + if (modelPb.getLabelColumns() != null) { + this.labelColumnList = modelPb.getLabelColumns(); + } + if (modelPb.getFeatureColumns() != null) { + this.featureColumnList = modelPb.getFeatureColumns(); + } } @Override @@ -195,6 +225,24 @@ public Builder setLabels(Map labels) { return this; } + @Override + Builder setTrainingRuns(List trainingRunList) { + this.trainingRunList = checkNotNull(trainingRunList); + return this; + } + + @Override + Builder setLabelColumns(List labelColumnList) { + this.labelColumnList = checkNotNull(labelColumnList); + return this; + } + + @Override + Builder setFeatureColumns(List featureColumnList) { + this.featureColumnList = checkNotNull(featureColumnList); + return this; + } + @Override public ModelInfo build() { return new ModelInfo(this); @@ -211,6 +259,9 @@ public ModelInfo build() { this.lastModifiedTime = builder.lastModifiedTime; this.expirationTime = builder.expirationTime; this.labels = builder.labels; + this.trainingRunList = ImmutableList.copyOf(builder.trainingRunList); + this.labelColumnList = ImmutableList.copyOf(builder.labelColumnList); + this.featureColumnList = ImmutableList.copyOf(builder.featureColumnList); } /** Returns the hash of the model resource. */ @@ -261,6 +312,24 @@ public Map getLabels() { return labels.userMap(); } + /** Returns metadata about each training run iteration. */ + @BetaApi + public ImmutableList getTrainingRuns() { + return trainingRunList; + } + + /** Returns information about the label columns for this model. */ + @BetaApi + public ImmutableList getLabelColumns() { + return labelColumnList; + } + + /** Returns information about the feature columns for this model. */ + @BetaApi + public ImmutableList getFeatureColumns() { + return featureColumnList; + } + public Builder toBuilder() { return new BuilderImpl(this); } @@ -277,6 +346,9 @@ public String toString() { .add("lastModifiedTime", lastModifiedTime) .add("expirationTime", expirationTime) .add("labels", labels) + .add("trainingRuns", trainingRunList) + .add("labelColumns", labelColumnList) + .add("featureColumns", featureColumnList) .toString(); } @@ -321,6 +393,9 @@ Model toPb() { modelPb.setLastModifiedTime(lastModifiedTime); modelPb.setExpirationTime(expirationTime); modelPb.setLabels(labels.toPb()); + modelPb.setTrainingRuns(trainingRunList); + modelPb.setLabelColumns(labelColumnList); + modelPb.setFeatureColumns(featureColumnList); return modelPb; } diff --git a/google-cloud-clients/google-cloud-bigquery/src/test/java/com/google/cloud/bigquery/ModelInfoTest.java b/google-cloud-clients/google-cloud-bigquery/src/test/java/com/google/cloud/bigquery/ModelInfoTest.java index c5590b1afb57..891ba07ff084 100644 --- a/google-cloud-clients/google-cloud-bigquery/src/test/java/com/google/cloud/bigquery/ModelInfoTest.java +++ b/google-cloud-clients/google-cloud-bigquery/src/test/java/com/google/cloud/bigquery/ModelInfoTest.java @@ -18,6 +18,10 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; +import com.google.api.services.bigquery.model.TrainingOptions; +import com.google.api.services.bigquery.model.TrainingRun; +import java.util.Arrays; +import java.util.List; import org.junit.Test; public class ModelInfoTest { @@ -30,6 +34,12 @@ public class ModelInfoTest { private static final String DESCRIPTION = "description"; private static final String FRIENDLY_NAME = "friendlyname"; + private static final TrainingOptions TRAINING_OPTIONS = + new TrainingOptions().setDataSplitColumn("foo").setEarlyStop(true).setLossType("bar"); + private static final TrainingRun TRAINING_RUN = + new TrainingRun().setTrainingOptions(TRAINING_OPTIONS); + private static final List TRAINING_RUN_LIST = Arrays.asList(TRAINING_RUN); + private static final ModelInfo MODEL_INFO = ModelInfo.newBuilder(MODEL_ID) .setEtag(ETAG) @@ -38,6 +48,7 @@ public class ModelInfoTest { .setLastModifiedTime(LAST_MODIFIED_TIME) .setDescription(DESCRIPTION) .setFriendlyName(FRIENDLY_NAME) + .setTrainingRuns(TRAINING_RUN_LIST) .build(); @Test @@ -59,6 +70,7 @@ public void testBuilder() { assertEquals(EXPIRATION_TIME, MODEL_INFO.getExpirationTime()); assertEquals(DESCRIPTION, MODEL_INFO.getDescription()); assertEquals(FRIENDLY_NAME, MODEL_INFO.getFriendlyName()); + assertEquals(TRAINING_OPTIONS, MODEL_INFO.getTrainingRuns().get(0).getTrainingOptions()); } @Test @@ -71,6 +83,9 @@ public void testOf() { assertNull(modelInfo.getExpirationTime()); assertNull(modelInfo.getDescription()); assertNull(modelInfo.getFriendlyName()); + assertEquals(modelInfo.getTrainingRuns().isEmpty(), true); + assertEquals(modelInfo.getLabelColumns().isEmpty(), true); + assertEquals(modelInfo.getFeatureColumns().isEmpty(), true); } @Test @@ -94,5 +109,8 @@ private void compareModelInfo(ModelInfo expected, ModelInfo value) { assertEquals(expected.getFriendlyName(), value.getFriendlyName()); assertEquals(expected.getLabels(), value.getLabels()); assertEquals(expected.hashCode(), value.hashCode()); + assertEquals(expected.getTrainingRuns(), value.getTrainingRuns()); + assertEquals(expected.getLabelColumns(), value.getLabelColumns()); + assertEquals(expected.getFeatureColumns(), value.getFeatureColumns()); } } diff --git a/google-cloud-clients/google-cloud-bigquery/src/test/java/com/google/cloud/bigquery/it/ITBigQueryTest.java b/google-cloud-clients/google-cloud-bigquery/src/test/java/com/google/cloud/bigquery/it/ITBigQueryTest.java index 7dc394add0f7..9305f75478eb 100644 --- a/google-cloud-clients/google-cloud-bigquery/src/test/java/com/google/cloud/bigquery/it/ITBigQueryTest.java +++ b/google-cloud-clients/google-cloud-bigquery/src/test/java/com/google/cloud/bigquery/it/ITBigQueryTest.java @@ -1083,6 +1083,11 @@ public void testModelLifecycle() throws InterruptedException { Model model = bigquery.getModel(modelId); assertNotNull(model); assertEquals(model.getModelType(), "LINEAR_REGRESSION"); + // Compare the extended model metadata. + assertEquals(model.getFeatureColumns().get(0).getName(), "f1"); + assertEquals(model.getLabelColumns().get(0).getName(), "predicted_label"); + assertEquals( + model.getTrainingRuns().get(0).getTrainingOptions().getLearnRateStrategy(), "CONSTANT"); // Mutate metadata. ModelInfo info = model.toBuilder().setDescription("TEST").build(); From 1f817f2b0483dd566bca975f1c59206b4de2a521 Mon Sep 17 00:00:00 2001 From: Seth Hollyman Date: Thu, 23 May 2019 14:32:34 +0100 Subject: [PATCH 2/2] nit: reduce visibility of set methods on Model for extended stats --- .../src/main/java/com/google/cloud/bigquery/Model.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/google-cloud-clients/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/Model.java b/google-cloud-clients/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/Model.java index 76d3d08d645b..fd06985e9882 100644 --- a/google-cloud-clients/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/Model.java +++ b/google-cloud-clients/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/Model.java @@ -110,19 +110,19 @@ public Builder setLabels(Map labels) { } @Override - public Builder setTrainingRuns(List trainingRunList) { + Builder setTrainingRuns(List trainingRunList) { infoBuilder.setTrainingRuns(trainingRunList); return this; } @Override - public Builder setLabelColumns(List labelColumnList) { + Builder setLabelColumns(List labelColumnList) { infoBuilder.setLabelColumns(labelColumnList); return this; } @Override - public Builder setFeatureColumns(List featureColumnList) { + Builder setFeatureColumns(List featureColumnList) { infoBuilder.setFeatureColumns(featureColumnList); return this; }