40
40
from pystruct .datasets import load_snakes
41
41
from pystruct .utils import make_grid_edges , edge_list_to_features
42
42
from pystruct .models import EdgeFeatureGraphCRF
43
- from pystruct .inference import get_installed
44
43
45
44
46
45
def one_hot_colors (x ):
47
46
x = x / 255
48
- flat = np .dot (x .reshape (- 1 , 3 ), 2 ** np .arange (3 ))
47
+ flat = np .dot (x .reshape (- 1 , 3 ), 2 ** np .arange (3 ))
49
48
one_hot = label_binarize (flat , classes = [1 , 2 , 3 , 4 , 6 ])
50
49
return one_hot .reshape (x .shape [0 ], x .shape [1 ], 5 )
51
50
@@ -93,7 +92,7 @@ def prepare_data(X):
93
92
return X_directions , X_edge_features
94
93
95
94
96
- print ("Please be patient. Will take 5-20 minutes." )
95
+ print ("Please be patient. Learning will take 5-20 minutes." )
97
96
snakes = load_snakes ()
98
97
X_train , Y_train = snakes ['X_train' ], snakes ['Y_train' ]
99
98
@@ -102,10 +101,7 @@ def prepare_data(X):
102
101
103
102
X_train_directions , X_train_edge_features = prepare_data (X_train )
104
103
105
- if 'ogm' in get_installed ():
106
- inference = ('ogm' , {'alg' : 'fm' })
107
- else :
108
- inference = 'qpbo'
104
+ inference = 'qpbo'
109
105
# first, train on X with directions only:
110
106
crf = EdgeFeatureGraphCRF (inference_method = inference )
111
107
ssvm = OneSlackSSVM (crf , inference_cache = 50 , C = .1 , tol = .1 , max_iter = 100 ,
0 commit comments