-
Notifications
You must be signed in to change notification settings - Fork 139
Closed
Description
Implement a clean API to use ATM from Python.
The usage should follow this pattern:
>>> from atm import data
>>> demo_datasets = data.get_demos()
>>> demo_datasets
{
'iris': 'demos/iris.csv',
'pollution': 'demos/pollution.csv',
'pitchfork_genres': 'demos/pitchfork_genres.csv'
}
>>> from atm import ATM
>>> atm = ATM() # Additional DB connection arguments can be passed here
>>> path_to_csv = demo_datasets['pollution']
>>> results = atm.run(train_path=path_to_csv) # Additional dataset and datarun config an be passed here
Processing dataset demos/pollution.csv
100%|##########################| 100/100 [00:10<00:00, 6.09it/s]
>>> results.describe()
Datarun 1 summary:
Dataset: 'demos/pollution.csv'
Column Name: 'class'
Judgment Metric: 'f1'
Classifiers Tested: 100
Elapsed Time: 0:00:07.638668
>>> results.get_best_classifier()
Classifier id: 94
Classifier type: knn
Params chosen:
n_neighbors: 13
leaf_size: 38
weights: uniform
algorithm: kd_tree
metric: manhattan
_scale: True
Cross Validation Score: 0.858 +- 0.096
Test Score: 0.714
>>> scores = results.get_scores()
>>> scores.head()
cv_judgment_metric cv_judgment_metric_stdev id test_judgment_metric rank
0 0.8584126984 0.0960095737 94 0.7142857143 1.0
1 0.8222222222 0.0623609564 12 0.6250000000 2.0
2 0.8147619048 0.1117618135 64 0.8750000000 3.0
3 0.8139393939 0.0588721670 68 0.6086956522 4.0
4 0.8067754468 0.0875180564 50 0.6250000000 5.0
>>> results.export_best_classifier('path/to/model.pkl')
Classifier 94 saved as path/to/model.pkl
>>> from atm import Model
>>> model = Model.load('path/to/model.pkl')
>>> import pandas as pd
>>> data = pd.read_csv(demo_datasets['pollution'])
>>> predictions = model.predict(data.head())