@@ -87,43 +87,65 @@ def get_labels(self):
8787 return np .asarray (self ._label )
8888
8989
90- def data_generator (batch_size , width_height , file_name ):
91- _dataset = Dataset (file_name , True , width_height )
92-
93- def get_epoch ():
94-
95- _index_in_epoch = 0
96- _perm = np .arange (_dataset .n_samples )
97- np .random .shuffle (_perm )
98- for _ in range (int (math .ceil (_dataset .n_samples / batch_size ))):
99- start = _index_in_epoch
100- _index_in_epoch += batch_size
101- # finish one epoch
102- if _index_in_epoch > _dataset .n_samples :
103- data , label = _dataset .data (_perm [start :])
104- data1 , label1 = _dataset .data (
105- _perm [:_index_in_epoch - _dataset .n_samples ])
106- data = np .concatenate ([data , data1 ], axis = 0 )
107- label = np .concatenate ([label , label1 ], axis = 0 )
108- else :
109- end = _index_in_epoch
110- data , label = _dataset .data (_perm [start :end ])
111-
112- # n*h*w*c -> n*c*h*w
113- data = np .transpose (data , (0 , 3 , 1 , 2 ))
114- # bgr -> rgb
115- data = data [:, ::- 1 , :, :]
116- data = np .reshape (data , (batch_size , - 1 ))
117- yield (data , label )
118-
119- return get_epoch
120-
121-
122- def load_train (batch_size , width_height , data_root ):
123- return [data_generator (batch_size , width_height , os .path .join (data_root , split + '.txt' ))
124- for split in ["train" , "database_nolabel" , "test" ]]
125-
126- def load_val (batch_size , width_height , data_root ):
127- return [data_generator (batch_size , width_height , os .path .join (data_root , split + '.txt' ))
128- for split in ["database" , "test.txt" ]]
90+ class Dataloader (object ):
91+
92+ def __init__ (self , batch_size , width_height , data_root ):
93+ self .batch_size = batch_size
94+ self .width_height = width_height
95+ self .data_root = data_root
96+
97+ def data_generator (self , split ):
98+ file_name = os .path .join (self .data_root , split + '.txt' )
99+ _dataset = Dataset (file_name , True , self .width_height )
100+
101+ def get_epoch ():
102+
103+ _index_in_epoch = 0
104+ _perm = np .arange (_dataset .n_samples )
105+ np .random .shuffle (_perm )
106+ for _ in range (int (math .ceil (_dataset .n_samples / self .batch_size ))):
107+ start = _index_in_epoch
108+ _index_in_epoch += self .batch_size
109+ # finish one epoch
110+ if _index_in_epoch > _dataset .n_samples :
111+ data , label = _dataset .data (_perm [start :])
112+ data1 , label1 = _dataset .data (
113+ _perm [:_index_in_epoch - _dataset .n_samples ])
114+ data = np .concatenate ([data , data1 ], axis = 0 )
115+ label = np .concatenate ([label , label1 ], axis = 0 )
116+ else :
117+ end = _index_in_epoch
118+ data , label = _dataset .data (_perm [start :end ])
119+
120+ # n*h*w*c -> n*c*h*w
121+ data = np .transpose (data , (0 , 3 , 1 , 2 ))
122+ # bgr -> rgb
123+ data = data [:, ::- 1 , :, :]
124+ data = np .reshape (data , (self .batch_size , - 1 ))
125+ yield (data , label )
126+
127+ return get_epoch
129128
129+ @property
130+ def train_gen (self ):
131+ return self .data_generator ('train' )
132+
133+ @property
134+ def test_gen (self ):
135+ return self .data_generator ('test' )
136+
137+ @property
138+ def db_gen (self ):
139+ return self .data_generator ('database' )
140+
141+ @property
142+ def unlabeled_db_gen (self ):
143+ return self .data_generator ('database_nolabel' )
144+
145+ @staticmethod
146+ def inf_gen (gen ):
147+ def generator ():
148+ while True :
149+ for images_iter_ , labels_iter_ in gen ():
150+ return images_iter_ , labels_iter_
151+ return generator
0 commit comments