@@ -69,11 +69,11 @@ def forward_fn(inputs, data_format):
6969class ModelHelper (AbstractModelHelper ):
7070 """Model helper for creating a ConvNet model for the Fashion-MNIST dataset."""
7171
72- def __init__ (self ):
72+ def __init__ (self , data_format = 'channels_last' ):
7373 """Constructor function."""
7474
7575 # class-independent initialization
76- super (ModelHelper , self ).__init__ ()
76+ super (ModelHelper , self ).__init__ (data_format )
7777
7878 # initialize training & evaluation subsets
7979 self .dataset_train = FMnistDataset (is_train = True )
@@ -89,15 +89,15 @@ def build_dataset_eval(self):
8989
9090 return self .dataset_eval .build ()
9191
92- def forward_train (self , inputs , data_format = 'channels_last' ):
92+ def forward_train (self , inputs ):
9393 """Forward computation at training."""
9494
95- return forward_fn (inputs , data_format )
95+ return forward_fn (inputs , self . data_format )
9696
97- def forward_eval (self , inputs , data_format = 'channels_last' ):
97+ def forward_eval (self , inputs ):
9898 """Forward computation at evaluation."""
9999
100- return forward_fn (inputs , data_format )
100+ return forward_fn (inputs , self . data_format )
101101
102102 def calc_loss (self , labels , outputs , trainable_vars ):
103103 """Calculate loss (and some extra evaluation metrics)."""
0 commit comments