Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ enum ModelField implements FieldSelector {
LAST_MODIFIED_TIME("lastModifiedTime"),
LOCATION("location"),
MODEL_REFERENCE("modelReference"),
TIME_PARTITIONING("timePartitioning"),
TRAINING_RUNS("trainingRuns"),
LABEL_COLUMNS("labelColumns"),
FEATURE_COLUMNS("featureColumns"),
TYPE("modelType");

static final List<? extends FieldSelector> REQUIRED_FIELDS = ImmutableList.of(MODEL_REFERENCE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@

import static com.google.common.base.Preconditions.checkNotNull;

import com.google.api.services.bigquery.model.StandardSqlField;
import com.google.api.services.bigquery.model.TrainingRun;
import com.google.cloud.bigquery.BigQuery.ModelOption;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.List;
import java.util.Map;
import java.util.Objects;

Expand Down Expand Up @@ -106,6 +109,24 @@ public Builder setLabels(Map<String, String> labels) {
return this;
}

@Override
Builder setTrainingRuns(List<TrainingRun> trainingRunList) {
infoBuilder.setTrainingRuns(trainingRunList);
return this;
}

@Override
Builder setLabelColumns(List<StandardSqlField> labelColumnList) {
infoBuilder.setLabelColumns(labelColumnList);
return this;
}

@Override
Builder setFeatureColumns(List<StandardSqlField> featureColumnList) {
infoBuilder.setFeatureColumns(featureColumnList);
return this;
}

public Model build() {
return new Model(bigquery, infoBuilder);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,17 @@

import static com.google.common.base.Preconditions.checkNotNull;

import com.google.api.core.BetaApi;
import com.google.api.services.bigquery.model.Model;
import com.google.api.services.bigquery.model.StandardSqlField;
import com.google.api.services.bigquery.model.TrainingRun;
import com.google.common.base.Function;
import com.google.common.base.MoreObjects;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import java.io.Serializable;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;

Expand Down Expand Up @@ -62,6 +68,9 @@ public Model apply(ModelInfo ModelInfo) {
private final Long lastModifiedTime;
private final Long expirationTime;
private final Labels labels;
private final ImmutableList<TrainingRun> trainingRunList;
private final ImmutableList<StandardSqlField> featureColumnList;
private final ImmutableList<StandardSqlField> labelColumnList;

/** A builder for {@code ModelInfo} objects. */
public abstract static class Builder {
Expand Down Expand Up @@ -97,6 +106,12 @@ public abstract static class Builder {

abstract Builder setLastModifiedTime(Long lastModifiedTime);

abstract Builder setTrainingRuns(List<TrainingRun> trainingRunList);

abstract Builder setLabelColumns(List<StandardSqlField> labelColumnList);

abstract Builder setFeatureColumns(List<StandardSqlField> featureColumnList);

/** Creates a {@code ModelInfo} object. */
public abstract ModelInfo build();
}
Expand All @@ -112,6 +127,9 @@ static class BuilderImpl extends Builder {
private Long lastModifiedTime;
private Long expirationTime;
private Labels labels = Labels.ZERO;
private List<TrainingRun> trainingRunList = Collections.emptyList();
private List<StandardSqlField> labelColumnList = Collections.emptyList();
private List<StandardSqlField> featureColumnList = Collections.emptyList();

BuilderImpl() {}

Expand All @@ -124,6 +142,9 @@ static class BuilderImpl extends Builder {
this.creationTime = modelInfo.creationTime;
this.lastModifiedTime = modelInfo.lastModifiedTime;
this.expirationTime = modelInfo.expirationTime;
this.trainingRunList = modelInfo.trainingRunList;
this.labelColumnList = modelInfo.labelColumnList;
this.featureColumnList = modelInfo.featureColumnList;
}

BuilderImpl(Model modelPb) {
Expand All @@ -139,6 +160,15 @@ static class BuilderImpl extends Builder {
this.lastModifiedTime = modelPb.getLastModifiedTime();
this.expirationTime = modelPb.getExpirationTime();
this.labels = Labels.fromPb(modelPb.getLabels());
if (modelPb.getTrainingRuns() != null) {
this.trainingRunList = modelPb.getTrainingRuns();
}
if (modelPb.getLabelColumns() != null) {
this.labelColumnList = modelPb.getLabelColumns();
}
if (modelPb.getFeatureColumns() != null) {
this.featureColumnList = modelPb.getFeatureColumns();
}
}

@Override
Expand Down Expand Up @@ -195,6 +225,24 @@ public Builder setLabels(Map<String, String> labels) {
return this;
}

@Override
Builder setTrainingRuns(List<TrainingRun> trainingRunList) {
this.trainingRunList = checkNotNull(trainingRunList);
return this;
}

@Override
Builder setLabelColumns(List<StandardSqlField> labelColumnList) {
this.labelColumnList = checkNotNull(labelColumnList);
return this;
}

@Override
Builder setFeatureColumns(List<StandardSqlField> featureColumnList) {
this.featureColumnList = checkNotNull(featureColumnList);
return this;
}

@Override
public ModelInfo build() {
return new ModelInfo(this);
Expand All @@ -211,6 +259,9 @@ public ModelInfo build() {
this.lastModifiedTime = builder.lastModifiedTime;
this.expirationTime = builder.expirationTime;
this.labels = builder.labels;
this.trainingRunList = ImmutableList.copyOf(builder.trainingRunList);
this.labelColumnList = ImmutableList.copyOf(builder.labelColumnList);
this.featureColumnList = ImmutableList.copyOf(builder.featureColumnList);
}

/** Returns the hash of the model resource. */
Expand Down Expand Up @@ -261,6 +312,24 @@ public Map<String, String> getLabels() {
return labels.userMap();
}

/** Returns metadata about each training run iteration. */
@BetaApi
public ImmutableList<TrainingRun> getTrainingRuns() {
return trainingRunList;
}

/** Returns information about the label columns for this model. */
@BetaApi
public ImmutableList<StandardSqlField> getLabelColumns() {
return labelColumnList;
}

/** Returns information about the feature columns for this model. */
@BetaApi
public ImmutableList<StandardSqlField> getFeatureColumns() {
return featureColumnList;
}

public Builder toBuilder() {
return new BuilderImpl(this);
}
Expand All @@ -277,6 +346,9 @@ public String toString() {
.add("lastModifiedTime", lastModifiedTime)
.add("expirationTime", expirationTime)
.add("labels", labels)
.add("trainingRuns", trainingRunList)
.add("labelColumns", labelColumnList)
.add("featureColumns", featureColumnList)
.toString();
}

Expand Down Expand Up @@ -321,6 +393,9 @@ Model toPb() {
modelPb.setLastModifiedTime(lastModifiedTime);
modelPb.setExpirationTime(expirationTime);
modelPb.setLabels(labels.toPb());
modelPb.setTrainingRuns(trainingRunList);
modelPb.setLabelColumns(labelColumnList);
modelPb.setFeatureColumns(featureColumnList);
return modelPb;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;

import com.google.api.services.bigquery.model.TrainingOptions;
import com.google.api.services.bigquery.model.TrainingRun;
import java.util.Arrays;
import java.util.List;
import org.junit.Test;

public class ModelInfoTest {
Expand All @@ -30,6 +34,12 @@ public class ModelInfoTest {
private static final String DESCRIPTION = "description";
private static final String FRIENDLY_NAME = "friendlyname";

private static final TrainingOptions TRAINING_OPTIONS =
new TrainingOptions().setDataSplitColumn("foo").setEarlyStop(true).setLossType("bar");
private static final TrainingRun TRAINING_RUN =
new TrainingRun().setTrainingOptions(TRAINING_OPTIONS);
private static final List<TrainingRun> TRAINING_RUN_LIST = Arrays.asList(TRAINING_RUN);

private static final ModelInfo MODEL_INFO =
ModelInfo.newBuilder(MODEL_ID)
.setEtag(ETAG)
Expand All @@ -38,6 +48,7 @@ public class ModelInfoTest {
.setLastModifiedTime(LAST_MODIFIED_TIME)
.setDescription(DESCRIPTION)
.setFriendlyName(FRIENDLY_NAME)
.setTrainingRuns(TRAINING_RUN_LIST)
.build();

@Test
Expand All @@ -59,6 +70,7 @@ public void testBuilder() {
assertEquals(EXPIRATION_TIME, MODEL_INFO.getExpirationTime());
assertEquals(DESCRIPTION, MODEL_INFO.getDescription());
assertEquals(FRIENDLY_NAME, MODEL_INFO.getFriendlyName());
assertEquals(TRAINING_OPTIONS, MODEL_INFO.getTrainingRuns().get(0).getTrainingOptions());
}

@Test
Expand All @@ -71,6 +83,9 @@ public void testOf() {
assertNull(modelInfo.getExpirationTime());
assertNull(modelInfo.getDescription());
assertNull(modelInfo.getFriendlyName());
assertEquals(modelInfo.getTrainingRuns().isEmpty(), true);
assertEquals(modelInfo.getLabelColumns().isEmpty(), true);
assertEquals(modelInfo.getFeatureColumns().isEmpty(), true);
}

@Test
Expand All @@ -94,5 +109,8 @@ private void compareModelInfo(ModelInfo expected, ModelInfo value) {
assertEquals(expected.getFriendlyName(), value.getFriendlyName());
assertEquals(expected.getLabels(), value.getLabels());
assertEquals(expected.hashCode(), value.hashCode());
assertEquals(expected.getTrainingRuns(), value.getTrainingRuns());
assertEquals(expected.getLabelColumns(), value.getLabelColumns());
assertEquals(expected.getFeatureColumns(), value.getFeatureColumns());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1083,6 +1083,11 @@ public void testModelLifecycle() throws InterruptedException {
Model model = bigquery.getModel(modelId);
assertNotNull(model);
assertEquals(model.getModelType(), "LINEAR_REGRESSION");
// Compare the extended model metadata.
assertEquals(model.getFeatureColumns().get(0).getName(), "f1");
assertEquals(model.getLabelColumns().get(0).getName(), "predicted_label");
assertEquals(
model.getTrainingRuns().get(0).getTrainingOptions().getLearnRateStrategy(), "CONSTANT");

// Mutate metadata.
ModelInfo info = model.toBuilder().setDescription("TEST").build();
Expand Down