diff --git a/google-cloud-clients/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/BigQuery.java b/google-cloud-clients/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/BigQuery.java index c9c2e8226673..a566fbeb8421 100644 --- a/google-cloud-clients/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/BigQuery.java +++ b/google-cloud-clients/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/BigQuery.java @@ -132,7 +132,9 @@ enum ModelField implements FieldSelector { LAST_MODIFIED_TIME("lastModifiedTime"), LOCATION("location"), MODEL_REFERENCE("modelReference"), - TIME_PARTITIONING("timePartitioning"), + TRAINING_RUNS("trainingRuns"), + LABEL_COLUMNS("labelColumns"), + FEATURE_COLUMNS("featureColumns"), TYPE("modelType"); static final List REQUIRED_FIELDS = ImmutableList.of(MODEL_REFERENCE); diff --git a/google-cloud-clients/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/Model.java b/google-cloud-clients/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/Model.java index 7606a5b25e78..fd06985e9882 100644 --- a/google-cloud-clients/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/Model.java +++ b/google-cloud-clients/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/Model.java @@ -18,9 +18,12 @@ import static com.google.common.base.Preconditions.checkNotNull; +import com.google.api.services.bigquery.model.StandardSqlField; +import com.google.api.services.bigquery.model.TrainingRun; import com.google.cloud.bigquery.BigQuery.ModelOption; import java.io.IOException; import java.io.ObjectInputStream; +import java.util.List; import java.util.Map; import java.util.Objects; @@ -106,6 +109,24 @@ public Builder setLabels(Map labels) { return this; } + @Override + Builder setTrainingRuns(List trainingRunList) { + infoBuilder.setTrainingRuns(trainingRunList); + return this; + } + + @Override + Builder setLabelColumns(List labelColumnList) { + infoBuilder.setLabelColumns(labelColumnList); + return this; + } + + @Override + Builder setFeatureColumns(List featureColumnList) { + infoBuilder.setFeatureColumns(featureColumnList); + return this; + } + public Model build() { return new Model(bigquery, infoBuilder); } diff --git a/google-cloud-clients/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/ModelInfo.java b/google-cloud-clients/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/ModelInfo.java index 15df4f492444..23f263708bda 100644 --- a/google-cloud-clients/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/ModelInfo.java +++ b/google-cloud-clients/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/ModelInfo.java @@ -18,11 +18,17 @@ import static com.google.common.base.Preconditions.checkNotNull; +import com.google.api.core.BetaApi; import com.google.api.services.bigquery.model.Model; +import com.google.api.services.bigquery.model.StandardSqlField; +import com.google.api.services.bigquery.model.TrainingRun; import com.google.common.base.Function; import com.google.common.base.MoreObjects; import com.google.common.base.Strings; +import com.google.common.collect.ImmutableList; import java.io.Serializable; +import java.util.Collections; +import java.util.List; import java.util.Map; import java.util.Objects; @@ -62,6 +68,9 @@ public Model apply(ModelInfo ModelInfo) { private final Long lastModifiedTime; private final Long expirationTime; private final Labels labels; + private final ImmutableList trainingRunList; + private final ImmutableList featureColumnList; + private final ImmutableList labelColumnList; /** A builder for {@code ModelInfo} objects. */ public abstract static class Builder { @@ -97,6 +106,12 @@ public abstract static class Builder { abstract Builder setLastModifiedTime(Long lastModifiedTime); + abstract Builder setTrainingRuns(List trainingRunList); + + abstract Builder setLabelColumns(List labelColumnList); + + abstract Builder setFeatureColumns(List featureColumnList); + /** Creates a {@code ModelInfo} object. */ public abstract ModelInfo build(); } @@ -112,6 +127,9 @@ static class BuilderImpl extends Builder { private Long lastModifiedTime; private Long expirationTime; private Labels labels = Labels.ZERO; + private List trainingRunList = Collections.emptyList(); + private List labelColumnList = Collections.emptyList(); + private List featureColumnList = Collections.emptyList(); BuilderImpl() {} @@ -124,6 +142,9 @@ static class BuilderImpl extends Builder { this.creationTime = modelInfo.creationTime; this.lastModifiedTime = modelInfo.lastModifiedTime; this.expirationTime = modelInfo.expirationTime; + this.trainingRunList = modelInfo.trainingRunList; + this.labelColumnList = modelInfo.labelColumnList; + this.featureColumnList = modelInfo.featureColumnList; } BuilderImpl(Model modelPb) { @@ -139,6 +160,15 @@ static class BuilderImpl extends Builder { this.lastModifiedTime = modelPb.getLastModifiedTime(); this.expirationTime = modelPb.getExpirationTime(); this.labels = Labels.fromPb(modelPb.getLabels()); + if (modelPb.getTrainingRuns() != null) { + this.trainingRunList = modelPb.getTrainingRuns(); + } + if (modelPb.getLabelColumns() != null) { + this.labelColumnList = modelPb.getLabelColumns(); + } + if (modelPb.getFeatureColumns() != null) { + this.featureColumnList = modelPb.getFeatureColumns(); + } } @Override @@ -195,6 +225,24 @@ public Builder setLabels(Map labels) { return this; } + @Override + Builder setTrainingRuns(List trainingRunList) { + this.trainingRunList = checkNotNull(trainingRunList); + return this; + } + + @Override + Builder setLabelColumns(List labelColumnList) { + this.labelColumnList = checkNotNull(labelColumnList); + return this; + } + + @Override + Builder setFeatureColumns(List featureColumnList) { + this.featureColumnList = checkNotNull(featureColumnList); + return this; + } + @Override public ModelInfo build() { return new ModelInfo(this); @@ -211,6 +259,9 @@ public ModelInfo build() { this.lastModifiedTime = builder.lastModifiedTime; this.expirationTime = builder.expirationTime; this.labels = builder.labels; + this.trainingRunList = ImmutableList.copyOf(builder.trainingRunList); + this.labelColumnList = ImmutableList.copyOf(builder.labelColumnList); + this.featureColumnList = ImmutableList.copyOf(builder.featureColumnList); } /** Returns the hash of the model resource. */ @@ -261,6 +312,24 @@ public Map getLabels() { return labels.userMap(); } + /** Returns metadata about each training run iteration. */ + @BetaApi + public ImmutableList getTrainingRuns() { + return trainingRunList; + } + + /** Returns information about the label columns for this model. */ + @BetaApi + public ImmutableList getLabelColumns() { + return labelColumnList; + } + + /** Returns information about the feature columns for this model. */ + @BetaApi + public ImmutableList getFeatureColumns() { + return featureColumnList; + } + public Builder toBuilder() { return new BuilderImpl(this); } @@ -277,6 +346,9 @@ public String toString() { .add("lastModifiedTime", lastModifiedTime) .add("expirationTime", expirationTime) .add("labels", labels) + .add("trainingRuns", trainingRunList) + .add("labelColumns", labelColumnList) + .add("featureColumns", featureColumnList) .toString(); } @@ -321,6 +393,9 @@ Model toPb() { modelPb.setLastModifiedTime(lastModifiedTime); modelPb.setExpirationTime(expirationTime); modelPb.setLabels(labels.toPb()); + modelPb.setTrainingRuns(trainingRunList); + modelPb.setLabelColumns(labelColumnList); + modelPb.setFeatureColumns(featureColumnList); return modelPb; } diff --git a/google-cloud-clients/google-cloud-bigquery/src/test/java/com/google/cloud/bigquery/ModelInfoTest.java b/google-cloud-clients/google-cloud-bigquery/src/test/java/com/google/cloud/bigquery/ModelInfoTest.java index c5590b1afb57..891ba07ff084 100644 --- a/google-cloud-clients/google-cloud-bigquery/src/test/java/com/google/cloud/bigquery/ModelInfoTest.java +++ b/google-cloud-clients/google-cloud-bigquery/src/test/java/com/google/cloud/bigquery/ModelInfoTest.java @@ -18,6 +18,10 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; +import com.google.api.services.bigquery.model.TrainingOptions; +import com.google.api.services.bigquery.model.TrainingRun; +import java.util.Arrays; +import java.util.List; import org.junit.Test; public class ModelInfoTest { @@ -30,6 +34,12 @@ public class ModelInfoTest { private static final String DESCRIPTION = "description"; private static final String FRIENDLY_NAME = "friendlyname"; + private static final TrainingOptions TRAINING_OPTIONS = + new TrainingOptions().setDataSplitColumn("foo").setEarlyStop(true).setLossType("bar"); + private static final TrainingRun TRAINING_RUN = + new TrainingRun().setTrainingOptions(TRAINING_OPTIONS); + private static final List TRAINING_RUN_LIST = Arrays.asList(TRAINING_RUN); + private static final ModelInfo MODEL_INFO = ModelInfo.newBuilder(MODEL_ID) .setEtag(ETAG) @@ -38,6 +48,7 @@ public class ModelInfoTest { .setLastModifiedTime(LAST_MODIFIED_TIME) .setDescription(DESCRIPTION) .setFriendlyName(FRIENDLY_NAME) + .setTrainingRuns(TRAINING_RUN_LIST) .build(); @Test @@ -59,6 +70,7 @@ public void testBuilder() { assertEquals(EXPIRATION_TIME, MODEL_INFO.getExpirationTime()); assertEquals(DESCRIPTION, MODEL_INFO.getDescription()); assertEquals(FRIENDLY_NAME, MODEL_INFO.getFriendlyName()); + assertEquals(TRAINING_OPTIONS, MODEL_INFO.getTrainingRuns().get(0).getTrainingOptions()); } @Test @@ -71,6 +83,9 @@ public void testOf() { assertNull(modelInfo.getExpirationTime()); assertNull(modelInfo.getDescription()); assertNull(modelInfo.getFriendlyName()); + assertEquals(modelInfo.getTrainingRuns().isEmpty(), true); + assertEquals(modelInfo.getLabelColumns().isEmpty(), true); + assertEquals(modelInfo.getFeatureColumns().isEmpty(), true); } @Test @@ -94,5 +109,8 @@ private void compareModelInfo(ModelInfo expected, ModelInfo value) { assertEquals(expected.getFriendlyName(), value.getFriendlyName()); assertEquals(expected.getLabels(), value.getLabels()); assertEquals(expected.hashCode(), value.hashCode()); + assertEquals(expected.getTrainingRuns(), value.getTrainingRuns()); + assertEquals(expected.getLabelColumns(), value.getLabelColumns()); + assertEquals(expected.getFeatureColumns(), value.getFeatureColumns()); } } diff --git a/google-cloud-clients/google-cloud-bigquery/src/test/java/com/google/cloud/bigquery/it/ITBigQueryTest.java b/google-cloud-clients/google-cloud-bigquery/src/test/java/com/google/cloud/bigquery/it/ITBigQueryTest.java index 7dc394add0f7..9305f75478eb 100644 --- a/google-cloud-clients/google-cloud-bigquery/src/test/java/com/google/cloud/bigquery/it/ITBigQueryTest.java +++ b/google-cloud-clients/google-cloud-bigquery/src/test/java/com/google/cloud/bigquery/it/ITBigQueryTest.java @@ -1083,6 +1083,11 @@ public void testModelLifecycle() throws InterruptedException { Model model = bigquery.getModel(modelId); assertNotNull(model); assertEquals(model.getModelType(), "LINEAR_REGRESSION"); + // Compare the extended model metadata. + assertEquals(model.getFeatureColumns().get(0).getName(), "f1"); + assertEquals(model.getLabelColumns().get(0).getName(), "predicted_label"); + assertEquals( + model.getTrainingRuns().get(0).getTrainingOptions().getLearnRateStrategy(), "CONSTANT"); // Mutate metadata. ModelInfo info = model.toBuilder().setDescription("TEST").build();