Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 9423e16

Browse files
committed
update ModelHelper's func interface in the examplar code
1 parent d119a04 commit 9423e16

1 file changed

Lines changed: 6 additions & 6 deletions

File tree

examples/convnet_at_fmnist.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,11 @@ def forward_fn(inputs, data_format):
6969
class ModelHelper(AbstractModelHelper):
7070
"""Model helper for creating a ConvNet model for the Fashion-MNIST dataset."""
7171

72-
def __init__(self):
72+
def __init__(self, data_format='channels_last'):
7373
"""Constructor function."""
7474

7575
# class-independent initialization
76-
super(ModelHelper, self).__init__()
76+
super(ModelHelper, self).__init__(data_format)
7777

7878
# initialize training & evaluation subsets
7979
self.dataset_train = FMnistDataset(is_train=True)
@@ -89,15 +89,15 @@ def build_dataset_eval(self):
8989

9090
return self.dataset_eval.build()
9191

92-
def forward_train(self, inputs, data_format='channels_last'):
92+
def forward_train(self, inputs):
9393
"""Forward computation at training."""
9494

95-
return forward_fn(inputs, data_format)
95+
return forward_fn(inputs, self.data_format)
9696

97-
def forward_eval(self, inputs, data_format='channels_last'):
97+
def forward_eval(self, inputs):
9898
"""Forward computation at evaluation."""
9999

100-
return forward_fn(inputs, data_format)
100+
return forward_fn(inputs, self.data_format)
101101

102102
def calc_loss(self, labels, outputs, trainable_vars):
103103
"""Calculate loss (and some extra evaluation metrics)."""

0 commit comments

Comments
 (0)