|
9 | 9 | from __future__ import print_function, division |
10 | 10 | import sys |
11 | 11 | import torch |
12 | | -from coplenet import COPLENet |
13 | | -from pymic.net_run.net_run import TrainInferAgent |
14 | 12 | from pymic.util.parse_config import parse_config |
| 13 | +from pymic.net_run.net_run_agent import NetRunAgent |
| 14 | +from pymic.net.net_dict import NetDict |
| 15 | +from coplenet import COPLENet |
| 16 | + |
| 17 | +my_net_dict = NetDict |
| 18 | +my_net_dict['COPLENet'] = COPLENet |
15 | 19 |
|
16 | 20 | def main(): |
17 | | - if(len(sys.argv) < 2): |
18 | | - print('Number of arguments should be 2. e.g.') |
19 | | - print(' python net_run.py config.cfg') |
| 21 | + if(len(sys.argv) < 3): |
| 22 | + print('Number of arguments should be 3. e.g.') |
| 23 | + print(' python train_infer.py train config.cfg') |
20 | 24 | exit() |
21 | 25 | cfg_file = str(sys.argv[1]) |
22 | 26 | config = parse_config(cfg_file) |
23 | 27 |
|
24 | | - # parameters of COPLENet |
25 | | - net_param = {"class_num" : 2, |
26 | | - "in_chns" : 1, |
27 | | - "bilinear" : True, |
28 | | - "feature_chns": [32, 64, 128, 256, 512], |
29 | | - "dropout" : [0.0, 0.0, 0.3, 0.4, 0.5]} |
30 | | - config['network'] = net_param |
31 | | - |
32 | | - net = COPLENet(net_param) |
33 | | - agent = TrainInferAgent(config, 'test') |
34 | | - agent.set_network(net) |
| 28 | + stage = str(sys.argv[1]) |
| 29 | + cfg_file = str(sys.argv[2]) |
| 30 | + config = parse_config(cfg_file) |
| 31 | + |
| 32 | + # use custormized CNN and loss function |
| 33 | + agent = NetRunAgent(config, stage) |
| 34 | + agent.set_network_dict(my_net_dict) |
35 | 35 | agent.run() |
36 | 36 |
|
37 | 37 | if __name__ == "__main__": |
|
0 commit comments