@@ -16,31 +16,36 @@ def __init__(self, output, parameters):
1616 for param in gm .getParameters ():
1717 val = param .getBuf (api .PARAMETER_VALUE )
1818 name = param .getName ()
19- assert isinstance (val , api .Matrix )
20- val .copyFromNumpyMat (parameters .get (name ))
19+ assert isinstance (val , api .Vector )
20+ val .copyFromNumpyArray (parameters .get (name ). flatten ( ))
2121 self .__gradient_machine__ = gm
2222 self .__data_types__ = topo .data_type ()
2323
2424 def iter_infer (self , reader , reader_dict = None ):
25+ if reader_dict is None :
26+ reader_dict = self .default_reader_dict ()
2527 feeder = DataFeeder (self .__data_types__ , reader_dict )
26- out_args = api .Arguments .createArguments (0 )
2728 self .__gradient_machine__ .start ()
2829 for data_batch in reader ():
29- yield self .__gradient_machine__ .forwardTest (
30- feeder (data_batch ), out_args , api .PASS_TEST )
30+ yield self .__gradient_machine__ .forwardTest (feeder (data_batch ))
3131 self .__gradient_machine__ .finish ()
3232
3333 def iter_infer_field (self , field , ** kwargs ):
3434 for result in self .iter_infer (** kwargs ):
3535 yield [each_result [field ] for each_result in result ]
3636
3737 def infer (self , field = 'value' , ** kwargs ):
38- retv = []
39- for result in itertools .izip (
40- self .iter_infer_field (
41- field = field , ** kwargs )):
42- retv .append (numpy .concatenate (result ))
43- return retv
38+ retv = None
39+ for result in self .iter_infer_field (field = field , ** kwargs ):
40+ if retv is None :
41+ retv = [[]] * len (result )
42+ for i , item in enumerate (result ):
43+ retv [i ].append (item )
44+ retv = [numpy .concatenate (out ) for out in retv ]
45+ if len (retv ) == 1 :
46+ return retv [0 ]
47+ else :
48+ return retv
4449
4550 def default_reader_dict (self ):
4651 reader_dict = dict ()
0 commit comments