From 04edf0ccd43dd304b22abdc72d403a53eee0be36 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 29 Jul 2022 10:37:19 +0800 Subject: [PATCH 001/225] add logging to classification agent --- pymic/net_run/agent_abstract.py | 3 ++- pymic/net_run/agent_cls.py | 13 +++++++------ pymic/net_run/agent_seg.py | 5 +++-- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/pymic/net_run/agent_abstract.py b/pymic/net_run/agent_abstract.py index 420e2f9..fce71f1 100644 --- a/pymic/net_run/agent_abstract.py +++ b/pymic/net_run/agent_abstract.py @@ -3,6 +3,7 @@ import os import random +import logging import torch import numpy as np import torch.optim as optim @@ -42,7 +43,7 @@ def __init__(self, config, stage = 'train'): self.random_seed = config['training'].get('random_seed', 1) if(self.deterministic): seed_torch(self.random_seed) - print("deterministric is true") + logging.info("deterministric is true") def set_datasets(self, train_set, valid_set, test_set): self.train_set = train_set diff --git a/pymic/net_run/agent_cls.py b/pymic/net_run/agent_cls.py index 4a80532..71d7c30 100644 --- a/pymic/net_run/agent_cls.py +++ b/pymic/net_run/agent_cls.py @@ -3,6 +3,7 @@ import copy import csv +import logging import time import torch from torchvision import transforms @@ -71,7 +72,7 @@ def create_network(self): else: self.net.double() param_number = sum(p.numel() for p in self.net.parameters() if p.requires_grad) - print('parameter number:', param_number) + logging.info('parameter number {0:}'.format(param_number)) def get_parameters_to_update(self): params = self.net.get_parameters_to_update() @@ -176,10 +177,10 @@ def write_scalars(self, train_scalars, valid_scalars, glob_it): self.summ_writer.add_scalars('loss', loss_scalar, glob_it) self.summ_writer.add_scalars(metrics, acc_scalar, glob_it) - print("{0:} it {1:}".format(str(datetime.now())[:-7], glob_it)) - print('train loss {0:.4f}, avg {1:} {2:.4f}'.format( + logging.info("{0:} it {1:}".format(str(datetime.now())[:-7], glob_it)) + logging.info('train loss {0:.4f}, avg {1:} {2:.4f}'.format( train_scalars['loss'], metrics, train_scalars[metrics])) - print('valid loss {0:.4f}, avg {1:} {2:.4f}'.format( + logging.info('valid loss {0:.4f}, avg {1:} {2:.4f}'.format( valid_scalars['loss'], metrics, valid_scalars[metrics])) def train_valid(self): @@ -218,7 +219,7 @@ def train_valid(self): self.trainIter = iter(self.train_loader) - print("{0:} training start".format(str(datetime.now())[:-7])) + logging.info("{0:} training start".format(str(datetime.now())[:-7])) self.summ_writer = SummaryWriter(self.config['training']['ckpt_save_dir']) for it in range(iter_start, iter_max, iter_valid): train_scalars = self.training() @@ -252,7 +253,7 @@ def train_valid(self): txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefx), 'wt') txt_file.write(str(self.max_val_it)) txt_file.close() - print('The best perfroming iter is {0:}, valid {1:} {2:}'.format(\ + logging.info('The best perfroming iter is {0:}, valid {1:} {2:}'.format(\ self.max_val_it, metrics, self.max_val_score)) self.summ_writer.close() diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index e4bb97c..e1d4cc4 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -323,6 +323,7 @@ def train_valid(self): t0 = time.time() train_scalars = self.training() t1 = time.time() + valid_scalars = self.validation() t2 = time.time() self.glob_it = it + iter_valid @@ -428,7 +429,7 @@ def test_time_dropout(m): self.save_ouputs(data) infer_time_list = np.asarray(infer_time_list) time_avg, time_std = infer_time_list.mean(), infer_time_list.std() - print("testing time {0:} +/- {1:}".format(time_avg, time_std)) + logging.info("testing time {0:} +/- {1:}".format(time_avg, time_std)) def infer_with_multiple_checkpoints(self): """ @@ -482,7 +483,7 @@ def infer_with_multiple_checkpoints(self): self.save_ouputs(data) infer_time_list = np.asarray(infer_time_list) time_avg, time_std = infer_time_list.mean(), infer_time_list.std() - print("testing time {0:} +/- {1:}".format(time_avg, time_std)) + logging.info("testing time {0:} +/- {1:}".format(time_avg, time_std)) def save_ouputs(self, data): output_dir = self.config['testing']['output_dir'] From c2bbec5522d4c1a7f1ac0fa29d1d2d7e764946d6 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 30 Jul 2022 10:25:29 +0800 Subject: [PATCH 002/225] add lr scheduler --- pymic/net_run/agent_abstract.py | 18 ++++++------- pymic/net_run/agent_seg.py | 33 ++++++++++++++++-------- pymic/net_run/get_optimizer.py | 45 ++++++++++++++++++++++++--------- pymic/util/general.py | 27 ++++++++++++++++++++ 4 files changed, 92 insertions(+), 31 deletions(-) create mode 100644 pymic/util/general.py diff --git a/pymic/net_run/agent_abstract.py b/pymic/net_run/agent_abstract.py index fce71f1..d701534 100644 --- a/pymic/net_run/agent_abstract.py +++ b/pymic/net_run/agent_abstract.py @@ -8,7 +8,7 @@ import numpy as np import torch.optim as optim from abc import ABCMeta, abstractmethod -from pymic.net_run.get_optimizer import get_optimiser +from pymic.net_run.get_optimizer import get_lr_scheduler, get_optimizer def seed_torch(seed=1): random.seed(seed) @@ -72,7 +72,9 @@ def get_checkpoint_name(self): ckpt_mode = self.config['testing']['ckpt_mode'] if(ckpt_mode == 0 or ckpt_mode == 1): ckpt_dir = self.config['training']['ckpt_save_dir'] - ckpt_prefix = ckpt_dir.split('/')[-1] + ckpt_prefix = self.config['training'].get('ckpt_prefix', None) + if(ckpt_prefix is None): + ckpt_prefix = ckpt_dir.split('/')[-1] txt_name = ckpt_dir + '/' + ckpt_prefix txt_name += "_latest.txt" if ckpt_mode == 0 else "_best.txt" with open(txt_name, 'r') as txt_file: @@ -146,19 +148,17 @@ def worker_init_fn(worker_id): batch_size = bn_test, shuffle=False, num_workers= bn_test) def create_optimizer(self, params): + opt_params = self.config['training'] if(self.optimizer is None): - self.optimizer = get_optimiser(self.config['training']['optimizer'], - params, - self.config['training']) + self.optimizer = get_optimizer(opt_params['optimizer'], + params, opt_params) last_iter = -1 if(self.checkpoint is not None): self.optimizer.load_state_dict(self.checkpoint['optimizer_state_dict']) last_iter = self.checkpoint['iteration'] - 1 if(self.scheduler is None): - self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, - self.config['training']['lr_milestones'], - self.config['training']['lr_gamma'], - last_epoch = last_iter) + opt_params["laster_iter"] = last_iter + self.scheduler = get_lr_scheduler(self.optimizer, opt_params) def convert_tensor_type(self, input_tensor): if(self.tensor_type == 'float'): diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index e1d4cc4..6da26cc 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -26,6 +26,7 @@ from pymic.loss.seg.util import get_classwise_dice from pymic.transform.trans_dict import TransformDict from pymic.util.image_process import convert_label +from pymic.util.general import keyword_match class SegmentationAgent(NetRunAgent): def __init__(self, config, stage = 'train'): @@ -192,10 +193,10 @@ def training(self): # forward + backward + optimize outputs = self.net(inputs) loss = self.get_loss_value(data, outputs, labels_prob) - # if (self.config['training']['use']) loss.backward() self.optimizer.step() - self.scheduler.step() + if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + self.scheduler.step() train_loss = train_loss + loss.item() # get dice evaluation for each class @@ -251,15 +252,19 @@ def validation(self): valid_cls_dice = np.asarray(valid_dice_list).mean(axis = 0) valid_avg_dice = valid_cls_dice.mean() + if(keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + self.scheduler.step(valid_avg_dice) + valid_scalers = {'loss': valid_avg_loss, 'avg_dice': valid_avg_dice,\ 'class_dice': valid_cls_dice} return valid_scalers - def write_scalars(self, train_scalars, valid_scalars, glob_it): + def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): loss_scalar ={'train':train_scalars['loss'], 'valid':valid_scalars['loss']} dice_scalar ={'train':train_scalars['avg_dice'], 'valid':valid_scalars['avg_dice']} self.summ_writer.add_scalars('loss', loss_scalar, glob_it) self.summ_writer.add_scalars('dice', dice_scalar, glob_it) + self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it) class_num = self.config['network']['class_num'] for c in range(class_num): cls_dice_scalar = {'train':train_scalars['class_dice'][c], \ @@ -282,13 +287,14 @@ def train_valid(self): self.device = torch.device("cuda:{0:}".format(device_ids[0])) self.net.to(self.device) ckpt_dir = self.config['training']['ckpt_save_dir'] - if(ckpt_dir[-1] == "/"): - ckpt_dir = ckpt_dir[:-1] - ckpt_prefx = ckpt_dir.split('/')[-1] + ckpt_prefix = self.config['training'].get('ckpt_prefix', None) + if(ckpt_prefix is None): + ckpt_prefix = ckpt_dir.split('/')[-1] iter_start = self.config['training']['iter_start'] iter_max = self.config['training']['iter_max'] iter_valid = self.config['training']['iter_valid'] iter_save = self.config['training']['iter_save'] + early_stop_it = self.config['training'].get('early_stop_patience', None) if(isinstance(iter_save, (tuple, list))): iter_save_list = iter_save else: @@ -299,7 +305,7 @@ def train_valid(self): self.best_model_wts = None self.checkpoint = None if(iter_start > 0): - checkpoint_file = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefx, iter_start) + checkpoint_file = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, iter_start) self.checkpoint = torch.load(checkpoint_file, map_location = self.device) # assert(self.checkpoint['iteration'] == iter_start) if(len(device_ids) > 1): @@ -320,6 +326,7 @@ def train_valid(self): self.summ_writer = SummaryWriter(self.config['training']['ckpt_save_dir']) self.glob_it = iter_start for it in range(iter_start, iter_max, iter_valid): + lr_value = self.optimizer.param_groups[0]['lr'] t0 = time.time() train_scalars = self.training() t1 = time.time() @@ -327,9 +334,10 @@ def train_valid(self): valid_scalars = self.validation() t2 = time.time() self.glob_it = it + iter_valid - logging.info("{0:} it {1:}".format(str(datetime.now())[:-7], self.glob_it)) + logging.info("\n{0:} it {1:}".format(str(datetime.now())[:-7], self.glob_it)) + logging.info('learning rate {0:}'.format(lr_value)) logging.info("training/validation time: {0:.2f}s/{1:.2f}s".format(t1-t0, t2-t1)) - self.write_scalars(train_scalars, valid_scalars, self.glob_it) + self.write_scalars(train_scalars, valid_scalars, lr_value, self.glob_it) if(valid_scalars['avg_dice'] > self.max_val_dice): self.max_val_dice = valid_scalars['avg_dice'] self.max_val_it = self.glob_it @@ -338,7 +346,9 @@ def train_valid(self): else: self.best_model_wts = copy.deepcopy(self.net.state_dict()) - if (self.glob_it in iter_save_list): + stop_now = True if(early_stop_it is not None and \ + self.glob_it - self.max_val_it > early_stop_it) else False + if ((self.glob_it in iter_save_list) or stop_now): save_dict = {'iteration': self.glob_it, 'valid_pred': valid_scalars['avg_dice'], 'model_state_dict': self.net.module.state_dict() \ @@ -349,6 +359,9 @@ def train_valid(self): txt_file = open("{0:}/{1:}_latest.txt".format(ckpt_dir, ckpt_prefx), 'wt') txt_file.write(str(self.glob_it)) txt_file.close() + if(stop_now): + logging.info("The training is early stopped") + break # save the best performing checkpoint save_dict = {'iteration': self.max_val_it, 'valid_pred': self.max_val_dice, diff --git a/pymic/net_run/get_optimizer.py b/pymic/net_run/get_optimizer.py index 7170b6e..e475286 100644 --- a/pymic/net_run/get_optimizer.py +++ b/pymic/net_run/get_optimizer.py @@ -2,33 +2,54 @@ from __future__ import print_function, division import torch -import torch.optim as optim +from torch import optim +from torch.optim import lr_scheduler +from pymic.util.general import keyword_match -def get_optimiser(name, net_params, optim_params): +def get_optimizer(name, net_params, optim_params): lr = optim_params['learning_rate'] momentum = optim_params['momentum'] weight_decay = optim_params['weight_decay'] - if(name == "SGD"): + if(keyword_match(name, "SGD")): return optim.SGD(net_params, lr, momentum = momentum, weight_decay = weight_decay) - elif(name == "Adam"): + elif(keyword_match(name, "Adam")): return optim.Adam(net_params, lr, weight_decay = weight_decay) - elif(name == "SparseAdam"): + elif(keyword_match(name, "SparseAdam")): return optim.SparseAdam(net_params, lr) - elif(name == "Adadelta"): + elif(keyword_match(name, "Adadelta")): return optim.Adadelta(net_params, lr, weight_decay = weight_decay) - elif(name == "Adagrad"): + elif(keyword_match(name, "Adagrad")): return optim.Adagrad(net_params, lr, weight_decay = weight_decay) - elif(name == "Adamax"): + elif(keyword_match(name, "Adamax")): return optim.Adamax(net_params, lr, weight_decay = weight_decay) - elif(name == "ASGD"): + elif(keyword_match(name, "ASGD")): return optim.ASGD(net_params, lr, weight_decay = weight_decay) - elif(name == "LBFGS"): + elif(keyword_match(name, "LBFGS")): return optim.LBFGS(net_params, lr) - elif(name == "RMSprop"): + elif(keyword_match(name, "RMSprop")): return optim.RMSprop(net_params, lr, momentum = momentum, weight_decay = weight_decay) - elif(name == "Rprop"): + elif(keyword_match(name, "Rprop")): return optim.Rprop(net_params, lr) else: raise ValueError("unsupported optimizer {0:}".format(name)) + + +def get_lr_scheduler(optimizer, sched_params): + name = sched_params["lr_scheduler"] + lr_gamma = sched_params["lr_gamma"] + if(keyword_match(name, "ReduceLROnPlateau")): + patience_it = sched_params["ReduceLROnPlateau_patience".lower()] + val_it = sched_params["iter_valid"] + patience = patience_it / val_it + scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, + mode = "max", factor=lr_gamma, patience = patience) + elif(keyword_match(name, "MultiStepLR")): + lr_milestones = sched_params["lr_milestones"] + last_iter = sched_params["last_iter"] + scheduler = lr_scheduler.MultiStepLR(optimizer, + lr_milestones, lr_gamma, last_iter) + else: + raise ValueError("unsupported lr scheduler {0:}".format(name)) + return scheduler \ No newline at end of file diff --git a/pymic/util/general.py b/pymic/util/general.py new file mode 100644 index 0000000..063d654 --- /dev/null +++ b/pymic/util/general.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import torch +import numpy as np + +def keyword_match(a,b): + return a.lower() == b.lower() + +def get_one_hot_seg(label, class_num): + """ + convert a segmentation label to one-hot + label: a tensor with a shape of [N, 1, D, H, W] or [N, 1, H, W] + class_num: class number. + output: an one-hot tensor with a shape of [N, C, D, H, W] or [N, C, H, W] + """ + size = list(label.size()) + if(size[1] != 1): + raise ValueError("The channel should be 1, \ + rather than {0:} before one-hot encoding".format(size[1])) + label = label.view(-1) + ones = torch.sparse.torch.eye(class_num).to(label.device) + one_hot = ones.index_select(0, label) + size.append(class_num) + one_hot = one_hot.view(*size) + one_hot = torch.transpose(one_hot, 1, -1) + one_hot = torch.squeeze(one_hot, -1) + return one_hot \ No newline at end of file From bde766242c45b938b73d42bd1f8858df47564581 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 30 Jul 2022 14:01:28 +0800 Subject: [PATCH 003/225] add early stop and update readme set early stop in agent_seg update readme for annotation-efficient learning --- README.md | 18 ++++++++++-------- pymic/net_run/agent_seg.py | 16 +++++++++------- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index d0aa89e..d6007e8 100644 --- a/README.md +++ b/README.md @@ -1,29 +1,31 @@ # PyMIC: A Pytorch-Based Toolkit for Medical Image Computing -PyMIC is a pytorch-based toolkit for medical image computing with deep learning. Despite that pytorch is a fantastic platform for deep learning, using it for medical image computing is not straightforward as medical images are often with higher dimension, multiple modalities and low contrast. The toolkit is developed to facilitate medical image computing researchers so that training and testing deep learning models become easier. It is very friendly to researchers who are new to this area. Even without writing any code, you can use PyMIC commands to train and test a model by simply editing configure files. +PyMIC is a pytorch-based toolkit for medical image computing with annotation-efficient deep learning. Despite that pytorch is a fantastic platform for deep learning, using it for medical image computing is not straightforward as medical images are often with high dimension and large volume, multiple modalities and difficulies in annotating. This toolkit is developed to facilitate medical image computing researchers so that training and testing deep learning models become easier. It is very friendly to researchers who are new to this area. Even without writing any code, you can use PyMIC commands to train and test a model by simply editing configuration files. PyMIC is developed to support learning with imperfect labels, including semi-supervised and weakly supervised learning, and learning with noisy annotations. Currently PyMIC supports 2D/3D medical image classification and segmentation, and it is still under development. It was originally developed for COVID-19 pneumonia lesion segmentation from CT images. If you use this toolkit, please cite the following paper: - * G. Wang, X. Liu, C. Li, Z. Xu, J. Ruan, H. Zhu, T. Meng, K. Li, N. Huang, S. Zhang. [A Noise-robust Framework for Automatic Segmentation of COVID-19 Pneumonia Lesions from CT Images.][tmi2020] IEEE Transactions on Medical Imaging. 39(8):2653-2663, 2020. DOI: [10.1109/TMI.2020.3000314][tmi2020] [tmi2020]:https://ieeexplore.ieee.org/document/9109297 -# Advantages -PyMIC provides some basic modules for medical image computing that can be share by different applications. We currently provide the following functions: +# Features +PyMIC provides flixible modules for medical image computing tasks including classification and segmentation. It currently provides the following functions: +* Support for annotation-efficient image segmentation, especially for semi-supervised, weakly-supervised and noisy-label learning. +* User friendly: For beginners, you only need to edit the configuration files for model training and inference, without writing code. For advanced users, you can customize different modules (networks, loss functions, training pipeline, etc) and easily integrate them into PyMIC. * Easy-to-use I/O interface to read and write different 2D and 3D images. +* Various data pre-processing/transformation methods before sending a tensor into a network. +* Implementation of typical neural networks for medical image segmentation. * Re-useable training and testing pipeline that can be transferred to different tasks. -* Various data pre-processing methods before sending a tensor into a network. -* Implementation of loss functions, especially for image segmentation. -* Implementation of evaluation metrics to get quantitative evaluation of your methods (for segmentation). +* Evaluation metrics for quantitative evaluation of your methods. # Usage ## Requirement * [Pytorch][torch_link] version >=1.0.1 * [TensorboardX][tbx_link] to visualize training performance * Some common python packages such as Numpy, Pandas, SimpleITK +* See `requirements.txt` for details. [torch_link]:https://pytorch.org/ [tbx_link]:https://github.com/lanpa/tensorboardX @@ -42,7 +44,7 @@ python setup.py install ``` ## Examples -[PyMIC_examples][examples] provides some examples of starting to use PyMIC. For beginners, you only need to simply change the configuration files to select different datasets, networks and training methods for running the code. For advanced users, you can develop your own modules based on this package. You can find both types of examples +[PyMIC_examples][examples] provides some examples of starting to use PyMIC. At the beginning, you only need to edit the configuration files to select different datasets, networks and training methods for running the code. When you are more familiar with PyMIC, you can customize different modules in the PyMIC package. You can find both types of examples: [examples]: https://github.com/HiLab-git/PyMIC_examples diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 6da26cc..0cc59b3 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -293,12 +293,14 @@ def train_valid(self): iter_start = self.config['training']['iter_start'] iter_max = self.config['training']['iter_max'] iter_valid = self.config['training']['iter_valid'] - iter_save = self.config['training']['iter_save'] + iter_save = self.config['training'].get('iter_save', None) early_stop_it = self.config['training'].get('early_stop_patience', None) - if(isinstance(iter_save, (tuple, list))): + if(iter_save is None): + iter_save_list = [iter_max] + elif(isinstance(iter_save, (tuple, list))): iter_save_list = iter_save else: - iter_save_list = range(iter_start, iter_max +1, iter_save) + iter_save_list = range(iter_start, iter_max + 1, iter_save) self.max_val_dice = 0.0 self.max_val_it = 0 @@ -354,9 +356,9 @@ def train_valid(self): 'model_state_dict': self.net.module.state_dict() \ if len(device_ids) > 1 else self.net.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict()} - save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefx, self.glob_it) + save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.glob_it) torch.save(save_dict, save_name) - txt_file = open("{0:}/{1:}_latest.txt".format(ckpt_dir, ckpt_prefx), 'wt') + txt_file = open("{0:}/{1:}_latest.txt".format(ckpt_dir, ckpt_prefix), 'wt') txt_file.write(str(self.glob_it)) txt_file.close() if(stop_now): @@ -367,9 +369,9 @@ def train_valid(self): 'valid_pred': self.max_val_dice, 'model_state_dict': self.best_model_wts, 'optimizer_state_dict': self.optimizer.state_dict()} - save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefx, self.max_val_it) + save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.max_val_it) torch.save(save_dict, save_name) - txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefx), 'wt') + txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefix), 'wt') txt_file.write(str(self.max_val_it)) txt_file.close() logging.info('The best performing iter is {0:}, valid dice {1:}'.format(\ From b98633fbfa721592b749c26fc93480080b4216d0 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 30 Jul 2022 14:10:13 +0800 Subject: [PATCH 004/225] Update README.md --- docs/README.md | 50 ++++++++++++++++++++++---------------------------- 1 file changed, 22 insertions(+), 28 deletions(-) diff --git a/docs/README.md b/docs/README.md index cdfd09b..8886f09 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1,41 +1,35 @@ -## Welcome to PyMIC Documentation +## Welcome to PyMIC -This page is under construction, we will update it later. +PyMIC is a pytorch-based toolkit for medical image computing with annotation-efficient deep learning. Despite that pytorch is a fantastic platform for deep learning, using it for medical image computing is not straightforward as medical images are often with high dimension and large volume, multiple modalities and difficulies in annotating. This toolkit is developed to facilitate medical image computing researchers so that training and testing deep learning models become easier. It is very friendly to researchers who are new to this area. Even without writing any code, you can use PyMIC commands to train and test a model by simply editing configuration files. PyMIC is developed to support learning with imperfect labels, including semi-supervised and weakly supervised learning, and learning with noisy annotations. + +Currently PyMIC supports 2D/3D medical image classification and segmentation, and it is still under development. It was originally developed for COVID-19 pneumonia lesion segmentation from CT images. ### Features +PyMIC provides flixible modules for medical image computing tasks including classification and segmentation. It currently provides the following functions: +* Support for annotation-efficient image segmentation, especially for semi-supervised, weakly-supervised and noisy-label learning. +* User friendly: For beginners, you only need to edit the configuration files for model training and inference, without writing code. For advanced users, you can customize different modules (networks, loss functions, training pipeline, etc) and easily integrate them into PyMIC. +* Easy-to-use I/O interface to read and write different 2D and 3D images. +* Various data pre-processing/transformation methods before sending a tensor into a network. +* Implementation of typical neural networks for medical image segmentation. +* Re-useable training and testing pipeline that can be transferred to different tasks. +* Evaluation metrics for quantitative evaluation of your methods. ### Installation +Run the following command to install the current released version of PyMIC: -### Quick Start - -### Links - -### Markdown - -```markdown -Syntax highlighted code block - -# Header 1 -## Header 2 -### Header 3 - -- Bulleted -- List - -1. Numbered -2. List - -**Bold** and _Italic_ and `Code` text - -[Link](url) and ![Image](src) +```bash +pip install PYMIC ``` -For more details see [Basic writing and formatting syntax](https://docs.github.com/en/github/writing-on-github/getting-started-with-writing-and-formatting-on-github/basic-writing-and-formatting-syntax). +Alternatively, you can download the source code for the latest version. Run the following command to compile and install: -### Jekyll Themes +```bash +python setup.py install +``` -Your Pages site will use the layout and styles from the Jekyll theme you have selected in your [repository settings](https://github.com/HiLab-git/PyMIC/settings/pages). The name of this theme is saved in the Jekyll `_config.yml` configuration file. +### Quick Start +[PyMIC_examples][examples] provides some examples of starting to use PyMIC. At the beginning, you only need to edit the configuration files to select different datasets, networks and training methods for running the code. When you are more familiar with PyMIC, you can customize different modules in the PyMIC package. You can find both types of examples: -### Support or Contact +[examples]: https://github.com/HiLab-git/PyMIC_examples Having trouble with Pages? Check out our [documentation](https://docs.github.com/categories/github-pages-basics/) or [contact support](https://support.github.com/contact) and we’ll help you sort it out. From b0f09afb63da076628af49f1d14c8ad73226339c Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 30 Jul 2022 14:11:43 +0800 Subject: [PATCH 005/225] Update README.md --- docs/README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/README.md b/docs/README.md index 8886f09..174cc33 100644 --- a/docs/README.md +++ b/docs/README.md @@ -32,4 +32,3 @@ python setup.py install [examples]: https://github.com/HiLab-git/PyMIC_examples -Having trouble with Pages? Check out our [documentation](https://docs.github.com/categories/github-pages-basics/) or [contact support](https://support.github.com/contact) and we’ll help you sort it out. From f5ff615c462853a9db538eb6889d5d7ba9c584df Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 30 Jul 2022 14:15:22 +0800 Subject: [PATCH 006/225] Update index.md --- docs/index.md | 55 +++++++++++++++++++++------------------------------ 1 file changed, 22 insertions(+), 33 deletions(-) diff --git a/docs/index.md b/docs/index.md index be80e7b..174cc33 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,45 +1,34 @@ -## Welcome to PyMIC Documentation +## Welcome to PyMIC -This page is under construction, we will update it later. +PyMIC is a pytorch-based toolkit for medical image computing with annotation-efficient deep learning. Despite that pytorch is a fantastic platform for deep learning, using it for medical image computing is not straightforward as medical images are often with high dimension and large volume, multiple modalities and difficulies in annotating. This toolkit is developed to facilitate medical image computing researchers so that training and testing deep learning models become easier. It is very friendly to researchers who are new to this area. Even without writing any code, you can use PyMIC commands to train and test a model by simply editing configuration files. PyMIC is developed to support learning with imperfect labels, including semi-supervised and weakly supervised learning, and learning with noisy annotations. + +Currently PyMIC supports 2D/3D medical image classification and segmentation, and it is still under development. It was originally developed for COVID-19 pneumonia lesion segmentation from CT images. ### Features +PyMIC provides flixible modules for medical image computing tasks including classification and segmentation. It currently provides the following functions: +* Support for annotation-efficient image segmentation, especially for semi-supervised, weakly-supervised and noisy-label learning. +* User friendly: For beginners, you only need to edit the configuration files for model training and inference, without writing code. For advanced users, you can customize different modules (networks, loss functions, training pipeline, etc) and easily integrate them into PyMIC. +* Easy-to-use I/O interface to read and write different 2D and 3D images. +* Various data pre-processing/transformation methods before sending a tensor into a network. +* Implementation of typical neural networks for medical image segmentation. +* Re-useable training and testing pipeline that can be transferred to different tasks. +* Evaluation metrics for quantitative evaluation of your methods. ### Installation +Run the following command to install the current released version of PyMIC: -### Quick Start - -### Links -[API][api_link] - -[api_link]:API/index.html -### Markdown - -Markdown is a lightweight and easy-to-use syntax for styling your writing. It includes conventions for - -```markdown -Syntax highlighted code block - -# Header 1 -## Header 2 -### Header 3 - -- Bulleted -- List - -1. Numbered -2. List - -**Bold** and _Italic_ and `Code` text - -[Link](url) and ![Image](src) +```bash +pip install PYMIC ``` -For more details see [Basic writing and formatting syntax](https://docs.github.com/en/github/writing-on-github/getting-started-with-writing-and-formatting-on-github/basic-writing-and-formatting-syntax). +Alternatively, you can download the source code for the latest version. Run the following command to compile and install: -### Jekyll Themes +```bash +python setup.py install +``` -Your Pages site will use the layout and styles from the Jekyll theme you have selected in your [repository settings](https://github.com/HiLab-git/PyMIC/settings/pages). The name of this theme is saved in the Jekyll `_config.yml` configuration file. +### Quick Start +[PyMIC_examples][examples] provides some examples of starting to use PyMIC. At the beginning, you only need to edit the configuration files to select different datasets, networks and training methods for running the code. When you are more familiar with PyMIC, you can customize different modules in the PyMIC package. You can find both types of examples: -### Support or Contact +[examples]: https://github.com/HiLab-git/PyMIC_examples -Having trouble with Pages? Check out our [documentation](https://docs.github.com/categories/github-pages-basics/) or [contact support](https://support.github.com/contact) and we’ll help you sort it out. From 15b227a6724afefc7dae623c6ab545a674333a86 Mon Sep 17 00:00:00 2001 From: taigw Date: Sun, 31 Jul 2022 16:13:43 +0800 Subject: [PATCH 007/225] update intensity transform add gaussian noise --- pymic/transform/gamma_correction.py | 40 ----------------- pymic/transform/intensity.py | 70 +++++++++++++++++++++++++++++ pymic/transform/trans_dict.py | 6 ++- 3 files changed, 74 insertions(+), 42 deletions(-) delete mode 100644 pymic/transform/gamma_correction.py create mode 100644 pymic/transform/intensity.py diff --git a/pymic/transform/gamma_correction.py b/pymic/transform/gamma_correction.py deleted file mode 100644 index 4a88f1c..0000000 --- a/pymic/transform/gamma_correction.py +++ /dev/null @@ -1,40 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division - -import torch -import json -import math -import random -import numpy as np -from scipy import ndimage -from pymic.transform.abstract_transform import AbstractTransform -from pymic.util.image_process import * - - -class ChannelWiseGammaCorrection(AbstractTransform): - """ - apply random gamma correction to each channel - """ - def __init__(self, params): - """ - (gamma_min, gamma_max) specify the range of gamma - """ - super(ChannelWiseGammaCorrection, self).__init__(params) - self.gamma_min = params['ChannelWiseGammaCorrection_gamma_min'.lower()] - self.gamma_max = params['ChannelWiseGammaCorrection_gamma_max'.lower()] - self.inverse = params.get('ChannelWiseGammaCorrection_inverse'.lower(), False) - - def __call__(self, sample): - image= sample['image'] - for chn in range(image.shape[0]): - gamma_c = random.random() * (self.gamma_max - self.gamma_min) + self.gamma_min - img_c = image[chn] - v_min = img_c.min() - v_max = img_c.max() - img_c = (img_c - v_min)/(v_max - v_min) - img_c = np.power(img_c, gamma_c)*(v_max - v_min) + v_min - image[chn] = img_c - - sample['image'] = image - return sample - diff --git a/pymic/transform/intensity.py b/pymic/transform/intensity.py new file mode 100644 index 0000000..b9e6070 --- /dev/null +++ b/pymic/transform/intensity.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import json +import math +import random +import numpy as np +from scipy import ndimage +from pymic.transform.abstract_transform import AbstractTransform +from pymic.util.image_process import * + + +class GammaCorrection(AbstractTransform): + """ + apply random gamma correction to each channel + """ + def __init__(self, params): + """ + (gamma_min, gamma_max) specify the range of gamma + """ + super(GammaCorrection, self).__init__(params) + self.channels = params['GammaCorrection_channels'.lower()] + self.gamma_min = params['GammaCorrection_gamma_min'.lower()] + self.gamma_max = params['GammaCorrection_gamma_max'.lower()] + self.prob = params.get('GammaCorrection_probability'.lower(), 0.5) + self.inverse = params.get('GammaCorrection_inverse'.lower(), False) + + def __call__(self, sample): + if(np.random.uniform() > self.prob): + return sample + image= sample['image'] + for chn in self.channels: + gamma_c = random.random() * (self.gamma_max - self.gamma_min) + self.gamma_min + img_c = image[chn] + v_min = img_c.min() + v_max = img_c.max() + img_c = (img_c - v_min)/(v_max - v_min) + img_c = np.power(img_c, gamma_c)*(v_max - v_min) + v_min + image[chn] = img_c + + sample['image'] = image + return sample + +class GaussianNoise(AbstractTransform): + """ + apply random gamma correction to each channel + """ + def __init__(self, params): + """ + (gamma_min, gamma_max) specify the range of gamma + """ + super(GaussianNoise, self).__init__(params) + self.channels = params['GaussianNoise_channels'.lower()] + self.mean = params['GaussianNoise_mean'.lower()] + self.std = params['GaussianNoise_std'.lower()] + self.prob = params.get('GaussianNoise_probability'.lower(), 0.5) + self.inverse = params.get('GaussianNoise_inverse'.lower(), False) + + def __call__(self, sample): + if(np.random.uniform() > self.prob): + return sample + image= sample['image'] + for chn in self.channels: + img_c = image[chn] + noise = np.random.normal(self.mean, self.std, img_c.shape) + image[chn] = img_c + noise + + sample['image'] = image + return sample diff --git a/pymic/transform/trans_dict.py b/pymic/transform/trans_dict.py index ae9ce9c..d90e431 100644 --- a/pymic/transform/trans_dict.py +++ b/pymic/transform/trans_dict.py @@ -1,8 +1,9 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division -from pymic.transform.gamma_correction import ChannelWiseGammaCorrection +from pymic.transform.intensity import * from pymic.transform.gray2rgb import GrayscaleToRGB from pymic.transform.flip import RandomFlip +from pymic.transform.intensity import GaussianNoise from pymic.transform.pad import Pad from pymic.transform.rotate import RandomRotate from pymic.transform.rescale import Rescale, RandomRescale @@ -12,12 +13,13 @@ from pymic.transform.label_convert import * TransformDict = { - 'ChannelWiseGammaCorrection': ChannelWiseGammaCorrection, 'ChannelWiseThreshold': ChannelWiseThreshold, 'ChannelWiseThresholdWithNormalize': ChannelWiseThresholdWithNormalize, 'CropWithBoundingBox': CropWithBoundingBox, 'CenterCrop': CenterCrop, 'GrayscaleToRGB': GrayscaleToRGB, + 'GammaCorrection': GammaCorrection, + 'GaussianNoise': GaussianNoise, 'LabelConvert': LabelConvert, 'LabelConvertNonzero': LabelConvertNonzero, 'LabelToProbability': LabelToProbability, From 376f2108bc2c8293f467697b4946aea6346e9810 Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 1 Aug 2022 10:00:52 +0800 Subject: [PATCH 008/225] Update infer_func.py set pre-defined tta for 2d images --- pymic/net_run/infer_func.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/pymic/net_run/infer_func.py b/pymic/net_run/infer_func.py index e603725..35bfb4c 100644 --- a/pymic/net_run/infer_func.py +++ b/pymic/net_run/infer_func.py @@ -131,24 +131,26 @@ def run(self, model, image): tta_mode = self.config.get('tta_mode', 0) if(tta_mode == 0): outputs = self.__infer(image) - elif(tta_mode == 1): # test time augmentation with flip in 2D + elif(tta_mode == 1): + # test time augmentation with flip in 2D + # you may define your own method for test time augmentation outputs1 = self.__infer(image) outputs2 = self.__infer(torch.flip(image, [-2])) - outputs3 = self.__infer(torch.flip(image, [-3])) - outputs4 = self.__infer(torch.flip(image, [-2, -3])) + outputs3 = self.__infer(torch.flip(image, [-1])) + outputs4 = self.__infer(torch.flip(image, [-2, -1])) if(isinstance(outputs1, (tuple, list))): outputs = [] for i in range(len(outputs)): temp_out1 = outputs1[i] temp_out2 = torch.flip(outputs2[i], [-2]) - temp_out3 = torch.flip(outputs3[i], [-3]) - temp_out4 = torch.flip(outputs4[i], [-2, -3]) + temp_out3 = torch.flip(outputs3[i], [-1]) + temp_out4 = torch.flip(outputs4[i], [-2, -1]) temp_mean = (temp_out1 + temp_out2 + temp_out3 + temp_out4) / 4 outputs.append(temp_mean) else: outputs2 = torch.flip(outputs2, [-2]) - outputs3 = torch.flip(outputs3, [-3]) - outputs4 = torch.flip(outputs4, [-2, -3]) + outputs3 = torch.flip(outputs3, [-1]) + outputs4 = torch.flip(outputs4, [-2, -1]) outputs = (outputs1 + outputs2 + outputs3 + outputs4) / 4 else: raise ValueError("Undefined tta_mode {0:}".format(tta_mode)) From 31113ed7cd14d84ea45f4af4ac255e6a66a8bf2e Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 1 Aug 2022 16:41:00 +0800 Subject: [PATCH 009/225] update ssl and nll rename net_run_noise as net_run_nll update ssl_abstract --- pymic/{net_run_noise => net_run_nll}/cl.py | 0 .../co_teaching.py | 0 pymic/net_run_ssl/ssl_abstract.py | 101 ++++++++++++++++++ pymic/net_run_ssl/ssl_cps.py | 4 +- pymic/net_run_ssl/ssl_em.py | 80 +------------- pymic/net_run_ssl/ssl_mt.py | 4 +- pymic/net_run_ssl/ssl_urpc.py | 4 +- 7 files changed, 110 insertions(+), 83 deletions(-) rename pymic/{net_run_noise => net_run_nll}/cl.py (100%) rename pymic/{net_run_noise => net_run_nll}/co_teaching.py (100%) create mode 100644 pymic/net_run_ssl/ssl_abstract.py diff --git a/pymic/net_run_noise/cl.py b/pymic/net_run_nll/cl.py similarity index 100% rename from pymic/net_run_noise/cl.py rename to pymic/net_run_nll/cl.py diff --git a/pymic/net_run_noise/co_teaching.py b/pymic/net_run_nll/co_teaching.py similarity index 100% rename from pymic/net_run_noise/co_teaching.py rename to pymic/net_run_nll/co_teaching.py diff --git a/pymic/net_run_ssl/ssl_abstract.py b/pymic/net_run_ssl/ssl_abstract.py new file mode 100644 index 0000000..acb3b5a --- /dev/null +++ b/pymic/net_run_ssl/ssl_abstract.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import logging +import numpy as np +import random +import torch +import torchvision.transforms as transforms +from pymic.io.nifty_dataset import NiftyDataset +from pymic.loss.seg.util import get_soft_label +from pymic.loss.seg.util import reshape_prediction_and_ground_truth +from pymic.loss.seg.util import get_classwise_dice +from pymic.loss.seg.ssl import EntropyLoss +from pymic.net_run.agent_seg import SegmentationAgent +from pymic.transform.trans_dict import TransformDict +from pymic.util.ramps import sigmoid_rampup + +class SSLSegAgent(SegmentationAgent): + """ + Implementation of the following paper: + Yves Grandvalet and Yoshua Bengio, + Semi-supervised Learningby Entropy Minimization. + NeurIPS, 2005. + """ + def __init__(self, config, stage = 'train'): + super(SSLSegAgent, self).__init__(config, stage) + self.transform_dict = TransformDict + self.train_set_unlab = None + + def get_unlabeled_dataset_from_config(self): + root_dir = self.config['dataset']['root_dir'] + modal_num = self.config['dataset']['modal_num'] + transform_names = self.config['dataset']['train_transform_unlab'] + + self.transform_list = [] + if(transform_names is None or len(transform_names) == 0): + data_transform = None + else: + transform_param = self.config['dataset'] + transform_param['task'] = 'segmentation' + for name in transform_names: + if(name not in self.transform_dict): + raise(ValueError("Undefined transform {0:}".format(name))) + one_transform = self.transform_dict[name](transform_param) + self.transform_list.append(one_transform) + data_transform = transforms.Compose(self.transform_list) + + csv_file = self.config['dataset'].get('train_csv_unlab', None) + dataset = NiftyDataset(root_dir=root_dir, + csv_file = csv_file, + modal_num = modal_num, + with_label= False, + transform = data_transform ) + return dataset + + def create_dataset(self): + super(SSLSegAgent, self).create_dataset() + if(self.stage == 'train'): + if(self.train_set_unlab is None): + self.train_set_unlab = self.get_unlabeled_dataset_from_config() + if(self.deterministic): + def worker_init_fn(worker_id): + random.seed(self.random_seed+worker_id) + worker_init = worker_init_fn + else: + worker_init = None + + bn_train_unlab = self.config['dataset']['train_batch_size_unlab'] + num_worker = self.config['dataset'].get('num_workder', 16) + self.train_loader_unlab = torch.utils.data.DataLoader(self.train_set_unlab, + batch_size = bn_train_unlab, shuffle=True, num_workers= num_worker, + worker_init_fn=worker_init) + + def training(self): + pass + + def write_scalars(self, train_scalars, valid_scalars, glob_it): + loss_scalar ={'train':train_scalars['loss'], + 'valid':valid_scalars['loss']} + loss_sup_scalar = {'train':train_scalars['loss_sup']} + loss_upsup_scalar = {'train':train_scalars['loss_reg']} + dice_scalar ={'train':train_scalars['avg_dice'], 'valid':valid_scalars['avg_dice']} + self.summ_writer.add_scalars('loss', loss_scalar, glob_it) + self.summ_writer.add_scalars('loss_sup', loss_sup_scalar, glob_it) + self.summ_writer.add_scalars('loss_reg', loss_upsup_scalar, glob_it) + self.summ_writer.add_scalars('regular_w', {'regular_w':train_scalars['regular_w']}, glob_it) + self.summ_writer.add_scalars('dice', dice_scalar, glob_it) + class_num = self.config['network']['class_num'] + for c in range(class_num): + cls_dice_scalar = {'train':train_scalars['class_dice'][c], \ + 'valid':valid_scalars['class_dice'][c]} + self.summ_writer.add_scalars('class_{0:}_dice'.format(c), cls_dice_scalar, glob_it) + logging.info('train loss {0:.4f}, avg dice {1:.4f} '.format( + train_scalars['loss'], train_scalars['avg_dice']) + "[" + \ + ' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]") + logging.info('valid loss {0:.4f}, avg dice {1:.4f} '.format( + valid_scalars['loss'], valid_scalars['avg_dice']) + "[" + \ + ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") + + def train_valid(self): + self.trainIter_unlab = iter(self.train_loader_unlab) + super(SSLSegAgent, self).train_valid() diff --git a/pymic/net_run_ssl/ssl_cps.py b/pymic/net_run_ssl/ssl_cps.py index e9bc41b..49a6e11 100644 --- a/pymic/net_run_ssl/ssl_cps.py +++ b/pymic/net_run_ssl/ssl_cps.py @@ -9,10 +9,10 @@ from pymic.loss.seg.util import get_classwise_dice from pymic.util.ramps import sigmoid_rampup from pymic.net_run.get_optimizer import get_optimiser -from pymic.net_run_ssl.ssl_em import SSLEntropyMinimization +from pymic.net_run_ssl.ssl_abstract import SSLSegAgent from pymic.net.net_dict_seg import SegNetDict -class SSLCrossPseudoSupervision(SSLEntropyMinimization): +class SSLCrossPseudoSupervision(SSLSegAgent): """ Using cross pseudo supervision according to the following paper: Xiaokang Chen, Yuhui Yuan, Gang Zeng, Jingdong Wang, diff --git a/pymic/net_run_ssl/ssl_em.py b/pymic/net_run_ssl/ssl_em.py index 32f85c1..28a9dec 100644 --- a/pymic/net_run_ssl/ssl_em.py +++ b/pymic/net_run_ssl/ssl_em.py @@ -2,19 +2,16 @@ from __future__ import print_function, division import logging import numpy as np -import random import torch -import torchvision.transforms as transforms -from pymic.io.nifty_dataset import NiftyDataset from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.ssl import EntropyLoss -from pymic.net_run.agent_seg import SegmentationAgent +from pymic.net_run_ssl.ssl_abstract import SSLSegAgent from pymic.transform.trans_dict import TransformDict from pymic.util.ramps import sigmoid_rampup -class SSLEntropyMinimization(SegmentationAgent): +class SSLEntropyMinimization(SSLSegAgent): """ Implementation of the following paper: Yves Grandvalet and Yoshua Bengio, @@ -26,50 +23,6 @@ def __init__(self, config, stage = 'train'): self.transform_dict = TransformDict self.train_set_unlab = None - def get_unlabeled_dataset_from_config(self): - root_dir = self.config['dataset']['root_dir'] - modal_num = self.config['dataset']['modal_num'] - transform_names = self.config['dataset']['train_transform_unlab'] - - self.transform_list = [] - if(transform_names is None or len(transform_names) == 0): - data_transform = None - else: - transform_param = self.config['dataset'] - transform_param['task'] = 'segmentation' - for name in transform_names: - if(name not in self.transform_dict): - raise(ValueError("Undefined transform {0:}".format(name))) - one_transform = self.transform_dict[name](transform_param) - self.transform_list.append(one_transform) - data_transform = transforms.Compose(self.transform_list) - - csv_file = self.config['dataset'].get('train_csv_unlab', None) - dataset = NiftyDataset(root_dir=root_dir, - csv_file = csv_file, - modal_num = modal_num, - with_label= False, - transform = data_transform ) - return dataset - - def create_dataset(self): - super(SSLEntropyMinimization, self).create_dataset() - if(self.stage == 'train'): - if(self.train_set_unlab is None): - self.train_set_unlab = self.get_unlabeled_dataset_from_config() - if(self.deterministic): - def worker_init_fn(worker_id): - random.seed(self.random_seed+worker_id) - worker_init = worker_init_fn - else: - worker_init = None - - bn_train_unlab = self.config['dataset']['train_batch_size_unlab'] - num_worker = self.config['dataset'].get('num_workder', 16) - self.train_loader_unlab = torch.utils.data.DataLoader(self.train_set_unlab, - batch_size = bn_train_unlab, shuffle=True, num_workers= num_worker, - worker_init_fn=worker_init) - def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] @@ -142,31 +95,4 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} - return train_scalers - - def write_scalars(self, train_scalars, valid_scalars, glob_it): - loss_scalar ={'train':train_scalars['loss'], - 'valid':valid_scalars['loss']} - loss_sup_scalar = {'train':train_scalars['loss_sup']} - loss_upsup_scalar = {'train':train_scalars['loss_reg']} - dice_scalar ={'train':train_scalars['avg_dice'], 'valid':valid_scalars['avg_dice']} - self.summ_writer.add_scalars('loss', loss_scalar, glob_it) - self.summ_writer.add_scalars('loss_sup', loss_sup_scalar, glob_it) - self.summ_writer.add_scalars('loss_reg', loss_upsup_scalar, glob_it) - self.summ_writer.add_scalars('regular_w', {'regular_w':train_scalars['regular_w']}, glob_it) - self.summ_writer.add_scalars('dice', dice_scalar, glob_it) - class_num = self.config['network']['class_num'] - for c in range(class_num): - cls_dice_scalar = {'train':train_scalars['class_dice'][c], \ - 'valid':valid_scalars['class_dice'][c]} - self.summ_writer.add_scalars('class_{0:}_dice'.format(c), cls_dice_scalar, glob_it) - logging.info('train loss {0:.4f}, avg dice {1:.4f} '.format( - train_scalars['loss'], train_scalars['avg_dice']) + "[" + \ - ' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]") - logging.info('valid loss {0:.4f}, avg dice {1:.4f} '.format( - valid_scalars['loss'], valid_scalars['avg_dice']) + "[" + \ - ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") - - def train_valid(self): - self.trainIter_unlab = iter(self.train_loader_unlab) - super(SSLEntropyMinimization, self).train_valid() + return train_scalers \ No newline at end of file diff --git a/pymic/net_run_ssl/ssl_mt.py b/pymic/net_run_ssl/ssl_mt.py index d25edbc..b96e0b1 100644 --- a/pymic/net_run_ssl/ssl_mt.py +++ b/pymic/net_run_ssl/ssl_mt.py @@ -7,10 +7,10 @@ from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.util.ramps import sigmoid_rampup -from pymic.net_run_ssl.ssl_em import SSLEntropyMinimization +from pymic.net_run_ssl.ssl_abstract import SSLSegAgent from pymic.net.net_dict_seg import SegNetDict -class SSLMeanTeacher(SSLEntropyMinimization): +class SSLMeanTeacher(SSLSegAgent): """ Training and testing agent for semi-supervised segmentation """ diff --git a/pymic/net_run_ssl/ssl_urpc.py b/pymic/net_run_ssl/ssl_urpc.py index d2d5a1f..d9dd953 100644 --- a/pymic/net_run_ssl/ssl_urpc.py +++ b/pymic/net_run_ssl/ssl_urpc.py @@ -8,9 +8,9 @@ from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.util.ramps import sigmoid_rampup -from pymic.net_run_ssl.ssl_em import SSLEntropyMinimization +from pymic.net_run_ssl.ssl_abstract import SSLSegAgent -class SSLURPC(SSLEntropyMinimization): +class SSLURPC(SSLSegAgent): """ Uncertainty-Rectified Pyramid Consistency according to the following paper: Xiangde Luo, Wenjun Liao, Jieneng Chen, Tao Song, Yinan Chen, From 58109748221f2ffb705b21c368665b181bcac050 Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 1 Aug 2022 16:58:27 +0800 Subject: [PATCH 010/225] update wsl rename the classes --- pymic/net_run_wsl/wsl_abstract.py | 37 +++++++++++++++++++++++++++ pymic/net_run_wsl/wsl_dmpls.py | 10 +++----- pymic/net_run_wsl/wsl_em.py | 34 +++--------------------- pymic/net_run_wsl/wsl_gatedcrf.py | 10 +++----- pymic/net_run_wsl/wsl_mumford_shah.py | 10 +++----- pymic/net_run_wsl/wsl_tv.py | 10 +++----- pymic/net_run_wsl/wsl_ustm.py | 12 +++------ 7 files changed, 57 insertions(+), 66 deletions(-) create mode 100644 pymic/net_run_wsl/wsl_abstract.py diff --git a/pymic/net_run_wsl/wsl_abstract.py b/pymic/net_run_wsl/wsl_abstract.py new file mode 100644 index 0000000..fe80ea5 --- /dev/null +++ b/pymic/net_run_wsl/wsl_abstract.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import logging +from pymic.net_run.agent_seg import SegmentationAgent + +class WSLSegAgent(SegmentationAgent): + """ + Training and testing agent for semi-supervised segmentation + """ + def __init__(self, config, stage = 'train'): + super(WSLSegAgent, self).__init__(config, stage) + + def training(self): + pass + + def write_scalars(self, train_scalars, valid_scalars, glob_it): + loss_scalar ={'train':train_scalars['loss'], + 'valid':valid_scalars['loss']} + loss_sup_scalar = {'train':train_scalars['loss_sup']} + loss_upsup_scalar = {'train':train_scalars['loss_reg']} + dice_scalar ={'train':train_scalars['avg_dice'], 'valid':valid_scalars['avg_dice']} + self.summ_writer.add_scalars('loss', loss_scalar, glob_it) + self.summ_writer.add_scalars('loss_sup', loss_sup_scalar, glob_it) + self.summ_writer.add_scalars('loss_reg', loss_upsup_scalar, glob_it) + self.summ_writer.add_scalars('regular_w', {'regular_w':train_scalars['regular_w']}, glob_it) + self.summ_writer.add_scalars('dice', dice_scalar, glob_it) + class_num = self.config['network']['class_num'] + for c in range(class_num): + cls_dice_scalar = {'train':train_scalars['class_dice'][c], \ + 'valid':valid_scalars['class_dice'][c]} + self.summ_writer.add_scalars('class_{0:}_dice'.format(c), cls_dice_scalar, glob_it) + logging.info('train loss {0:.4f}, avg dice {1:.4f} '.format( + train_scalars['loss'], train_scalars['avg_dice']) + "[" + \ + ' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]") + logging.info('valid loss {0:.4f}, avg dice {1:.4f} '.format( + valid_scalars['loss'], valid_scalars['avg_dice']) + "[" + \ + ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") diff --git a/pymic/net_run_wsl/wsl_dmpls.py b/pymic/net_run_wsl/wsl_dmpls.py index c42ed7a..234f1a2 100644 --- a/pymic/net_run_wsl/wsl_dmpls.py +++ b/pymic/net_run_wsl/wsl_dmpls.py @@ -4,18 +4,14 @@ import numpy as np import random import torch -import torchvision.transforms as transforms -from pymic.io.nifty_dataset import NiftyDataset from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.dice import DiceLoss -from pymic.loss.seg.ssl import TotalVariationLoss -from pymic.net_run.agent_seg import SegmentationAgent -from pymic.net_run_wsl.wsl_em import WSL_EntropyMinimization +from pymic.net_run_wsl.wsl_abstract import WSLSegAgent from pymic.util.ramps import sigmoid_rampup -class WSL_DMPLS(WSL_EntropyMinimization): +class WSLDMPLS(WSLSegAgent): """ Implementation of the following paper: Xiangde Luo, Minhao Hu, Wenjun Liao, Shuwei Zhai, Tao Song, Guotai Wang, @@ -28,7 +24,7 @@ def __init__(self, config, stage = 'train'): if net_type not in ['DualBranchUNet2D', 'DualBranchUNet3D']: raise ValueError("""For WSL_DMPLS, a dual branch network is expected. \ It only supports DualBranchUNet2D and DualBranchUNet3D currently.""") - super(WSL_DMPLS, self).__init__(config, stage) + super(WSLDMPLS, self).__init__(config, stage) def training(self): class_num = self.config['network']['class_num'] diff --git a/pymic/net_run_wsl/wsl_em.py b/pymic/net_run_wsl/wsl_em.py index 9534504..cc19600 100644 --- a/pymic/net_run_wsl/wsl_em.py +++ b/pymic/net_run_wsl/wsl_em.py @@ -2,24 +2,21 @@ from __future__ import print_function, division import logging import numpy as np -import random import torch -import torchvision.transforms as transforms -from pymic.io.nifty_dataset import NiftyDataset from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.ssl import EntropyLoss from pymic.net_run.agent_seg import SegmentationAgent -from pymic.transform.trans_dict import TransformDict +from pymic.net_run_wsl.wsl_abstract import WSLSegAgent from pymic.util.ramps import sigmoid_rampup -class WSL_EntropyMinimization(SegmentationAgent): +class WSLEntropyMinimization(WSLSegAgent): """ Training and testing agent for semi-supervised segmentation """ def __init__(self, config, stage = 'train'): - super(WSL_EntropyMinimization, self).__init__(config, stage) + super(WSLEntropyMinimization, self).__init__(config, stage) def training(self): class_num = self.config['network']['class_num'] @@ -85,27 +82,4 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} - return train_scalers - - def write_scalars(self, train_scalars, valid_scalars, glob_it): - loss_scalar ={'train':train_scalars['loss'], - 'valid':valid_scalars['loss']} - loss_sup_scalar = {'train':train_scalars['loss_sup']} - loss_upsup_scalar = {'train':train_scalars['loss_reg']} - dice_scalar ={'train':train_scalars['avg_dice'], 'valid':valid_scalars['avg_dice']} - self.summ_writer.add_scalars('loss', loss_scalar, glob_it) - self.summ_writer.add_scalars('loss_sup', loss_sup_scalar, glob_it) - self.summ_writer.add_scalars('loss_reg', loss_upsup_scalar, glob_it) - self.summ_writer.add_scalars('regular_w', {'regular_w':train_scalars['regular_w']}, glob_it) - self.summ_writer.add_scalars('dice', dice_scalar, glob_it) - class_num = self.config['network']['class_num'] - for c in range(class_num): - cls_dice_scalar = {'train':train_scalars['class_dice'][c], \ - 'valid':valid_scalars['class_dice'][c]} - self.summ_writer.add_scalars('class_{0:}_dice'.format(c), cls_dice_scalar, glob_it) - logging.info('train loss {0:.4f}, avg dice {1:.4f} '.format( - train_scalars['loss'], train_scalars['avg_dice']) + "[" + \ - ' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]") - logging.info('valid loss {0:.4f}, avg dice {1:.4f} '.format( - valid_scalars['loss'], valid_scalars['avg_dice']) + "[" + \ - ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") + return train_scalers \ No newline at end of file diff --git a/pymic/net_run_wsl/wsl_gatedcrf.py b/pymic/net_run_wsl/wsl_gatedcrf.py index 8e9c6de..af6c562 100644 --- a/pymic/net_run_wsl/wsl_gatedcrf.py +++ b/pymic/net_run_wsl/wsl_gatedcrf.py @@ -2,24 +2,20 @@ from __future__ import print_function, division import logging import numpy as np -import random import torch -import torchvision.transforms as transforms -from pymic.io.nifty_dataset import NiftyDataset from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.gatedcrf import ModelLossSemsegGatedCRF -from pymic.net_run.agent_seg import SegmentationAgent -from pymic.net_run_wsl.wsl_em import WSL_EntropyMinimization +from pymic.net_run_wsl.wsl_abstract import WSLSegAgent from pymic.util.ramps import sigmoid_rampup -class WSL_GatedCRF(WSL_EntropyMinimization): +class WSLGatedCRF(WSLSegAgent): """ Training and testing agent for semi-supervised segmentation """ def __init__(self, config, stage = 'train'): - super(WSL_GatedCRF, self).__init__(config, stage) + super(WSLGatedCRF, self).__init__(config, stage) # parameters for gated CRF wsl_cfg = self.config['weakly_supervised_learning'] w0 = wsl_cfg.get('GatedCRFLoss_W0'.lower(), 1.0) diff --git a/pymic/net_run_wsl/wsl_mumford_shah.py b/pymic/net_run_wsl/wsl_mumford_shah.py index 909a65b..f642e59 100644 --- a/pymic/net_run_wsl/wsl_mumford_shah.py +++ b/pymic/net_run_wsl/wsl_mumford_shah.py @@ -2,24 +2,20 @@ from __future__ import print_function, division import logging import numpy as np -import random import torch -import torchvision.transforms as transforms -from pymic.io.nifty_dataset import NiftyDataset from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.mumford_shah import MumfordShahLoss -from pymic.net_run.agent_seg import SegmentationAgent -from pymic.net_run_wsl.wsl_em import WSL_EntropyMinimization +from pymic.net_run_wsl.wsl_abstract import WSLSegAgent from pymic.util.ramps import sigmoid_rampup -class WSL_MumfordShah(WSL_EntropyMinimization): +class WSLMumfordShah(WSLSegAgent): """ Training and testing agent for semi-supervised segmentation """ def __init__(self, config, stage = 'train'): - super(WSL_MumfordShah, self).__init__(config, stage) + super(WSLMumfordShah, self).__init__(config, stage) def training(self): class_num = self.config['network']['class_num'] diff --git a/pymic/net_run_wsl/wsl_tv.py b/pymic/net_run_wsl/wsl_tv.py index f11c5e0..9492150 100644 --- a/pymic/net_run_wsl/wsl_tv.py +++ b/pymic/net_run_wsl/wsl_tv.py @@ -2,24 +2,20 @@ from __future__ import print_function, division import logging import numpy as np -import random import torch -import torchvision.transforms as transforms -from pymic.io.nifty_dataset import NiftyDataset from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.ssl import TotalVariationLoss -from pymic.net_run.agent_seg import SegmentationAgent -from pymic.net_run_wsl.wsl_em import WSL_EntropyMinimization +from pymic.net_run_wsl.wsl_abstract import WSLSegAgent from pymic.util.ramps import sigmoid_rampup -class WSL_TotalVariation(WSL_EntropyMinimization): +class WSLTotalVariation(WSLSegAgent): """ Training and testing agent for semi-supervised segmentation """ def __init__(self, config, stage = 'train'): - super(WSL_TotalVariation, self).__init__(config, stage) + super(WSLTotalVariation, self).__init__(config, stage) def training(self): class_num = self.config['network']['class_num'] diff --git a/pymic/net_run_wsl/wsl_ustm.py b/pymic/net_run_wsl/wsl_ustm.py index c7306e8..32e79df 100644 --- a/pymic/net_run_wsl/wsl_ustm.py +++ b/pymic/net_run_wsl/wsl_ustm.py @@ -5,27 +5,23 @@ import random import torch import torch.nn.functional as F -import torchvision.transforms as transforms -from pymic.io.nifty_dataset import NiftyDataset from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice -from pymic.loss.seg.ssl import EntropyLoss from pymic.net.net_dict_seg import SegNetDict -from pymic.net_run_wsl.wsl_em import WSL_EntropyMinimization -from pymic.transform.trans_dict import TransformDict +from pymic.net_run_wsl.wsl_abstract import WSLSegAgent from pymic.util.ramps import sigmoid_rampup -class WSL_USTM(WSL_EntropyMinimization): +class WSLUSTM(WSLSegAgent): """ Training and testing agent for semi-supervised segmentation """ def __init__(self, config, stage = 'train'): - super(WSL_USTM, self).__init__(config, stage) + super(WSLUSTM, self).__init__(config, stage) self.net_ema = None def create_network(self): - super(WSL_USTM, self).create_network() + super(WSLUSTM, self).create_network() if(self.net_ema is None): net_name = self.config['network']['net_type'] if(net_name not in SegNetDict): From 5f99937ec783526dc29e04224cb5129099dc7654 Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 4 Aug 2022 21:39:09 +0800 Subject: [PATCH 011/225] update nll method rename nll method --- pymic/net_run_nll/{cl.py => nll_cl.py} | 6 +++--- .../{co_teaching.py => nll_co_teaching.py} | 18 ++++++++---------- 2 files changed, 11 insertions(+), 13 deletions(-) rename pymic/net_run_nll/{cl.py => nll_cl.py} (97%) rename pymic/net_run_nll/{co_teaching.py => nll_co_teaching.py} (94%) diff --git a/pymic/net_run_nll/cl.py b/pymic/net_run_nll/nll_cl.py similarity index 97% rename from pymic/net_run_nll/cl.py rename to pymic/net_run_nll/nll_cl.py index de31e3e..8792ccd 100644 --- a/pymic/net_run_nll/cl.py +++ b/pymic/net_run_nll/nll_cl.py @@ -45,9 +45,9 @@ def get_confident_map(gt, pred, CL_type = 'both'): noise = cleanlab.pruning.get_noise_indices(gt, prob, prune_method=CL_type, n_jobs=1) return noise -class SegmentationAgentwithCL(SegmentationAgent): +class NLLConfidentLeran(SegmentationAgent): def __init__(self, config, stage = 'test'): - super(SegmentationAgentwithCL, self).__init__(config, stage) + super(NLLConfidentLeran, self).__init__(config, stage) def infer_with_cl(self): device_ids = self.config['testing']['gpus'] @@ -179,7 +179,7 @@ def main(): with_label= True, transform = data_transform ) - agent = SegmentationAgentwithCL(config, 'test') + agent = NLLConfidentLeran(config, 'test') agent.set_datasets(None, None, dataset) agent.transform_list = transform_list agent.run() diff --git a/pymic/net_run_nll/co_teaching.py b/pymic/net_run_nll/nll_co_teaching.py similarity index 94% rename from pymic/net_run_nll/co_teaching.py rename to pymic/net_run_nll/nll_co_teaching.py index 228e3bd..e1392eb 100644 --- a/pymic/net_run_nll/co_teaching.py +++ b/pymic/net_run_nll/nll_co_teaching.py @@ -29,16 +29,14 @@ import sys from pymic.util.parse_config import * -class CoTeachingAgent(SegmentationAgent): +class NLLCoTeaching(SegmentationAgent): """ - Using cross pseudo supervision according to the following paper: - Xiaokang Chen, Yuhui Yuan, Gang Zeng, Jingdong Wang, - Semi-Supervised Semantic Segmentation with Cross Pseudo Supervision, - CVPR 2021, pp. 2613-2022. - https://arxiv.org/abs/2106.01226 + Co-teaching: Robust Training of Deep Neural Networks with Extremely + Noisy Labels + https://arxiv.org/abs/1804.06872 """ def __init__(self, config, stage = 'train'): - super(CoTeachingAgent, self).__init__(config, stage) + super(NLLCoTeaching, self).__init__(config, stage) self.net2 = None self.optimizer2 = None self.scheduler2 = None @@ -48,7 +46,7 @@ def __init__(self, config, stage = 'train'): " coteaching, the specified loss {0:} is ingored".format(loss_type)) def create_network(self): - super(CoTeachingAgent, self).create_network() + super(NLLCoTeaching, self).create_network() if(self.net2 is None): net_name = self.config['network']['net_type'] if(net_name not in SegNetDict): @@ -74,7 +72,7 @@ def train_valid(self): self.config['training']['lr_milestones'], self.config['training']['lr_gamma'], last_epoch = last_iter) - super(CoTeachingAgent, self).train_valid() + super(NLLCoTeaching, self).train_valid() def training(self): class_num = self.config['network']['class_num'] @@ -211,5 +209,5 @@ def write_scalars(self, train_scalars, valid_scalars, glob_it): format='%(message)s') logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging_config(config) - agent = CoTeachingAgent(config, stage) + agent = NLLCoTeaching(config, stage) agent.run() \ No newline at end of file From 25bee6e88861635c482d4e1c94bb9c3264c75f3f Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 5 Aug 2022 22:00:24 +0800 Subject: [PATCH 012/225] Update index.md --- docs/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/index.md b/docs/index.md index 174cc33..ef7ab5b 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,6 +1,6 @@ ## Welcome to PyMIC -PyMIC is a pytorch-based toolkit for medical image computing with annotation-efficient deep learning. Despite that pytorch is a fantastic platform for deep learning, using it for medical image computing is not straightforward as medical images are often with high dimension and large volume, multiple modalities and difficulies in annotating. This toolkit is developed to facilitate medical image computing researchers so that training and testing deep learning models become easier. It is very friendly to researchers who are new to this area. Even without writing any code, you can use PyMIC commands to train and test a model by simply editing configuration files. PyMIC is developed to support learning with imperfect labels, including semi-supervised and weakly supervised learning, and learning with noisy annotations. +PyMIC is a Pytorch-based toolkit for medical image computing with annotation-efficient deep learning. Despite that pytorch is a fantastic platform for deep learning, using it for medical image computing is not straightforward as medical images are often with high dimension and large volume, multiple modalities and difficulies in annotating. This toolkit is developed to facilitate medical image computing researchers so that training and testing deep learning models become easier. It is very friendly to researchers who are new to this area. Even without writing any code, you can use PyMIC commands to train and test a model by simply editing configuration files. PyMIC is developed to support learning with imperfect labels, including semi-supervised and weakly supervised learning, and learning with noisy annotations. Currently PyMIC supports 2D/3D medical image classification and segmentation, and it is still under development. It was originally developed for COVID-19 pneumonia lesion segmentation from CT images. From 6cb0f22dcfdda1c068e0bd5722c466f9e0d660af Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 5 Aug 2022 22:10:41 +0800 Subject: [PATCH 013/225] update doc support Read the Docs --- .gitignore | 158 +++++++++++++++++++++++++++++++++++++++++- docs/API/index.html | 6 -- docs/Makefile | 21 +++++- docs/make.bat | 35 ++++++++++ docs/source/api.rst | 7 ++ docs/source/conf.py | 35 ++++++++++ docs/source/index.rst | 20 ++++++ docs/source/usage.rst | 34 +++++++++ pyproject.toml | 8 +++ 9 files changed, 316 insertions(+), 8 deletions(-) delete mode 100644 docs/API/index.html create mode 100644 docs/make.bat create mode 100644 docs/source/api.rst create mode 100644 docs/source/conf.py create mode 100644 docs/source/index.rst create mode 100644 docs/source/usage.rst create mode 100644 pyproject.toml diff --git a/.gitignore b/.gitignore index 81e73ef..f7da7ac 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,160 @@ build/* dist/* *egg*/* *stop* -files.txt \ No newline at end of file +files.txt + +# Created by https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks +# Edit at https://www.toptal.com/developers/gitignore?templates=python,jupyternotebooks + +### JupyterNotebooks ### +# gitignore template for Jupyter Notebooks +# website: http://jupyter.org/ + +.ipynb_checkpoints +*/.ipynb_checkpoints/* + +# IPython +profile_default/ +ipython_config.py + +# Remove previous ipynb_checkpoints +# git rm -r .ipynb_checkpoints/ + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook + +# IPython + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# End of https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks diff --git a/docs/API/index.html b/docs/API/index.html deleted file mode 100644 index f462d8d..0000000 --- a/docs/API/index.html +++ /dev/null @@ -1,6 +0,0 @@ - - - Codestin Search App - -

This is a subfolder of PyMCI document

- \ No newline at end of file diff --git a/docs/Makefile b/docs/Makefile index 0519ecb..d0c3cbf 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -1 +1,20 @@ - \ No newline at end of file +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000..9534b01 --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/source/api.rst b/docs/source/api.rst new file mode 100644 index 0000000..ec94338 --- /dev/null +++ b/docs/source/api.rst @@ -0,0 +1,7 @@ +API +=== + +.. autosummary:: + :toctree: generated + + lumache diff --git a/docs/source/conf.py b/docs/source/conf.py new file mode 100644 index 0000000..9096c8f --- /dev/null +++ b/docs/source/conf.py @@ -0,0 +1,35 @@ +# Configuration file for the Sphinx documentation builder. + +# -- Project information + +project = 'PyMIC' +copyright = '2021, HiLab' +author = 'HiLab' + +release = '0.1' +version = '0.1.0' + +# -- General configuration + +extensions = [ + 'sphinx.ext.duration', + 'sphinx.ext.doctest', + 'sphinx.ext.autodoc', + 'sphinx.ext.autosummary', + 'sphinx.ext.intersphinx', +] + +intersphinx_mapping = { + 'python': ('https://docs.python.org/3/', None), + 'sphinx': ('https://www.sphinx-doc.org/en/master/', None), +} +intersphinx_disabled_domains = ['std'] + +templates_path = ['_templates'] + +# -- Options for HTML output + +html_theme = 'sphinx_rtd_theme' + +# -- Options for EPUB output +epub_show_urls = 'footnote' diff --git a/docs/source/index.rst b/docs/source/index.rst new file mode 100644 index 0000000..4f822cc --- /dev/null +++ b/docs/source/index.rst @@ -0,0 +1,20 @@ +Welcome to PyMIC's documentation! +=================================== + +PyMIC is a pytorch-based toolkit for medical image computing with annotation-efficient deep learning. +PyMIC is developed to support learning with imperfect labels, including semi-supervised and weakly supervised learning, and learning with noisy annotations. + +Check out the :doc:`usage` section for further information, including +how to :ref:`installation` the project. + +.. note:: + + This project is under active development. + +Contents +-------- + +.. toctree:: + + usage + api diff --git a/docs/source/usage.rst b/docs/source/usage.rst new file mode 100644 index 0000000..924afcf --- /dev/null +++ b/docs/source/usage.rst @@ -0,0 +1,34 @@ +Usage +===== + +.. _installation: + +Installation +------------ + +To use Lumache, first install it using pip: + +.. code-block:: console + + (.venv) $ pip install lumache + +Creating recipes +---------------- + +To retrieve a list of random ingredients, +you can use the ``lumache.get_random_ingredients()`` function: + +.. autofunction:: lumache.get_random_ingredients + +The ``kind`` parameter should be either ``"meat"``, ``"fish"``, +or ``"veggies"``. Otherwise, :py:func:`lumache.get_random_ingredients` +will raise an exception. + +.. autoexception:: lumache.InvalidKindError + +For example: + +>>> import lumache +>>> lumache.get_random_ingredients() +['shells', 'gorgonzola', 'parsley'] + diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..4bbff29 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,8 @@ +[build-system] +requires = ["flit_core >=3.2,<4"] +build-backend = "flit_core.buildapi" + +[project] +name = "PyMIC" +authors = [{name = "Graziella", email = "graziella@lumache"}] +dynamic = ["version", "description"] From a778c6ad46dd3a7d8934d52dad40d72e419d626f Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 5 Aug 2022 22:30:24 +0800 Subject: [PATCH 014/225] Update index.rst test --- docs/source/index.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index 4f822cc..dddd5df 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -9,7 +9,7 @@ how to :ref:`installation` the project. .. note:: - This project is under active development. + This project is under active development. It will be updated later. Contents -------- From ee67842972f80841edcc5641fcfd9be0261e9ec9 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 6 Aug 2022 12:48:53 +0800 Subject: [PATCH 015/225] update test-time augmentation and post process --- pymic/net_run/infer_func.py | 2 +- pymic/util/image_process.py | 12 ++++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/pymic/net_run/infer_func.py b/pymic/net_run/infer_func.py index 35bfb4c..78184fe 100644 --- a/pymic/net_run/infer_func.py +++ b/pymic/net_run/infer_func.py @@ -140,7 +140,7 @@ def run(self, model, image): outputs4 = self.__infer(torch.flip(image, [-2, -1])) if(isinstance(outputs1, (tuple, list))): outputs = [] - for i in range(len(outputs)): + for i in range(len(outputs1)): temp_out1 = outputs1[i] temp_out2 = torch.flip(outputs2[i], [-2]) temp_out3 = torch.flip(outputs3[i], [-1]) diff --git a/pymic/util/image_process.py b/pymic/util/image_process.py index 61b577b..aa9f611 100644 --- a/pymic/util/image_process.py +++ b/pymic/util/image_process.py @@ -107,9 +107,9 @@ def crop_and_pad_ND_array_to_desired_shape(image, out_shape, pad_mod): return image_pad -def get_largest_component(image): +def get_largest_k_components(image, k = 1): """ - get the largest component from 2D or 3D binary image + get the largest K components from 2D or 3D binary image image: nd array """ dim = len(image.shape) @@ -124,8 +124,12 @@ def get_largest_component(image): raise ValueError("the dimension number should be 2 or 3") labeled_array, numpatches = ndimage.label(image, s) sizes = ndimage.sum(image, labeled_array, range(1, numpatches + 1)) - max_label = np.where(sizes == sizes.max())[0] + 1 - output = np.asarray(labeled_array == max_label, np.uint8) + sizes_sort = sorted(sizes, reverse = True) + kmin = min(k, numpatches) + output = np.zeros_like(image) + for i in range(kmin): + labeli = np.where(sizes == sizes_sort[i])[0] + 1 + output = output + np.asarray(labeled_array == labeli, np.uint8) return output def get_euclidean_distance(image, dim = 3, spacing = [1.0, 1.0, 1.0]): From 7235afd08f7e2e03bd164ad2a16e39f8c60c4696 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 6 Aug 2022 14:04:36 +0800 Subject: [PATCH 016/225] add post process for inference --- pymic/net_run/agent_seg.py | 14 ++++++++++++- pymic/util/image_process.py | 7 ++----- pymic/util/post_process.py | 40 +++++++++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 6 deletions(-) create mode 100644 pymic/util/post_process.py diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 0cc59b3..38c75ae 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -25,13 +25,16 @@ from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.transform.trans_dict import TransformDict +from pymic.util.post_process import PostProcessDict from pymic.util.image_process import convert_label from pymic.util.general import keyword_match class SegmentationAgent(NetRunAgent): def __init__(self, config, stage = 'train'): super(SegmentationAgent, self).__init__(config, stage) - self.transform_dict = TransformDict + self.transform_dict = TransformDict + self.postprocess_dict = PostProcessDict + self.postprocessor = None def get_stage_dataset_from_config(self, stage): assert(stage in ['train', 'valid', 'test']) @@ -155,6 +158,9 @@ def get_loss_value(self, data, pred, gt, param = None): loss_value = self.loss_calculator(loss_input_dict) return loss_value + def set_postprocessor(self, postprocessor): + self.postprocessor = postprocessor + def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] @@ -410,6 +416,9 @@ def test_time_dropout(m): infer_cfg = self.config['testing'] infer_cfg['class_num'] = self.config['network']['class_num'] self.inferer = Inferer(infer_cfg) + postpro_name = self.config['testing'].get('post_process', None) + if(self.postprocessor is None and postpro_name is not None): + self.postprocessor = PostProcessDict[postpro_name](self.config['testing']) infer_time_list = [] with torch.no_grad(): for data in self.test_loader: @@ -518,6 +527,9 @@ def save_ouputs(self, data): output = np.asarray(np.argmax(prob, axis = 1), np.uint8) if((label_source is not None) and (label_target is not None)): output = convert_label(output, label_source, label_target) + if(self.postprocessor is not None): + for i in range(len(names)): + output[i] = self.postprocessor(output[i]) # save the output and (optionally) probability predictions root_dir = self.config['dataset']['root_dir'] for i in range(len(names)): diff --git a/pymic/util/image_process.py b/pymic/util/image_process.py index aa9f611..896e8c1 100644 --- a/pymic/util/image_process.py +++ b/pymic/util/image_process.py @@ -116,12 +116,9 @@ def get_largest_k_components(image, k = 1): if(image.sum() == 0 ): print('the largest component is null') return image - if(dim == 2): - s = ndimage.generate_binary_structure(2,1) - elif(dim == 3): - s = ndimage.generate_binary_structure(3,1) - else: + if(dim < 2 or dim > 3): raise ValueError("the dimension number should be 2 or 3") + s = ndimage.generate_binary_structure(dim,1) labeled_array, numpatches = ndimage.label(image, s) sizes = ndimage.sum(image, labeled_array, range(1, numpatches + 1)) sizes_sort = sorted(sizes, reverse = True) diff --git a/pymic/util/post_process.py b/pymic/util/post_process.py new file mode 100644 index 0000000..da133ca --- /dev/null +++ b/pymic/util/post_process.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import, print_function + +import os +import numpy as np +import SimpleITK as sitk +from pymic.util.image_process import get_largest_k_components + +class PostProcess(object): + def __init__(self, params): + self.params = params + + def __call__(self, seg): + return seg + +class PostKeepLargestComponent(PostProcess): + def __init__(self, params): + super(PostKeepLargestComponent, self).__init__(params) + self.mode = params.get("KeepLargestComponent_mode".lower(), 1) + """ + mode = 1: keep the largest component of the union of foreground classes. + mode = 2: keep the largest component for each foreground class. + """ + + def __call__(self, seg): + if(self.mode == 1): + mask = np.asarray(seg > 0, np.uint8) + mask = get_largest_k_components(mask) + seg = seg * mask + elif(self.mode == 2): + class_num = seg.max() + output = np.zeros_like(seg) + for c in range(1, class_num + 1): + seg_c = np.asarray(seg == c, np.uint8) + seg_c = get_largest_k_components(seg_c) + output = output + seg_c * c + return seg + +PostProcessDict = { + 'KeepLargestComponent': PostKeepLargestComponent} \ No newline at end of file From 03d99646aa665423e8f3c0e260a22b7423b758f8 Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 8 Aug 2022 16:49:14 +0800 Subject: [PATCH 017/225] update ssl_em according to agent_seg, allow recording lr during training, and support ReduceLROnPlateau --- pymic/net_run_ssl/ssl_abstract.py | 7 ++++--- pymic/net_run_ssl/ssl_cps.py | 4 ++-- pymic/net_run_ssl/ssl_em.py | 4 +++- pymic/util/evaluation_seg.py | 6 +++--- 4 files changed, 12 insertions(+), 9 deletions(-) diff --git a/pymic/net_run_ssl/ssl_abstract.py b/pymic/net_run_ssl/ssl_abstract.py index acb3b5a..f1b97ba 100644 --- a/pymic/net_run_ssl/ssl_abstract.py +++ b/pymic/net_run_ssl/ssl_abstract.py @@ -28,7 +28,7 @@ def __init__(self, config, stage = 'train'): def get_unlabeled_dataset_from_config(self): root_dir = self.config['dataset']['root_dir'] - modal_num = self.config['dataset']['modal_num'] + modal_num = self.config['dataset'].get('modal_num', 1) transform_names = self.config['dataset']['train_transform_unlab'] self.transform_list = [] @@ -72,8 +72,8 @@ def worker_init_fn(worker_id): def training(self): pass - - def write_scalars(self, train_scalars, valid_scalars, glob_it): + + def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): loss_scalar ={'train':train_scalars['loss'], 'valid':valid_scalars['loss']} loss_sup_scalar = {'train':train_scalars['loss_sup']} @@ -83,6 +83,7 @@ def write_scalars(self, train_scalars, valid_scalars, glob_it): self.summ_writer.add_scalars('loss_sup', loss_sup_scalar, glob_it) self.summ_writer.add_scalars('loss_reg', loss_upsup_scalar, glob_it) self.summ_writer.add_scalars('regular_w', {'regular_w':train_scalars['regular_w']}, glob_it) + self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it) self.summ_writer.add_scalars('dice', dice_scalar, glob_it) class_num = self.config['network']['class_num'] for c in range(class_num): diff --git a/pymic/net_run_ssl/ssl_cps.py b/pymic/net_run_ssl/ssl_cps.py index 49a6e11..a878ea0 100644 --- a/pymic/net_run_ssl/ssl_cps.py +++ b/pymic/net_run_ssl/ssl_cps.py @@ -8,7 +8,7 @@ from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.util.ramps import sigmoid_rampup -from pymic.net_run.get_optimizer import get_optimiser +from pymic.net_run.get_optimizer import get_optimizer from pymic.net_run_ssl.ssl_abstract import SSLSegAgent from pymic.net.net_dict_seg import SegNetDict @@ -41,7 +41,7 @@ def create_network(self): def train_valid(self): # create optimizor for the second network if(self.optimizer2 is None): - self.optimizer2 = get_optimiser(self.config['training']['optimizer'], + self.optimizer2 = get_optimizer(self.config['training']['optimizer'], self.net2.parameters(), self.config['training']) last_iter = -1 diff --git a/pymic/net_run_ssl/ssl_em.py b/pymic/net_run_ssl/ssl_em.py index 28a9dec..bb1cd55 100644 --- a/pymic/net_run_ssl/ssl_em.py +++ b/pymic/net_run_ssl/ssl_em.py @@ -10,6 +10,7 @@ from pymic.net_run_ssl.ssl_abstract import SSLSegAgent from pymic.transform.trans_dict import TransformDict from pymic.util.ramps import sigmoid_rampup +from pymic.util.general import keyword_match class SSLEntropyMinimization(SSLSegAgent): """ @@ -73,7 +74,8 @@ def training(self): # if (self.config['training']['use']) loss.backward() self.optimizer.step() - self.scheduler.step() + if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + self.scheduler.step() train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() diff --git a/pymic/util/evaluation_seg.py b/pymic/util/evaluation_seg.py index 61ae51c..b04880a 100644 --- a/pymic/util/evaluation_seg.py +++ b/pymic/util/evaluation_seg.py @@ -74,7 +74,7 @@ def get_edge_points(img): return edge -def binary_hausdorff95(s, g, spacing = None): +def binary_hd95(s, g, spacing = None): """ get the hausdorff distance between a binary segmentation and the ground truth inputs: @@ -165,8 +165,8 @@ def get_binary_evaluation_score(s_volume, g_volume, spacing, metric): elif(metric_lower == 'assd'): score = binary_assd(s_volume, g_volume, spacing) - elif(metric_lower == "hausdorff95"): - score = binary_hausdorff95(s_volume, g_volume, spacing) + elif(metric_lower == "hd95"): + score = binary_hd95(s_volume, g_volume, spacing) elif(metric_lower == "rve"): score = binary_relative_volume_error(s_volume, g_volume) From e87c0fe59adbeca0eb8b0c891dd488ff0893e0e6 Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 9 Aug 2022 11:54:57 +0800 Subject: [PATCH 018/225] update ssl method add lr scheduler --- pymic/net_run_ssl/ssl_cps.py | 36 ++++++++++++++++++++--------------- pymic/net_run_ssl/ssl_main.py | 4 ++-- pymic/net_run_ssl/ssl_mt.py | 6 ++++-- pymic/net_run_ssl/ssl_uamt.py | 7 +++++-- pymic/net_run_ssl/ssl_urpc.py | 7 ++++--- 5 files changed, 36 insertions(+), 24 deletions(-) diff --git a/pymic/net_run_ssl/ssl_cps.py b/pymic/net_run_ssl/ssl_cps.py index a878ea0..582cdc0 100644 --- a/pymic/net_run_ssl/ssl_cps.py +++ b/pymic/net_run_ssl/ssl_cps.py @@ -7,12 +7,13 @@ from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice -from pymic.util.ramps import sigmoid_rampup -from pymic.net_run.get_optimizer import get_optimizer +from pymic.net_run.get_optimizer import get_optimizer, get_lr_scheduler from pymic.net_run_ssl.ssl_abstract import SSLSegAgent from pymic.net.net_dict_seg import SegNetDict +from pymic.util.ramps import sigmoid_rampup +from pymic.util.general import keyword_match -class SSLCrossPseudoSupervision(SSLSegAgent): +class SSLCPS(SSLSegAgent): """ Using cross pseudo supervision according to the following paper: Xiaokang Chen, Yuhui Yuan, Gang Zeng, Jingdong Wang, @@ -21,13 +22,13 @@ class SSLCrossPseudoSupervision(SSLSegAgent): https://arxiv.org/abs/2106.01226 """ def __init__(self, config, stage = 'train'): - super(SSLCrossPseudoSupervision, self).__init__(config, stage) + super(SSLCPS, self).__init__(config, stage) self.net2 = None self.optimizer2 = None self.scheduler2 = None def create_network(self): - super(SSLCrossPseudoSupervision, self).create_network() + super(SSLCPS, self).create_network() if(self.net2 is None): net_name = self.config['network']['net_type'] if(net_name not in SegNetDict): @@ -40,20 +41,18 @@ def create_network(self): def train_valid(self): # create optimizor for the second network + opt_params = self.config['training'] if(self.optimizer2 is None): - self.optimizer2 = get_optimizer(self.config['training']['optimizer'], - self.net2.parameters(), - self.config['training']) + self.optimizer2 = get_optimizer(opt_params['optimizer'], + self.net2.parameters(), opt_params) last_iter = -1 # if(self.checkpoint is not None): # self.optimizer2.load_state_dict(self.checkpoint['optimizer_state_dict']) # last_iter = self.checkpoint['iteration'] - 1 if(self.scheduler2 is None): - self.scheduler2 = optim.lr_scheduler.MultiStepLR(self.optimizer2, - self.config['training']['lr_milestones'], - self.config['training']['lr_gamma'], - last_epoch = last_iter) - super(SSLCrossPseudoSupervision, self).train_valid() + opt_params["laster_iter"] = last_iter + self.scheduler2 = get_lr_scheduler(self.optimizer, opt_params) + super(SSLCPS, self).train_valid() def training(self): class_num = self.config['network']['class_num'] @@ -121,9 +120,10 @@ def training(self): loss.backward() self.optimizer.step() - self.scheduler.step() self.optimizer2.step() - self.scheduler2.step() + if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + self.scheduler.step() + self.scheduler2.step() train_loss = train_loss + loss.item() train_loss_sup1 = train_loss_sup1 + loss_sup1.item() @@ -152,6 +152,12 @@ def training(self): 'loss_pse_sup1':train_avg_loss_pse_sup1, 'loss_pse_sup2': train_avg_loss_pse_sup2, 'regular_w':regular_w, 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} return train_scalers + + def validation(self): + return_value = super(SSLCPS, self).validation() + if(keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + self.scheduler2.step(return_value['avg_dice']) + return return_value def write_scalars(self, train_scalars, valid_scalars, glob_it): loss_scalar ={'train':train_scalars['loss'], diff --git a/pymic/net_run_ssl/ssl_main.py b/pymic/net_run_ssl/ssl_main.py index cf5a8cd..54492ae 100644 --- a/pymic/net_run_ssl/ssl_main.py +++ b/pymic/net_run_ssl/ssl_main.py @@ -9,13 +9,13 @@ from pymic.net_run_ssl.ssl_mt import SSLMeanTeacher from pymic.net_run_ssl.ssl_uamt import SSLUncertaintyAwareMeanTeacher from pymic.net_run_ssl.ssl_urpc import SSLURPC -from pymic.net_run_ssl.ssl_cps import SSLCrossPseudoSupervision +from pymic.net_run_ssl.ssl_cps import SSLCPS SSLMethodDict = {'EntropyMinimization': SSLEntropyMinimization, 'MeanTeacher': SSLMeanTeacher, 'UAMT': SSLUncertaintyAwareMeanTeacher, 'URPC': SSLURPC, - 'CPS': SSLCrossPseudoSupervision} + 'CPS': SSLCPS} def main(): if(len(sys.argv) < 3): diff --git a/pymic/net_run_ssl/ssl_mt.py b/pymic/net_run_ssl/ssl_mt.py index b96e0b1..7905968 100644 --- a/pymic/net_run_ssl/ssl_mt.py +++ b/pymic/net_run_ssl/ssl_mt.py @@ -6,9 +6,10 @@ from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice -from pymic.util.ramps import sigmoid_rampup from pymic.net_run_ssl.ssl_abstract import SSLSegAgent from pymic.net.net_dict_seg import SegNetDict +from pymic.util.ramps import sigmoid_rampup +from pymic.util.general import keyword_match class SSLMeanTeacher(SSLSegAgent): """ @@ -89,7 +90,8 @@ def training(self): loss.backward() self.optimizer.step() - self.scheduler.step() + if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + self.scheduler.step() # update EMA alpha = ssl_cfg.get('ema_decay', 0.99) diff --git a/pymic/net_run_ssl/ssl_uamt.py b/pymic/net_run_ssl/ssl_uamt.py index d1de32f..3352231 100644 --- a/pymic/net_run_ssl/ssl_uamt.py +++ b/pymic/net_run_ssl/ssl_uamt.py @@ -6,8 +6,9 @@ from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice -from pymic.util.ramps import sigmoid_rampup from pymic.net_run_ssl.ssl_mt import SSLMeanTeacher +from pymic.util.ramps import sigmoid_rampup +from pymic.util.general import keyword_match class SSLUncertaintyAwareMeanTeacher(SSLMeanTeacher): """ @@ -97,7 +98,9 @@ def training(self): loss.backward() self.optimizer.step() - self.scheduler.step() + if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + self.scheduler.step() + # update EMA alpha = ssl_cfg.get('ema_decay', 0.99) diff --git a/pymic/net_run_ssl/ssl_urpc.py b/pymic/net_run_ssl/ssl_urpc.py index d9dd953..0513dc2 100644 --- a/pymic/net_run_ssl/ssl_urpc.py +++ b/pymic/net_run_ssl/ssl_urpc.py @@ -7,8 +7,9 @@ from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice -from pymic.util.ramps import sigmoid_rampup from pymic.net_run_ssl.ssl_abstract import SSLSegAgent +from pymic.util.ramps import sigmoid_rampup +from pymic.util.general import keyword_match class SSLURPC(SSLSegAgent): """ @@ -90,8 +91,8 @@ def training(self): loss.backward() self.optimizer.step() - self.scheduler.step() - + if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + self.scheduler.step() train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() train_loss_reg = train_loss_reg + loss_reg.item() From e057152fe4d6aaffc933c684bd26925ddfd3a95e Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 9 Aug 2022 17:27:27 +0800 Subject: [PATCH 019/225] update network and ssl cct allow dropout in the decoder add CCT for SSL --- pymic/net/net2d/unet2d.py | 8 +- pymic/net/net2d/unet2d_cct.py | 195 ++++++++++++++++++++++++++++++++ pymic/net/net2d/unet2d_scse.py | 8 +- pymic/net/net3d/unet2d5.py | 8 +- pymic/net/net3d/unet3d.py | 8 +- pymic/net/net3d/unet3d_scse.py | 8 +- pymic/net/net_dict_seg.py | 2 + pymic/net_run/agent_abstract.py | 2 +- pymic/net_run_ssl/ssl_cct.py | 156 +++++++++++++++++++++++++ pymic/net_run_ssl/ssl_cps.py | 5 +- pymic/net_run_ssl/ssl_main.py | 9 +- 11 files changed, 383 insertions(+), 26 deletions(-) create mode 100644 pymic/net/net2d/unet2d_cct.py create mode 100644 pymic/net_run_ssl/ssl_cct.py diff --git a/pymic/net/net2d/unet2d.py b/pymic/net/net2d/unet2d.py index 0cc607f..703ced3 100644 --- a/pymic/net/net2d/unet2d.py +++ b/pymic/net/net2d/unet2d.py @@ -91,10 +91,10 @@ def __init__(self, params): self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) if(len(self.ft_chns) == 5): self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], 0.0, self.bilinear) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], 0.0, self.bilinear) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], 0.0, self.bilinear) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], 0.0, self.bilinear) + self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.bilinear) + self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.bilinear) + self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.bilinear) + self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.bilinear) self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 1) if(self.deep_sup): diff --git a/pymic/net/net2d/unet2d_cct.py b/pymic/net/net2d/unet2d_cct.py new file mode 100644 index 0000000..88a369f --- /dev/null +++ b/pymic/net/net2d/unet2d_cct.py @@ -0,0 +1,195 @@ +# -*- coding: utf-8 -*- +""" +An modification the U-Net with auxiliary decoders according to +the CCT paper: + Yassine Ouali, Celine Hudelot and Myriam Tami: + Semi-Supervised Semantic Segmentation With Cross-Consistency Training. + CVPR 2020. + https://arxiv.org/abs/2003.09005 +Code adapted from: https://github.com/yassouali/CCT +""" +from __future__ import print_function, division + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from torch.distributions.uniform import Uniform +from pymic.net.net2d.unet2d import ConvBlock, DownBlock, UpBlock + +class Encoder(nn.Module): + def __init__(self, params): + super(Encoder, self).__init__() + self.params = params + self.in_chns = self.params['in_chns'] + self.ft_chns = self.params['feature_chns'] + self.dropout = self.params['dropout'] + assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) + + self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) + self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) + self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) + self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) + if(len(self.ft_chns) == 5): + self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) + + def forward(self, x): + x0 = self.in_conv(x) + x1 = self.down1(x0) + x2 = self.down2(x1) + x3 = self.down3(x2) + output = [x0, x1, x2, x3] + if(len(self.ft_chns) == 5): + x4 = self.down4(x3) + output.append(x4) + return output + +class Decoder(nn.Module): + def __init__(self, params): + super(Decoder, self).__init__() + self.params = params + self.in_chns = self.params['in_chns'] + self.ft_chns = self.params['feature_chns'] + self.dropout = self.params['dropout'] + self.n_class = self.params['class_num'] + self.bilinear = self.params['bilinear'] + + assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) + + if(len(self.ft_chns) == 5): + self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.bilinear) + self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.bilinear) + self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.bilinear) + self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.bilinear) + self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 1) + + def forward(self, x): + if(len(self.ft_chns) == 5): + assert(len(x) == 5) + x0, x1, x2, x3, x4 = x + x_d3 = self.up1(x4, x3) + else: + assert(len(x) == 4) + x0, x1, x2, x3 = x + x_d3 = x3 + x_d2 = self.up2(x_d3, x2) + x_d1 = self.up3(x_d2, x1) + x_d0 = self.up4(x_d1, x0) + output = self.out_conv(x_d0) + return output + +def _l2_normalize(d): + # Normalizing per batch axis + d_reshaped = d.view(d.shape[0], -1, *(1 for _ in range(d.dim() - 2))) + d /= torch.norm(d_reshaped, dim=1, keepdim=True) + 1e-8 + return d + + + +def get_r_adv(x_list, decoder, it=1, xi=1e-1, eps=10.0): + """ + Virtual Adversarial Training according to + https://arxiv.org/abs/1704.03976 + """ + x_detached = [item.detach() for item in x_list] + xe_detached = x_detached[-1] + with torch.no_grad(): + pred = F.softmax(decoder(x_detached), dim=1) + + d = torch.rand(x_list[-1].shape).sub(0.5).to(x_list[-1].device) + d = _l2_normalize(d) + + for _ in range(it): + d.requires_grad_() + x_detached[-1] = xe_detached + xi * d + pred_hat = decoder(x_detached) + logp_hat = F.log_softmax(pred_hat, dim=1) + adv_distance = F.kl_div(logp_hat, pred, reduction='batchmean') + adv_distance.backward() + d = _l2_normalize(d.grad) + decoder.zero_grad() + + r_adv = d * eps + return x_list[-1] + r_adv + + +class AuxiliaryDecoder(nn.Module): + def __init__(self, params, aux_type): + super(AuxiliaryDecoder, self).__init__() + self.params = params + self.decoder = Decoder(params) + self.aux_type = aux_type + uniform_range = params.get("Uniform_range".lower(), 0.3) + self.uni_dist = Uniform(-uniform_range, uniform_range) + + def feature_drop(self, x): + attention = torch.mean(x, dim=1, keepdim=True) + max_val, _ = torch.max(attention.view(x.size(0), -1), dim=1, keepdim=True) + threshold = max_val * np.random.uniform(0.7, 0.9) + threshold = threshold.view(x.size(0), 1, 1, 1).expand_as(attention) + drop_mask = (attention < threshold).float() + return x.mul(drop_mask) + + def feature_based_noise(self, x): + noise_vector = self.uni_dist.sample(x.shape[1:]).to(x.device).unsqueeze(0) + x_noise = x.mul(noise_vector) + x + return x_noise + + def forward(self, x): + if(self.aux_type == "DropOut"): + pass + elif(self.aux_type == "FeatureDrop"): + x[-1] = self.feature_drop(x[-1]) + elif(self.aux_type == "FeatureNoise"): + x[-1] = self.feature_based_noise(x[-1]) + elif(self.aux_type == "VAT"): + it = self.params.get("VAT_it".lower(), 2) + xi = self.params.get("VAT_xi".lower(), 1e-6) + eps= self.params.get("VAT_eps".lower(), 2.0) + x[-1] = get_r_adv(x, self.decoder, it, xi, eps) + else: + raise ValueError("Undefined auxiliary decoder type {0:}".format(self.aux_type)) + + output = self.decoder(x) + return output + + +class UNet2D_CCT(nn.Module): + def __init__(self, params): + super(UNet2D_CCT, self).__init__() + self.params = params + self.encoder = Encoder(params) + self.decoder = Decoder(params) + aux_names = params.get("CCT_aux_decoders".lower(), None) + if aux_names is None: + aux_names = ["DropOut", "FeatureDrop", "FeatureNoise", "VAT"] + aux_decoders = [] + for aux_name in aux_names: + aux_decoders.append(AuxiliaryDecoder(params, aux_name)) + self.aux_decoders = nn.ModuleList(aux_decoders) + + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + + f = self.encoder(x) + output = self.decoder(f) + if(len(x_shape) == 5): + new_shape = [N, D] + list(output.shape)[1:] + output = torch.reshape(output, new_shape) + output = torch.transpose(output, 1, 2) + + if(self.training): + aux_outputs = [aux_d(f) for aux_d in self.aux_decoders] + if(len(x_shape) == 5): + for i in range(len(aux_outputs)): + aux_outi = torch.reshape(aux_outputs[i], new_shape) + aux_outputs[i] = torch.transpose(aux_outi, 1, 2) + return output, aux_outputs + else: + return output \ No newline at end of file diff --git a/pymic/net/net2d/unet2d_scse.py b/pymic/net/net2d/unet2d_scse.py index 95b25b1..6f9d3f7 100644 --- a/pymic/net/net2d/unet2d_scse.py +++ b/pymic/net/net2d/unet2d_scse.py @@ -79,10 +79,10 @@ def __init__(self, params): self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p = 0.0) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p = 0.0) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p = 0.0) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p = 0.0) + self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p = self.dropout[3]) + self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p = self.dropout[2]) + self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p = self.dropout[1]) + self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p = self.dropout[0]) self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 3, padding = 1) diff --git a/pymic/net/net3d/unet2d5.py b/pymic/net/net3d/unet2d5.py index 4ed8d7e..9e6a72d 100644 --- a/pymic/net/net3d/unet2d5.py +++ b/pymic/net/net3d/unet2d5.py @@ -149,13 +149,13 @@ def __init__(self, params): self.block3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dims[3], self.dropout[3], True) self.block4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dims[4], self.dropout[4], False) self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], - self.dims[3], dropout_p = 0.0, bilinear = self.bilinear) + self.dims[3], dropout_p = self.dropout[3], bilinear = self.bilinear) self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], - self.dims[2], dropout_p = 0.0, bilinear = self.bilinear) + self.dims[2], dropout_p = self.dropout[2], bilinear = self.bilinear) self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], - self.dims[1], dropout_p = 0.0, bilinear = self.bilinear) + self.dims[1], dropout_p = self.dropout[1], bilinear = self.bilinear) self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], - self.dims[0], dropout_p = 0.0, bilinear = self.bilinear) + self.dims[0], dropout_p = self.dropout[0], bilinear = self.bilinear) self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = (1, 3, 3), padding = (0, 1, 1)) diff --git a/pymic/net/net3d/unet3d.py b/pymic/net/net3d/unet3d.py index a37204e..058cb79 100644 --- a/pymic/net/net3d/unet3d.py +++ b/pymic/net/net3d/unet3d.py @@ -106,13 +106,13 @@ def __init__(self, params): if(len(self.ft_chns) == 5): self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], - dropout_p = 0.0, trilinear=self.trilinear) + dropout_p = self.dropout[3], trilinear=self.trilinear) self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], - dropout_p = 0.0, trilinear=self.trilinear) + dropout_p = self.dropout[2], trilinear=self.trilinear) self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], - dropout_p = 0.0, trilinear=self.trilinear) + dropout_p = self.dropout[1], trilinear=self.trilinear) self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], - dropout_p = 0.0, trilinear=self.trilinear) + dropout_p = self.dropout[0], trilinear=self.trilinear) self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 1) if(self.deep_sup): diff --git a/pymic/net/net3d/unet3d_scse.py b/pymic/net/net3d/unet3d_scse.py index 5832830..0f15e25 100644 --- a/pymic/net/net3d/unet3d_scse.py +++ b/pymic/net/net3d/unet3d_scse.py @@ -78,10 +78,10 @@ def __init__(self, params): self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p = 0.0) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p = 0.0) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p = 0.0) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p = 0.0) + self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p = self.dropout[3]) + self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p = self.dropout[2]) + self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p = self.dropout[1]) + self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p = self.dropout[0]) self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 3, padding = 1) diff --git a/pymic/net/net_dict_seg.py b/pymic/net/net_dict_seg.py index 55711d4..aa912bc 100644 --- a/pymic/net/net_dict_seg.py +++ b/pymic/net/net_dict_seg.py @@ -3,6 +3,7 @@ from pymic.net.net2d.unet2d import UNet2D from pymic.net.net2d.unet2d_dual_branch import DualBranchUNet2D from pymic.net.net2d.unet2d_urpc import UNet2D_URPC +from pymic.net.net2d.unet2d_cct import UNet2D_CCT from pymic.net.net2d.cople_net import COPLENet from pymic.net.net2d.unet2d_attention import AttentionUNet2D from pymic.net.net2d.unet2d_nest import NestedUNet2D @@ -15,6 +16,7 @@ 'UNet2D': UNet2D, 'DualBranchUNet2D': DualBranchUNet2D, 'UNet2D_URPC': UNet2D_URPC, + 'UNet2D_CCT': UNet2D_CCT, 'COPLENet': COPLENet, 'AttentionUNet2D': AttentionUNet2D, 'NestedUNet2D': NestedUNet2D, diff --git a/pymic/net_run/agent_abstract.py b/pymic/net_run/agent_abstract.py index d701534..8ffadc1 100644 --- a/pymic/net_run/agent_abstract.py +++ b/pymic/net_run/agent_abstract.py @@ -157,7 +157,7 @@ def create_optimizer(self, params): self.optimizer.load_state_dict(self.checkpoint['optimizer_state_dict']) last_iter = self.checkpoint['iteration'] - 1 if(self.scheduler is None): - opt_params["laster_iter"] = last_iter + opt_params["last_iter"] = last_iter self.scheduler = get_lr_scheduler(self.optimizer, opt_params) def convert_tensor_type(self, input_tensor): diff --git a/pymic/net_run_ssl/ssl_cct.py b/pymic/net_run_ssl/ssl_cct.py new file mode 100644 index 0000000..80e9d07 --- /dev/null +++ b/pymic/net_run_ssl/ssl_cct.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import logging +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from pymic.loss.seg.util import get_soft_label +from pymic.loss.seg.util import reshape_prediction_and_ground_truth +from pymic.loss.seg.util import get_classwise_dice +from pymic.net_run_ssl.ssl_abstract import SSLSegAgent +from pymic.util.ramps import sigmoid_rampup +from pymic.util.general import keyword_match + +def softmax_mse_loss(inputs, targets, conf_mask=False, threshold=None, use_softmax=False): + assert inputs.requires_grad == True and targets.requires_grad == False + assert inputs.size() == targets.size() # (batch_size * num_classes * H * W) + inputs = F.softmax(inputs, dim=1) + if use_softmax: + targets = F.softmax(targets, dim=1) + + if conf_mask: + loss_mat = F.mse_loss(inputs, targets, reduction='none') + mask = (targets.max(1)[0] > threshold) + loss_mat = loss_mat[mask.unsqueeze(1).expand_as(loss_mat)] + if loss_mat.shape.numel() == 0: loss_mat = torch.tensor([0.]).to(inputs.device) + return loss_mat.mean() + else: + return F.mse_loss(inputs, targets, reduction='mean') # take the mean over the batch_size + + +def softmax_kl_loss(inputs, targets, conf_mask=False, threshold=None, use_softmax=False): + assert inputs.requires_grad == True and targets.requires_grad == False + assert inputs.size() == targets.size() + input_log_softmax = F.log_softmax(inputs, dim=1) + if use_softmax: + targets = F.softmax(targets, dim=1) + + if conf_mask: + loss_mat = F.kl_div(input_log_softmax, targets, reduction='none') + mask = (targets.max(1)[0] > threshold) + loss_mat = loss_mat[mask.unsqueeze(1).expand_as(loss_mat)] + if loss_mat.shape.numel() == 0: loss_mat = torch.tensor([0.]).to(inputs.device) + return loss_mat.sum() / mask.shape.numel() + else: + return F.kl_div(input_log_softmax, targets, reduction='mean') + + +def softmax_js_loss(inputs, targets, **_): + assert inputs.requires_grad == True and targets.requires_grad == False + assert inputs.size() == targets.size() + epsilon = 1e-5 + + M = (F.softmax(inputs, dim=1) + targets) * 0.5 + kl1 = F.kl_div(F.log_softmax(inputs, dim=1), M, reduction='mean') + kl2 = F.kl_div(torch.log(targets+epsilon), M, reduction='mean') + return (kl1 + kl2) * 0.5 + +unsup_loss_dict = {"MSE": softmax_mse_loss, + "KL":softmax_kl_loss, + "JS":softmax_js_loss} + +class SSLCCT(SSLSegAgent): + """ + Cross-Consistency Training according to the following paper: + Yassine Ouali, Celine Hudelot and Myriam Tami: + Semi-Supervised Semantic Segmentation With Cross-Consistency Training. + CVPR 2020. + https://arxiv.org/abs/2003.09005 + Code adapted from: https://github.com/yassouali/CCT + """ + def training(self): + class_num = self.config['network']['class_num'] + iter_valid = self.config['training']['iter_valid'] + ssl_cfg = self.config['semi_supervised_learning'] + unsup_loss_name = ssl_cfg.get('unsupervised_loss', "MSE") + self.unsup_loss_f = unsup_loss_dict[unsup_loss_name] + train_loss = 0 + train_loss_sup = 0 + train_loss_reg = 0 + train_dice_list = [] + self.net.train() + + for it in range(iter_valid): + try: + data_lab = next(self.trainIter) + except StopIteration: + self.trainIter = iter(self.train_loader) + data_lab = next(self.trainIter) + try: + data_unlab = next(self.trainIter_unlab) + except StopIteration: + self.trainIter_unlab = iter(self.train_loader_unlab) + data_unlab = next(self.trainIter_unlab) + + # get the inputs + x0 = self.convert_tensor_type(data_lab['image']) + y0 = self.convert_tensor_type(data_lab['label_prob']) + x1 = self.convert_tensor_type(data_unlab['image']) + inputs = torch.cat([x0, x1], dim = 0) + inputs, y0 = inputs.to(self.device), y0.to(self.device) + + # zero the parameter gradients + self.optimizer.zero_grad() + + # forward pass + output, aux_outputs = self.net(inputs) + n0 = list(x0.shape)[0] + + # get supervised loss + p0 = output[:n0] + loss_sup = self.get_loss_value(data_lab, p0, y0) + + # get regularization loss + p1 = F.softmax(output[n0:].detach(), dim=1) + p1_aux = [aux_out[n0:] for aux_out in aux_outputs] + loss_reg = 0.0 + for p1_auxi in p1_aux: + loss_reg += self.unsup_loss_f( p1_auxi, p1, use_softmax = True) + loss_reg = loss_reg / len(p1_aux) + + iter_max = self.config['training']['iter_max'] + ramp_up_length = ssl_cfg.get('ramp_up_length', iter_max) + regular_w = 0.0 + if(self.glob_it > ssl_cfg.get('iter_sup', 0)): + regular_w = ssl_cfg.get('regularize_w', 0.1) + if(ramp_up_length is not None and self.glob_it < ramp_up_length): + regular_w = regular_w * sigmoid_rampup(self.glob_it, ramp_up_length) + + loss = loss_sup + regular_w*loss_reg + + loss.backward() + self.optimizer.step() + if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + self.scheduler.step() + train_loss = train_loss + loss.item() + train_loss_sup = train_loss_sup + loss_sup.item() + train_loss_reg = train_loss_reg + loss_reg.item() + # get dice evaluation for each class in annotated images + if(isinstance(p0, tuple) or isinstance(p0, list)): + p0 = p0[0] + p0_argmax = torch.argmax(p0, dim = 1, keepdim = True) + p0_soft = get_soft_label(p0_argmax, class_num, self.tensor_type) + p0_soft, y0 = reshape_prediction_and_ground_truth(p0_soft, y0) + dice_list = get_classwise_dice(p0_soft, y0) + train_dice_list.append(dice_list.cpu().numpy()) + train_avg_loss = train_loss / iter_valid + train_avg_loss_sup = train_loss_sup / iter_valid + train_avg_loss_reg = train_loss_reg / iter_valid + train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) + train_avg_dice = train_cls_dice.mean() + + train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, + 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, + 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} + return train_scalers diff --git a/pymic/net_run_ssl/ssl_cps.py b/pymic/net_run_ssl/ssl_cps.py index 582cdc0..e39bec9 100644 --- a/pymic/net_run_ssl/ssl_cps.py +++ b/pymic/net_run_ssl/ssl_cps.py @@ -51,7 +51,7 @@ def train_valid(self): # last_iter = self.checkpoint['iteration'] - 1 if(self.scheduler2 is None): opt_params["laster_iter"] = last_iter - self.scheduler2 = get_lr_scheduler(self.optimizer, opt_params) + self.scheduler2 = get_lr_scheduler(self.optimizer2, opt_params) super(SSLCPS, self).train_valid() def training(self): @@ -159,7 +159,7 @@ def validation(self): self.scheduler2.step(return_value['avg_dice']) return return_value - def write_scalars(self, train_scalars, valid_scalars, glob_it): + def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): loss_scalar ={'train':train_scalars['loss'], 'valid':valid_scalars['loss']} loss_sup_scalar = {'net1':train_scalars['loss_sup1'], @@ -171,6 +171,7 @@ def write_scalars(self, train_scalars, valid_scalars, glob_it): self.summ_writer.add_scalars('loss_sup', loss_sup_scalar, glob_it) self.summ_writer.add_scalars('loss_pseudo_sup', loss_pse_sup_scalar, glob_it) self.summ_writer.add_scalars('regular_w', {'regular_w':train_scalars['regular_w']}, glob_it) + self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it) self.summ_writer.add_scalars('dice', dice_scalar, glob_it) class_num = self.config['network']['class_num'] for c in range(class_num): diff --git a/pymic/net_run_ssl/ssl_main.py b/pymic/net_run_ssl/ssl_main.py index 54492ae..6bddf29 100644 --- a/pymic/net_run_ssl/ssl_main.py +++ b/pymic/net_run_ssl/ssl_main.py @@ -8,14 +8,17 @@ from pymic.net_run_ssl.ssl_em import SSLEntropyMinimization from pymic.net_run_ssl.ssl_mt import SSLMeanTeacher from pymic.net_run_ssl.ssl_uamt import SSLUncertaintyAwareMeanTeacher -from pymic.net_run_ssl.ssl_urpc import SSLURPC +from pymic.net_run_ssl.ssl_cct import SSLCCT from pymic.net_run_ssl.ssl_cps import SSLCPS +from pymic.net_run_ssl.ssl_urpc import SSLURPC + SSLMethodDict = {'EntropyMinimization': SSLEntropyMinimization, 'MeanTeacher': SSLMeanTeacher, 'UAMT': SSLUncertaintyAwareMeanTeacher, - 'URPC': SSLURPC, - 'CPS': SSLCPS} + 'CCT': SSLCCT, + 'CPS': SSLCPS, + 'URPC': SSLURPC} def main(): if(len(sys.argv) < 3): From 153afe56f548d2eea55b65368510a404fd92c00a Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 11 Aug 2022 10:00:56 +0800 Subject: [PATCH 020/225] update reference for ssl methods --- pymic/net_run_ssl/ssl_em.py | 5 +++-- pymic/net_run_ssl/ssl_mt.py | 6 +++++- pymic/net_run_ssl/ssl_urpc.py | 11 +++++------ 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/pymic/net_run_ssl/ssl_em.py b/pymic/net_run_ssl/ssl_em.py index bb1cd55..e4b1f19 100644 --- a/pymic/net_run_ssl/ssl_em.py +++ b/pymic/net_run_ssl/ssl_em.py @@ -15,9 +15,10 @@ class SSLEntropyMinimization(SSLSegAgent): """ Implementation of the following paper: - Yves Grandvalet and Yoshua Bengio, + Yves Grandvalet and Yoshua Bengio: Semi-supervised Learningby Entropy Minimization. - NeurIPS, 2005. + NeurIPS, 2005. + https://papers.nips.cc/paper/2004/file/96f2b50b5d3613adf9c27049b2a888c7-Paper.pdf """ def __init__(self, config, stage = 'train'): super(SSLEntropyMinimization, self).__init__(config, stage) diff --git a/pymic/net_run_ssl/ssl_mt.py b/pymic/net_run_ssl/ssl_mt.py index 7905968..aa7fbff 100644 --- a/pymic/net_run_ssl/ssl_mt.py +++ b/pymic/net_run_ssl/ssl_mt.py @@ -13,7 +13,11 @@ class SSLMeanTeacher(SSLSegAgent): """ - Training and testing agent for semi-supervised segmentation + Mean Teacher for semi-supervised learning according to the following paper: + Antti Tarvainen, Harri Valpola: Mean teachers are better role models: Weight-averaged + consistency targets improve semi-supervised deep learning results. + NeurIPS 2017. + https://arxiv.org/abs/1703.01780 """ def __init__(self, config, stage = 'train'): super(SSLMeanTeacher, self).__init__(config, stage) diff --git a/pymic/net_run_ssl/ssl_urpc.py b/pymic/net_run_ssl/ssl_urpc.py index 0513dc2..5d269fd 100644 --- a/pymic/net_run_ssl/ssl_urpc.py +++ b/pymic/net_run_ssl/ssl_urpc.py @@ -14,12 +14,11 @@ class SSLURPC(SSLSegAgent): """ Uncertainty-Rectified Pyramid Consistency according to the following paper: - Xiangde Luo, Wenjun Liao, Jieneng Chen, Tao Song, Yinan Chen, - Shichuan Zhang, Nianyong Chen, Guotai Wang, Shaoting Zhang. - Efficient Semi-supervised Gross Target Volume of Nasopharyngeal Carcinoma - Segmentation via Uncertainty Rectified Pyramid Consistency. - MICCAI 2021, pp. 318-329. - https://arxiv.org/abs/2012.07042 + Xiangde Luo, Guotai Wang*, Wenjun Liao, Jieneng Chen, Tao Song, Yinan Chen, + Shichuan Zhang, Dimitris N. Metaxas, Shaoting Zhang. + Semi-Supervised Medical Image Segmentation via Uncertainty Rectified Pyramid Consistency . + Medical Image Analysis 2022. + https://doi.org/10.1016/j.media.2022.102517 """ def training(self): class_num = self.config['network']['class_num'] From 89fd6ccf996e57f8ca60963ebfdf8455be5bcf42 Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 11 Aug 2022 12:36:39 +0800 Subject: [PATCH 021/225] update network for wsl Rename WSL classes, update learning rate scheduler and update dual-branch network --- pymic/net/net2d/unet2d.py | 61 ++++++++++++++++++++++++++ pymic/net/net2d/unet2d_cct.py | 63 +-------------------------- pymic/net/net2d/unet2d_dual_branch.py | 32 +++++++++++++- pymic/net/net_dict_seg.py | 4 +- pymic/net_run_wsl/wsl_abstract.py | 5 ++- pymic/net_run_wsl/wsl_dmpls.py | 9 ++-- pymic/net_run_wsl/wsl_em.py | 4 +- pymic/net_run_wsl/wsl_gatedcrf.py | 3 +- pymic/net_run_wsl/wsl_main.py | 24 +++++----- pymic/net_run_wsl/wsl_tv.py | 6 ++- pymic/net_run_wsl/wsl_ustm.py | 4 +- 11 files changed, 127 insertions(+), 88 deletions(-) diff --git a/pymic/net/net2d/unet2d.py b/pymic/net/net2d/unet2d.py index 703ced3..a361f0f 100644 --- a/pymic/net/net2d/unet2d.py +++ b/pymic/net/net2d/unet2d.py @@ -72,6 +72,67 @@ def forward(self, x1, x2): x = torch.cat([x2, x1], dim=1) return self.conv(x) +class Encoder(nn.Module): + def __init__(self, params): + super(Encoder, self).__init__() + self.params = params + self.in_chns = self.params['in_chns'] + self.ft_chns = self.params['feature_chns'] + self.dropout = self.params['dropout'] + assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) + + self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) + self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) + self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) + self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) + if(len(self.ft_chns) == 5): + self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) + + def forward(self, x): + x0 = self.in_conv(x) + x1 = self.down1(x0) + x2 = self.down2(x1) + x3 = self.down3(x2) + output = [x0, x1, x2, x3] + if(len(self.ft_chns) == 5): + x4 = self.down4(x3) + output.append(x4) + return output + +class Decoder(nn.Module): + def __init__(self, params): + super(Decoder, self).__init__() + self.params = params + self.in_chns = self.params['in_chns'] + self.ft_chns = self.params['feature_chns'] + self.dropout = self.params['dropout'] + self.n_class = self.params['class_num'] + self.bilinear = self.params['bilinear'] + + assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) + + if(len(self.ft_chns) == 5): + self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.bilinear) + self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.bilinear) + self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.bilinear) + self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.bilinear) + self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 1) + + def forward(self, x): + if(len(self.ft_chns) == 5): + assert(len(x) == 5) + x0, x1, x2, x3, x4 = x + x_d3 = self.up1(x4, x3) + else: + assert(len(x) == 4) + x0, x1, x2, x3 = x + x_d3 = x3 + x_d2 = self.up2(x_d3, x2) + x_d1 = self.up3(x_d2, x1) + x_d0 = self.up4(x_d1, x0) + output = self.out_conv(x_d0) + return output + class UNet2D(nn.Module): def __init__(self, params): super(UNet2D, self).__init__() diff --git a/pymic/net/net2d/unet2d_cct.py b/pymic/net/net2d/unet2d_cct.py index 88a369f..f7558bc 100644 --- a/pymic/net/net2d/unet2d_cct.py +++ b/pymic/net/net2d/unet2d_cct.py @@ -15,68 +15,7 @@ import torch.nn.functional as F import numpy as np from torch.distributions.uniform import Uniform -from pymic.net.net2d.unet2d import ConvBlock, DownBlock, UpBlock - -class Encoder(nn.Module): - def __init__(self, params): - super(Encoder, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) - - self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) - self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) - self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) - self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) - if(len(self.ft_chns) == 5): - self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - - def forward(self, x): - x0 = self.in_conv(x) - x1 = self.down1(x0) - x2 = self.down2(x1) - x3 = self.down3(x2) - output = [x0, x1, x2, x3] - if(len(self.ft_chns) == 5): - x4 = self.down4(x3) - output.append(x4) - return output - -class Decoder(nn.Module): - def __init__(self, params): - super(Decoder, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.n_class = self.params['class_num'] - self.bilinear = self.params['bilinear'] - - assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) - - if(len(self.ft_chns) == 5): - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.bilinear) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.bilinear) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.bilinear) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.bilinear) - self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 1) - - def forward(self, x): - if(len(self.ft_chns) == 5): - assert(len(x) == 5) - x0, x1, x2, x3, x4 = x - x_d3 = self.up1(x4, x3) - else: - assert(len(x) == 4) - x0, x1, x2, x3 = x - x_d3 = x3 - x_d2 = self.up2(x_d3, x2) - x_d1 = self.up3(x_d2, x1) - x_d0 = self.up4(x_d1, x0) - output = self.out_conv(x_d0) - return output +from pymic.net.net2d.unet2d import Encoder, Decoder def _l2_normalize(d): # Normalizing per batch axis diff --git a/pymic/net/net2d/unet2d_dual_branch.py b/pymic/net/net2d/unet2d_dual_branch.py index 59ec138..9622bd0 100644 --- a/pymic/net/net2d/unet2d_dual_branch.py +++ b/pymic/net/net2d/unet2d_dual_branch.py @@ -11,10 +11,38 @@ import torch import torch.nn as nn -import numpy as np -from torch.nn.functional import interpolate from pymic.net.net2d.unet2d import * +class UNet2D_DualBranch(nn.Module): + def __init__(self, params): + super(UNet2D_DualBranch, self).__init__() + self.encoder = Encoder(params) + self.decoder1 = Decoder(params) + self.decoder2 = Decoder(params) + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + + f = self.encoder(x) + output1 = self.decoder1(f) + output2 = self.decoder2(f) + if(len(x_shape) == 5): + new_shape = [N, D] + list(output1.shape)[1:] + output1 = torch.reshape(output1, new_shape) + output1 = torch.transpose(output1, 1, 2) + output2 = torch.reshape(output2, new_shape) + output2 = torch.transpose(output2, 1, 2) + + if(self.training): + return output1, output2 + else: + return (output1 + output2)/2 + # for backup class DualBranchUNet2D(UNet2D): def __init__(self, params): params['deep_supervise'] = False diff --git a/pymic/net/net_dict_seg.py b/pymic/net/net_dict_seg.py index aa912bc..0ee554e 100644 --- a/pymic/net/net_dict_seg.py +++ b/pymic/net/net_dict_seg.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division from pymic.net.net2d.unet2d import UNet2D -from pymic.net.net2d.unet2d_dual_branch import DualBranchUNet2D +from pymic.net.net2d.unet2d_dual_branch import UNet2D_DualBranch from pymic.net.net2d.unet2d_urpc import UNet2D_URPC from pymic.net.net2d.unet2d_cct import UNet2D_CCT from pymic.net.net2d.cople_net import COPLENet @@ -14,7 +14,7 @@ SegNetDict = { 'UNet2D': UNet2D, - 'DualBranchUNet2D': DualBranchUNet2D, + 'UNet2D_DualBranch': UNet2D_DualBranch, 'UNet2D_URPC': UNet2D_URPC, 'UNet2D_CCT': UNet2D_CCT, 'COPLENet': COPLENet, diff --git a/pymic/net_run_wsl/wsl_abstract.py b/pymic/net_run_wsl/wsl_abstract.py index fe80ea5..d64063e 100644 --- a/pymic/net_run_wsl/wsl_abstract.py +++ b/pymic/net_run_wsl/wsl_abstract.py @@ -5,7 +5,7 @@ class WSLSegAgent(SegmentationAgent): """ - Training and testing agent for semi-supervised segmentation + Training and testing agent for weakly supervised segmentation """ def __init__(self, config, stage = 'train'): super(WSLSegAgent, self).__init__(config, stage) @@ -13,7 +13,7 @@ def __init__(self, config, stage = 'train'): def training(self): pass - def write_scalars(self, train_scalars, valid_scalars, glob_it): + def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): loss_scalar ={'train':train_scalars['loss'], 'valid':valid_scalars['loss']} loss_sup_scalar = {'train':train_scalars['loss_sup']} @@ -23,6 +23,7 @@ def write_scalars(self, train_scalars, valid_scalars, glob_it): self.summ_writer.add_scalars('loss_sup', loss_sup_scalar, glob_it) self.summ_writer.add_scalars('loss_reg', loss_upsup_scalar, glob_it) self.summ_writer.add_scalars('regular_w', {'regular_w':train_scalars['regular_w']}, glob_it) + self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it) self.summ_writer.add_scalars('dice', dice_scalar, glob_it) class_num = self.config['network']['class_num'] for c in range(class_num): diff --git a/pymic/net_run_wsl/wsl_dmpls.py b/pymic/net_run_wsl/wsl_dmpls.py index 234f1a2..1e60e47 100644 --- a/pymic/net_run_wsl/wsl_dmpls.py +++ b/pymic/net_run_wsl/wsl_dmpls.py @@ -10,6 +10,7 @@ from pymic.loss.seg.dice import DiceLoss from pymic.net_run_wsl.wsl_abstract import WSLSegAgent from pymic.util.ramps import sigmoid_rampup +from pymic.util.general import keyword_match class WSLDMPLS(WSLSegAgent): """ @@ -18,12 +19,13 @@ class WSLDMPLS(WSLSegAgent): Shaoting Zhang. ScribblScribble-Supervised Medical Image Segmentation via Dual-Branch Network and Dynamically Mixed Pseudo Labels Supervision. MICCAI 2022. + https://arxiv.org/abs/2203.02106 """ def __init__(self, config, stage = 'train'): net_type = config['network']['net_type'] - if net_type not in ['DualBranchUNet2D', 'DualBranchUNet3D']: + if net_type not in ['UNet2D_DualBranch', 'UNet3D_DualBranch']: raise ValueError("""For WSL_DMPLS, a dual branch network is expected. \ - It only supports DualBranchUNet2D and DualBranchUNet3D currently.""") + It only supports UNet2D_DualBranch and UNet3D_DualBranch currently.""") super(WSLDMPLS, self).__init__(config, stage) def training(self): @@ -82,7 +84,8 @@ def training(self): loss.backward() self.optimizer.step() - self.scheduler.step() + if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + self.scheduler.step() train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() diff --git a/pymic/net_run_wsl/wsl_em.py b/pymic/net_run_wsl/wsl_em.py index cc19600..66823c1 100644 --- a/pymic/net_run_wsl/wsl_em.py +++ b/pymic/net_run_wsl/wsl_em.py @@ -10,6 +10,7 @@ from pymic.net_run.agent_seg import SegmentationAgent from pymic.net_run_wsl.wsl_abstract import WSLSegAgent from pymic.util.ramps import sigmoid_rampup +from pymic.util.general import keyword_match class WSLEntropyMinimization(WSLSegAgent): """ @@ -60,7 +61,8 @@ def training(self): loss.backward() self.optimizer.step() - self.scheduler.step() + if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + self.scheduler.step() train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() diff --git a/pymic/net_run_wsl/wsl_gatedcrf.py b/pymic/net_run_wsl/wsl_gatedcrf.py index af6c562..d728328 100644 --- a/pymic/net_run_wsl/wsl_gatedcrf.py +++ b/pymic/net_run_wsl/wsl_gatedcrf.py @@ -86,7 +86,8 @@ def training(self): loss.backward() self.optimizer.step() - self.scheduler.step() + if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + self.scheduler.step() train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() diff --git a/pymic/net_run_wsl/wsl_main.py b/pymic/net_run_wsl/wsl_main.py index 595aa3e..916e1d8 100644 --- a/pymic/net_run_wsl/wsl_main.py +++ b/pymic/net_run_wsl/wsl_main.py @@ -5,19 +5,19 @@ import os import sys from pymic.util.parse_config import * -from pymic.net_run_wsl.wsl_em import WSL_EntropyMinimization -from pymic.net_run_wsl.wsl_gatedcrf import WSL_GatedCRF -from pymic.net_run_wsl.wsl_mumford_shah import WSL_MumfordShah -from pymic.net_run_wsl.wsl_tv import WSL_TotalVariation -from pymic.net_run_wsl.wsl_ustm import WSL_USTM -from pymic.net_run_wsl.wsl_dmpls import WSL_DMPLS +from pymic.net_run_wsl.wsl_em import WSLEntropyMinimization +from pymic.net_run_wsl.wsl_gatedcrf import WSLGatedCRF +from pymic.net_run_wsl.wsl_mumford_shah import WSLMumfordShah +from pymic.net_run_wsl.wsl_tv import WSLTotalVariation +from pymic.net_run_wsl.wsl_ustm import WSLUSTM +from pymic.net_run_wsl.wsl_dmpls import WSLDMPLS -WSLMethodDict = {'EntropyMinimization': WSL_EntropyMinimization, - 'GatedCRF': WSL_GatedCRF, - 'MumfordShah': WSL_MumfordShah, - 'TotalVariation': WSL_TotalVariation, - 'USTM': WSL_USTM, - 'DMPLS': WSL_DMPLS} +WSLMethodDict = {'EntropyMinimization': WSLEntropyMinimization, + 'GatedCRF': WSLGatedCRF, + 'MumfordShah': WSLMumfordShah, + 'TotalVariation': WSLTotalVariation, + 'USTM': WSLUSTM, + 'DMPLS': WSLDMPLS} def main(): if(len(sys.argv) < 3): diff --git a/pymic/net_run_wsl/wsl_tv.py b/pymic/net_run_wsl/wsl_tv.py index 9492150..fde1c10 100644 --- a/pymic/net_run_wsl/wsl_tv.py +++ b/pymic/net_run_wsl/wsl_tv.py @@ -9,10 +9,11 @@ from pymic.loss.seg.ssl import TotalVariationLoss from pymic.net_run_wsl.wsl_abstract import WSLSegAgent from pymic.util.ramps import sigmoid_rampup +from pymic.util.general import keyword_match class WSLTotalVariation(WSLSegAgent): """ - Training and testing agent for semi-supervised segmentation + Weakly suepervised segmentation with Total Variation Regularization. """ def __init__(self, config, stage = 'train'): super(WSLTotalVariation, self).__init__(config, stage) @@ -59,7 +60,8 @@ def training(self): # if (self.config['training']['use']) loss.backward() self.optimizer.step() - self.scheduler.step() + if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + self.scheduler.step() train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() diff --git a/pymic/net_run_wsl/wsl_ustm.py b/pymic/net_run_wsl/wsl_ustm.py index 32e79df..dd556e7 100644 --- a/pymic/net_run_wsl/wsl_ustm.py +++ b/pymic/net_run_wsl/wsl_ustm.py @@ -11,6 +11,7 @@ from pymic.net.net_dict_seg import SegNetDict from pymic.net_run_wsl.wsl_abstract import WSLSegAgent from pymic.util.ramps import sigmoid_rampup +from pymic.util.general import keyword_match class WSLUSTM(WSLSegAgent): """ @@ -108,7 +109,8 @@ def training(self): loss.backward() self.optimizer.step() - self.scheduler.step() + if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + self.scheduler.step() # update EMA alpha = wsl_cfg.get('ema_decay', 0.99) From dd2f79781b598ed95f54a3cef851ca034164463f Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 11 Aug 2022 15:22:42 +0800 Subject: [PATCH 022/225] Update wsl_gatedcrf.py --- pymic/net_run_wsl/wsl_gatedcrf.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pymic/net_run_wsl/wsl_gatedcrf.py b/pymic/net_run_wsl/wsl_gatedcrf.py index d728328..2ae8318 100644 --- a/pymic/net_run_wsl/wsl_gatedcrf.py +++ b/pymic/net_run_wsl/wsl_gatedcrf.py @@ -9,10 +9,16 @@ from pymic.loss.seg.gatedcrf import ModelLossSemsegGatedCRF from pymic.net_run_wsl.wsl_abstract import WSLSegAgent from pymic.util.ramps import sigmoid_rampup +from pymic.util.general import keyword_match class WSLGatedCRF(WSLSegAgent): """ - Training and testing agent for semi-supervised segmentation + Implementation of the Gated CRF Loss for Weakly Supervised Semantic Image Segmentation. + Anton Obukhov, Stamatios Georgoulis, Dengxin Dai, Luc Van Gool: + Gated CRF Loss for Weakly Supervised Semantic Image Segmentation. + CoRR, abs/1906.04651, 2019 + http://arxiv.org/abs/1906.04651 + } """ def __init__(self, config, stage = 'train'): super(WSLGatedCRF, self).__init__(config, stage) From 41152e2fa85f6a7e26e0d332a0335acdc6abba16 Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 11 Aug 2022 16:03:42 +0800 Subject: [PATCH 023/225] update mumford shah method --- pymic/loss/seg/mumford_shah.py | 27 ++------------------------- pymic/net_run_wsl/wsl_mumford_shah.py | 9 +++++++-- 2 files changed, 9 insertions(+), 27 deletions(-) diff --git a/pymic/loss/seg/mumford_shah.py b/pymic/loss/seg/mumford_shah.py index eeaa250..f167b71 100644 --- a/pymic/loss/seg/mumford_shah.py +++ b/pymic/loss/seg/mumford_shah.py @@ -4,37 +4,14 @@ import torch import torch.nn as nn -class DiceLoss(nn.Module): - def __init__(self, params = None): - super(DiceLoss, self).__init__() - if(params is None): - self.softmax = True - else: - self.softmax = params.get('loss_softmax', True) - - def forward(self, loss_input_dict): - predict = loss_input_dict['prediction'] - soft_y = loss_input_dict['ground_truth'] - - if(isinstance(predict, (list, tuple))): - predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) - predict = reshape_tensor_to_2D(predict) - soft_y = reshape_tensor_to_2D(soft_y) - dice_score = get_classwise_dice(predict, soft_y) - dice_loss = 1.0 - dice_score.mean() - return dice_loss - class MumfordShahLoss(nn.Module): """ Implementation of Mumford Shah Loss in this paper: - Boah Kim and Jong Chul Ye, Mumford–Shah Loss Functional + Boah Kim and Jong Chul Ye: Mumford–Shah Loss Functional for Image Segmentation With Deep Learning. IEEE TIP, 2019. The oringial implementation is availabel at: https://github.com/jongcye/CNN_MumfordShah_Loss - - currently only 2D version is supported. + Currently only 2D version is supported. """ def __init__(self, params = None): super(MumfordShahLoss, self).__init__() diff --git a/pymic/net_run_wsl/wsl_mumford_shah.py b/pymic/net_run_wsl/wsl_mumford_shah.py index f642e59..095a0f6 100644 --- a/pymic/net_run_wsl/wsl_mumford_shah.py +++ b/pymic/net_run_wsl/wsl_mumford_shah.py @@ -9,10 +9,14 @@ from pymic.loss.seg.mumford_shah import MumfordShahLoss from pymic.net_run_wsl.wsl_abstract import WSLSegAgent from pymic.util.ramps import sigmoid_rampup +from pymic.util.general import keyword_match class WSLMumfordShah(WSLSegAgent): """ - Training and testing agent for semi-supervised segmentation + Weakly supervised learning with Mumford Shah Loss according to this paper: + Boah Kim and Jong Chul Ye: Mumford–Shah Loss Functional + for Image Segmentation With Deep Learning. IEEE TIP, 2019. + https://doi.org/10.1109/TIP.2019.2941265 """ def __init__(self, config, stage = 'train'): super(WSLMumfordShah, self).__init__(config, stage) @@ -61,7 +65,8 @@ def training(self): # if (self.config['training']['use']) loss.backward() self.optimizer.step() - self.scheduler.step() + if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + self.scheduler.step() train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() From cb61d458b5db53ae0f789fa767105d85bbeef71a Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 13 Aug 2022 09:23:19 +0800 Subject: [PATCH 024/225] Update wsl_ustm.py add reference --- pymic/net_run_wsl/wsl_ustm.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pymic/net_run_wsl/wsl_ustm.py b/pymic/net_run_wsl/wsl_ustm.py index dd556e7..6083069 100644 --- a/pymic/net_run_wsl/wsl_ustm.py +++ b/pymic/net_run_wsl/wsl_ustm.py @@ -15,7 +15,12 @@ class WSLUSTM(WSLSegAgent): """ - Training and testing agent for semi-supervised segmentation + USTM for scribble-supervised segmentation according to the following paper: + Xiaoming Liu, Quan Yuan, Yaozong Gao, Helei He, Shuo Wang, Xiao Tang, + Jinshan Tang, Dinggang Shen: + Weakly Supervised Segmentation of COVID19 Infection with Scribble Annotation on CT Images. + Patter Recognition, 2022. + https://doi.org/10.1016/j.patcog.2021.108341 """ def __init__(self, config, stage = 'train'): super(WSLUSTM, self).__init__(config, stage) From 251e4d1a42083907b0d52a80b36f04a5cd9a137f Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 17 Aug 2022 16:07:54 +0800 Subject: [PATCH 025/225] update SSL, WSL and NLL update rampup and lr scheduler.step --- pymic/loss/loss_dict_seg.py | 4 +- pymic/loss/seg/ce.py | 23 ++-- pymic/loss/seg/slsr.py | 9 +- pymic/net_run/agent_seg.py | 9 +- pymic/net_run/get_optimizer.py | 2 + pymic/net_run_nll/nll_cl.py | 52 ++++---- pymic/net_run_nll/nll_co_teaching.py | 109 ++++++---------- pymic/net_run_nll/nll_main.py | 37 ++++++ pymic/net_run_nll/nll_trinet.py | 178 ++++++++++++++++++++++++++ pymic/net_run_ssl/ssl_abstract.py | 4 - pymic/net_run_ssl/ssl_cct.py | 22 ++-- pymic/net_run_ssl/ssl_cps.py | 83 +++++------- pymic/net_run_ssl/ssl_em.py | 20 +-- pymic/net_run_ssl/ssl_mt.py | 20 +-- pymic/net_run_ssl/ssl_uamt.py | 22 ++-- pymic/net_run_ssl/ssl_urpc.py | 20 ++- pymic/net_run_wsl/wsl_dmpls.py | 19 ++- pymic/net_run_wsl/wsl_em.py | 21 ++- pymic/net_run_wsl/wsl_gatedcrf.py | 19 ++- pymic/net_run_wsl/wsl_mumford_shah.py | 19 ++- pymic/net_run_wsl/wsl_tv.py | 18 +-- pymic/net_run_wsl/wsl_ustm.py | 21 ++- pymic/util/ramps.py | 27 ++-- 23 files changed, 457 insertions(+), 301 deletions(-) create mode 100644 pymic/net_run_nll/nll_main.py create mode 100644 pymic/net_run_nll/nll_trinet.py diff --git a/pymic/loss/loss_dict_seg.py b/pymic/loss/loss_dict_seg.py index 929ec43..a8a53ad 100644 --- a/pymic/loss/loss_dict_seg.py +++ b/pymic/loss/loss_dict_seg.py @@ -1,14 +1,14 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division import torch.nn as nn -from pymic.loss.seg.ce import CrossEntropyLoss, GeneralizedCrossEntropyLoss +from pymic.loss.seg.ce import CrossEntropyLoss, GeneralizedCELoss from pymic.loss.seg.dice import DiceLoss, FocalDiceLoss, NoiseRobustDiceLoss from pymic.loss.seg.slsr import SLSRLoss from pymic.loss.seg.exp_log import ExpLogLoss from pymic.loss.seg.mse import MSELoss, MAELoss SegLossDict = {'CrossEntropyLoss': CrossEntropyLoss, - 'GeneralizedCrossEntropyLoss': GeneralizedCrossEntropyLoss, + 'GeneralizedCELoss': GeneralizedCELoss, 'SLSRLoss': SLSRLoss, 'DiceLoss': DiceLoss, 'FocalDiceLoss': FocalDiceLoss, diff --git a/pymic/loss/seg/ce.py b/pymic/loss/seg/ce.py index dadeba7..da2bf14 100644 --- a/pymic/loss/seg/ce.py +++ b/pymic/loss/seg/ce.py @@ -59,34 +59,36 @@ def forward(self, loss_input_dict): ce = torch.mean(ce) return ce -class GeneralizedCrossEntropyLoss(nn.Module): +class GeneralizedCELoss(nn.Module): """ Generalized cross entropy loss to deal with noisy labels. Z. Zhang et al. Generalized Cross Entropy Loss for Training Deep Neural Networks with Noisy Labels, NeurIPS 2018. """ def __init__(self, params): - super(GeneralizedCrossEntropyLoss, self).__init__() - self.enable_pix_weight = params['GeneralizedCrossEntropyLoss_Enable_Pixel_Weight'.lower()] - self.enable_cls_weight = params['GeneralizedCrossEntropyLoss_Enable_Class_Weight'.lower()] - self.q = params['GeneralizedCrossEntropyLoss_q'.lower()] + """ + q: in (0, 1), becmomes MAE when q = 1 + """ + super(GeneralizedCELoss, self).__init__() + self.enable_pix_weight = params.get('GeneralizedCELoss_Enable_Pixel_Weight', False) + self.enable_cls_weight = params.get('GeneralizedCELoss_Enable_Class_Weight', False) + self.q = params.get('GeneralizedCELoss_q', 0.5) + self.softmax = params.get('loss_softmax', True) def forward(self, loss_input_dict): predict = loss_input_dict['prediction'] - soft_y = loss_input_dict['ground_truth'] - pix_w = loss_input_dict['pixel_weight'] - cls_w = loss_input_dict['class_weight'] - softmax = loss_input_dict['softmax'] + soft_y = loss_input_dict['ground_truth'] if(isinstance(predict, (list, tuple))): predict = predict[0] - if(softmax): + if(self.softmax): predict = nn.Softmax(dim = 1)(predict) predict = reshape_tensor_to_2D(predict) soft_y = reshape_tensor_to_2D(soft_y) gce = (1.0 - torch.pow(predict, self.q)) / self.q * soft_y if(self.enable_cls_weight): + cls_w = loss_input_dict.get('class_weight', None) if(cls_w is None): raise ValueError("Class weight is enabled but not defined") gce = torch.sum(gce * cls_w, dim = 1) @@ -94,6 +96,7 @@ def forward(self, loss_input_dict): gce = torch.sum(gce, dim = 1) if(self.enable_pix_weight): + pix_w = loss_input_dict.get('pixel_weight', None) if(pix_w is None): raise ValueError("Pixel weight is enabled but not defined") pix_w = reshape_tensor_to_2D(pix_w) diff --git a/pymic/loss/seg/slsr.py b/pymic/loss/seg/slsr.py index 6ad60b3..706d2fc 100644 --- a/pymic/loss/seg/slsr.py +++ b/pymic/loss/seg/slsr.py @@ -2,8 +2,10 @@ """ Spatial Label Smoothing Regularization (SLSR) loss for learning from noisy annotatins according to the following paper: - Minqing Zhang, Jiantao Gao et al., Characterizing Label Errors: - Confident Learning for Noisy-Labeled Image Segmentation, MICCAI 2020. + Minqing Zhang, Jiantao Gao et al.: + Characterizing Label Errors: Confident Learning for Noisy-Labeled Image + Segmentation, MICCAI 2020. + https://link.springer.com/chapter/10.1007/978-3-030-59710-8_70 """ from __future__ import print_function, division @@ -17,7 +19,7 @@ def __init__(self, params): if(params is None): params = {} self.softmax = params.get('loss_softmax', True) - self.epsilon = params.get('slsrloss_softmax', 0.25) + self.epsilon = params.get('slsrloss_epsilon', 0.25) def forward(self, loss_input_dict): predict = loss_input_dict['prediction'] @@ -35,7 +37,6 @@ def forward(self, loss_input_dict): soft_y = reshape_tensor_to_2D(soft_y) if(pix_w is not None): pix_w = reshape_tensor_to_2D(pix_w > 0).float() - # smooth labels for pixels in the unconfident mask smooth_y = (soft_y - 0.5) * (0.5 - self.epsilon) / 0.5 + 0.5 smooth_y = pix_w * smooth_y + (1 - pix_w) * soft_y diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 38c75ae..d8e74a2 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -10,6 +10,7 @@ import numpy as np import torch.nn as nn import torch.optim as optim +from torch.optim import lr_scheduler import torch.nn.functional as F from datetime import datetime from tensorboardX import SummaryWriter @@ -27,7 +28,6 @@ from pymic.transform.trans_dict import TransformDict from pymic.util.post_process import PostProcessDict from pymic.util.image_process import convert_label -from pymic.util.general import keyword_match class SegmentationAgent(NetRunAgent): def __init__(self, config, stage = 'train'): @@ -164,7 +164,7 @@ def set_postprocessor(self, postprocessor): def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] - train_loss = 0 + train_loss = 0 train_dice_list = [] self.net.train() for it in range(iter_valid): @@ -201,7 +201,8 @@ def training(self): loss = self.get_loss_value(data, outputs, labels_prob) loss.backward() self.optimizer.step() - if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): self.scheduler.step() train_loss = train_loss + loss.item() @@ -258,7 +259,7 @@ def validation(self): valid_cls_dice = np.asarray(valid_dice_list).mean(axis = 0) valid_avg_dice = valid_cls_dice.mean() - if(keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + if(isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): self.scheduler.step(valid_avg_dice) valid_scalers = {'loss': valid_avg_loss, 'avg_dice': valid_avg_dice,\ diff --git a/pymic/net_run/get_optimizer.py b/pymic/net_run/get_optimizer.py index e475286..c4504de 100644 --- a/pymic/net_run/get_optimizer.py +++ b/pymic/net_run/get_optimizer.py @@ -38,6 +38,8 @@ def get_optimizer(name, net_params, optim_params): def get_lr_scheduler(optimizer, sched_params): name = sched_params["lr_scheduler"] + if(name is None): + return None lr_gamma = sched_params["lr_gamma"] if(keyword_match(name, "ReduceLROnPlateau")): patience_it = sched_params["ReduceLROnPlateau_patience".lower()] diff --git a/pymic/net_run_nll/nll_cl.py b/pymic/net_run_nll/nll_cl.py index 8792ccd..8173471 100644 --- a/pymic/net_run_nll/nll_cl.py +++ b/pymic/net_run_nll/nll_cl.py @@ -14,6 +14,7 @@ import sys import torch import numpy as np +import pandas as pd import torch.nn as nn import torchvision.transforms as transforms from PIL import Image @@ -45,9 +46,9 @@ def get_confident_map(gt, pred, CL_type = 'both'): noise = cleanlab.pruning.get_noise_indices(gt, prob, prune_method=CL_type, n_jobs=1) return noise -class NLLConfidentLeran(SegmentationAgent): +class NLLConfidentLearn(SegmentationAgent): def __init__(self, config, stage = 'test'): - super(NLLConfidentLeran, self).__init__(config, stage) + super(NLLConfidentLearn, self).__init__(config, stage) def infer_with_cl(self): device_ids = self.config['testing']['gpus'] @@ -93,16 +94,6 @@ def test_time_dropout(m): filename_list.append(names) images = images.to(device) - # for debug - # for i in range(images.shape[0]): - # image_i = images[i][0] - # label_i = images[i][0] - # image_name = "temp/{0:}_image.nii.gz".format(names[0]) - # label_name = "temp/{0:}_label.nii.gz".format(names[0]) - # save_nd_array_as_image(image_i, image_name, reference_name = None) - # save_nd_array_as_image(label_i, label_name, reference_name = None) - # continue - pred = self.inferer.run(self.net, images) # convert tensor to numpy if(isinstance(pred, (tuple, list))): @@ -142,15 +133,10 @@ def test_time_dropout(m): dst_path = os.path.join(save_dir, filename) conf_map.save(dst_path) - def run(self): - self.create_dataset() - self.create_network() - self.infer_with_cl() - -def main(): +def get_confidence_map(): if(len(sys.argv) < 2): print('Number of arguments should be 3. e.g.') - print(' python cl.py config.cfg') + print(' python nll_cl.py config.cfg') exit() cfg_file = str(sys.argv[1]) config = parse_config(cfg_file) @@ -172,17 +158,35 @@ def main(): transform_list.append(one_transform) data_transform = transforms.Compose(transform_list) print('transform list', transform_list) - csv_file = config['dataset']['train_csv'] + csv_file = config['dataset']['train_csv'] + modal_num = config['dataset'].get('modal_num', 1) dataset = NiftyDataset(root_dir = config['dataset']['root_dir'], csv_file = csv_file, - modal_num = config['dataset']['modal_num'], + modal_num = modal_num, with_label= True, transform = data_transform ) - agent = NLLConfidentLeran(config, 'test') + agent = NLLConfidentLearn(config, 'test') agent.set_datasets(None, None, dataset) agent.transform_list = transform_list - agent.run() + agent.create_dataset() + agent.create_network() + agent.infer_with_cl() + + # create training csv for confidence learning + df_train = pd.read_csv(csv_file) + pixel_weight = [] + weight_dir = config['testing']['output_dir'] + "_conf" + for i in range(len(df_train["label"])): + lab_name = df_train["label"][i].split('/')[-1] + weight_name = "../" + weight_dir + '/' + lab_name + pixel_weight.append(weight_name) + train_cl_dict = {"image": df_train["image"], + "pixel_weight": pixel_weight, + "label": df_train["label"]} + train_cl_csv = csv_file.replace(".csv", "_cl.csv") + df_cl = pd.DataFrame.from_dict(train_cl_dict) + df_cl.to_csv(train_cl_csv, index = False) if __name__ == "__main__": - main() \ No newline at end of file + get_confidence_map() \ No newline at end of file diff --git a/pymic/net_run_nll/nll_co_teaching.py b/pymic/net_run_nll/nll_co_teaching.py index e1392eb..bcaec4e 100644 --- a/pymic/net_run_nll/nll_co_teaching.py +++ b/pymic/net_run_nll/nll_co_teaching.py @@ -11,23 +11,37 @@ """ from __future__ import print_function, division import logging +import os +import sys import numpy as np import torch import torch.nn as nn import torch.optim as optim +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.util import reshape_tensor_to_2D -from pymic.util.ramps import sigmoid_rampup -from pymic.net_run.get_optimizer import get_optimiser from pymic.net_run.agent_seg import SegmentationAgent from pymic.net.net_dict_seg import SegNetDict - -import logging -import os -import sys from pymic.util.parse_config import * +from pymic.util.ramps import get_rampup_ratio + +class BiNet(nn.Module): + def __init__(self, params): + super(BiNet, self).__init__() + net_name = params['net_type'] + self.net1 = SegNetDict[net_name](params) + self.net2 = SegNetDict[net_name](params) + + def forward(self, x): + out1 = self.net1(x) + out2 = self.net2(x) + + if(self.training): + return out1, out2 + else: + return (out1 + out2) / 3 class NLLCoTeaching(SegmentationAgent): """ @@ -37,48 +51,27 @@ class NLLCoTeaching(SegmentationAgent): """ def __init__(self, config, stage = 'train'): super(NLLCoTeaching, self).__init__(config, stage) - self.net2 = None - self.optimizer2 = None - self.scheduler2 = None loss_type = config['training']["loss_type"] if(loss_type != "CrossEntropyLoss"): logging.warn("only CrossEntropyLoss supported for" + " coteaching, the specified loss {0:} is ingored".format(loss_type)) def create_network(self): - super(NLLCoTeaching, self).create_network() - if(self.net2 is None): - net_name = self.config['network']['net_type'] - if(net_name not in SegNetDict): - raise ValueError("Undefined network {0:}".format(net_name)) - self.net2 = SegNetDict[net_name](self.config['network']) + if(self.net is None): + self.net = BiNet(self.config['network']) if(self.tensor_type == 'float'): - self.net2.float() + self.net.float() else: - self.net2.double() - - def train_valid(self): - # create optimizor for the second network - if(self.optimizer2 is None): - self.optimizer2 = get_optimiser(self.config['training']['optimizer'], - self.net2.parameters(), - self.config['training']) - last_iter = -1 - # if(self.checkpoint is not None): - # self.optimizer2.load_state_dict(self.checkpoint['optimizer_state_dict']) - # last_iter = self.checkpoint['iteration'] - 1 - if(self.scheduler2 is None): - self.scheduler2 = optim.lr_scheduler.MultiStepLR(self.optimizer2, - self.config['training']['lr_milestones'], - self.config['training']['lr_gamma'], - last_epoch = last_iter) - super(NLLCoTeaching, self).train_valid() + self.net.double() def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] - select_ratio = self.config['training']['co_teaching_select_ratio'] - rampup_length = self.config['training']['co_teaching_rampup_length'] + nll_cfg = self.config['noisy_label_learning'] + select_ratio = nll_cfg['co_teaching_select_ratio'] + iter_max = self.config['training']['iter_max'] + rampup_start = nll_cfg.get('rampup_start', 0) + rampup_end = nll_cfg.get('rampup_end', iter_max) train_loss_no_select1 = 0 train_loss_no_select2 = 0 @@ -86,8 +79,6 @@ def training(self): train_loss2 = 0 train_dice_list = [] self.net.train() - self.net2.train() - self.net2.to(self.device) for it in range(iter_valid): try: data = next(self.trainIter) @@ -102,11 +93,9 @@ def training(self): # zero the parameter gradients self.optimizer.zero_grad() - self.optimizer2.zero_grad() # forward + backward + optimize - outputs1 = self.net(inputs) - outputs2 = self.net2(inputs) + outputs1, outputs2 = self.net(inputs) prob1 = nn.Softmax(dim = 1)(outputs1) prob2 = nn.Softmax(dim = 1)(outputs2) @@ -122,8 +111,9 @@ def training(self): loss2 = torch.sum(loss2, dim = 1) # shape is [N] ind_2_sorted = torch.argsort(loss2) - forget_ratio = (1 - select_ratio) * self.glob_it / rampup_length - remb_ratio = max(select_ratio, 1 - forget_ratio) + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + forget_ratio = (1 - select_ratio) * rampup_ratio + remb_ratio = 1 - forget_ratio num_remb = int(remb_ratio * len(loss1)) ind_1_update = ind_1_sorted[:num_remb] @@ -134,22 +124,17 @@ def training(self): loss = loss1_select.mean() + loss2_select.mean() - # if (self.config['training']['use']) loss.backward() self.optimizer.step() - self.scheduler.step() - self.optimizer2.step() - self.scheduler2.step() + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): + self.scheduler.step() train_loss_no_select1 = train_loss_no_select1 + loss1.mean().item() train_loss_no_select2 = train_loss_no_select2 + loss2.mean().item() train_loss1 = train_loss1 + loss1_select.mean().item() train_loss2 = train_loss2 + loss2_select.mean().item() - # get dice evaluation for each class in annotated images - # if(isinstance(outputs1, tuple) or isinstance(outputs1, list)): - # outputs1 = outputs1[0] - outputs1_argmax = torch.argmax(outputs1, dim = 1, keepdim = True) soft_out1 = get_soft_label(outputs1_argmax, class_num, self.tensor_type) soft_out1, labels_prob = reshape_prediction_and_ground_truth(soft_out1, labels_prob) @@ -169,7 +154,7 @@ def training(self): 'select_ratio':remb_ratio, 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} return train_scalers - def write_scalars(self, train_scalars, valid_scalars, glob_it): + def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): loss_scalar ={'train':train_scalars['loss'], 'valid':valid_scalars['loss']} loss_no_select_scalar = {'net1':train_scalars['loss_no_select1'], @@ -179,6 +164,7 @@ def write_scalars(self, train_scalars, valid_scalars, glob_it): self.summ_writer.add_scalars('loss', loss_scalar, glob_it) self.summ_writer.add_scalars('loss_no_select', loss_no_select_scalar, glob_it) self.summ_writer.add_scalars('select_ratio', {'select_ratio':train_scalars['select_ratio']}, glob_it) + self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it) self.summ_writer.add_scalars('dice', dice_scalar, glob_it) class_num = self.config['network']['class_num'] for c in range(class_num): @@ -192,22 +178,3 @@ def write_scalars(self, train_scalars, valid_scalars, glob_it): logging.info('valid loss {0:.4f}, avg dice {1:.4f} '.format( valid_scalars['loss'], valid_scalars['avg_dice']) + "[" + \ ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") - -if __name__ == "__main__": - if(len(sys.argv) < 3): - print('Number of arguments should be 3. e.g.') - print(' pymic_ssl train config.cfg') - exit() - stage = str(sys.argv[1]) - cfg_file = str(sys.argv[2]) - config = parse_config(cfg_file) - config = synchronize_config(config) - log_dir = config['training']['ckpt_save_dir'] - if(not os.path.exists(log_dir)): - os.mkdir(log_dir) - logging.basicConfig(filename=log_dir+"/log.txt", level=logging.INFO, - format='%(message)s') - logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) - logging_config(config) - agent = NLLCoTeaching(config, stage) - agent.run() \ No newline at end of file diff --git a/pymic/net_run_nll/nll_main.py b/pymic/net_run_nll/nll_main.py new file mode 100644 index 0000000..d1ae7a1 --- /dev/null +++ b/pymic/net_run_nll/nll_main.py @@ -0,0 +1,37 @@ + +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import logging +import os +import sys +from pymic.util.parse_config import * +from pymic.net_run_nll.nll_co_teaching import NLLCoTeaching +from pymic.net_run_nll.nll_trinet import NLLTriNet + +NLLMethodDict = {'CoTeaching': NLLCoTeaching, + "TriNet": NLLTriNet} + +def main(): + if(len(sys.argv) < 3): + print('Number of arguments should be 3. e.g.') + print(' pymic_nll train config.cfg') + exit() + stage = str(sys.argv[1]) + cfg_file = str(sys.argv[2]) + config = parse_config(cfg_file) + config = synchronize_config(config) + log_dir = config['training']['ckpt_save_dir'] + if(not os.path.exists(log_dir)): + os.mkdir(log_dir) + logging.basicConfig(filename=log_dir+"/log.txt", level=logging.INFO, + format='%(message)s') + logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) + logging_config(config) + nll_method = config['noisy_label_learning']['nll_method'] + agent = NLLMethodDict[nll_method](config, stage) + agent.run() + +if __name__ == "__main__": + main() + + \ No newline at end of file diff --git a/pymic/net_run_nll/nll_trinet.py b/pymic/net_run_nll/nll_trinet.py new file mode 100644 index 0000000..eb0ecdd --- /dev/null +++ b/pymic/net_run_nll/nll_trinet.py @@ -0,0 +1,178 @@ +# -*- coding: utf-8 -*- +""" +Implementation of Co-teaching for learning from noisy samples for +segmentation tasks according to the following paper: + Bo Han et al., Co-teaching: Robust Training of Deep NeuralNetworks + with Extremely Noisy Labels, NeurIPS, 2018 +The author's original implementation was: +https://github.com/bhanML/Co-teaching + + +""" +from __future__ import print_function, division +import logging +import os +import sys +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.optim import lr_scheduler +from pymic.loss.seg.util import get_soft_label +from pymic.loss.seg.util import reshape_prediction_and_ground_truth +from pymic.loss.seg.util import get_classwise_dice +from pymic.loss.seg.util import reshape_tensor_to_2D +from pymic.net_run.agent_seg import SegmentationAgent +from pymic.net.net_dict_seg import SegNetDict +from pymic.util.parse_config import * +from pymic.util.ramps import get_rampup_ratio + + + +class TriNet(nn.Module): + def __init__(self, params): + super(TriNet, self).__init__() + net_name = params['net_type'] + self.net1 = SegNetDict[net_name](params) + self.net2 = SegNetDict[net_name](params) + self.net3 = SegNetDict[net_name](params) + + def forward(self, x): + out1 = self.net1(x) + out2 = self.net2(x) + out3 = self.net3(x) + + if(self.training): + return out1, out2, out3 + else: + return (out1 + out2 + out3) / 3 + +class NLLTriNet(SegmentationAgent): + """ + Co-teaching: Robust Training of Deep Neural Networks with Extremely + Noisy Labels + https://arxiv.org/abs/1804.06872 + """ + def __init__(self, config, stage = 'train'): + super(NLLTriNet, self).__init__(config, stage) + + def create_network(self): + if(self.net is None): + self.net = TriNet(self.config['network']) + if(self.tensor_type == 'float'): + self.net.float() + else: + self.net.double() + + def get_loss_and_confident_mask(self, pred, labels_prob, conf_ratio): + prob = nn.Softmax(dim = 1)(pred) + prob_2d = reshape_tensor_to_2D(prob) * 0.999 + 5e-4 + y_2d = reshape_tensor_to_2D(labels_prob) + + loss = - y_2d* torch.log(prob_2d) + loss = torch.sum(loss, dim = 1) # shape is [N] + threshold = torch.quantile(loss, conf_ratio) + mask = loss < threshold + return loss, mask + + def training(self): + class_num = self.config['network']['class_num'] + iter_valid = self.config['training']['iter_valid'] + nll_cfg = self.config['noisy_label_learning'] + iter_max = self.config['training']['iter_max'] + select_ratio = nll_cfg['trinet_select_ratio'] + rampup_start = nll_cfg.get('rampup_start', 0) + rampup_end = nll_cfg.get('rampup_end', iter_max) + + train_loss_no_select1 = 0 + train_loss_no_select2 = 0 + train_loss1, train_loss2, train_loss3 = 0, 0, 0 + train_dice_list = [] + self.net.train() + for it in range(iter_valid): + try: + data = next(self.trainIter) + except StopIteration: + self.trainIter = iter(self.train_loader) + data = next(self.trainIter) + + # get the inputs + inputs = self.convert_tensor_type(data['image']) + labels_prob = self.convert_tensor_type(data['label_prob']) + inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device) + + # zero the parameter gradients + self.optimizer.zero_grad() + + # forward + backward + optimize + outputs1, outputs2, outputs3 = self.net(inputs) + + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end) + forget_ratio = (1 - select_ratio) * rampup_ratio + remb_ratio = 1 - forget_ratio + + loss1, mask1 = self.get_loss_and_confident_mask(outputs1, labels_prob, remb_ratio) + loss2, mask2 = self.get_loss_and_confident_mask(outputs2, labels_prob, remb_ratio) + loss3, mask3 = self.get_loss_and_confident_mask(outputs3, labels_prob, remb_ratio) + mask12, mask13, mask23 = mask1 * mask2, mask1 * mask3, mask2 * mask3 + mask12, mask13, mask23 = mask12.detach(), mask13.detach(), mask23.detach() + + loss1_avg = torch.sum(loss1 * mask23) / mask23.sum() + loss2_avg = torch.sum(loss2 * mask13) / mask13.sum() + loss3_avg = torch.sum(loss3 * mask12) / mask12.sum() + loss = (loss1_avg + loss2_avg + loss3_avg) / 3 + + loss.backward() + self.optimizer.step() + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): + self.scheduler.step() + + train_loss_no_select1 = train_loss_no_select1 + loss1.mean().item() + train_loss_no_select2 = train_loss_no_select2 + loss2.mean().item() + train_loss1 = train_loss1 + loss1_avg.item() + train_loss2 = train_loss2 + loss2_avg.item() + + outputs1_argmax = torch.argmax(outputs1, dim = 1, keepdim = True) + soft_out1 = get_soft_label(outputs1_argmax, class_num, self.tensor_type) + soft_out1, labels_prob = reshape_prediction_and_ground_truth(soft_out1, labels_prob) + dice_list = get_classwise_dice(soft_out1, labels_prob).detach().cpu().numpy() + train_dice_list.append(dice_list) + train_avg_loss_no_select1 = train_loss_no_select1 / iter_valid + train_avg_loss_no_select2 = train_loss_no_select2 / iter_valid + train_avg_loss1 = train_loss1 / iter_valid + train_avg_loss2 = train_loss2 / iter_valid + train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) + train_avg_dice = train_cls_dice.mean() + + train_scalers = {'loss': (train_avg_loss1 + train_avg_loss2) / 2, + 'loss1':train_avg_loss1, 'loss2': train_avg_loss2, + 'loss_no_select1':train_avg_loss_no_select1, + 'loss_no_select2':train_avg_loss_no_select2, + 'select_ratio':remb_ratio, 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} + return train_scalers + + def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): + loss_scalar ={'train':train_scalars['loss'], + 'valid':valid_scalars['loss']} + loss_no_select_scalar = {'net1':train_scalars['loss_no_select1'], + 'net2':train_scalars['loss_no_select2']} + + dice_scalar ={'train':train_scalars['avg_dice'], 'valid':valid_scalars['avg_dice']} + self.summ_writer.add_scalars('loss', loss_scalar, glob_it) + self.summ_writer.add_scalars('loss_no_select', loss_no_select_scalar, glob_it) + self.summ_writer.add_scalars('select_ratio', {'select_ratio':train_scalars['select_ratio']}, glob_it) + self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it) + self.summ_writer.add_scalars('dice', dice_scalar, glob_it) + class_num = self.config['network']['class_num'] + for c in range(class_num): + cls_dice_scalar = {'train':train_scalars['class_dice'][c], \ + 'valid':valid_scalars['class_dice'][c]} + self.summ_writer.add_scalars('class_{0:}_dice'.format(c), cls_dice_scalar, glob_it) + + logging.info('train loss {0:.4f}, avg dice {1:.4f} '.format( + train_scalars['loss'], train_scalars['avg_dice']) + "[" + \ + ' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]") + logging.info('valid loss {0:.4f}, avg dice {1:.4f} '.format( + valid_scalars['loss'], valid_scalars['avg_dice']) + "[" + \ + ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") diff --git a/pymic/net_run_ssl/ssl_abstract.py b/pymic/net_run_ssl/ssl_abstract.py index f1b97ba..1d18c4d 100644 --- a/pymic/net_run_ssl/ssl_abstract.py +++ b/pymic/net_run_ssl/ssl_abstract.py @@ -12,7 +12,6 @@ from pymic.loss.seg.ssl import EntropyLoss from pymic.net_run.agent_seg import SegmentationAgent from pymic.transform.trans_dict import TransformDict -from pymic.util.ramps import sigmoid_rampup class SSLSegAgent(SegmentationAgent): """ @@ -70,9 +69,6 @@ def worker_init_fn(worker_id): batch_size = bn_train_unlab, shuffle=True, num_workers= num_worker, worker_init_fn=worker_init) - def training(self): - pass - def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): loss_scalar ={'train':train_scalars['loss'], 'valid':valid_scalars['loss']} diff --git a/pymic/net_run_ssl/ssl_cct.py b/pymic/net_run_ssl/ssl_cct.py index 80e9d07..d0c4f24 100644 --- a/pymic/net_run_ssl/ssl_cct.py +++ b/pymic/net_run_ssl/ssl_cct.py @@ -5,12 +5,12 @@ import torch.nn as nn import torch.nn.functional as F import numpy as np +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.net_run_ssl.ssl_abstract import SSLSegAgent -from pymic.util.ramps import sigmoid_rampup -from pymic.util.general import keyword_match +from pymic.util.ramps import get_rampup_ratio def softmax_mse_loss(inputs, targets, conf_mask=False, threshold=None, use_softmax=False): assert inputs.requires_grad == True and targets.requires_grad == False @@ -73,6 +73,9 @@ def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] ssl_cfg = self.config['semi_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = ssl_cfg.get('rampup_start', 0) + rampup_end = ssl_cfg.get('rampup_end', iter_max) unsup_loss_name = ssl_cfg.get('unsupervised_loss', "MSE") self.unsup_loss_f = unsup_loss_dict[unsup_loss_name] train_loss = 0 @@ -118,20 +121,15 @@ def training(self): for p1_auxi in p1_aux: loss_reg += self.unsup_loss_f( p1_auxi, p1, use_softmax = True) loss_reg = loss_reg / len(p1_aux) - - iter_max = self.config['training']['iter_max'] - ramp_up_length = ssl_cfg.get('ramp_up_length', iter_max) - regular_w = 0.0 - if(self.glob_it > ssl_cfg.get('iter_sup', 0)): - regular_w = ssl_cfg.get('regularize_w', 0.1) - if(ramp_up_length is not None and self.glob_it < ramp_up_length): - regular_w = regular_w * sigmoid_rampup(self.glob_it, ramp_up_length) - + + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + regular_w = ssl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg loss.backward() self.optimizer.step() - if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): self.scheduler.step() train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() diff --git a/pymic/net_run_ssl/ssl_cps.py b/pymic/net_run_ssl/ssl_cps.py index e39bec9..2264d0d 100644 --- a/pymic/net_run_ssl/ssl_cps.py +++ b/pymic/net_run_ssl/ssl_cps.py @@ -3,15 +3,30 @@ import logging import numpy as np import torch -import torch.optim as optim +import torch.nn as nn +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice -from pymic.net_run.get_optimizer import get_optimizer, get_lr_scheduler from pymic.net_run_ssl.ssl_abstract import SSLSegAgent from pymic.net.net_dict_seg import SegNetDict -from pymic.util.ramps import sigmoid_rampup -from pymic.util.general import keyword_match +from pymic.util.ramps import get_rampup_ratio + +class BiNet(nn.Module): + def __init__(self, params): + super(BiNet, self).__init__() + net_name = params['net_type'] + self.net1 = SegNetDict[net_name](params) + self.net2 = SegNetDict[net_name](params) + + def forward(self, x): + out1 = self.net1(x) + out2 = self.net2(x) + + if(self.training): + return out1, out2 + else: + return (out1 + out2) / 3 class SSLCPS(SSLSegAgent): """ @@ -23,48 +38,27 @@ class SSLCPS(SSLSegAgent): """ def __init__(self, config, stage = 'train'): super(SSLCPS, self).__init__(config, stage) - self.net2 = None - self.optimizer2 = None - self.scheduler2 = None def create_network(self): - super(SSLCPS, self).create_network() - if(self.net2 is None): - net_name = self.config['network']['net_type'] - if(net_name not in SegNetDict): - raise ValueError("Undefined network {0:}".format(net_name)) - self.net2 = SegNetDict[net_name](self.config['network']) + if(self.net is None): + self.net = BiNet(self.config['network']) if(self.tensor_type == 'float'): - self.net2.float() + self.net.float() else: - self.net2.double() - - def train_valid(self): - # create optimizor for the second network - opt_params = self.config['training'] - if(self.optimizer2 is None): - self.optimizer2 = get_optimizer(opt_params['optimizer'], - self.net2.parameters(), opt_params) - last_iter = -1 - # if(self.checkpoint is not None): - # self.optimizer2.load_state_dict(self.checkpoint['optimizer_state_dict']) - # last_iter = self.checkpoint['iteration'] - 1 - if(self.scheduler2 is None): - opt_params["laster_iter"] = last_iter - self.scheduler2 = get_lr_scheduler(self.optimizer2, opt_params) - super(SSLCPS, self).train_valid() + self.net.double() def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] ssl_cfg = self.config['semi_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = ssl_cfg.get('rampup_start', 0) + rampup_end = ssl_cfg.get('rampup_end', iter_max) train_loss = 0 train_loss_sup1, train_loss_pseudo_sup1 = 0, 0 train_loss_sup2, train_loss_pseudo_sup2 = 0, 0 train_dice_list = [] self.net.train() - self.net2.train() - self.net2.to(self.device) for it in range(iter_valid): try: data_lab = next(self.trainIter) @@ -86,9 +80,8 @@ def training(self): # zero the parameter gradients self.optimizer.zero_grad() - self.optimizer2.zero_grad() - outputs1, outputs2 = self.net(inputs), self.net2(inputs) + outputs1, outputs2 = self.net(inputs) outputs_soft1 = torch.softmax(outputs1, dim=1) outputs_soft2 = torch.softmax(outputs2, dim=1) @@ -106,13 +99,8 @@ def training(self): pse_sup1 = self.get_loss_value(data_unlab, outputs1[n0:], pse_prob2) pse_sup2 = self.get_loss_value(data_unlab, outputs2[n0:], pse_prob1) - iter_max = self.config['training']['iter_max'] - ramp_up_len = ssl_cfg.get('ramp_up_length', iter_max) - regular_w = 0.0 - if(self.glob_it > ssl_cfg.get('iter_sup', 0)): - regular_w = ssl_cfg.get('regularize_w', 0.1) - if(ramp_up_len is not None and self.glob_it < ramp_up_len): - regular_w = regular_w * sigmoid_rampup(self.glob_it, ramp_up_len) + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + regular_w = ssl_cfg.get('regularize_w', 0.1) * rampup_ratio model1_loss = loss_sup1 + regular_w * pse_sup1 model2_loss = loss_sup2 + regular_w * pse_sup2 @@ -120,10 +108,9 @@ def training(self): loss.backward() self.optimizer.step() - self.optimizer2.step() - if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): self.scheduler.step() - self.scheduler2.step() train_loss = train_loss + loss.item() train_loss_sup1 = train_loss_sup1 + loss_sup1.item() @@ -152,13 +139,7 @@ def training(self): 'loss_pse_sup1':train_avg_loss_pse_sup1, 'loss_pse_sup2': train_avg_loss_pse_sup2, 'regular_w':regular_w, 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} return train_scalers - - def validation(self): - return_value = super(SSLCPS, self).validation() - if(keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): - self.scheduler2.step(return_value['avg_dice']) - return return_value - + def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): loss_scalar ={'train':train_scalars['loss'], 'valid':valid_scalars['loss']} diff --git a/pymic/net_run_ssl/ssl_em.py b/pymic/net_run_ssl/ssl_em.py index e4b1f19..810a90c 100644 --- a/pymic/net_run_ssl/ssl_em.py +++ b/pymic/net_run_ssl/ssl_em.py @@ -3,14 +3,14 @@ import logging import numpy as np import torch +from torhc.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.ssl import EntropyLoss from pymic.net_run_ssl.ssl_abstract import SSLSegAgent from pymic.transform.trans_dict import TransformDict -from pymic.util.ramps import sigmoid_rampup -from pymic.util.general import keyword_match +from pymic.util.ramps import get_rampup_ratio class SSLEntropyMinimization(SSLSegAgent): """ @@ -29,6 +29,9 @@ def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] ssl_cfg = self.config['semi_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = ssl_cfg.get('rampup_start', 0) + rampup_end = ssl_cfg.get('rampup_end', iter_max) train_loss = 0 train_loss_sup = 0 train_loss_reg = 0 @@ -64,18 +67,15 @@ def training(self): loss_dict = {"prediction":outputs, 'softmax':True} loss_reg = EntropyLoss()(loss_dict) - iter_max = self.config['training']['iter_max'] - ramp_up_length = ssl_cfg.get('ramp_up_length', iter_max) - regular_w = 0.0 - if(self.glob_it > ssl_cfg.get('iter_sup', 0)): - regular_w = ssl_cfg.get('regularize_w', 0.1) - if(ramp_up_length is not None and self.glob_it < ramp_up_length): - regular_w = regular_w * sigmoid_rampup(self.glob_it, ramp_up_length) + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + regular_w = ssl_cfg.get('regularize_w', 0.1) * rampup_ratio + loss = loss_sup + regular_w*loss_reg # if (self.config['training']['use']) loss.backward() self.optimizer.step() - if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): self.scheduler.step() train_loss = train_loss + loss.item() diff --git a/pymic/net_run_ssl/ssl_mt.py b/pymic/net_run_ssl/ssl_mt.py index aa7fbff..0456726 100644 --- a/pymic/net_run_ssl/ssl_mt.py +++ b/pymic/net_run_ssl/ssl_mt.py @@ -3,13 +3,13 @@ import logging import torch import numpy as np +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.net_run_ssl.ssl_abstract import SSLSegAgent from pymic.net.net_dict_seg import SegNetDict -from pymic.util.ramps import sigmoid_rampup -from pymic.util.general import keyword_match +from pymic.util.ramps import get_rampup_ratio class SSLMeanTeacher(SSLSegAgent): """ @@ -39,6 +39,9 @@ def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] ssl_cfg = self.config['semi_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = ssl_cfg.get('rampup_start', 0) + rampup_end = ssl_cfg.get('rampup_end', iter_max) train_loss = 0 train_loss_sup = 0 train_loss_reg = 0 @@ -82,19 +85,16 @@ def training(self): outputs_ema = self.net_ema(inputs_ema) p1_ema_soft = torch.softmax(outputs_ema, dim=1) - iter_max = self.config['training']['iter_max'] - ramp_up_length = ssl_cfg.get('ramp_up_length', iter_max) - regular_w = 0.0 - if(self.glob_it > ssl_cfg.get('iter_sup', 0)): - regular_w = ssl_cfg.get('regularize_w', 0.1) - if(ramp_up_length is not None and self.glob_it < ramp_up_length): - regular_w = regular_w * sigmoid_rampup(self.glob_it, ramp_up_length) + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + regular_w = ssl_cfg.get('regularize_w', 0.1) * rampup_ratio + loss_reg = torch.nn.MSELoss()(p1_soft, p1_ema_soft) loss = loss_sup + regular_w*loss_reg loss.backward() self.optimizer.step() - if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): self.scheduler.step() # update EMA diff --git a/pymic/net_run_ssl/ssl_uamt.py b/pymic/net_run_ssl/ssl_uamt.py index 3352231..360dab1 100644 --- a/pymic/net_run_ssl/ssl_uamt.py +++ b/pymic/net_run_ssl/ssl_uamt.py @@ -3,12 +3,12 @@ import logging import torch import numpy as np +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.net_run_ssl.ssl_mt import SSLMeanTeacher -from pymic.util.ramps import sigmoid_rampup -from pymic.util.general import keyword_match +from pymic.util.ramps import get_rampup_ratio class SSLUncertaintyAwareMeanTeacher(SSLMeanTeacher): """ @@ -22,6 +22,9 @@ def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] ssl_cfg = self.config['semi_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = ssl_cfg.get('rampup_start', 0) + rampup_end = ssl_cfg.get('rampup_end', iter_max) train_loss = 0 train_loss_sup = 0 train_loss_reg = 0 @@ -81,24 +84,19 @@ def training(self): uncertainty = -1.0 * torch.sum(preds*torch.log(preds + 1e-6), dim=1, keepdim=True) - iter_max = self.config['training']['iter_max'] - ramp_up_length = ssl_cfg.get('ramp_up_length', iter_max) - threshold_ramp = sigmoid_rampup(self.glob_it, iter_max) + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") class_num = list(y0.shape)[1] - threshold = (0.75+0.25*threshold_ramp)*np.log(class_num) + threshold = (0.75+0.25*rampup_ratio)*np.log(class_num) mask = (uncertainty < threshold).float() loss_reg = torch.sum(mask*square_error)/(2*torch.sum(mask)+1e-16) - regular_w = 0.0 - if(self.glob_it > ssl_cfg.get('iter_sup', 0)): - regular_w = ssl_cfg.get('regularize_w', 0.1) - if(ramp_up_length is not None and self.glob_it < ramp_up_length): - regular_w = regular_w * sigmoid_rampup(self.glob_it, ramp_up_length) + regular_w = ssl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg loss.backward() self.optimizer.step() - if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): self.scheduler.step() diff --git a/pymic/net_run_ssl/ssl_urpc.py b/pymic/net_run_ssl/ssl_urpc.py index 5d269fd..d0179cd 100644 --- a/pymic/net_run_ssl/ssl_urpc.py +++ b/pymic/net_run_ssl/ssl_urpc.py @@ -4,12 +4,12 @@ import torch import torch.nn as nn import numpy as np +from torhc.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.net_run_ssl.ssl_abstract import SSLSegAgent -from pymic.util.ramps import sigmoid_rampup -from pymic.util.general import keyword_match +from pymic.util.ramps import get_rampup_ratio class SSLURPC(SSLSegAgent): """ @@ -24,6 +24,9 @@ def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] ssl_cfg = self.config['semi_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = ssl_cfg.get('rampup_start', 0) + rampup_end = ssl_cfg.get('rampup_end', iter_max) train_loss = 0 train_loss_sup = 0 train_loss_reg = 0 @@ -78,19 +81,14 @@ def training(self): loss_reg += loss_i loss_reg = loss_reg / len(outputs_list) - iter_max = self.config['training']['iter_max'] - ramp_up_length = ssl_cfg.get('ramp_up_length', iter_max) - regular_w = 0.0 - if(self.glob_it > ssl_cfg.get('iter_sup', 0)): - regular_w = ssl_cfg.get('regularize_w', 0.1) - if(ramp_up_length is not None and self.glob_it < ramp_up_length): - regular_w = regular_w * sigmoid_rampup(self.glob_it, ramp_up_length) - + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + regular_w = ssl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg loss.backward() self.optimizer.step() - if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): self.scheduler.step() train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() diff --git a/pymic/net_run_wsl/wsl_dmpls.py b/pymic/net_run_wsl/wsl_dmpls.py index 1e60e47..a198ddc 100644 --- a/pymic/net_run_wsl/wsl_dmpls.py +++ b/pymic/net_run_wsl/wsl_dmpls.py @@ -4,13 +4,13 @@ import numpy as np import random import torch +from torhc.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.dice import DiceLoss from pymic.net_run_wsl.wsl_abstract import WSLSegAgent -from pymic.util.ramps import sigmoid_rampup -from pymic.util.general import keyword_match +from pymic.util.ramps import get_rampup_ratio class WSLDMPLS(WSLSegAgent): """ @@ -32,6 +32,9 @@ def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] wsl_cfg = self.config['weakly_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = wsl_cfg.get('rampup_start', 0) + rampup_end = wsl_cfg.get('rampup_end', iter_max) train_loss = 0 train_loss_sup = 0 train_loss_reg = 0 @@ -73,18 +76,14 @@ def training(self): loss_dict2 = {"prediction":outputs2, 'ground_truth':pseudo_lab} loss_reg = 0.5 * (loss_calculator(loss_dict1) + loss_calculator(loss_dict2)) - iter_max = self.config['training']['iter_max'] - ramp_up_length = wsl_cfg.get('ramp_up_length', iter_max) - regular_w = 0.0 - if(self.glob_it > wsl_cfg.get('iter_sup', 0)): - regular_w = wsl_cfg.get('regularize_w', 0.1) - if(ramp_up_length is not None and self.glob_it < ramp_up_length): - regular_w = regular_w * sigmoid_rampup(self.glob_it, ramp_up_length) + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg loss.backward() self.optimizer.step() - if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): self.scheduler.step() train_loss = train_loss + loss.item() diff --git a/pymic/net_run_wsl/wsl_em.py b/pymic/net_run_wsl/wsl_em.py index 66823c1..3b2d595 100644 --- a/pymic/net_run_wsl/wsl_em.py +++ b/pymic/net_run_wsl/wsl_em.py @@ -3,18 +3,18 @@ import logging import numpy as np import torch +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.ssl import EntropyLoss from pymic.net_run.agent_seg import SegmentationAgent from pymic.net_run_wsl.wsl_abstract import WSLSegAgent -from pymic.util.ramps import sigmoid_rampup -from pymic.util.general import keyword_match +from pymic.util.ramps import get_rampup_ratio class WSLEntropyMinimization(WSLSegAgent): """ - Training and testing agent for semi-supervised segmentation + Weakly suepervised segmentation with Entropy Minimization Regularization. """ def __init__(self, config, stage = 'train'): super(WSLEntropyMinimization, self).__init__(config, stage) @@ -23,6 +23,9 @@ def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] wsl_cfg = self.config['weakly_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = wsl_cfg.get('rampup_start', 0) + rampup_end = wsl_cfg.get('rampup_end', iter_max) train_loss = 0 train_loss_sup = 0 train_loss_reg = 0 @@ -50,18 +53,14 @@ def training(self): loss_dict= {"prediction":outputs, 'softmax':True} loss_reg = EntropyLoss()(loss_dict) - iter_max = self.config['training']['iter_max'] - ramp_up_length = wsl_cfg.get('ramp_up_length', iter_max) - regular_w = 0.0 - if(self.glob_it > wsl_cfg.get('iter_sup', 0)): - regular_w = wsl_cfg.get('regularize_w', 0.1) - if(ramp_up_length is not None and self.glob_it < ramp_up_length): - regular_w = regular_w * sigmoid_rampup(self.glob_it, ramp_up_length) + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg loss.backward() self.optimizer.step() - if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): self.scheduler.step() train_loss = train_loss + loss.item() diff --git a/pymic/net_run_wsl/wsl_gatedcrf.py b/pymic/net_run_wsl/wsl_gatedcrf.py index 2ae8318..2be8856 100644 --- a/pymic/net_run_wsl/wsl_gatedcrf.py +++ b/pymic/net_run_wsl/wsl_gatedcrf.py @@ -3,13 +3,13 @@ import logging import numpy as np import torch +from torhc.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.gatedcrf import ModelLossSemsegGatedCRF from pymic.net_run_wsl.wsl_abstract import WSLSegAgent -from pymic.util.ramps import sigmoid_rampup -from pymic.util.general import keyword_match +from pymic.util.ramps import get_rampup_ratio class WSLGatedCRF(WSLSegAgent): """ @@ -38,6 +38,9 @@ def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] wsl_cfg = self.config['weakly_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = wsl_cfg.get('rampup_start', 0) + rampup_end = wsl_cfg.get('rampup_end', iter_max) train_loss = 0 train_loss_sup = 0 train_loss_reg = 0 @@ -81,18 +84,14 @@ def training(self): loss_reg = gatecrf_loss(outputs_soft, self.kernels, self.radius, batch_dict,input_shape[-2], input_shape[-1])["loss"] - iter_max = self.config['training']['iter_max'] - ramp_up_length = wsl_cfg.get('ramp_up_length', iter_max) - regular_w = 0.0 - if(self.glob_it > wsl_cfg.get('iter_sup', 0)): - regular_w = wsl_cfg.get('regularize_w', 0.1) - if(ramp_up_length is not None and self.glob_it < ramp_up_length): - regular_w = regular_w * sigmoid_rampup(self.glob_it, ramp_up_length) + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg loss.backward() self.optimizer.step() - if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): self.scheduler.step() train_loss = train_loss + loss.item() diff --git a/pymic/net_run_wsl/wsl_mumford_shah.py b/pymic/net_run_wsl/wsl_mumford_shah.py index 095a0f6..df4c68f 100644 --- a/pymic/net_run_wsl/wsl_mumford_shah.py +++ b/pymic/net_run_wsl/wsl_mumford_shah.py @@ -3,13 +3,13 @@ import logging import numpy as np import torch +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.mumford_shah import MumfordShahLoss from pymic.net_run_wsl.wsl_abstract import WSLSegAgent -from pymic.util.ramps import sigmoid_rampup -from pymic.util.general import keyword_match +from pymic.util.ramps import get_rampup_ratio class WSLMumfordShah(WSLSegAgent): """ @@ -25,6 +25,9 @@ def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] wsl_cfg = self.config['weakly_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = wsl_cfg.get('rampup_start', 0) + rampup_end = wsl_cfg.get('rampup_end', iter_max) train_loss = 0 train_loss_sup = 0 train_loss_reg = 0 @@ -54,18 +57,14 @@ def training(self): loss_dict = {"prediction":outputs, 'image':inputs} loss_reg = reg_loss_calculator(loss_dict) - iter_max = self.config['training']['iter_max'] - ramp_up_length = wsl_cfg.get('ramp_up_length', iter_max) - regular_w = 0.0 - if(self.glob_it > wsl_cfg.get('iter_sup', 0)): - regular_w = wsl_cfg.get('regularize_w', 0.1) - if(ramp_up_length is not None and self.glob_it < ramp_up_length): - regular_w = regular_w * sigmoid_rampup(self.glob_it, ramp_up_length) + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg # if (self.config['training']['use']) loss.backward() self.optimizer.step() - if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): self.scheduler.step() train_loss = train_loss + loss.item() diff --git a/pymic/net_run_wsl/wsl_tv.py b/pymic/net_run_wsl/wsl_tv.py index fde1c10..2e56cb4 100644 --- a/pymic/net_run_wsl/wsl_tv.py +++ b/pymic/net_run_wsl/wsl_tv.py @@ -3,12 +3,13 @@ import logging import numpy as np import torch +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.ssl import TotalVariationLoss from pymic.net_run_wsl.wsl_abstract import WSLSegAgent -from pymic.util.ramps import sigmoid_rampup +from pymic.util.ramps import get_rampup_ratio from pymic.util.general import keyword_match class WSLTotalVariation(WSLSegAgent): @@ -22,6 +23,9 @@ def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] wsl_cfg = self.config['weakly_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = wsl_cfg.get('rampup_start', 0) + rampup_end = wsl_cfg.get('rampup_end', iter_max) train_loss = 0 train_loss_sup = 0 train_loss_reg = 0 @@ -49,18 +53,14 @@ def training(self): loss_dict = {"prediction":outputs, 'softmax':True} loss_reg = TotalVariationLoss()(loss_dict) - iter_max = self.config['training']['iter_max'] - ramp_up_length = wsl_cfg.get('ramp_up_length', iter_max) - regular_w = 0.0 - if(self.glob_it > wsl_cfg.get('iter_sup', 0)): - regular_w = wsl_cfg.get('regularize_w', 0.1) - if(ramp_up_length is not None and self.glob_it < ramp_up_length): - regular_w = regular_w * sigmoid_rampup(self.glob_it, ramp_up_length) + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg # if (self.config['training']['use']) loss.backward() self.optimizer.step() - if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): self.scheduler.step() train_loss = train_loss + loss.item() diff --git a/pymic/net_run_wsl/wsl_ustm.py b/pymic/net_run_wsl/wsl_ustm.py index 6083069..0a2f7e1 100644 --- a/pymic/net_run_wsl/wsl_ustm.py +++ b/pymic/net_run_wsl/wsl_ustm.py @@ -5,12 +5,13 @@ import random import torch import torch.nn.functional as F +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.net.net_dict_seg import SegNetDict from pymic.net_run_wsl.wsl_abstract import WSLSegAgent -from pymic.util.ramps import sigmoid_rampup +from pymic.util.ramps import get_rampup_ratio from pymic.util.general import keyword_match class WSLUSTM(WSLSegAgent): @@ -42,6 +43,9 @@ def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] wsl_cfg = self.config['weakly_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = wsl_cfg.get('rampup_start', 0) + rampup_end = wsl_cfg.get('rampup_end', iter_max) train_loss = 0 train_loss_sup = 0 train_loss_reg = 0 @@ -97,24 +101,19 @@ def training(self): uncertainty = -1.0 * torch.sum(preds*torch.log(preds + 1e-6), dim=1, keepdim=True) - iter_max = self.config['training']['iter_max'] - ramp_up_length = wsl_cfg.get('ramp_up_length', iter_max) - threshold_ramp = sigmoid_rampup(self.glob_it, iter_max) + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") class_num = list(y.shape)[1] - threshold = (0.75+0.25*threshold_ramp)*np.log(class_num) + threshold = (0.75+0.25*rampup_ratio)*np.log(class_num) mask = (uncertainty < threshold).float() loss_reg = torch.sum(mask*square_error)/(2*torch.sum(mask)+1e-16) - regular_w = 0.0 - if(self.glob_it > wsl_cfg.get('iter_sup', 0)): - regular_w = wsl_cfg.get('regularize_w', 0.1) - if(ramp_up_length is not None and self.glob_it < ramp_up_length): - regular_w = regular_w * sigmoid_rampup(self.glob_it, ramp_up_length) + regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg loss.backward() self.optimizer.step() - if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): self.scheduler.step() # update EMA diff --git a/pymic/util/ramps.py b/pymic/util/ramps.py index e344cfe..b58adb6 100644 --- a/pymic/util/ramps.py +++ b/pymic/util/ramps.py @@ -10,24 +10,21 @@ 0 and 1. """ -def sigmoid_rampup(i, length): - """Exponential rampup from https://arxiv.org/abs/1610.02242""" - if length == 0: - return 1.0 - else: - i = np.clip(i, 0.0, length) - phase = 1.0 - (i + 0.0) / length - return float(np.exp(-5.0 * phase * phase)) - -def linear_rampup(i, length): - """Linear rampup""" - assert i >= 0 and length >= 0 - i = np.clip(i, 0.0, length) - return (i + 0.0) / length +def get_rampup_ratio(i, start, end, mode = "linear"): + if( i < start): + rampup = 0.0 + elif(i > end): + rampup = 1.0 + elif(mode == "linear"): + rampup = (i - start) / (end - start) + elif(mode == "sigmoid"): + phase = 1.0 - (i - start) / (end - start) + rampup = float(np.exp(-5.0 * phase * phase)) + return rampup -def cosine_rampdown(i, length): +def cosine_rampdown(i, start, end): """Cosine rampdown from https://arxiv.org/abs/1608.03983""" i = np.clip(i, 0.0, length) return float(.5 * (np.cos(np.pi * i / length) + 1)) \ No newline at end of file From 04567f1cc4c1abf23dc97c83a8f5168c260ead2c Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 18 Aug 2022 16:00:45 +0800 Subject: [PATCH 026/225] update log name update log name --- pymic/net_run/net_run.py | 4 ++-- pymic/net_run_nll/nll_main.py | 2 +- pymic/net_run_ssl/ssl_main.py | 2 +- pymic/net_run_wsl/wsl_main.py | 2 +- pymic/util/average_model.py | 1 + 5 files changed, 6 insertions(+), 5 deletions(-) diff --git a/pymic/net_run/net_run.py b/pymic/net_run/net_run.py index 4ec1ce7..971af7e 100644 --- a/pymic/net_run/net_run.py +++ b/pymic/net_run/net_run.py @@ -18,8 +18,8 @@ def main(): config = synchronize_config(config) log_dir = config['training']['ckpt_save_dir'] if(not os.path.exists(log_dir)): - os.mkdir(log_dir) - logging.basicConfig(filename=log_dir+"/log.txt", level=logging.INFO, + os.makedirs(log_dir, exist_ok=True) + logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, format='%(message)s') logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging_config(config) diff --git a/pymic/net_run_nll/nll_main.py b/pymic/net_run_nll/nll_main.py index d1ae7a1..7d8f1f8 100644 --- a/pymic/net_run_nll/nll_main.py +++ b/pymic/net_run_nll/nll_main.py @@ -23,7 +23,7 @@ def main(): log_dir = config['training']['ckpt_save_dir'] if(not os.path.exists(log_dir)): os.mkdir(log_dir) - logging.basicConfig(filename=log_dir+"/log.txt", level=logging.INFO, + logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, format='%(message)s') logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging_config(config) diff --git a/pymic/net_run_ssl/ssl_main.py b/pymic/net_run_ssl/ssl_main.py index 6bddf29..d904ab1 100644 --- a/pymic/net_run_ssl/ssl_main.py +++ b/pymic/net_run_ssl/ssl_main.py @@ -32,7 +32,7 @@ def main(): log_dir = config['training']['ckpt_save_dir'] if(not os.path.exists(log_dir)): os.mkdir(log_dir) - logging.basicConfig(filename=log_dir+"/log.txt", level=logging.INFO, + logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, format='%(message)s') logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging_config(config) diff --git a/pymic/net_run_wsl/wsl_main.py b/pymic/net_run_wsl/wsl_main.py index 916e1d8..abedb6b 100644 --- a/pymic/net_run_wsl/wsl_main.py +++ b/pymic/net_run_wsl/wsl_main.py @@ -31,7 +31,7 @@ def main(): log_dir = config['training']['ckpt_save_dir'] if(not os.path.exists(log_dir)): os.mkdir(log_dir) - logging.basicConfig(filename=log_dir+"/log.txt", level=logging.INFO, + logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, format='%(message)s') logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging_config(config) diff --git a/pymic/util/average_model.py b/pymic/util/average_model.py index 0b6fb29..73a537f 100644 --- a/pymic/util/average_model.py +++ b/pymic/util/average_model.py @@ -1,3 +1,4 @@ + import torch checkpoint_name1 = "/home/guotai/projects/PyMIC/examples/brats/model/casecade/wt/unet3d_4_8000.pt" From 137c7ebec02482761a70a02db345c9e3a9ec935e Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 19 Aug 2022 11:57:36 +0800 Subject: [PATCH 027/225] fix typo --- pymic/net/net3d/unet3d.py | 3 +- pymic/net_run/agent_seg.py | 4 +- pymic/net_run/net_run.py | 2 +- pymic/net_run_nll/nll_clslsr.py | 191 ++++++++++++++++++++++++++++++++ pymic/net_run_ssl/ssl_em.py | 2 +- pymic/net_run_ssl/ssl_urpc.py | 2 +- 6 files changed, 197 insertions(+), 7 deletions(-) create mode 100644 pymic/net_run_nll/nll_clslsr.py diff --git a/pymic/net/net3d/unet3d.py b/pymic/net/net3d/unet3d.py index 058cb79..fdedf4d 100644 --- a/pymic/net/net3d/unet3d.py +++ b/pymic/net/net3d/unet3d.py @@ -96,7 +96,6 @@ def __init__(self, params): self.n_class = self.params['class_num'] self.trilinear = self.params['trilinear'] self.deep_sup = self.params['deep_supervise'] - self.stage = self.params['stage'] assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) @@ -134,7 +133,7 @@ def forward(self, x): x_d1 = self.up3(x_d2, x1) x_d0 = self.up4(x_d1, x0) output = self.out_conv(x_d0) - if(self.deep_sup and self.stage == "train"): + if(self.deep_sup): out_shape = list(output.shape)[2:] output1 = self.out_conv1(x_d1) output1 = interpolate(output1, out_shape, mode = 'trilinear') diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index d8e74a2..5a53f3e 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -307,7 +307,7 @@ def train_valid(self): elif(isinstance(iter_save, (tuple, list))): iter_save_list = iter_save else: - iter_save_list = range(iter_start, iter_max + 1, iter_save) + iter_save_list = range(0, iter_max + 1, iter_save) self.max_val_dice = 0.0 self.max_val_it = 0 @@ -519,7 +519,7 @@ def save_ouputs(self, data): filename_replace_source = self.config['testing'].get('filename_replace_source', None) filename_replace_target = self.config['testing'].get('filename_replace_target', None) if(not os.path.exists(output_dir)): - os.mkdir(output_dir) + os.makedirs(output_dir, exist_ok=True) names, pred = data['names'], data['predict'] if(isinstance(pred, (list, tuple))): diff --git a/pymic/net_run/net_run.py b/pymic/net_run/net_run.py index 971af7e..4c953ad 100644 --- a/pymic/net_run/net_run.py +++ b/pymic/net_run/net_run.py @@ -10,7 +10,7 @@ def main(): if(len(sys.argv) < 3): print('Number of arguments should be 3. e.g.') - print(' pymic_net_run train config.cfg') + print(' pymic_run train config.cfg') exit() stage = str(sys.argv[1]) cfg_file = str(sys.argv[2]) diff --git a/pymic/net_run_nll/nll_clslsr.py b/pymic/net_run_nll/nll_clslsr.py new file mode 100644 index 0000000..2894db6 --- /dev/null +++ b/pymic/net_run_nll/nll_clslsr.py @@ -0,0 +1,191 @@ +# -*- coding: utf-8 -*- +""" +Caculating the confidence map of labels of training samples, +which is used in the method of SLSR. + Minqing Zhang et al., Characterizing Label Errors: Confident Learning + for Noisy-Labeled Image Segmentation, MICCAI 2020. +""" + +from __future__ import print_function, division +import cleanlab +import logging +import os +import scipy +import sys +import torch +import numpy as np +import pandas as pd +import torch.nn as nn +import torchvision.transforms as transforms +from PIL import Image +from pymic.io.nifty_dataset import NiftyDataset +from pymic.transform.trans_dict import TransformDict +from pymic.util.parse_config import * +from pymic.net_run.agent_seg import SegmentationAgent +from pymic.net_run.infer_func import Inferer + +def get_confident_map(gt, pred, CL_type = 'both'): + """ + gt: ground truth label (one-hot) with shape of NXC + pred: digit prediction of network with shape of NXC + """ + prob = scipy.special.softmax(pred, axis = 1) + if CL_type in ['both', 'Qij']: + noise = cleanlab.pruning.get_noise_indices(gt, prob, prune_method='both', n_jobs=1) + elif CL_type == 'Cij': + noise = cleanlab.pruning.get_noise_indices(gt, pred, prune_method='both', n_jobs=1) + elif CL_type == 'intersection': + noise_qij = cleanlab.pruning.get_noise_indices(gt, prob, prune_method='both', n_jobs=1) + noise_cij = cleanlab.pruning.get_noise_indices(gt, pred, prune_method='both', n_jobs=1) + noise = noise_qij & noise_cij + elif CL_type == 'union': + noise_qij = cleanlab.pruning.get_noise_indices(gt, prob, prune_method='both', n_jobs=1) + noise_cij = cleanlab.pruning.get_noise_indices(gt, pred, prune_method='both', n_jobs=1) + noise = noise_qij | noise_cij + elif CL_type in ['prune_by_class', 'prune_by_noise_rate']: + noise = cleanlab.pruning.get_noise_indices(gt, prob, prune_method=CL_type, n_jobs=1) + return noise + +class NLLCLSLSR(SegmentationAgent): + def __init__(self, config, stage = 'test'): + super(NLLCLSLSR, self).__init__(config, stage) + + def infer_with_cl(self): + device_ids = self.config['testing']['gpus'] + device = torch.device("cuda:{0:}".format(device_ids[0])) + self.net.to(device) + + if(self.config['testing'].get('evaluation_mode', True)): + self.net.eval() + if(self.config['testing'].get('test_time_dropout', False)): + def test_time_dropout(m): + if(type(m) == nn.Dropout): + logging.info('dropout layer') + m.train() + self.net.apply(test_time_dropout) + + ckpt_mode = self.config['testing']['ckpt_mode'] + ckpt_name = self.get_checkpoint_name() + if(ckpt_mode == 3): + assert(isinstance(ckpt_name, (tuple, list))) + self.infer_with_multiple_checkpoints() + return + else: + if(isinstance(ckpt_name, (tuple, list))): + raise ValueError("ckpt_mode should be 3 if ckpt_name is a list") + + # load network parameters and set the network as evaluation mode + checkpoint = torch.load(ckpt_name, map_location = device) + self.net.load_state_dict(checkpoint['model_state_dict']) + + if(self.inferer is None): + infer_cfg = self.config['testing'] + class_num = self.config['network']['class_num'] + infer_cfg['class_num'] = class_num + self.inferer = Inferer(infer_cfg) + pred_list = [] + gt_list = [] + filename_list = [] + with torch.no_grad(): + for data in self.test_loader: + images = self.convert_tensor_type(data['image']) + labels = self.convert_tensor_type(data['label_prob']) + names = data['names'] + filename_list.append(names) + images = images.to(device) + + pred = self.inferer.run(self.net, images) + # convert tensor to numpy + if(isinstance(pred, (tuple, list))): + pred = [item.cpu().numpy() for item in pred] + else: + pred = pred.cpu().numpy() + data['predict'] = pred + # inverse transform + for transform in self.transform_list[::-1]: + if (transform.inverse): + data = transform.inverse_transform_for_prediction(data) + + pred = data['predict'] + # conver prediction from N, C, H, W to (N*H*W)*C + print(names, pred.shape, labels.shape) + pred_2d = np.swapaxes(pred, 1, 2) + pred_2d = np.swapaxes(pred_2d, 2, 3) + pred_2d = pred_2d.reshape(-1, class_num) + lab = labels.cpu().numpy() + lab_2d = np.swapaxes(lab, 1, 2) + lab_2d = np.swapaxes(lab_2d, 2, 3) + lab_2d = lab_2d.reshape(-1, class_num) + pred_list.append(pred_2d) + gt_list.append(lab_2d) + + pred_cat = np.concatenate(pred_list) + gt_cat = np.concatenate(gt_list) + gt = np.argmax(gt_cat, axis = 1) + gt = gt.reshape(-1).astype(np.uint8) + print(gt.shape, pred_cat.shape) + conf = get_confident_map(gt, pred_cat) + conf = conf.reshape(-1, 256, 256).astype(np.uint8) * 255 + save_dir = self.config['dataset']['root_dir'] + "/slsr_conf" + for idx in range(len(filename_list)): + filename = filename_list[idx][0].split('/')[-1] + conf_map = Image.fromarray(conf[idx]) + dst_path = os.path.join(save_dir, filename) + conf_map.save(dst_path) + +def get_confidence_map(): + if(len(sys.argv) < 2): + print('Number of arguments should be 3. e.g.') + print(' python nll_cl.py config.cfg') + exit() + cfg_file = str(sys.argv[1]) + config = parse_config(cfg_file) + config = synchronize_config(config) + + # set dataset + transform_names = config['dataset']['valid_transform'] + transform_list = [] + transform_dict = TransformDict + if(transform_names is None or len(transform_names) == 0): + data_transform = None + else: + transform_param = config['dataset'] + transform_param['task'] = 'segmentation' + for name in transform_names: + if(name not in transform_dict): + raise(ValueError("Undefined transform {0:}".format(name))) + one_transform = transform_dict[name](transform_param) + transform_list.append(one_transform) + data_transform = transforms.Compose(transform_list) + print('transform list', transform_list) + csv_file = config['dataset']['train_csv'] + modal_num = config['dataset'].get('modal_num', 1) + dataset = NiftyDataset(root_dir = config['dataset']['root_dir'], + csv_file = csv_file, + modal_num = modal_num, + with_label= True, + transform = data_transform ) + + agent = NLLCLSLSR(config, 'test') + agent.set_datasets(None, None, dataset) + agent.transform_list = transform_list + agent.create_dataset() + agent.create_network() + agent.infer_with_cl() + + # create training csv for confidence learning + df_train = pd.read_csv(csv_file) + pixel_weight = [] + for i in range(len(df_train["label"])): + lab_name = df_train["label"][i].split('/')[-1] + weight_name = "slsr_conf/" + lab_name + pixel_weight.append(weight_name) + train_cl_dict = {"image": df_train["image"], + "pixel_weight": pixel_weight, + "label": df_train["label"]} + train_cl_csv = csv_file.replace(".csv", "_clslsr.csv") + df_cl = pd.DataFrame.from_dict(train_cl_dict) + df_cl.to_csv(train_cl_csv, index = False) + +if __name__ == "__main__": + get_confidence_map() \ No newline at end of file diff --git a/pymic/net_run_ssl/ssl_em.py b/pymic/net_run_ssl/ssl_em.py index 810a90c..49dd22f 100644 --- a/pymic/net_run_ssl/ssl_em.py +++ b/pymic/net_run_ssl/ssl_em.py @@ -3,7 +3,7 @@ import logging import numpy as np import torch -from torhc.optim import lr_scheduler +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice diff --git a/pymic/net_run_ssl/ssl_urpc.py b/pymic/net_run_ssl/ssl_urpc.py index d0179cd..20b3d84 100644 --- a/pymic/net_run_ssl/ssl_urpc.py +++ b/pymic/net_run_ssl/ssl_urpc.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn import numpy as np -from torhc.optim import lr_scheduler +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice From 303e624b80e55ddd6d478a67311f299e32b6e8a3 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 19 Aug 2022 16:45:45 +0800 Subject: [PATCH 028/225] add dast for nll add dast for nll set the output model of dual branch network --- pymic/loss/seg/ce.py | 2 +- pymic/net/net2d/unet2d_dual_branch.py | 51 +---- pymic/net_run_nll/nll_dast.py | 260 ++++++++++++++++++++++++++ pymic/net_run_nll/nll_main.py | 4 +- 4 files changed, 271 insertions(+), 46 deletions(-) create mode 100644 pymic/net_run_nll/nll_dast.py diff --git a/pymic/loss/seg/ce.py b/pymic/loss/seg/ce.py index da2bf14..cdef1a0 100644 --- a/pymic/loss/seg/ce.py +++ b/pymic/loss/seg/ce.py @@ -6,7 +6,7 @@ from pymic.loss.seg.util import reshape_tensor_to_2D class CrossEntropyLoss(nn.Module): - def __init__(self, params): + def __init__(self, params = None): super(CrossEntropyLoss, self).__init__() if(params is None): self.softmax = True diff --git a/pymic/net/net2d/unet2d_dual_branch.py b/pymic/net/net2d/unet2d_dual_branch.py index 9622bd0..3531c89 100644 --- a/pymic/net/net2d/unet2d_dual_branch.py +++ b/pymic/net/net2d/unet2d_dual_branch.py @@ -16,6 +16,7 @@ class UNet2D_DualBranch(nn.Module): def __init__(self, params): super(UNet2D_DualBranch, self).__init__() + self.output_mode = params.get("output_mode", "average") self.encoder = Encoder(params) self.decoder1 = Decoder(params) self.decoder2 = Decoder(params) @@ -41,47 +42,9 @@ def forward(self, x): if(self.training): return output1, output2 else: - return (output1 + output2)/2 - # for backup -class DualBranchUNet2D(UNet2D): - def __init__(self, params): - params['deep_supervise'] = False - super(DualBranchUNet2D, self).__init__(params) - if(len(self.ft_chns) == 5): - self.up1_aux = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], 0.0, self.bilinear) - self.up2_aux = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], 0.0, self.bilinear) - self.up3_aux = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], 0.0, self.bilinear) - self.up4_aux = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], 0.0, self.bilinear) - - self.out_conv_aux = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 1) - - def forward(self, x): - x_shape = list(x.shape) - if(len(x_shape) == 5): - [N, C, D, H, W] = x_shape - new_shape = [N*D, C, H, W] - x = torch.transpose(x, 1, 2) - x = torch.reshape(x, new_shape) - - x0 = self.in_conv(x) - x1 = self.down1(x0) - x2 = self.down2(x1) - x3 = self.down3(x2) - if(len(self.ft_chns) == 5): - x4 = self.down4(x3) - x_d3, x_d3_aux = self.up1(x4, x3), self.up1_aux(x4, x3) - else: - x_d3, x_d3_aux = x3, x3 - - x_d2, x_d2_aux = self.up2(x_d3, x2), self.up2_aux(x_d3_aux, x2) - x_d1, x_d1_aux = self.up3(x_d2, x1), self.up3_aux(x_d2_aux, x1) - x_d0, x_d0_aux = self.up4(x_d1, x0), self.up4_aux(x_d1_aux, x0) - output, output_aux = self.out_conv(x_d0), self.out_conv_aux(x_d0_aux) - - if(len(x_shape) == 5): - new_shape = [N, D] + list(output.shape)[1:] - output = torch.reshape(output, new_shape) - output = torch.transpose(output, 1, 2) - output_aux = torch.reshape(output_aux, new_shape) - output_aux = torch.transpose(output_aux, 1, 2) - return output, output_aux \ No newline at end of file + if(self.output_mode == "average"): + return (output1 + output2)/2 + elif(self.output_mode == "first"): + return output1 + else: + return output2 diff --git a/pymic/net_run_nll/nll_dast.py b/pymic/net_run_nll/nll_dast.py new file mode 100644 index 0000000..d95eec0 --- /dev/null +++ b/pymic/net_run_nll/nll_dast.py @@ -0,0 +1,260 @@ +# -*- coding: utf-8 -*- +""" +Implementation of DAST for noise robust learning according to the following paper. + Shuojue Yang, Guotai Wang, Hui Sun, Xiangde Luo, Peng Sun, Kang Li, Qijun Wang, + Shaoting Zhang: Learning COVID-19 Pneumonia Lesion Segmentation from Imperfect + Annotations via Divergence-Aware Selective Training. + JBHI 2022. https://ieeexplore.ieee.org/document/9770406 +""" + +from __future__ import print_function, division +import random +import torch +import numpy as np +import torch.nn as nn +import torchvision.transforms as transforms +from torch.optim import lr_scheduler +from pymic.io.nifty_dataset import NiftyDataset +from pymic.loss.seg.util import get_soft_label +from pymic.loss.seg.util import reshape_prediction_and_ground_truth +from pymic.loss.seg.util import get_classwise_dice +from pymic.net_run.agent_seg import SegmentationAgent +from pymic.util.parse_config import * +from pymic.util.ramps import get_rampup_ratio + +class Rank(object): + """ + Dynamically rank the current training sample with specific metrics + """ + def __init__(self, quene_length = 100): + self.vals = [] + self.quene_length = quene_length + + def add_val(self, val): + """ + Update the quene and calculate the order of the input value. + + Return + --------- + rank: rank of the input value with a range of (0, self.quenen_length) + """ + if len(self.vals) < self.quene_length: + self.vals.append(val) + rank = -1 + else: + self.vals.pop(0) + self.vals.append(val) + assert len(self.vals) == self.quene_length + idxes = np.argsort(self.vals) + rank = np.where(idxes == self.quene_length-1)[0][0] + return rank + +class ConsistLoss(nn.Module): + def __init__(self): + super(ConsistLoss, self).__init__() + + def kl_div_map(self, input, label): + kl_map = torch.sum(label * (torch.log(label + 1e-16) - torch.log(input + 1e-16)), dim = 1) + return kl_map + + def kl_loss(self,input, target, size_average=True): + kl_div = self.kl_div_map(input, target) + if size_average: + return torch.mean(kl_div) + else: + return kl_div + + def forward(self, input1, input2, size_average = True): + kl1 = self.kl_loss(input1, input2.detach(), size_average=size_average) + kl2 = self.kl_loss(input2, input1.detach(), size_average=size_average) + return (kl1 + kl2) / 2 + +def get_ce(prob, soft_y, size_avg = True): + prob = prob * 0.999 + 5e-4 + ce = - soft_y* torch.log(prob) + ce = torch.sum(ce, dim = 1) # shape is [N] + if(size_avg): + ce = torch.mean(ce) + return ce + +@torch.no_grad() +def select_criterion(no_noisy_sample, cl_noisy_sample, label): + """ + no_noisy_sample: noisy branch's output probability for noisy sample + cl_noisy_sample: clean branch's output probability for noisy sample + label: noisy label + """ + l_n = get_ce(no_noisy_sample, label, size_avg = False) + l_c = get_ce(cl_noisy_sample, label, size_avg = False) + js_distance = ConsistLoss() + variance = js_distance(no_noisy_sample, cl_noisy_sample, size_average=False) + exp_variance = torch.exp(-16 * variance) + loss_n = torch.mean(l_c * exp_variance).item() + loss_c = torch.mean(l_n * exp_variance).item() + return loss_n, loss_c + +class NLLDAST(SegmentationAgent): + def __init__(self, config, stage = 'train'): + super(NLLDAST, self).__init__(config, stage) + self.train_set_noise = None + self.train_loader_noise = None + self.trainIter_noise = None + self.noisy_rank = None + self.clean_rank = None + + def get_noisy_dataset_from_config(self): + root_dir = self.config['dataset']['root_dir'] + modal_num = self.config['dataset'].get('modal_num', 1) + transform_names = self.config['dataset']['train_transform'] + + self.transform_list = [] + if(transform_names is None or len(transform_names) == 0): + data_transform = None + else: + transform_param = self.config['dataset'] + transform_param['task'] = 'segmentation' + for name in transform_names: + if(name not in self.transform_dict): + raise(ValueError("Undefined transform {0:}".format(name))) + one_transform = self.transform_dict[name](transform_param) + self.transform_list.append(one_transform) + data_transform = transforms.Compose(self.transform_list) + + csv_file = self.config['dataset'].get('train_csv_noise', None) + dataset = NiftyDataset(root_dir=root_dir, + csv_file = csv_file, + modal_num = modal_num, + with_label= True, + transform = data_transform ) + return dataset + + def create_dataset(self): + super(NLLDAST, self).create_dataset() + if(self.stage == 'train'): + if(self.train_set_noise is None): + self.train_set_noise = self.get_noisy_dataset_from_config() + if(self.deterministic): + def worker_init_fn(worker_id): + random.seed(self.random_seed + worker_id) + worker_init = worker_init_fn + else: + worker_init = None + + bn_train_noise = self.config['dataset']['train_batch_size_noise'] + num_worker = self.config['dataset'].get('num_workder', 16) + self.train_loader_noise = torch.utils.data.DataLoader(self.train_set_noise, + batch_size = bn_train_noise, shuffle=True, num_workers= num_worker, + worker_init_fn=worker_init) + + def training(self): + class_num = self.config['network']['class_num'] + iter_valid = self.config['training']['iter_valid'] + nll_cfg = self.config['noisy_label_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = nll_cfg.get('rampup_start', 0) + rampup_end = nll_cfg.get('rampup_end', iter_max) + train_loss = 0 + train_loss_sup = 0 + train_loss_reg = 0 + train_dice_list = [] + self.net.train() + + rank_length = nll_cfg.get("dast_rank_length", 20) + consist_loss = ConsistLoss() + for it in range(iter_valid): + try: + data_cl = next(self.trainIter) + except StopIteration: + self.trainIter = iter(self.train_loader) + data_cl = next(self.trainIter) + try: + data_no = next(self.trainIter_noise) + except StopIteration: + self.trainIter_noise = iter(self.train_loader_noise) + data_no = next(self.trainIter_noise) + + # get the inputs + x0 = self.convert_tensor_type(data_cl['image']) # clean sample + y0 = self.convert_tensor_type(data_cl['label_prob']) + x1 = self.convert_tensor_type(data_no['image']) # noisy sample + y1 = self.convert_tensor_type(data_no['label_prob']) + inputs = torch.cat([x0, x1], dim = 0).to(self.device) + y0, y1 = y0.to(self.device), y1.to(self.device) + + # zero the parameter gradients + self.optimizer.zero_grad() + + # forward + backward + optimize + b0_pred, b1_pred = self.net(inputs) + n0 = list(x0.shape)[0] # number of clean samples + b0_x0_pred = b0_pred[:n0] # predication of clean samples from clean branch + b0_x1_pred = b0_pred[n0:] # predication of noisy samples from clean branch + b1_x1_pred = b1_pred[n0:] # predication of noisy samples from noisy branch + + # supervised loss for the clean and noisy branches, respectively + loss_sup_cl = self.get_loss_value(data_cl, b0_x0_pred, y0) + loss_sup_no = self.get_loss_value(data_no, b1_x1_pred, y1) + loss_sup = (loss_sup_cl + loss_sup_no) / 2 + loss = loss_sup + + # Severe Noise supression & Supplementary Training + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + w_dbc = nll_cfg.get('dast_dbc_w', 0.1) * rampup_ratio + w_st = nll_cfg.get('dast_st_w', 0.1) * rampup_ratio + b1_x1_prob = nn.Softmax(dim = 1)(b1_x1_pred) + b0_x1_prob = nn.Softmax(dim = 1)(b0_x1_pred) + loss_n, loss_c = select_criterion(b1_x1_prob, b0_x1_prob, y1) + rank_n = self.noisy_rank.add_val(loss_n) + rank_c = self.clean_rank.add_val(loss_c) + if loss_n < loss_c: + if rank_c >= rank_length * 0.8: + loss_dbc = consist_loss(b1_x1_prob, b0_x1_prob) + loss = loss + loss_dbc * w_dbc + if rank_n <= 0.2 * rank_length: + b0_x1_argmax = torch.argmax(b0_x1_pred, dim = 1, keepdim = True) + b0_x1_lab = get_soft_label(b0_x1_argmax, class_num, self.tensor_type) + b1_x1_argmax = torch.argmax(b1_x1_pred, dim = 1, keepdim = True) + b1_x1_lab = get_soft_label(b1_x1_argmax, class_num, self.tensor_type) + pseudo_label = (b0_x1_lab + b1_x1_lab + y1) / 3 + sharpen = lambda p,T: p**(1.0/T)/(p**(1.0/T) + (1-p)**(1.0/T)) + b0_x1_prob = nn.Softmax(dim = 1)(b0_x1_pred) + loss_st = torch.mean(torch.abs(b0_x1_prob - sharpen(pseudo_label, 0.5))) + loss = loss + loss_st * w_st + + loss.backward() + self.optimizer.step() + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): + self.scheduler.step() + + train_loss = train_loss + loss.item() + train_loss_sup = train_loss_sup + loss_sup.item() + # train_loss_reg = train_loss_reg + loss_reg.item() + # get dice evaluation for each class in annotated images + if(isinstance(b0_x0_pred, tuple) or isinstance(b0_x0_pred, list)): + p0 = b0_x0_pred[0] + else: + p0 = b0_x0_pred + p0_argmax = torch.argmax(p0, dim = 1, keepdim = True) + p0_soft = get_soft_label(p0_argmax, class_num, self.tensor_type) + p0_soft, y0 = reshape_prediction_and_ground_truth(p0_soft, y0) + dice_list = get_classwise_dice(p0_soft, y0) + train_dice_list.append(dice_list.cpu().numpy()) + train_avg_loss = train_loss / iter_valid + train_avg_loss_sup = train_loss_sup / iter_valid + train_avg_loss_reg = train_loss_reg / iter_valid + train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) + train_avg_dice = train_cls_dice.mean() + + train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, + 'loss_reg':train_avg_loss_reg, 'regular_w':w_dbc, + 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} + return train_scalers + + def train_valid(self): + self.trainIter_noise = iter(self.train_loader_noise) + nll_cfg = self.config['noisy_label_learning'] + rank_length = nll_cfg.get("dast_rank_length", 20) + self.noisy_rank = Rank(rank_length) + self.clean_rank = Rank(rank_length) + super(NLLDAST, self).train_valid() \ No newline at end of file diff --git a/pymic/net_run_nll/nll_main.py b/pymic/net_run_nll/nll_main.py index 7d8f1f8..cc07a44 100644 --- a/pymic/net_run_nll/nll_main.py +++ b/pymic/net_run_nll/nll_main.py @@ -7,9 +7,11 @@ from pymic.util.parse_config import * from pymic.net_run_nll.nll_co_teaching import NLLCoTeaching from pymic.net_run_nll.nll_trinet import NLLTriNet +from pymic.net_run_nll.nll_dast import NLLDAST NLLMethodDict = {'CoTeaching': NLLCoTeaching, - "TriNet": NLLTriNet} + "TriNet": NLLTriNet, + "DAST": NLLDAST} def main(): if(len(sys.argv) < 3): From 24c46cffc58d75d3b7f3cee8552f0d045c7d1398 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 19 Aug 2022 16:50:19 +0800 Subject: [PATCH 029/225] Update nll_dast.py add config parameter --- pymic/net_run_nll/nll_dast.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pymic/net_run_nll/nll_dast.py b/pymic/net_run_nll/nll_dast.py index d95eec0..19a59a2 100644 --- a/pymic/net_run_nll/nll_dast.py +++ b/pymic/net_run_nll/nll_dast.py @@ -207,10 +207,11 @@ def training(self): rank_n = self.noisy_rank.add_val(loss_n) rank_c = self.clean_rank.add_val(loss_c) if loss_n < loss_c: - if rank_c >= rank_length * 0.8: + select_ratio = nll_cfg.get('dast_select_ratio', 0.2) + if rank_c >= rank_length * (1 - select_ratio): loss_dbc = consist_loss(b1_x1_prob, b0_x1_prob) loss = loss + loss_dbc * w_dbc - if rank_n <= 0.2 * rank_length: + if rank_n <= rank_length * select_ratio: b0_x1_argmax = torch.argmax(b0_x1_pred, dim = 1, keepdim = True) b0_x1_lab = get_soft_label(b0_x1_argmax, class_num, self.tensor_type) b1_x1_argmax = torch.argmax(b1_x1_pred, dim = 1, keepdim = True) From 75622c346f7e753c668ef734f0be5085d6a6074a Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 20 Aug 2022 12:01:23 +0800 Subject: [PATCH 030/225] update comment and fix typo update comment and fix typo --- pymic/net_run/agent_cls.py | 14 ++++++++------ pymic/net_run_nll/nll_clslsr.py | 3 ++- pymic/net_run_nll/nll_trinet.py | 17 +++++------------ 3 files changed, 15 insertions(+), 19 deletions(-) diff --git a/pymic/net_run/agent_cls.py b/pymic/net_run/agent_cls.py index 71d7c30..8687048 100644 --- a/pymic/net_run/agent_cls.py +++ b/pymic/net_run/agent_cls.py @@ -195,7 +195,9 @@ def train_valid(self): ckpt_dir = self.config['training']['ckpt_save_dir'] if(ckpt_dir[-1] == "/"): ckpt_dir = ckpt_dir[:-1] - ckpt_prefx = ckpt_dir.split('/')[-1] + ckpt_prefix = self.config['training'].get('ckpt_prefix', None) + if(ckpt_prefix is None): + ckpt_prefix = ckpt_dir.split('/')[-1] iter_start = self.config['training']['iter_start'] iter_max = self.config['training']['iter_max'] iter_valid = self.config['training']['iter_valid'] @@ -206,7 +208,7 @@ def train_valid(self): self.best_model_wts = None self.checkpoint = None if(iter_start > 0): - checkpoint_file = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefx, iter_start) + checkpoint_file = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, iter_start) self.checkpoint = torch.load(checkpoint_file, map_location = self.device) assert(self.checkpoint['iteration'] == iter_start) self.net.load_state_dict(self.checkpoint['model_state_dict']) @@ -237,9 +239,9 @@ def train_valid(self): 'valid_pred': valid_scalars[metrics], 'model_state_dict': self.net.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict()} - save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefx, glob_it) + save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, glob_it) torch.save(save_dict, save_name) - txt_file = open("{0:}/{1:}_latest.txt".format(ckpt_dir, ckpt_prefx), 'wt') + txt_file = open("{0:}/{1:}_latest.txt".format(ckpt_dir, ckpt_prefix), 'wt') txt_file.write(str(glob_it)) txt_file.close() @@ -248,9 +250,9 @@ def train_valid(self): 'valid_pred': self.max_val_score, 'model_state_dict': self.best_model_wts, 'optimizer_state_dict': self.optimizer.state_dict()} - save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefx, self.max_val_it) + save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.max_val_it) torch.save(save_dict, save_name) - txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefx), 'wt') + txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefix), 'wt') txt_file.write(str(self.max_val_it)) txt_file.close() logging.info('The best perfroming iter is {0:}, valid {1:} {2:}'.format(\ diff --git a/pymic/net_run_nll/nll_clslsr.py b/pymic/net_run_nll/nll_clslsr.py index 2894db6..9ee7182 100644 --- a/pymic/net_run_nll/nll_clslsr.py +++ b/pymic/net_run_nll/nll_clslsr.py @@ -3,7 +3,8 @@ Caculating the confidence map of labels of training samples, which is used in the method of SLSR. Minqing Zhang et al., Characterizing Label Errors: Confident Learning - for Noisy-Labeled Image Segmentation, MICCAI 2020. + for Noisy-Labeled Image Segmentation, MICCAI 2020. + https://link.springer.com/chapter/10.1007/978-3-030-59710-8_70 """ from __future__ import print_function, division diff --git a/pymic/net_run_nll/nll_trinet.py b/pymic/net_run_nll/nll_trinet.py index eb0ecdd..6af5449 100644 --- a/pymic/net_run_nll/nll_trinet.py +++ b/pymic/net_run_nll/nll_trinet.py @@ -1,13 +1,11 @@ # -*- coding: utf-8 -*- """ -Implementation of Co-teaching for learning from noisy samples for +Implementation of trinet for learning from noisy samples for segmentation tasks according to the following paper: - Bo Han et al., Co-teaching: Robust Training of Deep NeuralNetworks - with Extremely Noisy Labels, NeurIPS, 2018 -The author's original implementation was: -https://github.com/bhanML/Co-teaching - - + Tianwei Zhang, Lequan Yu, Na Hu, Su Lv, Shi Gu: + Robust Medical Image Segmentation from Non-expert Annotations with Tri-network. + MICCAI 2020. + https://link.springer.com/chapter/10.1007/978-3-030-59719-1_25 """ from __future__ import print_function, division import logging @@ -48,11 +46,6 @@ def forward(self, x): return (out1 + out2 + out3) / 3 class NLLTriNet(SegmentationAgent): - """ - Co-teaching: Robust Training of Deep Neural Networks with Extremely - Noisy Labels - https://arxiv.org/abs/1804.06872 - """ def __init__(self, config, stage = 'train'): super(NLLTriNet, self).__init__(config, stage) From 8f8eb33ff88cb657ca925e5c782a80d0fbb07c3a Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 20 Aug 2022 12:07:53 +0800 Subject: [PATCH 031/225] Update README.md --- README.md | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index d6007e8..e29abde 100644 --- a/README.md +++ b/README.md @@ -31,24 +31,26 @@ PyMIC provides flixible modules for medical image computing tasks including clas [tbx_link]:https://github.com/lanpa/tensorboardX ## Installation -Run the following command to install the current released version of PyMIC: +Run the following command to install the latest released version of PyMIC: ```bash pip install PYMIC ``` -Alternatively, you can download the source code for the latest version. Run the following command to compile and install: +Alternatively, you can download the source code for the latest dev version. Run the following command to compile and install: ```bash python setup.py install ``` -## Examples -[PyMIC_examples][examples] provides some examples of starting to use PyMIC. At the beginning, you only need to edit the configuration files to select different datasets, networks and training methods for running the code. When you are more familiar with PyMIC, you can customize different modules in the PyMIC package. You can find both types of examples: +## How to start +* [PyMIC_examples][exp_link] shows some examples of starting to use PyMIC. +* [PyMIC_doc][docs_link] provides documentation of this project. -[examples]: https://github.com/HiLab-git/PyMIC_examples +[docs_link]:https://pymic.readthedocs.io/en/latest/ +[exp_link]:https://github.com/HiLab-git/PyMIC_examples -# Projects based on PyMIC +## Projects based on PyMIC Using PyMIC, it becomes easy to develop deep learning models for different projects, such as the following: 1, [COPLE-Net][coplenet] (TMI 2020), COVID-19 Pneumonia Segmentation from CT images. From 51629241f12a83016f79d10efce494a2342b996c Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 20 Aug 2022 18:04:05 +0800 Subject: [PATCH 032/225] add init file and update version add init file and update version --- pymic/net_run_nll/__init__.py | 0 pymic/net_run_ssl/__init__.py | 0 pymic/net_run_wsl/__init__.py | 0 pymic/net_run_wsl/wsl_dmpls.py | 2 +- pymic/net_run_wsl/wsl_gatedcrf.py | 2 +- requirements.txt | 12 ++++++++++++ setup.py | 6 ++++-- 7 files changed, 18 insertions(+), 4 deletions(-) create mode 100644 pymic/net_run_nll/__init__.py create mode 100644 pymic/net_run_ssl/__init__.py create mode 100644 pymic/net_run_wsl/__init__.py create mode 100644 requirements.txt diff --git a/pymic/net_run_nll/__init__.py b/pymic/net_run_nll/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pymic/net_run_ssl/__init__.py b/pymic/net_run_ssl/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pymic/net_run_wsl/__init__.py b/pymic/net_run_wsl/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pymic/net_run_wsl/wsl_dmpls.py b/pymic/net_run_wsl/wsl_dmpls.py index a198ddc..8ee9e53 100644 --- a/pymic/net_run_wsl/wsl_dmpls.py +++ b/pymic/net_run_wsl/wsl_dmpls.py @@ -4,7 +4,7 @@ import numpy as np import random import torch -from torhc.optim import lr_scheduler +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice diff --git a/pymic/net_run_wsl/wsl_gatedcrf.py b/pymic/net_run_wsl/wsl_gatedcrf.py index 2be8856..64e0f1b 100644 --- a/pymic/net_run_wsl/wsl_gatedcrf.py +++ b/pymic/net_run_wsl/wsl_gatedcrf.py @@ -3,7 +3,7 @@ import logging import numpy as np import torch -from torhc.optim import lr_scheduler +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..2dc1604 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +matplotlib>=3.1.2 +numpy>=1.17.4 +pandas>=0.25.3 +python>=3.6 +scikit-image>=0.16.2 +scikit-learn>=0.22 +scipy>=1.3.3 +SimpleITK>=1.2.4 +tensorboard>=2.1.0 +tensorboardX>=1.9 +torch>=1.7.1 +torchvision>=0.8.2 diff --git a/setup.py b/setup.py index 498aa26..ce7271b 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ # Get the summary description = 'An open-source deep learning platform' + \ - ' for medical image computing' + ' for annotation-efficient medical image computing' # Get the long description with open('README.md', encoding='utf-8') as f: @@ -11,7 +11,7 @@ setuptools.setup( name = 'PYMIC', - version = "0.2.5", + version = "0.3.0", author ='PyMIC Consortium', author_email = 'wguotai@gmail.com', description = description, @@ -31,6 +31,8 @@ 'console_scripts': [ 'pymic_run = pymic.net_run.net_run:main', 'pymic_ssl = pymic.net_run_ssl.ssl_main:main', + 'pymic_wsl = pymic.net_run_wsl.wsl_main:main', + 'pymic_nll = pymic.net_run_nll.nll_main:main', 'pymic_eval_cls = pymic.util.evaluation_cls:main', 'pymic_eval_seg = pymic.util.evaluation_seg:main' ], From 3797e1ea6c856bb0a5288aa4d7432e4073d0dcc0 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 26 Aug 2022 11:35:33 +0800 Subject: [PATCH 033/225] update installtaion --- docs/source/index.rst | 2 +- docs/source/installation.rst | 44 ++++++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) create mode 100644 docs/source/installation.rst diff --git a/docs/source/index.rst b/docs/source/index.rst index dddd5df..5e71e9c 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -15,6 +15,6 @@ Contents -------- .. toctree:: - + installation usage api diff --git a/docs/source/installation.rst b/docs/source/installation.rst new file mode 100644 index 0000000..b434839 --- /dev/null +++ b/docs/source/installation.rst @@ -0,0 +1,44 @@ +.. _installation: + +.. role:: bash(code) + :language: bash + +Installation +============ + +Install PyMIC using pip (e.g., within a `Python virtual environment `_): + +.. code-block:: bash + + pip install PYMIC + +Alternatively, you can download or clone the code from `GitHub `_ and install PyMIC by + +.. code-block:: bash + + git clone https://github.com/HiLab-git/PyMIC + cd PyMIC + python setup.py install + +Dependencies +------------ +PyMIC requires Python 3.6 (or higher) and depends on the following packages: + + - `h5py `_ + - `NumPy `_ + - `scikit-image `_ + - `SciPy `_ + - `SimpleITK `_ + +.. note:: + For the :mod:`pymia.data` package, not all dependencies are installed directly due to their heaviness. + Meaning, you need to either manually install PyTorch by + + - :bash:`pip install torch` + + or TensorFlow by + + - :bash:`pip install tensorflow` + + depending on your preferred deep learning framework when using the :mod:`pymia.data` package. + Upon loading a module from the :mod:`pymia.data` package, pymia will always check if the required dependencies are fulfilled. From 78532b7598bbbb3ffb2809606d9d8cd4243ce98a Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 26 Aug 2022 11:47:43 +0800 Subject: [PATCH 034/225] Update requirements.txt update python to 3.8 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 2dc1604..e84ccae 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ matplotlib>=3.1.2 numpy>=1.17.4 pandas>=0.25.3 -python>=3.6 +python>=3.8 scikit-image>=0.16.2 scikit-learn>=0.22 scipy>=1.3.3 From e7f9be98d1a3abaf2f600f9dd87230dae8006da7 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 26 Aug 2022 11:50:56 +0800 Subject: [PATCH 035/225] Update requirements.txt --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e84ccae..c4c9ac3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ matplotlib>=3.1.2 numpy>=1.17.4 pandas>=0.25.3 -python>=3.8 scikit-image>=0.16.2 scikit-learn>=0.22 scipy>=1.3.3 From 21a3279e2eb10a89f2c2fed09e20052d96f4820d Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 26 Aug 2022 12:01:17 +0800 Subject: [PATCH 036/225] update usage add train and testing commands --- docs/source/index.rst | 1 - docs/source/usage.rst | 36 ++++++++++++++++++++---------------- 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index 5e71e9c..9f53b3a 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -15,6 +15,5 @@ Contents -------- .. toctree:: - installation usage api diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 924afcf..0bc986f 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -6,29 +6,33 @@ Usage Installation ------------ -To use Lumache, first install it using pip: +Install PyMIC using pip (e.g., within a `Python virtual environment `_): -.. code-block:: console +.. code-block:: bash - (.venv) $ pip install lumache + pip install PYMIC -Creating recipes ----------------- +Alternatively, you can download or clone the code from `GitHub `_ and install PyMIC by -To retrieve a list of random ingredients, -you can use the ``lumache.get_random_ingredients()`` function: +.. code-block:: bash -.. autofunction:: lumache.get_random_ingredients + git clone https://github.com/HiLab-git/PyMIC + cd PyMIC + python setup.py install -The ``kind`` parameter should be either ``"meat"``, ``"fish"``, -or ``"veggies"``. Otherwise, :py:func:`lumache.get_random_ingredients` -will raise an exception. +Train and Test +------------ + +PyMIC accepts a configuration file for runing. For example, to train a network +for segmentation with full supervision, run the fullowing command: + +.. code-block:: bash + + pymic_run train myconfig.cfg -.. autoexception:: lumache.InvalidKindError +After training, run the following command for testing: -For example: +.. code-block:: bash ->>> import lumache ->>> lumache.get_random_ingredients() -['shells', 'gorgonzola', 'parsley'] + pymic_run test myconfig.cfg From 7bb24a0377608129cbe75beec11f03497892412c Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 26 Aug 2022 12:44:27 +0800 Subject: [PATCH 037/225] update index add citation --- docs/source/index.rst | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/docs/source/index.rst b/docs/source/index.rst index 9f53b3a..f16dac8 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -17,3 +17,26 @@ Contents .. toctree:: usage api + + +Citation +-------- +If you use PyMIC for your research, please acknowledge it accordingly by citing our paper: + +`G. Wang, X. Luo, R. Gu, S. Yang, Y. Qu, S. Zhai, Q. Zhao, K. Li, S. Zhang. (2022). +PyMIC: A deep learning toolkit for annotation-efficient medical image segmentation. +arXiv, 2208.09350. `_ + + +BibTeX entry: + +.. code-block:: none + + @article{Wang2022pymic, + author = {Guotai Wang and Xiangde Luo and Ran Gu and Shuojue Yang and Yijie Qu and Shuwei Zhai and Qianfei Zhao and Kang Li and Shaoting Zhang}, + title = {{PyMIC: A deep learning toolkit for annotation-efficient medical image segmentation}}, + year = {2022}, + url = {http://arxiv.org/abs/2208.09350}, + journal = {arXiv}, + pages = {1-10}, + } From 3d69be6ead328246963f24d54f50b24f27783d9d Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 26 Aug 2022 13:03:00 +0800 Subject: [PATCH 038/225] add example of configuration file add example of configuration file --- docs/source/index.rst | 1 + docs/source/usage.rst | 85 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 85 insertions(+), 1 deletion(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index f16dac8..54461d7 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -38,5 +38,6 @@ BibTeX entry: year = {2022}, url = {http://arxiv.org/abs/2208.09350}, journal = {arXiv}, + volume = {2208.09350}, pages = {1-10}, } diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 0bc986f..b2948d4 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -21,7 +21,7 @@ Alternatively, you can download or clone the code from `GitHub `_ + +.. code-block:: none + [dataset] + # tensor type (float or double) + tensor_type = float + + task_type = seg + root_dir = ../../PyMIC_data/JSRT + train_csv = config/jsrt_train.csv + valid_csv = config/jsrt_valid.csv + test_csv = config/jsrt_test.csv + + train_batch_size = 4 + + # data transforms + train_transform = [NormalizeWithMeanStd, RandomCrop, LabelConvert, LabelToProbability] + valid_transform = [NormalizeWithMeanStd, LabelConvert, LabelToProbability] + test_transform = [NormalizeWithMeanStd] + + NormalizeWithMeanStd_channels = [0] + RandomCrop_output_size = [240, 240] + + LabelConvert_source_list = [0, 255] + LabelConvert_target_list = [0, 1] + + [network] + # type of network + net_type = UNet2D + + # number of class, required for segmentation task + class_num = 2 + in_chns = 1 + feature_chns = [16, 32, 64, 128, 256] + dropout = [0, 0, 0.3, 0.4, 0.5] + bilinear = False + deep_supervise= False + + [training] + # list of gpus + gpus = [0] + loss_type = DiceLoss + + # for optimizers + optimizer = Adam + learning_rate = 1e-3 + momentum = 0.9 + weight_decay = 1e-5 + + # for lr scheduler (MultiStepLR) + lr_scheduler = MultiStepLR + lr_gamma = 0.5 + lr_milestones = [2000, 4000, 6000] + + ckpt_save_dir = model/unet_dice_loss + ckpt_prefix = unet + + # start iter + iter_start = 0 + iter_max = 8000 + iter_valid = 200 + iter_save = 8000 + + [testing] + # list of gpus + gpus = [0] + # checkpoint mode can be [0-latest, 1-best, 2-specified] + ckpt_mode = 0 + output_dir = result + + # convert the label of prediction output + label_source = [0, 1] + label_target = [0, 255] From 32bff9bdc80a278afce13cddb7fa4715730860a7 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 26 Aug 2022 15:35:42 +0800 Subject: [PATCH 039/225] Update usage.rst --- docs/source/usage.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index b2948d4..837db66 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -40,7 +40,7 @@ Configuration File ------------------ PyMIC uses configuration files to specify the setting and parameters of a deep -learning pipeline, so that users can reuse the code and minimizing their workload. +learning pipeline, so that users can reuse the code and minimize their workload. Users can use configuration files to config almost all the componets involved, such as dataset, network structure, loss function, optimizer, learning rate scheduler and post processing methods, etc. The following is an example configuration @@ -48,6 +48,7 @@ file used for segmentation of lung from radiograph, which can be find in `PyMIC_examples/segmentation/JSRT. `_ .. code-block:: none + [dataset] # tensor type (float or double) tensor_type = float From 49bcaa22e6c0f209899e669ffb7e26e3db3df6cc Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 26 Aug 2022 15:49:38 +0800 Subject: [PATCH 040/225] Update usage.rst add a tip --- docs/source/usage.rst | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 837db66..76455a0 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -43,22 +43,26 @@ PyMIC uses configuration files to specify the setting and parameters of a deep learning pipeline, so that users can reuse the code and minimize their workload. Users can use configuration files to config almost all the componets involved, such as dataset, network structure, loss function, optimizer, learning rate -scheduler and post processing methods, etc. The following is an example configuration +scheduler and post processing methods, etc. + +.. tip:: +Genreally, the configuration file have four sections: ``dataset``, ``network``, +``training`` and ``testing``. + +The following is an example configuration file used for segmentation of lung from radiograph, which can be find in `PyMIC_examples/segmentation/JSRT. `_ .. code-block:: none - + [dataset] # tensor type (float or double) tensor_type = float - task_type = seg root_dir = ../../PyMIC_data/JSRT train_csv = config/jsrt_train.csv valid_csv = config/jsrt_valid.csv test_csv = config/jsrt_test.csv - train_batch_size = 4 # data transforms @@ -73,10 +77,8 @@ file used for segmentation of lung from radiograph, which can be find in LabelConvert_target_list = [0, 1] [network] - # type of network net_type = UNet2D - - # number of class, required for segmentation task + # Parameters for UNet2D class_num = 2 in_chns = 1 feature_chns = [16, 32, 64, 128, 256] @@ -100,8 +102,8 @@ file used for segmentation of lung from radiograph, which can be find in lr_gamma = 0.5 lr_milestones = [2000, 4000, 6000] - ckpt_save_dir = model/unet_dice_loss - ckpt_prefix = unet + ckpt_save_dir = model/unet_dice_loss + ckpt_prefix = unet # start iter iter_start = 0 @@ -113,10 +115,15 @@ file used for segmentation of lung from radiograph, which can be find in # list of gpus gpus = [0] # checkpoint mode can be [0-latest, 1-best, 2-specified] - ckpt_mode = 0 - output_dir = result + ckpt_mode = 0 + output_dir = result # convert the label of prediction output label_source = [0, 1] label_target = [0, 255] + +SegmentationAgent +----------------- + +SegmentationAgent \ No newline at end of file From aced7c494f1e5e5d43bb895eb009cf5369299b27 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 26 Aug 2022 15:53:27 +0800 Subject: [PATCH 041/225] Update usage.rst --- docs/source/usage.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 76455a0..54d753c 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -46,8 +46,8 @@ such as dataset, network structure, loss function, optimizer, learning rate scheduler and post processing methods, etc. .. tip:: -Genreally, the configuration file have four sections: ``dataset``, ``network``, -``training`` and ``testing``. + Genreally, the configuration file have four sections: ``dataset``, ``network``, + ``training`` and ``testing``. The following is an example configuration file used for segmentation of lung from radiograph, which can be find in From dcbc67c0908f8c19e8d6ef41cc446e72f257af0c Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 26 Aug 2022 16:04:17 +0800 Subject: [PATCH 042/225] Update usage.rst add example for SegmentationAgent --- docs/source/usage.rst | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 54d753c..fe220d2 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -126,4 +126,17 @@ file used for segmentation of lung from radiograph, which can be find in SegmentationAgent ----------------- -SegmentationAgent \ No newline at end of file +:mod:`pymic.net_run.agent_seg.SegmentationAgent` is the general class used for training +and inference of deep learning models. You just need to specify a configuration file to +initialize an instance of that class. An example code to use it is: + +.. code-block:: none + from pymic.util.parse_config import * + + config_name = "a_config_file.cfg" + config = parse_config(config_name) + config = synchronize_config(config) + stage = "train" # or "test" + agent = SegmentationAgent(config, stage) + agent.run() + From 77b863dce11abdbedff23b08184b03f0e838dc75 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 26 Aug 2022 16:13:40 +0800 Subject: [PATCH 043/225] Update usage.rst update tip for SegmentationAgent --- docs/source/usage.rst | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index fe220d2..7c52a30 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -46,6 +46,7 @@ such as dataset, network structure, loss function, optimizer, learning rate scheduler and post processing methods, etc. .. tip:: + Genreally, the configuration file have four sections: ``dataset``, ``network``, ``training`` and ``testing``. @@ -131,6 +132,7 @@ and inference of deep learning models. You just need to specify a configuration initialize an instance of that class. An example code to use it is: .. code-block:: none + from pymic.util.parse_config import * config_name = "a_config_file.cfg" @@ -140,3 +142,13 @@ initialize an instance of that class. An example code to use it is: agent = SegmentationAgent(config, stage) agent.run() +The above code will use the dataset, network and loss function, etc specifcied in the +configuration file for running. + +.. tip:: + + If you use the built-in modules such as ``UNet`` and ``Dice`` + ``CrossEntropy`` loss + for segmentation, you don't need to write the above code. Just just use the `pymic_run` + command. + + From 2faaf00fcaa6bd346c21a00694b78811bbb3d299 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 26 Aug 2022 16:18:40 +0800 Subject: [PATCH 044/225] Update usage.rst add tips --- docs/source/usage.rst | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 7c52a30..79fc5e8 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -35,6 +35,12 @@ After training, run the following command for testing: .. code-block:: bash pymic_run test myconfig.cfg + +.. tip:: + + We provide several examples in `PyMIC_examples + `_. Please run these examples to + quickly start with using PyMIC. Configuration File ------------------ @@ -45,7 +51,7 @@ Users can use configuration files to config almost all the componets involved, such as dataset, network structure, loss function, optimizer, learning rate scheduler and post processing methods, etc. -.. tip:: +.. note:: Genreally, the configuration file have four sections: ``dataset``, ``network``, ``training`` and ``testing``. From a5ea7ad93421591eb4c8ead7671ea8a41a54fd63 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 26 Aug 2022 16:23:47 +0800 Subject: [PATCH 045/225] Update usage.rst --- docs/source/usage.rst | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 79fc5e8..308c745 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -37,10 +37,9 @@ After training, run the following command for testing: pymic_run test myconfig.cfg .. tip:: - - We provide several examples in `PyMIC_examples - `_. Please run these examples to - quickly start with using PyMIC. + + We provide several examples in `PyMIC_examples. `_ + Please run these examples to quickly start with using PyMIC. Configuration File ------------------ From b3b01181496580f46fff4b7efabe2566c2db0ce0 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 26 Aug 2022 16:39:20 +0800 Subject: [PATCH 046/225] add fsl add a page for fully supervised learning --- docs/source/usage.fsl.rst | 31 +++++++++++++++++++++++++++++++ docs/source/usage.rst | 30 +++++------------------------- 2 files changed, 36 insertions(+), 25 deletions(-) create mode 100644 docs/source/usage.fsl.rst diff --git a/docs/source/usage.fsl.rst b/docs/source/usage.fsl.rst new file mode 100644 index 0000000..08c1ae5 --- /dev/null +++ b/docs/source/usage.fsl.rst @@ -0,0 +1,31 @@ +.. _fully_supervised_learning: + +Fully Supervised Learning +========================= + +SegmentationAgent +----------------- + +:mod:`pymic.net_run.agent_seg.SegmentationAgent` is the general class used for training +and inference of deep learning models. You just need to specify a configuration file to +initialize an instance of that class. An example code to use it is: + +.. code-block:: none + + from pymic.util.parse_config import * + + config_name = "a_config_file.cfg" + config = parse_config(config_name) + config = synchronize_config(config) + stage = "train" # or "test" + agent = SegmentationAgent(config, stage) + agent.run() + +The above code will use the dataset, network and loss function, etc specifcied in the +configuration file for running. + +.. tip:: + + If you use the built-in modules such as ``UNet`` and ``Dice`` + ``CrossEntropy`` loss + for segmentation, you don't need to write the above code. Just just use the `pymic_run` + command. \ No newline at end of file diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 308c745..bee643f 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -20,6 +20,8 @@ Alternatively, you can download or clone the code from `GitHub `_ Please run these examples to quickly start with using PyMIC. + +.. _configuration: + Configuration File ------------------ @@ -129,31 +134,6 @@ file used for segmentation of lung from radiograph, which can be find in label_target = [0, 255] -SegmentationAgent ------------------ - -:mod:`pymic.net_run.agent_seg.SegmentationAgent` is the general class used for training -and inference of deep learning models. You just need to specify a configuration file to -initialize an instance of that class. An example code to use it is: - -.. code-block:: none - - from pymic.util.parse_config import * - - config_name = "a_config_file.cfg" - config = parse_config(config_name) - config = synchronize_config(config) - stage = "train" # or "test" - agent = SegmentationAgent(config, stage) - agent.run() - -The above code will use the dataset, network and loss function, etc specifcied in the -configuration file for running. - -.. tip:: - If you use the built-in modules such as ``UNet`` and ``Dice`` + ``CrossEntropy`` loss - for segmentation, you don't need to write the above code. Just just use the `pymic_run` - command. From a6999b8f333373372a80fbf5fa10326c7f188443 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 26 Aug 2022 16:55:21 +0800 Subject: [PATCH 047/225] update usage add installation file --- docs/source/index.rst | 18 +++++++++++++----- docs/source/installation.rst | 13 ------------- 2 files changed, 13 insertions(+), 18 deletions(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index 54461d7..9d24e9a 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -11,16 +11,24 @@ how to :ref:`installation` the project. This project is under active development. It will be updated later. -Contents --------- - .. toctree:: - usage - api + :maxdepth: 1 + :hidden: + :caption: Getting started + + installation + usage + api +* :doc:`installation` helps you installing PyMIC quickly. + +* :doc:`usage` give you an overview of how to use PyMIC. + +* :doc:`api`. Citation -------- + If you use PyMIC for your research, please acknowledge it accordingly by citing our paper: `G. Wang, X. Luo, R. Gu, S. Yang, Y. Qu, S. Zhai, Q. Zhao, K. Li, S. Zhang. (2022). diff --git a/docs/source/installation.rst b/docs/source/installation.rst index b434839..c59fc05 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -29,16 +29,3 @@ PyMIC requires Python 3.6 (or higher) and depends on the following packages: - `scikit-image `_ - `SciPy `_ - `SimpleITK `_ - -.. note:: - For the :mod:`pymia.data` package, not all dependencies are installed directly due to their heaviness. - Meaning, you need to either manually install PyTorch by - - - :bash:`pip install torch` - - or TensorFlow by - - - :bash:`pip install tensorflow` - - depending on your preferred deep learning framework when using the :mod:`pymia.data` package. - Upon loading a module from the :mod:`pymia.data` package, pymia will always check if the required dependencies are fulfilled. From ef48bd70814420f28b6e66466dc64760b670248b Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 26 Aug 2022 17:02:29 +0800 Subject: [PATCH 048/225] Update index.rst update toctree --- docs/source/index.rst | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index 9d24e9a..769d0e0 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -11,21 +11,20 @@ how to :ref:`installation` the project. This project is under active development. It will be updated later. + +Getting Started +--------------- + +If you are new to PyMIC, here are some guides for learning to use it quickly. + .. toctree:: - :maxdepth: 1 - :hidden: + :maxdepth: 2 :caption: Getting started installation usage api -* :doc:`installation` helps you installing PyMIC quickly. - -* :doc:`usage` give you an overview of how to use PyMIC. - -* :doc:`api`. - Citation -------- From 3a81b2b98b9c79b364b11a8e7a51522511b64208 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 26 Aug 2022 21:13:20 +0800 Subject: [PATCH 049/225] update usage update usage --- docs/source/installation.rst | 8 ++++++-- docs/source/usage.rst | 23 ++++++----------------- 2 files changed, 12 insertions(+), 19 deletions(-) diff --git a/docs/source/installation.rst b/docs/source/installation.rst index c59fc05..ded2c6e 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -3,8 +3,8 @@ .. role:: bash(code) :language: bash -Installation -============ +Install with ``pip`` +---------------------- Install PyMIC using pip (e.g., within a `Python virtual environment `_): @@ -12,6 +12,9 @@ Install PyMIC using pip (e.g., within a `Python virtual environment `_ and install PyMIC by .. code-block:: bash @@ -24,6 +27,7 @@ Dependencies ------------ PyMIC requires Python 3.6 (or higher) and depends on the following packages: + - `pandas `_ - `h5py `_ - `NumPy `_ - `scikit-image `_ diff --git a/docs/source/usage.rst b/docs/source/usage.rst index bee643f..4030b33 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -1,24 +1,13 @@ Usage ===== -.. _installation: +.. toctree:: + :maxdepth: 2 + :caption: Getting started -Installation ------------- - -Install PyMIC using pip (e.g., within a `Python virtual environment `_): - -.. code-block:: bash - - pip install PYMIC - -Alternatively, you can download or clone the code from `GitHub `_ and install PyMIC by - -.. code-block:: bash - - git clone https://github.com/HiLab-git/PyMIC - cd PyMIC - python setup.py install + traintest + configuration + fully_supervised_learning .. _traintest: From fdf49041d9483a799425aa0d1bb18eec104cb8e7 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 26 Aug 2022 21:22:11 +0800 Subject: [PATCH 050/225] update installation update installation --- docs/source/index.rst | 7 +------ docs/source/installation.rst | 10 ++++------ 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index 769d0e0..2396626 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -12,14 +12,9 @@ how to :ref:`installation` the project. This project is under active development. It will be updated later. -Getting Started ---------------- - -If you are new to PyMIC, here are some guides for learning to use it quickly. - .. toctree:: :maxdepth: 2 - :caption: Getting started + :caption: Getting Started installation usage diff --git a/docs/source/installation.rst b/docs/source/installation.rst index ded2c6e..092e62b 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -3,8 +3,9 @@ .. role:: bash(code) :language: bash -Install with ``pip`` ----------------------- + +Installation +------------ Install PyMIC using pip (e.g., within a `Python virtual environment `_): @@ -12,8 +13,6 @@ Install PyMIC using pip (e.g., within a `Python virtual environment `_ and install PyMIC by @@ -23,8 +22,7 @@ Alternatively, you can download or clone the code from `GitHub `_ From 1073442e972897fbc2b74c5f65af3b32c538aa6e Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 26 Aug 2022 21:32:56 +0800 Subject: [PATCH 051/225] add quickstart add quickstart --- docs/source/usage.quickstart.rst | 0 docs/source/usage.rst | 116 +------------------------------ 2 files changed, 2 insertions(+), 114 deletions(-) create mode 100644 docs/source/usage.quickstart.rst diff --git a/docs/source/usage.quickstart.rst b/docs/source/usage.quickstart.rst new file mode 100644 index 0000000..e69de29 diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 4030b33..4280c75 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -5,122 +5,10 @@ Usage :maxdepth: 2 :caption: Getting started - traintest - configuration - fully_supervised_learning + usage.quickstart + usage.fsl -.. _traintest: -Train and Test --------------- - -PyMIC accepts a configuration file for runing. For example, to train a network -for segmentation with full supervision, run the fullowing command: - -.. code-block:: bash - - pymic_run train myconfig.cfg - -After training, run the following command for testing: - -.. code-block:: bash - - pymic_run test myconfig.cfg - -.. tip:: - - We provide several examples in `PyMIC_examples. `_ - Please run these examples to quickly start with using PyMIC. - - -.. _configuration: - -Configuration File ------------------- - -PyMIC uses configuration files to specify the setting and parameters of a deep -learning pipeline, so that users can reuse the code and minimize their workload. -Users can use configuration files to config almost all the componets involved, -such as dataset, network structure, loss function, optimizer, learning rate -scheduler and post processing methods, etc. - -.. note:: - - Genreally, the configuration file have four sections: ``dataset``, ``network``, - ``training`` and ``testing``. - -The following is an example configuration -file used for segmentation of lung from radiograph, which can be find in -`PyMIC_examples/segmentation/JSRT. `_ - -.. code-block:: none - - [dataset] - # tensor type (float or double) - tensor_type = float - task_type = seg - root_dir = ../../PyMIC_data/JSRT - train_csv = config/jsrt_train.csv - valid_csv = config/jsrt_valid.csv - test_csv = config/jsrt_test.csv - train_batch_size = 4 - - # data transforms - train_transform = [NormalizeWithMeanStd, RandomCrop, LabelConvert, LabelToProbability] - valid_transform = [NormalizeWithMeanStd, LabelConvert, LabelToProbability] - test_transform = [NormalizeWithMeanStd] - - NormalizeWithMeanStd_channels = [0] - RandomCrop_output_size = [240, 240] - - LabelConvert_source_list = [0, 255] - LabelConvert_target_list = [0, 1] - - [network] - net_type = UNet2D - # Parameters for UNet2D - class_num = 2 - in_chns = 1 - feature_chns = [16, 32, 64, 128, 256] - dropout = [0, 0, 0.3, 0.4, 0.5] - bilinear = False - deep_supervise= False - - [training] - # list of gpus - gpus = [0] - loss_type = DiceLoss - - # for optimizers - optimizer = Adam - learning_rate = 1e-3 - momentum = 0.9 - weight_decay = 1e-5 - - # for lr scheduler (MultiStepLR) - lr_scheduler = MultiStepLR - lr_gamma = 0.5 - lr_milestones = [2000, 4000, 6000] - - ckpt_save_dir = model/unet_dice_loss - ckpt_prefix = unet - - # start iter - iter_start = 0 - iter_max = 8000 - iter_valid = 200 - iter_save = 8000 - - [testing] - # list of gpus - gpus = [0] - # checkpoint mode can be [0-latest, 1-best, 2-specified] - ckpt_mode = 0 - output_dir = result - - # convert the label of prediction output - label_source = [0, 1] - label_target = [0, 255] From 5ad0f311aa97ae3017011bc2c67a2da2e84da0c1 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 26 Aug 2022 21:48:06 +0800 Subject: [PATCH 052/225] update quick start --- docs/source/usage.quickstart.rst | 116 +++++++++++++++++++++++++++++++ docs/source/usage.rst | 8 +++ 2 files changed, 124 insertions(+) diff --git a/docs/source/usage.quickstart.rst b/docs/source/usage.quickstart.rst index e69de29..921cd23 100644 --- a/docs/source/usage.quickstart.rst +++ b/docs/source/usage.quickstart.rst @@ -0,0 +1,116 @@ +.. _quickstart: + +Quick Start +=========== + + +Train and Test +-------------- + +PyMIC accepts a configuration file for runing. For example, to train a network +for segmentation with full supervision, run the fullowing command: + +.. code-block:: bash + + pymic_run train myconfig.cfg + +After training, run the following command for testing: + +.. code-block:: bash + + pymic_run test myconfig.cfg + +.. tip:: + + We provide several examples in `PyMIC_examples. `_ + Please run these examples to quickly start with using PyMIC. + + +.. _configuration: + +Configuration File +------------------ + +PyMIC uses configuration files to specify the setting and parameters of a deep +learning pipeline, so that users can reuse the code and minimize their workload. +Users can use configuration files to config almost all the componets involved, +such as dataset, network structure, loss function, optimizer, learning rate +scheduler and post processing methods, etc. + +.. note:: + + Genreally, the configuration file have four sections: ``dataset``, ``network``, + ``training`` and ``testing``. + +The following is an example configuration +file used for segmentation of lung from radiograph, which can be find in +`PyMIC_examples/segmentation/JSRT. `_ + +.. code-block:: none + + [dataset] + # tensor type (float or double) + tensor_type = float + task_type = seg + root_dir = ../../PyMIC_data/JSRT + train_csv = config/jsrt_train.csv + valid_csv = config/jsrt_valid.csv + test_csv = config/jsrt_test.csv + train_batch_size = 4 + + # data transforms + train_transform = [NormalizeWithMeanStd, RandomCrop, LabelConvert, LabelToProbability] + valid_transform = [NormalizeWithMeanStd, LabelConvert, LabelToProbability] + test_transform = [NormalizeWithMeanStd] + + NormalizeWithMeanStd_channels = [0] + RandomCrop_output_size = [240, 240] + + LabelConvert_source_list = [0, 255] + LabelConvert_target_list = [0, 1] + + [network] + net_type = UNet2D + # Parameters for UNet2D + class_num = 2 + in_chns = 1 + feature_chns = [16, 32, 64, 128, 256] + dropout = [0, 0, 0.3, 0.4, 0.5] + bilinear = False + deep_supervise= False + + [training] + # list of gpus + gpus = [0] + loss_type = DiceLoss + + # for optimizers + optimizer = Adam + learning_rate = 1e-3 + momentum = 0.9 + weight_decay = 1e-5 + + # for lr scheduler (MultiStepLR) + lr_scheduler = MultiStepLR + lr_gamma = 0.5 + lr_milestones = [2000, 4000, 6000] + + ckpt_save_dir = model/unet_dice_loss + ckpt_prefix = unet + + # start iter + iter_start = 0 + iter_max = 8000 + iter_valid = 200 + iter_save = 8000 + + [testing] + # list of gpus + gpus = [0] + # checkpoint mode can be [0-latest, 1-best, 2-specified] + ckpt_mode = 0 + output_dir = result + + # convert the label of prediction output + label_source = [0, 1] + label_target = [0, 255] \ No newline at end of file diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 4280c75..f8a9c4a 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -1,6 +1,14 @@ Usage ===== +This usage give details of how to use PyMIC. +Beginners can easily start with training a deep learning +model with configure files. When you are more familar with +the PyMIC pipeline, you can define your customized modules +and reuse the remaining parts of the pipeline, with minimized +workload. + + .. toctree:: :maxdepth: 2 :caption: Getting started From 48946f57c1224604364b546ffca8ee0f742c1399 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 26 Aug 2022 22:28:08 +0800 Subject: [PATCH 053/225] update usage-dataset update usage-dataset --- docs/source/installation.rst | 8 +---- docs/source/usage.fsl.rst | 56 +++++++++++++++++++++++++++++++- docs/source/usage.quickstart.rst | 3 +- docs/source/usage.rst | 4 +-- pymic/io/h5_dataset.py | 4 +-- 5 files changed, 62 insertions(+), 13 deletions(-) diff --git a/docs/source/installation.rst b/docs/source/installation.rst index 092e62b..ced640f 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -1,11 +1,5 @@ -.. _installation: - -.. role:: bash(code) - :language: bash - - Installation ------------- +============ Install PyMIC using pip (e.g., within a `Python virtual environment `_): diff --git a/docs/source/usage.fsl.rst b/docs/source/usage.fsl.rst index 08c1ae5..e96d4a3 100644 --- a/docs/source/usage.fsl.rst +++ b/docs/source/usage.fsl.rst @@ -28,4 +28,58 @@ configuration file for running. If you use the built-in modules such as ``UNet`` and ``Dice`` + ``CrossEntropy`` loss for segmentation, you don't need to write the above code. Just just use the `pymic_run` - command. \ No newline at end of file + command. + +Dataset +------- + +PyMIC provides two types of datasets for loading images from +disk to memory: ``NiftyDataset`` and ``H5DataSet``. + +``NiftyDataset`` is designed for 2D and 3D images in common formats +such as .png, .jpeg, .bmp and nii.gz. ``H5DataSet`` is used for +hdf5 data that are more efficient to load. + +To use ``NiftyDataset``, users need to specify the root path +of the dataset and the csv file storing the image and label +file names. Note that three .csv files are needed, and they are +for training, validation and testing, respectively. For example: + +.. code-block:: none + + [dataset] + # tensor type (float or double) + tensor_type = float + task_type = seg + root_dir = ../../PyMIC_data/JSRT + train_csv = config/jsrt_train.csv + valid_csv = config/jsrt_valid.csv + test_csv = config/jsrt_test.csv + train_batch_size = 4 + +The .csv file should have at least two columns (fields), +one for ``image`` and one for ``label``. If the input image +have multiple modalities, and each modality is saved in a single +file, then the .csv file should have N + 1 columnes, where the +first N columns are for the N modalities, and the last column +is for the label. + +To use your own dataset, you can define a dataset as a child class +of ``NiftyDataset``, ``H5DataSet``, :mod:`or torch.utils.data.Dataset` +, and use :mod:`SegmentationAgent.set_datasets()` +to set the customized datasets. For example: + +.. code-block:: none + + from torch.utils.data import Dataset + + class MyDataset(Dataset): + ... + # define your custom dataset here + + trainset = MyDataset(...) + valset = MyDataset(...) + testset = MyDataset(...) + agent = SegmentationAgent(config, stage) + agent.set_datasets(trainset, valset, testset) + agent.run() diff --git a/docs/source/usage.quickstart.rst b/docs/source/usage.quickstart.rst index 921cd23..e1cada6 100644 --- a/docs/source/usage.quickstart.rst +++ b/docs/source/usage.quickstart.rst @@ -113,4 +113,5 @@ file used for segmentation of lung from radiograph, which can be find in # convert the label of prediction output label_source = [0, 1] - label_target = [0, 255] \ No newline at end of file + label_target = [0, 255] + diff --git a/docs/source/usage.rst b/docs/source/usage.rst index f8a9c4a..1c8f9d0 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -1,7 +1,7 @@ Usage ===== -This usage give details of how to use PyMIC. +This usage gives details of how to use PyMIC. Beginners can easily start with training a deep learning model with configure files. When you are more familar with the PyMIC pipeline, you can define your customized modules @@ -11,7 +11,7 @@ workload. .. toctree:: :maxdepth: 2 - :caption: Getting started + :caption: usage.quickstart usage.fsl diff --git a/pymic/io/h5_dataset.py b/pymic/io/h5_dataset.py index a2913c1..a270d93 100644 --- a/pymic/io/h5_dataset.py +++ b/pymic/io/h5_dataset.py @@ -11,7 +11,7 @@ from torch.utils.data import Dataset from torch.utils.data.sampler import Sampler -class H5DataSets(Dataset): +class H5DataSet(Dataset): """ Dataset for loading images stored in h5 format. It generates 4D tensors with dimention order [C, D, H, W] for 3D images, and @@ -90,7 +90,7 @@ def grouper(iterable, n): if __name__ == "__main__": root_dir = "/home/guotai/disk2t/projects/semi_supervise/SSL4MIS/data/ACDC/data/slices" file_name = "/home/guotai/disk2t/projects/semi_supervise/slices.txt" - dataset = H5DataSets(root_dir, file_name) + dataset = H5DataSet(root_dir, file_name) train_loader = torch.utils.data.DataLoader(dataset, batch_size = 4, shuffle=True, num_workers= 1) for sample in train_loader: From 7544748f5b010f0e94edf1ae637080e99c3b5646 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 26 Aug 2022 22:43:34 +0800 Subject: [PATCH 054/225] Update usage.rst --- docs/source/usage.rst | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 1c8f9d0..2aef1d7 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -8,10 +8,8 @@ the PyMIC pipeline, you can define your customized modules and reuse the remaining parts of the pipeline, with minimized workload. - .. toctree:: :maxdepth: 2 - :caption: usage.quickstart usage.fsl From eee064720d655cf343988eaddcf58dfa6d77a709 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 27 Aug 2022 16:05:14 +0800 Subject: [PATCH 055/225] Update usage.rst --- docs/source/usage.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 2aef1d7..35fbcc5 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -5,7 +5,7 @@ This usage gives details of how to use PyMIC. Beginners can easily start with training a deep learning model with configure files. When you are more familar with the PyMIC pipeline, you can define your customized modules -and reuse the remaining parts of the pipeline, with minimized +and reuse the remaining parts of the pipeline, with minimal workload. .. toctree:: From b72de6a6a5a82528037764753a66327924cf9a3f Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 27 Aug 2022 16:14:34 +0800 Subject: [PATCH 056/225] Update index.rst --- docs/source/index.rst | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index 2396626..e6dcb64 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,11 +1,14 @@ Welcome to PyMIC's documentation! =================================== -PyMIC is a pytorch-based toolkit for medical image computing with annotation-efficient deep learning. -PyMIC is developed to support learning with imperfect labels, including semi-supervised and weakly supervised learning, and learning with noisy annotations. - -Check out the :doc:`usage` section for further information, including -how to :ref:`installation` the project. +PyMIC is a pytorch-based toolkit for medical image computing with annotation-efficient +deep learning. PyMIC is developed to support learning with imperfect labels, including +semi-supervised and weakly supervised learning, and learning with noisy annotations. + +Check out the :doc:`Installation` section for install PyMIC, and go to the :doc:`Usage` +section for understanding modules for the segmentation pipeline designed in PyMIC. +Please follow `PyMIC_examples. `_ +to quickly start with using PyMIC. .. note:: From 0603dd1d6b0a8410e708f5345bfa41d74483f2f0 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 27 Aug 2022 16:31:38 +0800 Subject: [PATCH 057/225] Update usage.fsl.rst add transforms --- docs/source/usage.fsl.rst | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/docs/source/usage.fsl.rst b/docs/source/usage.fsl.rst index e96d4a3..88af075 100644 --- a/docs/source/usage.fsl.rst +++ b/docs/source/usage.fsl.rst @@ -27,7 +27,7 @@ configuration file for running. .. tip:: If you use the built-in modules such as ``UNet`` and ``Dice`` + ``CrossEntropy`` loss - for segmentation, you don't need to write the above code. Just just use the `pymic_run` + for segmentation, you don't need to write the above code. Just just use the ``pymic_run`` command. Dataset @@ -35,14 +35,13 @@ Dataset PyMIC provides two types of datasets for loading images from disk to memory: ``NiftyDataset`` and ``H5DataSet``. - ``NiftyDataset`` is designed for 2D and 3D images in common formats -such as .png, .jpeg, .bmp and nii.gz. ``H5DataSet`` is used for +such as png, jpeg, bmp and nii.gz. ``H5DataSet`` is used for hdf5 data that are more efficient to load. To use ``NiftyDataset``, users need to specify the root path of the dataset and the csv file storing the image and label -file names. Note that three .csv files are needed, and they are +file names. Note that three csv files are needed, and they are for training, validation and testing, respectively. For example: .. code-block:: none @@ -57,10 +56,11 @@ for training, validation and testing, respectively. For example: test_csv = config/jsrt_test.csv train_batch_size = 4 -The .csv file should have at least two columns (fields), -one for ``image`` and one for ``label``. If the input image -have multiple modalities, and each modality is saved in a single -file, then the .csv file should have N + 1 columnes, where the +By default, the ``valid_batch_size`` is set to the same as the ``train_batch_size``, +and the ``test_batch_size`` is 1. The csv file should have at least two columns (fields), +one for ``image`` and the other for ``label``. If the input image +have multiple modalities with each modality saved in a single +file, then the csv file should have N + 1 columnes, where the first N columns are for the N modalities, and the last column is for the label. @@ -77,9 +77,21 @@ to set the customized datasets. For example: ... # define your custom dataset here - trainset = MyDataset(...) - valset = MyDataset(...) - testset = MyDataset(...) + trainset, valset, testset = MyDataset(...), MyDataset(...), MyDataset(...) agent = SegmentationAgent(config, stage) agent.set_datasets(trainset, valset, testset) agent.run() + +Transforms +---------- + +Several transforms are defined in PyMIC to preprocess or augment the data +before sending it to the network. The ``TransformDict`` in +:mod:`pymic.transform.trans_dict` lists all the built in transforms supported +in PyMIC. + + +Subject-wise results... +Transform Comments +ChannelWiseThreshold GREYMATTER +ChannelWiseThresholdWithNormalize GREYMATTER \ No newline at end of file From 5089df4abc771e34515655e017c94b669bb58117 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 27 Aug 2022 16:38:06 +0800 Subject: [PATCH 058/225] update transform table update transform table --- docs/source/index.rst | 4 ++-- docs/source/usage.fsl.rst | 9 +++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index e6dcb64..f9b62ba 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -5,9 +5,9 @@ PyMIC is a pytorch-based toolkit for medical image computing with annotation-eff deep learning. PyMIC is developed to support learning with imperfect labels, including semi-supervised and weakly supervised learning, and learning with noisy annotations. -Check out the :doc:`Installation` section for install PyMIC, and go to the :doc:`Usage` +Check out the :doc:`installation` section for install PyMIC, and go to the :doc:`usage` section for understanding modules for the segmentation pipeline designed in PyMIC. -Please follow `PyMIC_examples. `_ +Please follow `PyMIC_examples `_ to quickly start with using PyMIC. .. note:: diff --git a/docs/source/usage.fsl.rst b/docs/source/usage.fsl.rst index 88af075..25ddcf0 100644 --- a/docs/source/usage.fsl.rst +++ b/docs/source/usage.fsl.rst @@ -91,7 +91,8 @@ before sending it to the network. The ``TransformDict`` in in PyMIC. -Subject-wise results... -Transform Comments -ChannelWiseThreshold GREYMATTER -ChannelWiseThresholdWithNormalize GREYMATTER \ No newline at end of file + +|Transform|Comments| +|---|---| +|ChannelWiseThreshold | GREYMATTER | +|ChannelWiseThresholdWithNormalize | GREYMATTER | \ No newline at end of file From 5a5f9adf12a3ee386b5ada51fe46637ff43900b7 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 27 Aug 2022 16:53:15 +0800 Subject: [PATCH 059/225] Update usage.fsl.rst update transform --- docs/source/usage.fsl.rst | 47 +++++++++++++++++++++++++++++++++++---- 1 file changed, 43 insertions(+), 4 deletions(-) diff --git a/docs/source/usage.fsl.rst b/docs/source/usage.fsl.rst index 25ddcf0..7538582 100644 --- a/docs/source/usage.fsl.rst +++ b/docs/source/usage.fsl.rst @@ -90,9 +90,48 @@ before sending it to the network. The ``TransformDict`` in :mod:`pymic.transform.trans_dict` lists all the built in transforms supported in PyMIC. +In the configuration file, users can specify the transforms required for training, +validation and testing data, respectively. The parameters of each tranform class +should also be provided, such as following: +.. code-block:: none -|Transform|Comments| -|---|---| -|ChannelWiseThreshold | GREYMATTER | -|ChannelWiseThresholdWithNormalize | GREYMATTER | \ No newline at end of file + # data transforms + train_transform = [Pad, RandomRotate, RandomCrop, RandomFlip, NormalizeWithMeanStd, GammaCorrection, GaussianNoise, LabelToProbability] + valid_transform = [NormalizeWithMeanStd, Pad, LabelToProbability] + test_transform = [NormalizeWithMeanStd, Pad] + + Pad_output_size = [8, 256, 256] + Pad_ceil_mode = False + Pad_inverse = True # the inverse transform will be enabled during testing + + RandomRotate_angle_range_d = [-90, 90] + RandomRotate_angle_range_h = None + RandomRotate_angle_range_w = None + + RandomCrop_output_size = [6, 192, 192] + RandomCrop_foreground_focus = False + RandomCrop_foreground_ratio = None + Randomcrop_mask_label = None + + RandomFlip_flip_depth = False + RandomFlip_flip_height = True + RandomFlip_flip_width = True + + NormalizeWithMeanStd_channels = [0] + + GammaCorrection_channels = [0] + GammaCorrection_gamma_min = 0.7 + GammaCorrection_gamma_max = 1.5 + + GaussianNoise_channels = [0] + GaussianNoise_mean = 0 + GaussianNoise_std = 0.05 + GaussianNoise_probability = 0.5 + +For spatial transforms, you can specify whether an inverse transform is enabled +or not. Setting the inverse flag as True will transform the prediction output +inversely during testing, which is useful in testing time augmentation. If you +want to make images with different shapes to have the same shape before testing, +then the the correspoinding transform's inverse flag can also be set as True, so +that the prediction output will be transformed back to the original image space. \ No newline at end of file From 8617243cf5b3032312e15a05722520751fb9bb96 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 27 Aug 2022 21:35:18 +0800 Subject: [PATCH 060/225] Update usage.fsl.rst add customized transform --- docs/source/usage.fsl.rst | 41 +++++++++++++++++++++++++++++++++------ 1 file changed, 35 insertions(+), 6 deletions(-) diff --git a/docs/source/usage.fsl.rst b/docs/source/usage.fsl.rst index 7538582..9447bd7 100644 --- a/docs/source/usage.fsl.rst +++ b/docs/source/usage.fsl.rst @@ -78,7 +78,7 @@ to set the customized datasets. For example: # define your custom dataset here trainset, valset, testset = MyDataset(...), MyDataset(...), MyDataset(...) - agent = SegmentationAgent(config, stage) + agent = SegmentationAgent(config, stage) agent.set_datasets(trainset, valset, testset) agent.run() @@ -101,9 +101,10 @@ should also be provided, such as following: valid_transform = [NormalizeWithMeanStd, Pad, LabelToProbability] test_transform = [NormalizeWithMeanStd, Pad] + # the inverse transform will be enabled during testing Pad_output_size = [8, 256, 256] Pad_ceil_mode = False - Pad_inverse = True # the inverse transform will be enabled during testing + Pad_inverse = True RandomRotate_angle_range_d = [-90, 90] RandomRotate_angle_range_h = None @@ -131,7 +132,35 @@ should also be provided, such as following: For spatial transforms, you can specify whether an inverse transform is enabled or not. Setting the inverse flag as True will transform the prediction output -inversely during testing, which is useful in testing time augmentation. If you -want to make images with different shapes to have the same shape before testing, -then the the correspoinding transform's inverse flag can also be set as True, so -that the prediction output will be transformed back to the original image space. \ No newline at end of file +inversely during testing, such as ``Pad_inverse = True`` shown above. +If you want to make images with different shapes to have the same shape before testing, +then the correspoinding transform's inverse flag can be set as True, so +that the prediction output will be transformed back to the original image space. +This is also useful for test time augmentation. + +You can also define your own transform operations. To integrate your customized +transform to the PyMIC pipeline, just add it to the ``TransformDict``, and you can +also specify the parameters via configuration file for the customized transform. +The following is some example code for this: + +.. code-block:: none + from pymic.transform.trans_dict import TransformDict + from pymic.transform.abstract_transform import AbstractTransform + + # customized transform + class MyTransform(AbstractTransform): + def __init__(self, params): + super(MyTransform, self).__init__(params) + ... + + def __call__(self, sample): + ... + + def inverse_transform_for_prediction(self, sample): + ... + + my_trans_dict = TransformDict + my_trans_dict["MyTransform"] = MyTransform + agent = SegmentationAgent(config, stage) + agent.set_transform_dict(my_trans_dict) + agent.run() \ No newline at end of file From ceec0a4f1c2a9b70e3595c136e01823987f0de26 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 27 Aug 2022 21:41:53 +0800 Subject: [PATCH 061/225] Update usage.fsl.rst --- docs/source/usage.fsl.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/usage.fsl.rst b/docs/source/usage.fsl.rst index 9447bd7..89b0b86 100644 --- a/docs/source/usage.fsl.rst +++ b/docs/source/usage.fsl.rst @@ -140,10 +140,11 @@ This is also useful for test time augmentation. You can also define your own transform operations. To integrate your customized transform to the PyMIC pipeline, just add it to the ``TransformDict``, and you can -also specify the parameters via configuration file for the customized transform. +also specify the parameters via a configuration file for the customized transform. The following is some example code for this: .. code-block:: none + from pymic.transform.trans_dict import TransformDict from pymic.transform.abstract_transform import AbstractTransform From a56ee1638e8d80a05fda25e7cedbf4a57cbfdb8f Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 27 Aug 2022 21:56:50 +0800 Subject: [PATCH 062/225] Update usage.fsl.rst add network example --- docs/source/usage.fsl.rst | 46 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/docs/source/usage.fsl.rst b/docs/source/usage.fsl.rst index 89b0b86..89f34ad 100644 --- a/docs/source/usage.fsl.rst +++ b/docs/source/usage.fsl.rst @@ -87,7 +87,7 @@ Transforms Several transforms are defined in PyMIC to preprocess or augment the data before sending it to the network. The ``TransformDict`` in -:mod:`pymic.transform.trans_dict` lists all the built in transforms supported +:mod:`pymic.transform.trans_dict` lists all the built-in transforms supported in PyMIC. In the configuration file, users can specify the transforms required for training, @@ -164,4 +164,48 @@ The following is some example code for this: my_trans_dict["MyTransform"] = MyTransform agent = SegmentationAgent(config, stage) agent.set_transform_dict(my_trans_dict) + agent.run() + +Networks +-------- + +The configuration file has a ``network`` section to specify the network's type and +hyper-parameters. For example, the following is a configuration for using ``2DUNet``: + +.. code-block:: none + + [network] + net_type = UNet2D + # Parameters for UNet2D + class_num = 2 + in_chns = 1 + feature_chns = [16, 32, 64, 128, 256] + dropout = [0, 0, 0.3, 0.4, 0.5] + bilinear = False + deep_supervise= False + +The ``SegNetDict`` in :mod:`pymic.net.neg_dict_seg` lists all the built-in network +structures currently implemented in PyMIC. + +You can also define your own networks. To integrate your customized +network to the PyMIC pipeline, just call ``set_network()`` of ``SegmentationAgent``. +The following is some example code for this: + +.. code-block:: none + + import torch.nn as nn + from pymic.net.net_dict_seg import SegNetDict + + # customized network + class MyNetwork(nn.Module): + def __init__(self, params): + super(MyNetwork, self).__init__() + ... + + def forward(self, x): + ... + + net = MyNetwork(params) + agent = SegmentationAgent(config, stage) + agent.set_network(net) agent.run() \ No newline at end of file From b16fc1eda01154e524bb3ef3a0b40216810860c2 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 27 Aug 2022 22:09:42 +0800 Subject: [PATCH 063/225] Update usage.fsl.rst --- docs/source/usage.fsl.rst | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/docs/source/usage.fsl.rst b/docs/source/usage.fsl.rst index 89f34ad..7d30f63 100644 --- a/docs/source/usage.fsl.rst +++ b/docs/source/usage.fsl.rst @@ -13,6 +13,7 @@ initialize an instance of that class. An example code to use it is: .. code-block:: none from pymic.util.parse_config import * + from pymic.net_run.agent_seg import SegmentationAgent config_name = "a_config_file.cfg" config = parse_config(config_name) @@ -72,6 +73,7 @@ to set the customized datasets. For example: .. code-block:: none from torch.utils.data import Dataset + from pymic.net_run.agent_seg import SegmentationAgent class MyDataset(Dataset): ... @@ -147,6 +149,7 @@ The following is some example code for this: from pymic.transform.trans_dict import TransformDict from pymic.transform.abstract_transform import AbstractTransform + from pymic.net_run.agent_seg import SegmentationAgent # customized transform class MyTransform(AbstractTransform): @@ -194,7 +197,7 @@ The following is some example code for this: .. code-block:: none import torch.nn as nn - from pymic.net.net_dict_seg import SegNetDict + from pymic.net_run.agent_seg import SegmentationAgent # customized network class MyNetwork(nn.Module): @@ -202,7 +205,7 @@ The following is some example code for this: super(MyNetwork, self).__init__() ... - def forward(self, x): + def forward(self, x): ... net = MyNetwork(params) From de0ef4fbfae0666da2833e431d3db6037eda6b90 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 27 Aug 2022 22:20:59 +0800 Subject: [PATCH 064/225] Update usage.fsl.rst custom loss --- docs/source/usage.fsl.rst | 51 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/docs/source/usage.fsl.rst b/docs/source/usage.fsl.rst index 7d30f63..dd6d0a5 100644 --- a/docs/source/usage.fsl.rst +++ b/docs/source/usage.fsl.rst @@ -211,4 +211,53 @@ The following is some example code for this: net = MyNetwork(params) agent = SegmentationAgent(config, stage) agent.set_network(net) - agent.run() \ No newline at end of file + agent.run() + +Loss Functions +-------------- + +The setting of loss function is in the ``training`` section of the configuration file, +where the loss function name and hyper-parameters should be provided. +The ``SegLossDict`` in :mod:`pymic.loss.loss_dict_seg` lists all the built-in loss +functions currently implemented in PyMIC. + +The following is an example of the setting of loss: + +.. code-block:: none + + loss_type = DiceLoss + loss_softmax = True + +Note that PyMIC supports using a combination of loss functions. Just set ``loss_type`` +as a list of loss functions, and use ``loss_weight`` to specify the weight of each +loss, such as the following: + +.. code-block:: none + + loss_type = [DiceLoss, CrossEntropyLoss] + loss_weight = [0.5, 0.5] + +You can also define your own loss functions. To integrate your customized +loss function to the PyMIC pipeline, just add it to the ``SegLossDict``, and you can +also specify the parameters via a configuration file for the customized loss. +The following is some example code for this: + +.. code-block:: none + + from pymic.loss.loss_dict_seg import SegLossDict + from pymic.net_run.agent_seg import SegmentationAgent + + # customized loss + class MyLoss(nn.Module): + def __init__(self, params = None): + super(MyLoss, self).__init__() + ... + + def forward(self, loss_input_dict): + ... + + my_loss_dict = SegLossDict + my_loss_dict["MyLoss"] = MyLoss + agent = SegmentationAgent(config, stage) + agent.set_loss_dict(my_loss_dict) + agent.run() From 9991ded24129695cd14987e01f3e7505838122be Mon Sep 17 00:00:00 2001 From: taigw Date: Sun, 28 Aug 2022 22:07:54 +0800 Subject: [PATCH 065/225] Update usage.fsl.rst add training options --- docs/source/usage.fsl.rst | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/docs/source/usage.fsl.rst b/docs/source/usage.fsl.rst index dd6d0a5..6a0da8b 100644 --- a/docs/source/usage.fsl.rst +++ b/docs/source/usage.fsl.rst @@ -261,3 +261,23 @@ The following is some example code for this: agent = SegmentationAgent(config, stage) agent.set_loss_dict(my_loss_dict) agent.run() + + +Training Options +---------------- + +In addition to the loss fuction, users can specify several training +options in the ``training`` section of the configuration file. + +Optimizer +^^^^^^^^^ + +For optimizer, users need to set ``optimizer``, ``learning_rate``, +``momentum`` and ``weight_decay``. + + +Learning Rate Scheduler +^^^^^^^^^^^^^^^^^^^^^^^ + +The current supported learning rate schedulers are ``ReduceLROnPlateau`` +and ``MultiStepLR``. From 42b42ca5966c1c7be58a71a70a85bf4b5680c031 Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 29 Aug 2022 16:22:24 +0800 Subject: [PATCH 066/225] fix two bugs for classification fix two bugs for classification --- pymic/net_run/agent_cls.py | 2 +- pymic/util/evaluation_cls.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pymic/net_run/agent_cls.py b/pymic/net_run/agent_cls.py index 8687048..b443225 100644 --- a/pymic/net_run/agent_cls.py +++ b/pymic/net_run/agent_cls.py @@ -308,7 +308,7 @@ def infer(self): csv_writer = csv.writer(csv_file, delimiter=',', quotechar='"',quoting=csv.QUOTE_MINIMAL) head = ['image', 'label'] - if(len(out_lab_list[0]) > 1): + if(len(out_lab_list[0]) > 2): head = ['image'] + ['label{0:}'.format(i) for i in range(class_num)] csv_writer.writerow(head) for item in out_lab_list: diff --git a/pymic/util/evaluation_cls.py b/pymic/util/evaluation_cls.py index 1aeb3d4..8057b6b 100644 --- a/pymic/util/evaluation_cls.py +++ b/pymic/util/evaluation_cls.py @@ -62,7 +62,7 @@ def binary_evaluation(config): for i in range(len(gt_items)): assert(gt_items.iloc[i, 0] == prob_items.iloc[i, 0]) - gt_data = np.asarray(gt_items.iloc[:, 1]) + gt_data = np.asarray(gt_items.iloc[:, -1]) prob_data = np.asarray(prob_items.iloc[:, 1:]) score_list = [] for metric in metric_list: From 78f98a4c999322023fbc6f331f3748e1b6b064eb Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 29 Aug 2022 17:03:44 +0800 Subject: [PATCH 067/225] Update torch_pretrained_net.py --- pymic/net/cls/torch_pretrained_net.py | 47 ++++++++++++++++++++++++--- 1 file changed, 42 insertions(+), 5 deletions(-) diff --git a/pymic/net/cls/torch_pretrained_net.py b/pymic/net/cls/torch_pretrained_net.py index 198419f..f197617 100644 --- a/pymic/net/cls/torch_pretrained_net.py +++ b/pymic/net/cls/torch_pretrained_net.py @@ -1,6 +1,7 @@ # pretrained models from pytorch: https://pytorch.org/vision/0.8/models.html from __future__ import print_function, division +import itertools import torch import torch.nn as nn import torchvision.models as models @@ -25,7 +26,7 @@ def __init__(self, params): super(ResNet18, self).__init__() self.params = params cls_num = params['class_num'] - in_chns = params.get('input_chns', 3) + self.in_chns = params.get('input_chns', 3) self.pretrain = params.get('pretrain', True) self.update_layers = params.get('update_layers', 0) self.net = models.resnet18(pretrained = self.pretrain) @@ -33,6 +34,11 @@ def __init__(self, params): # replace the last layer num_ftrs = self.net.fc.in_features self.net.fc = nn.Linear(num_ftrs, cls_num) + + # replace the first layer when in_chns is not 3 + if(self.in_chns != 3): + self.net.conv1 = nn.Conv2d(self.in_chns, 64, kernel_size=(7, 7), + stride=(2, 2), padding=(3, 3), bias=False) def forward(self, x): return self.net(x) @@ -41,7 +47,14 @@ def get_parameters_to_update(self): if(self.pretrain == False or self.update_layers == 0): return self.net.parameters() elif(self.update_layers == -1): - return self.net.fc.parameters() + params = self.net.fc.parameters() + if(self.in_chns !=3): + # combining the two iterables into a single one + # see: https://dzone.com/articles/python-joining-multiple + params = itertools.chain() + for pram in [self.net.fc.parameters(), self.net.conv1.parameters()]: + params = itertools.chain(params, pram) + return params else: raise(ValueError("update_layers can only be 0 (all layers) " + "or -1 (the last layer)")) @@ -51,7 +64,7 @@ def __init__(self, params): super(VGG16, self).__init__() self.params = params cls_num = params['class_num'] - in_chns = params.get('input_chns', 3) + self.in_chns = params.get('input_chns', 3) self.pretrain = params.get('pretrain', True) self.update_layers = params.get('update_layers', 0) self.net = models.vgg16(pretrained = self.pretrain) @@ -59,6 +72,11 @@ def __init__(self, params): # replace the last layer num_ftrs = self.net.classifier[-1].in_features self.net.classifier[-1] = nn.Linear(num_ftrs, cls_num) + + # replace the first layer when in_chns is not 3 + if(self.in_chns != 3): + self.net.conv1 = nn.Conv2d(self.in_chns, 64, kernel_size=(7, 7), + stride=(2, 2), padding=(3, 3), bias=False) def forward(self, x): return self.net(x) @@ -67,7 +85,14 @@ def get_parameters_to_update(self): if(self.pretrain == False or self.update_layers == 0): return self.net.parameters() elif(self.update_layers == -1): - return self.net.classifier[-1].parameters() + params = self.net.classifier[-1].parameters() + if(self.in_chns !=3): + # combining the two iterables into a single one + # see: https://dzone.com/articles/python-joining-multiple + params = itertools.chain() + for pram in [self.net.classifier[-1].parameters(), self.net.conv1.parameters()]: + params = itertools.chain(params, pram) + return params else: raise(ValueError("update_layers can only be 0 (all layers) " + "or -1 (the last layer)")) @@ -85,6 +110,11 @@ def __init__(self, params): # replace the last layer num_ftrs = self.net.last_channel self.net.classifier[-1] = nn.Linear(num_ftrs, cls_num) + + # replace the first layer when in_chns is not 3 + if(self.in_chns != 3): + self.net.conv1 = nn.Conv2d(self.in_chns, 64, kernel_size=(7, 7), + stride=(2, 2), padding=(3, 3), bias=False) def forward(self, x): return self.net(x) @@ -93,7 +123,14 @@ def get_parameters_to_update(self): if(self.pretrain == False or self.update_layers == 0): return self.net.parameters() elif(self.update_layers == -1): - return self.net.classifier[-1].parameters() + params = self.net.classifier[-1].parameters() + if(self.in_chns !=3): + # combining the two iterables into a single one + # see: https://dzone.com/articles/python-joining-multiple + params = itertools.chain() + for pram in [self.net.classifier[-1].parameters(), self.net.conv1.parameters()]: + params = itertools.chain(params, pram) + return params else: raise(ValueError("update_layers can only be 0 (all layers) " + "or -1 (the last layer)")) \ No newline at end of file From 2ba1079105890efe43ecd485b6c5089405196160 Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 30 Aug 2022 12:56:01 +0800 Subject: [PATCH 068/225] Update usage.fsl.rst --- docs/source/usage.fsl.rst | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/docs/source/usage.fsl.rst b/docs/source/usage.fsl.rst index 6a0da8b..5228106 100644 --- a/docs/source/usage.fsl.rst +++ b/docs/source/usage.fsl.rst @@ -269,6 +269,24 @@ Training Options In addition to the loss fuction, users can specify several training options in the ``training`` section of the configuration file. +Itreations +^^^^^^^^^^ + +For training iterations, the following parameters need to be specified in +the configuration file: + +* iter_start: the start iteration, by default is 0. You can also set it to the +itheration where a pre-trained model stopped for continuing with the trainnig. + +* iter_max: the maximal allowed iteration for training. + +* iter_valid: the iterations needed before evaluating the performance on the +validaiton set. + +* iter_save: The inverse_transform_for_prediction + + + Optimizer ^^^^^^^^^ @@ -280,4 +298,5 @@ Learning Rate Scheduler ^^^^^^^^^^^^^^^^^^^^^^^ The current supported learning rate schedulers are ``ReduceLROnPlateau`` -and ``MultiStepLR``. +and ``MultiStepLR``, which can be specified in ``lr_scheduler`` in +the configuration file. Parameters related to ``ReduceLROnPlateau`` From debff4ba7c12c6963e3acf05d550d63352fd3f19 Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 30 Aug 2022 13:26:30 +0800 Subject: [PATCH 069/225] Update usage.fsl.rst --- docs/source/usage.fsl.rst | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/docs/source/usage.fsl.rst b/docs/source/usage.fsl.rst index 5228106..db817d6 100644 --- a/docs/source/usage.fsl.rst +++ b/docs/source/usage.fsl.rst @@ -275,15 +275,17 @@ Itreations For training iterations, the following parameters need to be specified in the configuration file: -* iter_start: the start iteration, by default is 0. You can also set it to the -itheration where a pre-trained model stopped for continuing with the trainnig. +* iter_start: the start iteration, by default is 0. None zero value means the +iteration where a pre-trained model stopped for continuing with the trainnig. * iter_max: the maximal allowed iteration for training. -* iter_valid: the iterations needed before evaluating the performance on the -validaiton set. +* iter_valid: if the value is K, it means evaluating the performance on the +validaiton set for every K steps. -* iter_save: The inverse_transform_for_prediction +* iter_save: The iteations for saving the model. If the value is k, it means +the model will be savled every k iterations. It can also be a list of integer numbers, +which specifies the iterations to save the model. From 44949540d027f5a88c4ff35e44fa154fd24bbf7e Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 30 Aug 2022 13:47:42 +0800 Subject: [PATCH 070/225] Update usage.fsl.rst --- docs/source/usage.fsl.rst | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/docs/source/usage.fsl.rst b/docs/source/usage.fsl.rst index db817d6..7310d91 100644 --- a/docs/source/usage.fsl.rst +++ b/docs/source/usage.fsl.rst @@ -275,18 +275,20 @@ Itreations For training iterations, the following parameters need to be specified in the configuration file: -* iter_start: the start iteration, by default is 0. None zero value means the +``iter_start``: the start iteration, by default is 0. None zero value means the iteration where a pre-trained model stopped for continuing with the trainnig. -* iter_max: the maximal allowed iteration for training. +``iter_max``: the maximal allowed iteration for training. -* iter_valid: if the value is K, it means evaluating the performance on the +``iter_valid``: if the value is K, it means evaluating the performance on the validaiton set for every K steps. -* iter_save: The iteations for saving the model. If the value is k, it means -the model will be savled every k iterations. It can also be a list of integer numbers, +``iter_save``: The iteations for saving the model. If the value is k, it means +the model will be saved every k iterations. It can also be a list of integer numbers, which specifies the iterations to save the model. +``early_stop_patience``: if the value is k, it means the training will stop when +the performance on the validation set does not improve for k iteations. Optimizer From da2a2519ba9e8cebba4976c8976c7dc843889066 Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 30 Aug 2022 14:00:45 +0800 Subject: [PATCH 071/225] Update usage.fsl.rst --- docs/source/usage.fsl.rst | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/docs/source/usage.fsl.rst b/docs/source/usage.fsl.rst index 7310d91..567e7bc 100644 --- a/docs/source/usage.fsl.rst +++ b/docs/source/usage.fsl.rst @@ -295,12 +295,20 @@ Optimizer ^^^^^^^^^ For optimizer, users need to set ``optimizer``, ``learning_rate``, -``momentum`` and ``weight_decay``. +``momentum`` and ``weight_decay``. The built-in optimizers include ``SGD``, +``Adam``, ``SparseAdam``, ``Adadelta``, ``Adagrad``, ``Adamax``, ``ASGD``, +``LBFGS``, ``RMSprop``, ``Rprop`` that are implemented in :mod:`torch.optim`. +You can also use customized optimizers via :mod:`SegmentationAgent.set_optimizer()`. Learning Rate Scheduler ^^^^^^^^^^^^^^^^^^^^^^^ -The current supported learning rate schedulers are ``ReduceLROnPlateau`` +The current built-in learning rate schedulers are ``ReduceLROnPlateau`` and ``MultiStepLR``, which can be specified in ``lr_scheduler`` in -the configuration file. Parameters related to ``ReduceLROnPlateau`` +the configuration file. + +Parameters related to ``ReduceLROnPlateau`` include ``lr_gamma``. +Parameters related to ``MultiStepLR`` include ``lr_gamma`` and ``lr_milestones``. + +You can also use customized lr schedulers via :mod:`SegmentationAgent.set_scheduler()`. From 7484bcf608acce15283b995f5151e9ee06dc02f8 Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 30 Aug 2022 14:10:07 +0800 Subject: [PATCH 072/225] Update usage.fsl.rst --- docs/source/usage.fsl.rst | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/docs/source/usage.fsl.rst b/docs/source/usage.fsl.rst index 567e7bc..5be5408 100644 --- a/docs/source/usage.fsl.rst +++ b/docs/source/usage.fsl.rst @@ -312,3 +312,15 @@ Parameters related to ``ReduceLROnPlateau`` include ``lr_gamma``. Parameters related to ``MultiStepLR`` include ``lr_gamma`` and ``lr_milestones``. You can also use customized lr schedulers via :mod:`SegmentationAgent.set_scheduler()`. + +Other Options +^^^^^^^^^^^^^ + +Other options for training include: + +``gpus``: a list of GPU index for training the model. If the length is larger than +one (such as [0, 1]), it means the model will be trained on multiple GPUs parallelly. + +``ckpt_save_dir``: the path to the folder for saving the trained models. + +``ckpt_prefix``: the prefix of the name to save the checkpoints. \ No newline at end of file From 530352a246005522ca8310d2c38ccba8814b52a4 Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 30 Aug 2022 14:36:00 +0800 Subject: [PATCH 073/225] Update usage.fsl.rst add options for testing. --- docs/source/usage.fsl.rst | 53 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 51 insertions(+), 2 deletions(-) diff --git a/docs/source/usage.fsl.rst b/docs/source/usage.fsl.rst index 5be5408..87136e8 100644 --- a/docs/source/usage.fsl.rst +++ b/docs/source/usage.fsl.rst @@ -297,7 +297,7 @@ Optimizer For optimizer, users need to set ``optimizer``, ``learning_rate``, ``momentum`` and ``weight_decay``. The built-in optimizers include ``SGD``, ``Adam``, ``SparseAdam``, ``Adadelta``, ``Adagrad``, ``Adamax``, ``ASGD``, -``LBFGS``, ``RMSprop``, ``Rprop`` that are implemented in :mod:`torch.optim`. +``LBFGS``, ``RMSprop`` and ``Rprop`` that are implemented in :mod:`torch.optim`. You can also use customized optimizers via :mod:`SegmentationAgent.set_optimizer()`. @@ -323,4 +323,53 @@ one (such as [0, 1]), it means the model will be trained on multiple GPUs parall ``ckpt_save_dir``: the path to the folder for saving the trained models. -``ckpt_prefix``: the prefix of the name to save the checkpoints. \ No newline at end of file +``ckpt_prefix``: the prefix of the name to save the checkpoints. + + +Inference Options +----------------- + +There are several options for inference after training the model. You can also select +the GPUs for testing, enable sliding window inference or inference with +test-time augmentation, etc. The following is a list of options availble for inference: + +``gpus``: a list of GPU index. Atually, only the first GPU in the list is used. + +``evaluation_mode`` (bool, default is True): set the model to evaluation mode or not. + +``test_time_dropout`` (bool, default is False): use test-time dropout or not. + +``ckpt_mode`` (integer): which checkpoint is used. 0--the last checkpoint; 1--the checkpoint +with the best performance on the validation set; 2--a specified checkpoint. + +``ckpt_name`` (string): the full path to the checkpoint if ``ckpt_mode = 2``. + +``post_process`` (string, default is None): the post process method after inference. +The current available post processing is ``PostKeepLargestComponent``. Uses can also +specify customized post process methods via :mod:`SegmentationAgent.set_postprocessor()`. + +``sliding_window_enable`` (bool, default is False): use sliding window for inference or not. + +``sliding_window_size``: a list for sliding window size when ``sliding_window_enable = True``. + +``sliding_window_stride``: a list for sliding window stride when ``sliding_window_enable = True``. + +``tta_mode`` (integer, default is 0): the mode for Test Time Augmentation (TTA). 0--not using TTA; 1--using +TTA based on horizontal and vertical flipping. + +``output_dir`` (string): the dir to save the prediction output. + +``ignore_dir`` (bool, default is True): if the input image name has a `/`, it will be replaced +with `_` in the output file name. + +``save_probability`` (boold, default is False): save the output probability for each class. + +``label_source`` (list, default is None): a list of label to be converted after prediction. For example, +``label_source = [0, 1]`` and ``label_target = [0, 255]`` will convert label value from 1 to 255. + +``label_target`` (list, default is None): a list of label after conversion. Use this with ``label_source``. + +``filename_replace_source`` (string, default is None): the substring in the filename will be replaced with +a new substring specified by ``filename_replace_target``. + +``filename_replace_target`` (string, default is None): work with ``filename_replace_source``. \ No newline at end of file From 5bb9d276efb931ab44619958702e2552bcde8512 Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 30 Aug 2022 14:47:06 +0800 Subject: [PATCH 074/225] Update usage.fsl.rst update options for testing --- docs/source/usage.fsl.rst | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/source/usage.fsl.rst b/docs/source/usage.fsl.rst index 87136e8..86a8f48 100644 --- a/docs/source/usage.fsl.rst +++ b/docs/source/usage.fsl.rst @@ -342,17 +342,17 @@ test-time augmentation, etc. The following is a list of options availble for inf ``ckpt_mode`` (integer): which checkpoint is used. 0--the last checkpoint; 1--the checkpoint with the best performance on the validation set; 2--a specified checkpoint. -``ckpt_name`` (string): the full path to the checkpoint if ``ckpt_mode = 2``. +``ckpt_name`` (string, optinal): the full path to the checkpoint if ckpt_mode = 2. ``post_process`` (string, default is None): the post process method after inference. -The current available post processing is ``PostKeepLargestComponent``. Uses can also +The current available post processing is :mod:`PostKeepLargestComponent`. Uses can also specify customized post process methods via :mod:`SegmentationAgent.set_postprocessor()`. ``sliding_window_enable`` (bool, default is False): use sliding window for inference or not. -``sliding_window_size``: a list for sliding window size when ``sliding_window_enable = True``. +``sliding_window_size`` (optinal): a list for sliding window size when sliding_window_enable = True. -``sliding_window_stride``: a list for sliding window stride when ``sliding_window_enable = True``. +``sliding_window_stride`` (optinal): a list for sliding window stride when sliding_window_enable = True. ``tta_mode`` (integer, default is 0): the mode for Test Time Augmentation (TTA). 0--not using TTA; 1--using TTA based on horizontal and vertical flipping. @@ -365,11 +365,11 @@ with `_` in the output file name. ``save_probability`` (boold, default is False): save the output probability for each class. ``label_source`` (list, default is None): a list of label to be converted after prediction. For example, -``label_source = [0, 1]`` and ``label_target = [0, 255]`` will convert label value from 1 to 255. +:mod:`label_source`` = [0, 1] and :mod:`label_target`` = [0, 255] will convert label value from 1 to 255. -``label_target`` (list, default is None): a list of label after conversion. Use this with ``label_source``. +``label_target`` (list, default is None): a list of label after conversion. Use this with :mod:`label_source`. ``filename_replace_source`` (string, default is None): the substring in the filename will be replaced with -a new substring specified by ``filename_replace_target``. +a new substring specified by :mod:`filename_replace_target`. -``filename_replace_target`` (string, default is None): work with ``filename_replace_source``. \ No newline at end of file +``filename_replace_target`` (string, default is None): work with :mod:`filename_replace_source`. \ No newline at end of file From a559184a6586841d07bae3e97bae680f772684f9 Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 30 Aug 2022 20:44:34 +0800 Subject: [PATCH 075/225] Update usage.fsl.rst --- docs/source/usage.fsl.rst | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/docs/source/usage.fsl.rst b/docs/source/usage.fsl.rst index 86a8f48..574db64 100644 --- a/docs/source/usage.fsl.rst +++ b/docs/source/usage.fsl.rst @@ -42,7 +42,35 @@ hdf5 data that are more efficient to load. To use ``NiftyDataset``, users need to specify the root path of the dataset and the csv file storing the image and label -file names. Note that three csv files are needed, and they are +file names. The configurations include the following items: + +``tensor_type``: data type for tensors. Should be :mod:`float`` or :mod:`double`. + +``task_type``: should be :mod:`seg` for segmentation tasks. + +``root_dir`` (string): the root dir of dataset. + +``modal_num`` (int, default is 1): modalities number. For images with N modalities, +each modality should be saved in an indepdent file. + +``train_csv`` (string): the path of csv files for training set. + +``valid_csv`` (string): the path of csv files for validation set. + +``test_csv`` (string): the path of csv files for testing set. + +``train_batch_size`` (int): the batch size for training set. + +``valid_batch_size`` (int, optional): the batch size for validation set. The defualt value +is set as :mod:`train_batch_size`. + +``test_batch_size`` (int, optional): the batch size for testing set. The defualt value +is 1. + + + + +Note that three csv files are needed, and they are for training, validation and testing, respectively. For example: .. code-block:: none @@ -321,6 +349,10 @@ Other options for training include: ``gpus``: a list of GPU index for training the model. If the length is larger than one (such as [0, 1]), it means the model will be trained on multiple GPUs parallelly. +``deterministic`` (bool, default is True): set the training deterministic or not. + +``random_seed`` (int, optioinal): the random seed customized by user. Default value is 1. + ``ckpt_save_dir``: the path to the folder for saving the trained models. ``ckpt_prefix``: the prefix of the name to save the checkpoints. From 4e6d3f4e5cf16a7680e011f29bc41914bf81ffaf Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 30 Aug 2022 20:50:36 +0800 Subject: [PATCH 076/225] Update usage.fsl.rst --- docs/source/usage.fsl.rst | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/docs/source/usage.fsl.rst b/docs/source/usage.fsl.rst index 574db64..b8081b7 100644 --- a/docs/source/usage.fsl.rst +++ b/docs/source/usage.fsl.rst @@ -44,7 +44,7 @@ To use ``NiftyDataset``, users need to specify the root path of the dataset and the csv file storing the image and label file names. The configurations include the following items: -``tensor_type``: data type for tensors. Should be :mod:`float`` or :mod:`double`. +``tensor_type``: data type for tensors. Should be :mod:`float` or :mod:`double`. ``task_type``: should be :mod:`seg` for segmentation tasks. @@ -53,11 +53,11 @@ file names. The configurations include the following items: ``modal_num`` (int, default is 1): modalities number. For images with N modalities, each modality should be saved in an indepdent file. -``train_csv`` (string): the path of csv files for training set. +``train_csv`` (string): the path of csv file for training set. -``valid_csv`` (string): the path of csv files for validation set. +``valid_csv`` (string): the path of csv file for validation set. -``test_csv`` (string): the path of csv files for testing set. +``test_csv`` (string): the path of csv file for testing set. ``train_batch_size`` (int): the batch size for training set. @@ -67,11 +67,12 @@ is set as :mod:`train_batch_size`. ``test_batch_size`` (int, optional): the batch size for testing set. The defualt value is 1. - - - -Note that three csv files are needed, and they are -for training, validation and testing, respectively. For example: +The csv file should have at least two columns (fields), +one for ``image`` and the other for ``label``. If the input image +have multiple modalities with each modality saved in a single +file, then the csv file should have N + 1 columns, where the +first N columns are for the N modalities, and the last column +is for the label. The following is an example for configuration of dataset. .. code-block:: none @@ -85,13 +86,6 @@ for training, validation and testing, respectively. For example: test_csv = config/jsrt_test.csv train_batch_size = 4 -By default, the ``valid_batch_size`` is set to the same as the ``train_batch_size``, -and the ``test_batch_size`` is 1. The csv file should have at least two columns (fields), -one for ``image`` and the other for ``label``. If the input image -have multiple modalities with each modality saved in a single -file, then the csv file should have N + 1 columnes, where the -first N columns are for the N modalities, and the last column -is for the label. To use your own dataset, you can define a dataset as a child class of ``NiftyDataset``, ``H5DataSet``, :mod:`or torch.utils.data.Dataset` From 44150c051c15e891f306bf3c935e5a283dda5ad6 Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 30 Aug 2022 20:52:07 +0800 Subject: [PATCH 077/225] Update usage.fsl.rst --- docs/source/usage.fsl.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/usage.fsl.rst b/docs/source/usage.fsl.rst index b8081b7..c51b852 100644 --- a/docs/source/usage.fsl.rst +++ b/docs/source/usage.fsl.rst @@ -365,7 +365,7 @@ test-time augmentation, etc. The following is a list of options availble for inf ``test_time_dropout`` (bool, default is False): use test-time dropout or not. -``ckpt_mode`` (integer): which checkpoint is used. 0--the last checkpoint; 1--the checkpoint +``ckpt_mode`` (int): which checkpoint is used. 0--the last checkpoint; 1--the checkpoint with the best performance on the validation set; 2--a specified checkpoint. ``ckpt_name`` (string, optinal): the full path to the checkpoint if ckpt_mode = 2. @@ -380,7 +380,7 @@ specify customized post process methods via :mod:`SegmentationAgent.set_postproc ``sliding_window_stride`` (optinal): a list for sliding window stride when sliding_window_enable = True. -``tta_mode`` (integer, default is 0): the mode for Test Time Augmentation (TTA). 0--not using TTA; 1--using +``tta_mode`` (int, default is 0): the mode for Test Time Augmentation (TTA). 0--not using TTA; 1--using TTA based on horizontal and vertical flipping. ``output_dir`` (string): the dir to save the prediction output. From adbc78694b8b7b8753b705509f7290c6f9b11d06 Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 30 Aug 2022 20:53:59 +0800 Subject: [PATCH 078/225] Update usage.fsl.rst --- docs/source/usage.fsl.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/usage.fsl.rst b/docs/source/usage.fsl.rst index c51b852..c05088e 100644 --- a/docs/source/usage.fsl.rst +++ b/docs/source/usage.fsl.rst @@ -391,7 +391,7 @@ with `_` in the output file name. ``save_probability`` (boold, default is False): save the output probability for each class. ``label_source`` (list, default is None): a list of label to be converted after prediction. For example, -:mod:`label_source`` = [0, 1] and :mod:`label_target`` = [0, 255] will convert label value from 1 to 255. +:mod:`label_source` = [0, 1] and :mod:`label_target` = [0, 255] will convert label value from 1 to 255. ``label_target`` (list, default is None): a list of label after conversion. Use this with :mod:`label_source`. From a22fec7e5fab2314b8c975dc6871a85263e8e274 Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 30 Aug 2022 20:57:34 +0800 Subject: [PATCH 079/225] add ssl add ssl --- docs/source/usage.rst | 1 + docs/source/usage.ssl.rst | 10 ++++++++++ 2 files changed, 11 insertions(+) create mode 100644 docs/source/usage.ssl.rst diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 35fbcc5..f197fc6 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -13,6 +13,7 @@ workload. usage.quickstart usage.fsl + usage.ssl diff --git a/docs/source/usage.ssl.rst b/docs/source/usage.ssl.rst new file mode 100644 index 0000000..40d9b54 --- /dev/null +++ b/docs/source/usage.ssl.rst @@ -0,0 +1,10 @@ +.. _semi_supervised_learning: + +Semi-Supervised Learning +========================= + +SSLSegAgent +----------- + +:mod:`pymic.net_run_ssl.ssl_abstract.SSLSegAgent` is the abstract class used for +semi-supervised learning. From f96a52a5c2694bd85fd9face6c1757667e466cd4 Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 30 Aug 2022 21:09:38 +0800 Subject: [PATCH 080/225] Update usage.ssl.rst --- docs/source/usage.ssl.rst | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/docs/source/usage.ssl.rst b/docs/source/usage.ssl.rst index 40d9b54..ba717f1 100644 --- a/docs/source/usage.ssl.rst +++ b/docs/source/usage.ssl.rst @@ -3,8 +3,25 @@ Semi-Supervised Learning ========================= -SSLSegAgent ------------ +pymic_ssl +--------- + +:mod:`pymic_ssl` is the command for using built-in semi-supervised methods for training. +Similarly to :mod:`pymic_run`, it should be followed by two parameters, specifying the +stage and configuration files. The training and testing commands are: + +.. code-block:: bash + + pymic_ssl train myconfig_ssl.cfg + pymic_ssl test myconfig_ssl.cfg + +.. tip:: + + If the SSL method only involves one network, either ``pymic_ssl`` or ``pymic_run`` + can be used for inference. Their difference only exists in the training stage. + +SSL Configurations +------------------ :mod:`pymic.net_run_ssl.ssl_abstract.SSLSegAgent` is the abstract class used for -semi-supervised learning. +semi-supervised learning. The reccesponding From f5c3bb02a8a660aec9b0d4b58b24adf4bbe03039 Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 30 Aug 2022 21:25:17 +0800 Subject: [PATCH 081/225] Update usage.ssl.rst --- docs/source/usage.ssl.rst | 50 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/docs/source/usage.ssl.rst b/docs/source/usage.ssl.rst index ba717f1..0fddeb1 100644 --- a/docs/source/usage.ssl.rst +++ b/docs/source/usage.ssl.rst @@ -23,5 +23,55 @@ stage and configuration files. The training and testing commands are: SSL Configurations ------------------ +In the configuration file for ``pymic_ssl``, in addition to those used in fully +supervised learning, there are some items specified for semi-supervised learning. + +Users should provide values for the following items in ``dataset`` section of +the configuration file: + +``train_csv_unlab`` (string): the csv file for unlabeled dataset. Note that ``train_csv`` +is only used for labeled dataset. + +``train_batch_size_unlab`` (int): the batch size for unlabeled dataset. Note that +``train_batch_size`` means the batch size for the labeled dataset. + +``train_transform_unlab`` (list): a list of transforms used for unlabeled data. + + +The following is an example of the ``dataset`` section for semi-supervised learning: + +.. code-block:: none + + ... + root_dir =../../PyMIC_data/ACDC/preprocess/ + train_csv = config/data/image_train_r10_lab.csv + train_csv_unlab = config/data/image_train_r10_unlab.csv + valid_csv = config/data/image_valid.csv + test_csv = config/data/image_test.csv + + train_batch_size = 4 + train_batch_size_unlab = 4 + + # data transforms + train_transform = [Pad, RandomRotate, RandomCrop, RandomFlip, NormalizeWithMeanStd, GammaCorrection, GaussianNoise, LabelToProbability] + train_transform_unlab = [Pad, RandomRotate, RandomCrop, RandomFlip, NormalizeWithMeanStd, GammaCorrection, GaussianNoise] + valid_transform = [NormalizeWithMeanStd, Pad, LabelToProbability] + test_transform = [NormalizeWithMeanStd, Pad] + ... + +In addition, there is a ``semi_supervised_learning`` section that is specifically designed +for SSL methods. In that section, users need to specify the ``ssl_method`` and configurations +related to the SSL method. For example, the correspoinding configuration for CPS is: + +.. code-block:: none + + ... + [semi_supervised_learning] + ssl_method = CPS + regularize_w = 0.1 + rampup_start = 1000 + rampup_end = 20000 + ... + :mod:`pymic.net_run_ssl.ssl_abstract.SSLSegAgent` is the abstract class used for semi-supervised learning. The reccesponding From afc2e72fd2f329b83a3d39e40a395845b1110883 Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 30 Aug 2022 21:35:36 +0800 Subject: [PATCH 082/225] Update usage.ssl.rst --- docs/source/usage.ssl.rst | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/docs/source/usage.ssl.rst b/docs/source/usage.ssl.rst index 0fddeb1..7e1561d 100644 --- a/docs/source/usage.ssl.rst +++ b/docs/source/usage.ssl.rst @@ -8,7 +8,7 @@ pymic_ssl :mod:`pymic_ssl` is the command for using built-in semi-supervised methods for training. Similarly to :mod:`pymic_run`, it should be followed by two parameters, specifying the -stage and configuration files. The training and testing commands are: +stage and configuration file, respectively. The training and testing commands are: .. code-block:: bash @@ -73,5 +73,23 @@ related to the SSL method. For example, the correspoinding configuration for CPS rampup_end = 20000 ... +.. note:: + + The configuration items vary with different SLL methods. Please refer to the API + of each built-in SLL method for details of the correspoinding configuration. + +SSL Methods +----------- + :mod:`pymic.net_run_ssl.ssl_abstract.SSLSegAgent` is the abstract class used for -semi-supervised learning. The reccesponding +semi-supervised learning. The built-in SLL methods are child classes of SSLSegAgent. +Currently, the following SLL methods are implemented in PyMIC: + +|PyMIC Method|Reference|Remarks| +|---|---|---| +|SSLEntropyMinimization|[Grandvalet et al.][em_paper], NeurIPS 2005| Oringinally proposed for classification| +|SSLMeanTeacher| [Tarvainen et al.][mt_paper], NeurIPS 2017| Oringinally proposed for classification| +|SSLUAMT| [Yu et al.][uamt_paper], MICCAI 2019| Uncertainty-aware mean teacher| +|SSLURPC| [Luo et al.][urpc_paper], MedIA 2022| Uncertainty rectified pyramid consistency| +|SSLCCT| [Ouali et al.][cct_paper], CVPR 2020| Cross-pseudo supervision| +|SSLCPS| [Chen et al.][cps_paper], CVPR 2021| Cross-consistency training| \ No newline at end of file From d436bc14446d52b5d94d7ce7f0e2645bc4360ede Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 30 Aug 2022 21:46:34 +0800 Subject: [PATCH 083/225] Update usage.ssl.rst --- docs/source/usage.ssl.rst | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/source/usage.ssl.rst b/docs/source/usage.ssl.rst index 7e1561d..00590f5 100644 --- a/docs/source/usage.ssl.rst +++ b/docs/source/usage.ssl.rst @@ -29,13 +29,13 @@ supervised learning, there are some items specified for semi-supervised learning Users should provide values for the following items in ``dataset`` section of the configuration file: -``train_csv_unlab`` (string): the csv file for unlabeled dataset. Note that ``train_csv`` -is only used for labeled dataset. +* ``train_csv_unlab`` (string): the csv file for unlabeled dataset. + Note that ``train_csv`` is only used for labeled dataset. -``train_batch_size_unlab`` (int): the batch size for unlabeled dataset. Note that -``train_batch_size`` means the batch size for the labeled dataset. +* ``train_batch_size_unlab`` (int): the batch size for unlabeled dataset. + Note that ``train_batch_size`` means the batch size for the labeled dataset. -``train_transform_unlab`` (list): a list of transforms used for unlabeled data. +* ``train_transform_unlab`` (list): a list of transforms used for unlabeled data. The following is an example of the ``dataset`` section for semi-supervised learning: @@ -82,7 +82,7 @@ SSL Methods ----------- :mod:`pymic.net_run_ssl.ssl_abstract.SSLSegAgent` is the abstract class used for -semi-supervised learning. The built-in SLL methods are child classes of SSLSegAgent. +semi-supervised learning. The built-in SLL methods are child classes of :mod:`SSLSegAgent`. Currently, the following SLL methods are implemented in PyMIC: |PyMIC Method|Reference|Remarks| From 6e229d45337229fafb0cf9f72690a95194522d62 Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 30 Aug 2022 21:51:41 +0800 Subject: [PATCH 084/225] Update usage.fsl.rst --- docs/source/usage.fsl.rst | 100 +++++++++++++++++++------------------- 1 file changed, 50 insertions(+), 50 deletions(-) diff --git a/docs/source/usage.fsl.rst b/docs/source/usage.fsl.rst index c05088e..c4a0e5b 100644 --- a/docs/source/usage.fsl.rst +++ b/docs/source/usage.fsl.rst @@ -44,28 +44,28 @@ To use ``NiftyDataset``, users need to specify the root path of the dataset and the csv file storing the image and label file names. The configurations include the following items: -``tensor_type``: data type for tensors. Should be :mod:`float` or :mod:`double`. +* ``tensor_type``: data type for tensors. Should be :mod:`float` or :mod:`double`. -``task_type``: should be :mod:`seg` for segmentation tasks. +* ``task_type``: should be :mod:`seg` for segmentation tasks. -``root_dir`` (string): the root dir of dataset. +* ``root_dir`` (string): the root dir of dataset. -``modal_num`` (int, default is 1): modalities number. For images with N modalities, -each modality should be saved in an indepdent file. +* ``modal_num`` (int, default is 1): modalities number. For images with N modalities, + each modality should be saved in an indepdent file. -``train_csv`` (string): the path of csv file for training set. +* ``train_csv`` (string): the path of csv file for training set. -``valid_csv`` (string): the path of csv file for validation set. +* ``valid_csv`` (string): the path of csv file for validation set. -``test_csv`` (string): the path of csv file for testing set. +* ``test_csv`` (string): the path of csv file for testing set. -``train_batch_size`` (int): the batch size for training set. +* ``train_batch_size`` (int): the batch size for training set. -``valid_batch_size`` (int, optional): the batch size for validation set. The defualt value -is set as :mod:`train_batch_size`. +* ``valid_batch_size`` (int, optional): the batch size for validation set. + The defualt value is set as :mod:`train_batch_size`. -``test_batch_size`` (int, optional): the batch size for testing set. The defualt value -is 1. +* ``test_batch_size`` (int, optional): the batch size for testing set. + The defualt value is 1. The csv file should have at least two columns (fields), one for ``image`` and the other for ``label``. If the input image @@ -297,20 +297,20 @@ Itreations For training iterations, the following parameters need to be specified in the configuration file: -``iter_start``: the start iteration, by default is 0. None zero value means the -iteration where a pre-trained model stopped for continuing with the trainnig. +* ``iter_start``: the start iteration, by default is 0. None zero value means the + iteration where a pre-trained model stopped for continuing with the trainnig. -``iter_max``: the maximal allowed iteration for training. +* ``iter_max``: the maximal allowed iteration for training. -``iter_valid``: if the value is K, it means evaluating the performance on the -validaiton set for every K steps. +* ``iter_valid``: if the value is K, it means evaluating the performance on the + validaiton set for every K steps. -``iter_save``: The iteations for saving the model. If the value is k, it means -the model will be saved every k iterations. It can also be a list of integer numbers, -which specifies the iterations to save the model. +* ``iter_save``: The iteations for saving the model. If the value is k, it means + the model will be saved every k iterations. It can also be a list of integer numbers, + which specifies the iterations to save the model. -``early_stop_patience``: if the value is k, it means the training will stop when -the performance on the validation set does not improve for k iteations. +* ``early_stop_patience``: if the value is k, it means the training will stop when + the performance on the validation set does not improve for k iteations. Optimizer @@ -340,16 +340,16 @@ Other Options Other options for training include: -``gpus``: a list of GPU index for training the model. If the length is larger than -one (such as [0, 1]), it means the model will be trained on multiple GPUs parallelly. +* ``gpus``: a list of GPU index for training the model. If the length is larger than + one (such as [0, 1]), it means the model will be trained on multiple GPUs parallelly. -``deterministic`` (bool, default is True): set the training deterministic or not. +* ``deterministic`` (bool, default is True): set the training deterministic or not. -``random_seed`` (int, optioinal): the random seed customized by user. Default value is 1. +* ``random_seed`` (int, optioinal): the random seed customized by user. Default value is 1. -``ckpt_save_dir``: the path to the folder for saving the trained models. +* ``ckpt_save_dir``: the path to the folder for saving the trained models. -``ckpt_prefix``: the prefix of the name to save the checkpoints. +* ``ckpt_prefix``: the prefix of the name to save the checkpoints. Inference Options @@ -359,43 +359,43 @@ There are several options for inference after training the model. You can also s the GPUs for testing, enable sliding window inference or inference with test-time augmentation, etc. The following is a list of options availble for inference: -``gpus``: a list of GPU index. Atually, only the first GPU in the list is used. +* ``gpus``: a list of GPU index. Atually, only the first GPU in the list is used. -``evaluation_mode`` (bool, default is True): set the model to evaluation mode or not. +* ``evaluation_mode`` (bool, default is True): set the model to evaluation mode or not. -``test_time_dropout`` (bool, default is False): use test-time dropout or not. +* ``test_time_dropout`` (bool, default is False): use test-time dropout or not. -``ckpt_mode`` (int): which checkpoint is used. 0--the last checkpoint; 1--the checkpoint +* ``ckpt_mode`` (int): which checkpoint is used. 0--the last checkpoint; 1--the checkpoint with the best performance on the validation set; 2--a specified checkpoint. -``ckpt_name`` (string, optinal): the full path to the checkpoint if ckpt_mode = 2. +* ``ckpt_name`` (string, optinal): the full path to the checkpoint if ckpt_mode = 2. -``post_process`` (string, default is None): the post process method after inference. +* ``post_process`` (string, default is None): the post process method after inference. The current available post processing is :mod:`PostKeepLargestComponent`. Uses can also specify customized post process methods via :mod:`SegmentationAgent.set_postprocessor()`. -``sliding_window_enable`` (bool, default is False): use sliding window for inference or not. +* ``sliding_window_enable`` (bool, default is False): use sliding window for inference or not. -``sliding_window_size`` (optinal): a list for sliding window size when sliding_window_enable = True. +* ``sliding_window_size`` (optinal): a list for sliding window size when sliding_window_enable = True. -``sliding_window_stride`` (optinal): a list for sliding window stride when sliding_window_enable = True. +* ``sliding_window_stride`` (optinal): a list for sliding window stride when sliding_window_enable = True. -``tta_mode`` (int, default is 0): the mode for Test Time Augmentation (TTA). 0--not using TTA; 1--using -TTA based on horizontal and vertical flipping. +* ``tta_mode`` (int, default is 0): the mode for Test Time Augmentation (TTA). 0--not using TTA; 1--using + TTA based on horizontal and vertical flipping. -``output_dir`` (string): the dir to save the prediction output. +* ``output_dir`` (string): the dir to save the prediction output. -``ignore_dir`` (bool, default is True): if the input image name has a `/`, it will be replaced -with `_` in the output file name. +* ``ignore_dir`` (bool, default is True): if the input image name has a `/`, it will be replaced + with `_` in the output file name. -``save_probability`` (boold, default is False): save the output probability for each class. +* ``save_probability`` (boold, default is False): save the output probability for each class. -``label_source`` (list, default is None): a list of label to be converted after prediction. For example, -:mod:`label_source` = [0, 1] and :mod:`label_target` = [0, 255] will convert label value from 1 to 255. +* ``label_source`` (list, default is None): a list of label to be converted after prediction. For example, + :mod:`label_source` = [0, 1] and :mod:`label_target` = [0, 255] will convert label value from 1 to 255. -``label_target`` (list, default is None): a list of label after conversion. Use this with :mod:`label_source`. +* ``label_target`` (list, default is None): a list of label after conversion. Use this with :mod:`label_source`. -``filename_replace_source`` (string, default is None): the substring in the filename will be replaced with -a new substring specified by :mod:`filename_replace_target`. +* ``filename_replace_source`` (string, default is None): the substring in the filename will be replaced with + a new substring specified by :mod:`filename_replace_target`. -``filename_replace_target`` (string, default is None): work with :mod:`filename_replace_source`. \ No newline at end of file +* ``filename_replace_target`` (string, default is None): work with :mod:`filename_replace_source`. \ No newline at end of file From ab3d56e676f5e6b6bb1b36019dabb69d8540f030 Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 30 Aug 2022 22:04:01 +0800 Subject: [PATCH 085/225] fix typo fix typo --- docs/source/usage.fsl.rst | 6 +++--- docs/source/usage.ssl.rst | 11 +---------- 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/docs/source/usage.fsl.rst b/docs/source/usage.fsl.rst index c4a0e5b..1470a7b 100644 --- a/docs/source/usage.fsl.rst +++ b/docs/source/usage.fsl.rst @@ -366,13 +366,13 @@ test-time augmentation, etc. The following is a list of options availble for inf * ``test_time_dropout`` (bool, default is False): use test-time dropout or not. * ``ckpt_mode`` (int): which checkpoint is used. 0--the last checkpoint; 1--the checkpoint -with the best performance on the validation set; 2--a specified checkpoint. + with the best performance on the validation set; 2--a specified checkpoint. * ``ckpt_name`` (string, optinal): the full path to the checkpoint if ckpt_mode = 2. * ``post_process`` (string, default is None): the post process method after inference. -The current available post processing is :mod:`PostKeepLargestComponent`. Uses can also -specify customized post process methods via :mod:`SegmentationAgent.set_postprocessor()`. + The current available post processing is :mod:`PostKeepLargestComponent`. Uses can also + specify customized post process methods via :mod:`SegmentationAgent.set_postprocessor()`. * ``sliding_window_enable`` (bool, default is False): use sliding window for inference or not. diff --git a/docs/source/usage.ssl.rst b/docs/source/usage.ssl.rst index 00590f5..22dd1bf 100644 --- a/docs/source/usage.ssl.rst +++ b/docs/source/usage.ssl.rst @@ -83,13 +83,4 @@ SSL Methods :mod:`pymic.net_run_ssl.ssl_abstract.SSLSegAgent` is the abstract class used for semi-supervised learning. The built-in SLL methods are child classes of :mod:`SSLSegAgent`. -Currently, the following SLL methods are implemented in PyMIC: - -|PyMIC Method|Reference|Remarks| -|---|---|---| -|SSLEntropyMinimization|[Grandvalet et al.][em_paper], NeurIPS 2005| Oringinally proposed for classification| -|SSLMeanTeacher| [Tarvainen et al.][mt_paper], NeurIPS 2017| Oringinally proposed for classification| -|SSLUAMT| [Yu et al.][uamt_paper], MICCAI 2019| Uncertainty-aware mean teacher| -|SSLURPC| [Luo et al.][urpc_paper], MedIA 2022| Uncertainty rectified pyramid consistency| -|SSLCCT| [Ouali et al.][cct_paper], CVPR 2020| Cross-pseudo supervision| -|SSLCPS| [Chen et al.][cps_paper], CVPR 2021| Cross-consistency training| \ No newline at end of file +Currently, the following SLL methods are implemented in PyMIC: \ No newline at end of file From 935f948912f4e29ac06b81f9a91da121e9605d84 Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 1 Sep 2022 09:34:56 +0800 Subject: [PATCH 086/225] Update usage.ssl.rst --- docs/source/usage.ssl.rst | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/docs/source/usage.ssl.rst b/docs/source/usage.ssl.rst index 22dd1bf..b4c827a 100644 --- a/docs/source/usage.ssl.rst +++ b/docs/source/usage.ssl.rst @@ -83,4 +83,19 @@ SSL Methods :mod:`pymic.net_run_ssl.ssl_abstract.SSLSegAgent` is the abstract class used for semi-supervised learning. The built-in SLL methods are child classes of :mod:`SSLSegAgent`. -Currently, the following SLL methods are implemented in PyMIC: \ No newline at end of file +The available SSL methods implemnted in PyMIC are listed in :mod:`pymic.net_run_ssl.ssl_main.SSLMethodDict`, +and they are: + +* ``EntropyMinimization``: (`NeurIPS 2005 `_) + Using entorpy minimization to regularize unannotated samples. + +* ``MeanTeacher``: (`NeurIPS 2017 `_) Use self-ensembling mean teacher to supervise the student model on + unannotated samples. + +* ``UAMT``: (`MICCAI 2019 `_) Uncertainty aware mean teacher. + +* ``CCT``: (`CVPR 2020 `_) Cross-consistency training. + +* ``CPS``: (`CVPR 2021 `_) Cross-pseudo supervision. + +* ``URPC``: (`MIA 2022 `_) Uncertainty rectified pyramid consistency. \ No newline at end of file From b0d3da9280ec82abe9dfa9f69e5cd4a310c50b5d Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 1 Sep 2022 10:02:33 +0800 Subject: [PATCH 087/225] Update usage.ssl.rst --- docs/source/usage.ssl.rst | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/docs/source/usage.ssl.rst b/docs/source/usage.ssl.rst index b4c827a..acd3c8a 100644 --- a/docs/source/usage.ssl.rst +++ b/docs/source/usage.ssl.rst @@ -78,8 +78,8 @@ related to the SSL method. For example, the correspoinding configuration for CPS The configuration items vary with different SLL methods. Please refer to the API of each built-in SLL method for details of the correspoinding configuration. -SSL Methods ------------ +Built-in SSL Methods +-------------------- :mod:`pymic.net_run_ssl.ssl_abstract.SSLSegAgent` is the abstract class used for semi-supervised learning. The built-in SLL methods are child classes of :mod:`SSLSegAgent`. @@ -98,4 +98,29 @@ and they are: * ``CPS``: (`CVPR 2021 `_) Cross-pseudo supervision. -* ``URPC``: (`MIA 2022 `_) Uncertainty rectified pyramid consistency. \ No newline at end of file +* ``URPC``: (`MIA 2022 `_) Uncertainty rectified pyramid consistency. + +Customized SSL Methods +---------------------- + +PyMIC alo supports customizing SSL methods by inheriting the :mod:`SSLSegAgent` class. +You may only need to rewrite the :mod:`training()` method and reuse most part of the +existing pipeline, such as data loading, validation and inference methods. For example: + +.. code-block:: none + + from pymic.net_run_ssl.ssl_abstract import SSLSegAgent + + class MySSLMethod(SSLSegAgent): + def __init__(self, config, stage = 'train'): + super(MySSLMethod, self).__init__(config, stage) + ... + + def training(self): + ... + + agent = MySSLMethod(config, stage) + agent.run() + +You may need to check the source code of built-in SLL methods to be more familar with +how to implement your own SLL method. \ No newline at end of file From 59a25644685c52fd82fdcde7d89366bd155297de Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 1 Sep 2022 10:11:42 +0800 Subject: [PATCH 088/225] add wsl method add wsl method --- docs/source/usage.rst | 1 + docs/source/usage.wsl.rst | 134 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 135 insertions(+) create mode 100644 docs/source/usage.wsl.rst diff --git a/docs/source/usage.rst b/docs/source/usage.rst index f197fc6..d8cca7f 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -14,6 +14,7 @@ workload. usage.quickstart usage.fsl usage.ssl + usage.wsl diff --git a/docs/source/usage.wsl.rst b/docs/source/usage.wsl.rst new file mode 100644 index 0000000..f278d56 --- /dev/null +++ b/docs/source/usage.wsl.rst @@ -0,0 +1,134 @@ +.. _weakly_supervised_learning: + +Weakly-Supervised Learning +========================== + +pymic_wsl +--------- + +:mod:`pymic_wsl` is the command for using built-in weakly-supervised methods for training. +Similarly to :mod:`pymic_run`, it should be followed by two parameters, specifying the +stage and configuration file, respectively. The training and testing commands are: + +.. code-block:: bash + + pymic_wsl train myconfig_wsl.cfg + pymic_wsl test myconfig_wsl.cfg + +.. tip:: + + If the WSL method only involves one network, either ``pymic_wsl`` or ``pymic_run`` + can be used for inference. Their difference only exists in the training stage. + +.. note:: + + Currently, the weakly supervised methods supported by PyMIC are only for learning + from partial annotations, such scribble-based annotation. Learning from image-level + or point annotations may involve several training stages and will be considered + in the future. + + +WSL Configurations +------------------ + +In the configuration file for ``pymic_wsl``, in addition to those used in fully +supervised learning, there are some items specified for weakly-supervised learning. + +Users should provide values for the following items in ``dataset`` section of +the configuration file: + +* ``train_csv_unlab`` (string): the csv file for unlabeled dataset. + Note that ``train_csv`` is only used for labeled dataset. + +* ``train_batch_size_unlab`` (int): the batch size for unlabeled dataset. + Note that ``train_batch_size`` means the batch size for the labeled dataset. + +* ``train_transform_unlab`` (list): a list of transforms used for unlabeled data. + + +The following is an example of the ``dataset`` section for semi-supervised learning: + +.. code-block:: none + + ... + root_dir =../../PyMIC_data/ACDC/preprocess/ + train_csv = config/data/image_train_r10_lab.csv + train_csv_unlab = config/data/image_train_r10_unlab.csv + valid_csv = config/data/image_valid.csv + test_csv = config/data/image_test.csv + + train_batch_size = 4 + train_batch_size_unlab = 4 + + # data transforms + train_transform = [Pad, RandomRotate, RandomCrop, RandomFlip, NormalizeWithMeanStd, GammaCorrection, GaussianNoise, LabelToProbability] + train_transform_unlab = [Pad, RandomRotate, RandomCrop, RandomFlip, NormalizeWithMeanStd, GammaCorrection, GaussianNoise] + valid_transform = [NormalizeWithMeanStd, Pad, LabelToProbability] + test_transform = [NormalizeWithMeanStd, Pad] + ... + +In addition, there is a ``semi_supervised_learning`` section that is specifically designed +for SSL methods. In that section, users need to specify the ``ssl_method`` and configurations +related to the SSL method. For example, the correspoinding configuration for CPS is: + +.. code-block:: none + + ... + [semi_supervised_learning] + ssl_method = CPS + regularize_w = 0.1 + rampup_start = 1000 + rampup_end = 20000 + ... + +.. note:: + + The configuration items vary with different SLL methods. Please refer to the API + of each built-in SLL method for details of the correspoinding configuration. + +Built-in WSL Methods +-------------------- + +:mod:`pymic.net_run_ssl.ssl_abstract.SSLSegAgent` is the abstract class used for +semi-supervised learning. The built-in SLL methods are child classes of :mod:`SSLSegAgent`. +The available SSL methods implemnted in PyMIC are listed in :mod:`pymic.net_run_ssl.ssl_main.SSLMethodDict`, +and they are: + +* ``EntropyMinimization``: (`NeurIPS 2005 `_) + Using entorpy minimization to regularize unannotated samples. + +* ``MeanTeacher``: (`NeurIPS 2017 `_) Use self-ensembling mean teacher to supervise the student model on + unannotated samples. + +* ``UAMT``: (`MICCAI 2019 `_) Uncertainty aware mean teacher. + +* ``CCT``: (`CVPR 2020 `_) Cross-consistency training. + +* ``CPS``: (`CVPR 2021 `_) Cross-pseudo supervision. + +* ``URPC``: (`MIA 2022 `_) Uncertainty rectified pyramid consistency. + +Customized WSL Methods +---------------------- + +PyMIC alo supports customizing SSL methods by inheriting the :mod:`SSLSegAgent` class. +You may only need to rewrite the :mod:`training()` method and reuse most part of the +existing pipeline, such as data loading, validation and inference methods. For example: + +.. code-block:: none + + from pymic.net_run_ssl.ssl_abstract import SSLSegAgent + + class MySSLMethod(SSLSegAgent): + def __init__(self, config, stage = 'train'): + super(MySSLMethod, self).__init__(config, stage) + ... + + def training(self): + ... + + agent = MySSLMethod(config, stage) + agent.run() + +You may need to check the source code of built-in SLL methods to be more familar with +how to implement your own SLL method. \ No newline at end of file From 9654365166cea7b64d18b5b9b4524febb00b4275 Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 1 Sep 2022 10:43:23 +0800 Subject: [PATCH 089/225] Update usage.wsl.rst --- docs/source/usage.wsl.rst | 59 +++++++++++++++++++++------------------ 1 file changed, 32 insertions(+), 27 deletions(-) diff --git a/docs/source/usage.wsl.rst b/docs/source/usage.wsl.rst index f278d56..e920f46 100644 --- a/docs/source/usage.wsl.rst +++ b/docs/source/usage.wsl.rst @@ -34,51 +34,56 @@ WSL Configurations In the configuration file for ``pymic_wsl``, in addition to those used in fully supervised learning, there are some items specified for weakly-supervised learning. -Users should provide values for the following items in ``dataset`` section of -the configuration file: +First, in the :mod:`train_transform` list, a special transform named :mod:`PartialLabelToProbability` +should be used to transform patial labels into a one-hot probability map and a weighting +map of pixels (i.e., the weight of a pixel is 1 if labeled and 0 otherwise). The patial +cross entropy loss on labeled pixels is actually implemented by a weighted cross entropy loss. +The loss setting is `loss_type = CrossEntropyLoss`. -* ``train_csv_unlab`` (string): the csv file for unlabeled dataset. - Note that ``train_csv`` is only used for labeled dataset. +Second, there is a ``weakly_supervised_learning`` section that is specifically designed +for WSL methods. In that section, users need to specify the ``wsl_method`` and configurations +related to the WSL method. For example, the correspoinding configuration for GatedCRF is: -* ``train_batch_size_unlab`` (int): the batch size for unlabeled dataset. - Note that ``train_batch_size`` means the batch size for the labeled dataset. -* ``train_transform_unlab`` (list): a list of transforms used for unlabeled data. - - -The following is an example of the ``dataset`` section for semi-supervised learning: .. code-block:: none + [dataset] ... - root_dir =../../PyMIC_data/ACDC/preprocess/ - train_csv = config/data/image_train_r10_lab.csv - train_csv_unlab = config/data/image_train_r10_unlab.csv + root_dir = ../../PyMIC_data/ACDC/preprocess + train_csv = config/data/image_train.csv valid_csv = config/data/image_valid.csv test_csv = config/data/image_test.csv train_batch_size = 4 - train_batch_size_unlab = 4 # data transforms - train_transform = [Pad, RandomRotate, RandomCrop, RandomFlip, NormalizeWithMeanStd, GammaCorrection, GaussianNoise, LabelToProbability] - train_transform_unlab = [Pad, RandomRotate, RandomCrop, RandomFlip, NormalizeWithMeanStd, GammaCorrection, GaussianNoise] - valid_transform = [NormalizeWithMeanStd, Pad, LabelToProbability] - test_transform = [NormalizeWithMeanStd, Pad] + train_transform = [Pad, RandomCrop, RandomFlip, NormalizeWithMeanStd, PartialLabelToProbability] + valid_transform = [NormalizeWithMeanStd, Pad, LabelToProbability] + test_transform = [NormalizeWithMeanStd, Pad] ... -In addition, there is a ``semi_supervised_learning`` section that is specifically designed -for SSL methods. In that section, users need to specify the ``ssl_method`` and configurations -related to the SSL method. For example, the correspoinding configuration for CPS is: - -.. code-block:: none + [network] + ... + [training] + ... + loss_type = CrossEntropyLoss ... - [semi_supervised_learning] - ssl_method = CPS + + [weakly_supervised_learning] + wsl_method = GatedCRF regularize_w = 0.1 - rampup_start = 1000 - rampup_end = 20000 + rampup_start = 2000 + rampup_end = 15000 + GatedCRFLoss_W0 = 1.0 + GatedCRFLoss_XY0 = 5 + GatedCRFLoss_rgb = 0.1 + GatedCRFLoss_W1 = 1.0 + GatedCRFLoss_XY1 = 3 + GatedCRFLoss_Radius = 5 + + [testing] ... .. note:: From 1c3c6be44bb5f8e4b7ba183e89e523d2ec7e0307 Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 1 Sep 2022 10:50:48 +0800 Subject: [PATCH 090/225] update wsl methods update wsl methods --- docs/source/usage.ssl.rst | 10 +++++----- docs/source/usage.wsl.rst | 38 +++++++++++++++++++++----------------- 2 files changed, 26 insertions(+), 22 deletions(-) diff --git a/docs/source/usage.ssl.rst b/docs/source/usage.ssl.rst index acd3c8a..a2d53d0 100644 --- a/docs/source/usage.ssl.rst +++ b/docs/source/usage.ssl.rst @@ -75,14 +75,14 @@ related to the SSL method. For example, the correspoinding configuration for CPS .. note:: - The configuration items vary with different SLL methods. Please refer to the API - of each built-in SLL method for details of the correspoinding configuration. + The configuration items vary with different SSL methods. Please refer to the API + of each built-in SSL method for details of the correspoinding configuration. Built-in SSL Methods -------------------- :mod:`pymic.net_run_ssl.ssl_abstract.SSLSegAgent` is the abstract class used for -semi-supervised learning. The built-in SLL methods are child classes of :mod:`SSLSegAgent`. +semi-supervised learning. The built-in SSL methods are child classes of :mod:`SSLSegAgent`. The available SSL methods implemnted in PyMIC are listed in :mod:`pymic.net_run_ssl.ssl_main.SSLMethodDict`, and they are: @@ -122,5 +122,5 @@ existing pipeline, such as data loading, validation and inference methods. For e agent = MySSLMethod(config, stage) agent.run() -You may need to check the source code of built-in SLL methods to be more familar with -how to implement your own SLL method. \ No newline at end of file +You may need to check the source code of built-in SSL methods to be more familar with +how to implement your own SSL method. \ No newline at end of file diff --git a/docs/source/usage.wsl.rst b/docs/source/usage.wsl.rst index e920f46..dea0400 100644 --- a/docs/source/usage.wsl.rst +++ b/docs/source/usage.wsl.rst @@ -94,46 +94,50 @@ related to the WSL method. For example, the correspoinding configuration for Gat Built-in WSL Methods -------------------- -:mod:`pymic.net_run_ssl.ssl_abstract.SSLSegAgent` is the abstract class used for -semi-supervised learning. The built-in SLL methods are child classes of :mod:`SSLSegAgent`. -The available SSL methods implemnted in PyMIC are listed in :mod:`pymic.net_run_ssl.ssl_main.SSLMethodDict`, +:mod:`pymic.net_run_wsl.wsl_abstract.WSLSegAgent` is the abstract class used for +weakly-supervised learning. The built-in SLL methods are child classes of :mod:`WSLSegAgent`. +The available WSL methods implemnted in PyMIC are listed in :mod:`pymic.net_run_wsl.wsl_main.WSLMethodDict`, and they are: * ``EntropyMinimization``: (`NeurIPS 2005 `_) - Using entorpy minimization to regularize unannotated samples. + Using entorpy minimization to regularize unannotated pixels. -* ``MeanTeacher``: (`NeurIPS 2017 `_) Use self-ensembling mean teacher to supervise the student model on - unannotated samples. +* ``GatedCRF``: (`arXiv 2019 `_) + Use gated CRF to regularize unannotated pixels. -* ``UAMT``: (`MICCAI 2019 `_) Uncertainty aware mean teacher. +* ``TotalVariation``: (`arXiv 2022 `_) + Use Total Variation to regularize unannotated pixels. -* ``CCT``: (`CVPR 2020 `_) Cross-consistency training. +* ``MumfordShah``: (`TIP 2020 `_) + Use Mumford Shah loss to regularize unannotated pixels. -* ``CPS``: (`CVPR 2021 `_) Cross-pseudo supervision. +* ``USTM``: (`PR 2022 `_) + Adapt USTM with transform-consistency regularization. -* ``URPC``: (`MIA 2022 `_) Uncertainty rectified pyramid consistency. +* ``DMPLS``: (`MICCAI 2022 `_) + Dynamically mixed pseudo label supervision Customized WSL Methods ---------------------- -PyMIC alo supports customizing SSL methods by inheriting the :mod:`SSLSegAgent` class. +PyMIC alo supports customizing WSL methods by inheriting the :mod:`WSLSegAgent` class. You may only need to rewrite the :mod:`training()` method and reuse most part of the existing pipeline, such as data loading, validation and inference methods. For example: .. code-block:: none - from pymic.net_run_ssl.ssl_abstract import SSLSegAgent + from pymic.net_run_wsl.wsl_abstract import WSLSegAgent - class MySSLMethod(SSLSegAgent): + class MyWSLMethod(WSLSegAgent): def __init__(self, config, stage = 'train'): - super(MySSLMethod, self).__init__(config, stage) + super(MyWSLMethod, self).__init__(config, stage) ... def training(self): ... - agent = MySSLMethod(config, stage) + agent = MyWSLMethod(config, stage) agent.run() -You may need to check the source code of built-in SLL methods to be more familar with -how to implement your own SLL method. \ No newline at end of file +You may need to check the source code of built-in WSL methods to be more familar with +how to implement your own WSL method. \ No newline at end of file From f74faf36df9f848098bb1ec428683c0a88fd7714 Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 1 Sep 2022 10:58:45 +0800 Subject: [PATCH 091/225] fix typo fix typo --- docs/source/usage.ssl.rst | 2 +- docs/source/usage.wsl.rst | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/usage.ssl.rst b/docs/source/usage.ssl.rst index a2d53d0..143d3f8 100644 --- a/docs/source/usage.ssl.rst +++ b/docs/source/usage.ssl.rst @@ -24,7 +24,7 @@ SSL Configurations ------------------ In the configuration file for ``pymic_ssl``, in addition to those used in fully -supervised learning, there are some items specified for semi-supervised learning. +supervised learning, there are some items specificalized for semi-supervised learning. Users should provide values for the following items in ``dataset`` section of the configuration file: diff --git a/docs/source/usage.wsl.rst b/docs/source/usage.wsl.rst index dea0400..df391c6 100644 --- a/docs/source/usage.wsl.rst +++ b/docs/source/usage.wsl.rst @@ -32,7 +32,7 @@ WSL Configurations ------------------ In the configuration file for ``pymic_wsl``, in addition to those used in fully -supervised learning, there are some items specified for weakly-supervised learning. +supervised learning, there are some items specificalized for weakly-supervised learning. First, in the :mod:`train_transform` list, a special transform named :mod:`PartialLabelToProbability` should be used to transform patial labels into a one-hot probability map and a weighting @@ -88,8 +88,8 @@ related to the WSL method. For example, the correspoinding configuration for Gat .. note:: - The configuration items vary with different SLL methods. Please refer to the API - of each built-in SLL method for details of the correspoinding configuration. + The configuration items vary with different WSL methods. Please refer to the API + of each built-in WSL method for details of the correspoinding configuration. Built-in WSL Methods -------------------- From 2df44ce443a97d899eb9061693c6fee2ec51c3dc Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 1 Sep 2022 11:45:01 +0800 Subject: [PATCH 092/225] update usage.nll update usage.nll --- docs/source/usage.fsl.rst | 2 + docs/source/usage.nll.rst | 121 ++++++++++++++++++++++++++++++++++++++ docs/source/usage.wsl.rst | 2 +- 3 files changed, 124 insertions(+), 1 deletion(-) create mode 100644 docs/source/usage.nll.rst diff --git a/docs/source/usage.fsl.rst b/docs/source/usage.fsl.rst index 1470a7b..585d7ec 100644 --- a/docs/source/usage.fsl.rst +++ b/docs/source/usage.fsl.rst @@ -235,6 +235,8 @@ The following is some example code for this: agent.set_network(net) agent.run() +.. _fsl_loss: + Loss Functions -------------- diff --git a/docs/source/usage.nll.rst b/docs/source/usage.nll.rst new file mode 100644 index 0000000..ddc68aa --- /dev/null +++ b/docs/source/usage.nll.rst @@ -0,0 +1,121 @@ +.. _noisy_label_learning: + +Noisy Label Learning +==================== + +pymic_nll +--------- + +:mod:`pymic_nll` is the command for using built-in NLL methods for training. +Similarly to :mod:`pymic_run`, it should be followed by two parameters, specifying the +stage and configuration file, respectively. The training and testing commands are: + +.. code-block:: bash + + pymic_nll train myconfig_nll.cfg + pymic_nll test myconfig_nll.cfg + +.. tip:: + + If the NLL method only involves one network, either ``pymic_nll`` or ``pymic_run`` + can be used for inference. Their difference only exists in the training stage. + +.. note:: + + Some NLL methods only use noise-robust loss functions without complex + training process, and just combining the standard :mod:`SegmentationAgent` with such + loss function works for training. ``pymic_run`` instead of ``pymic_nll`` should + be used for these methods. + + +NLL Configurations +------------------ + +In the configuration file for ``pymic_nll``, in addition to those used in standard fully +supervised learning, there is a ``noisy_label_learning`` section that is specifically designed +for NLL methods. In that section, users need to specify the ``nll_method`` and configurations +related to the NLL method. For example, the correspoinding configuration for CoTeaching is: + +.. code-block:: none + + [dataset] + ... + + [network] + ... + + [training] + ... + + [noisy_label_learning] + nll_method = CoTeaching + co_teaching_select_ratio = 0.8 + rampup_start = 1000 + rampup_end = 8000 + + [testing] + ... + +.. note:: + + The configuration items vary with different NLL methods. Please refer to the API + of each built-in NLL method for details of the correspoinding configuration. + +Built-in NLL Methods +-------------------- + +Some NLL methods only use noise-robust loss functions. They are used with ``pymic_run`` +for training. Just set ``loss_type`` to one of them in the configuration file, similarly +to the fully supervised learning. + +* ``GCELoss``: (`NeurIPS 2018 `_) + Generalized cross entropy loss. + +* ``MAELoss``: (`AAAI 2017 `_) + Mean Absolute Error loss. + +* ``NRDiceLoss``: (`TMI 2020 `_) + Noise-robust Dice loss. + +The other NLL methods are implemented in child classes of +:mod:`pymic.net_run_nll.nll_abstract.NLLSegAgent`, and they are: + +* ``CLSLSR``: (`MICCAI 2020 `_) + Confident learning with spatial label smoothing regularization. + +* ``CoTeaching``: (`NeurIPS 2018 `_) + Co-teaching between two networks for learning from noisy labels. + +* ``TriNet``: (`MICCAI 2020 `_) + Tri-network combined with sample selection. + +* ``DAST``: (`JBHI 2022 `_) + Divergence-aware selective training. + +Customized NLL Methods +---------------------- + +PyMIC alo supports customizing NLL methods by inheriting the :mod:`NLLSegAgent` class. +You may only need to rewrite the :mod:`training()` method and reuse most part of the +existing pipeline, such as data loading, validation and inference methods. For example: + +.. code-block:: none + + from pymic.net_run_nll.nll_abstract import NLLSegAgent + + class MyNLLMethod(NLLSegAgent): + def __init__(self, config, stage = 'train'): + super(MyNLLMethod, self).__init__(config, stage) + ... + + def training(self): + ... + + agent = MyNLLMethod(config, stage) + agent.run() + +You may need to check the source code of built-in NLL methods to be more familar with +how to implement your own NLL method. + +In addition, if you want to design a new noise-robust loss fucntion, +just follow :doc:`_fsl_loss` to impelement and use the customized loss. \ No newline at end of file diff --git a/docs/source/usage.wsl.rst b/docs/source/usage.wsl.rst index df391c6..00471f6 100644 --- a/docs/source/usage.wsl.rst +++ b/docs/source/usage.wsl.rst @@ -95,7 +95,7 @@ Built-in WSL Methods -------------------- :mod:`pymic.net_run_wsl.wsl_abstract.WSLSegAgent` is the abstract class used for -weakly-supervised learning. The built-in SLL methods are child classes of :mod:`WSLSegAgent`. +weakly-supervised learning. The built-in WSL methods are child classes of :mod:`WSLSegAgent`. The available WSL methods implemnted in PyMIC are listed in :mod:`pymic.net_run_wsl.wsl_main.WSLMethodDict`, and they are: From 2657cb180d59fba3ee15fdcdeb108318c4bfa855 Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 1 Sep 2022 13:17:16 +0800 Subject: [PATCH 093/225] Update usage.rst --- docs/source/usage.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index d8cca7f..4a4af53 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -15,6 +15,7 @@ workload. usage.fsl usage.ssl usage.wsl + usage.nll From 1594c2c11e7ad64ab21c6fcacf57224a7beb26c2 Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 1 Sep 2022 13:30:37 +0800 Subject: [PATCH 094/225] Update usage.nll.rst --- docs/source/usage.nll.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/usage.nll.rst b/docs/source/usage.nll.rst index ddc68aa..f6edd33 100644 --- a/docs/source/usage.nll.rst +++ b/docs/source/usage.nll.rst @@ -118,4 +118,4 @@ You may need to check the source code of built-in NLL methods to be more familar how to implement your own NLL method. In addition, if you want to design a new noise-robust loss fucntion, -just follow :doc:`_fsl_loss` to impelement and use the customized loss. \ No newline at end of file +just follow :doc:`usage.fsl` to impelement and use the customized loss. \ No newline at end of file From 45ad055857cab3ec9931ecf3488bf7d398401015 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 2 Sep 2022 21:27:18 +0800 Subject: [PATCH 095/225] Update api.rst --- docs/source/api.rst | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index ec94338..4faa9b7 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -1,7 +1,9 @@ API === -.. autosummary:: - :toctree: generated +UNet2D +------ - lumache +.. automodule:: pymic.net.net2d.unet2d + :members: + :show-inheritance: From b5ae5c6a23135bc619a36cc61c447f50fb9cc067 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 2 Sep 2022 22:04:31 +0800 Subject: [PATCH 096/225] Update unet2d.py --- pymic/net/net2d/unet2d.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pymic/net/net2d/unet2d.py b/pymic/net/net2d/unet2d.py index a361f0f..c17cd0c 100644 --- a/pymic/net/net2d/unet2d.py +++ b/pymic/net/net2d/unet2d.py @@ -135,6 +135,9 @@ def forward(self, x): class UNet2D(nn.Module): def __init__(self, params): + """ + 2D UNet + """ super(UNet2D, self).__init__() self.params = params self.in_chns = self.params['in_chns'] From 857e970e4e1b84f85e37562edb967981cd668cf4 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 2 Sep 2022 22:48:10 +0800 Subject: [PATCH 097/225] Update conf.py add path --- docs/source/conf.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/source/conf.py b/docs/source/conf.py index 9096c8f..cf1b568 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,6 +1,9 @@ # Configuration file for the Sphinx documentation builder. # -- Project information +import os +import sys +sys.path.insert(0, os.path.abspath('./../..')) project = 'PyMIC' copyright = '2021, HiLab' From ffa89bb6db5013fb37bf5e19352c093e5fe8ad90 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 3 Sep 2022 09:45:33 +0800 Subject: [PATCH 098/225] add api add api --- docs/source/api.rst | 29 ++++++-- docs/source/modules.rst | 7 ++ docs/source/pymic.io.rst | 37 ++++++++++ docs/source/pymic.layer.rst | 45 ++++++++++++ docs/source/pymic.loss.cls.rst | 53 +++++++++++++++ docs/source/pymic.loss.rst | 38 +++++++++++ docs/source/pymic.loss.seg.rst | 109 ++++++++++++++++++++++++++++++ docs/source/pymic.net.cls.rst | 21 ++++++ docs/source/pymic.net.net2d.rst | 85 +++++++++++++++++++++++ docs/source/pymic.net.net3d.rst | 45 ++++++++++++ docs/source/pymic.net.rst | 39 +++++++++++ docs/source/pymic.net_run.rst | 61 +++++++++++++++++ docs/source/pymic.net_run_nll.rst | 61 +++++++++++++++++ docs/source/pymic.net_run_ssl.rst | 77 +++++++++++++++++++++ docs/source/pymic.net_run_wsl.rst | 77 +++++++++++++++++++++ docs/source/pymic.rst | 27 ++++++++ docs/source/pymic.transform.rst | 109 ++++++++++++++++++++++++++++++ docs/source/pymic.util.rst | 93 +++++++++++++++++++++++++ 18 files changed, 1008 insertions(+), 5 deletions(-) create mode 100644 docs/source/modules.rst create mode 100644 docs/source/pymic.io.rst create mode 100644 docs/source/pymic.layer.rst create mode 100644 docs/source/pymic.loss.cls.rst create mode 100644 docs/source/pymic.loss.rst create mode 100644 docs/source/pymic.loss.seg.rst create mode 100644 docs/source/pymic.net.cls.rst create mode 100644 docs/source/pymic.net.net2d.rst create mode 100644 docs/source/pymic.net.net3d.rst create mode 100644 docs/source/pymic.net.rst create mode 100644 docs/source/pymic.net_run.rst create mode 100644 docs/source/pymic.net_run_nll.rst create mode 100644 docs/source/pymic.net_run_ssl.rst create mode 100644 docs/source/pymic.net_run_wsl.rst create mode 100644 docs/source/pymic.rst create mode 100644 docs/source/pymic.transform.rst create mode 100644 docs/source/pymic.util.rst diff --git a/docs/source/api.rst b/docs/source/api.rst index 4faa9b7..2fd9039 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -1,9 +1,28 @@ API === -UNet2D ------- +Module contents +--------------- -.. automodule:: pymic.net.net2d.unet2d - :members: - :show-inheritance: +.. automodule:: pymic + :members: + :undoc-members: + :show-inheritance: + + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + pymic.io + pymic.layer + pymic.loss + pymic.net + pymic.net_run + pymic.net_run_nll + pymic.net_run_ssl + pymic.net_run_wsl + pymic.transform + pymic.util \ No newline at end of file diff --git a/docs/source/modules.rst b/docs/source/modules.rst new file mode 100644 index 0000000..07e93c1 --- /dev/null +++ b/docs/source/modules.rst @@ -0,0 +1,7 @@ +pymic +===== + +.. toctree:: + :maxdepth: 4 + + pymic diff --git a/docs/source/pymic.io.rst b/docs/source/pymic.io.rst new file mode 100644 index 0000000..812522f --- /dev/null +++ b/docs/source/pymic.io.rst @@ -0,0 +1,37 @@ +pymic.io package +================ + +Submodules +---------- + +pymic.io.h5\_dataset module +--------------------------- + +.. automodule:: pymic.io.h5_dataset + :members: + :undoc-members: + :show-inheritance: + +pymic.io.image\_read\_write module +---------------------------------- + +.. automodule:: pymic.io.image_read_write + :members: + :undoc-members: + :show-inheritance: + +pymic.io.nifty\_dataset module +------------------------------ + +.. automodule:: pymic.io.nifty_dataset + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: pymic.io + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/pymic.layer.rst b/docs/source/pymic.layer.rst new file mode 100644 index 0000000..211beb7 --- /dev/null +++ b/docs/source/pymic.layer.rst @@ -0,0 +1,45 @@ +pymic.layer package +=================== + +Submodules +---------- + +pymic.layer.activation module +----------------------------- + +.. automodule:: pymic.layer.activation + :members: + :undoc-members: + :show-inheritance: + +pymic.layer.convolution module +------------------------------ + +.. automodule:: pymic.layer.convolution + :members: + :undoc-members: + :show-inheritance: + +pymic.layer.deconvolution module +-------------------------------- + +.. automodule:: pymic.layer.deconvolution + :members: + :undoc-members: + :show-inheritance: + +pymic.layer.space2channel module +-------------------------------- + +.. automodule:: pymic.layer.space2channel + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: pymic.layer + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/pymic.loss.cls.rst b/docs/source/pymic.loss.cls.rst new file mode 100644 index 0000000..73b816b --- /dev/null +++ b/docs/source/pymic.loss.cls.rst @@ -0,0 +1,53 @@ +pymic.loss.cls package +====================== + +Submodules +---------- + +pymic.loss.cls.ce module +------------------------ + +.. automodule:: pymic.loss.cls.ce + :members: + :undoc-members: + :show-inheritance: + +pymic.loss.cls.l1 module +------------------------ + +.. automodule:: pymic.loss.cls.l1 + :members: + :undoc-members: + :show-inheritance: + +pymic.loss.cls.mse module +------------------------- + +.. automodule:: pymic.loss.cls.mse + :members: + :undoc-members: + :show-inheritance: + +pymic.loss.cls.nll module +------------------------- + +.. automodule:: pymic.loss.cls.nll + :members: + :undoc-members: + :show-inheritance: + +pymic.loss.cls.util module +-------------------------- + +.. automodule:: pymic.loss.cls.util + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: pymic.loss.cls + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/pymic.loss.rst b/docs/source/pymic.loss.rst new file mode 100644 index 0000000..437e7c0 --- /dev/null +++ b/docs/source/pymic.loss.rst @@ -0,0 +1,38 @@ +pymic.loss package +================== + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + pymic.loss.cls + pymic.loss.seg + +Submodules +---------- + +pymic.loss.loss\_dict\_cls module +--------------------------------- + +.. automodule:: pymic.loss.loss_dict_cls + :members: + :undoc-members: + :show-inheritance: + +pymic.loss.loss\_dict\_seg module +--------------------------------- + +.. automodule:: pymic.loss.loss_dict_seg + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: pymic.loss + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/pymic.loss.seg.rst b/docs/source/pymic.loss.seg.rst new file mode 100644 index 0000000..c283ed3 --- /dev/null +++ b/docs/source/pymic.loss.seg.rst @@ -0,0 +1,109 @@ +pymic.loss.seg package +====================== + +Submodules +---------- + +pymic.loss.seg.ce module +------------------------ + +.. automodule:: pymic.loss.seg.ce + :members: + :undoc-members: + :show-inheritance: + +pymic.loss.seg.combined module +------------------------------ + +.. automodule:: pymic.loss.seg.combined + :members: + :undoc-members: + :show-inheritance: + +pymic.loss.seg.deep\_sup module +------------------------------- + +.. automodule:: pymic.loss.seg.deep_sup + :members: + :undoc-members: + :show-inheritance: + +pymic.loss.seg.dice module +-------------------------- + +.. automodule:: pymic.loss.seg.dice + :members: + :undoc-members: + :show-inheritance: + +pymic.loss.seg.exp\_log module +------------------------------ + +.. automodule:: pymic.loss.seg.exp_log + :members: + :undoc-members: + :show-inheritance: + +pymic.loss.seg.gatedcrf module +------------------------------ + +.. automodule:: pymic.loss.seg.gatedcrf + :members: + :undoc-members: + :show-inheritance: + +pymic.loss.seg.gatedcrf\_backup module +-------------------------------------- + +.. automodule:: pymic.loss.seg.gatedcrf_backup + :members: + :undoc-members: + :show-inheritance: + +pymic.loss.seg.mse module +------------------------- + +.. automodule:: pymic.loss.seg.mse + :members: + :undoc-members: + :show-inheritance: + +pymic.loss.seg.mumford\_shah module +----------------------------------- + +.. automodule:: pymic.loss.seg.mumford_shah + :members: + :undoc-members: + :show-inheritance: + +pymic.loss.seg.slsr module +-------------------------- + +.. automodule:: pymic.loss.seg.slsr + :members: + :undoc-members: + :show-inheritance: + +pymic.loss.seg.ssl module +------------------------- + +.. automodule:: pymic.loss.seg.ssl + :members: + :undoc-members: + :show-inheritance: + +pymic.loss.seg.util module +-------------------------- + +.. automodule:: pymic.loss.seg.util + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: pymic.loss.seg + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/pymic.net.cls.rst b/docs/source/pymic.net.cls.rst new file mode 100644 index 0000000..42842f7 --- /dev/null +++ b/docs/source/pymic.net.cls.rst @@ -0,0 +1,21 @@ +pymic.net.cls package +===================== + +Submodules +---------- + +pymic.net.cls.torch\_pretrained\_net module +------------------------------------------- + +.. automodule:: pymic.net.cls.torch_pretrained_net + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: pymic.net.cls + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/pymic.net.net2d.rst b/docs/source/pymic.net.net2d.rst new file mode 100644 index 0000000..d978dfe --- /dev/null +++ b/docs/source/pymic.net.net2d.rst @@ -0,0 +1,85 @@ +pymic.net.net2d package +======================= + +Submodules +---------- + +pymic.net.net2d.cople\_net module +--------------------------------- + +.. automodule:: pymic.net.net2d.cople_net + :members: + :undoc-members: + :show-inheritance: + +pymic.net.net2d.scse2d module +----------------------------- + +.. automodule:: pymic.net.net2d.scse2d + :members: + :undoc-members: + :show-inheritance: + +pymic.net.net2d.unet2d module +----------------------------- + +.. automodule:: pymic.net.net2d.unet2d + :members: + :undoc-members: + :show-inheritance: + +pymic.net.net2d.unet2d\_attention module +---------------------------------------- + +.. automodule:: pymic.net.net2d.unet2d_attention + :members: + :undoc-members: + :show-inheritance: + +pymic.net.net2d.unet2d\_cct module +---------------------------------- + +.. automodule:: pymic.net.net2d.unet2d_cct + :members: + :undoc-members: + :show-inheritance: + +pymic.net.net2d.unet2d\_dual\_branch module +------------------------------------------- + +.. automodule:: pymic.net.net2d.unet2d_dual_branch + :members: + :undoc-members: + :show-inheritance: + +pymic.net.net2d.unet2d\_nest module +----------------------------------- + +.. automodule:: pymic.net.net2d.unet2d_nest + :members: + :undoc-members: + :show-inheritance: + +pymic.net.net2d.unet2d\_scse module +----------------------------------- + +.. automodule:: pymic.net.net2d.unet2d_scse + :members: + :undoc-members: + :show-inheritance: + +pymic.net.net2d.unet2d\_urpc module +----------------------------------- + +.. automodule:: pymic.net.net2d.unet2d_urpc + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: pymic.net.net2d + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/pymic.net.net3d.rst b/docs/source/pymic.net.net3d.rst new file mode 100644 index 0000000..a09c187 --- /dev/null +++ b/docs/source/pymic.net.net3d.rst @@ -0,0 +1,45 @@ +pymic.net.net3d package +======================= + +Submodules +---------- + +pymic.net.net3d.scse3d module +----------------------------- + +.. automodule:: pymic.net.net3d.scse3d + :members: + :undoc-members: + :show-inheritance: + +pymic.net.net3d.unet2d5 module +------------------------------ + +.. automodule:: pymic.net.net3d.unet2d5 + :members: + :undoc-members: + :show-inheritance: + +pymic.net.net3d.unet3d module +----------------------------- + +.. automodule:: pymic.net.net3d.unet3d + :members: + :undoc-members: + :show-inheritance: + +pymic.net.net3d.unet3d\_scse module +----------------------------------- + +.. automodule:: pymic.net.net3d.unet3d_scse + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: pymic.net.net3d + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/pymic.net.rst b/docs/source/pymic.net.rst new file mode 100644 index 0000000..01a3209 --- /dev/null +++ b/docs/source/pymic.net.rst @@ -0,0 +1,39 @@ +pymic.net package +================= + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + pymic.net.cls + pymic.net.net2d + pymic.net.net3d + +Submodules +---------- + +pymic.net.net\_dict\_cls module +------------------------------- + +.. automodule:: pymic.net.net_dict_cls + :members: + :undoc-members: + :show-inheritance: + +pymic.net.net\_dict\_seg module +------------------------------- + +.. automodule:: pymic.net.net_dict_seg + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: pymic.net + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/pymic.net_run.rst b/docs/source/pymic.net_run.rst new file mode 100644 index 0000000..74aab12 --- /dev/null +++ b/docs/source/pymic.net_run.rst @@ -0,0 +1,61 @@ +pymic.net\_run package +====================== + +Submodules +---------- + +pymic.net\_run.agent\_abstract module +------------------------------------- + +.. automodule:: pymic.net_run.agent_abstract + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.agent\_cls module +-------------------------------- + +.. automodule:: pymic.net_run.agent_cls + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.agent\_seg module +-------------------------------- + +.. automodule:: pymic.net_run.agent_seg + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.get\_optimizer module +------------------------------------ + +.. automodule:: pymic.net_run.get_optimizer + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.infer\_func module +--------------------------------- + +.. automodule:: pymic.net_run.infer_func + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.net\_run module +------------------------------ + +.. automodule:: pymic.net_run.net_run + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: pymic.net_run + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/pymic.net_run_nll.rst b/docs/source/pymic.net_run_nll.rst new file mode 100644 index 0000000..120d23c --- /dev/null +++ b/docs/source/pymic.net_run_nll.rst @@ -0,0 +1,61 @@ +pymic.net\_run\_nll package +=========================== + +Submodules +---------- + +pymic.net\_run\_nll.nll\_cl module +---------------------------------- + +.. automodule:: pymic.net_run_nll.nll_cl + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run\_nll.nll\_clslsr module +-------------------------------------- + +.. automodule:: pymic.net_run_nll.nll_clslsr + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run\_nll.nll\_co\_teaching module +-------------------------------------------- + +.. automodule:: pymic.net_run_nll.nll_co_teaching + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run\_nll.nll\_dast module +------------------------------------ + +.. automodule:: pymic.net_run_nll.nll_dast + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run\_nll.nll\_main module +------------------------------------ + +.. automodule:: pymic.net_run_nll.nll_main + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run\_nll.nll\_trinet module +-------------------------------------- + +.. automodule:: pymic.net_run_nll.nll_trinet + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: pymic.net_run_nll + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/pymic.net_run_ssl.rst b/docs/source/pymic.net_run_ssl.rst new file mode 100644 index 0000000..236e2d6 --- /dev/null +++ b/docs/source/pymic.net_run_ssl.rst @@ -0,0 +1,77 @@ +pymic.net\_run\_ssl package +=========================== + +Submodules +---------- + +pymic.net\_run\_ssl.ssl\_abstract module +---------------------------------------- + +.. automodule:: pymic.net_run_ssl.ssl_abstract + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run\_ssl.ssl\_cct module +----------------------------------- + +.. automodule:: pymic.net_run_ssl.ssl_cct + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run\_ssl.ssl\_cps module +----------------------------------- + +.. automodule:: pymic.net_run_ssl.ssl_cps + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run\_ssl.ssl\_em module +---------------------------------- + +.. automodule:: pymic.net_run_ssl.ssl_em + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run\_ssl.ssl\_main module +------------------------------------ + +.. automodule:: pymic.net_run_ssl.ssl_main + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run\_ssl.ssl\_mt module +---------------------------------- + +.. automodule:: pymic.net_run_ssl.ssl_mt + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run\_ssl.ssl\_uamt module +------------------------------------ + +.. automodule:: pymic.net_run_ssl.ssl_uamt + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run\_ssl.ssl\_urpc module +------------------------------------ + +.. automodule:: pymic.net_run_ssl.ssl_urpc + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: pymic.net_run_ssl + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/pymic.net_run_wsl.rst b/docs/source/pymic.net_run_wsl.rst new file mode 100644 index 0000000..5eda921 --- /dev/null +++ b/docs/source/pymic.net_run_wsl.rst @@ -0,0 +1,77 @@ +pymic.net\_run\_wsl package +=========================== + +Submodules +---------- + +pymic.net\_run\_wsl.wsl\_abstract module +---------------------------------------- + +.. automodule:: pymic.net_run_wsl.wsl_abstract + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run\_wsl.wsl\_dmpls module +------------------------------------- + +.. automodule:: pymic.net_run_wsl.wsl_dmpls + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run\_wsl.wsl\_em module +---------------------------------- + +.. automodule:: pymic.net_run_wsl.wsl_em + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run\_wsl.wsl\_gatedcrf module +---------------------------------------- + +.. automodule:: pymic.net_run_wsl.wsl_gatedcrf + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run\_wsl.wsl\_main module +------------------------------------ + +.. automodule:: pymic.net_run_wsl.wsl_main + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run\_wsl.wsl\_mumford\_shah module +--------------------------------------------- + +.. automodule:: pymic.net_run_wsl.wsl_mumford_shah + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run\_wsl.wsl\_tv module +---------------------------------- + +.. automodule:: pymic.net_run_wsl.wsl_tv + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run\_wsl.wsl\_ustm module +------------------------------------ + +.. automodule:: pymic.net_run_wsl.wsl_ustm + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: pymic.net_run_wsl + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/pymic.rst b/docs/source/pymic.rst new file mode 100644 index 0000000..7545740 --- /dev/null +++ b/docs/source/pymic.rst @@ -0,0 +1,27 @@ +pymic package +============= + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + pymic.io + pymic.layer + pymic.loss + pymic.net + pymic.net_run + pymic.net_run_nll + pymic.net_run_ssl + pymic.net_run_wsl + pymic.transform + pymic.util + +Module contents +--------------- + +.. automodule:: pymic + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/pymic.transform.rst b/docs/source/pymic.transform.rst new file mode 100644 index 0000000..1dafa3b --- /dev/null +++ b/docs/source/pymic.transform.rst @@ -0,0 +1,109 @@ +pymic.transform package +======================= + +Submodules +---------- + +pymic.transform.abstract\_transform module +------------------------------------------ + +.. automodule:: pymic.transform.abstract_transform + :members: + :undoc-members: + :show-inheritance: + +pymic.transform.crop module +--------------------------- + +.. automodule:: pymic.transform.crop + :members: + :undoc-members: + :show-inheritance: + +pymic.transform.flip module +--------------------------- + +.. automodule:: pymic.transform.flip + :members: + :undoc-members: + :show-inheritance: + +pymic.transform.gray2rgb module +------------------------------- + +.. automodule:: pymic.transform.gray2rgb + :members: + :undoc-members: + :show-inheritance: + +pymic.transform.intensity module +-------------------------------- + +.. automodule:: pymic.transform.intensity + :members: + :undoc-members: + :show-inheritance: + +pymic.transform.label\_convert module +------------------------------------- + +.. automodule:: pymic.transform.label_convert + :members: + :undoc-members: + :show-inheritance: + +pymic.transform.normalize module +-------------------------------- + +.. automodule:: pymic.transform.normalize + :members: + :undoc-members: + :show-inheritance: + +pymic.transform.pad module +-------------------------- + +.. automodule:: pymic.transform.pad + :members: + :undoc-members: + :show-inheritance: + +pymic.transform.rescale module +------------------------------ + +.. automodule:: pymic.transform.rescale + :members: + :undoc-members: + :show-inheritance: + +pymic.transform.rotate module +----------------------------- + +.. automodule:: pymic.transform.rotate + :members: + :undoc-members: + :show-inheritance: + +pymic.transform.threshold module +-------------------------------- + +.. automodule:: pymic.transform.threshold + :members: + :undoc-members: + :show-inheritance: + +pymic.transform.trans\_dict module +---------------------------------- + +.. automodule:: pymic.transform.trans_dict + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: pymic.transform + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/pymic.util.rst b/docs/source/pymic.util.rst new file mode 100644 index 0000000..3a6c3ab --- /dev/null +++ b/docs/source/pymic.util.rst @@ -0,0 +1,93 @@ +pymic.util package +================== + +Submodules +---------- + +pymic.util.average\_model module +-------------------------------- + +.. automodule:: pymic.util.average_model + :members: + :undoc-members: + :show-inheritance: + +pymic.util.evaluation\_cls module +--------------------------------- + +.. automodule:: pymic.util.evaluation_cls + :members: + :undoc-members: + :show-inheritance: + +pymic.util.evaluation\_seg module +--------------------------------- + +.. automodule:: pymic.util.evaluation_seg + :members: + :undoc-members: + :show-inheritance: + +pymic.util.general module +------------------------- + +.. automodule:: pymic.util.general + :members: + :undoc-members: + :show-inheritance: + +pymic.util.image\_process module +-------------------------------- + +.. automodule:: pymic.util.image_process + :members: + :undoc-members: + :show-inheritance: + +pymic.util.parse\_config module +------------------------------- + +.. automodule:: pymic.util.parse_config + :members: + :undoc-members: + :show-inheritance: + +pymic.util.post\_process module +------------------------------- + +.. automodule:: pymic.util.post_process + :members: + :undoc-members: + :show-inheritance: + +pymic.util.preprocess module +---------------------------- + +.. automodule:: pymic.util.preprocess + :members: + :undoc-members: + :show-inheritance: + +pymic.util.ramps module +----------------------- + +.. automodule:: pymic.util.ramps + :members: + :undoc-members: + :show-inheritance: + +pymic.util.rename\_model module +------------------------------- + +.. automodule:: pymic.util.rename_model + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: pymic.util + :members: + :undoc-members: + :show-inheritance: From 05a20ca7d190b007bc193934cf358158e2ad63e9 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 3 Sep 2022 12:15:15 +0800 Subject: [PATCH 099/225] update doc update doc --- docs/source/api.rst | 12 ------------ pymic/io/image_read_write.py | 17 ++++++++++++----- 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index 2fd9039..206000d 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -1,18 +1,6 @@ API === -Module contents ---------------- - -.. automodule:: pymic - :members: - :undoc-members: - :show-inheritance: - - -Subpackages ------------ - .. toctree:: :maxdepth: 4 diff --git a/pymic/io/image_read_write.py b/pymic/io/image_read_write.py index a855b6c..798a8e2 100644 --- a/pymic/io/image_read_write.py +++ b/pymic/io/image_read_write.py @@ -7,7 +7,8 @@ from PIL import Image def load_nifty_volume_as_4d_array(filename): - """Read a nifty image and return a dictionay storing data array, spacing and direction + """ + Read a nifty image and return a dictionay storing data array, spacing and direction output['data_array'] 4d array with shape [C, D, H, W] output['spacing'] a list of spacing in z, y, x axis output['direction'] a 3x3 matrix for direction @@ -53,7 +54,11 @@ def load_rgb_image_as_3d_array(filename): def load_image_as_nd_array(image_name): """ - return a 4D array with shape [C, D, H, W], or 3D array with shape [C, H, W] + load an image and return a 4D array with shape [C, D, H, W], + or 3D array with shape [C, H, W]. + + Args: + image_name (string): the image name. """ if (image_name.endswith(".nii.gz") or image_name.endswith(".nii") or image_name.endswith(".mha")): @@ -67,11 +72,13 @@ def load_image_as_nd_array(image_name): def save_array_as_nifty_volume(data, image_name, reference_name = None): """ - save a numpy array as nifty image - inputs: + Save a numpy array as nifty image + + Args: data: a numpy array with shape [Depth, Height, Width] image_name: the ouput file name - reference_name: file name of the reference image of which affine and header are used + reference_name: file name of the reference image of which + meta information is used outputs: None """ img = sitk.GetImageFromArray(data) From 6be8f279b6a703e3f9b721efd6b5716f2fa91ac3 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 3 Sep 2022 12:27:59 +0800 Subject: [PATCH 100/225] Update image_read_write.py --- pymic/io/image_read_write.py | 71 +++++++++++++++++++++++++----------- 1 file changed, 50 insertions(+), 21 deletions(-) diff --git a/pymic/io/image_read_write.py b/pymic/io/image_read_write.py index 798a8e2..782e67f 100644 --- a/pymic/io/image_read_write.py +++ b/pymic/io/image_read_write.py @@ -8,10 +8,17 @@ def load_nifty_volume_as_4d_array(filename): """ - Read a nifty image and return a dictionay storing data array, spacing and direction + Read a nifty image and return a dictionay storing data array, origin, + spacing and direction. output['data_array'] 4d array with shape [C, D, H, W] output['spacing'] a list of spacing in z, y, x axis output['direction'] a 3x3 matrix for direction + + Args: + filename (str): the input file name + + Returns: + dict: a dictionay storing data array, origin, spacing and direction. """ img_obj = sitk.ReadImage(filename) data_array = sitk.GetArrayFromImage(img_obj) @@ -33,6 +40,19 @@ def load_nifty_volume_as_4d_array(filename): return output def load_rgb_image_as_3d_array(filename): + """ + Read an RGB image and return a dictionay storing data array, origin, + spacing and direction. + output['data_array'] 3d array with shape [D, H, W] + output['spacing'] a list of spacing in z, y, x axis + output['direction'] a 3x3 matrix for direction + + Args: + filename (str): the input file name + + Returns: + dict: a dictionay storing data array, origin, spacing and direction. + """ image = np.asarray(Image.open(filename)) image_shape = image.shape image_dim = len(image_shape) @@ -58,7 +78,10 @@ def load_image_as_nd_array(image_name): or 3D array with shape [C, H, W]. Args: - image_name (string): the image name. + image_name (str): the image name. + + Returns: + dict: a dictionay storing data array, origin, spacing and direction. """ if (image_name.endswith(".nii.gz") or image_name.endswith(".nii") or image_name.endswith(".mha")): @@ -75,11 +98,10 @@ def save_array_as_nifty_volume(data, image_name, reference_name = None): Save a numpy array as nifty image Args: - data: a numpy array with shape [Depth, Height, Width] - image_name: the ouput file name - reference_name: file name of the reference image of which - meta information is used - outputs: None + data(numpy.ndarray): a numpy array with shape [Depth, Height, Width] + image_name (str): the ouput file name + reference_name (str): file name of the reference image of which + meta information is used """ img = sitk.GetImageFromArray(data) if(reference_name is not None): @@ -92,11 +114,11 @@ def save_array_as_nifty_volume(data, image_name, reference_name = None): def save_array_as_rgb_image(data, image_name): """ - save a numpy array as rgb image - inputs: - data: a numpy array with shape [3, H, W] or [H, W, 3] or [H, W] - image_name: the output file name - outputs: None + Save a numpy array as rgb image + + Args: + data (numpy.ndarray): a numpy array with shape [3, H, W] or [H, W, 3] or [H, W] + image_name (str): the output file name """ data_dim = len(data.shape) if(data_dim == 3): @@ -108,11 +130,13 @@ def save_array_as_rgb_image(data, image_name): def save_nd_array_as_image(data, image_name, reference_name = None): """ - save a 3D or 2D numpy array as medical image or RGB image - inputs: - data: a numpy array with shape [D, H, W] or [C, H, W] - image_name: the output file name - outputs: None + Save a 3D or 2D numpy array as medical image or RGB image + + Args: + data (numpy.ndarray): a numpy array with shape [D, H, W] or [C, H, W] + image_name (str): the output file name + reference_name (str): file name of the reference image of which + meta information is used """ data_dim = len(data.shape) assert(data_dim == 2 or data_dim == 3) @@ -131,10 +155,15 @@ def save_nd_array_as_image(data, image_name, reference_name = None): def rotate_nifty_volume_to_LPS(filename_or_image_dict, origin = None, direction = None): ''' - filename_or_image_dict: filename of the nifty file (str) or image dictionary - returned by load_nifty_volume_as_4d_array. If supplied with the former, - the flipped image data will be saved to override the original file. - If supplied with the later, only flipped image data will be returned. + Rotate the axis of a 3D volume to LPS + + Args: + filename_or_image_dict (str or dict): filename of the nifty file (str) or image dictionary + returned by load_nifty_volume_as_4d_array. If supplied with the former, + the flipped image data will be saved to override the original file. + If supplied with the later, only flipped image data will be returned. + origin (list or tuple): the origin of the image. + direction (list or tuple): the direction of the image. ''' if type(filename_or_image_dict) == str: From ddbb6db320515e9e68026d9ba68c2a4001d57d7c Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 3 Sep 2022 12:31:37 +0800 Subject: [PATCH 101/225] Update image_read_write.py --- pymic/io/image_read_write.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pymic/io/image_read_write.py b/pymic/io/image_read_write.py index 782e67f..534c790 100644 --- a/pymic/io/image_read_write.py +++ b/pymic/io/image_read_write.py @@ -9,10 +9,10 @@ def load_nifty_volume_as_4d_array(filename): """ Read a nifty image and return a dictionay storing data array, origin, - spacing and direction. - output['data_array'] 4d array with shape [C, D, H, W] - output['spacing'] a list of spacing in z, y, x axis - output['direction'] a 3x3 matrix for direction + spacing and direction.\n + output['data_array'] 4d array with shape [C, D, H, W];\n + output['spacing'] a list of spacing in z, y, x axis;\n + output['direction'] a 3x3 matrix for direction. Args: filename (str): the input file name @@ -42,10 +42,10 @@ def load_nifty_volume_as_4d_array(filename): def load_rgb_image_as_3d_array(filename): """ Read an RGB image and return a dictionay storing data array, origin, - spacing and direction. - output['data_array'] 3d array with shape [D, H, W] - output['spacing'] a list of spacing in z, y, x axis - output['direction'] a 3x3 matrix for direction + spacing and direction. \n + output['data_array'] 3d array with shape [D, H, W]; \n + output['spacing'] a list of spacing in z, y, x axis; \n + output['direction'] a 3x3 matrix for direction. Args: filename (str): the input file name From b79ec02934ff7a491f6e0a5b04326185bd1d6b83 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 3 Sep 2022 12:36:16 +0800 Subject: [PATCH 102/225] Update image_read_write.py --- pymic/io/image_read_write.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pymic/io/image_read_write.py b/pymic/io/image_read_write.py index 534c790..2a1e59b 100644 --- a/pymic/io/image_read_write.py +++ b/pymic/io/image_read_write.py @@ -158,10 +158,10 @@ def rotate_nifty_volume_to_LPS(filename_or_image_dict, origin = None, direction Rotate the axis of a 3D volume to LPS Args: - filename_or_image_dict (str or dict): filename of the nifty file (str) or image dictionary - returned by load_nifty_volume_as_4d_array. If supplied with the former, - the flipped image data will be saved to override the original file. - If supplied with the later, only flipped image data will be returned. + filename_or_image_dict (str): filename of the nifty file (str) or image dictionary + returned by load_nifty_volume_as_4d_array. If supplied with the former, + the flipped image data will be saved to override the original file. + If supplied with the later, only flipped image data will be returned. origin (list or tuple): the origin of the image. direction (list or tuple): the direction of the image. ''' From 95e38f8f6edb132964245ae30e5c362897c64860 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 3 Sep 2022 15:15:16 +0800 Subject: [PATCH 103/225] Update image_read_write.py --- pymic/io/image_read_write.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/pymic/io/image_read_write.py b/pymic/io/image_read_write.py index 2a1e59b..e06924e 100644 --- a/pymic/io/image_read_write.py +++ b/pymic/io/image_read_write.py @@ -98,10 +98,10 @@ def save_array_as_nifty_volume(data, image_name, reference_name = None): Save a numpy array as nifty image Args: - data(numpy.ndarray): a numpy array with shape [Depth, Height, Width] - image_name (str): the ouput file name + data (numpy.ndarray): a numpy array with shape [Depth, Height, Width].\n + image_name (str): the ouput file name.\n reference_name (str): file name of the reference image of which - meta information is used + meta information is used. """ img = sitk.GetImageFromArray(data) if(reference_name is not None): @@ -117,8 +117,9 @@ def save_array_as_rgb_image(data, image_name): Save a numpy array as rgb image Args: - data (numpy.ndarray): a numpy array with shape [3, H, W] or [H, W, 3] or [H, W] - image_name (str): the output file name + data (numpy.ndarray): a numpy array with shape [3, H, W] or + [H, W, 3] or [H, W]. \n + image_name (str): the output file name. """ data_dim = len(data.shape) if(data_dim == 3): @@ -133,10 +134,10 @@ def save_nd_array_as_image(data, image_name, reference_name = None): Save a 3D or 2D numpy array as medical image or RGB image Args: - data (numpy.ndarray): a numpy array with shape [D, H, W] or [C, H, W] - image_name (str): the output file name + data (numpy.ndarray): a numpy array with shape [D, H, W] or [C, H, W]. \n + image_name (str): the output file name. \n reference_name (str): file name of the reference image of which - meta information is used + meta information is used. """ data_dim = len(data.shape) assert(data_dim == 2 or data_dim == 3) @@ -161,9 +162,13 @@ def rotate_nifty_volume_to_LPS(filename_or_image_dict, origin = None, direction filename_or_image_dict (str): filename of the nifty file (str) or image dictionary returned by load_nifty_volume_as_4d_array. If supplied with the former, the flipped image data will be saved to override the original file. - If supplied with the later, only flipped image data will be returned. - origin (list or tuple): the origin of the image. + If supplied with the later, only flipped image data will be returned.\n + origin (list or tuple): the origin of the image.\n direction (list or tuple): the direction of the image. + + Returns: + dict: a dictionary for image data and meta info, with ``data_array``, + ``origin``, ``direction`` and ``spacing``. ''' if type(filename_or_image_dict) == str: From 5a5d5bd5898e2eec7fc9dceece80d3ff594df718 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 3 Sep 2022 16:05:54 +0800 Subject: [PATCH 104/225] update docs update docs --- pymic/io/h5_dataset.py | 5 +++ pymic/io/nifty_dataset.py | 36 +++++++++++------ pymic/loss/seg/ce.py | 81 +++++++++++++++++++++++++-------------- 3 files changed, 82 insertions(+), 40 deletions(-) diff --git a/pymic/io/h5_dataset.py b/pymic/io/h5_dataset.py index a270d93..c6e5298 100644 --- a/pymic/io/h5_dataset.py +++ b/pymic/io/h5_dataset.py @@ -16,6 +16,11 @@ class H5DataSet(Dataset): Dataset for loading images stored in h5 format. It generates 4D tensors with dimention order [C, D, H, W] for 3D images, and 3D tensors with dimention order [C, H, W] for 2D images + + Args: + root_dir (str): thr root dir of images. \n + sample_list_name (str): a file name for sample list. \n + tranform (list): A list of transform objects applied on a sample. """ def __init__(self, root_dir, sample_list_name, transform = None): self.root_dir = root_dir diff --git a/pymic/io/nifty_dataset.py b/pymic/io/nifty_dataset.py index 590aa15..f09d24f 100644 --- a/pymic/io/nifty_dataset.py +++ b/pymic/io/nifty_dataset.py @@ -10,22 +10,21 @@ from pymic.io.image_read_write import load_image_as_nd_array class NiftyDataset(Dataset): - """Dataset for loading images. It generates 4D tensors with + """ + Dataset for loading images for segmentation. It generates 4D tensors with dimention order [C, D, H, W] for 3D images, and 3D tensors - with dimention order [C, H, W] for 2D images""" + with dimention order [C, H, W] for 2D images. + Args: + root_dir (str): Directory with all the images. \n + csv_file (str): Path to the csv file with image names. \n + modal_num (int): Number of modalities. \n + with_label (bool): Load the data with segmentation ground truth or not. \n + with_weight(bool): Load pixel-wise weight map or not. \n + transform (list): list of transform to be applied on a sample. + """ def __init__(self, root_dir, csv_file, modal_num = 1, with_label = False, transform=None): - """ - Args: - root_dir (string): Directory with all the images. - csv_file (string): Path to the csv file with image names. - modal_num (int): Number of modalities. - with_label (bool): Load the data with segmentation ground truth. - with_weight(bool): Load pixel-wise weight map. - transform (callable, optional): Optional transform to be applied - on a sample. - """ self.root_dir = root_dir self.csv_items = pd.read_csv(csv_file) self.modal_num = modal_num @@ -89,6 +88,19 @@ def __getitem__(self, idx): class ClassificationDataset(NiftyDataset): + """ + Dataset for loading images for classification. It generates 4D tensors with + dimention order [C, D, H, W] for 3D images, and 3D tensors + with dimention order [C, H, W] for 2D images. + + Args: + root_dir (str): Directory with all the images. \n + csv_file (str): Path to the csv file with image names. \n + modal_num (int): Number of modalities. \n + class_num (int): class number of the classificaiton task. \n + with_label (bool): Load the data with segmentation ground truth or not. \n + transform (list): list of transform to be applied on a sample. + """ def __init__(self, root_dir, csv_file, modal_num = 1, class_num = 2, with_label = False, transform=None): super(ClassificationDataset, self).__init__(root_dir, diff --git a/pymic/loss/seg/ce.py b/pymic/loss/seg/ce.py index cdef1a0..e88ed8c 100644 --- a/pymic/loss/seg/ce.py +++ b/pymic/loss/seg/ce.py @@ -6,6 +6,14 @@ from pymic.loss.seg.util import reshape_tensor_to_2D class CrossEntropyLoss(nn.Module): + """ + Cross entropy loss for segmentation tasks. + The arguments should be written in the `params` dictionary, and it has the + following fields: + + Args: + `loss_softmax` (bool): Apply softmax to the prediction of network or not. \n + """ def __init__(self, params = None): super(CrossEntropyLoss, self).__init__() if(params is None): @@ -14,6 +22,22 @@ def __init__(self, params = None): self.softmax = params.get('loss_softmax', True) def forward(self, loss_input_dict): + """ + Forward pass for calculating the loss. + The arguments should be written in the `loss_input_dict` dictionary, + and it has the following fields: + + Args: + `prediction` (tensor): prediction of a network, with the + shape of [N, C, D, H, W] or [N, C, H, W]. \n + `ground_truth` (tensor): ground truth, with the + shape of [N, C, D, H, W] or [N, C, H, W]. \n + `pixel_weight` (tensor or None): pixel weight map, with the + shape of [N, D, H, W] or [N, H, W]. \n + + Returns: + tensor: the loss value. + """ predict = loss_input_dict['prediction'] soft_y = loss_input_dict['ground_truth'] pix_w = loss_input_dict.get('pixel_weight', None) @@ -36,39 +60,23 @@ def forward(self, loss_input_dict): ce = torch.sum(pix_w * ce) / (pix_w.sum() + 1e-5) return ce -class PartialCrossEntropyLoss(nn.Module): - def __init__(self, params): - super(CrossEntropyLoss, self).__init__() - self.softmax = params.get('loss_softmax', True) - - def forward(self, loss_input_dict): - predict = loss_input_dict['prediction'] - soft_y = loss_input_dict['ground_truth'] - - if(isinstance(predict, (list, tuple))): - predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) - predict = reshape_tensor_to_2D(predict) - soft_y = reshape_tensor_to_2D(soft_y) - - # for numeric stability - predict = predict * 0.999 + 5e-4 - ce = - soft_y* torch.log(predict) - ce = torch.sum(ce, dim = 1) # shape is [N] - ce = torch.mean(ce) - return ce - class GeneralizedCELoss(nn.Module): """ Generalized cross entropy loss to deal with noisy labels. - Z. Zhang et al. Generalized Cross Entropy Loss for Training Deep Neural Networks - with Noisy Labels, NeurIPS 2018. + + Reference: Z. Zhang et al. Generalized Cross Entropy Loss for Training Deep Neural Networks + with Noisy Labels, NeurIPS 2018. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + Args: + `GeneralizedCELoss_Enable_Pixel_Weight` (bool): Use pixel weighting or not. \n + `GeneralizedCELoss_Enable_Class_Weight` (bool): Use class weighting or not. \n + `GeneralizedCELoss_q` (float): hyper-parameter in the range of (0, 1). \n + `loss_softmax` (bool): Apply softmax to the network's prediction or not. """ def __init__(self, params): - """ - q: in (0, 1), becmomes MAE when q = 1 - """ super(GeneralizedCELoss, self).__init__() self.enable_pix_weight = params.get('GeneralizedCELoss_Enable_Pixel_Weight', False) self.enable_cls_weight = params.get('GeneralizedCELoss_Enable_Class_Weight', False) @@ -76,6 +84,23 @@ def __init__(self, params): self.softmax = params.get('loss_softmax', True) def forward(self, loss_input_dict): + ''' + Forward pass for calculating the loss. + The arguments should be written in the `loss_input_dict` dictionary, + and it has the following fields: + + Args: + `prediction` (tensor): prediction of a network, with the + shape of [N, C, D, H, W] or [N, C, H, W]. \n + `ground_truth` (tensor): ground truth, with the + shape of [N, C, D, H, W] or [N, C, H, W]. \n + `pixel_weight` (tensor or None): pixel weight map, with the + shape of [N, D, H, W] or [N, H, W]. \n + `class_weight` (tensor or None): class weight map. + + Returns: + tensor: the loss value. + ''' predict = loss_input_dict['prediction'] soft_y = loss_input_dict['ground_truth'] From fe09952712a6009b454c2d87b55a6076d0f07911 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 3 Sep 2022 17:59:48 +0800 Subject: [PATCH 105/225] update docs update docs --- docs/source/pymic.loss.seg.rst | 8 ---- pymic/loss/seg/abstract.py | 32 +++++++++++++++ pymic/loss/seg/ce.py | 66 +++++++------------------------ pymic/loss/seg/combined.py | 12 +++++- pymic/loss/seg/deep_sup.py | 12 +++++- pymic/loss/seg/dice.py | 40 ++++++++++++++----- pymic/loss/seg/exp_log.py | 24 ++++++++++- pymic/loss/seg/gatedcrf.py | 30 +++++--------- pymic/loss/seg/gatedcrf_backup.py | 48 ---------------------- pymic/loss/seg/mse.py | 21 ++++++++-- pymic/loss/seg/mumford_shah.py | 38 ++++++++++++++---- pymic/loss/seg/slsr.py | 32 +++++++++------ pymic/loss/seg/ssl.py | 32 ++++++++++++++- pymic/loss/seg/util.py | 27 ++++++++++--- 14 files changed, 252 insertions(+), 170 deletions(-) create mode 100644 pymic/loss/seg/abstract.py delete mode 100644 pymic/loss/seg/gatedcrf_backup.py diff --git a/docs/source/pymic.loss.seg.rst b/docs/source/pymic.loss.seg.rst index c283ed3..d37f41f 100644 --- a/docs/source/pymic.loss.seg.rst +++ b/docs/source/pymic.loss.seg.rst @@ -52,14 +52,6 @@ pymic.loss.seg.gatedcrf module :undoc-members: :show-inheritance: -pymic.loss.seg.gatedcrf\_backup module --------------------------------------- - -.. automodule:: pymic.loss.seg.gatedcrf_backup - :members: - :undoc-members: - :show-inheritance: - pymic.loss.seg.mse module ------------------------- diff --git a/pymic/loss/seg/abstract.py b/pymic/loss/seg/abstract.py new file mode 100644 index 0000000..b5af5a0 --- /dev/null +++ b/pymic/loss/seg/abstract.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import torch.nn as nn + +class AbstractSegLoss(nn.Module): + """ + Cross entropy loss for segmentation tasks. + The arguments should be written in the `params` dictionary, and it has the + following fields: + + Args: + `loss_softmax` (bool): Apply softmax to the prediction of network or not. \n + """ + def __init__(self, params = None): + super(AbstractSegLoss, self).__init__() + + def forward(self, loss_input_dict): + """ + Forward pass for calculating the loss. + The arguments should be written in the `loss_input_dict` dictionary, + and it has the following fields: + + :param `prediction`: (tensor) Prediction of a network, with the + shape of [N, C, D, H, W] or [N, C, H, W]. + :param `ground_truth`: (tensor) Ground truth, with the + shape of [N, C, D, H, W] or [N, C, H, W]. + + :return: Loss function value. + """ + pass diff --git a/pymic/loss/seg/ce.py b/pymic/loss/seg/ce.py index e88ed8c..f9ee8ea 100644 --- a/pymic/loss/seg/ce.py +++ b/pymic/loss/seg/ce.py @@ -3,16 +3,16 @@ import torch import torch.nn as nn +from pymic.loss.seg.abstract import AbstractSegLoss from pymic.loss.seg.util import reshape_tensor_to_2D -class CrossEntropyLoss(nn.Module): +class CrossEntropyLoss(AbstractSegLoss): """ Cross entropy loss for segmentation tasks. The arguments should be written in the `params` dictionary, and it has the following fields: - Args: - `loss_softmax` (bool): Apply softmax to the prediction of network or not. \n + :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not. """ def __init__(self, params = None): super(CrossEntropyLoss, self).__init__() @@ -22,22 +22,6 @@ def __init__(self, params = None): self.softmax = params.get('loss_softmax', True) def forward(self, loss_input_dict): - """ - Forward pass for calculating the loss. - The arguments should be written in the `loss_input_dict` dictionary, - and it has the following fields: - - Args: - `prediction` (tensor): prediction of a network, with the - shape of [N, C, D, H, W] or [N, C, H, W]. \n - `ground_truth` (tensor): ground truth, with the - shape of [N, C, D, H, W] or [N, C, H, W]. \n - `pixel_weight` (tensor or None): pixel weight map, with the - shape of [N, D, H, W] or [N, H, W]. \n - - Returns: - tensor: the loss value. - """ predict = loss_input_dict['prediction'] soft_y = loss_input_dict['ground_truth'] pix_w = loss_input_dict.get('pixel_weight', None) @@ -60,7 +44,7 @@ def forward(self, loss_input_dict): ce = torch.sum(pix_w * ce) / (pix_w.sum() + 1e-5) return ce -class GeneralizedCELoss(nn.Module): +class GeneralizedCELoss(AbstractSegLoss): """ Generalized cross entropy loss to deal with noisy labels. @@ -70,37 +54,20 @@ class GeneralizedCELoss(nn.Module): The arguments should be written in the `params` dictionary, and it has the following fields: - Args: - `GeneralizedCELoss_Enable_Pixel_Weight` (bool): Use pixel weighting or not. \n - `GeneralizedCELoss_Enable_Class_Weight` (bool): Use class weighting or not. \n - `GeneralizedCELoss_q` (float): hyper-parameter in the range of (0, 1). \n - `loss_softmax` (bool): Apply softmax to the network's prediction or not. + :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not. + :param `loss_gce_q`: (float): hyper-parameter in the range of (0, 1). + :param `loss_with_pixel_weight`: (optional, bool): Use pixel weighting or not. + :param `loss_class_weight`: (optional, list or none): If not none, a list of weight for each class. + """ def __init__(self, params): super(GeneralizedCELoss, self).__init__() - self.enable_pix_weight = params.get('GeneralizedCELoss_Enable_Pixel_Weight', False) - self.enable_cls_weight = params.get('GeneralizedCELoss_Enable_Class_Weight', False) - self.q = params.get('GeneralizedCELoss_q', 0.5) self.softmax = params.get('loss_softmax', True) - - def forward(self, loss_input_dict): - ''' - Forward pass for calculating the loss. - The arguments should be written in the `loss_input_dict` dictionary, - and it has the following fields: - - Args: - `prediction` (tensor): prediction of a network, with the - shape of [N, C, D, H, W] or [N, C, H, W]. \n - `ground_truth` (tensor): ground truth, with the - shape of [N, C, D, H, W] or [N, C, H, W]. \n - `pixel_weight` (tensor or None): pixel weight map, with the - shape of [N, D, H, W] or [N, H, W]. \n - `class_weight` (tensor or None): class weight map. + self.q = params.get('loss_gce_q', 0.5) + self.enable_pix_weight = params.get('loss_with_pixel_weight', False) + self.cls_weight = params.get('loss_class_weight', None) - Returns: - tensor: the loss value. - ''' + def forward(self, loss_input_dict): predict = loss_input_dict['prediction'] soft_y = loss_input_dict['ground_truth'] @@ -112,11 +79,8 @@ def forward(self, loss_input_dict): soft_y = reshape_tensor_to_2D(soft_y) gce = (1.0 - torch.pow(predict, self.q)) / self.q * soft_y - if(self.enable_cls_weight): - cls_w = loss_input_dict.get('class_weight', None) - if(cls_w is None): - raise ValueError("Class weight is enabled but not defined") - gce = torch.sum(gce * cls_w, dim = 1) + if(self.cls_weight is not None): + gce = torch.sum(gce * self.cls_w, dim = 1) else: gce = torch.sum(gce, dim = 1) diff --git a/pymic/loss/seg/combined.py b/pymic/loss/seg/combined.py index 1a75874..f4c4431 100644 --- a/pymic/loss/seg/combined.py +++ b/pymic/loss/seg/combined.py @@ -3,8 +3,18 @@ import torch import torch.nn as nn +from pymic.loss.seg.abstract import AbstractSegLoss -class CombinedLoss(nn.Module): +class CombinedLoss(AbstractSegLoss): + ''' + A combination of a list of loss functions. + Arguments should be saved in the `params` dictionary. + + :param `loss_type`: (list) A list of loss function name. + :param `loss_weight`: (list) A list of weights for each loss fucntion. + :param loss_dict: (dictionary) A dictionary of avaiable loss functions. + + ''' def __init__(self, params, loss_dict): super(CombinedLoss, self).__init__() loss_names = params['loss_type'] diff --git a/pymic/loss/seg/deep_sup.py b/pymic/loss/seg/deep_sup.py index cd6fdc0..7362fd8 100644 --- a/pymic/loss/seg/deep_sup.py +++ b/pymic/loss/seg/deep_sup.py @@ -2,8 +2,18 @@ from __future__ import print_function, division import torch.nn as nn +from pymic.loss.seg.abstract import AbstractSegLoss -class DeepSuperviseLoss(nn.Module): +class DeepSuperviseLoss(AbstractSegLoss): + ''' + Combine deep supervision with a basic loss function. + Arguments should be provided in the `params` dictionary, and it has the + following fields: + + :param `deep_suervise_weight`: (list) A list of weight for each deep supervision scale. \n + :param `base_loss`: (nn.Module) The basic function used for each scale. + + ''' def __init__(self, params): super(DeepSuperviseLoss, self).__init__() self.deep_sup_weight = params.get('deep_suervise_weight', None) diff --git a/pymic/loss/seg/dice.py b/pymic/loss/seg/dice.py index 583b95c..c207f37 100644 --- a/pymic/loss/seg/dice.py +++ b/pymic/loss/seg/dice.py @@ -3,9 +3,17 @@ import torch import torch.nn as nn +from pymic.loss.seg.abstract import AbstractSegLoss from pymic.loss.seg.util import reshape_tensor_to_2D, get_classwise_dice -class DiceLoss(nn.Module): +class DiceLoss(AbstractSegLoss): + ''' + Dice loss for segmentation tasks. + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not. + ''' def __init__(self, params = None): super(DiceLoss, self).__init__() if(params is None): @@ -27,11 +35,18 @@ def forward(self, loss_input_dict): dice_loss = 1.0 - dice_score.mean() return dice_loss -class FocalDiceLoss(nn.Module): +class FocalDiceLoss(AbstractSegLoss): """ - focal Dice according to the following paper: - Pei Wang and Albert C. S. Chung, Focal Dice Loss and Image Dilation for - Brain Tumor Segmentation, 2018 + Focal Dice according to the following paper: + + * Pei Wang and Albert C. S. Chung, Focal Dice Loss and Image Dilation for + Brain Tumor Segmentation, 2018. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not. + :param `FocalDiceLoss_beta`: (float) The hyper-parameter to set (>=1.0). """ def __init__(self, params = None): super(FocalDiceLoss, self).__init__() @@ -54,12 +69,19 @@ def forward(self, loss_input_dict): dice_loss = 1.0 - dice_score.mean() return dice_loss -class NoiseRobustDiceLoss(nn.Module): +class NoiseRobustDiceLoss(AbstractSegLoss): """ Noise-robust Dice loss according to the following paper. - G. Wang et al. A Noise-Robust Framework for Automatic Segmentation of COVID-19 - Pneumonia Lesions From CT Images, IEEE TMI, 2020. - https://doi.org/10.1109/TMI.2020.3000314 + + * G. Wang et al. A Noise-Robust Framework for Automatic Segmentation of COVID-19 + Pneumonia Lesions From CT Images, + `IEEE TMI `_, 2020. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not. + :param `NoiseRobustDiceLoss_gamma`: (float) The hyper-parameter gammar to set (1, 2). """ def __init__(self, params): super(NoiseRobustDiceLoss, self).__init__() diff --git a/pymic/loss/seg/exp_log.py b/pymic/loss/seg/exp_log.py index 8f54a67..33bd031 100644 --- a/pymic/loss/seg/exp_log.py +++ b/pymic/loss/seg/exp_log.py @@ -9,8 +9,16 @@ class ExpLogLoss(nn.Module): """ The exponential logarithmic loss in this paper: - K. Wong et al.: 3D Segmentation with Exponential Logarithmic Loss for Highly - Unbalanced Object Sizes. MICCAI 2018. + + * K. Wong et al.: 3D Segmentation with Exponential Logarithmic Loss for Highly + Unbalanced Object Sizes. MICCAI 2018. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not. + :param `ExpLogLoss_w_dice`: (float) Weight of ExpLog Dice loss in the range of [0, 1]. + :param `ExpLogLoss_gamma`: (float) Hyper-parameter gamma. """ def __init__(self, params): super(ExpLogLoss, self).__init__() @@ -19,6 +27,18 @@ def __init__(self, params): self.gamma = params['ExpLogLoss_gamma'.lower()] def forward(self, loss_input_dict): + """ + Forward pass for calculating the loss. + The arguments should be written in the `loss_input_dict` dictionary, + and it has the following fields: + + :param `prediction`: (tensor) Prediction of a network, with the + shape of [N, C, D, H, W] or [N, C, H, W]. + :param `ground_truth`: (tensor) Ground truth, with the + shape of [N, C, D, H, W] or [N, C, H, W]. + + :return: Loss function value. + """ predict = loss_input_dict['prediction'] soft_y = loss_input_dict['ground_truth'] diff --git a/pymic/loss/seg/gatedcrf.py b/pymic/loss/seg/gatedcrf.py index e95ec43..5e23655 100644 --- a/pymic/loss/seg/gatedcrf.py +++ b/pymic/loss/seg/gatedcrf.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- """ -the original implementation is from: -https://github.com/LEONOB2014/GatedCRFLoss/blob/master/models/model_loss_semseg_gatedcrf.py +The code is adapted from the original implementation at `Github. +`_ """ import torch import torch.nn.functional as F @@ -9,17 +9,12 @@ class ModelLossSemsegGatedCRF(torch.nn.Module): """ - This module provides an implementation of the Gated CRF Loss for Weakly Supervised Semantic Image Segmentation. + Gated CRF Loss for Weakly Supervised Semantic Image Segmentation. This loss function promotes consistent label assignment guided by input features, such as RGBXY. - Please consider using the following bibtex for citation: - @article{obukhov2019gated, - author={Anton Obukhov and Stamatios Georgoulis and Dengxin Dai and Luc {Van Gool}}, - title={Gated {CRF} Loss for Weakly Supervised Semantic Image Segmentation}, - journal={CoRR}, - volume={abs/1906.04651}, - year={2019}, - url={http://arxiv.org/abs/1906.04651}, - } + + * Reference: Anton Obukhov, Stamatios Georgoulis, Dengxin Dai and Luc Van Gool: + Gated CRF Loss for Weakly Supervised Semantic Image Segmentation. `CoRR + `_ 2019. """ def forward( self, y_hat_softmax, kernels_desc, kernels_radius, sample, height_input, width_input, @@ -27,18 +22,13 @@ def forward( ): """ Performs the forward pass of the loss. + :param y_hat_softmax: A tensor of predicted per-pixel class probabilities of size NxCxHxW :param kernels_desc: A list of dictionaries, each describing one Gaussian kernel composition from modalities. The final kernel is a weighted sum of individual kernels. Following example is a composition of RGBXY and XY kernels: - kernels_desc: [{ - 'weight': 0.9, # Weight of RGBXY kernel - 'xy': 6, # Sigma for XY - 'rgb': 0.1, # Sigma for RGB - },{ - 'weight': 0.1, # Weight of XY kernel - 'xy': 6, # Sigma for XY - }] + kernels_desc: [{'weight': 0.9,'xy': 6,'rgb': 0.1},{'weight': 0.1,'xy': 6}] + :param kernels_radius: Defines size of bounding box region around each pixel in which the kernel is constructed. :param sample: A dictionary with modalities (except 'xy') used in kernels_desc parameter. Each of the provided modalities is allowed to be larger than the shape of y_hat_softmax, in such case downsampling will be diff --git a/pymic/loss/seg/gatedcrf_backup.py b/pymic/loss/seg/gatedcrf_backup.py deleted file mode 100644 index 710fa8b..0000000 --- a/pymic/loss/seg/gatedcrf_backup.py +++ /dev/null @@ -1,48 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division - -import torch -import torch.nn as nn -import numpy as np -from pymic.loss.seg.ce import CrossEntropyLoss -from pymic.loss.seg.gatedcrf_util import ModelLossSemsegGatedCRF - -class GatedCRFLoss(nn.Module): - def __init__(self, params): - super(GatedCRFLoss, self).__init__() - self.gcrf_loss = ModelLossSemsegGatedCRF() - self.softmax = params.get('loss_softmax', True) - w0 = params['GatedCRFLoss_W0'.lower()] - xy0= params['GatedCRFLoss_XY0'.lower()] - rgb= params['GatedCRFLoss_rgb'.lower()] - w1 = params['GatedCRFLoss_W1'.lower()] - xy1= params['GatedCRFLoss_XY1'.lower()] - kernel0 = {'weight': w0, 'xy': xy0, 'rgb': rgb} - kernel1 = {'weight': w1, 'xy': xy1} - self.kernels = [kernel0, kernel1] - self.radius = params['GatedCRFLoss_Radius'.lower()] - - def forward(self, loss_input_dict): - predict = loss_input_dict['prediction'] - image = loss_input_dict['image'] # should be normalized by mean, std - scribble= loss_input_dict['scribbles'] - validity_mask = loss_input_dict['validity_mask'] - - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) - - batch_dict = {'rgb': image, - 'semseg_scribbles': scribble} - x_shape = list(predict.shape) - l_crf = {'loss': 0} - if(self.gcrf_w > 0): - l_crf = self.gcrf_loss(predict, - self.kernels, - self.radius, - batch_dict, - x_shape[-2], - x_shape[-1], - mask_src=validity_mask, - out_kernels_vis=False, - ) - return l_crf['loss'] \ No newline at end of file diff --git a/pymic/loss/seg/mse.py b/pymic/loss/seg/mse.py index 305f188..511cd28 100644 --- a/pymic/loss/seg/mse.py +++ b/pymic/loss/seg/mse.py @@ -1,7 +1,15 @@ import torch import torch.nn as nn +from pymic.loss.seg.abstract import AbstractSegLoss -class MSELoss(nn.Module): +class MSELoss(AbstractSegLoss): + """ + Mean Sequare Loss for segmentation tasks. + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not. + """ def __init__(self, params): super(MSELoss, self).__init__() if(params is None): @@ -22,7 +30,14 @@ def forward(self, loss_input_dict): return mse -class MAELoss(nn.MSELoss): +class MAELoss(AbstractSegLoss): + """ + Mean Absolute Loss for segmentation tasks. + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not. + """ def __init__(self, params): super(MAELoss, self).__init__(params) if(params is None): @@ -30,7 +45,7 @@ def __init__(self, params): else: self.softmax = params.get('loss_softmax', True) - def get_prediction_error(self, loss_input_dict): + def forward(self, loss_input_dict): predict = loss_input_dict['prediction'] soft_y = loss_input_dict['ground_truth'] diff --git a/pymic/loss/seg/mumford_shah.py b/pymic/loss/seg/mumford_shah.py index f167b71..9583dc0 100644 --- a/pymic/loss/seg/mumford_shah.py +++ b/pymic/loss/seg/mumford_shah.py @@ -6,12 +6,21 @@ class MumfordShahLoss(nn.Module): """ - Implementation of Mumford Shah Loss in this paper: - Boah Kim and Jong Chul Ye: Mumford–Shah Loss Functional - for Image Segmentation With Deep Learning. IEEE TIP, 2019. - The oringial implementation is availabel at: - https://github.com/jongcye/CNN_MumfordShah_Loss + Implementation of Mumford Shah Loss for weakly supervised learning. + + * Reference: Boah Kim and Jong Chul Ye: Mumford–Shah Loss Functional + for Image Segmentation With Deep Learning. IEEE TIP, 2019. + + The oringial implementation is availabel at `Github. + `_ Currently only 2D version is supported. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not. + :param `MumfordShahLoss_penalty`: (optional, str) `l1` or `l2`. Default is `l1`. + :param `MumfordShahLoss_lambda`: (optional, float) Hyper-parameter lambda, default is 1.0. """ def __init__(self, params = None): super(MumfordShahLoss, self).__init__() @@ -23,8 +32,11 @@ def __init__(self, params = None): def get_levelset_loss(self, output, target): """ - output: softmax output of a network - target: the input image + Get the level set loss value. + + :param `output`: (tensor) softmax output of a network. + :param `target`: (tensor) the input image. + :return: the level set loss. """ outshape = output.shape tarshape = target.shape @@ -49,6 +61,18 @@ def get_gradient_loss(self, pred, penalty = "l2"): return loss def forward(self, loss_input_dict): + """ + Forward pass for calculating the loss. + The arguments should be written in the `loss_input_dict` dictionary, + and it has the following fields: + + :param `prediction`: (tensor) Prediction of a network, with the + shape of [N, C, D, H, W] or [N, C, H, W]. + :param `image`: (tensor) Image, with the + shape of [N, C, D, H, W] or [N, C, H, W]. + + :return: Loss function value. + """ predict = loss_input_dict['prediction'] image = loss_input_dict['image'] if(isinstance(predict, (list, tuple))): diff --git a/pymic/loss/seg/slsr.py b/pymic/loss/seg/slsr.py index 706d2fc..18aba3a 100644 --- a/pymic/loss/seg/slsr.py +++ b/pymic/loss/seg/slsr.py @@ -1,19 +1,30 @@ # -*- coding: utf-8 -*- -""" -Spatial Label Smoothing Regularization (SLSR) loss for learning from -noisy annotatins according to the following paper: - Minqing Zhang, Jiantao Gao et al.: - Characterizing Label Errors: Confident Learning for Noisy-Labeled Image - Segmentation, MICCAI 2020. - https://link.springer.com/chapter/10.1007/978-3-030-59710-8_70 -""" + from __future__ import print_function, division import torch import torch.nn as nn +from pymic.loss.seg.abstract import AbstractSegLoss from pymic.loss.seg.util import reshape_tensor_to_2D -class SLSRLoss(nn.Module): +class SLSRLoss(AbstractSegLoss): + """ + Spatial Label Smoothing Regularization (SLSR) loss for learning from + noisy annotatins. This loss requires pixel weighting, please make sure + that a `pixel_weight` field is provided for the csv file of the training images. + + The pixel wight here is actually the confidence mask, i.e., if the value is one, + it means the label of corresponding pixel is noisy and should be smoothed. + + * Reference: Minqing Zhang, Jiantao Gao et al.: Characterizing Label Errors: Confident Learning for Noisy-Labeled Image + Segmentation, `MICCAI 2020. `_ + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not. + :param `slsrloss_epsilon`: (optional, float) Hyper-parameter epsilon. Default is 0.25. + """ def __init__(self, params): super(SLSRLoss, self).__init__() if(params is None): @@ -25,9 +36,6 @@ def forward(self, loss_input_dict): predict = loss_input_dict['prediction'] soft_y = loss_input_dict['ground_truth'] pix_w = loss_input_dict.get('pixel_weight', None) - # the pixel wight here is actually the confidence mask - # i.e., if the value is one, it means the label of corresponding - # pixel is noisy and should be replaced with smoothed label. if(isinstance(predict, (list, tuple))): predict = predict[0] diff --git a/pymic/loss/seg/ssl.py b/pymic/loss/seg/ssl.py index 5b22157..7eaf981 100644 --- a/pymic/loss/seg/ssl.py +++ b/pymic/loss/seg/ssl.py @@ -9,7 +9,11 @@ class EntropyLoss(nn.Module): """ - Minimize the entropy for each pixel + Entropy Minimization for segmentation tasks. + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not. """ def __init__(self, params = None): super(EntropyLoss, self).__init__() @@ -19,6 +23,16 @@ def __init__(self, params = None): self.softmax = params.get('loss_softmax', True) def forward(self, loss_input_dict): + """ + Forward pass for calculating the loss. + The arguments should be written in the `loss_input_dict` dictionary, + and it has the following fields: + + :param `prediction`: (tensor) Prediction of a network, with the + shape of [N, C, D, H, W] or [N, C, H, W]. + + :return: Loss function value. + """ predict = loss_input_dict['prediction'] if(isinstance(predict, (list, tuple))): @@ -35,7 +49,11 @@ def forward(self, loss_input_dict): class TotalVariationLoss(nn.Module): """ - Minimize the total variation of a segmentation + Total Variation Loss for segmentation tasks. + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not. """ def __init__(self, params = None): super(TotalVariationLoss, self).__init__() @@ -45,6 +63,16 @@ def __init__(self, params = None): self.softmax = params.get('loss_softmax', True) def forward(self, loss_input_dict): + """ + Forward pass for calculating the loss. + The arguments should be written in the `loss_input_dict` dictionary, + and it has the following fields: + + :param `prediction`: (tensor) Prediction of a network, with the + shape of [N, C, D, H, W] or [N, C, H, W]. + + :return: Loss function value. + """ predict = loss_input_dict['prediction'] if(isinstance(predict, (list, tuple))): diff --git a/pymic/loss/seg/util.py b/pymic/loss/seg/util.py index ed6aba3..4ca41d2 100644 --- a/pymic/loss/seg/util.py +++ b/pymic/loss/seg/util.py @@ -7,9 +7,13 @@ def get_soft_label(input_tensor, num_class, data_type = 'float'): """ - convert a label tensor to one-hot label - input_tensor: tensor with shae [B, 1, D, H, W] or [B, 1, H, W] - output_tensor: shape [B, num_class, D, H, W] or [B, num_class, H, W] + Convert a label tensor to one-hot label for segmentation tasks. + + :param `input_tensor`: (tensor) Tensor with shae [B, 1, D, H, W] or [B, 1, H, W]. + :param `num_class`: (int) The class number. + :param `data_type`: (optional, str) Type of data, `float` (default) or `double`. + + :return: A tensor with shape [B, num_class, D, H, W] or [B, num_class, H, W] """ shape = input_tensor.shape @@ -31,7 +35,7 @@ def get_soft_label(input_tensor, num_class, data_type = 'float'): def reshape_tensor_to_2D(x): """ - reshape input variables of shape [B, C, D, H, W] to [voxel_n, C] + Reshape input tensor of shape [N, C, D, H, W] or [N, C, H, W] to [voxel_n, C] """ tensor_dim = len(x.size()) num_class = list(x.size())[1] @@ -47,7 +51,12 @@ def reshape_tensor_to_2D(x): def reshape_prediction_and_ground_truth(predict, soft_y): """ - reshape input variables of shape [B, C, D, H, W] to [voxel_n, C] + Reshape input variables two 2D. + + :param predict: (tensor) A tensor of shape [N, C, D, H, W] or [N, C, H, W]. + :param soft_y: (tensor) A tensor of shape [N, C, D, H, W] or [N, C, H, W]. + + :return: Two output tensors with shape [voxel_n, C] that correspond to the two inputs. """ tensor_dim = len(predict.size()) num_class = list(predict.size())[1] @@ -67,7 +76,13 @@ def reshape_prediction_and_ground_truth(predict, soft_y): def get_classwise_dice(predict, soft_y, pix_w = None): """ - get dice scores for each class in predict (after softmax) and soft_y + Get dice scores for each class in predict (after softmax) and soft_y. + + :param predict: (tensor) Prediction of a segmentation network after softmax. + :param soft_y: (tensor) The one-hot segmentation ground truth. + :param pix_w: (optional, tensor) The pixel weight map. Default is None. + + :return: Dice score for each class. """ if(pix_w is None): From 8fec1d0767f117f5c529de49afc9278654e8e55f Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 3 Sep 2022 18:25:07 +0800 Subject: [PATCH 106/225] Update agent_abstract.py --- pymic/net_run/agent_abstract.py | 52 +++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/pymic/net_run/agent_abstract.py b/pymic/net_run/agent_abstract.py index 8ffadc1..f047d24 100644 --- a/pymic/net_run/agent_abstract.py +++ b/pymic/net_run/agent_abstract.py @@ -21,6 +21,19 @@ def seed_torch(seed=1): torch.backends.cudnn.deterministic = True class NetRunAgent(object): + """ + The abstract class for medical image segmentation. + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + .. note:: + + The config dictionary should have at least four sections: `dataset`, + `network`, `training` and `inference`. See :doc:`usage.quickstart` and + :doc:`usage.fsl` for example. + + """ __metaclass__ = ABCMeta def __init__(self, config, stage = 'train'): assert(stage in ['train', 'inference', 'test']) @@ -46,26 +59,65 @@ def __init__(self, config, stage = 'train'): logging.info("deterministric is true") def set_datasets(self, train_set, valid_set, test_set): + """ + Set customized datasets for training and inference. + + :param train_set: (torch.utils.data.Dataset) The training set. + :param valid_set: (torch.utils.data.Dataset) The validation set. + :param test_set: (torch.utils.data.Dataset) The testing set. + """ self.train_set = train_set self.valid_set = valid_set self.test_set = test_set def set_transform_dict(self, custom_transform_dict): + """ + Set the available Transforms, including customized Transforms. + + :param custom_transform_dict: (dictionary) A dictionary of + available Transforms. + """ self.transform_dict = custom_transform_dict def set_network(self, net): + """ + Set the network. + + :param net: (nn.Module) A deep learning network. + """ self.net = net def set_loss_dict(self, loss_dict): + """ + Set the available loss functions, including customized loss functions. + + :param loss_dict: (dictionary) A dictionary of + available loss functions. + """ self.loss_dict = loss_dict def set_optimizer(self, optimizer): + """ + Set the optimizer. + + :param optimizer: An optimizer. + """ self.optimizer = optimizer def set_scheduler(self, scheduler): + """ + Set the learning rate scheduler. + + :param scheduler: A learning rate scheduler. + """ self.scheduler = scheduler def set_inferer(self, inferer): + """ + Set the inferer. + + :param inferer: An inferer object. + """ self.inferer = inferer def get_checkpoint_name(self): From 96aaf583d11407220e2b21e9ea13ebc7aa3c9877 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 3 Sep 2022 21:52:43 +0800 Subject: [PATCH 107/225] update doc for semi-supervised segmentation update doc for semi-supervised segmentation --- pymic/net_run/agent_abstract.py | 93 ++++++++++++++++++++++++++++++- pymic/net_run/agent_cls.py | 27 +++++++-- pymic/net_run/agent_seg.py | 76 +++++++------------------ pymic/net_run/infer_func.py | 18 ++++++ pymic/net_run/net_run.py | 3 + pymic/net_run_ssl/ssl_abstract.py | 17 ++++-- pymic/net_run_ssl/ssl_cct.py | 23 ++++++-- pymic/net_run_ssl/ssl_cps.py | 19 +++++-- pymic/net_run_ssl/ssl_em.py | 19 +++++-- pymic/net_run_ssl/ssl_main.py | 3 + pymic/net_run_ssl/ssl_mt.py | 19 +++++-- pymic/net_run_ssl/ssl_uamt.py | 19 +++++-- pymic/net_run_ssl/ssl_urpc.py | 21 +++++-- 13 files changed, 258 insertions(+), 99 deletions(-) diff --git a/pymic/net_run/agent_abstract.py b/pymic/net_run/agent_abstract.py index f047d24..dc6b1ad 100644 --- a/pymic/net_run/agent_abstract.py +++ b/pymic/net_run/agent_abstract.py @@ -11,6 +11,11 @@ from pymic.net_run.get_optimizer import get_lr_scheduler, get_optimizer def seed_torch(seed=1): + """ + Set random seed. + + :param seed: (int) the seed for random. + """ random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) @@ -47,6 +52,7 @@ def __init__(self, config, stage = 'train'): self.net = None self.optimizer = None self.scheduler = None + self.net_dict = None self.loss_dict = None self.transform_dict = None self.inferer = None @@ -87,12 +93,19 @@ def set_network(self, net): """ self.net = net + def set_net_dict(self, net_dict): + """ + Set the available networks, including customized networks. + + :param net_dict: (dictionary) A dictionary of available networks. + """ + self.net_dict = net_dict + def set_loss_dict(self, loss_dict): """ Set the available loss functions, including customized loss functions. - :param loss_dict: (dictionary) A dictionary of - available loss functions. + :param loss_dict: (dictionary) A dictionary of available loss functions. """ self.loss_dict = loss_dict @@ -121,6 +134,9 @@ def set_inferer(self, inferer): self.inferer = inferer def get_checkpoint_name(self): + """ + Get the checkpoint name for inference based on config['testing']['ckpt_mode']. + """ ckpt_mode = self.config['testing']['ckpt_mode'] if(ckpt_mode == 0 or ckpt_mode == 1): ckpt_dir = self.config['training']['ckpt_save_dir'] @@ -138,33 +154,94 @@ def get_checkpoint_name(self): @abstractmethod def get_stage_dataset_from_config(self, stage): + """ + Create dataset based on training, validation or inference stage. + + :param stage: (str) `train`, `valid` or `test`. + """ raise(ValueError("not implemented")) @abstractmethod def get_parameters_to_update(self): + """ + Get parameters for update. + """ + raise(ValueError("not implemented")) + + @abstractmethod + def get_loss_value(self, data, pred, gt, param = None): + """ + Get the loss value. Assume `pred` and `gt` has been sent to self.device. + `data` is obtained by dataloader, and is a dictionary containing extra + information, such as pixel-level weight. By default, such information + is not used by standard loss functions such as Dice loss and cross entropy loss. + + + :param data: (dictionary) A data dictionary obtained by dataloader. + :param pred: (tensor) Prediction result by the network. + :param gt: (tensor) Ground truth. + :param param: (dictionary) Other parameters if needed. + """ raise(ValueError("not implemented")) @abstractmethod def create_network(self): + """ + Create network based on configuration. + """ + raise(ValueError("not implemented")) + + @abstractmethod + def create_loss_calculator(self): + """ + Create loss function object. + """ raise(ValueError("not implemented")) @abstractmethod def training(self): + """ + Train the network + """ raise(ValueError("not implemented")) @abstractmethod def validation(self): + """ + Evaluate the performance on the validation set. + """ raise(ValueError("not implemented")) @abstractmethod def train_valid(self): + """ + Train and valid. + """ raise(ValueError("not implemented")) @abstractmethod def infer(self): + """ + Inference on testing set. + """ + raise(ValueError("not implemented")) + + @abstractmethod + def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): + """ + Write scalars using SummaryWriter. + + :param train_scalars: (dictionary) Scalars for training set. + :param valid_scalars: (dictionary) Scalars for validation set. + :param lr_value: (float) Current learning rate. + :param glob_it: (int) Current iteration number. + """ raise(ValueError("not implemented")) def create_dataset(self): + """ + Create datasets for training, validation or testing based on configuraiton. + """ if(self.stage == 'train'): if(self.train_set is None): self.train_set = self.get_stage_dataset_from_config('train') @@ -200,6 +277,12 @@ def worker_init_fn(worker_id): batch_size = bn_test, shuffle=False, num_workers= bn_test) def create_optimizer(self, params): + """ + Create optimizer based on configuration. + + :param params: network parameters for optimization. Usually it is obtained by + `self.get_parameters_to_update()`. + """ opt_params = self.config['training'] if(self.optimizer is None): self.optimizer = get_optimizer(opt_params['optimizer'], @@ -213,12 +296,18 @@ def create_optimizer(self, params): self.scheduler = get_lr_scheduler(self.optimizer, opt_params) def convert_tensor_type(self, input_tensor): + """ + Convert the type of an input tensor to float or double based on configuration. + """ if(self.tensor_type == 'float'): return input_tensor.float() else: return input_tensor.double() def run(self): + """ + Run the training or inference code according to configuration. + """ self.create_dataset() self.create_network() if(self.stage == 'train'): diff --git a/pymic/net_run/agent_cls.py b/pymic/net_run/agent_cls.py index 8687048..bb04a9d 100644 --- a/pymic/net_run/agent_cls.py +++ b/pymic/net_run/agent_cls.py @@ -20,6 +20,18 @@ warnings.filterwarnings('ignore', '.*output shape of zoom.*') class ClassificationAgent(NetRunAgent): + """ + The agent for image classificaiton tasks. + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + .. note:: + + The config dictionary should have at least four sections: `dataset`, + `network`, `training` and `inference`. See :doc:`usage.quickstart` and + :doc:`usage.fsl` for example. + """ def __init__(self, config, stage = 'train'): super(ClassificationAgent, self).__init__(config, stage) self.transform_dict = TransformDict @@ -87,14 +99,20 @@ def create_loss_calculator(self): else: raise ValueError("Undefined loss function {0:}".format(loss_name)) - def get_loss_value(self, data, inputs, outputs, labels): + def get_loss_value(self, data, pred, gt, param = None): loss_input_dict = {} - loss_input_dict['prediction'] = outputs - loss_input_dict['ground_truth'] = labels + loss_input_dict['prediction'] = pred + loss_input_dict['ground_truth'] = gt loss_value = self.loss_calculater(loss_input_dict) return loss_value def get_evaluation_score(self, outputs, labels): + """ + Get evaluation score for a prediction. + + :param outputs: (tensor) Prediction obtained by a network. + :param labels: (tensor) The ground truth. + """ metrics = self.config['training'].get("evaluation_metric", "accuracy") if(metrics != "accuracy"): # default classification accuracy raise ValueError("Not implemeted for metric {0:}".format(metrics)) @@ -170,12 +188,13 @@ def validation(self): valid_scalers = {'loss': avg_loss, metrics: avg_score} return valid_scalers - def write_scalars(self, train_scalars, valid_scalars, glob_it): + def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): metrics =self.config['training'].get("evaluation_metric", "accuracy") loss_scalar ={'train':train_scalars['loss'], 'valid':valid_scalars['loss']} acc_scalar ={'train':train_scalars[metrics],'valid':valid_scalars[metrics]} self.summ_writer.add_scalars('loss', loss_scalar, glob_it) self.summ_writer.add_scalars(metrics, acc_scalar, glob_it) + self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it) logging.info("{0:} it {1:}".format(str(datetime.now())[:-7], glob_it)) logging.info('train loss {0:.4f}, avg {1:} {2:.4f}'.format( diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 5a53f3e..bba4508 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -33,6 +33,7 @@ class SegmentationAgent(NetRunAgent): def __init__(self, config, stage = 'train'): super(SegmentationAgent, self).__init__(config, stage) self.transform_dict = TransformDict + self.net_dict = SegNetDict self.postprocess_dict = PostProcessDict self.postprocessor = None @@ -70,9 +71,9 @@ def get_stage_dataset_from_config(self, stage): def create_network(self): if(self.net is None): net_name = self.config['network']['net_type'] - if(net_name not in SegNetDict): + if(net_name not in self.net_dict): raise ValueError("Undefined network {0:}".format(net_name)) - self.net = SegNetDict[net_name](self.config['network']) + self.net = self.net_dict[net_name](self.config['network']) if(self.tensor_type == 'float'): self.net.float() else: @@ -83,50 +84,6 @@ def create_network(self): def get_parameters_to_update(self): return self.net.parameters() - def get_class_level_weight(self): - class_num = self.config['network']['class_num'] - class_weight= self.config['training'].get('loss_class_weight', None) - if(class_weight is None): - class_weight = torch.ones(class_num) - else: - assert(len(class_weight) == class_num) - class_weight = torch.from_numpy(np.asarray(class_weight)) - class_weight = self.convert_tensor_type(class_weight) - return class_weight - - def get_image_level_weight(self, data): - imageweight_enb = self.config['training'].get('loss_with_image_weight', False) - img_w = None - if(imageweight_enb): - if(self.net.training): - if('image_weight' not in data): - raise ValueError("image weight is enabled not not provided") - img_w = data['image_weight'] - else: - img_w = data.get('image_weight', None) - if(img_w is None): - batch_size = data['image'].shape[0] - img_w = torch.ones(batch_size) - img_w = self.convert_tensor_type(img_w) - return img_w - - def get_pixel_level_weight(self, data): - pixelweight_enb = self.config['training'].get('loss_with_pixel_weight', False) - pix_w = None - if(pixelweight_enb): - if(self.net.training): - if('pixel_weight' not in data): - raise ValueError("pixel weight is enabled but not provided") - pix_w = data['pixel_weight'] - else: - pix_w = data.get('pixel_weight', None) - if(pix_w is None): - pix_w_shape = list(data['label_prob'].shape) - pix_w_shape[1] = 1 - pix_w = torch.ones(pix_w_shape) - pix_w = self.convert_tensor_type(pix_w) - return pix_w - def create_loss_calculator(self): if(self.loss_dict is None): self.loss_dict = SegLossDict @@ -145,13 +102,6 @@ def create_loss_calculator(self): self.loss_calculator = base_loss def get_loss_value(self, data, pred, gt, param = None): - """ - Assume pred and gt has been sent to self.device - data is obtained by dataloaderis and is dictionary containing extra - information, such as pixel-level weight, class-level weight. - By default, such information is not used by standard loss functions - such as Dice loss and cross entropy loss. - """ loss_input_dict = {'prediction':pred, 'ground_truth': gt} if data.get('pixel_weight', None) is not None: loss_input_dict['pixel_weight'] = data['pixel_weight'].to(pred.device) @@ -159,6 +109,12 @@ def get_loss_value(self, data, pred, gt, param = None): return loss_value def set_postprocessor(self, postprocessor): + """ + Set post processor after prediction. + + :param postprocessor: post processor, such as an instance of + `pymic.util.post_process.PostProcess`. + """ self.postprocessor = postprocessor def training(self): @@ -451,14 +407,14 @@ def test_time_dropout(m): infer_time = time.time() - start_time infer_time_list.append(infer_time) - self.save_ouputs(data) + self.save_outputs(data) infer_time_list = np.asarray(infer_time_list) time_avg, time_std = infer_time_list.mean(), infer_time_list.std() logging.info("testing time {0:} +/- {1:}".format(time_avg, time_std)) def infer_with_multiple_checkpoints(self): """ - inference with ensemble of multilple check points + Inference with ensemble of multilple check points. """ device_ids = self.config['testing']['gpus'] device = torch.device("cuda:{0:}".format(device_ids[0])) @@ -505,12 +461,18 @@ def infer_with_multiple_checkpoints(self): infer_time = time.time() - start_time infer_time_list.append(infer_time) - self.save_ouputs(data) + self.save_outputs(data) infer_time_list = np.asarray(infer_time_list) time_avg, time_std = infer_time_list.mean(), infer_time_list.std() logging.info("testing time {0:} +/- {1:}".format(time_avg, time_std)) - def save_ouputs(self, data): + def save_outputs(self, data): + """ + Save prediction output. + + :param data: (dictionary) A data dictionary with prediciton result and other + information such as input image name. + """ output_dir = self.config['testing']['output_dir'] ignore_dir = self.config['testing'].get('filename_ignore_dir', True) save_prob = self.config['testing'].get('save_probability', False) diff --git a/pymic/net_run/infer_func.py b/pymic/net_run/infer_func.py index 78184fe..81d0b53 100644 --- a/pymic/net_run/infer_func.py +++ b/pymic/net_run/infer_func.py @@ -5,6 +5,18 @@ from torch.nn.functional import interpolate class Inferer(object): + """ + The class for inference. + The arguments should be written in the `config` dictionary, + and it has the following fields: + + :param `sliding_window_enable`: (optional, bool) Default is `False`. + :param `sliding_window_size`: (optional, list) The sliding window size. + :param `sliding_window_stride`: (optional, list) The sliding window stride. + :param `tta_mode`: (optional, int) The test time augmentation mode. Default + is 0 (no test time augmentation). The other option is 1 (augmentation + with horinzontal and vertical flipping). + """ def __init__(self, config): self.config = config @@ -127,6 +139,12 @@ def __infer_with_sliding_window(self, image): return output_list def run(self, model, image): + """ + Using `model` for inference on `image`. + + :param model: (nn.Module) a network. + :param image: (tensor) An image. + """ self.model = model tta_mode = self.config.get('tta_mode', 0) if(tta_mode == 0): diff --git a/pymic/net_run/net_run.py b/pymic/net_run/net_run.py index 4c953ad..aec6fa0 100644 --- a/pymic/net_run/net_run.py +++ b/pymic/net_run/net_run.py @@ -8,6 +8,9 @@ from pymic.net_run.agent_seg import SegmentationAgent def main(): + """ + The main function for running a network for training or inference. + """ if(len(sys.argv) < 3): print('Number of arguments should be 3. e.g.') print(' pymic_run train config.cfg') diff --git a/pymic/net_run_ssl/ssl_abstract.py b/pymic/net_run_ssl/ssl_abstract.py index 1d18c4d..e847d6c 100644 --- a/pymic/net_run_ssl/ssl_abstract.py +++ b/pymic/net_run_ssl/ssl_abstract.py @@ -15,10 +15,16 @@ class SSLSegAgent(SegmentationAgent): """ - Implementation of the following paper: - Yves Grandvalet and Yoshua Bengio, - Semi-supervised Learningby Entropy Minimization. - NeurIPS, 2005. + Abstract class for semi-supervised segmentation. + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. """ def __init__(self, config, stage = 'train'): super(SSLSegAgent, self).__init__(config, stage) @@ -26,6 +32,9 @@ def __init__(self, config, stage = 'train'): self.train_set_unlab = None def get_unlabeled_dataset_from_config(self): + """ + Create a dataset for the unlabeled images based on configuration. + """ root_dir = self.config['dataset']['root_dir'] modal_num = self.config['dataset'].get('modal_num', 1) transform_names = self.config['dataset']['train_transform_unlab'] diff --git a/pymic/net_run_ssl/ssl_cct.py b/pymic/net_run_ssl/ssl_cct.py index d0c4f24..0bc23a2 100644 --- a/pymic/net_run_ssl/ssl_cct.py +++ b/pymic/net_run_ssl/ssl_cct.py @@ -62,12 +62,23 @@ def softmax_js_loss(inputs, targets, **_): class SSLCCT(SSLSegAgent): """ - Cross-Consistency Training according to the following paper: - Yassine Ouali, Celine Hudelot and Myriam Tami: - Semi-Supervised Semantic Segmentation With Cross-Consistency Training. - CVPR 2020. - https://arxiv.org/abs/2003.09005 - Code adapted from: https://github.com/yassouali/CCT + Cross-Consistency Training for semi-supervised segmentation. It requires a network + with multiple decoders for learning, such as `pymic.net.net2d.unet2d_cct.UNet2D_CCT`. + + * Reference: Yassine Ouali, Celine Hudelot and Myriam Tami: + Semi-Supervised Semantic Segmentation With Cross-Consistency Training. + `CVPR 2020. `_ + + The Code is adapted from `Github `_ + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. """ def training(self): class_num = self.config['network']['class_num'] diff --git a/pymic/net_run_ssl/ssl_cps.py b/pymic/net_run_ssl/ssl_cps.py index 2264d0d..239443c 100644 --- a/pymic/net_run_ssl/ssl_cps.py +++ b/pymic/net_run_ssl/ssl_cps.py @@ -30,11 +30,20 @@ def forward(self, x): class SSLCPS(SSLSegAgent): """ - Using cross pseudo supervision according to the following paper: - Xiaokang Chen, Yuhui Yuan, Gang Zeng, Jingdong Wang, - Semi-Supervised Semantic Segmentation with Cross Pseudo Supervision, - CVPR 2021, pp. 2613-2022. - https://arxiv.org/abs/2106.01226 + Using cross pseudo supervision for semi-supervised segmentation. + + * Reference: Xiaokang Chen, Yuhui Yuan, Gang Zeng, Jingdong Wang, + Semi-Supervised Semantic Segmentation with Cross Pseudo Supervision, + `CVPR 2021 `_, pp. 2613-2022. + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. """ def __init__(self, config, stage = 'train'): super(SSLCPS, self).__init__(config, stage) diff --git a/pymic/net_run_ssl/ssl_em.py b/pymic/net_run_ssl/ssl_em.py index 49dd22f..745501f 100644 --- a/pymic/net_run_ssl/ssl_em.py +++ b/pymic/net_run_ssl/ssl_em.py @@ -14,11 +14,20 @@ class SSLEntropyMinimization(SSLSegAgent): """ - Implementation of the following paper: - Yves Grandvalet and Yoshua Bengio: - Semi-supervised Learningby Entropy Minimization. - NeurIPS, 2005. - https://papers.nips.cc/paper/2004/file/96f2b50b5d3613adf9c27049b2a888c7-Paper.pdf + Using Entropy Minimization for semi-supervised segmentation. + + * Reference: Yves Grandvalet and Yoshua Bengio: + Semi-supervised Learningby Entropy Minimization. + `NeurIPS, 2005. `_ + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. """ def __init__(self, config, stage = 'train'): super(SSLEntropyMinimization, self).__init__(config, stage) diff --git a/pymic/net_run_ssl/ssl_main.py b/pymic/net_run_ssl/ssl_main.py index d904ab1..7c4f2b9 100644 --- a/pymic/net_run_ssl/ssl_main.py +++ b/pymic/net_run_ssl/ssl_main.py @@ -21,6 +21,9 @@ 'URPC': SSLURPC} def main(): + """ + Main function for running a semi-supervised method. + """ if(len(sys.argv) < 3): print('Number of arguments should be 3. e.g.') print(' pymic_ssl train config.cfg') diff --git a/pymic/net_run_ssl/ssl_mt.py b/pymic/net_run_ssl/ssl_mt.py index 0456726..bf4dacd 100644 --- a/pymic/net_run_ssl/ssl_mt.py +++ b/pymic/net_run_ssl/ssl_mt.py @@ -13,11 +13,20 @@ class SSLMeanTeacher(SSLSegAgent): """ - Mean Teacher for semi-supervised learning according to the following paper: - Antti Tarvainen, Harri Valpola: Mean teachers are better role models: Weight-averaged - consistency targets improve semi-supervised deep learning results. - NeurIPS 2017. - https://arxiv.org/abs/1703.01780 + Mean Teacher for semi-supervised segmentation. + + * Reference: Antti Tarvainen, Harri Valpola: Mean teachers are better role models: + Weight-averaged consistency targets improve semi-supervised deep learning results. + `NeurIPS 2017. `_ + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. """ def __init__(self, config, stage = 'train'): super(SSLMeanTeacher, self).__init__(config, stage) diff --git a/pymic/net_run_ssl/ssl_uamt.py b/pymic/net_run_ssl/ssl_uamt.py index 360dab1..40a742b 100644 --- a/pymic/net_run_ssl/ssl_uamt.py +++ b/pymic/net_run_ssl/ssl_uamt.py @@ -12,11 +12,20 @@ class SSLUncertaintyAwareMeanTeacher(SSLMeanTeacher): """ - Uncertainty Aware Mean Teacher according to the following paper: - Lequan Yu, Shujun Wang, Xiaomeng Li, Chi-Wing Fu, and Pheng-Ann Heng. - Uncertainty-aware Self-ensembling Model for Semi-supervised 3D Left - Atrium Segmentation, MICCAI 2019. - https://arxiv.org/abs/1907.07034 + Uncertainty Aware Mean Teacher for semi-supervised segmentation. + + * Reference: Lequan Yu, Shujun Wang, Xiaomeng Li, Chi-Wing Fu, and Pheng-Ann Heng. + Uncertainty-aware Self-ensembling Model for Semi-supervised 3D Left Atrium + Segmentation, `MICCAI 2019. `_ + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. """ def training(self): class_num = self.config['network']['class_num'] diff --git a/pymic/net_run_ssl/ssl_urpc.py b/pymic/net_run_ssl/ssl_urpc.py index 20b3d84..336e525 100644 --- a/pymic/net_run_ssl/ssl_urpc.py +++ b/pymic/net_run_ssl/ssl_urpc.py @@ -13,12 +13,21 @@ class SSLURPC(SSLSegAgent): """ - Uncertainty-Rectified Pyramid Consistency according to the following paper: - Xiangde Luo, Guotai Wang*, Wenjun Liao, Jieneng Chen, Tao Song, Yinan Chen, - Shichuan Zhang, Dimitris N. Metaxas, Shaoting Zhang. - Semi-Supervised Medical Image Segmentation via Uncertainty Rectified Pyramid Consistency . - Medical Image Analysis 2022. - https://doi.org/10.1016/j.media.2022.102517 + Uncertainty-Rectified Pyramid Consistency for semi-supervised segmentation. + + * Reference: Xiangde Luo, Guotai Wang*, Wenjun Liao, Jieneng Chen, Tao Song, Yinan Chen, + Shichuan Zhang, Dimitris N. Metaxas, Shaoting Zhang. + Semi-Supervised Medical Image Segmentation via Uncertainty Rectified Pyramid Consistency . + `Medical Image Analysis 2022. `_ + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. """ def training(self): class_num = self.config['network']['class_num'] From 0f3710e65c622e7446aca341b6dfee823d77e552 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 3 Sep 2022 22:36:55 +0800 Subject: [PATCH 108/225] add docs for nll and wsl add docs for nll and wsl --- docs/source/pymic.net_run_nll.rst | 8 -- pymic/net_run_nll/nll_cl.py | 192 -------------------------- pymic/net_run_nll/nll_clslsr.py | 39 ++++-- pymic/net_run_nll/nll_co_teaching.py | 26 ++-- pymic/net_run_nll/nll_dast.py | 45 ++++-- pymic/net_run_nll/nll_main.py | 3 + pymic/net_run_nll/nll_trinet.py | 28 ++-- pymic/net_run_wsl/wsl_abstract.py | 14 +- pymic/net_run_wsl/wsl_dmpls.py | 21 ++- pymic/net_run_wsl/wsl_em.py | 15 +- pymic/net_run_wsl/wsl_gatedcrf.py | 20 ++- pymic/net_run_wsl/wsl_main.py | 5 +- pymic/net_run_wsl/wsl_mumford_shah.py | 18 ++- pymic/net_run_wsl/wsl_tv.py | 11 +- pymic/net_run_wsl/wsl_ustm.py | 21 ++- 15 files changed, 188 insertions(+), 278 deletions(-) delete mode 100644 pymic/net_run_nll/nll_cl.py diff --git a/docs/source/pymic.net_run_nll.rst b/docs/source/pymic.net_run_nll.rst index 120d23c..40ee9ef 100644 --- a/docs/source/pymic.net_run_nll.rst +++ b/docs/source/pymic.net_run_nll.rst @@ -4,14 +4,6 @@ pymic.net\_run\_nll package Submodules ---------- -pymic.net\_run\_nll.nll\_cl module ----------------------------------- - -.. automodule:: pymic.net_run_nll.nll_cl - :members: - :undoc-members: - :show-inheritance: - pymic.net\_run\_nll.nll\_clslsr module -------------------------------------- diff --git a/pymic/net_run_nll/nll_cl.py b/pymic/net_run_nll/nll_cl.py deleted file mode 100644 index 8173471..0000000 --- a/pymic/net_run_nll/nll_cl.py +++ /dev/null @@ -1,192 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Caculating the confidence map of labels of training samples, -which is used in the method of SLSR. - Minqing Zhang et al., Characterizing Label Errors: Confident Learning - for Noisy-Labeled Image Segmentation, MICCAI 2020. -""" - -from __future__ import print_function, division -import cleanlab -import logging -import os -import scipy -import sys -import torch -import numpy as np -import pandas as pd -import torch.nn as nn -import torchvision.transforms as transforms -from PIL import Image -from pymic.io.nifty_dataset import NiftyDataset -from pymic.transform.trans_dict import TransformDict -from pymic.util.parse_config import * -from pymic.net_run.agent_seg import SegmentationAgent -from pymic.net_run.infer_func import Inferer - -def get_confident_map(gt, pred, CL_type = 'both'): - """ - gt: ground truth label (one-hot) with shape of NXC - pred: digit prediction of network with shape of NXC - """ - prob = scipy.special.softmax(pred, axis = 1) - if CL_type in ['both', 'Qij']: - noise = cleanlab.pruning.get_noise_indices(gt, prob, prune_method='both', n_jobs=1) - elif CL_type == 'Cij': - noise = cleanlab.pruning.get_noise_indices(gt, pred, prune_method='both', n_jobs=1) - elif CL_type == 'intersection': - noise_qij = cleanlab.pruning.get_noise_indices(gt, prob, prune_method='both', n_jobs=1) - noise_cij = cleanlab.pruning.get_noise_indices(gt, pred, prune_method='both', n_jobs=1) - noise = noise_qij & noise_cij - elif CL_type == 'union': - noise_qij = cleanlab.pruning.get_noise_indices(gt, prob, prune_method='both', n_jobs=1) - noise_cij = cleanlab.pruning.get_noise_indices(gt, pred, prune_method='both', n_jobs=1) - noise = noise_qij | noise_cij - elif CL_type in ['prune_by_class', 'prune_by_noise_rate']: - noise = cleanlab.pruning.get_noise_indices(gt, prob, prune_method=CL_type, n_jobs=1) - return noise - -class NLLConfidentLearn(SegmentationAgent): - def __init__(self, config, stage = 'test'): - super(NLLConfidentLearn, self).__init__(config, stage) - - def infer_with_cl(self): - device_ids = self.config['testing']['gpus'] - device = torch.device("cuda:{0:}".format(device_ids[0])) - self.net.to(device) - - if(self.config['testing'].get('evaluation_mode', True)): - self.net.eval() - if(self.config['testing'].get('test_time_dropout', False)): - def test_time_dropout(m): - if(type(m) == nn.Dropout): - logging.info('dropout layer') - m.train() - self.net.apply(test_time_dropout) - - ckpt_mode = self.config['testing']['ckpt_mode'] - ckpt_name = self.get_checkpoint_name() - if(ckpt_mode == 3): - assert(isinstance(ckpt_name, (tuple, list))) - self.infer_with_multiple_checkpoints() - return - else: - if(isinstance(ckpt_name, (tuple, list))): - raise ValueError("ckpt_mode should be 3 if ckpt_name is a list") - - # load network parameters and set the network as evaluation mode - checkpoint = torch.load(ckpt_name, map_location = device) - self.net.load_state_dict(checkpoint['model_state_dict']) - - if(self.inferer is None): - infer_cfg = self.config['testing'] - class_num = self.config['network']['class_num'] - infer_cfg['class_num'] = class_num - self.inferer = Inferer(infer_cfg) - pred_list = [] - gt_list = [] - filename_list = [] - with torch.no_grad(): - for data in self.test_loader: - images = self.convert_tensor_type(data['image']) - labels = self.convert_tensor_type(data['label_prob']) - names = data['names'] - filename_list.append(names) - images = images.to(device) - - pred = self.inferer.run(self.net, images) - # convert tensor to numpy - if(isinstance(pred, (tuple, list))): - pred = [item.cpu().numpy() for item in pred] - else: - pred = pred.cpu().numpy() - data['predict'] = pred - # inverse transform - for transform in self.transform_list[::-1]: - if (transform.inverse): - data = transform.inverse_transform_for_prediction(data) - - pred = data['predict'] - # conver prediction from N, C, H, W to (N*H*W)*C - print(names, pred.shape, labels.shape) - pred_2d = np.swapaxes(pred, 1, 2) - pred_2d = np.swapaxes(pred_2d, 2, 3) - pred_2d = pred_2d.reshape(-1, class_num) - lab = labels.cpu().numpy() - lab_2d = np.swapaxes(lab, 1, 2) - lab_2d = np.swapaxes(lab_2d, 2, 3) - lab_2d = lab_2d.reshape(-1, class_num) - pred_list.append(pred_2d) - gt_list.append(lab_2d) - - pred_cat = np.concatenate(pred_list) - gt_cat = np.concatenate(gt_list) - gt = np.argmax(gt_cat, axis = 1) - gt = gt.reshape(-1).astype(np.uint8) - print(gt.shape, pred_cat.shape) - conf = get_confident_map(gt, pred_cat) - conf = conf.reshape(-1, 256, 256).astype(np.uint8) * 255 - save_dir = self.config['testing']['output_dir'] + "_conf" - for idx in range(len(filename_list)): - filename = filename_list[idx][0].split('/')[-1] - conf_map = Image.fromarray(conf[idx]) - dst_path = os.path.join(save_dir, filename) - conf_map.save(dst_path) - -def get_confidence_map(): - if(len(sys.argv) < 2): - print('Number of arguments should be 3. e.g.') - print(' python nll_cl.py config.cfg') - exit() - cfg_file = str(sys.argv[1]) - config = parse_config(cfg_file) - config = synchronize_config(config) - - # set dataset - transform_names = config['dataset']['valid_transform'] - transform_list = [] - transform_dict = TransformDict - if(transform_names is None or len(transform_names) == 0): - data_transform = None - else: - transform_param = config['dataset'] - transform_param['task'] = 'segmentation' - for name in transform_names: - if(name not in transform_dict): - raise(ValueError("Undefined transform {0:}".format(name))) - one_transform = transform_dict[name](transform_param) - transform_list.append(one_transform) - data_transform = transforms.Compose(transform_list) - print('transform list', transform_list) - csv_file = config['dataset']['train_csv'] - modal_num = config['dataset'].get('modal_num', 1) - dataset = NiftyDataset(root_dir = config['dataset']['root_dir'], - csv_file = csv_file, - modal_num = modal_num, - with_label= True, - transform = data_transform ) - - agent = NLLConfidentLearn(config, 'test') - agent.set_datasets(None, None, dataset) - agent.transform_list = transform_list - agent.create_dataset() - agent.create_network() - agent.infer_with_cl() - - # create training csv for confidence learning - df_train = pd.read_csv(csv_file) - pixel_weight = [] - weight_dir = config['testing']['output_dir'] + "_conf" - for i in range(len(df_train["label"])): - lab_name = df_train["label"][i].split('/')[-1] - weight_name = "../" + weight_dir + '/' + lab_name - pixel_weight.append(weight_name) - train_cl_dict = {"image": df_train["image"], - "pixel_weight": pixel_weight, - "label": df_train["label"]} - train_cl_csv = csv_file.replace(".csv", "_cl.csv") - df_cl = pd.DataFrame.from_dict(train_cl_dict) - df_cl.to_csv(train_cl_csv, index = False) - -if __name__ == "__main__": - get_confidence_map() \ No newline at end of file diff --git a/pymic/net_run_nll/nll_clslsr.py b/pymic/net_run_nll/nll_clslsr.py index 9ee7182..722b588 100644 --- a/pymic/net_run_nll/nll_clslsr.py +++ b/pymic/net_run_nll/nll_clslsr.py @@ -1,14 +1,5 @@ # -*- coding: utf-8 -*- -""" -Caculating the confidence map of labels of training samples, -which is used in the method of SLSR. - Minqing Zhang et al., Characterizing Label Errors: Confident Learning - for Noisy-Labeled Image Segmentation, MICCAI 2020. - https://link.springer.com/chapter/10.1007/978-3-030-59710-8_70 -""" - from __future__ import print_function, division -import cleanlab import logging import os import scipy @@ -27,9 +18,16 @@ def get_confident_map(gt, pred, CL_type = 'both'): """ - gt: ground truth label (one-hot) with shape of NXC - pred: digit prediction of network with shape of NXC + Get the confidence map based on the label and prediction. + + :param gt: (tensor) One-hot label with shape of NXC. + :param pred: (tensor) Digit prediction of network with shape of NXC. + :param CL_type: (str) A string in {'both', 'Qij', 'Cij', 'intersection', + 'union', 'prune_by_class', 'prune_by_noise_rate'}. + + :return: A tensor representing the noisiness of each pixel. """ + import cleanlab prob = scipy.special.softmax(pred, axis = 1) if CL_type in ['both', 'Qij']: noise = cleanlab.pruning.get_noise_indices(gt, prob, prune_method='both', n_jobs=1) @@ -48,10 +46,24 @@ def get_confident_map(gt, pred, CL_type = 'both'): return noise class NLLCLSLSR(SegmentationAgent): + """ + An agent to estimatate the confidence of noisy labels during inference. + + * Reference: Minqing Zhang et al., Characterizing Label Errors: Confident Learning + for Noisy-Labeled Image Segmentation, + `MICCAI 2020. `_ + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + """ def __init__(self, config, stage = 'test'): super(NLLCLSLSR, self).__init__(config, stage) def infer_with_cl(self): + """ + Inference with confidence estimation. + """ device_ids = self.config['testing']['gpus'] device = torch.device("cuda:{0:}".format(device_ids[0])) self.net.to(device) @@ -135,9 +147,12 @@ def test_time_dropout(m): conf_map.save(dst_path) def get_confidence_map(): + """ + The main function to get the confidence map during inference. + """ if(len(sys.argv) < 2): print('Number of arguments should be 3. e.g.') - print(' python nll_cl.py config.cfg') + print(' python nll_clslsr.py config.cfg') exit() cfg_file = str(sys.argv[1]) config = parse_config(cfg_file) diff --git a/pymic/net_run_nll/nll_co_teaching.py b/pymic/net_run_nll/nll_co_teaching.py index bcaec4e..06240d6 100644 --- a/pymic/net_run_nll/nll_co_teaching.py +++ b/pymic/net_run_nll/nll_co_teaching.py @@ -1,14 +1,5 @@ # -*- coding: utf-8 -*- -""" -Implementation of Co-teaching for learning from noisy samples for -segmentation tasks according to the following paper: - Bo Han et al., Co-teaching: Robust Training of Deep NeuralNetworks - with Extremely Noisy Labels, NeurIPS, 2018 -The author's original implementation was: -https://github.com/bhanML/Co-teaching - -""" from __future__ import print_function, division import logging import os @@ -45,9 +36,20 @@ def forward(self, x): class NLLCoTeaching(SegmentationAgent): """ - Co-teaching: Robust Training of Deep Neural Networks with Extremely - Noisy Labels - https://arxiv.org/abs/1804.06872 + Co-teaching for noisy-label learning. + + * Reference: Bo Han, Quanming Yao, Xingrui Yu, Gang Niu, Miao Xu, Weihua Hu, + Ivor Tsang, Masashi Sugiyama. Co-teaching: Robust Training of Deep Neural Networks with Extremely + Noisy Labels. `NeurIPS 201. `_ + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `noisy_label_learning` is needed. See :doc:`usage.nll` for details. """ def __init__(self, config, stage = 'train'): super(NLLCoTeaching, self).__init__(config, stage) diff --git a/pymic/net_run_nll/nll_dast.py b/pymic/net_run_nll/nll_dast.py index 19a59a2..f057f64 100644 --- a/pymic/net_run_nll/nll_dast.py +++ b/pymic/net_run_nll/nll_dast.py @@ -1,12 +1,4 @@ # -*- coding: utf-8 -*- -""" -Implementation of DAST for noise robust learning according to the following paper. - Shuojue Yang, Guotai Wang, Hui Sun, Xiangde Luo, Peng Sun, Kang Li, Qijun Wang, - Shaoting Zhang: Learning COVID-19 Pneumonia Lesion Segmentation from Imperfect - Annotations via Divergence-Aware Selective Training. - JBHI 2022. https://ieeexplore.ieee.org/document/9770406 -""" - from __future__ import print_function, division import random import torch @@ -24,7 +16,9 @@ class Rank(object): """ - Dynamically rank the current training sample with specific metrics + Dynamically rank the current training sample with specific metrics. + + :param quene_length: (int) The lenght for a quene. """ def __init__(self, quene_length = 100): self.vals = [] @@ -34,9 +28,8 @@ def add_val(self, val): """ Update the quene and calculate the order of the input value. - Return - --------- - rank: rank of the input value with a range of (0, self.quenen_length) + :param val: (float) a value adding to the quene. + :return: rank of the input value with a range of (0, self.quene_length) """ if len(self.vals) < self.quene_length: self.vals.append(val) @@ -80,9 +73,11 @@ def get_ce(prob, soft_y, size_avg = True): @torch.no_grad() def select_criterion(no_noisy_sample, cl_noisy_sample, label): """ - no_noisy_sample: noisy branch's output probability for noisy sample - cl_noisy_sample: clean branch's output probability for noisy sample - label: noisy label + Obtain the sample selection criterion score. + + :param no_noisy_sample: noisy branch's output probability for noisy sample. + :param cl_noisy_sample: clean branch's output probability for noisy sample. + :param label: noisy label. """ l_n = get_ce(no_noisy_sample, label, size_avg = False) l_c = get_ce(cl_noisy_sample, label, size_avg = False) @@ -94,6 +89,23 @@ def select_criterion(no_noisy_sample, cl_noisy_sample, label): return loss_n, loss_c class NLLDAST(SegmentationAgent): + """ + Divergence-Aware Selective Training for noisy label learning. + + * Reference: Shuojue Yang, Guotai Wang, Hui Sun, Xiangde Luo, Peng Sun, + Kang Li, Qijun Wang, Shaoting Zhang: Learning COVID-19 Pneumonia Lesion + Segmentation from Imperfect Annotations via Divergence-Aware Selective Training. + `JBHI 2022. `_ + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `noisy_label_learning` is needed. See :doc:`usage.nll` for details. + """ def __init__(self, config, stage = 'train'): super(NLLDAST, self).__init__(config, stage) self.train_set_noise = None @@ -103,6 +115,9 @@ def __init__(self, config, stage = 'train'): self.clean_rank = None def get_noisy_dataset_from_config(self): + """ + Create a dataset for images with noisy labels based on configuraiton. + """ root_dir = self.config['dataset']['root_dir'] modal_num = self.config['dataset'].get('modal_num', 1) transform_names = self.config['dataset']['train_transform'] diff --git a/pymic/net_run_nll/nll_main.py b/pymic/net_run_nll/nll_main.py index cc07a44..a855f2b 100644 --- a/pymic/net_run_nll/nll_main.py +++ b/pymic/net_run_nll/nll_main.py @@ -14,6 +14,9 @@ "DAST": NLLDAST} def main(): + """ + The main function for noisy label learning methods. + """ if(len(sys.argv) < 3): print('Number of arguments should be 3. e.g.') print(' pymic_nll train config.cfg') diff --git a/pymic/net_run_nll/nll_trinet.py b/pymic/net_run_nll/nll_trinet.py index 6af5449..7ac74af 100644 --- a/pymic/net_run_nll/nll_trinet.py +++ b/pymic/net_run_nll/nll_trinet.py @@ -1,12 +1,5 @@ # -*- coding: utf-8 -*- -""" -Implementation of trinet for learning from noisy samples for -segmentation tasks according to the following paper: - Tianwei Zhang, Lequan Yu, Na Hu, Su Lv, Shi Gu: - Robust Medical Image Segmentation from Non-expert Annotations with Tri-network. - MICCAI 2020. - https://link.springer.com/chapter/10.1007/978-3-030-59719-1_25 -""" + from __future__ import print_function, division import logging import os @@ -25,8 +18,6 @@ from pymic.util.parse_config import * from pymic.util.ramps import get_rampup_ratio - - class TriNet(nn.Module): def __init__(self, params): super(TriNet, self).__init__() @@ -46,6 +37,23 @@ def forward(self, x): return (out1 + out2 + out3) / 3 class NLLTriNet(SegmentationAgent): + """ + Implementation of trinet for learning from noisy samples for + segmentation tasks. + + * Reference: Tianwei Zhang, Lequan Yu, Na Hu, Su Lv, Shi Gu: + Robust Medical Image Segmentation from Non-expert Annotations with Tri-network. + `MICCAI 2020. `_ + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `noisy_label_learning` is needed. See :doc:`usage.nll` for details. + """ def __init__(self, config, stage = 'train'): super(NLLTriNet, self).__init__(config, stage) diff --git a/pymic/net_run_wsl/wsl_abstract.py b/pymic/net_run_wsl/wsl_abstract.py index d64063e..f317dd0 100644 --- a/pymic/net_run_wsl/wsl_abstract.py +++ b/pymic/net_run_wsl/wsl_abstract.py @@ -5,13 +5,19 @@ class WSLSegAgent(SegmentationAgent): """ - Training and testing agent for weakly supervised segmentation + Abstract agent for weakly supervised segmentation. + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `weakly_supervised_learning` is needed. See :doc:`usage.wsl` for details. """ def __init__(self, config, stage = 'train'): super(WSLSegAgent, self).__init__(config, stage) - - def training(self): - pass def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): loss_scalar ={'train':train_scalars['loss'], diff --git a/pymic/net_run_wsl/wsl_dmpls.py b/pymic/net_run_wsl/wsl_dmpls.py index 8ee9e53..15e74a0 100644 --- a/pymic/net_run_wsl/wsl_dmpls.py +++ b/pymic/net_run_wsl/wsl_dmpls.py @@ -14,12 +14,21 @@ class WSLDMPLS(WSLSegAgent): """ - Implementation of the following paper: - Xiangde Luo, Minhao Hu, Wenjun Liao, Shuwei Zhai, Tao Song, Guotai Wang, - Shaoting Zhang. ScribblScribble-Supervised Medical Image Segmentation via - Dual-Branch Network and Dynamically Mixed Pseudo Labels Supervision. - MICCAI 2022. - https://arxiv.org/abs/2203.02106 + Weakly supervised segmentation based on Dynamically Mixed Pseudo Labels Supervision. + + * Reference: Xiangde Luo, Minhao Hu, Wenjun Liao, Shuwei Zhai, Tao Song, Guotai Wang, + Shaoting Zhang. ScribblScribble-Supervised Medical Image Segmentation via + Dual-Branch Network and Dynamically Mixed Pseudo Labels Supervision. + `MICCAI 2022. `_ + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `weakly_supervised_learning` is needed. See :doc:`usage.wsl` for details. """ def __init__(self, config, stage = 'train'): net_type = config['network']['net_type'] diff --git a/pymic/net_run_wsl/wsl_em.py b/pymic/net_run_wsl/wsl_em.py index 3b2d595..387bd93 100644 --- a/pymic/net_run_wsl/wsl_em.py +++ b/pymic/net_run_wsl/wsl_em.py @@ -14,7 +14,20 @@ class WSLEntropyMinimization(WSLSegAgent): """ - Weakly suepervised segmentation with Entropy Minimization Regularization. + Weakly supervised segmentation based on Entropy Minimization. + + * Reference: Yves Grandvalet and Yoshua Bengio: + Semi-supervised Learningby Entropy Minimization. + `NeurIPS, 2005. `_ + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `weakly_supervised_learning` is needed. See :doc:`usage.wsl` for details. """ def __init__(self, config, stage = 'train'): super(WSLEntropyMinimization, self).__init__(config, stage) diff --git a/pymic/net_run_wsl/wsl_gatedcrf.py b/pymic/net_run_wsl/wsl_gatedcrf.py index 64e0f1b..270a955 100644 --- a/pymic/net_run_wsl/wsl_gatedcrf.py +++ b/pymic/net_run_wsl/wsl_gatedcrf.py @@ -13,12 +13,20 @@ class WSLGatedCRF(WSLSegAgent): """ - Implementation of the Gated CRF Loss for Weakly Supervised Semantic Image Segmentation. - Anton Obukhov, Stamatios Georgoulis, Dengxin Dai, Luc Van Gool: - Gated CRF Loss for Weakly Supervised Semantic Image Segmentation. - CoRR, abs/1906.04651, 2019 - http://arxiv.org/abs/1906.04651 - } + Implementation of the Gated CRF loss for weakly supervised segmentation. + + * Reference: Anton Obukhov, Stamatios Georgoulis, Dengxin Dai, Luc Van Gool: + Gated CRF Loss for Weakly Supervised Semantic Image Segmentation. + `CoRR `_, abs/1906.04651, 2019. + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `weakly_supervised_learning` is needed. See :doc:`usage.wsl` for details. """ def __init__(self, config, stage = 'train'): super(WSLGatedCRF, self).__init__(config, stage) diff --git a/pymic/net_run_wsl/wsl_main.py b/pymic/net_run_wsl/wsl_main.py index abedb6b..d8e791d 100644 --- a/pymic/net_run_wsl/wsl_main.py +++ b/pymic/net_run_wsl/wsl_main.py @@ -20,9 +20,12 @@ 'DMPLS': WSLDMPLS} def main(): + """ + The main function for training and inference of weakly supervised segmentation. + """ if(len(sys.argv) < 3): print('Number of arguments should be 3. e.g.') - print(' pymic_ssl train config.cfg') + print(' pymic_wsl train config.cfg') exit() stage = str(sys.argv[1]) cfg_file = str(sys.argv[2]) diff --git a/pymic/net_run_wsl/wsl_mumford_shah.py b/pymic/net_run_wsl/wsl_mumford_shah.py index df4c68f..862d761 100644 --- a/pymic/net_run_wsl/wsl_mumford_shah.py +++ b/pymic/net_run_wsl/wsl_mumford_shah.py @@ -13,10 +13,20 @@ class WSLMumfordShah(WSLSegAgent): """ - Weakly supervised learning with Mumford Shah Loss according to this paper: - Boah Kim and Jong Chul Ye: Mumford–Shah Loss Functional - for Image Segmentation With Deep Learning. IEEE TIP, 2019. - https://doi.org/10.1109/TIP.2019.2941265 + Weakly supervised learning with Mumford Shah Loss. + + * Reference: Boah Kim and Jong Chul Ye: Mumford–Shah Loss Functional + for Image Segmentation With Deep Learning. + `IEEE TIP `_, 2019. + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `weakly_supervised_learning` is needed. See :doc:`usage.wsl` for details. """ def __init__(self, config, stage = 'train'): super(WSLMumfordShah, self).__init__(config, stage) diff --git a/pymic/net_run_wsl/wsl_tv.py b/pymic/net_run_wsl/wsl_tv.py index 2e56cb4..3fd55a3 100644 --- a/pymic/net_run_wsl/wsl_tv.py +++ b/pymic/net_run_wsl/wsl_tv.py @@ -14,7 +14,16 @@ class WSLTotalVariation(WSLSegAgent): """ - Weakly suepervised segmentation with Total Variation Regularization. + Weakly suepervised segmentation with Total Variation regularization. + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `weakly_supervised_learning` is needed. See :doc:`usage.wsl` for details. """ def __init__(self, config, stage = 'train'): super(WSLTotalVariation, self).__init__(config, stage) diff --git a/pymic/net_run_wsl/wsl_ustm.py b/pymic/net_run_wsl/wsl_ustm.py index 0a2f7e1..b033f9d 100644 --- a/pymic/net_run_wsl/wsl_ustm.py +++ b/pymic/net_run_wsl/wsl_ustm.py @@ -16,12 +16,21 @@ class WSLUSTM(WSLSegAgent): """ - USTM for scribble-supervised segmentation according to the following paper: - Xiaoming Liu, Quan Yuan, Yaozong Gao, Helei He, Shuo Wang, Xiao Tang, - Jinshan Tang, Dinggang Shen: - Weakly Supervised Segmentation of COVID19 Infection with Scribble Annotation on CT Images. - Patter Recognition, 2022. - https://doi.org/10.1016/j.patcog.2021.108341 + USTM for scribble-supervised segmentation. + + * Reference: Xiaoming Liu, Quan Yuan, Yaozong Gao, Helei He, Shuo Wang, + Xiao Tang, Jinshan Tang, Dinggang Shen: Weakly Supervised Segmentation + of COVID19 Infection with Scribble Annotation on CT Images. + `Patter Recognition `_, 2022. + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `weakly_supervised_learning` is needed. See :doc:`usage.wsl` for details. """ def __init__(self, config, stage = 'train'): super(WSLUSTM, self).__init__(config, stage) From 7ecdb0c8cba485d1fab8038cab81faf3b8de5271 Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 5 Sep 2022 22:26:50 +0800 Subject: [PATCH 109/225] update transform and docs update transform and docs --- docs/source/pymic.transform.rst | 8 -- docs/source/pymic.util.rst | 16 --- pymic/transform/abstract_transform.py | 13 ++ pymic/transform/crop.py | 122 ++++++++++-------- pymic/transform/flip.py | 20 +-- pymic/transform/gray2rgb.py | 31 ----- pymic/transform/intensity.py | 49 +++++-- pymic/transform/label_convert.py | 65 ++++++---- pymic/transform/normalize.py | 79 ++++++++---- pymic/transform/pad.py | 25 ++-- pymic/transform/rescale.py | 39 ++++-- pymic/transform/rotate.py | 36 +++--- pymic/transform/threshold.py | 68 ++++++---- pymic/transform/trans_dict.py | 2 - pymic/util/average_model.py | 23 ---- pymic/util/evaluation_cls.py | 38 +++++- .../{rename_model.py => model_operate.py} | 19 +++ 17 files changed, 386 insertions(+), 267 deletions(-) delete mode 100644 pymic/transform/gray2rgb.py delete mode 100644 pymic/util/average_model.py rename pymic/util/{rename_model.py => model_operate.py} (60%) diff --git a/docs/source/pymic.transform.rst b/docs/source/pymic.transform.rst index 1dafa3b..6b78bf1 100644 --- a/docs/source/pymic.transform.rst +++ b/docs/source/pymic.transform.rst @@ -28,14 +28,6 @@ pymic.transform.flip module :undoc-members: :show-inheritance: -pymic.transform.gray2rgb module -------------------------------- - -.. automodule:: pymic.transform.gray2rgb - :members: - :undoc-members: - :show-inheritance: - pymic.transform.intensity module -------------------------------- diff --git a/docs/source/pymic.util.rst b/docs/source/pymic.util.rst index 3a6c3ab..65ef60d 100644 --- a/docs/source/pymic.util.rst +++ b/docs/source/pymic.util.rst @@ -4,14 +4,6 @@ pymic.util package Submodules ---------- -pymic.util.average\_model module --------------------------------- - -.. automodule:: pymic.util.average_model - :members: - :undoc-members: - :show-inheritance: - pymic.util.evaluation\_cls module --------------------------------- @@ -76,14 +68,6 @@ pymic.util.ramps module :undoc-members: :show-inheritance: -pymic.util.rename\_model module -------------------------------- - -.. automodule:: pymic.util.rename_model - :members: - :undoc-members: - :show-inheritance: - Module contents --------------- diff --git a/pymic/transform/abstract_transform.py b/pymic/transform/abstract_transform.py index 4958bde..5556ee0 100644 --- a/pymic/transform/abstract_transform.py +++ b/pymic/transform/abstract_transform.py @@ -2,11 +2,24 @@ from __future__ import print_function, division class AbstractTransform(object): + """ + The abstract class for Transform. + """ def __init__(self, params): self.task = params['Task'.lower()] def __call__(self, sample): + """ + Forward pass of the transform. + + :arg sample: (dict) A dictionary for the input sample obtained by dataloader. + """ return sample def inverse_transform_for_prediction(self, sample): + """ + Inverse transform for the sample dictionary. + Especially, it will update sample['predict'] obtained by a network's + prediction based on the inverse transform. This function is only useful for spatial transforms. + """ raise(ValueError("not implemented")) diff --git a/pymic/transform/crop.py b/pymic/transform/crop.py index ec486dc..a27288d 100644 --- a/pymic/transform/crop.py +++ b/pymic/transform/crop.py @@ -13,19 +13,23 @@ class CenterCrop(AbstractTransform): """ Crop the given image at the center. - input shape should be [C, D, H, W] or [C, H, W]) + Input shape should be [C, D, H, W] or [C, H, W]. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `CenterCrop_output_size`: (list or tuple) The output size. + [D, H, W] for 3D images and [H, W] for 2D images. + If D is None, then the z-axis is not cropped. + :param `CenterCrop_inverse`: (optional, bool) Is inverse transform needed for inference. + Default is `True`. """ def __init__(self, params): - """ - output_size (tuple/list): Desired spatial output size. - [D, H, W] for 3D images and [H, W] for 2D images - If D is None, then the z-axis is not cropped - """ + super(CenterCrop, self).__init__(params) self.output_size = params['CenterCrop_output_size'.lower()] self.inverse = params.get('CenterCrop_inverse'.lower(), True) - self.task = params['Task'.lower()] - def get_crop_param(self, sample): + def _get_crop_param(self, sample): input_shape = sample['image'].shape input_dim = len(input_shape) - 1 assert(input_dim == len(self.output_size)) @@ -46,7 +50,7 @@ def get_crop_param(self, sample): def __call__(self, sample): image = sample['image'] - sample, crop_min, crop_max = self.get_crop_param(sample) + sample, crop_min, crop_max = self._get_crop_param(sample) image_t = crop_ND_volume_with_bounding_box(image, crop_min, crop_max) sample['image'] = image_t @@ -63,7 +67,7 @@ def __call__(self, sample): sample['pixel_weight'] = weight return sample - def get_param_for_inverse_transform(self, sample): + def _get_param_for_inverse_transform(self, sample): if(isinstance(sample['CenterCrop_Param'], list) or \ isinstance(sample['CenterCrop_Param'], tuple)): params = json.loads(sample['CenterCrop_Param'][0]) @@ -72,9 +76,7 @@ def get_param_for_inverse_transform(self, sample): return params def inverse_transform_for_prediction(self, sample): - ''' rescale sample['predict'] (5D or 4D) to the original spatial shape. - origin_shape is a 4D or 3D vector as saved in __call__().''' - params = self.get_param_for_inverse_transform(sample) + params = self._get_param_for_inverse_transform(sample) origin_shape = params[0] crop_min = params[1] crop_max = params[2] @@ -101,22 +103,27 @@ def inverse_transform_for_prediction(self, sample): return sample class CropWithBoundingBox(CenterCrop): - """Crop the image (shape [C, D, H, W] or [C, H, W]) based on bounding box + """ + Crop the image (shape [C, D, H, W] or [C, H, W]) based on a bounding box. + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `CropWithBoundingBox_start`: (None, or list/tuple) The start index + along each spatial axis. If None, calculate the start index automatically + so that the cropped region is centered at the non-zero region. + :param `CropWithBoundingBox_output_size`: (None or tuple/list): + Desired spatial output size. + If None, set it as the size of bounding box of non-zero region. + :param `CropWithBoundingBox_inverse`: (optional, bool) Is inverse transform needed for inference. + Default is `True`. """ def __init__(self, params): - """ - start (None or tuple/list): The start index along each spatial axis. - if None, calculate the start index automatically so that - the cropped region is centered at the non-zero region. - output_size (None or tuple/list): Desired spatial output size. - if None, set it as the size of bounding box of non-zero region - """ self.start = params['CropWithBoundingBox_start'.lower()] self.output_size = params['CropWithBoundingBox_output_size'.lower()] self.inverse = params.get('CropWithBoundingBox_inverse'.lower(), True) self.task = params['task'] - def get_crop_param(self, sample): + def _get_crop_param(self, sample): image = sample['image'] input_shape = sample['image'].shape input_dim = len(input_shape) - 1 @@ -146,7 +153,7 @@ def get_crop_param(self, sample): print("for crop", crop_min, crop_max) return sample, crop_min, crop_max - def get_param_for_inverse_transform(self, sample): + def _get_param_for_inverse_transform(self, sample): if(isinstance(sample['CropWithBoundingBox_Param'], list) or \ isinstance(sample['CropWithBoundingBox_Param'], tuple)): params = json.loads(sample['CropWithBoundingBox_Param'][0]) @@ -156,20 +163,26 @@ def get_param_for_inverse_transform(self, sample): class RandomCrop(CenterCrop): - """Randomly crop the input image (shape [C, D, H, W] or [C, H, W]) + """Randomly crop the input image (shape [C, D, H, W] or [C, H, W]). + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `RandomCrop_output_size`: (list/tuple) Desired output size [D, H, W] or [H, W]. + The output channel is the same as the input channel. + If D is None for 3D images, the z-axis is not cropped. + :param `RandomCrop_foreground_focus`: (optional, bool) + If true, allow crop around the foreground. Default is False. + :param `RandomCrop_foreground_ratio`: (optional, float) + Specifying the probability of foreground focus cropping when + `RandomCrop_foreground_focus` is True. + :param `RandomCrop_mask_label`: (optional, None, or list/tuple) + Specifying the foreground labels for foreground focus cropping when + `RandomCrop_foreground_focus` is True. + :param `RandomCrop_inverse`: (optional, bool) Is inverse transform needed for inference. + Default is `True`. """ def __init__(self, params): - """ - output_size (tuple or list): Desired output size [D, H, W] or [H, W]. - the output channel is the same as the input channel. - If D is None for 3D images, the z-axis is not cropped - foreground_focus (bool): If true, allow crop around the foreground. - foreground_ratio (float): Specifying the probability of foreground - focus cropping when foreground_focus is true. - mask_label (None, or tuple / list): Specifying the foreground labels for foreground - focus cropping - """ - # super(RandomCrop, self).__init__(params) self.output_size = params['RandomCrop_output_size'.lower()] self.fg_focus = params.get('RandomCrop_foreground_focus'.lower(), False) self.fg_ratio = params.get('RandomCrop_foreground_ratio'.lower(), 0.5) @@ -180,7 +193,7 @@ def __init__(self, params): if(self.mask_label is not None): assert isinstance(self.mask_label, (list, tuple)) - def get_crop_param(self, sample): + def _get_crop_param(self, sample): image = sample['image'] input_shape = image.shape input_dim = len(input_shape) - 1 @@ -217,7 +230,7 @@ def get_crop_param(self, sample): sample['RandomCrop_Param'] = json.dumps((input_shape, crop_min, crop_max)) return sample, crop_min, crop_max - def get_param_for_inverse_transform(self, sample): + def _get_param_for_inverse_transform(self, sample): if(isinstance(sample['RandomCrop_Param'], list) or \ isinstance(sample['RandomCrop_Param'], tuple)): params = json.loads(sample['RandomCrop_Param'][0]) @@ -226,26 +239,31 @@ def get_param_for_inverse_transform(self, sample): return params class RandomResizedCrop(CenterCrop): - """Randomly crop the input image (shape [C, H, W]) - Only 2D images are supported + """ + Randomly crop the input image (shape [C, H, W]). Only 2D images are supported. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `RandomResizedCrop_output_size`: (list/tuple) Desired output size [H, W]. + The output channel is the same as the input channel. + :param `RandomResizedCrop_scale`: (list/tuple) Range of scale, e.g. (0.08, 1.0). + :param `RandomResizedCrop_ratio`: (list/tuple) Range of aspect ratio, e.g. (0.75, 1.33). + :param `RandomResizedCrop_inverse`: (optional, bool) Is inverse transform needed for inference. + Default is `False`. Currently, the inverse transform is not supported, and + this transform is assumed to be used only during training stage. """ def __init__(self, params): - """ - output_size (tuple or list): Desired output size [H, W]. - the output channel is the same as the input channel. - scale (tuple or list): range of scale, e.g. (0.08, 1.0) - ratio (tuple or list): range of aspect ratio, e.g. (0.75, 1.33) - """ self.output_size = params['RandomResizedCrop_output_size'.lower()] self.scale = params['RandomResizedCrop_scale'.lower()] self.ratio = params['RandomResizedCrop_ratio'.lower()] - self.inverse = params.get('RandomResizedCrop_inverse'.lower(), True) + self.inverse = params.get('RandomResizedCrop_inverse'.lower(), False) self.task = params['Task'.lower()] assert isinstance(self.output_size, (list, tuple)) assert isinstance(self.scale, (list, tuple)) assert isinstance(self.ratio, (list, tuple)) - def get_crop_param(self, sample): + def _get_crop_param(self, sample): image = sample['image'] input_shape = image.shape input_dim = len(input_shape) - 1 @@ -273,7 +291,7 @@ def __call__(self, sample): image = sample['image'] input_shape = image.shape input_dim = len(input_shape) - 1 - sample, crop_min, crop_max = self.get_crop_param(sample) + sample, crop_min, crop_max = self._get_crop_param(sample) image_t = crop_ND_volume_with_bounding_box(image, crop_min, crop_max) crp_shape = image_t.shape @@ -294,10 +312,4 @@ def __call__(self, sample): weight = crop_ND_volume_with_bounding_box(weight, crop_min, crop_max) weight = ndimage.interpolation.zoom(weight, scale, order = 1) sample['pixel_weight'] = weight - return sample - - def inverse_transform_for_prediction(self, sample): - """ - not implemented - """ - raise(ValueError("not implemented")) \ No newline at end of file + return sample \ No newline at end of file diff --git a/pymic/transform/flip.py b/pymic/transform/flip.py index 7b4a995..ca0915e 100644 --- a/pymic/transform/flip.py +++ b/pymic/transform/flip.py @@ -12,13 +12,19 @@ class RandomFlip(AbstractTransform): - """ random flip the image (shape [C, D, H, W] or [C, H, W]) """ + """ Random flip the image. The shape is [C, D, H, W] or [C, H, W]. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `RandomFlip_flip_depth`: (bool) + Random flip along depth axis or not, only used for 3D images. + :param `RandomFlip_flip_height`: (bool) Random flip along height axis or not. + :param `RandomFlip_flip_width`: (bool) Random flip along width axis or not. + :param `RandomFlip_inverse`: (optional, bool) Is inverse transform needed for inference. + Default is `True`. + """ def __init__(self, params): - """ - flip_depth (bool) : random flip along depth axis or not, only used for 3D images - flip_height (bool): random flip along height axis or not - flip_width (bool) : random flip along width axis or not - """ super(RandomFlip, self).__init__(params) self.flip_depth = params['RandomFlip_flip_depth'.lower()] self.flip_height = params['RandomFlip_flip_height'.lower()] @@ -54,8 +60,6 @@ def __call__(self, sample): return sample def inverse_transform_for_prediction(self, sample): - ''' flip sample['predict'] (5D or 4D) to the original direction. - flip_axis is a list as saved in __call__().''' if(isinstance(sample['RandomFlip_Param'], list) or \ isinstance(sample['RandomFlip_Param'], tuple)): flip_axis = json.loads(sample['RandomFlip_Param'][0]) diff --git a/pymic/transform/gray2rgb.py b/pymic/transform/gray2rgb.py deleted file mode 100644 index 5b8c214..0000000 --- a/pymic/transform/gray2rgb.py +++ /dev/null @@ -1,31 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division - -import torch -import json -import math -import random -import numpy as np -from scipy import ndimage -from pymic.transform.abstract_transform import AbstractTransform -from pymic.util.image_process import * - - -class GrayscaleToRGB(AbstractTransform): - """ - apply random gamma correction to each channel - """ - def __init__(self, params): - """ - (gamma_min, gamma_max) specify the range of gamma - """ - super(GrayscaleToRGB, self).__init__(params) - self.inverse = params.get('GrayscaleToRGB_inverse'.lower(), False) - - def __call__(self, sample): - image= sample['image'] - assert(image.shape[0] == 1 or image.shape[0] == 3) - if(image.shape[0] == 1): - sample['image'] = np.concatenate([image, image, image]) - return sample - diff --git a/pymic/transform/intensity.py b/pymic/transform/intensity.py index b9e6070..171604b 100644 --- a/pymic/transform/intensity.py +++ b/pymic/transform/intensity.py @@ -13,18 +13,26 @@ class GammaCorrection(AbstractTransform): """ - apply random gamma correction to each channel + Apply random gamma correction to given channels. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `GammaCorrection_channels`: (list) A list of int for specifying the channels. + :param `GammaCorrection_gamma_min`: (float) The minimal gamma value. + :param `GammaCorrection_gamma_max`: (float) The maximal gamma value. + :param `GammaCorrection_probability`: (optional, float) + The probability of applying GammaCorrection. Default is 0.5. + :param `GammaCorrection_inverse`: (optional, bool) + Is inverse transform needed for inference. Default is `False`. """ def __init__(self, params): - """ - (gamma_min, gamma_max) specify the range of gamma - """ super(GammaCorrection, self).__init__(params) self.channels = params['GammaCorrection_channels'.lower()] self.gamma_min = params['GammaCorrection_gamma_min'.lower()] self.gamma_max = params['GammaCorrection_gamma_max'.lower()] self.prob = params.get('GammaCorrection_probability'.lower(), 0.5) - self.inverse = params.get('GammaCorrection_inverse'.lower(), False) + self.inverse = params.get('GammaCorrection_inverse'.lower(), False) def __call__(self, sample): if(np.random.uniform() > self.prob): @@ -44,12 +52,20 @@ def __call__(self, sample): class GaussianNoise(AbstractTransform): """ - apply random gamma correction to each channel + Add Gaussian Noise to given channels. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `GaussianNoise_channels`: (list) A list of int for specifying the channels. + :param `GaussianNoise_mean`: (float) The mean value of noise. + :param `GaussianNoise_std`: (float) The std of noise. + :param `GaussianNoise_probability`: (optional, float) + The probability of applying GaussianNoise. Default is 0.5. + :param `GaussianNoise_inverse`: (optional, bool) + Is inverse transform needed for inference. Default is `False`. """ def __init__(self, params): - """ - (gamma_min, gamma_max) specify the range of gamma - """ super(GaussianNoise, self).__init__(params) self.channels = params['GaussianNoise_channels'.lower()] self.mean = params['GaussianNoise_mean'.lower()] @@ -68,3 +84,18 @@ def __call__(self, sample): sample['image'] = image return sample + +class GrayscaleToRGB(AbstractTransform): + """ + Convert gray scale images to RGB by copying channels. + """ + def __init__(self, params): + super(GrayscaleToRGB, self).__init__(params) + self.inverse = params.get('GrayscaleToRGB_inverse'.lower(), False) + + def __call__(self, sample): + image= sample['image'] + assert(image.shape[0] == 1 or image.shape[0] == 3) + if(image.shape[0] == 1): + sample['image'] = np.concatenate([image, image, image]) + return sample \ No newline at end of file diff --git a/pymic/transform/label_convert.py b/pymic/transform/label_convert.py index 3d2f2cc..0dcae37 100644 --- a/pymic/transform/label_convert.py +++ b/pymic/transform/label_convert.py @@ -12,7 +12,7 @@ class ReduceLabelDim(AbstractTransform): """ - remove the first dimension of label tensor + Remove the first dimension of label tensor. """ def __init__(self, params): super(ReduceLabelDim, self).__init__(params) @@ -25,12 +25,18 @@ def __call__(self, sample): return sample class LabelConvert(AbstractTransform): - """ Convert a list of labels to another list""" + """ + Convert the label based on a source list and target list. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `LabelConvert_source_list`: (list) A list of labels to be converted. + :param `LabelConvert_target_list`: (list) The target label list. + :param `LabelConvert_inverse`: (optional, bool) + Is inverse transform needed for inference. Default is `False`. + """ def __init__(self, params): - """ - source_list (tuple/list): A list of labels to be converted - target_list (tuple/list): The target label list - """ super(LabelConvert, self).__init__(params) self.source_list = params['LabelConvert_source_list'.lower()] self.target_list = params['LabelConvert_target_list'.lower()] @@ -44,7 +50,9 @@ def __call__(self, sample): return sample class LabelConvertNonzero(AbstractTransform): - """ Convert label into binary (nonzero as 1)""" + """ + Convert label into binary, i.e., setting nonzero labels as 1. + """ def __init__(self, params): super(LabelConvertNonzero, self).__init__(params) self.inverse = params.get('LabelConvertNonzero_inverse'.lower(), False) @@ -56,11 +64,17 @@ def __call__(self, sample): return sample class LabelToProbability(AbstractTransform): - """Convert one-channel label map to one-hot multi-channel probability map""" + """ + Convert one-channel label map to one-hot multi-channel probability map. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `LabelToProbability_class_num`: (int) The class number in the label map. + :param `LabelToProbability_inverse`: (optional, bool) + Is inverse transform needed for inference. Default is `False`. + """ def __init__(self, params): - """ - class_num (int): the class number in the label map - """ super(LabelToProbability, self).__init__(params) self.class_num = params['LabelToProbability_class_num'.lower()] self.inverse = params.get('LabelToProbability_inverse'.lower(), False) @@ -81,15 +95,21 @@ def __call__(self, sample): class PartialLabelToProbability(AbstractTransform): - """Convert one-channel label map to one-hot multi-channel probability map - Note that the label map represents partial labels. - For segmentation tasks only. - 0: background - 1 to C-1: foreground (C-classes) - C: unknown label. - the output consists of: - label_prob: one-hot probability map - pixel_weight: weigh of pixels, 0 if the label is unknown + """ + Convert one-channel partial label map to one-hot multi-channel probability map. + This is used for segmentation tasks only. In the input label map, 0 represents the + background class, 1 to C-1 represent the foreground classes, and C represents + unlabeled pixels. In the output dictionary, `label_prob` is the one-hot probability + map, and `pixel_weight` represents a weighting map, where the weight for a pixel + is 0 if the label is unkown. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `PartialLabelToProbability_class_num`: (int) The class number for the + segmentation task. + :param `PartialLabelToProbability_inverse`: (optional, bool) + Is inverse transform needed for inference. Default is `False`. """ def __init__(self, params): """ @@ -107,11 +127,6 @@ def __call__(self, sample): label_prob[i] = label == i*np.ones_like(label) sample['label_prob'] = label_prob sample['pixel_weight'] = 1.0 - np.asarray([label == self.class_num], np.float32) - - # # for gated CRF loss - # scribble = label - 1 - # scribble[label == 0] = 255 - # sample['scribbles'] = scribble return sample diff --git a/pymic/transform/normalize.py b/pymic/transform/normalize.py index c5e81ad..4e493dd 100644 --- a/pymic/transform/normalize.py +++ b/pymic/transform/normalize.py @@ -8,19 +8,31 @@ class NormalizeWithMeanStd(AbstractTransform): - """Nomralize the image (shape [C, D, H, W] or [C, H, W]) with mean and std for given channels + """ + Normalize the image based on mean and std. The image should have a shape + of [C, D, H, W] or [C, H, W]. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `NormalizeWithMeanStd_channels`: (list/tuple or None) + A list or tuple of int for specifying the channels. + If None, the transform operates on all the channels. + :param `NormalizeWithMeanStd_mean`: (list/tuple or None) + The mean values along each specified channel. + If None, the mean values are calculated automatically. + :param `NormalizeWithMeanStd_std`: (list/tuple or None) + The std values along each specified channel. + If None, the std values are calculated automatically. + :param `NormalizeWithMeanStd_ignore_non_positive`: (optional, bool) + Only used when mean and std are not given. Default is False. + If True, calculate mean and std in the positive region for normalization, + and set non-positive region to random. If False, calculate + the mean and std values in the entire image region. + :param `NormalizeWithMeanStd_inverse`: (optional, bool) + Is inverse transform needed for inference. Default is `False`. """ def __init__(self, params): - """ - :param chanels: (None or tuple/list) the indices of channels to be noramlized. - :param mean: (None or tuple/list): The mean values along each channel. - :param std : (None or tuple/list): The std values along each channel. - When mean and std are not provided, calculate them from the entire image - region or the non-positive region. - :param ignore_non_positive: (bool) Only used when mean and std are not given. - Use positive region to calculate mean and std, and set non-positive region to random. - :param inverse: (bool) Whether inverse transform is needed or not. - """ super(NormalizeWithMeanStd, self).__init__(params) self.chns = params['NormalizeWithMeanStd_channels'.lower()] self.mean = params.get('NormalizeWithMeanStd_mean'.lower(), None) @@ -56,15 +68,24 @@ def __call__(self, sample): class NormalizeWithMinMax(AbstractTransform): - """Nomralize the image (shape [C, D, H, W] or [C, H, W]) with min and max for given channels + """Nomralize the image to [0, 1]. The shape should be [C, D, H, W] or [C, H, W]. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `NormalizeWithMinMax_channels`: (list/tuple or None) + A list or tuple of int for specifying the channels. + If None, the transform operates on all the channels. + :param `NormalizeWithMinMax_threshold_lower`: (list/tuple or None) + The min values along each specified channel. + If None, the min values are calculated automatically. + :param `NormalizeWithMinMax_threshold_upper`: (list/tuple or None) + The max values along each specified channel. + If None, the max values are calculated automatically. + :param `NormalizeWithMinMax_inverse`: (optional, bool) + Is inverse transform needed for inference. Default is `False`. """ def __init__(self, params): - """ - :param chanels: (None or tuple/list) the indices of channels to be noramlized. - :param threshold_lower: (tuple/list/None) The lower threshold value along each channel. - :param threshold_upper: (typle/list/None) The upper threshold value along each channel. - :param inverse: (bool) Whether inverse transform is needed or not. - """ super(NormalizeWithMinMax, self).__init__(params) self.chns = params['NormalizeWithMinMax_channels'.lower()] self.thred_lower = params['NormalizeWithMinMax_threshold_lower'.lower()] @@ -91,15 +112,23 @@ def __call__(self, sample): return sample class NormalizeWithPercentiles(AbstractTransform): - """Nomralize the image (shape [C, D, H, W] or [C, H, W]) with percentiles for given channels + """Nomralize the image to [0, 1] with percentiles for given channels. + The shape should be [C, D, H, W] or [C, H, W]. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `NormalizeWithPercentiles_channels`: (list/tuple or None) + A list or tuple of int for specifying the channels. + If None, the transform operates on all the channels. + :param `NormalizeWithPercentiles_percentile_lower`: (float) + The min percentile, which must be between 0 and 100 inclusive. + :param `NormalizeWithPercentiles_percentile_upper`: (float) + The max percentile, which must be between 0 and 100 inclusive. + :param `NormalizeWithMinMax_inverse`: (optional, bool) + Is inverse transform needed for inference. Default is `False`. """ def __init__(self, params): - """ - :param chanels: (None or tuple/list) the indices of channels to be noramlized. - :param percentile_lower: (tuple/list/None) The lower percentile along each channel. - :param percentile_upper: (typle/list/None) The upper percentile along each channel. - :param inverse: (bool) Whether inverse transform is needed or not. - """ super(NormalizeWithPercentiles, self).__init__(params) self.chns = params['NormalizeWithPercentiles_channels'.lower()] self.percent_lower = params['NormalizeWithPercentiles_percentile_lower'.lower()] diff --git a/pymic/transform/pad.py b/pymic/transform/pad.py index 06b565c..0ec196c 100644 --- a/pymic/transform/pad.py +++ b/pymic/transform/pad.py @@ -13,17 +13,25 @@ class Pad(AbstractTransform): """ - Pad the image (shape [C, D, H, W] or [C, H, W]) to an new spatial shape, - the real output size will be max(image_size, output_size) + Pad an image to an new spatial shape. + The image has a shape of [C, D, H, W] or [C, H, W]. + The real output size will be max(image_size, output_size). + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `Pad_output_size`: (list/tuple) The output size along each spatial axis. + :param `Pad_ceil_mode`: (optional, bool) If true (by default), the real output size will + be the minimal integer multiples of output_size higher than the input size. + For example, the input image has a shape of [3, 100, 100], `Pad_output_size` + = [32, 32], and the real output size will be [3, 128, 128] if `Pad_ceil_mode` = True. + :param `Pad_inverse`: (optional, bool) + Is inverse transform needed for inference. Default is `True`. """ def __init__(self, params): - """ - :param output_size: (tuple/list) the size along each spatial axis. - :param ceil_mode: (bool) if true, the real output size is integer multiples of output_size. - """ super(Pad, self).__init__(params) self.output_size = params['Pad_output_size'.lower()] - self.ceil_mode = params['Pad_ceil_mode'.lower()] + self.ceil_mode = params.get('Pad_ceil_mode'.lower(), False) self.inverse = params.get('Pad_inverse'.lower(), True) def __call__(self, sample): @@ -62,9 +70,6 @@ def __call__(self, sample): return sample def inverse_transform_for_prediction(self, sample): - ''' crop sample['predict'] (5D or 4D) to the original spatial shape. - origin_shape is a 4D or 3D vector as saved in __call__().''' - # raise ValueError("not implemented") if(isinstance(sample['Pad_Param'], list) or isinstance(sample['Pad_Param'], tuple)): params = json.loads(sample['Pad_Param'][0]) else: diff --git a/pymic/transform/rescale.py b/pymic/transform/rescale.py index 1039e46..04dd458 100644 --- a/pymic/transform/rescale.py +++ b/pymic/transform/rescale.py @@ -12,16 +12,19 @@ class Rescale(AbstractTransform): - """Rescale the image in a sample to a given size.""" + """Rescale the image to a given size. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `Rescale_output_size`: (list/tuple or int) The output size along each spatial axis, + such as [D, H, W] or [H, W]. If D is None, the input image is only reslcaled in 2D. + If int, the smallest axis is matched to output_size keeping aspect ratio the same + as the input. + :param `Rescale_inverse`: (optional, bool) + Is inverse transform needed for inference. Default is `True`. + """ def __init__(self, params): - """ - output_size (tuple/list or int): Desired output size. - If tuple/list, output_size should in the format of [D, H, W] or [H, W]. - Channel number is kept the same as the input. If D is None, the input image - is only reslcaled in 2D. - If int, the smallest axis is matched to output_size keeping - aspect ratio the same. - """ super(Rescale, self).__init__(params) self.output_size = params["Rescale_output_size".lower()] self.inverse = params.get("Rescale_inverse".lower(), True) @@ -59,8 +62,6 @@ def __call__(self, sample): return sample def inverse_transform_for_prediction(self, sample): - ''' rescale sample['predict'] (5D or 4D) to the original spatial shape. - origin_shape is a 4D or 3D vector as saved in __call__().''' if(isinstance(sample['Rescale_origin_shape'], list) or \ isinstance(sample['Rescale_origin_shape'], tuple)): origin_shape = json.loads(sample['Rescale_origin_shape'][0]) @@ -78,7 +79,19 @@ def inverse_transform_for_prediction(self, sample): return sample class RandomRescale(AbstractTransform): - """Rescale the image in a sample to a given size.""" + """ + Rescale the input image randomly along each spatial axis. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `RandomRescale_lower_bound`: (list/tuple or int) + Desired minimal rescale ratio. If tuple/list, the length should be 3 or 2. + :param `RandomRescale_upper_bound`: (list/tuple or int) + Desired maximal rescale ratio. If tuple/list, the length should be 3 or 2. + :param `RandomRescale_inverse`: (optional, bool) + Is inverse transform needed for inference. Default is `True`. + """ def __init__(self, params): """ ratio0 (tuple/list or int): Desired minimal rescale ratio. @@ -123,8 +136,6 @@ def __call__(self, sample): return sample def inverse_transform_for_prediction(self, sample): - ''' rescale sample['predict'] (5D or 4D) to the original spatial shape. - origin_shape is a 4D or 3D vector as saved in __call__().''' if(isinstance(sample['RandomRescale_origin_shape'], list) or \ isinstance(sample['RandomRescale_origin_shape'], tuple)): origin_shape = json.loads(sample['RandomRescale_origin_shape'][0]) diff --git a/pymic/transform/rotate.py b/pymic/transform/rotate.py index 1bc1bc7..ba2d655 100644 --- a/pymic/transform/rotate.py +++ b/pymic/transform/rotate.py @@ -13,15 +13,24 @@ class RandomRotate(AbstractTransform): """ - random rotate the image (shape [C, D, H, W] or [C, H, W]) + Random rotate an image, wiht a shape of [C, D, H, W] or [C, H, W]. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `RandomRotate_angle_range_d`: (list/tuple or None) + Rotation angle (degree) range along depth axis (x-y plane), e.g., (-90, 90). + If None, no rotation along this axis. + :param `RandomRotate_angle_range_h`: (list/tuple or None) + Rotation angle (degree) range along height axis (x-z plane), e.g., (-90, 90). + If None, no rotation along this axis. Only used for 3D images. + :param `RandomRotate_angle_range_w`: (list/tuple or None) + Rotation angle (degree) range along width axis (y-z plane), e.g., (-90, 90). + If None, no rotation along this axis. Only used for 3D images. + :param `RandomRotate_inverse`: (optional, bool) + Is inverse transform needed for inference. Default is `True`. """ def __init__(self, params): - """ - angle_range_d (tuple/list/None) : rorate angle range along depth axis (degree), - only used for 3D images - angle_range_h (tuple/list/None) : rorate angle range along height axis (degree) - angle_range_w (tuple/list/None) : rorate angle range along width axis (degree) - """ super(RandomRotate, self).__init__(params) self.angle_range_d = params['RandomRotate_angle_range_d'.lower()] self.angle_range_h = params['RandomRotate_angle_range_h'.lower()] @@ -30,11 +39,11 @@ def __init__(self, params): def __apply_transformation(self, image, transform_param_list, order = 1): """ - apply rotation transformation to an ND image - Args: - image (nd array): the input nd image - transform_param_list (list): a list of roration angle and axes - order (int): interpolation order + Apply rotation transformation to an ND image. + + :param image: The input ND image. + :param transform_param_list: (list) A list of roration angle and axes. + :param order: (int) Interpolation order. """ for angle, axes in transform_param_list: image = ndimage.rotate(image, angle, axes, reshape = False, order = order) @@ -70,9 +79,6 @@ def __call__(self, sample): return sample def inverse_transform_for_prediction(self, sample): - ''' rorate sample['predict'] (5D or 4D) to the original direction. - transform_param_list is a list as saved in __call__().''' - # get the paramters for invers transformation if(isinstance(sample['RandomRotate_Param'], list) or \ isinstance(sample['RandomRotate_Param'], tuple)): transform_param_list = json.loads(sample['RandomRotate_Param'][0]) diff --git a/pymic/transform/threshold.py b/pymic/transform/threshold.py index a76c133..11e6d0e 100644 --- a/pymic/transform/threshold.py +++ b/pymic/transform/threshold.py @@ -12,20 +12,29 @@ class ChannelWiseThreshold(AbstractTransform): - """Threshold the image (shape [C, D, H, W] or [C, H, W]) for each channel + """ + Thresholding the image for given channels. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `ChannelWiseThreshold_channels`: (list/tuple or None) + A list of specified channels for thresholding. If None (by default), + all the channels will be thresholded. + :param `ChannelWiseThreshold_threshold_lower`: (list/tuple or None) + The lower threshold for the given channels. + :param `ChannelWiseThreshold_threshold_upper`: (list/tuple or None) + The upper threshold for the given channels. + :param `ChannelWiseThreshold_replace_lower`: (list/tuple or None) + The output value for pixels with an input value lower than the threshold_lower. + :param `ChannelWiseThreshold_replace_upper`: (list/tuple or None) + The output value for pixels with an input value higher than the threshold_upper. + :param `ChannelWiseThreshold_inverse`: (optional, bool) + Is inverse transform needed for inference. Default is `False`. """ def __init__(self, params): - """ - channels (tuple/list/None): the list of specified channels for thresholding. Default value - is all the channels. - threshold_lower (tuple/list/None): The lower threshold values for specified channels. - threshold_upper (tuple/list/None): The uppoer threshold values for specified channels. - replace_lower (tuple/list/None): new values for pixels with intensity smaller than - threshold_lower. Default value is - replace_upper (tuple/list/None): new values for pixels with intensity larger than threshold_upper. - """ super(ChannelWiseThreshold, self).__init__(params) - self.channlels = params['ChannelWiseThreshold_channels'.lower()] + self.channels = params['ChannelWiseThreshold_channels'.lower()] self.threshold_lower = params['ChannelWiseThreshold_threshold_lower'.lower()] self.threshold_upper = params['ChannelWiseThreshold_threshold_upper'.lower()] self.replace_lower = params['ChannelWiseThreshold_replace_lower'.lower()] @@ -34,7 +43,7 @@ def __init__(self, params): def __call__(self, sample): image= sample['image'] - channels = range(image.shape[0]) if self.channlels is None else self.channlels + channels = range(image.shape[0]) if self.channels is None else self.channels for i in range(len(channels)): chn = channels[i] if((self.threshold_lower is not None) and (self.threshold_lower[i] is not None)): @@ -55,20 +64,32 @@ def __call__(self, sample): class ChannelWiseThresholdWithNormalize(AbstractTransform): """ - Note that this can be replaced by ChannelWiseThreshold + NormalizeWithMinMax + Apply thresholding and normalization for given channels. + Pixel intensity will be truncated to the range of (lower, upper) and then + normalized. If mean_std_mode is True, the mean and std values for pixel + in the target range is calculated for normalization, and input intensity + outside that range will be replaced by random values. Otherwise, the intensity + will be normalized to [0, 1]. - Threshold the image (shape [C, D, H, W] or [C, H, W]) for each channel - and then normalize the image based on remaining pixels + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `ChannelWiseThresholdWithNormalize_channels`: (list/tuple or None) + A list of specified channels for thresholding. If None (by default), + all the channels will be affected by this transform. + :param `ChannelWiseThresholdWithNormalize_threshold_lower`: (list/tuple or None) + The lower threshold for the given channels. + :param `ChannelWiseThresholdWithNormalize_threshold_upper`: (list/tuple or None) + The upper threshold for the given channels. + :param `ChannelWiseThresholdWithNormalize_mean_std_mode`: (bool) + If True, using mean and std for normalization. If False, using min and max + values for normalization. + :param `ChannelWiseThresholdWithNormalize_inverse`: (optional, bool) + Is inverse transform needed for inference. Default is `False`. """ def __init__(self, params): - """ - :param threshold_lower: (tuple/list/None) The lower threshold value along each channel. - :param threshold_upper: (typle/list/None) The upper threshold value along each channel. - :param mean_std_mode: (bool) If true, nomalize the image based on mean and std values, - and pixels values outside the threshold value are replaced random number. - If false, use the min and max values for normalization. - """ super(ChannelWiseThresholdWithNormalize, self).__init__(params) + self.channels = params['ChannelWiseThresholdWithNormalize_channels'.lower()] self.threshold_lower = params['ChannelWiseThresholdWithNormalize_threshold_lower'.lower()] self.threshold_upper = params['ChannelWiseThresholdWithNormalize_threshold_upper'.lower()] self.mean_std_mode = params['ChannelWiseThresholdWithNormalize_mean_std_mode'.lower()] @@ -76,7 +97,8 @@ def __init__(self, params): def __call__(self, sample): image= sample['image'] - for chn in range(image.shape[0]): + channels = range(image.shape[0]) if self.channels is None else self.channels + for chn in channels: v0 = self.threshold_lower[chn] v1 = self.threshold_upper[chn] if(self.mean_std_mode == True): diff --git a/pymic/transform/trans_dict.py b/pymic/transform/trans_dict.py index d90e431..e0ef85c 100644 --- a/pymic/transform/trans_dict.py +++ b/pymic/transform/trans_dict.py @@ -1,9 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division from pymic.transform.intensity import * -from pymic.transform.gray2rgb import GrayscaleToRGB from pymic.transform.flip import RandomFlip -from pymic.transform.intensity import GaussianNoise from pymic.transform.pad import Pad from pymic.transform.rotate import RandomRotate from pymic.transform.rescale import Rescale, RandomRescale diff --git a/pymic/util/average_model.py b/pymic/util/average_model.py deleted file mode 100644 index 73a537f..0000000 --- a/pymic/util/average_model.py +++ /dev/null @@ -1,23 +0,0 @@ - -import torch - -checkpoint_name1 = "/home/guotai/projects/PyMIC/examples/brats/model/casecade/wt/unet3d_4_8000.pt" -checkpoint1 = torch.load(checkpoint_name1) -state_dict1 = checkpoint1['model_state_dict'] - -checkpoint_name2 = "/home/guotai/projects/PyMIC/examples/brats/model/casecade/wt/unet3d_4_10000.pt" -checkpoint2 = torch.load(checkpoint_name2) -state_dict2 = checkpoint2['model_state_dict'] - -checkpoint_name3 = "/home/guotai/projects/PyMIC/examples/brats/model/casecade/wt/unet3d_4_12000.pt" -checkpoint3 = torch.load(checkpoint_name3) -state_dict3 = checkpoint3['model_state_dict'] - -state_dict = {} -for item in state_dict1: - print(item) - state_dict[item] = (state_dict1[item] + state_dict2[item] + state_dict3[item])/3 - -save_dict = {'model_state_dict': state_dict} -save_name = "/home/guotai/projects/PyMIC/examples/brats/model/casecade/wt/unet3d_4_avg.pt" -torch.save(save_dict, save_name) \ No newline at end of file diff --git a/pymic/util/evaluation_cls.py b/pymic/util/evaluation_cls.py index 1aeb3d4..ff7e4e4 100644 --- a/pymic/util/evaluation_cls.py +++ b/pymic/util/evaluation_cls.py @@ -1,4 +1,7 @@ # -*- coding: utf-8 -*- +""" +Evaluation module for classification tasks. +""" from __future__ import absolute_import, print_function import os @@ -16,16 +19,25 @@ from pymic.util.parse_config import parse_config def accuracy(gt_label, pred_label): + """ + Calculate the accuracy. + """ correct_pred = gt_label == pred_label acc = (correct_pred.sum() + 0.0 ) / len(gt_label) return acc def sensitivity(gt_label, pred_label): + """ + Calculate the sensitivity for binary prediction. + """ pos_pred = gt_label * pred_label senst = (pos_pred.sum() + 0.0) / gt_label.sum() return senst def specificity(gt_label, pred_label): + """ + Calculate the specificity for binary prediction. + """ gt_label = 1 - gt_label pred_label = 1 - pred_label neg_pred = gt_label * pred_label @@ -34,8 +46,13 @@ def specificity(gt_label, pred_label): def get_evaluation_score(gt_label, pred_prob, metric): """ - the gt_label is 1-d array - currently only binary classification is considered + Get an evaluation score for binary classification. + + :param gt_label: (array) Ground truth label. + :param pred_prob: (array) Predicted positive probability. + :param metric: (str) One of the evaluation metrics in + {`accuracy`, `recall`, `sensitivity`, `specificity`, + `precision`, `auc`}. """ pred_lab = np.argmax(pred_prob, axis = 1) if(metric == "accuracy"): @@ -53,6 +70,15 @@ def get_evaluation_score(gt_label, pred_prob, metric): return score def binary_evaluation(config): + """ + Evaluation of binary classification performance. + The arguments are given in the `config` dictionary. + It should have the following fields: + + :param metric_list: (list) A list of evaluation metrics. + :param ground_truth_csv: (str) The csv file for ground truth. + :param predict_prob_csv: (str) The csv file for prediction probability. + """ metric_list = config['metric_list'] gt_csv = config['ground_truth_csv'] prob_csv= config['predict_prob_csv'] @@ -79,7 +105,13 @@ def binary_evaluation(config): def nexcl_evaluation(config): """ - evaluation for nonexclusive classification + Evaluation of non-exclusive binary classification performance. + The arguments are given in the `config` dictionary. + It should have the following fields: + + :param metric_list: (list) A list of evaluation metrics. + :param ground_truth_csv: (str) The csv file for ground truth. + :param predict_prob_csv: (str) The csv file for prediction probability. """ metric_list = config['metric_list'] gt_csv = config['ground_truth_csv'] diff --git a/pymic/util/rename_model.py b/pymic/util/model_operate.py similarity index 60% rename from pymic/util/rename_model.py rename to pymic/util/model_operate.py index 02c9ed2..9ba2eb3 100644 --- a/pymic/util/rename_model.py +++ b/pymic/util/model_operate.py @@ -13,6 +13,25 @@ def rename_model_variable(input_file, output_file, input_var_list, output_var_li checkpoint['model_state_dict'] = state_dict torch.save(checkpoint, output_file) + +def get_average_model(checkpoint_name1, checkpoint_name2, checkpoint_name3, save_name): + checkpoint1 = torch.load(checkpoint_name1) + state_dict1 = checkpoint1['model_state_dict'] + + checkpoint2 = torch.load(checkpoint_name2) + state_dict2 = checkpoint2['model_state_dict'] + + checkpoint3 = torch.load(checkpoint_name3) + state_dict3 = checkpoint3['model_state_dict'] + + state_dict = {} + for item in state_dict1: + print(item) + state_dict[item] = (state_dict1[item] + state_dict2[item] + state_dict3[item])/3 + + save_dict = {'model_state_dict': state_dict} + torch.save(save_dict, save_name) + if __name__ == "__main__": input_file = '/home/guotai/disk2t/projects/dlls/training_fetal_brain/exp1/model/unet2dres_bn1_20000.pt' output_file = '/home/guotai/disk2t/projects/dlls/training_fetal_brain/exp1/model/unet2dres_bn1_20000_rename.pt' From 450a8486f0e6f17306ac6f1fe6d27625ea8adefa Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 6 Sep 2022 11:27:28 +0800 Subject: [PATCH 110/225] update docs update docs --- pymic/util/evaluation_cls.py | 22 +++++ pymic/util/evaluation_seg.py | 162 +++++++++++++++++++++++++++-------- pymic/util/general.py | 13 ++- pymic/util/image_process.py | 97 +++++++++++++++------ 4 files changed, 228 insertions(+), 66 deletions(-) diff --git a/pymic/util/evaluation_cls.py b/pymic/util/evaluation_cls.py index ff7e4e4..29e3494 100644 --- a/pymic/util/evaluation_cls.py +++ b/pymic/util/evaluation_cls.py @@ -76,6 +76,8 @@ def binary_evaluation(config): It should have the following fields: :param metric_list: (list) A list of evaluation metrics. + The supported metrics are {`accuracy`, `recall`, `sensitivity`, `specificity`, + `precision`, `auc`}. :param ground_truth_csv: (str) The csv file for ground truth. :param predict_prob_csv: (str) The csv file for prediction probability. """ @@ -110,6 +112,8 @@ def nexcl_evaluation(config): It should have the following fields: :param metric_list: (list) A list of evaluation metrics. + The supported metrics are {`accuracy`, `recall`, `sensitivity`, `specificity`, + `precision`, `auc`}. :param ground_truth_csv: (str) The csv file for ground truth. :param predict_prob_csv: (str) The csv file for prediction probability. """ @@ -153,6 +157,24 @@ def nexcl_evaluation(config): csv_writer.writerow(item) def main(): + """ + Main function for evaluation of classification results. + A configuration file is needed for runing. e.g., + + .. code-block:: none + + pymic_evaluate_cls config.cfg + + The configuration file should have an `evaluation` section with + the following fields: + + :param task_type: (str) `cls` or `cls_nexcl`. + :param metric_list: (list) A list of evaluation metrics. + The supported metrics are {`accuracy`, `recall`, `sensitivity`, `specificity`, + `precision`, `auc`}. + :param ground_truth_csv: (str) The csv file for ground truth. + :param predict_prob_csv: (str) The csv file for prediction probability. + """ if(len(sys.argv) < 2): print('Number of arguments should be 2. e.g.') print(' pymic_evaluate_cls config.cfg') diff --git a/pymic/util/evaluation_seg.py b/pymic/util/evaluation_seg.py index b04880a..ec02297 100644 --- a/pymic/util/evaluation_seg.py +++ b/pymic/util/evaluation_seg.py @@ -1,4 +1,7 @@ # -*- coding: utf-8 -*- +""" +Evaluation module for segmenation tasks. +""" from __future__ import absolute_import, print_function import csv import os @@ -14,13 +17,18 @@ from pymic.util.image_process import * from pymic.util.parse_config import parse_config -# Dice evaluation + def binary_dice(s, g, resize = False): """ - calculate the Dice score of two N-d volumes. - s: the segmentation volume of numpy array - g: the ground truth volume of numpy array - resize: if s and g have different shapes, resize s to match g. + Calculate the Dice score of two N-d volumes for binary segmentation. + + :param s: The segmentation volume of numpy array. + :param g: the ground truth volume of numpy array. + :param resize: (optional, bool) + If s and g have different shapes, resize s to match g. + Default is `True`. + + :return: The Dice value. """ assert(len(s.shape)== len(g.shape)) if(resize): @@ -39,30 +47,43 @@ def binary_dice(s, g, resize = False): return dice def dice_of_images(s_name, g_name): + """ + Calculate the Dice score given the image names of binary segmentation + and ground truth, respectively. + + :param s_name: (str) The filename of segmentation result. + :param g_name: (str) The filename of ground truth. + + :return: The Dice value. + """ s = load_image_as_nd_array(s_name)['data_array'] g = load_image_as_nd_array(g_name)['data_array'] dice = binary_dice(s, g) return dice -# IOU evaluation + def binary_iou(s,g): + """ + Calculate the IoU score of two N-d volumes for binary segmentation. + + :param s: The segmentation volume of numpy array. + :param g: the ground truth volume of numpy array. + + :return: The IoU value. + """ assert(len(s.shape)== len(g.shape)) intersecion = np.multiply(s, g) union = np.asarray(s + g >0, np.float32) iou = (intersecion.sum() + 1e-5)/(union.sum() + 1e-5) return iou -def iou_of_images(s_name, g_name): - s = load_image_as_nd_array(s_name)['data_array'] - g = load_image_as_nd_array(g_name)['data_array'] - margin = (3, 8, 8) - g = get_detection_binary_bounding_box(g, margin) - return binary_iou(s, g) - # Hausdorff and ASSD evaluation def get_edge_points(img): """ - get edge points of a binary segmentation result + Get edge points of a binary segmentation result. + + :param img: (numpy.array) a 2D or 3D array of binary segmentation. + :return: an edge map. """ dim = len(img.shape) if(dim == 2): @@ -76,11 +97,14 @@ def get_edge_points(img): def binary_hd95(s, g, spacing = None): """ - get the hausdorff distance between a binary segmentation and the ground truth - inputs: - s: a 3D or 2D binary image for segmentation - g: a 2D or 2D binary image for ground truth - spacing: a list for image spacing, length should be 3 or 2 + Get the 95 percentile of hausdorff distance between a binary segmentation + and the ground truth. + + :param s: (numpy.array) a 2D or 3D binary image for segmentation. + :param g: (numpy.array) a 2D or 2D binary image for ground truth. + :param spacing: (list) A list for image spacing, length should be 2 or 3. + + :return: The HD95 value. """ s_edge = get_edge_points(s) g_edge = get_edge_points(g) @@ -109,11 +133,14 @@ def binary_hd95(s, g, spacing = None): def binary_assd(s, g, spacing = None): """ - get the average symetric surface distance between a binary segmentation and the ground truth - inputs: - s: a 3D or 2D binary image for segmentation - g: a 2D or 2D binary image for ground truth - spacing: a list for image spacing, length should be 3 or 2 + Get the Average Symetric Surface Distance (ASSD) between a binary segmentation + and the ground truth. + + :param s: (numpy.array) a 2D or 3D binary image for segmentation. + :param g: (numpy.array) a 2D or 2D binary image for ground truth. + :param spacing: (list) A list for image spacing, length should be 2 or 3. + + :return: The ASSD value. """ s_edge = get_edge_points(s) g_edge = get_edge_points(g) @@ -139,14 +166,34 @@ def binary_assd(s, g, spacing = None): return assd # relative volume error evaluation -def binary_relative_volume_error(s_volume, g_volume): - s_v = float(s_volume.sum()) - g_v = float(g_volume.sum()) +def binary_relative_volume_error(s, g): + """ + Get the Relative Volume Error (RVE) between a binary segmentation + and the ground truth. + + :param s: (numpy.array) a 2D or 3D binary image for segmentation. + :param g: (numpy.array) a 2D or 2D binary image for ground truth. + + :return: The RVE value. + """ + s_v = float(s.sum()) + g_v = float(g.sum()) assert(g_v > 0) rve = abs(s_v - g_v)/g_v return rve def get_binary_evaluation_score(s_volume, g_volume, spacing, metric): + """ + Evaluate the performance of binary segmentation using a specified metric. + The metric options are {`dice`, `iou`, `assd`, `hd95`, `rve`, `volume`}. + + :param s_volume: (numpy.array) a 2D or 3D binary image for segmentation. + :param g_volume: (numpy.array) a 2D or 2D binary image for ground truth. + :param spacing: (list) A list for image spacing, length should be 2 or 3. + :param metric: (str) The metric name. + + :return: The metric value. + """ if(len(s_volume.shape) == 4): assert(s_volume.shape[0] == 1 and g_volume.shape[0] == 1) s_volume = np.reshape(s_volume, s_volume.shape[1:]) @@ -158,19 +205,14 @@ def get_binary_evaluation_score(s_volume, g_volume, spacing, metric): if(metric_lower == "dice"): score = binary_dice(s_volume, g_volume) - elif(metric_lower == "iou"): score = binary_iou(s_volume,g_volume) - elif(metric_lower == 'assd'): score = binary_assd(s_volume, g_volume, spacing) - elif(metric_lower == "hd95"): score = binary_hd95(s_volume, g_volume, spacing) - elif(metric_lower == "rve"): score = binary_relative_volume_error(s_volume, g_volume) - elif(metric_lower == "volume"): voxel_size = 1.0 for dim in range(len(spacing)): @@ -182,6 +224,21 @@ def get_binary_evaluation_score(s_volume, g_volume, spacing, metric): return score def get_multi_class_evaluation_score(s_volume, g_volume, label_list, fuse_label, spacing, metric): + """ + Evaluate the segmentation performance using a specified metric for a list of labels. + The metric options are {`dice`, `iou`, `assd`, `hd95`, `rve`, `volume`}. + If `fuse_label` is `True`, the labels in `label_list` will be merged as foreground + and other labels will be merged as background as a binary segmentation result. + + :param s_volume: (numpy.array) A 2D or 3D image for segmentation. + :param g_volume: (numpy.array) A 2D or 2D image for ground truth. + :param label_list: (list) A list of target labels. + :param fuse_label: (bool) Fuse the labels in `label_list` or not. + :param spacing: (list) A list for image spacing, length should be 2 or 3. + :param metric: (str) The metric name. + + :return: The metric value list. + """ if(fuse_label): s_volume_sub = np.zeros_like(s_volume) g_volume_sub = np.zeros_like(g_volume) @@ -198,8 +255,31 @@ def get_multi_class_evaluation_score(s_volume, g_volume, label_list, fuse_label, score_list.append(temp_score) return score_list -def evaluation(config_file): - config = parse_config(config_file)['evaluation'] +def evaluation(config): + """ + Run evaluation of segmentation results based on a configuration dictionary `config`. + The following fields should be provided in `config`: + + :param metric: (str) The metric for evaluation. + The metric options are {`dice`, `iou`, `assd`, `hd95`, `rve`, `volume`}. + :param label_list: (list) The list of labels for evaluation. + :param label_fuse: (option, bool) If true, fuse the labels in the `label_list` + as the foreground, and other labels as the background. Default is False. + :param organ_name: (str) The name of the organ for segmentation. + :param ground_truth_folder_root: (str) The root dir of ground truth images. + :param segmentation_folder_root: (str) The root dir of segmentation images. + :param evaluation_image_pair: (str) The csv file that provide the segmentation + images and the corresponding ground truth images. + :param ground_truth_label_convert_source: (optional, list) The list of source + labels for label conversion in the ground truth. + :param ground_truth_label_convert_target: (optional, list) The list of target + labels for label conversion in the ground truth. + :param segmentation_label_convert_source: (optional, list) The list of source + labels for label conversion in the segmentation. + :param segmentation_label_convert_target: (optional, list) The list of target + labels for label conversion in the segmentation. + """ + metric = config['metric'] label_list = config['label_list'] label_fuse = config.get('label_fuse', False) @@ -271,13 +351,25 @@ def evaluation(config_file): print("{0:} std ".format(metric), score_std) def main(): + """ + Main function for evaluation of segmentation results. + A configuration file is needed for runing. e.g., + + .. code-block:: none + + pymic_evaluate_cls config.cfg + + The configuration file should have an `evaluation` section. + See :mod:`pymic.util.evaluation_seg.evaluation` for details of the configuration required. + """ if(len(sys.argv) < 2): print('Number of arguments should be 2. e.g.') print(' pymic_evaluate_seg config.cfg') exit() config_file = str(sys.argv[1]) assert(os.path.isfile(config_file)) - evaluation(config_file) + config = parse_config(config_file)['evaluation'] + evaluation(config) if __name__ == '__main__': main() diff --git a/pymic/util/general.py b/pymic/util/general.py index 063d654..75b6af1 100644 --- a/pymic/util/general.py +++ b/pymic/util/general.py @@ -4,14 +4,19 @@ import numpy as np def keyword_match(a,b): + """ + Test if two string are the same when converted to lower case. + """ return a.lower() == b.lower() def get_one_hot_seg(label, class_num): """ - convert a segmentation label to one-hot - label: a tensor with a shape of [N, 1, D, H, W] or [N, 1, H, W] - class_num: class number. - output: an one-hot tensor with a shape of [N, C, D, H, W] or [N, C, H, W] + Convert a segmentation label to one-hot. + + :param label: A tensor with a shape of [N, 1, D, H, W] or [N, 1, H, W] + :param class_num: Class number. + + :return: a one-hot tensor with a shape of [N, C, D, H, W] or [N, C, H, W]. """ size = list(label.size()) if(size[1] != 1): diff --git a/pymic/util/image_process.py b/pymic/util/image_process.py index 896e8c1..9a484e8 100644 --- a/pymic/util/image_process.py +++ b/pymic/util/image_process.py @@ -7,46 +7,68 @@ def get_ND_bounding_box(volume, margin = None): """ - get the bounding box of nonzero region in an ND volume + Get the bounding box of nonzero region in an ND volume. + + :param volume: An ND numpy array. + :param margin: (list) + The margin of bounding box along each axis. + + :return bb_min: (list) A list for the minimal value of each axis + of the bounding box. + :return bb_max: (list) A list for the maximal value of each axis + of the bounding box. """ input_shape = volume.shape if(margin is None): margin = [0] * len(input_shape) assert(len(input_shape) == len(margin)) indxes = np.nonzero(volume) - idx_min = [] - idx_max = [] + bb_min = [] + bb_max = [] for i in range(len(input_shape)): - idx_min.append(int(indxes[i].min())) - idx_max.append(int(indxes[i].max()) + 1) + bb_min.append(int(indxes[i].min())) + bb_max.append(int(indxes[i].max()) + 1) for i in range(len(input_shape)): - idx_min[i] = max(idx_min[i] - margin[i], 0) - idx_max[i] = min(idx_max[i] + margin[i], input_shape[i]) - return idx_min, idx_max + bb_min[i] = max(bb_min[i] - margin[i], 0) + bb_max[i] = min(bb_max[i] + margin[i], input_shape[i]) + return bb_min, bb_max -def crop_ND_volume_with_bounding_box(volume, min_idx, max_idx): +def crop_ND_volume_with_bounding_box(volume, bb_min, bb_max): """ - crop/extract a subregion form an nd image. + Extract a subregion form an ND image. + + :param volume: The input ND array. + :param bb_min: (list) The lower bound of the bounding box for each axis. + :param bb_max: (list) The upper bound of the bounding box for each axis. + + :return: A croped ND image. """ dim = len(volume.shape) assert(dim >= 2 and dim <= 5) - assert(max_idx[0] - min_idx[0] <= volume.shape[0]) + assert(bb_max[0] - bb_min[0] <= volume.shape[0]) if(dim == 2): - output = volume[min_idx[0]:max_idx[0], min_idx[1]:max_idx[1]] + output = volume[bb_min[0]:bb_max[0], bb_min[1]:bb_max[1]] elif(dim == 3): - output = volume[min_idx[0]:max_idx[0], min_idx[1]:max_idx[1], min_idx[2]:max_idx[2]] + output = volume[bb_min[0]:bb_max[0], bb_min[1]:bb_max[1], bb_min[2]:bb_max[2]] elif(dim == 4): - output = volume[min_idx[0]:max_idx[0], min_idx[1]:max_idx[1], min_idx[2]:max_idx[2], min_idx[3]:max_idx[3]] + output = volume[bb_min[0]:bb_max[0], bb_min[1]:bb_max[1], bb_min[2]:bb_max[2], bb_min[3]:bb_max[3]] elif(dim == 5): - output = volume[min_idx[0]:max_idx[0], min_idx[1]:max_idx[1], min_idx[2]:max_idx[2], min_idx[3]:max_idx[3], min_idx[4]:max_idx[4]] + output = volume[bb_min[0]:bb_max[0], bb_min[1]:bb_max[1], bb_min[2]:bb_max[2], bb_min[3]:bb_max[3], bb_min[4]:bb_max[4]] else: raise ValueError("the dimension number shoud be 2 to 5") return output def set_ND_volume_roi_with_bounding_box_range(volume, bb_min, bb_max, sub_volume, addition = True): """ - set a subregion to an nd image. if addition is True, the original volume is added by the subregion. + Set the subregion of an ND image. If `addition` is `True`, the original volume is added by the given sub volume. + + :param volume: The input ND volume. + :param bb_min: (list) The lower bound of the bounding box for each axis. + :param bb_max: (list) The upper bound of the bounding box for each axis. + :param sub_volume: The sub volume to replace the target region of the orginal volume. + :param addition: (optional, bool) If True, the sub volume will be added + to the target region of the input volume. """ dim = len(bb_min) out = volume @@ -75,6 +97,13 @@ def set_ND_volume_roi_with_bounding_box_range(volume, bb_min, bb_max, sub_volume return out def crop_and_pad_ND_array_to_desired_shape(image, out_shape, pad_mod): + """ + Crop and pad an image to a given shape. + + :param image: The input ND array. + :param out_shape: (list) The desired output shape. + :param pad_mod: (str) See `numpy.pad `_ + """ in_shape = image.shape dim = len(in_shape) crop_shape = [min(out_shape[i], in_shape[i]) for i in range(dim)] @@ -109,8 +138,12 @@ def crop_and_pad_ND_array_to_desired_shape(image, out_shape, pad_mod): def get_largest_k_components(image, k = 1): """ - get the largest K components from 2D or 3D binary image - image: nd array + Get the largest K components from 2D or 3D binary image. + + :param image: The input ND array for binary segmentation. + :param k: (int) The value of k. + + :return: An output array with only the largest K components of the input. """ dim = len(image.shape) if(image.sum() == 0 ): @@ -131,8 +164,13 @@ def get_largest_k_components(image, k = 1): def get_euclidean_distance(image, dim = 3, spacing = [1.0, 1.0, 1.0]): """ - get euclidean distance transform of 2D or 3D binary images - the output distance map is unsigned + Get euclidean distance transform of 3D binary images. + The output distance map is unsigned. + + :param image: The input 3D array. + :param dim: (int) Using 2D (dim = 2) or 3D (dim = 3) distance transforms. + :param spacing: (list) The spacing along each axis. + """ img_shape = image.shape input_dim = len(img_shape) @@ -155,10 +193,11 @@ def get_euclidean_distance(image, dim = 3, spacing = [1.0, 1.0, 1.0]): def convert_label(label, source_list, target_list): """ - convert a label map based a source list and a target list of labels - label: nd array - source_list: a list of labels that will be converted, e.g. [0, 1, 2, 4] - target_list: a list of target labels, e.g. [0, 1, 2, 3] + Convert a label map based a source list and a target list of labels + + :param label: (numpy.array) The input label map. + :param source_list: A list of labels that will be converted, e.g. [0, 1, 2, 4] + :param target_list: A list of target labels, e.g. [0, 1, 2, 3] """ assert(len(source_list) == len(target_list)) label_converted = np.zeros_like(label) @@ -170,9 +209,13 @@ def convert_label(label, source_list, target_list): def resample_sitk_image_to_given_spacing(image, spacing, order): """ - image: an sitk image object - spacing: 3D tuple / list for spacing along x, y, z direction - order: order for interpolation + Resample an sitk image objct to a given spacing. + + :param image: The input sitk image object. + :param spacing: (list/tuple) Target spacing along x, y, z direction. + :param order: (int) Order for interpolation. + + :return: A resampled sitk image object. """ spacing0 = image.GetSpacing() data = sitk.GetArrayFromImage(image) From 451a85c3ce2cb23321eee1e8531cfd5fdb1b6cce Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 7 Sep 2022 22:46:49 +0800 Subject: [PATCH 111/225] update docs of loss and network update docs of loss and network --- docs/source/pymic.loss.cls.rst | 28 +------ docs/source/pymic.loss.seg.rst | 8 ++ docs/source/usage.fsl.rst | 2 +- pymic/io/image_read_write.py | 64 ++++++--------- pymic/io/nifty_dataset.py | 27 ++++--- pymic/loss/cls/basic.py | 107 ++++++++++++++++++++++++++ pymic/loss/cls/ce.py | 42 ---------- pymic/loss/cls/l1.py | 52 ------------- pymic/loss/cls/mse.py | 22 ------ pymic/loss/cls/nll.py | 18 ----- pymic/loss/cls/util.py | 9 ++- pymic/loss/loss_dict_cls.py | 15 +++- pymic/loss/loss_dict_seg.py | 31 +++++++- pymic/loss/seg/abstract.py | 17 ++-- pymic/loss/seg/ce.py | 22 +++--- pymic/loss/seg/combined.py | 6 +- pymic/loss/seg/deep_sup.py | 4 +- pymic/loss/seg/dice.py | 18 ++--- pymic/loss/seg/exp_log.py | 20 +---- pymic/loss/seg/gatedcrf.py | 13 ++-- pymic/loss/seg/mse.py | 16 +--- pymic/loss/seg/mumford_shah.py | 3 +- pymic/loss/seg/slsr.py | 5 +- pymic/loss/seg/ssl.py | 16 +--- pymic/net/net2d/cople_net.py | 52 +++++++++---- pymic/net/net2d/scse2d.py | 52 ++++++------- pymic/net/net2d/unet2d.py | 98 +++++++++++++++++------ pymic/net/net2d/unet2d_cct.py | 51 ++++++++---- pymic/net/net2d/unet2d_dual_branch.py | 25 ++++-- pymic/net/net2d/unet2d_nest.py | 30 +++++--- pymic/net/net2d/unet2d_scse.py | 40 +++++++--- pymic/net/net_dict_seg.py | 15 ++++ pymic/net_run_wsl/wsl_gatedcrf.py | 4 +- pymic/transform/trans_dict.py | 29 +++++++ pymic/util/post_process.py | 17 +++- pymic/util/preprocess.py | 16 +++- pymic/util/ramps.py | 58 ++++++++++---- 37 files changed, 614 insertions(+), 438 deletions(-) create mode 100644 pymic/loss/cls/basic.py delete mode 100644 pymic/loss/cls/ce.py delete mode 100644 pymic/loss/cls/l1.py delete mode 100644 pymic/loss/cls/mse.py delete mode 100644 pymic/loss/cls/nll.py diff --git a/docs/source/pymic.loss.cls.rst b/docs/source/pymic.loss.cls.rst index 73b816b..6f4cfca 100644 --- a/docs/source/pymic.loss.cls.rst +++ b/docs/source/pymic.loss.cls.rst @@ -4,34 +4,10 @@ pymic.loss.cls package Submodules ---------- -pymic.loss.cls.ce module +pymic.loss.cls.basic module ------------------------ -.. automodule:: pymic.loss.cls.ce - :members: - :undoc-members: - :show-inheritance: - -pymic.loss.cls.l1 module ------------------------- - -.. automodule:: pymic.loss.cls.l1 - :members: - :undoc-members: - :show-inheritance: - -pymic.loss.cls.mse module -------------------------- - -.. automodule:: pymic.loss.cls.mse - :members: - :undoc-members: - :show-inheritance: - -pymic.loss.cls.nll module -------------------------- - -.. automodule:: pymic.loss.cls.nll +.. automodule:: pymic.loss.cls.basic :members: :undoc-members: :show-inheritance: diff --git a/docs/source/pymic.loss.seg.rst b/docs/source/pymic.loss.seg.rst index d37f41f..d858bda 100644 --- a/docs/source/pymic.loss.seg.rst +++ b/docs/source/pymic.loss.seg.rst @@ -4,6 +4,14 @@ pymic.loss.seg package Submodules ---------- +pymic.loss.seg.abstract module +------------------------ + +.. automodule:: pymic.loss.seg.abstract + :members: + :undoc-members: + :show-inheritance: + pymic.loss.seg.ce module ------------------------ diff --git a/docs/source/usage.fsl.rst b/docs/source/usage.fsl.rst index 585d7ec..053daf7 100644 --- a/docs/source/usage.fsl.rst +++ b/docs/source/usage.fsl.rst @@ -209,7 +209,7 @@ hyper-parameters. For example, the following is a configuration for using ``2DUN bilinear = False deep_supervise= False -The ``SegNetDict`` in :mod:`pymic.net.neg_dict_seg` lists all the built-in network +The ``SegNetDict`` in :mod:`pymic.net.net_dict_seg` lists all the built-in network structures currently implemented in PyMIC. You can also define your own networks. To integrate your customized diff --git a/pymic/io/image_read_write.py b/pymic/io/image_read_write.py index e06924e..cb65e19 100644 --- a/pymic/io/image_read_write.py +++ b/pymic/io/image_read_write.py @@ -10,15 +10,12 @@ def load_nifty_volume_as_4d_array(filename): """ Read a nifty image and return a dictionay storing data array, origin, spacing and direction.\n - output['data_array'] 4d array with shape [C, D, H, W];\n - output['spacing'] a list of spacing in z, y, x axis;\n - output['direction'] a 3x3 matrix for direction. + output['data_array'] 4D array with shape [C, D, H, W];\n + output['spacing'] A list of spacing in z, y, x axis;\n + output['direction'] A 3x3 matrix for direction. - Args: - filename (str): the input file name - - Returns: - dict: a dictionay storing data array, origin, spacing and direction. + :param filename: (str) The input file name + :return: A dictionay storing data array, origin, spacing and direction. """ img_obj = sitk.ReadImage(filename) data_array = sitk.GetArrayFromImage(img_obj) @@ -43,15 +40,12 @@ def load_rgb_image_as_3d_array(filename): """ Read an RGB image and return a dictionay storing data array, origin, spacing and direction. \n - output['data_array'] 3d array with shape [D, H, W]; \n + output['data_array'] 3D array with shape [D, H, W]; \n output['spacing'] a list of spacing in z, y, x axis; \n output['direction'] a 3x3 matrix for direction. - Args: - filename (str): the input file name - - Returns: - dict: a dictionay storing data array, origin, spacing and direction. + :param filename: (str) The input file name + :return: A dictionay storing data array, origin, spacing and direction. """ image = np.asarray(Image.open(filename)) image_shape = image.shape @@ -74,14 +68,11 @@ def load_rgb_image_as_3d_array(filename): def load_image_as_nd_array(image_name): """ - load an image and return a 4D array with shape [C, D, H, W], + Load an image and return a 4D array with shape [C, D, H, W], or 3D array with shape [C, H, W]. - Args: - image_name (str): the image name. - - Returns: - dict: a dictionay storing data array, origin, spacing and direction. + :param filename: (str) The input file name + :return: A dictionay storing data array, origin, spacing and direction. """ if (image_name.endswith(".nii.gz") or image_name.endswith(".nii") or image_name.endswith(".mha")): @@ -97,10 +88,9 @@ def save_array_as_nifty_volume(data, image_name, reference_name = None): """ Save a numpy array as nifty image - Args: - data (numpy.ndarray): a numpy array with shape [Depth, Height, Width].\n - image_name (str): the ouput file name.\n - reference_name (str): file name of the reference image of which + :param data: (numpy.ndarray) A numpy array with shape [Depth, Height, Width]. + :param image_name: (str) The ouput file name. + :param reference_name: (str) File name of the reference image of which meta information is used. """ img = sitk.GetImageFromArray(data) @@ -114,12 +104,11 @@ def save_array_as_nifty_volume(data, image_name, reference_name = None): def save_array_as_rgb_image(data, image_name): """ - Save a numpy array as rgb image + Save a numpy array as rgb image. - Args: - data (numpy.ndarray): a numpy array with shape [3, H, W] or - [H, W, 3] or [H, W]. \n - image_name (str): the output file name. + :param data: (numpy.ndarray) A numpy array with shape [3, H, W] or + [H, W, 3] or [H, W]. + :param image_name: (str) The output file name. """ data_dim = len(data.shape) if(data_dim == 3): @@ -133,10 +122,9 @@ def save_nd_array_as_image(data, image_name, reference_name = None): """ Save a 3D or 2D numpy array as medical image or RGB image - Args: - data (numpy.ndarray): a numpy array with shape [D, H, W] or [C, H, W]. \n - image_name (str): the output file name. \n - reference_name (str): file name of the reference image of which + :param data: (numpy.ndarray) A numpy array with shape [3, H, W] or + [H, W, 3] or [H, W]. + :param reference_name: (str) File name of the reference image of which meta information is used. """ data_dim = len(data.shape) @@ -158,16 +146,14 @@ def rotate_nifty_volume_to_LPS(filename_or_image_dict, origin = None, direction ''' Rotate the axis of a 3D volume to LPS - Args: - filename_or_image_dict (str): filename of the nifty file (str) or image dictionary + :param filename_or_image_dict: (str) Filename of the nifty file (str) or image dictionary returned by load_nifty_volume_as_4d_array. If supplied with the former, the flipped image data will be saved to override the original file. If supplied with the later, only flipped image data will be returned.\n - origin (list or tuple): the origin of the image.\n - direction (list or tuple): the direction of the image. + :param origin: (list/tuple) The origin of the image. + :param direction: (list or tuple) The direction of the image. - Returns: - dict: a dictionary for image data and meta info, with ``data_array``, + :return: A dictionary for image data and meta info, with ``data_array``, ``origin``, ``direction`` and ``spacing``. ''' diff --git a/pymic/io/nifty_dataset.py b/pymic/io/nifty_dataset.py index f09d24f..bb1ff23 100644 --- a/pymic/io/nifty_dataset.py +++ b/pymic/io/nifty_dataset.py @@ -15,13 +15,12 @@ class NiftyDataset(Dataset): dimention order [C, D, H, W] for 3D images, and 3D tensors with dimention order [C, H, W] for 2D images. - Args: - root_dir (str): Directory with all the images. \n - csv_file (str): Path to the csv file with image names. \n - modal_num (int): Number of modalities. \n - with_label (bool): Load the data with segmentation ground truth or not. \n - with_weight(bool): Load pixel-wise weight map or not. \n - transform (list): list of transform to be applied on a sample. + :param root_dir: (str) Directory with all the images. + :param csv_file: (str) Path to the csv file with image names. + :param modal_num: (int) Number of modalities. + :param with_label: (bool) Load the data with segmentation ground truth or not. + :param transform: (list) List of transforms to be applied on a sample. + The built-in transforms can listed in :mod:`pymic.transform.trans_dict`. """ def __init__(self, root_dir, csv_file, modal_num = 1, with_label = False, transform=None): @@ -93,13 +92,13 @@ class ClassificationDataset(NiftyDataset): dimention order [C, D, H, W] for 3D images, and 3D tensors with dimention order [C, H, W] for 2D images. - Args: - root_dir (str): Directory with all the images. \n - csv_file (str): Path to the csv file with image names. \n - modal_num (int): Number of modalities. \n - class_num (int): class number of the classificaiton task. \n - with_label (bool): Load the data with segmentation ground truth or not. \n - transform (list): list of transform to be applied on a sample. + :param root_dir: (str) Directory with all the images. + :param csv_file: (str) Path to the csv file with image names. + :param modal_num: (int) Number of modalities. + :param class_num: (int) Class number of the classificaiton task. + :param with_label: (bool) Load the data with segmentation ground truth or not. + :param transform: (list) List of transforms to be applied on a sample. + The built-in transforms can listed in :mod:`pymic.transform.trans_dict`. """ def __init__(self, root_dir, csv_file, modal_num = 1, class_num = 2, with_label = False, transform=None): diff --git a/pymic/loss/cls/basic.py b/pymic/loss/cls/basic.py new file mode 100644 index 0000000..4c90943 --- /dev/null +++ b/pymic/loss/cls/basic.py @@ -0,0 +1,107 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import torch.nn as nn + +class AbstractClassificationLoss(nn.Module): + """ + Abstract Classification Loss. + """ + def __init__(self, params = None): + super(AbstractClassificationLoss, self).__init__() + + def forward(self, loss_input_dict): + """ + The arguments should be written in the `loss_input_dict` dictionary, and it has the + following fields. + + :param prediction: A prediction with shape of [N, C] where C is the class number. + :param ground_truth: The corresponding ground truth, with shape of [N, 1]. + + Note that `prediction` is the digit output of a network, before using softmax. + """ + pass + +class CrossEntropyLoss(AbstractClassificationLoss): + """ + Standard Softmax-based CE loss. + """ + def __init__(self, params = None): + super(CrossEntropyLoss, self).__init__(params) + self.ce_loss = nn.CrossEntropyLoss() + + def forward(self, loss_input_dict): + predict = loss_input_dict['prediction'] + labels = loss_input_dict['ground_truth'] + loss = self.ce_loss(predict, labels) + return loss + +class SigmoidCELoss(AbstractClassificationLoss): + """ + Sigmoid-based CE loss. + """ + def __init__(self, params = None): + super(SigmoidCELoss, self).__init__(params) + + def forward(self, loss_input_dict): + predict = loss_input_dict['prediction'] + labels = loss_input_dict['ground_truth'] + predict = nn.Sigmoid()(predict) * 0.999 + 5e-4 + loss = - labels * torch.log(predict) - (1 - labels) * torch.log( 1 - predict) + loss = loss.mean() + return loss + +class L1Loss(AbstractClassificationLoss): + """ + L1 (MAE) loss for classification + """ + def __init__(self, params = None): + super(L1Loss, self).__init__(params) + self.l1_loss = nn.L1Loss() + + def forward(self, loss_input_dict): + predict = loss_input_dict['prediction'] + labels = loss_input_dict['ground_truth'][:, None] # reshape to N, 1 + softmax = nn.Softmax(dim = 1) + predict = softmax(predict) + num_class = list(predict.size())[1] + data_type = 'float' if(predict.dtype is torch.float32) else 'double' + soft_y = get_soft_label(labels, num_class, data_type) + loss = self.l1_loss(predict, soft_y) + return loss + +class MSELoss(AbstractClassificationLoss): + """ + Mean Square Error loss for classification. + """ + def __init__(self, params = None): + super(MSELoss, self).__init__(params) + self.mse_loss = nn.MSELoss() + + def forward(self, loss_input_dict): + predict = loss_input_dict['prediction'] + labels = loss_input_dict['ground_truth'][:, None] # reshape to N, 1 + softmax = nn.Softmax(dim = 1) + predict = softmax(predict) + num_class = list(predict.size())[1] + data_type = 'float' if(predict.dtype is torch.float32) else 'double' + soft_y = get_soft_label(labels, num_class, data_type) + loss = self.mse_loss(predict, soft_y) + return loss + +class NLLLoss(AbstractClassificationLoss): + """ + The negative log likelihood loss for classification. + """ + def __init__(self, params = None): + super(NLLLoss, self).__init__(params) + self.nll_loss = nn.NLLLoss() + + def forward(self, loss_input_dict): + predict = loss_input_dict['prediction'] + labels = loss_input_dict['ground_truth'] + logsoft = nn.LogSoftmax(dim = 1) + predict = logsoft(predict) + loss = self.nll_loss(predict, labels) + return loss \ No newline at end of file diff --git a/pymic/loss/cls/ce.py b/pymic/loss/cls/ce.py deleted file mode 100644 index 5313c6e..0000000 --- a/pymic/loss/cls/ce.py +++ /dev/null @@ -1,42 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division - -import torch -import torch.nn as nn - -class CrossEntropyLoss(nn.Module): - """ - Standard Softmax-based CE loss - Args: - predict has a shape of [N, C] where C is the class number - labels has a shape of [N] - - note that predict is the digit output of a network, before using softmax - """ - def __init__(self, params): - super(CrossEntropyLoss, self).__init__() - self.ce_loss = nn.CrossEntropyLoss() - - def forward(self, loss_input_dict): - predict = loss_input_dict['prediction'] - labels = loss_input_dict['ground_truth'] - loss = self.ce_loss(predict, labels) - return loss - -class SigmoidCELoss(nn.Module): - """ - Sigmoid-based CE loss - Args: - predict has a shape of [N, C] where C is the class number - labels has a shape of [N, C] with binary values - note that predict is the digit output of a network, before using sigmoid.""" - def __init__(self, params): - super(SigmoidCELoss, self).__init__() - - def forward(self, loss_input_dict): - predict = loss_input_dict['prediction'] - labels = loss_input_dict['ground_truth'] - predict = nn.Sigmoid()(predict) * 0.999 + 5e-4 - loss = - labels * torch.log(predict) - (1 - labels) * torch.log( 1 - predict) - loss = loss.mean() - return loss \ No newline at end of file diff --git a/pymic/loss/cls/l1.py b/pymic/loss/cls/l1.py deleted file mode 100644 index 57a4d7a..0000000 --- a/pymic/loss/cls/l1.py +++ /dev/null @@ -1,52 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division - -import torch -import torch.nn as nn -from pymic.loss.cls.util import get_soft_label - -class L1Loss(nn.Module): - """ - L1 (MAE) loss for classification - """ - def __init__(self, params): - super(L1Loss, self).__init__() - self.l1_loss = nn.L1Loss() - - def forward(self, loss_input_dict): - predict = loss_input_dict['prediction'] - labels = loss_input_dict['ground_truth'][:, None] # reshape to N, 1 - softmax = nn.Softmax(dim = 1) - predict = softmax(predict) - num_class = list(predict.size())[1] - data_type = 'float' if(predict.dtype is torch.float32) else 'double' - soft_y = get_soft_label(labels, num_class, data_type) - loss = self.l1_loss(predict, soft_y) - return loss - -class RectifiedLoss(nn.Module): - def __init__(self, params): - super(RectifiedLoss, self).__init__() - # self.l1_loss = nn.L1Loss() - - def forward(self, loss_input_dict): - predict = loss_input_dict['prediction'] - labels = loss_input_dict['ground_truth'][:, None] # reshape to N, 1 - - # softmax = nn.Softmax(dim = 1) - # predict = softmax(predict) - num_class = list(predict.size())[1] - data_type = 'float' if(predict.dtype is torch.float32) else 'double' - soft_y = get_soft_label(labels, num_class, data_type) - g = 2* soft_y - 1 - loss = torch.exp((g*1.5- predict) * g) - mask = predict < g - if (data_type == 'float'): - mask = mask.float() - else: - mask = mask.double() - w = (mask - 0.5) * g + 0.5 - loss = w * loss + 0.1*(g - predict) * (g - predict) - loss = loss.mean() - # loss = self.l1_loss(predict, soft_y) - return loss \ No newline at end of file diff --git a/pymic/loss/cls/mse.py b/pymic/loss/cls/mse.py deleted file mode 100644 index 179e207..0000000 --- a/pymic/loss/cls/mse.py +++ /dev/null @@ -1,22 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division - -import torch -import torch.nn as nn -from pymic.loss.cls.util import get_soft_label - -class MSELoss(nn.Module): - def __init__(self, params): - super(MSELoss, self).__init__() - self.mse_loss = nn.MSELoss() - - def forward(self, loss_input_dict): - predict = loss_input_dict['prediction'] - labels = loss_input_dict['ground_truth'][:, None] # reshape to N, 1 - softmax = nn.Softmax(dim = 1) - predict = softmax(predict) - num_class = list(predict.size())[1] - data_type = 'float' if(predict.dtype is torch.float32) else 'double' - soft_y = get_soft_label(labels, num_class, data_type) - loss = self.mse_loss(predict, soft_y) - return loss diff --git a/pymic/loss/cls/nll.py b/pymic/loss/cls/nll.py deleted file mode 100644 index ef37181..0000000 --- a/pymic/loss/cls/nll.py +++ /dev/null @@ -1,18 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division - -import torch -import torch.nn as nn - -class NLLLoss(nn.Module): - def __init__(self, params): - super(NLLLoss, self).__init__() - self.nll_loss = nn.NLLLoss() - - def forward(self, loss_input_dict): - predict = loss_input_dict['prediction'] - labels = loss_input_dict['ground_truth'] - logsoft = nn.LogSoftmax(dim = 1) - predict = logsoft(predict) - loss = self.nll_loss(predict, labels) - return loss diff --git a/pymic/loss/cls/util.py b/pymic/loss/cls/util.py index b0a25b1..87e2e55 100644 --- a/pymic/loss/cls/util.py +++ b/pymic/loss/cls/util.py @@ -7,9 +7,12 @@ def get_soft_label(input_tensor, num_class, data_type = 'float'): """ - convert a label tensor to one-hot soft label - input_tensor: tensor with shae [B, 1] - output_tensor: shape [B, num_class] + Convert a label tensor to one-hot soft label. + + :param input_tensor: Tensor with shape of [B, 1]. + :param output_tensor: Tensor with shape of [B, num_class]. + :param num_class: (int) Class number. + :param data_type: (str) `float` or `double`. """ tensor_list = [] for i in range(num_class): diff --git a/pymic/loss/loss_dict_cls.py b/pymic/loss/loss_dict_cls.py index 7d723ab..e07f46b 100644 --- a/pymic/loss/loss_dict_cls.py +++ b/pymic/loss/loss_dict_cls.py @@ -1,9 +1,16 @@ # -*- coding: utf-8 -*- +""" +Built-in loss functions for classification. + +* CrossEntropyLoss :mod:`pymic.loss.cls.basic.CrossEntropyLoss` +* SigmoidCELoss :mod:`pymic.loss.cls.basic.SigmoidCELoss` +* L1Loss :mod:`pymic.loss.cls.basic.L1Loss` +* MSELoss :mod:`pymic.loss.cls.basic.MSELoss` +* NLLLoss :mod:`pymic.loss.cls.basic.NLLLoss` + +""" from __future__ import print_function, division -from pymic.loss.cls.ce import CrossEntropyLoss, SigmoidCELoss -from pymic.loss.cls.l1 import L1Loss -from pymic.loss.cls.nll import NLLLoss -from pymic.loss.cls.mse import MSELoss +from pymic.loss.cls.basic import * PyMICClsLossDict = {"CrossEntropyLoss": CrossEntropyLoss, "SigmoidCELoss": SigmoidCELoss, diff --git a/pymic/loss/loss_dict_seg.py b/pymic/loss/loss_dict_seg.py index a8a53ad..97c537e 100644 --- a/pymic/loss/loss_dict_seg.py +++ b/pymic/loss/loss_dict_seg.py @@ -1,19 +1,42 @@ # -*- coding: utf-8 -*- +""" +Built-in loss functions for segmentation. +The following are for fully supervised learning, or learnig from noisy labels: + +* CrossEntropyLoss :mod:`pymic.loss.seg.ce.CrossEntropyLoss` +* GeneralizedCELoss :mod:`pymic.loss.seg.ce.GeneralizedCELoss` +* DiceLoss :mod:`pymic.loss.seg.dice.DiceLoss` +* FocalDiceLoss :mod:`pymic.loss.seg.dice.FocalDiceLoss` +* NoiseRobustDiceLoss :mod:`pymic.loss.seg.dice.NoiseRobustDiceLoss` +* ExpLogLoss :mod:`pymic.loss.seg.exp_log.ExpLogLoss` +* MAELoss :mod:`pymic.loss.seg.mse.MAELoss` +* MSELoss :mod:`pymic.loss.seg.mse.MSELoss` +* SLSRLoss :mod:`pymic.loss.seg.slsr.SLSRLoss` + +The following are for semi-supervised or weakly supervised learning: + +* EntropyLoss :mod:`pymic.loss.seg.ssl.EntropyLoss` +* GatedCRFLoss: :mod:`pymic.loss.seg.gatedcrf.GatedCRFLoss` +* MumfordShahLoss :mod:`pymic.loss.seg.mumford_shah.MumfordShahLoss` +* TotalVariationLoss :mod:`pymic.loss.seg.ssl.TotalVariationLoss` +""" from __future__ import print_function, division import torch.nn as nn from pymic.loss.seg.ce import CrossEntropyLoss, GeneralizedCELoss from pymic.loss.seg.dice import DiceLoss, FocalDiceLoss, NoiseRobustDiceLoss -from pymic.loss.seg.slsr import SLSRLoss from pymic.loss.seg.exp_log import ExpLogLoss from pymic.loss.seg.mse import MSELoss, MAELoss +from pymic.loss.seg.slsr import SLSRLoss -SegLossDict = {'CrossEntropyLoss': CrossEntropyLoss, +SegLossDict = { + 'CrossEntropyLoss': CrossEntropyLoss, 'GeneralizedCELoss': GeneralizedCELoss, - 'SLSRLoss': SLSRLoss, 'DiceLoss': DiceLoss, 'FocalDiceLoss': FocalDiceLoss, 'NoiseRobustDiceLoss': NoiseRobustDiceLoss, 'ExpLogLoss': ExpLogLoss, + 'MAELoss': MAELoss, 'MSELoss': MSELoss, - 'MAELoss': MAELoss} + 'SLSRLoss': SLSRLoss + } diff --git a/pymic/loss/seg/abstract.py b/pymic/loss/seg/abstract.py index b5af5a0..f42d816 100644 --- a/pymic/loss/seg/abstract.py +++ b/pymic/loss/seg/abstract.py @@ -6,16 +6,20 @@ class AbstractSegLoss(nn.Module): """ - Cross entropy loss for segmentation tasks. - The arguments should be written in the `params` dictionary, and it has the + Abstract class for loss function of segmentation tasks. + The parameters should be written in the `params` dictionary, and it has the following fields: - Args: - `loss_softmax` (bool): Apply softmax to the prediction of network or not. \n + :param `loss_softmax`: (optional, bool) + Apply softmax to the prediction of network or not. Default is True. """ def __init__(self, params = None): super(AbstractSegLoss, self).__init__() - + if(params is None): + self.softmax = True + else: + self.softmax = params.get('loss_softmax', True) + def forward(self, loss_input_dict): """ Forward pass for calculating the loss. @@ -26,7 +30,8 @@ def forward(self, loss_input_dict): shape of [N, C, D, H, W] or [N, C, H, W]. :param `ground_truth`: (tensor) Ground truth, with the shape of [N, C, D, H, W] or [N, C, H, W]. - + :param `pixel_weight`: (optional) Pixel-wise weight map, with the + shape of [N, 1, D, H, W] or [N, 1, H, W]. Default is None. :return: Loss function value. """ pass diff --git a/pymic/loss/seg/ce.py b/pymic/loss/seg/ce.py index f9ee8ea..bbe6d02 100644 --- a/pymic/loss/seg/ce.py +++ b/pymic/loss/seg/ce.py @@ -9,17 +9,16 @@ class CrossEntropyLoss(AbstractSegLoss): """ Cross entropy loss for segmentation tasks. - The arguments should be written in the `params` dictionary, and it has the + + The parameters should be written in the `params` dictionary, and it has the following fields: - :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not. + :param `loss_softmax`: (optional, bool) + Apply softmax to the prediction of network or not. Default is True. """ def __init__(self, params = None): - super(CrossEntropyLoss, self).__init__() - if(params is None): - self.softmax = True - else: - self.softmax = params.get('loss_softmax', True) + super(CrossEntropyLoss, self).__init__(params) + def forward(self, loss_input_dict): predict = loss_input_dict['prediction'] @@ -48,10 +47,10 @@ class GeneralizedCELoss(AbstractSegLoss): """ Generalized cross entropy loss to deal with noisy labels. - Reference: Z. Zhang et al. Generalized Cross Entropy Loss for Training Deep Neural Networks - with Noisy Labels, NeurIPS 2018. + * Reference: Z. Zhang et al. Generalized Cross Entropy Loss for Training Deep Neural Networks + with Noisy Labels, NeurIPS 2018. - The arguments should be written in the `params` dictionary, and it has the + The parameters should be written in the `params` dictionary, and it has the following fields: :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not. @@ -61,8 +60,7 @@ class GeneralizedCELoss(AbstractSegLoss): """ def __init__(self, params): - super(GeneralizedCELoss, self).__init__() - self.softmax = params.get('loss_softmax', True) + super(GeneralizedCELoss, self).__init__(params) self.q = params.get('loss_gce_q', 0.5) self.enable_pix_weight = params.get('loss_with_pixel_weight', False) self.cls_weight = params.get('loss_class_weight', None) diff --git a/pymic/loss/seg/combined.py b/pymic/loss/seg/combined.py index f4c4431..4e9aad2 100644 --- a/pymic/loss/seg/combined.py +++ b/pymic/loss/seg/combined.py @@ -8,15 +8,17 @@ class CombinedLoss(AbstractSegLoss): ''' A combination of a list of loss functions. - Arguments should be saved in the `params` dictionary. + Parameters should be saved in the `params` dictionary. + :param `loss_softmax`: (optional, bool) + Apply softmax to the prediction of network or not. Default is True. :param `loss_type`: (list) A list of loss function name. :param `loss_weight`: (list) A list of weights for each loss fucntion. :param loss_dict: (dictionary) A dictionary of avaiable loss functions. ''' def __init__(self, params, loss_dict): - super(CombinedLoss, self).__init__() + super(CombinedLoss, self).__init__(params) loss_names = params['loss_type'] self.loss_weight = params['loss_weight'] assert (len(loss_names) == len(self.loss_weight)) diff --git a/pymic/loss/seg/deep_sup.py b/pymic/loss/seg/deep_sup.py index 7362fd8..42ce172 100644 --- a/pymic/loss/seg/deep_sup.py +++ b/pymic/loss/seg/deep_sup.py @@ -10,12 +10,14 @@ class DeepSuperviseLoss(AbstractSegLoss): Arguments should be provided in the `params` dictionary, and it has the following fields: + :param `loss_softmax`: (optional, bool) + Apply softmax to the prediction of network or not. Default is True. :param `deep_suervise_weight`: (list) A list of weight for each deep supervision scale. \n :param `base_loss`: (nn.Module) The basic function used for each scale. ''' def __init__(self, params): - super(DeepSuperviseLoss, self).__init__() + super(DeepSuperviseLoss, self).__init__(params) self.deep_sup_weight = params.get('deep_suervise_weight', None) self.base_loss = params['base_loss'] diff --git a/pymic/loss/seg/dice.py b/pymic/loss/seg/dice.py index c207f37..c3a1134 100644 --- a/pymic/loss/seg/dice.py +++ b/pymic/loss/seg/dice.py @@ -9,17 +9,13 @@ class DiceLoss(AbstractSegLoss): ''' Dice loss for segmentation tasks. - The arguments should be written in the `params` dictionary, and it has the + The parameters should be written in the `params` dictionary, and it has the following fields: :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not. ''' def __init__(self, params = None): - super(DiceLoss, self).__init__() - if(params is None): - self.softmax = True - else: - self.softmax = params.get('loss_softmax', True) + super(DiceLoss, self).__init__(params) def forward(self, loss_input_dict): predict = loss_input_dict['prediction'] @@ -42,15 +38,14 @@ class FocalDiceLoss(AbstractSegLoss): * Pei Wang and Albert C. S. Chung, Focal Dice Loss and Image Dilation for Brain Tumor Segmentation, 2018. - The arguments should be written in the `params` dictionary, and it has the + The parameters should be written in the `params` dictionary, and it has the following fields: :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not. :param `FocalDiceLoss_beta`: (float) The hyper-parameter to set (>=1.0). """ def __init__(self, params = None): - super(FocalDiceLoss, self).__init__() - self.softmax = params.get('loss_softmax', True) + super(FocalDiceLoss, self).__init__(params) self.beta = params['FocalDiceLoss_beta'.lower()] #beta should be >=1.0 def forward(self, loss_input_dict): @@ -77,15 +72,14 @@ class NoiseRobustDiceLoss(AbstractSegLoss): Pneumonia Lesions From CT Images, `IEEE TMI `_, 2020. - The arguments should be written in the `params` dictionary, and it has the + The parameters should be written in the `params` dictionary, and it has the following fields: :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not. :param `NoiseRobustDiceLoss_gamma`: (float) The hyper-parameter gammar to set (1, 2). """ def __init__(self, params): - super(NoiseRobustDiceLoss, self).__init__() - self.softmax = params.get('loss_softmax', True) + super(NoiseRobustDiceLoss, self).__init__(params) self.gamma = params['NoiseRobustDiceLoss_gamma'.lower()] def forward(self, loss_input_dict): diff --git a/pymic/loss/seg/exp_log.py b/pymic/loss/seg/exp_log.py index 33bd031..c1b3f00 100644 --- a/pymic/loss/seg/exp_log.py +++ b/pymic/loss/seg/exp_log.py @@ -4,14 +4,15 @@ import torch import torch.nn as nn +from pymic.loss.seg.abstract import AbstractSegLoss from pymic.loss.seg.util import reshape_tensor_to_2D, get_classwise_dice -class ExpLogLoss(nn.Module): +class ExpLogLoss(AbstractSegLoss): """ The exponential logarithmic loss in this paper: * K. Wong et al.: 3D Segmentation with Exponential Logarithmic Loss for Highly - Unbalanced Object Sizes. MICCAI 2018. + Unbalanced Object Sizes. `MICCAI 2018. `_ The arguments should be written in the `params` dictionary, and it has the following fields: @@ -21,24 +22,11 @@ class ExpLogLoss(nn.Module): :param `ExpLogLoss_gamma`: (float) Hyper-parameter gamma. """ def __init__(self, params): - super(ExpLogLoss, self).__init__() - self.softmax = params.get('loss_softmax', True) + super(ExpLogLoss, self).__init__(params) self.w_dice = params['ExpLogLoss_w_dice'.lower()] self.gamma = params['ExpLogLoss_gamma'.lower()] def forward(self, loss_input_dict): - """ - Forward pass for calculating the loss. - The arguments should be written in the `loss_input_dict` dictionary, - and it has the following fields: - - :param `prediction`: (tensor) Prediction of a network, with the - shape of [N, C, D, H, W] or [N, C, H, W]. - :param `ground_truth`: (tensor) Ground truth, with the - shape of [N, C, D, H, W] or [N, C, H, W]. - - :return: Loss function value. - """ predict = loss_input_dict['prediction'] soft_y = loss_input_dict['ground_truth'] diff --git a/pymic/loss/seg/gatedcrf.py b/pymic/loss/seg/gatedcrf.py index 5e23655..46e17e0 100644 --- a/pymic/loss/seg/gatedcrf.py +++ b/pymic/loss/seg/gatedcrf.py @@ -6,8 +6,7 @@ import torch import torch.nn.functional as F - -class ModelLossSemsegGatedCRF(torch.nn.Module): +class GatedCRFLoss(torch.nn.Module): """ Gated CRF Loss for Weakly Supervised Semantic Image Segmentation. This loss function promotes consistent label assignment guided by input features, such as RGBXY. @@ -60,7 +59,7 @@ def resize_fix_mask(mask, name): assert mask.dim() == 4 and mask.shape[:2] == (N, 1) and mask.dtype == torch.float32, \ f'{name} mask must be a NCHW batch with C=1 and dtype float32' if mask.shape[2:] != (height_pred, width_pred): - mask = ModelLossSemsegGatedCRF._downsample( + mask = GatedCRFLoss._downsample( mask, 'mask', height_pred, width_pred, custom_modality_downsamplers ) mask[mask != mask] = 0.0 # handle NaN @@ -130,18 +129,18 @@ def _create_kernels( if modality == 'weight': continue if modality == 'xy': - feature = ModelLossSemsegGatedCRF._get_mesh(N, height_pred, width_pred, device) + feature = GatedCRFLoss._get_mesh(N, height_pred, width_pred, device) else: assert modality in sample, \ f'Modality {modality} is listed in {i}-th kernel descriptor, but not present in the sample' feature = sample[modality] - feature = ModelLossSemsegGatedCRF._downsample( + feature = GatedCRFLoss._downsample( feature, modality, height_pred, width_pred, custom_modality_downsamplers ) feature /= sigma features.append(feature) features = torch.cat(features, dim=1) - kernel = weight * ModelLossSemsegGatedCRF._create_kernels_from_features(features, kernels_radius) + kernel = weight * GatedCRFLoss._create_kernels_from_features(features, kernels_radius) kernels = kernel if kernels is None else kernel + kernels return kernels @@ -149,7 +148,7 @@ def _create_kernels( def _create_kernels_from_features(features, radius): assert features.dim() == 4, 'Features must be a NCHW batch' N, C, H, W = features.shape - kernels = ModelLossSemsegGatedCRF._unfold(features, radius) + kernels = GatedCRFLoss._unfold(features, radius) kernels = kernels - kernels[:, :, radius, radius, :, :].view(N, C, 1, 1, H, W) kernels = (-0.5 * kernels ** 2).sum(dim=1, keepdim=True).exp() kernels[:, :, radius, radius, :, :] = 0 diff --git a/pymic/loss/seg/mse.py b/pymic/loss/seg/mse.py index 511cd28..ad83899 100644 --- a/pymic/loss/seg/mse.py +++ b/pymic/loss/seg/mse.py @@ -5,17 +5,13 @@ class MSELoss(AbstractSegLoss): """ Mean Sequare Loss for segmentation tasks. - The arguments should be written in the `params` dictionary, and it has the + The parameters should be written in the `params` dictionary, and it has the following fields: :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not. """ - def __init__(self, params): - super(MSELoss, self).__init__() - if(params is None): - self.softmax = True - else: - self.softmax = params.get('loss_softmax', True) + def __init__(self, params = None): + super(MSELoss, self).__init__(params) def forward(self, loss_input_dict): predict = loss_input_dict['prediction'] @@ -38,12 +34,8 @@ class MAELoss(AbstractSegLoss): :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not. """ - def __init__(self, params): + def __init__(self, params = None): super(MAELoss, self).__init__(params) - if(params is None): - self.softmax = True - else: - self.softmax = params.get('loss_softmax', True) def forward(self, loss_input_dict): predict = loss_input_dict['prediction'] diff --git a/pymic/loss/seg/mumford_shah.py b/pymic/loss/seg/mumford_shah.py index 9583dc0..6da51b5 100644 --- a/pymic/loss/seg/mumford_shah.py +++ b/pymic/loss/seg/mumford_shah.py @@ -15,7 +15,7 @@ class MumfordShahLoss(nn.Module): `_ Currently only 2D version is supported. - The arguments should be written in the `params` dictionary, and it has the + The parameters should be written in the `params` dictionary, and it has the following fields: :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not. @@ -26,7 +26,6 @@ def __init__(self, params = None): super(MumfordShahLoss, self).__init__() if(params is None): params = {} - self.softmax = params.get('loss_softmax', True) self.penalty = params.get('MumfordShahLoss_penalty', "l1") self.grad_w = params.get('MumfordShahLoss_lambda', 1.0) diff --git a/pymic/loss/seg/slsr.py b/pymic/loss/seg/slsr.py index 18aba3a..d5c4151 100644 --- a/pymic/loss/seg/slsr.py +++ b/pymic/loss/seg/slsr.py @@ -25,11 +25,10 @@ class SLSRLoss(AbstractSegLoss): :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not. :param `slsrloss_epsilon`: (optional, float) Hyper-parameter epsilon. Default is 0.25. """ - def __init__(self, params): - super(SLSRLoss, self).__init__() + def __init__(self, params = None): + super(SLSRLoss, self).__init__(params) if(params is None): params = {} - self.softmax = params.get('loss_softmax', True) self.epsilon = params.get('slsrloss_epsilon', 0.25) def forward(self, loss_input_dict): diff --git a/pymic/loss/seg/ssl.py b/pymic/loss/seg/ssl.py index 7eaf981..f15fc60 100644 --- a/pymic/loss/seg/ssl.py +++ b/pymic/loss/seg/ssl.py @@ -10,17 +10,13 @@ class EntropyLoss(nn.Module): """ Entropy Minimization for segmentation tasks. - The arguments should be written in the `params` dictionary, and it has the + The parameters should be written in the `params` dictionary, and it has the following fields: :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not. """ def __init__(self, params = None): - super(EntropyLoss, self).__init__() - if(params is None): - self.softmax = True - else: - self.softmax = params.get('loss_softmax', True) + super(EntropyLoss, self).__init__(params) def forward(self, loss_input_dict): """ @@ -50,17 +46,13 @@ def forward(self, loss_input_dict): class TotalVariationLoss(nn.Module): """ Total Variation Loss for segmentation tasks. - The arguments should be written in the `params` dictionary, and it has the + The parameters should be written in the `params` dictionary, and it has the following fields: :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not. """ def __init__(self, params = None): - super(TotalVariationLoss, self).__init__() - if(params is None): - self.softmax = True - else: - self.softmax = params.get('loss_softmax', True) + super(TotalVariationLoss, self).__init__(params) def forward(self, loss_input_dict): """ diff --git a/pymic/net/net2d/cople_net.py b/pymic/net/net2d/cople_net.py index 79533ce..bd54fd6 100644 --- a/pymic/net/net2d/cople_net.py +++ b/pymic/net/net2d/cople_net.py @@ -1,18 +1,12 @@ # -*- coding: utf-8 -*- -""" -Author: Guotai Wang -Date: 12 June, 2020 -Implementation of of COPLENet for COVID-19 pneumonia lesion segmentation from CT images. -Reference: - G. Wang et al. A Noise-robust Framework for Automatic Segmentation of COVID-19 Pneumonia Lesions - from CT Images. IEEE Transactions on Medical Imaging, 39(8),2020:2653-2663. DOI:10.1109/TMI.2020.3000314. -""" - from __future__ import print_function, division import torch import torch.nn as nn class ConvLayer(nn.Module): + """ + A combination of Conv2d, BatchNorm2d and LeakyReLU. + """ def __init__(self, in_channels, out_channels, kernel_size = 1): super(ConvLayer, self).__init__() padding = int((kernel_size - 1) / 2) @@ -26,6 +20,9 @@ def forward(self, x): return self.conv(x) class SEBlock(nn.Module): + """ + A Modified Squeeze-and-Excitation block for spatial attention. + """ def __init__(self, in_channels, r): super(SEBlock, self).__init__() @@ -42,6 +39,9 @@ def forward(self, x): return f*x + x class ASPPBlock(nn.Module): + """ + ASPP block. + """ def __init__(self,in_channels, out_channels_list, kernel_size_list, dilation_list): super(ASPPBlock, self).__init__() self.conv_num = len(out_channels_list) @@ -77,7 +77,10 @@ def forward(self, x): return y class ConvBNActBlock(nn.Module): - """Two convolution layers with batch norm, leaky relu, dropout and SE block""" + """ + Two convolution layers with batch norm, leaky relu, + dropout and SE block. + """ def __init__(self,in_channels, out_channels, dropout_p): super(ConvBNActBlock, self).__init__() self.conv_conv = nn.Sequential( @@ -95,7 +98,9 @@ def forward(self, x): return self.conv_conv(x) class DownBlock(nn.Module): - """Downsampling by a concantenation of max-pool and avg-pool, followed by ConvBNActBlock + """ + Downsampling by a concantenation of max-pool and avg-pool, + followed by ConvBNActBlock. """ def __init__(self, in_channels, out_channels, dropout_p): super(DownBlock, self).__init__() @@ -111,7 +116,9 @@ def forward(self, x): return y + x_cat class UpBlock(nn.Module): - """Upssampling followed by ConvBNActBlock""" + """ + Upssampling followed by ConvBNActBlock. + """ def __init__(self, in_channels1, in_channels2, out_channels, bilinear=True, dropout_p = 0.5): super(UpBlock, self).__init__() @@ -132,14 +139,33 @@ def forward(self, x1, x2): return y + x_cat class COPLENet(nn.Module): + """ + Implementation of of COPLENet for COVID-19 pneumonia lesion segmentation from CT images. + + * Reference: G. Wang et al. `A Noise-robust Framework for Automatic Segmentation of COVID-19 Pneumonia Lesions + from CT Images `_. + IEEE Transactions on Medical Imaging, 39(8),2020:2653-2663. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param bilinear: (bool) Using bilinear for up-sampling or not. + If False, deconvolution will be used for up-sampling. + """ def __init__(self, params): super(COPLENet, self).__init__() self.params = params self.in_chns = self.params['in_chns'] self.ft_chns = self.params['feature_chns'] + self.dropout = self.params['dropout'] self.n_class = self.params['class_num'] self.bilinear = self.params['bilinear'] - self.dropout = self.params['dropout'] assert(len(self.ft_chns) == 5) f0_half = int(self.ft_chns[0] / 2) diff --git a/pymic/net/net2d/scse2d.py b/pymic/net/net2d/scse2d.py index 6226102..1963a3b 100644 --- a/pymic/net/net2d/scse2d.py +++ b/pymic/net/net2d/scse2d.py @@ -1,14 +1,11 @@ """ -Squeeze hannel Squeeze and Excitation `_ - 2. `Spatial Squeeze and Excitation `_ - and Excitation Module -***************************** - -oringinal file: https://github.com/maodong2056/squeeze_and_excitation/blob/master/squeeze_and_excitation/squeeze_and_excitation.py -Collection of squeeze and excitation classes where each can be inserted as a block into a neural network architechture - 1. `C 3. `Channel and Spatial Squeeze and Excitation `_ -""" +1. Channel Squeeze and Excitation +2. Spatial Squeeze and Excitation +3. Concurrent Spatial and Channel Squeeze & Excitation +Oringinal file is on `Github. +`_ +""" from enum import Enum import torch import torch.nn as nn @@ -16,15 +13,15 @@ class ChannelSELayer(nn.Module): """ - Re-implementation of Squeeze-and-Excitation (SE) block described in: - *Hu et al., Squeeze-and-Excitation Networks, arXiv:1709.01507* + Re-implementation of Squeeze-and-Excitation (SE) block. + + * Reference: Jie Hu, Li Shen, Gang Sun: Squeeze-and-Excitation Networks. + `CVPR 2018. `_ + + :param num_channels: Number of input channels + :param reduction_ratio: By how much should the num_channels should be reduced. """ def __init__(self, num_channels, reduction_ratio=2): - - """ - :param num_channels: No of input channels - :param reduction_ratio: By how much should the num_channels should be reduced - """ super(ChannelSELayer, self).__init__() num_channels_reduced = num_channels // reduction_ratio self.reduction_ratio = reduction_ratio @@ -53,14 +50,14 @@ def forward(self, input_tensor): class SpatialSELayer(nn.Module): """ - Re-implementation of SE block -- squeezing spatially and exciting channel-wise described in: + Re-implementation of SE block -- squeezing spatially and exciting channel-wise. - *Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, MICCAI 2018* + * Reference: Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in + Fully Convolutional Networks, MICCAI 2018. + + :param num_channels: Number of input channels. """ def __init__(self, num_channels): - """ - :param num_channels: No of input channels - """ super(SpatialSELayer, self).__init__() self.conv = nn.Conv2d(num_channels, 1, 1) self.sigmoid = nn.Sigmoid() @@ -92,14 +89,15 @@ def forward(self, input_tensor, weights=None): class ChannelSpatialSELayer(nn.Module): """ - Re-implementation of concurrent spatial and channel squeeze & excitation: - *Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, MICCAI 2018, arXiv:1803.02579* + Re-implementation of concurrent spatial and channel squeeze & excitation. + + * Reference: Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in + Fully Convolutional Networks, MICCAI 2018. + + :param num_channels: Number of input channels. + :param reduction_ratio: By how much should the num_channels should be reduced. """ def __init__(self, num_channels, reduction_ratio=2): - """ - :param num_channels: No of input channels - :param reduction_ratio: By how much should the num_channels should be reduced - """ super(ChannelSpatialSELayer, self).__init__() self.cSE = ChannelSELayer(num_channels, reduction_ratio) self.sSE = SpatialSELayer(num_channels) diff --git a/pymic/net/net2d/unet2d.py b/pymic/net/net2d/unet2d.py index c17cd0c..9a8aca3 100644 --- a/pymic/net/net2d/unet2d.py +++ b/pymic/net/net2d/unet2d.py @@ -1,12 +1,4 @@ # -*- coding: utf-8 -*- -""" -An implementation of the U-Net paper: - Olaf Ronneberger, Philipp Fischer, Thomas Brox: - U-Net: Convolutional Networks for Biomedical Image Segmentation. - MICCAI (3) 2015: 234-241 -Note that there are some modifications from the original paper, such as -the use of batch normalization, dropout, and leaky relu here. -""" from __future__ import print_function, division import torch @@ -15,11 +7,15 @@ from torch.nn.functional import interpolate class ConvBlock(nn.Module): - """two convolution layers with batch norm and leaky relu""" + """ + Two convolution layers with batch norm and leaky relu. + Droput is used between the two convolution layers. + + :param in_channels: (int) Input channel number. + :param out_channels: (int) Output channel number. + :param dropout_p: (int) Dropout probability. + """ def __init__(self,in_channels, out_channels, dropout_p): - """ - dropout_p: probability to be zeroed - """ super(ConvBlock, self).__init__() self.conv_conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), @@ -35,7 +31,13 @@ def forward(self, x): return self.conv_conv(x) class DownBlock(nn.Module): - """Downsampling followed by ConvBlock""" + """ + Downsampling followed by ConvBlock + + :param in_channels: (int) Input channel number. + :param out_channels: (int) Output channel number. + :param dropout_p: (int) Dropout probability. + """ def __init__(self, in_channels, out_channels, dropout_p): super(DownBlock, self).__init__() self.maxpool_conv = nn.Sequential( @@ -47,15 +49,18 @@ def forward(self, x): return self.maxpool_conv(x) class UpBlock(nn.Module): - """Upsampling followed by ConvBlock""" + """ + Upsampling followed by ConvBlock + + :param in_channels1: (int) Channel number of high-level features. + :param in_channels2: (int) Channel number of low-level features. + :param out_channels: (int) Output channel number. + :param dropout_p: (int) Dropout probability. + :param bilinear: (bool) Use bilinear for up-sampling (by default). + If False, deconvolution is used for up-sampling. + """ def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, bilinear=True): - """ - in_channels1: channel of high-level features - in_channels2: channel of low-level features - out_channels: output channel number - dropout_p: probability of dropout - """ super(UpBlock, self).__init__() self.bilinear = bilinear if bilinear: @@ -73,6 +78,18 @@ def forward(self, x1, x2): return self.conv(x) class Encoder(nn.Module): + """ + Encoder of 2D UNet. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + """ def __init__(self, params): super(Encoder, self).__init__() self.params = params @@ -100,6 +117,21 @@ def forward(self, x): return output class Decoder(nn.Module): + """ + Decoder of 2D UNet. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param bilinear: (bool) Using bilinear for up-sampling or not. + If False, deconvolution will be used for up-sampling. + """ def __init__(self, params): super(Decoder, self).__init__() self.params = params @@ -134,10 +166,30 @@ def forward(self, x): return output class UNet2D(nn.Module): + """ + An implementation of 2D U-Net. + + * Reference: Olaf Ronneberger, Philipp Fischer, Thomas Brox: + U-Net: Convolutional Networks for Biomedical Image Segmentation. + MICCAI (3) 2015: 234-241 + + Note that there are some modifications from the original paper, such as + the use of batch normalization, dropout, and leaky relu here. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param bilinear: (bool) Using bilinear for up-sampling or not. + If False, deconvolution will be used for up-sampling. + :param deep_supervise: (bool) Using deep supervision for training or not. + """ def __init__(self, params): - """ - 2D UNet - """ super(UNet2D, self).__init__() self.params = params self.in_chns = self.params['in_chns'] diff --git a/pymic/net/net2d/unet2d_cct.py b/pymic/net/net2d/unet2d_cct.py index f7558bc..b3aa557 100644 --- a/pymic/net/net2d/unet2d_cct.py +++ b/pymic/net/net2d/unet2d_cct.py @@ -1,13 +1,4 @@ # -*- coding: utf-8 -*- -""" -An modification the U-Net with auxiliary decoders according to -the CCT paper: - Yassine Ouali, Celine Hudelot and Myriam Tami: - Semi-Supervised Semantic Segmentation With Cross-Consistency Training. - CVPR 2020. - https://arxiv.org/abs/2003.09005 -Code adapted from: https://github.com/yassouali/CCT -""" from __future__ import print_function, division import torch @@ -18,14 +9,13 @@ from pymic.net.net2d.unet2d import Encoder, Decoder def _l2_normalize(d): - # Normalizing per batch axis + """Normalizing per batch axis""" d_reshaped = d.view(d.shape[0], -1, *(1 for _ in range(d.dim() - 2))) d /= torch.norm(d_reshaped, dim=1, keepdim=True) + 1e-8 return d - -def get_r_adv(x_list, decoder, it=1, xi=1e-1, eps=10.0): +def _get_r_adv(x_list, decoder, it=1, xi=1e-1, eps=10.0): """ Virtual Adversarial Training according to https://arxiv.org/abs/1704.03976 @@ -53,6 +43,19 @@ def get_r_adv(x_list, decoder, it=1, xi=1e-1, eps=10.0): class AuxiliaryDecoder(nn.Module): + """ + An Auxiliary Decoder. + `aux_type` should be one of {`DropOut`, `FeatureDrop`, `FeatureNoise` and `VAT`}. + Other parameters for the decoder are given in the `params` dictionary, + see :mod:`pymic.net.net2d.unet2d.Decoder` for details. + In addition, the following fields are needed for pertubation: + + :param Uniform_range: (float) The range of noise. Only needed when `aux_type`=`FeatureNoise`. + :param VAT_it: (float) The iteration number of VAT. Only needed when `aux_type`=`VAT`. + :param VAT_xi: (float) The hyper-parameter xi of VAT. Only needed when `aux_type`=`VAT`. + :param VAT_eps: (float) The hyper-parameter eps of VAT. Only needed when `aux_type`=`VAT`. + + """ def __init__(self, params, aux_type): super(AuxiliaryDecoder, self).__init__() self.params = params @@ -85,7 +88,7 @@ def forward(self, x): it = self.params.get("VAT_it".lower(), 2) xi = self.params.get("VAT_xi".lower(), 1e-6) eps= self.params.get("VAT_eps".lower(), 2.0) - x[-1] = get_r_adv(x, self.decoder, it, xi, eps) + x[-1] = _get_r_adv(x, self.decoder, it, xi, eps) else: raise ValueError("Undefined auxiliary decoder type {0:}".format(self.aux_type)) @@ -94,6 +97,28 @@ def forward(self, x): class UNet2D_CCT(nn.Module): + """ + An modification the U-Net with auxiliary decoders according to + the CCT paper. + + * Reference: Yassine Ouali, Celine Hudelot and Myriam Tami: + Semi-Supervised Semantic Segmentation With Cross-Consistency Training. + `CVPR 2020. `_ + + Code adapted from `Github. `_ + + Parameter for the network backbone are given in the `params` dictionary, + see :mod:`pymic.net.net2d.unet2d.UNet2D` for details. + In addition, the following fields are needed for pertubation + in the auxiliary decoders: + + :param CCT_aux_decoders: (list) A list of auxiliary decoder types. + Supported values are {`DropOut`, `FeatureDrop`, `FeatureNoise` and `VAT`}. + + The parameters for different types of auxiliary decoders should also be + given in the `params` dictionary, + see :mod:`pymic.net.net2d.unet2d_cct.AuxiliaryDecoder` for details. + """ def __init__(self, params): super(UNet2D_CCT, self).__init__() self.params = params diff --git a/pymic/net/net2d/unet2d_dual_branch.py b/pymic/net/net2d/unet2d_dual_branch.py index 3531c89..828bdfe 100644 --- a/pymic/net/net2d/unet2d_dual_branch.py +++ b/pymic/net/net2d/unet2d_dual_branch.py @@ -1,12 +1,4 @@ # -*- coding: utf-8 -*- -""" -Extention of U-Net with two decoders. The network was introduced in -the following paper: - Xiangde Luo, Minhao Hu, Wenjun Liao, Shuwei Zhai, Tao Song, Guotai Wang, - Shaoting Zhang. ScribblScribble-Supervised Medical Image Segmentation via - Dual-Branch Network and Dynamically Mixed Pseudo Labels Supervision. - MICCAI 2022. -""" from __future__ import print_function, division import torch @@ -14,6 +6,23 @@ from pymic.net.net2d.unet2d import * class UNet2D_DualBranch(nn.Module): + """ + A dual branch network using UNet2D as backbone. + + * Reference: Xiangde Luo, Minhao Hu, Wenjun Liao, Shuwei Zhai, Tao Song, Guotai Wang, + Shaoting Zhang. ScribblScribble-Supervised Medical Image Segmentation via + Dual-Branch Network and Dynamically Mixed Pseudo Labels Supervision. + `MICCAI 2022. `_ + + The parameters for the backbone should be given in the `params` dictionary. + See :mod:`pymic.net.net2d.unet2d.UNet2D` for details. + In addition, the following field should be included: + + :param output_mode: (str) How to obtain the result during the inference. + `average`: taking average of the two branches. + `first`: takeing the result in the first branch. + `second`: taking the result in the second branch. + """ def __init__(self, params): super(UNet2D_DualBranch, self).__init__() self.output_mode = params.get("output_mode", "average") diff --git a/pymic/net/net2d/unet2d_nest.py b/pymic/net/net2d/unet2d_nest.py index 5071114..efa048f 100644 --- a/pymic/net/net2d/unet2d_nest.py +++ b/pymic/net/net2d/unet2d_nest.py @@ -1,24 +1,36 @@ # -*- coding: utf-8 -*- -""" -An implementation of the Nested U-Net paper: - Zongwei Zhou, et al.: - UNet++: A Nested U-Net Architecture for Medical Image Segmentation. - MICCAI DLMIA workshop, 2018: 3-11. -Note that there are some modifications from the original paper, such as -the use of dropout and leaky relu here. -""" import torch import torch.nn as nn from pymic.net.net2d.unet2d import * class NestedUNet2D(nn.Module): + """ + An implementation of the Nested U-Net. + + * Reference: Zongwei Zhou, et al.: `UNet++: A Nested U-Net Architecture for Medical Image Segmentation. + `_ + MICCAI DLMIA workshop, 2018: 3-11. + + Note that there are some modifications from the original paper, such as + the use of dropout and leaky relu here. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + """ def __init__(self, params): super(NestedUNet2D, self).__init__() self.params = params self.in_chns = self.params['in_chns'] self.filters = self.params['feature_chns'] - self.n_class = self.params['class_num'] self.dropout = self.params['dropout'] + self.n_class = self.params['class_num'] self.pool = nn.MaxPool2d(kernel_size=2, stride=2) self.Up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) diff --git a/pymic/net/net2d/unet2d_scse.py b/pymic/net/net2d/unet2d_scse.py index 6f9d3f7..b1606c9 100644 --- a/pymic/net/net2d/unet2d_scse.py +++ b/pymic/net/net2d/unet2d_scse.py @@ -1,10 +1,4 @@ # -*- coding: utf-8 -*- -""" -Combining U-Net with SCSE module according to the following paper: - Abhijit Guha Roy, Nassir Navab, Christian Wachinger: - Recalibrating Fully Convolutional Networks With Spatial and Channel "Squeeze and Excitation" Blocks. \ - IEEE Trans. Med. Imaging 38(2): 540-549 (2019) -""" from __future__ import print_function, division import torch @@ -13,8 +7,12 @@ from pymic.net.net2d.scse2d import * class ConvScSEBlock(nn.Module): - """two convolution layers with batch norm and leaky relu""" - def __init__(self,in_channels, out_channels, dropout_p): + """ + Two convolutional blocks followed by `ChannelSpatialSELayer`. + Each block consists of `Conv2d` + `BatchNorm2d` + `LeakyReLU`. + A dropout layer is used between the wo blocks. + """ + def __init__(self, in_channels, out_channels, dropout_p): super(ConvScSEBlock, self).__init__() self.conv_conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), @@ -31,7 +29,7 @@ def forward(self, x): return self.conv_conv(x) class DownBlock(nn.Module): - """Downsampling followed by ConvBlock""" + """Downsampling followed by `ConvScSEBlock`.""" def __init__(self, in_channels, out_channels, dropout_p): super(DownBlock, self).__init__() self.maxpool_conv = nn.Sequential( @@ -44,7 +42,7 @@ def forward(self, x): return self.maxpool_conv(x) class UpBlock(nn.Module): - """Upssampling followed by ConvBlock""" + """Up-sampling followed by `ConvScSEBlock`.""" def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, bilinear=True): super(UpBlock, self).__init__() @@ -64,14 +62,34 @@ def forward(self, x1, x2): return self.conv(x) class UNet2D_ScSE(nn.Module): + """ + Combining 2D U-Net with SCSE module. + + * Reference: Abhijit Guha Roy, Nassir Navab, Christian Wachinger: + Recalibrating Fully Convolutional Networks With Spatial and Channel + "Squeeze and Excitation" Blocks. + `IEEE Trans. Med. Imaging 38(2): 540-549 (2019). `_ + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param bilinear: (bool) Using bilinear for up-sampling or not. + If False, deconvolution will be used for up-sampling. + """ def __init__(self, params): super(UNet2D_ScSE, self).__init__() self.params = params self.in_chns = self.params['in_chns'] self.ft_chns = self.params['feature_chns'] + self.dropout = self.params['dropout'] self.n_class = self.params['class_num'] self.bilinear = self.params['bilinear'] - self.dropout = self.params['dropout'] assert(len(self.ft_chns) == 5) self.in_conv= ConvScSEBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) diff --git a/pymic/net/net_dict_seg.py b/pymic/net/net_dict_seg.py index 0ee554e..5dd3bee 100644 --- a/pymic/net/net_dict_seg.py +++ b/pymic/net/net_dict_seg.py @@ -1,4 +1,19 @@ # -*- coding: utf-8 -*- +""" +Built-in networks for segmentation. + +* UNet2D :mod:`pymic.net.net2d.unet2d.UNet2D` +* UNet2D_DualBranch :mod:`pymic.net.net2d.unet2d_dual_branch.UNet2D_DualBranch` +* UNet2D_URPC :mod:`pymic.net.net2d.unet2d_urpc.UNet2D_URPC` +* UNet2D_CCT :mod:`pymic.net.net2d.unet2d_cct.UNet2D_CCT` +* UNet2D_ScSE :mod:`pymic.net.net2d.unet2d_scse.UNet2D_ScSE` +* AttentionUNet2D :mod:`pymic.net.net2d.unet2d_attention.AttentionUNet2D` +* NestedUNet2D :mod:`pymic.net.net2d.unet2d_nest.NestedUNet2D` +* COPLENet :mod:`pymic.net.net2d.cople_net.COPLENet` +* UNet2D5 :mod:`pymic.net.net3d.unet2d5.UNet2D5` +* UNet3D :mod:`pymic.net.net3d.unet3d.UNet3D` +* UNet3D_ScSE :mod:`pymic.net.net3d.unet3d_scse.UNet3D_ScSE` +""" from __future__ import print_function, division from pymic.net.net2d.unet2d import UNet2D from pymic.net.net2d.unet2d_dual_branch import UNet2D_DualBranch diff --git a/pymic/net_run_wsl/wsl_gatedcrf.py b/pymic/net_run_wsl/wsl_gatedcrf.py index 270a955..85f601e 100644 --- a/pymic/net_run_wsl/wsl_gatedcrf.py +++ b/pymic/net_run_wsl/wsl_gatedcrf.py @@ -7,7 +7,7 @@ from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice -from pymic.loss.seg.gatedcrf import ModelLossSemsegGatedCRF +from pymic.loss.seg.gatedcrf import GatedCRFLoss from pymic.net_run_wsl.wsl_abstract import WSLSegAgent from pymic.util.ramps import get_rampup_ratio @@ -54,7 +54,7 @@ def training(self): train_loss_reg = 0 train_dice_list = [] - gatecrf_loss = ModelLossSemsegGatedCRF() + gatecrf_loss = GatedCRFLoss() self.net.train() for it in range(iter_valid): try: diff --git a/pymic/transform/trans_dict.py b/pymic/transform/trans_dict.py index e0ef85c..8da3b38 100644 --- a/pymic/transform/trans_dict.py +++ b/pymic/transform/trans_dict.py @@ -1,4 +1,33 @@ # -*- coding: utf-8 -*- +""" +The built-in transforms in PyMIC are: + +.. code-block:: none + + 'ChannelWiseThreshold': ChannelWiseThreshold, + 'ChannelWiseThresholdWithNormalize': ChannelWiseThresholdWithNormalize, + 'CropWithBoundingBox': CropWithBoundingBox, + 'CenterCrop': CenterCrop, + 'GrayscaleToRGB': GrayscaleToRGB, + 'GammaCorrection': GammaCorrection, + 'GaussianNoise': GaussianNoise, + 'LabelConvert': LabelConvert, + 'LabelConvertNonzero': LabelConvertNonzero, + 'LabelToProbability': LabelToProbability, + 'NormalizeWithMeanStd': NormalizeWithMeanStd, + 'NormalizeWithMinMax': NormalizeWithMinMax, + 'NormalizeWithPercentiles': NormalizeWithPercentiles, + 'PartialLabelToProbability':PartialLabelToProbability, + 'RandomCrop': RandomCrop, + 'RandomResizedCrop': RandomResizedCrop, + 'RandomRescale': RandomRescale, + 'RandomFlip': RandomFlip, + 'RandomRotate': RandomRotate, + 'ReduceLabelDim': ReduceLabelDim, + 'Rescale': Rescale, + 'Pad': Pad. + +""" from __future__ import print_function, division from pymic.transform.intensity import * from pymic.transform.flip import RandomFlip diff --git a/pymic/util/post_process.py b/pymic/util/post_process.py index da133ca..a0a9dff 100644 --- a/pymic/util/post_process.py +++ b/pymic/util/post_process.py @@ -7,6 +7,9 @@ from pymic.util.image_process import get_largest_k_components class PostProcess(object): + """ + The abastract class for post processing. + """ def __init__(self, params): self.params = params @@ -14,13 +17,19 @@ def __call__(self, seg): return seg class PostKeepLargestComponent(PostProcess): + """ + Post process by keeping the largest component. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `KeepLargestComponent_mode`: (int) + `1` means keep the largest component of the union of foreground classes. + `2` means keep the largest component for each foreground class. + """ def __init__(self, params): super(PostKeepLargestComponent, self).__init__(params) self.mode = params.get("KeepLargestComponent_mode".lower(), 1) - """ - mode = 1: keep the largest component of the union of foreground classes. - mode = 2: keep the largest component for each foreground class. - """ def __call__(self, seg): if(self.mode == 1): diff --git a/pymic/util/preprocess.py b/pymic/util/preprocess.py index f60131c..5f20372 100644 --- a/pymic/util/preprocess.py +++ b/pymic/util/preprocess.py @@ -6,6 +6,9 @@ from pymic.util.parse_config import parse_config def get_transform_list(trans_config_file): + """ + Create a list of transforms given a configuration file. + """ config = parse_config(trans_config_file) transform_list = [] @@ -21,11 +24,18 @@ def get_transform_list(trans_config_file): return transform_list def preprocess_with_transform(transforms, img_in_name, img_out_name, - lab_in_name = None, lab_out_name = None): + lab_in_name = None, lab_out_name = None): """ - using data transforms for preprocessing, such as image normalization, - cropping, etc. + Using a list of data transforms for preprocessing, + such as image normalization, cropping, etc. TODO: support multip-modality preprocessing. + + :param transforms: (list) A list of transform objects. + :param img_in_name: (str) Input file name. + :param img_out_name: (str) Output file name. + :param lab_in_name: (optional, str) If None, load the image's + corresponding label for preprocessing as well. + :param lab_out_name: (optional, str) The output label name. """ image_dict = load_image_as_nd_array(img_in_name) sample = {'image': np.asarray(image_dict['data_array'], np.float32), diff --git a/pymic/util/ramps.py b/pymic/util/ramps.py index b58adb6..2f8cc9f 100644 --- a/pymic/util/ramps.py +++ b/pymic/util/ramps.py @@ -1,30 +1,58 @@ # -*- coding: utf-8 -*- -from __future__ import print_function, division -import numpy as np - -"""Functions for ramping hyperparameters up or down +""" +Functions for ramping hyperparameters up or down. Each function takes the current training step or epoch, and the -ramp length (maximal step or epoch), and returns a multiplier between +ramp length (start and end step or epoch), and returns a multiplier between 0 and 1. """ - +from __future__ import print_function, division +import numpy as np def get_rampup_ratio(i, start, end, mode = "linear"): - if( i < start): - rampup = 0.0 - elif(i > end): - rampup = 1.0 - elif(mode == "linear"): + """ + Obtain the rampup ratio. + + :param i: (int) The current iteration. + :param start: (int) The start iteration. + :param end: (int) The end itertation. + :param mode: (str) Valid values are {`linear`, `sigmoid`, `cosine`}. + """ + i = np.clip(i, start, end) + if(mode == "linear"): rampup = (i - start) / (end - start) elif(mode == "sigmoid"): phase = 1.0 - (i - start) / (end - start) rampup = float(np.exp(-5.0 * phase * phase)) + elif(mode == "cosine"): + phase = 1.0 - (i - start) / (end - start) + rampup = float(.5 * (np.cos(np.pi * phase) + 1)) + else: + raise ValueError("Undefined rampup mode {0:}".format(mode)) return rampup -def cosine_rampdown(i, start, end): - """Cosine rampdown from https://arxiv.org/abs/1608.03983""" - i = np.clip(i, 0.0, length) - return float(.5 * (np.cos(np.pi * i / length) + 1)) \ No newline at end of file +def get_rampdown_ratio(i, start, end, mode = "linear"): + """ + Obtain the rampdown ratio. + + :param i: (int) The current iteration. + :param start: (int) The start iteration. + :param end: (int) The end itertation. + :param mode: (str) Valid values are {`linear`, `sigmoid`, `cosine`}. + """ + i = np.clip(i, start, end) + if(mode == "linear"): + rampdown = 1.0 - (i - start) / (end - start) + elif(mode == "sigmoid"): + phase = (i - start) / (end - start) + rampdown = float(np.exp(-5.0 * phase * phase)) + elif(mode == "cosine"): + phase = (i - start) / (end - start) + rampdown = float(.5 * (np.cos(np.pi * phase) + 1)) + else: + raise ValueError("Undefined rampup mode {0:}".format(mode)) + return rampdown + + \ No newline at end of file From 50a82ef28ddbb0a2006f6272a294069af1a48d1c Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 13 Sep 2022 22:38:22 +0800 Subject: [PATCH 112/225] update docs for 3D networks update docs for 3D networks --- pymic/layer/convolution.py | 103 ++++++---------- pymic/layer/deconvolution.py | 37 +++++- pymic/layer/space2channel.py | 5 + pymic/net/cls/torch_pretrained_net.py | 162 ++++++++++++++++++-------- pymic/net/net2d/scse2d.py | 7 +- pymic/net/net2d/unet2d.py | 2 +- pymic/net/net2d/unet2d_scse.py | 20 +++- pymic/net/net2d/unet2d_urpc.py | 36 ++++-- pymic/net/net3d/scse3d.py | 49 ++++---- pymic/net/net3d/unet2d5.py | 70 ++++++++--- pymic/net/net3d/unet3d.py | 83 +++++++------ pymic/net/net3d/unet3d_scse.py | 66 ++++++++--- pymic/net/net_dict_cls.py | 8 ++ 13 files changed, 420 insertions(+), 228 deletions(-) diff --git a/pymic/layer/convolution.py b/pymic/layer/convolution.py index 034cdeb..50e8b47 100644 --- a/pymic/layer/convolution.py +++ b/pymic/layer/convolution.py @@ -7,8 +7,22 @@ class ConvolutionLayer(nn.Module): """ A compose layer with the following components: - convolution -> (batch_norm / layer_norm / group_norm / instance_norm) -> activation -> (dropout) - batch norm and dropout are optional + convolution -> (batch_norm / layer_norm / group_norm / instance_norm) -> (activation) -> (dropout) + Batch norm and activation are optional. + + :param in_channels: (int) The input channel number. + :param out_channels: (int) The output channel number. + :param kernel_size: The size of convolution kernel. It can be either a single + int or a tupe of two or three ints. + :param dim: (int) The dimention of convolution (2 or 3). + :param stride: (int) The stride of convolution. + :param padding: (int) Padding size. + :param dilation: (int) Dilation rate. + :param conv_group: (int) The groupt number of convolution. + :param bias: (bool) Add bias or not for convolution. + :param norm_type: (str or None) Normalization type, can be `batch_norm`, 'group_norm'. + :param norm_group: (int) The number of group for group normalization. + :param acti_func: (str or None) Activation funtion. """ def __init__(self, in_channels, out_channels, kernel_size, dim = 3, stride = 1, padding = 0, dilation = 1, conv_group = 1, bias = True, @@ -50,9 +64,23 @@ def forward(self, x): class DepthSeperableConvolutionLayer(nn.Module): """ - A compose layer with the following components: - convolution -> (batch_norm) -> activation -> (dropout) - batch norm and dropout are optional + Depth seperable convolution with the following components: + 1x1 conv -> group conv -> (batch_norm / layer_norm / group_norm / instance_norm) -> (activation) -> (dropout) + Batch norm and activation are optional. + + :param in_channels: (int) The input channel number. + :param out_channels: (int) The output channel number. + :param kernel_size: The size of convolution kernel. It can be either a single + int or a tupe of two or three ints. + :param dim: (int) The dimention of convolution (2 or 3). + :param stride: (int) The stride of convolution. + :param padding: (int) Padding size. + :param dilation: (int) Dilation rate. + :param conv_group: (int) The groupt number of convolution. + :param bias: (bool) Add bias or not for convolution. + :param norm_type: (str or None) Normalization type, can be `batch_norm`, 'group_norm'. + :param norm_group: (int) The number of group for group normalization. + :param acti_func: (str or None) Activation funtion. """ def __init__(self, in_channels, out_channels, kernel_size, dim = 3, stride = 1, padding = 0, dilation =1, conv_group = 1, bias = True, @@ -97,68 +125,3 @@ def forward(self, x): f = self.acti_func(f) return f -class ConvolutionSepAll3DLayer(nn.Module): - """ - A compose layer with the following components: - convolution -> (batch_norm) -> activation -> (dropout) - batch norm and dropout are optional - """ - def __init__(self, in_channels, out_channels, kernel_size, dim = 3, - stride = 1, padding = 0, dilation =1, groups = 1, bias = True, - batch_norm = True, acti_func = None): - super(ConvolutionSepAll3DLayer, self).__init__() - self.n_in_chns = in_channels - self.n_out_chns = out_channels - self.batch_norm = batch_norm - self.acti_func = acti_func - - assert(dim == 3) - chn = min(in_channels, out_channels) - - self.conv_intra_plane1 = nn.Conv2d(chn, chn, - kernel_size, stride, padding, dilation, chn, bias) - - self.conv_intra_plane2 = nn.Conv2d(chn, chn, - kernel_size, stride, padding, dilation, chn, bias) - - self.conv_intra_plane3 = nn.Conv2d(chn, chn, - kernel_size, stride, padding, dilation, chn, bias) - - self.conv_space_wise = nn.Conv2d(in_channels, out_channels, - 1, stride, 0, dilation, 1, bias) - - if(self.batch_norm): - self.bn = nn.BatchNorm3d(out_channels) - - def forward(self, x): - in_shape = list(x.shape) - assert(len(in_shape) == 5) - [B, C, D, H, W] = in_shape - f0 = x.permute(0, 2, 1, 3, 4) #[B, D, C, H, W] - f0 = f0.contiguous().view([B*D, C, H, W]) - - Cc = min(self.n_in_chns, self.n_out_chns) - Co = self.n_out_chns - if(self.n_in_chns > self.n_out_chns): - f0 = self.conv_space_wise(f0) #[B*D, Cc, H, W] - - f1 = self.conv_intra_plane1(f0) - f2 = f1.contiguous().view([B, D, Cc, H, W]) - f2 = f2.permute(0, 3, 2, 1, 4) #[B, H, Cc, D, W] - f2 = f2.contiguous().view([B*H, Cc, D, W]) - f2 = self.conv_intra_plane2(f2) - f3 = f2.contiguous().view([B, H, Cc, D, W]) - f3 = f3.permute(0, 4, 2, 3, 1) #[B, W, Cc, D, H] - f3 = f3.contiguous().view([B*W, Cc, D, H]) - f3 = self.conv_intra_plane3(f3) - if(self.n_in_chns <= self.n_out_chns): - f3 = self.conv_space_wise(f3) #[B*W, Co, D, H] - - f3 = f3.contiguous().view([B, W, Co, D, H]) - f3 = f3.permute([0, 2, 3, 4, 1]) #[B, Co, D, H, W] - - if(self.batch_norm): - f3 = self.bn(f3) - if(self.acti_func is not None): - f3 = self.acti_func(f3) - return f3 diff --git a/pymic/layer/deconvolution.py b/pymic/layer/deconvolution.py index b9bac83..20c83c1 100644 --- a/pymic/layer/deconvolution.py +++ b/pymic/layer/deconvolution.py @@ -7,8 +7,21 @@ class DeconvolutionLayer(nn.Module): """ A compose layer with the following components: - deconvolution -> (batch_norm) -> activation -> (dropout) - batch norm and dropout are optional + deconvolution -> (batch_norm / layer_norm / group_norm / instance_norm) -> (activation) -> (dropout) + Batch norm and activation are optional. + + :param in_channels: (int) The input channel number. + :param out_channels: (int) The output channel number. + :param kernel_size: The size of convolution kernel. It can be either a single + int or a tupe of two or three ints. + :param dim: (int) The dimention of convolution (2 or 3). + :param stride: (int) The stride of convolution. + :param padding: (int) Padding size. + :param dilation: (int) Dilation rate. + :param groups: (int) The groupt number of convolution. + :param bias: (bool) Add bias or not for convolution. + :param batch_norm: (bool) Use batch norm or not. + :param acti_func: (str or None) Activation funtion. """ def __init__(self, in_channels, out_channels, kernel_size, dim = 3, stride = 1, padding = 0, output_padding = 0, @@ -44,9 +57,23 @@ def forward(self, x): class DepthSeperableDeconvolutionLayer(nn.Module): """ - A compose layer with the following components: - convolution -> (batch_norm) -> activation -> (dropout) - batch norm and dropout are optional + Depth seperable deconvolution with the following components: + 1x1 conv -> deconv -> (batch_norm / layer_norm / group_norm / instance_norm) -> (activation) -> (dropout) + Batch norm and activation are optional. + + :param in_channels: (int) The input channel number. + :param out_channels: (int) The output channel number. + :param kernel_size: The size of convolution kernel. It can be either a single + int or a tupe of two or three ints. + :param dim: (int) The dimention of convolution (2 or 3). + :param stride: (int) The stride of convolution. + :param padding: (int) Padding size for input. + :param output_padding: (int) Padding size for ouput. + :param dilation: (int) Dilation rate. + :param groups: (int) The groupt number of convolution. + :param bias: (bool) Add bias or not for convolution. + :param batch_norm: (bool) Use batch norm or not. + :param acti_func: (str or None) Activation funtion. """ def __init__(self, in_channels, out_channels, kernel_size, dim = 3, stride = 1, padding = 0, output_padding = 0, diff --git a/pymic/layer/space2channel.py b/pymic/layer/space2channel.py index c33766b..79981d8 100644 --- a/pymic/layer/space2channel.py +++ b/pymic/layer/space2channel.py @@ -5,7 +5,10 @@ import torch import torch.nn as nn import SimpleITK as sitk + class SpaceToChannel3D(nn.Module): + """ + Space to channel transform for 3D input.""" def __init__(self): super(SpaceToChannel3D, self).__init__() @@ -34,6 +37,8 @@ def forward(self, x): return x7 class ChannelToSpace3D(nn.Module): + """ + Channel to space transform for 3D input.""" def __init__(self): super(ChannelToSpace3D, self).__init__() diff --git a/pymic/net/cls/torch_pretrained_net.py b/pymic/net/cls/torch_pretrained_net.py index 198419f..edc5d9c 100644 --- a/pymic/net/cls/torch_pretrained_net.py +++ b/pymic/net/cls/torch_pretrained_net.py @@ -1,6 +1,7 @@ # pretrained models from pytorch: https://pytorch.org/vision/0.8/models.html from __future__ import print_function, division +import itertools import torch import torch.nn as nn import torchvision.models as models @@ -20,80 +21,149 @@ # 'mnasnet': models.mnasnet1_0 # } -class ResNet18(nn.Module): +class BuiltInNet(nn.Module): + """ + Built-in Network in Pytorch for classification. + Parameters should be set in the `params` dictionary that contains the + following fields: + + :param input_chns: (int) Input channel number, default is 3. + :param pretrain: (bool) Using pretrained model or not, default is True. + :param update_mode: (str) The strategy for updating layers: "`all`" means updating + all the layers, and "`last`" (by default) means updating the last layer, + as well as the first layer when `input_chns` is not 3. + """ def __init__(self, params): - super(ResNet18, self).__init__() - self.params = params - cls_num = params['class_num'] - in_chns = params.get('input_chns', 3) + super(BuiltInNet, self).__init__() + self.params = params + self.in_chns = params.get('input_chns', 3) self.pretrain = params.get('pretrain', True) - self.update_layers = params.get('update_layers', 0) + self.update_mode = params.get('update_mode', "last") + self.net = None + + def forward(self, x): + return self.net(x) + + def get_parameters_to_update(self): + pass + +class ResNet18(BuiltInNet): + """ + ResNet18 for classification. + Parameters should be set in the `params` dictionary that contains the + following fields: + + :param input_chns: (int) Input channel number, default is 3. + :param pretrain: (bool) Using pretrained model or not, default is True. + :param update_mode: (str) The strategy for updating layers: "`all`" means updating + all the layers, and "`last`" (by default) means updating the last layer, + as well as the first layer when `input_chns` is not 3. + """ + def __init__(self, params): + super(ResNet18, self).__init__(params) self.net = models.resnet18(pretrained = self.pretrain) # replace the last layer num_ftrs = self.net.fc.in_features - self.net.fc = nn.Linear(num_ftrs, cls_num) - - def forward(self, x): - return self.net(x) + self.net.fc = nn.Linear(num_ftrs, params['class_num']) + + # replace the first layer when in_chns is not 3 + if(self.in_chns != 3): + self.net.conv1 = nn.Conv2d(self.in_chns, 64, kernel_size=(7, 7), + stride=(2, 2), padding=(3, 3), bias=False) def get_parameters_to_update(self): - if(self.pretrain == False or self.update_layers == 0): + if(self.update_mode == "all"): return self.net.parameters() - elif(self.update_layers == -1): - return self.net.fc.parameters() + elif(self.update_layers == "last"): + params = self.net.fc.parameters() + if(self.in_chns !=3): + # combining the two iterables into a single one + # see: https://dzone.com/articles/python-joining-multiple + params = itertools.chain() + for pram in [self.net.fc.parameters(), self.net.conv1.parameters()]: + params = itertools.chain(params, pram) + return params else: - raise(ValueError("update_layers can only be 0 (all layers) " + - "or -1 (the last layer)")) + raise(ValueError("update_mode can only be 'all' or 'last'.")) -class VGG16(nn.Module): +class VGG16(BuiltInNet): + """ + VGG16 for classification. + Parameters should be set in the `params` dictionary that contains the + following fields: + + :param input_chns: (int) Input channel number, default is 3. + :param pretrain: (bool) Using pretrained model or not, default is True. + :param update_mode: (str) The strategy for updating layers: "`all`" means updating + all the layers, and "`last`" (by default) means updating the last layer, + as well as the first layer when `input_chns` is not 3. + """ def __init__(self, params): - super(VGG16, self).__init__() - self.params = params - cls_num = params['class_num'] - in_chns = params.get('input_chns', 3) - self.pretrain = params.get('pretrain', True) - self.update_layers = params.get('update_layers', 0) + super(VGG16, self).__init__(params) self.net = models.vgg16(pretrained = self.pretrain) # replace the last layer num_ftrs = self.net.classifier[-1].in_features - self.net.classifier[-1] = nn.Linear(num_ftrs, cls_num) - - def forward(self, x): - return self.net(x) + self.net.classifier[-1] = nn.Linear(num_ftrs, params['class_num']) + + # replace the first layer when in_chns is not 3 + if(self.in_chns != 3): + self.net.features[0] = nn.Conv2d(self.in_chns, 64, kernel_size=(3, 3), + stride=(1, 1), padding=(1, 1), bias=False) def get_parameters_to_update(self): - if(self.pretrain == False or self.update_layers == 0): + if(self.update_mode == "all"): return self.net.parameters() - elif(self.update_layers == -1): - return self.net.classifier[-1].parameters() + elif(self.update_mode == "last"): + params = self.net.classifier[-1].parameters() + if(self.in_chns !=3): + params = itertools.chain() + for pram in [self.net.classifier[-1].parameters(), self.net.net.features[0].parameters()]: + params = itertools.chain(params, pram) + return params else: - raise(ValueError("update_layers can only be 0 (all layers) " + - "or -1 (the last layer)")) + raise(ValueError("update_mode can only be 'all' or 'last'.")) + +class MobileNetV2(BuiltInNet): + """ + MobileNetV2 for classification. + Parameters should be set in the `params` dictionary that contains the + following fields: -class MobileNetV2(nn.Module): + :param input_chns: (int) Input channel number, default is 3. + :param pretrain: (bool) Using pretrained model or not, default is True. + :param update_mode: (str) The strategy for updating layers: "`all`" means updating + all the layers, and "`last`" (by default) means updating the last layer, + as well as the first layer when `input_chns` is not 3. + """ def __init__(self, params): super(MobileNetV2, self).__init__() - self.params = params - cls_num = params['class_num'] - in_chns = params.get('input_chns', 3) - self.pretrain = params.get('pretrain', True) - self.update_layers = params.get('update_layers', 0) self.net = models.mobilenet_v2(pretrained = self.pretrain) # replace the last layer num_ftrs = self.net.last_channel - self.net.classifier[-1] = nn.Linear(num_ftrs, cls_num) - - def forward(self, x): - return self.net(x) + self.net.classifier[-1] = nn.Linear(num_ftrs, params['class_num']) + + # replace the first layer when in_chns is not 3 + if(self.in_chns != 3): + self.net.features[0][0] = nn.Conv2d(self.in_chns, 32, kernel_size=(3, 3), + stride=(2, 2), padding=(1, 1), bias=False) def get_parameters_to_update(self): - if(self.pretrain == False or self.update_layers == 0): + if(self.update_mode == "all"): return self.net.parameters() - elif(self.update_layers == -1): - return self.net.classifier[-1].parameters() + elif(self.update_mode == "last"): + params = self.net.classifier[-1].parameters() + if(self.in_chns !=3): + params = itertools.chain() + for pram in [self.net.classifier[-1].parameters(), self.net.net.features[0][0].parameters()]: + params = itertools.chain(params, pram) + return params else: - raise(ValueError("update_layers can only be 0 (all layers) " + - "or -1 (the last layer)")) \ No newline at end of file + raise(ValueError("update_mode can only be 'all' or 'last'.")) + +if __name__ == "__main__": + params = {"class_num": 2, "pretrain": False, "input_chns": 3} + net = ResNet18(params) + print(net) \ No newline at end of file diff --git a/pymic/net/net2d/scse2d.py b/pymic/net/net2d/scse2d.py index 1963a3b..bd80713 100644 --- a/pymic/net/net2d/scse2d.py +++ b/pymic/net/net2d/scse2d.py @@ -1,11 +1,14 @@ +# -*- coding: utf-8 -*- """ -1. Channel Squeeze and Excitation -2. Spatial Squeeze and Excitation +2D implementation of: \n +1. Channel Squeeze and Excitation \n +2. Spatial Squeeze and Excitation \n 3. Concurrent Spatial and Channel Squeeze & Excitation Oringinal file is on `Github. `_ """ +from __future__ import print_function, division from enum import Enum import torch import torch.nn as nn diff --git a/pymic/net/net2d/unet2d.py b/pymic/net/net2d/unet2d.py index 9a8aca3..2708800 100644 --- a/pymic/net/net2d/unet2d.py +++ b/pymic/net/net2d/unet2d.py @@ -174,7 +174,7 @@ class UNet2D(nn.Module): MICCAI (3) 2015: 234-241 Note that there are some modifications from the original paper, such as - the use of batch normalization, dropout, and leaky relu here. + the use of batch normalization, dropout, leaky relu and deep supervision. Parameters are given in the `params` dictionary, and should include the following fields: diff --git a/pymic/net/net2d/unet2d_scse.py b/pymic/net/net2d/unet2d_scse.py index b1606c9..125843e 100644 --- a/pymic/net/net2d/unet2d_scse.py +++ b/pymic/net/net2d/unet2d_scse.py @@ -11,6 +11,10 @@ class ConvScSEBlock(nn.Module): Two convolutional blocks followed by `ChannelSpatialSELayer`. Each block consists of `Conv2d` + `BatchNorm2d` + `LeakyReLU`. A dropout layer is used between the wo blocks. + + :param in_channels: (int) Input channel number. + :param out_channels: (int) Output channel number. + :param dropout_p: (int) Dropout probability. """ def __init__(self, in_channels, out_channels, dropout_p): super(ConvScSEBlock, self).__init__() @@ -29,7 +33,12 @@ def forward(self, x): return self.conv_conv(x) class DownBlock(nn.Module): - """Downsampling followed by `ConvScSEBlock`.""" + """Downsampling followed by `ConvScSEBlock`. + + :param in_channels: (int) Input channel number. + :param out_channels: (int) Output channel number. + :param dropout_p: (int) Dropout probability. + """ def __init__(self, in_channels, out_channels, dropout_p): super(DownBlock, self).__init__() self.maxpool_conv = nn.Sequential( @@ -42,7 +51,14 @@ def forward(self, x): return self.maxpool_conv(x) class UpBlock(nn.Module): - """Up-sampling followed by `ConvScSEBlock`.""" + """Up-sampling followed by `ConvScSEBlock` in U-Net structure. + + :param in_channels1: (int) Input channel number for low-resolution feature map. + :param in_channels2: (int) Input channel number for high-resolution feature map. + :param out_channels: (int) Output channel number. + :param dropout_p: (int) Dropout probability. + :param bilinear: (bool) Use bilinear for up-sampling or not. + """ def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, bilinear=True): super(UpBlock, self).__init__() diff --git a/pymic/net/net2d/unet2d_urpc.py b/pymic/net/net2d/unet2d_urpc.py index 99d4b35..ee8ab7c 100644 --- a/pymic/net/net2d/unet2d_urpc.py +++ b/pymic/net/net2d/unet2d_urpc.py @@ -1,15 +1,4 @@ # -*- coding: utf-8 -*- -""" -An modification the U-Net to obtain multi-scale prediction according to -the URPC paper (MICCAI 2021): - Xiangde Luo, Wenjun Liao, Jienneg Chen, Tao Song, Yinan Chen, - Shichuan Zhang, Nianyong Chen, Guotai Wang, Shaoting Zhang: - Efficient Semi-Supervised Gross Target Volume of Nasopharyngeal Carcinoma - Segmentation via Uncertainty Rectified Pyramid Consistency. - MICCAI 2021: 318-329 - https://link.springer.com/chapter/10.1007/978-3-030-87196-3_30 -Also see: https://github.com/HiLab-git/SSL4MIS/blob/master/code/networks/unet.py -""" from __future__ import print_function, division import torch @@ -44,14 +33,37 @@ def forward(self, x): return x class UNet2D_URPC(nn.Module): + """ + An modification the U-Net to obtain multi-scale prediction according to + the URPC paper. + + * Reference: Xiangde Luo, Guotai Wang*, Wenjun Liao, Jieneng Chen, Tao Song, Yinan Chen, + Shichuan Zhang, Dimitris N. Metaxas, Shaoting Zhang. + Semi-Supervised Medical Image Segmentation via Uncertainty Rectified Pyramid Consistency . + `Medical Image Analysis 2022. `_ + + Also see: https://github.com/HiLab-git/SSL4MIS/blob/master/code/networks/unet.py + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param bilinear: (bool) Using bilinear for up-sampling or not. + If False, deconvolution will be used for up-sampling. + """ def __init__(self, params): super(UNet2D_URPC, self).__init__() self.params = params self.in_chns = self.params['in_chns'] self.ft_chns = self.params['feature_chns'] + self.dropout = self.params['dropout'] self.n_class = self.params['class_num'] self.bilinear = self.params['bilinear'] - self.dropout = self.params['dropout'] assert(len(self.ft_chns) == 5) self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) diff --git a/pymic/net/net3d/scse3d.py b/pymic/net/net3d/scse3d.py index ac0fa3c..df67689 100644 --- a/pymic/net/net3d/scse3d.py +++ b/pymic/net/net3d/scse3d.py @@ -1,11 +1,13 @@ """ -3D implementation of Spacial channel Squeeze and Excitation - `_ -`Spatial Squeeze and Excitation ` -***************************** +3D implementation of: \n +1. Channel Squeeze and Excitation \n +2. Spatial Squeeze and Excitation \n +3. Concurrent Spatial and Channel Squeeze & Excitation -oringinal file: https://github.com/maodong2056/squeeze_and_excitation/blob/master/squeeze_and_excitation/squeeze_and_excitation.py +Oringinal file is on `Github. +`_ """ +from __future__ import print_function, division from enum import Enum import torch @@ -14,15 +16,15 @@ class ChannelSELayer3D(nn.Module): """ - 3D implementation of Squeeze-and-Excitation (SE) block described in: - *Hu et al., Squeeze-and-Excitation Networks, arXiv:1709.01507* - """ - def __init__(self, num_channels, reduction_ratio=2): + 3D implementation of Squeeze-and-Excitation (SE) block. + + * Reference: Jie Hu, Li Shen, Gang Sun: Squeeze-and-Excitation Networks. + `CVPR 2018. `_ - """ - :param num_channels: No of input channels - :param reduction_ratio: By how much should the num_channels should be reduced - """ + :param num_channels: Number of input channels + :param reduction_ratio: By how much should the num_channels should be reduced + """ + def __init__(self, num_channels, reduction_ratio=2): super(ChannelSELayer3D, self).__init__() num_channels_reduced = num_channels // reduction_ratio self.reduction_ratio = reduction_ratio @@ -53,12 +55,12 @@ class SpatialSELayer3D(nn.Module): """ 3D Re-implementation of SE block -- squeezing spatially and exciting channel-wise described in: - *Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, MICCAI 2018* + * Reference: Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in + Fully Convolutional Networks, MICCAI 2018. + + :param num_channels: Number of input channels """ def __init__(self, num_channels): - """ - :param num_channels: No of input channels - """ super(SpatialSELayer3D, self).__init__() self.conv = nn.Conv3d(num_channels, 1, 1) self.sigmoid = nn.Sigmoid() @@ -90,14 +92,15 @@ def forward(self, input_tensor, weights=None): class ChannelSpatialSELayer3D(nn.Module): """ - 3D Re-implementation of concurrent spatial and channel squeeze & excitation: - *Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, MICCAI 2018, arXiv:1803.02579* + 3D Re-implementation of concurrent spatial and channel squeeze & excitation. + + * Reference: Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in + Fully Convolutional Networks, MICCAI 2018. + + :param num_channels: Number of input channels + :param reduction_ratio: By how much should the num_channels should be reduced """ def __init__(self, num_channels, reduction_ratio=2): - """ - :param num_channels: No of input channels - :param reduction_ratio: By how much should the num_channels should be reduced - """ super(ChannelSpatialSELayer3D, self).__init__() self.cSE = ChannelSELayer3D(num_channels, reduction_ratio) self.sSE = SpatialSELayer3D(num_channels) diff --git a/pymic/net/net3d/unet2d5.py b/pymic/net/net3d/unet2d5.py index 9e6a72d..308fdde 100644 --- a/pymic/net/net3d/unet2d5.py +++ b/pymic/net/net3d/unet2d5.py @@ -1,25 +1,20 @@ # -*- coding: utf-8 -*- -""" -A 2.5D network combining 3D convolutions with 2D convolutions according to the following paper: - Guotai Wang, Jonathan Shapey, Wenqi Li, Reuben Dorent, Alex Demitriadis, Sotirios Bisdas, - Ian Paddick, Robert Bradford, Shaoting Zhang, Sébastien Ourselin, Tom Vercauteren: - Automatic Segmentation of Vestibular Schwannoma from T2-Weighted MRI by Deep Spatial Attention - with Hardness-Weighted Loss. MICCAI (2) 2019: 264-272. -Note that the attention module in the orininal paper is not used here. -""" from __future__ import print_function, division import torch import torch.nn as nn import numpy as np class ConvBlockND(nn.Module): - """for 2D and 3D convolutional blocks""" + """ + 2D or 3D convolutional block + + :param in_channels: (int) Input channel number. + :param out_channels: (int) Output channel number. + :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. + :param dropout_p: (int) Dropout probability. + """ def __init__(self, in_channels, out_channels, dim = 2, dropout_p = 0.0): - """ - dim: should be 2 or 3 - dropout_p: probability to be zeroed - """ super(ConvBlockND, self).__init__() assert(dim == 2 or dim == 3) self.dim = dim @@ -49,8 +44,15 @@ def forward(self, x): return output class DownBlock(nn.Module): - """a convolutional block followed by downsampling""" - def __init__(self,in_channels, out_channels, + """`ConvBlockND` block followed by downsampling. + + :param in_channels: (int) Input channel number. + :param out_channels: (int) Output channel number. + :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. + :param dropout_p: (int) Dropout probability. + :param downsample: (bool) Use downsample or not after convolution. + """ + def __init__(self, in_channels, out_channels, dim = 2, dropout_p = 0.0, downsample = True): super(DownBlock, self).__init__() self.downsample = downsample @@ -86,7 +88,15 @@ def forward(self, x): return output, output_d class UpBlock(nn.Module): - """Upsampling followed by ConConvBlockNDvBlock""" + """Upsampling followed by `ConvBlockND` block + + :param in_channels1: (int) Input channel number for low-resolution feature map. + :param in_channels2: (int) Input channel number for high-resolution feature map. + :param out_channels: (int) Output channel number. + :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. + :param dropout_p: (int) Dropout probability. + :param bilinear: (bool) Use bilinear for up-sampling or not. + """ def __init__(self, in_channels1, in_channels2, out_channels, dim = 2, dropout_p = 0.0, bilinear=True): super(UpBlock, self).__init__() @@ -132,15 +142,41 @@ def forward(self, x1, x2): return output class UNet2D5(nn.Module): + """ + A 2.5D network combining 3D convolutions with 2D convolutions. + + * Reference: Guotai Wang, Jonathan Shapey, Wenqi Li, Reuben Dorent, Alex Demitriadis, + Sotirios Bisdas, Ian Paddick, Robert Bradford, Shaoting Zhang, Sébastien Ourselin, + Tom Vercauteren: Automatic Segmentation of Vestibular Schwannoma from T2-Weighted + MRI by Deep Spatial Attention with Hardness-Weighted Loss. + `MICCAI (2) 2019: 264-272. `_ + + Note that the attention module in the orininal paper is not used here. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param conv_dims: (list) The convolution dimension (2 or 3) for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param bilinear: (bool) Using bilinear for up-sampling or not. + If False, deconvolution will be used for up-sampling. + """ def __init__(self, params): super(UNet2D5, self).__init__() self.params = params self.in_chns = self.params['in_chns'] self.ft_chns = self.params['feature_chns'] + self.dropout = self.params['dropout'] self.dims = self.params['conv_dims'] self.n_class = self.params['class_num'] self.bilinear = self.params['bilinear'] - self.dropout = self.params['dropout'] + assert(len(self.ft_chns) == 5) self.block0 = DownBlock(self.in_chns, self.ft_chns[0], self.dims[0], self.dropout[0], True) diff --git a/pymic/net/net3d/unet3d.py b/pymic/net/net3d/unet3d.py index fdedf4d..af28d73 100644 --- a/pymic/net/net3d/unet3d.py +++ b/pymic/net/net3d/unet3d.py @@ -1,12 +1,4 @@ # -*- coding: utf-8 -*- -""" -An implementation of the 3D U-Net paper: - Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, Olaf Ronneberger: - 3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation. - MICCAI (2) 2016: 424-432 -Note that there are some modifications from the original paper, such as -the use of batch normalization, dropout, and leaky relu here. -""" from __future__ import print_function, division import torch @@ -15,11 +7,15 @@ from torch.nn.functional import interpolate class ConvBlock(nn.Module): - """two convolution layers with batch norm and leaky relu""" - def __init__(self,in_channels, out_channels, dropout_p): - """ - dropout_p: probability to be zeroed - """ + """ + Two 3D convolution layers with batch norm and leaky relu. + Droput is used between the two convolution layers. + + :param in_channels: (int) Input channel number. + :param out_channels: (int) Output channel number. + :param dropout_p: (int) Dropout probability. + """ + def __init__(self, in_channels, out_channels, dropout_p): super(ConvBlock, self).__init__() self.conv_conv = nn.Sequential( nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1), @@ -35,7 +31,13 @@ def forward(self, x): return self.conv_conv(x) class DownBlock(nn.Module): - """Downsampling followed by ConvBlock""" + """ + 3D downsampling followed by ConvBlock + + :param in_channels: (int) Input channel number. + :param out_channels: (int) Output channel number. + :param dropout_p: (int) Dropout probability. + """ def __init__(self, in_channels, out_channels, dropout_p): super(DownBlock, self).__init__() self.maxpool_conv = nn.Sequential( @@ -47,7 +49,16 @@ def forward(self, x): return self.maxpool_conv(x) class UpBlock(nn.Module): - """Upssampling followed by ConvBlock""" + """ + 3D upsampling followed by ConvBlock + + :param in_channels1: (int) Channel number of high-level features. + :param in_channels2: (int) Channel number of low-level features. + :param out_channels: (int) Output channel number. + :param dropout_p: (int) Dropout probability. + :param trilinear: (bool) Use trilinear for up-sampling (by default). + If False, deconvolution is used for up-sampling. + """ def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, trilinear=True): super(UpBlock, self).__init__() @@ -67,26 +78,30 @@ def forward(self, x1, x2): return self.conv(x) - -class UNetBlock(nn.Module): - def __init__(self,in_channels, out_channels, acti_func, acti_func_param): - super(UNetBlock, self).__init__() - - self.in_chns = in_channels - self.out_chns = out_channels - self.acti_func = acti_func - - self.conv1 = ConvolutionLayer(in_channels, out_channels, 3, - padding = 1, acti_func=get_acti_func(acti_func, acti_func_param)) - self.conv2 = ConvolutionLayer(out_channels, out_channels, 3, - padding = 1, acti_func=get_acti_func(acti_func, acti_func_param)) - - def forward(self, x): - x = self.conv1(x) - x = self.conv2(x) - return x - class UNet3D(nn.Module): + """ + An implementation of the U-Net. + + * Reference: Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, Olaf Ronneberger: + 3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation. + `MICCAI (2) 2016: 424-432. `_ + + Note that there are some modifications from the original paper, such as + the use of batch normalization, dropout, leaky relu and deep supervision. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param trilinear: (bool) Using trilinear for up-sampling or not. + If False, deconvolution will be used for up-sampling. + :param deep_supervise: (bool) Using deep supervision for training or not. + """ def __init__(self, params): super(UNet3D, self).__init__() self.params = params diff --git a/pymic/net/net3d/unet3d_scse.py b/pymic/net/net3d/unet3d_scse.py index 0f15e25..b2da0dc 100644 --- a/pymic/net/net3d/unet3d_scse.py +++ b/pymic/net/net3d/unet3d_scse.py @@ -1,20 +1,21 @@ # -*- coding: utf-8 -*- -""" -Combining 3D U-Net with SCSE module according to the following paper: - Abhijit Guha Roy, Nassir Navab, Christian Wachinger: - Recalibrating Fully Convolutional Networks With Spatial and Channel "Squeeze and Excitation" Blocks. \ - IEEE Trans. Med. Imaging 38(2): 540-549 (2019) -""" from __future__ import print_function, division - import torch import torch.nn as nn import numpy as np from pymic.net.net3d.scse3d import * class ConvScSEBlock3D(nn.Module): - """two convolution layers with batch norm and leaky relu""" - def __init__(self,in_channels, out_channels, dropout_p): + """ + Two 3D convolutional blocks followed by `ChannelSpatialSELayer3D`. + Each block consists of `Conv3d` + `BatchNorm3d` + `LeakyReLU`. + A dropout layer is used between the wo blocks. + + :param in_channels: (int) Input channel number. + :param out_channels: (int) Output channel number. + :param dropout_p: (int) Dropout probability. + """ + def __init__(self, in_channels, out_channels, dropout_p): super(ConvScSEBlock3D, self).__init__() self.conv_conv = nn.Sequential( nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1), @@ -31,7 +32,12 @@ def forward(self, x): return self.conv_conv(x) class DownBlock(nn.Module): - """Downsampling followed by ConvBlock""" + """3D Downsampling followed by `ConvScSEBlock3D`. + + :param in_channels: (int) Input channel number. + :param out_channels: (int) Output channel number. + :param dropout_p: (int) Dropout probability. + """ def __init__(self, in_channels, out_channels, dropout_p): super(DownBlock, self).__init__() self.maxpool_conv = nn.Sequential( @@ -43,12 +49,19 @@ def forward(self, x): return self.maxpool_conv(x) class UpBlock(nn.Module): - """Upssampling followed by ConvBlock""" + """3D Up-sampling followed by `ConvScSEBlock3D` in UNet3D_ScSE. + + :param in_channels1: (int) Input channel number for low-resolution feature map. + :param in_channels2: (int) Input channel number for high-resolution feature map. + :param out_channels: (int) Output channel number. + :param dropout_p: (int) Dropout probability. + :param trilinear: (bool) Use trilinear for up-sampling or not. + """ def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, - bilinear=True): + trilinear=True): super(UpBlock, self).__init__() - self.bilinear = bilinear - if bilinear: + self.trilinear = trilinear + if trilinear: self.conv1x1 = nn.Conv3d(in_channels1, in_channels2, kernel_size = 1) self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True) else: @@ -56,21 +69,42 @@ def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, self.conv = ConvScSEBlock3D(in_channels2 * 2, out_channels, dropout_p) def forward(self, x1, x2): - if self.bilinear: + if self.trilinear: x1 = self.conv1x1(x1) x1 = self.up(x1) x = torch.cat([x2, x1], dim=1) return self.conv(x) class UNet3D_ScSE(nn.Module): + """ + Combining 3D U-Net with SCSE module. + + * Reference: Abhijit Guha Roy, Nassir Navab, Christian Wachinger: + Recalibrating Fully Convolutional Networks With Spatial and Channel + "Squeeze and Excitation" Blocks. + `IEEE Trans. Med. Imaging 38(2): 540-549 (2019). `_ + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param trilinear: (bool) Using trilinear for up-sampling or not. + If False, deconvolution will be used for up-sampling. + """ def __init__(self, params): super(UNet3D_ScSE, self).__init__() self.params = params self.in_chns = self.params['in_chns'] self.ft_chns = self.params['feature_chns'] + self.dropout = self.params['dropout'] self.n_class = self.params['class_num'] self.bilinear = self.params['trilinear'] - self.dropout = self.params['dropout'] + assert(len(self.ft_chns) == 5) self.in_conv= ConvScSEBlock3D(self.in_chns, self.ft_chns[0], self.dropout[0]) diff --git a/pymic/net/net_dict_cls.py b/pymic/net/net_dict_cls.py index b03ea20..7996e59 100644 --- a/pymic/net/net_dict_cls.py +++ b/pymic/net/net_dict_cls.py @@ -1,4 +1,12 @@ # -*- coding: utf-8 -*- +""" +Built-in networks for classification. + +* resnet18 :mod:`pymic.net.cls.torch_pretrained_net.ResNet18` +* vgg16 :mod:`pymic.net.cls.torch_pretrained_net.VGG16` +* mobilenetv2 :mod:`pymic.net.cls.torch_pretrained_net.MobileNetV2` +""" + from __future__ import print_function, division from pymic.net.cls.torch_pretrained_net import * From 959fd60459deeaef866ab409e21fbb7fe1d8f3ac Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 30 Sep 2022 13:50:32 +0800 Subject: [PATCH 113/225] update evaluation docs update evaluation docs --- docs/source/usage.quickstart.rst | 48 ++++++++++++++++++++++++++++++++ pymic/io/h5_dataset.py | 5 +--- requirements.txt | 2 +- 3 files changed, 50 insertions(+), 5 deletions(-) diff --git a/docs/source/usage.quickstart.rst b/docs/source/usage.quickstart.rst index e1cada6..95cf20d 100644 --- a/docs/source/usage.quickstart.rst +++ b/docs/source/usage.quickstart.rst @@ -115,3 +115,51 @@ file used for segmentation of lung from radiograph, which can be find in label_source = [0, 1] label_target = [0, 255] + +Evaluation +---------- + +To evaluate a model's prediction results compared with the ground truth, +use the ``pymic_eval_seg`` and ``pymic_eval_cls`` commands for segmentation +and classfication tasks, respectively. Both of them accept a configuration +file to specify the evaluation metrics, predicted results, ground truth and +other information. + +For example, for segmentation tasks, run: + +.. code-block:: none + + pymic_eval_seg evaluation.cfg + +The configuration file is like (an example from ``PYMIC_examples/seg_ssl/ACDC``): + +.. code-block:: none + + [evaluation] + metric = dice + label_list = [1,2,3] + organ_name = heart + + ground_truth_folder_root = ../../PyMIC_data/ACDC/preprocess + segmentation_folder_root = result/unet2d_em + evaluation_image_pair = config/data/image_test_gt_seg.csv + +See :mod:`pymic.util.evaluation_seg.evaluation` for details of the configuration required. + +For classification tasks, run: + +.. code-block:: none + + pymic_eval_cls evaluation.cfg + +The configuration file is like (an example from ``PYMIC_examples/classification/CHNCXR``): + +.. code-block:: none + + [evaluation] + metric_list = [accuracy, auc] + ground_truth_csv = config/cxr_test.csv + predict_csv = result/resnet18.csv + predict_prob_csv = result/resnet18_prob.csv + +See :mod:`pymic.util.evaluation_cls.main` for details of the configuration required. diff --git a/pymic/io/h5_dataset.py b/pymic/io/h5_dataset.py index c6e5298..02f94f3 100644 --- a/pymic/io/h5_dataset.py +++ b/pymic/io/h5_dataset.py @@ -1,13 +1,11 @@ # -*- coding: utf-8 -*- from logging import root -import os -from re import S +import os import torch import random import h5py import pandas as pd -from scipy import ndimage from torch.utils.data import Dataset from torch.utils.data.sampler import Sampler @@ -40,7 +38,6 @@ def __getitem__(self, idx): sample = {'image': image, 'label': label} if self.transform: sample = self.transform(sample) - # sample["idx"] = idx return sample class TwoStreamBatchSampler(Sampler): diff --git a/requirements.txt b/requirements.txt index c4c9ac3..c8cd562 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ pandas>=0.25.3 scikit-image>=0.16.2 scikit-learn>=0.22 scipy>=1.3.3 -SimpleITK>=1.2.4 +SimpleITK>=2.0.0 tensorboard>=2.1.0 tensorboardX>=1.9 torch>=1.7.1 From 04211f4b57f908e196391cbfff9d9a76c87c1dae Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 30 Sep 2022 16:37:39 +0800 Subject: [PATCH 114/225] Update README.md --- README.md | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 3ce5929..1cf7b3c 100644 --- a/README.md +++ b/README.md @@ -2,13 +2,24 @@ PyMIC is a pytorch-based toolkit for medical image computing with annotation-efficient deep learning. Despite that pytorch is a fantastic platform for deep learning, using it for medical image computing is not straightforward as medical images are often with high dimension and large volume, multiple modalities and difficulies in annotating. This toolkit is developed to facilitate medical image computing researchers so that training and testing deep learning models become easier. It is very friendly to researchers who are new to this area. Even without writing any code, you can use PyMIC commands to train and test a model by simply editing configuration files. PyMIC is developed to support learning with imperfect labels, including semi-supervised and weakly supervised learning, and learning with noisy annotations. -Currently PyMIC supports 2D/3D medical image classification and segmentation, and it is still under development. It was originally developed for COVID-19 pneumonia lesion segmentation from CT images. If you use this toolkit, please cite the following paper: +Currently PyMIC supports 2D/3D medical image classification and segmentation, and it is still under development. If you use this toolkit, please cite the following paper: -* G. Wang, X. Liu, C. Li, Z. Xu, J. Ruan, H. Zhu, T. Meng, K. Li, N. Huang, S. Zhang. -[A Noise-robust Framework for Automatic Segmentation of COVID-19 Pneumonia Lesions from CT Images.][tmi2020] IEEE Transactions on Medical Imaging. 39(8):2653-2663, 2020. DOI: [10.1109/TMI.2020.3000314][tmi2020] +* G. Wang, X. Luo, R. Gu, S. Yang, Y. Qu, S. Zhai, Q. Zhao, K. Li, S. Zhang. (2022). +[PyMIC: A deep learning toolkit for annotation-efficient medical image segmentation.][arxiv2022] arXiv, 2208.09350. -[tmi2020]:https://ieeexplore.ieee.org/document/9109297 +[arxiv2022]:http://arxiv.org/abs/2208.09350 +BibTeX entry: + + @article{Wang2022pymic, + author = {Guotai Wang and Xiangde Luo and Ran Gu and Shuojue Yang and Yijie Qu and Shuwei Zhai and Qianfei Zhao and Kang Li and Shaoting Zhang}, + title = {{PyMIC: A deep learning toolkit for annotation-efficient medical image segmentation}}, + year = {2022}, + url = {http://arxiv.org/abs/2208.09350}, + journal = {arXiv}, + volume = {2208.09350}, + pages = {1-10}, + } # Features PyMIC provides flixible modules for medical image computing tasks including classification and segmentation. It currently provides the following functions: From 55d2edc3f982a9be0e0fa739542faec28f8180f5 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 1 Oct 2022 10:23:44 +0800 Subject: [PATCH 115/225] update classification agent --- pymic/loss/seg/ce.py | 1 - pymic/net_run/agent_cls.py | 8 +++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/pymic/loss/seg/ce.py b/pymic/loss/seg/ce.py index bbe6d02..529482b 100644 --- a/pymic/loss/seg/ce.py +++ b/pymic/loss/seg/ce.py @@ -18,7 +18,6 @@ class CrossEntropyLoss(AbstractSegLoss): """ def __init__(self, params = None): super(CrossEntropyLoss, self).__init__(params) - def forward(self, loss_input_dict): predict = loss_input_dict['prediction'] diff --git a/pymic/net_run/agent_cls.py b/pymic/net_run/agent_cls.py index 8ab7729..46bb2d7 100644 --- a/pymic/net_run/agent_cls.py +++ b/pymic/net_run/agent_cls.py @@ -145,7 +145,8 @@ def training(self): self.optimizer.zero_grad() # forward + backward + optimize outputs = self.net(inputs) - loss = self.get_loss_value(data, inputs, outputs, labels) + + loss = self.get_loss_value(data, outputs, labels) loss.backward() self.optimizer.step() self.scheduler.step() @@ -175,7 +176,7 @@ def validation(self): self.optimizer.zero_grad() # forward + backward + optimize outputs = self.net(inputs) - loss = self.get_loss_value(data, inputs, outputs, labels) + loss = self.get_loss_value(data, outputs, labels) # statistics sample_num += labels.size(0) @@ -243,10 +244,11 @@ def train_valid(self): logging.info("{0:} training start".format(str(datetime.now())[:-7])) self.summ_writer = SummaryWriter(self.config['training']['ckpt_save_dir']) for it in range(iter_start, iter_max, iter_valid): + lr_value = self.optimizer.param_groups[0]['lr'] train_scalars = self.training() valid_scalars = self.validation() glob_it = it + iter_valid - self.write_scalars(train_scalars, valid_scalars, glob_it) + self.write_scalars(train_scalars, valid_scalars, lr_value, glob_it) if(valid_scalars[metrics] > self.max_val_score): self.max_val_score = valid_scalars[metrics] From ac1c8fcb63811780c4a7597c36e599f5db51470e Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 1 Oct 2022 10:59:26 +0800 Subject: [PATCH 116/225] Update torch_pretrained_net.py --- pymic/net/cls/torch_pretrained_net.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pymic/net/cls/torch_pretrained_net.py b/pymic/net/cls/torch_pretrained_net.py index edc5d9c..5017f72 100644 --- a/pymic/net/cls/torch_pretrained_net.py +++ b/pymic/net/cls/torch_pretrained_net.py @@ -75,7 +75,7 @@ def __init__(self, params): def get_parameters_to_update(self): if(self.update_mode == "all"): return self.net.parameters() - elif(self.update_layers == "last"): + elif(self.update_mode == "last"): params = self.net.fc.parameters() if(self.in_chns !=3): # combining the two iterables into a single one @@ -119,7 +119,7 @@ def get_parameters_to_update(self): params = self.net.classifier[-1].parameters() if(self.in_chns !=3): params = itertools.chain() - for pram in [self.net.classifier[-1].parameters(), self.net.net.features[0].parameters()]: + for pram in [self.net.classifier[-1].parameters(), self.net.features[0].parameters()]: params = itertools.chain(params, pram) return params else: @@ -138,7 +138,7 @@ class MobileNetV2(BuiltInNet): as well as the first layer when `input_chns` is not 3. """ def __init__(self, params): - super(MobileNetV2, self).__init__() + super(MobileNetV2, self).__init__(params) self.net = models.mobilenet_v2(pretrained = self.pretrain) # replace the last layer @@ -157,7 +157,7 @@ def get_parameters_to_update(self): params = self.net.classifier[-1].parameters() if(self.in_chns !=3): params = itertools.chain() - for pram in [self.net.classifier[-1].parameters(), self.net.net.features[0][0].parameters()]: + for pram in [self.net.classifier[-1].parameters(), self.net.features[0][0].parameters()]: params = itertools.chain(params, pram) return params else: From c5415254375657ea35231bbe1a98caaafef2f336 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 1 Oct 2022 14:06:34 +0800 Subject: [PATCH 117/225] enable ReduceLROnPlateau --- pymic/net_run/agent_cls.py | 43 +++++++++++++++++++++++++++----------- pymic/net_run/agent_seg.py | 2 +- 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/pymic/net_run/agent_cls.py b/pymic/net_run/agent_cls.py index 46bb2d7..cbb79d4 100644 --- a/pymic/net_run/agent_cls.py +++ b/pymic/net_run/agent_cls.py @@ -5,11 +5,12 @@ import csv import logging import time -import torch -from torchvision import transforms import numpy as np +import torch import torch.nn as nn from datetime import datetime +from torch.optim import lr_scheduler +from torchvision import transforms from tensorboardX import SummaryWriter from pymic.io.nifty_dataset import ClassificationDataset from pymic.loss.loss_dict_cls import PyMICClsLossDict @@ -149,7 +150,9 @@ def training(self): loss = self.get_loss_value(data, outputs, labels) loss.backward() self.optimizer.step() - self.scheduler.step() + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): + self.scheduler.step() # statistics sample_num += labels.size(0) @@ -185,7 +188,9 @@ def validation(self): avg_loss = running_loss / sample_num avg_score= running_score.double() / sample_num - metrics =self.config['training'].get("evaluation_metric", "accuracy") + metrics = self.config['training'].get("evaluation_metric", "accuracy") + if(isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): + self.scheduler.step(avg_score) valid_scalers = {'loss': avg_loss, metrics: avg_score} return valid_scalers @@ -222,7 +227,15 @@ def train_valid(self): iter_max = self.config['training']['iter_max'] iter_valid = self.config['training']['iter_valid'] iter_save = self.config['training']['iter_save'] + early_stop_it = self.config['training'].get('early_stop_patience', None) metrics = self.config['training'].get("evaluation_metric", "accuracy") + if(iter_save is None): + iter_save_list = [iter_max] + elif(isinstance(iter_save, (tuple, list))): + iter_save_list = iter_save + else: + iter_save_list = range(0, iter_max + 1, iter_save) + self.max_val_score = 0.0 self.max_val_it = 0 self.best_model_wts = None @@ -243,29 +256,35 @@ def train_valid(self): logging.info("{0:} training start".format(str(datetime.now())[:-7])) self.summ_writer = SummaryWriter(self.config['training']['ckpt_save_dir']) + self.glob_it = iter_start for it in range(iter_start, iter_max, iter_valid): lr_value = self.optimizer.param_groups[0]['lr'] train_scalars = self.training() valid_scalars = self.validation() - glob_it = it + iter_valid - self.write_scalars(train_scalars, valid_scalars, lr_value, glob_it) + self.glob_it = it + iter_valid + self.write_scalars(train_scalars, valid_scalars, lr_value, self.glob_it) if(valid_scalars[metrics] > self.max_val_score): self.max_val_score = valid_scalars[metrics] - self.max_val_it = glob_it + self.max_val_it = self.glob_it self.best_model_wts = copy.deepcopy(self.net.state_dict()) - if (glob_it % iter_save == 0): - save_dict = {'iteration': glob_it, + stop_now = True if(early_stop_it is not None and \ + self.glob_it - self.max_val_it > early_stop_it) else False + + if ((self.glob_it in iter_save_list) or stop_now): + save_dict = {'iteration': self.glob_it, 'valid_pred': valid_scalars[metrics], 'model_state_dict': self.net.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict()} - save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, glob_it) + save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.glob_it) torch.save(save_dict, save_name) txt_file = open("{0:}/{1:}_latest.txt".format(ckpt_dir, ckpt_prefix), 'wt') - txt_file.write(str(glob_it)) + txt_file.write(str(self.glob_it)) txt_file.close() - + if(stop_now): + logging.info("The training is early stopped") + break # save the best performing checkpoint save_dict = {'iteration': self.max_val_it, 'valid_pred': self.max_val_score, diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index bba4508..9220656 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -10,9 +10,9 @@ import numpy as np import torch.nn as nn import torch.optim as optim -from torch.optim import lr_scheduler import torch.nn.functional as F from datetime import datetime +from torch.optim import lr_scheduler from tensorboardX import SummaryWriter from pymic.io.image_read_write import save_nd_array_as_image from pymic.io.nifty_dataset import NiftyDataset From bd38bc983d000879196ae3070392e98a45909040 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 1 Oct 2022 16:00:47 +0800 Subject: [PATCH 118/225] support mixup --- pymic/loss/cls/basic.py | 10 ++------ pymic/net_run/agent_cls.py | 21 ++++++++++------ pymic/net_run/agent_seg.py | 7 +++++- pymic/util/general.py | 50 +++++++++++++++++++++++++++++++++++++- 4 files changed, 71 insertions(+), 17 deletions(-) diff --git a/pymic/loss/cls/basic.py b/pymic/loss/cls/basic.py index 4c90943..56925fc 100644 --- a/pymic/loss/cls/basic.py +++ b/pymic/loss/cls/basic.py @@ -65,10 +65,7 @@ def forward(self, loss_input_dict): labels = loss_input_dict['ground_truth'][:, None] # reshape to N, 1 softmax = nn.Softmax(dim = 1) predict = softmax(predict) - num_class = list(predict.size())[1] - data_type = 'float' if(predict.dtype is torch.float32) else 'double' - soft_y = get_soft_label(labels, num_class, data_type) - loss = self.l1_loss(predict, soft_y) + loss = self.l1_loss(predict, labels) return loss class MSELoss(AbstractClassificationLoss): @@ -84,10 +81,7 @@ def forward(self, loss_input_dict): labels = loss_input_dict['ground_truth'][:, None] # reshape to N, 1 softmax = nn.Softmax(dim = 1) predict = softmax(predict) - num_class = list(predict.size())[1] - data_type = 'float' if(predict.dtype is torch.float32) else 'double' - soft_y = get_soft_label(labels, num_class, data_type) - loss = self.mse_loss(predict, soft_y) + loss = self.mse_loss(predict, labels) return loss class NLLLoss(AbstractClassificationLoss): diff --git a/pymic/net_run/agent_cls.py b/pymic/net_run/agent_cls.py index cbb79d4..f9f7781 100644 --- a/pymic/net_run/agent_cls.py +++ b/pymic/net_run/agent_cls.py @@ -9,6 +9,7 @@ import torch import torch.nn as nn from datetime import datetime +from random import random from torch.optim import lr_scheduler from torchvision import transforms from tensorboardX import SummaryWriter @@ -17,6 +18,7 @@ from pymic.net.net_dict_cls import TorchClsNetDict from pymic.transform.trans_dict import TransformDict from pymic.net_run.agent_abstract import NetRunAgent +from pymic.util.general import mixup import warnings warnings.filterwarnings('ignore', '.*output shape of zoom.*') @@ -111,16 +113,17 @@ def get_evaluation_score(self, outputs, labels): """ Get evaluation score for a prediction. - :param outputs: (tensor) Prediction obtained by a network. - :param labels: (tensor) The ground truth. + :param outputs: (tensor) Prediction obtained by a network with size N X C. + :param labels: (tensor) The ground truth with size N X C. """ metrics = self.config['training'].get("evaluation_metric", "accuracy") if(metrics != "accuracy"): # default classification accuracy raise ValueError("Not implemeted for metric {0:}".format(metrics)) if(self.task_type == "cls"): - _, preds = torch.max(outputs, 1) - consis= self.convert_tensor_type(preds == labels.data) - score = torch.mean(consis) + out_argmax = torch.argmax(outputs, 1) + lab_argmax = torch.argmax(labels, 1) + consis = self.convert_tensor_type(out_argmax == lab_argmax) + score = torch.mean(consis) elif(self.task_type == "cls_nexcl"): #nonexclusive classification preds = self.convert_tensor_type(outputs > 0.5) consis= self.convert_tensor_type(preds == labels.data) @@ -129,6 +132,7 @@ def get_evaluation_score(self, outputs, labels): def training(self): iter_valid = self.config['training']['iter_valid'] + mixup_prob = self.config['training'].get('mixup_probability', 0.5) sample_num = 0 running_loss = 0 running_score= 0 @@ -140,8 +144,11 @@ def training(self): self.trainIter = iter(self.train_loader) data = next(self.trainIter) inputs = self.convert_tensor_type(data['image']) - labels = data['label'].long() + labels = self.convert_tensor_type(data['label_prob']) + if(random() < mixup_prob): + inputs, labels = mixup(inputs, labels) inputs, labels = inputs.to(self.device), labels.to(self.device) + # zero the parameter gradients self.optimizer.zero_grad() # forward + backward + optimize @@ -174,7 +181,7 @@ def validation(self): self.net.eval() for data in validIter: inputs = self.convert_tensor_type(data['image']) - labels = data['label'].long() + labels = self.convert_tensor_type(data['label_prob']) inputs, labels = inputs.to(self.device), labels.to(self.device) self.optimizer.zero_grad() # forward + backward + optimize diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 9220656..620f29b 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -12,6 +12,7 @@ import torch.optim as optim import torch.nn.functional as F from datetime import datetime +from random import random from torch.optim import lr_scheduler from tensorboardX import SummaryWriter from pymic.io.image_read_write import save_nd_array_as_image @@ -28,6 +29,7 @@ from pymic.transform.trans_dict import TransformDict from pymic.util.post_process import PostProcessDict from pymic.util.image_process import convert_label +from pymic.util.general import mixup class SegmentationAgent(NetRunAgent): def __init__(self, config, stage = 'train'): @@ -120,6 +122,7 @@ def set_postprocessor(self, postprocessor): def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] + mixup_prob = self.config['training'].get('mixup_probability', 0.5) train_loss = 0 train_dice_list = [] self.net.train() @@ -132,7 +135,9 @@ def training(self): # get the inputs inputs = self.convert_tensor_type(data['image']) labels_prob = self.convert_tensor_type(data['label_prob']) - + if(random() < mixup_prob): + inputs, labels_prob = mixup(inputs, labels_prob) + # # for debug # for i in range(inputs.shape[0]): # image_i = inputs[i][0] diff --git a/pymic/util/general.py b/pymic/util/general.py index 75b6af1..99eb49f 100644 --- a/pymic/util/general.py +++ b/pymic/util/general.py @@ -29,4 +29,52 @@ def get_one_hot_seg(label, class_num): one_hot = one_hot.view(*size) one_hot = torch.transpose(one_hot, 1, -1) one_hot = torch.squeeze(one_hot, -1) - return one_hot \ No newline at end of file + return one_hot + +def mixup(inputs, labels): + """Shuffle a minibatch and do linear interpolation between images and labels. + Both classification and segmentation labels are supported. The targets should + be one-hot labels. + + :param inputs: a tensor of input images with size N X C0 x H x W. + :param labels: a tensor of one-hot labels. The shape is N X C for classification + tasks, and N X C X H X W for segmentation tasks. + """ + input_shape = list(inputs.shape) + label_shape = list(labels.shape) + img_dim = len(input_shape) - 2 + N = input_shape[0] # batch size + C = label_shape[1] # class number + rp1 = torch.randperm(N) + inputs1 = inputs[rp1] + labels1 = labels[rp1] + + rp2 = torch.randperm(N) + inputs2 = inputs[rp2] + labels2 = labels[rp2] + + a = np.random.beta(1, 1, [N, 1]) + if(img_dim == 2): + b = np.tile(a[..., None, None], [1] + input_shape[1:]) + elif(img_dim == 3): + b = np.tile(a[..., None, None, None], [1] + input_shape[1:]) + else: + raise ValueError("MixUp only supports 2D and 3D images, but the " + + "input image has {0:} dimensions".format(img_dim)) + + inputs1 = inputs1 * torch.from_numpy(b).float() + inputs2 = inputs2 * torch.from_numpy(1 - b).float() + inputs_mix = inputs1 + inputs2 + + if(len(label_shape) == 2): # for classification tasks + c = np.tile(a, [1, C]) + elif(img_dim == 2): # for 2D segmentation tasks + c = np.tile(a[..., None, None], [1] + label_shape[1:]) + else: # for 3D segmentation tasks + c = np.tile(a[..., None, None, None], [1] + label_shape[1:]) + + labels1 = labels1 * torch.from_numpy(c).float() + labels2 = labels2 * torch.from_numpy(1 - c).float() + labels_mix = labels1 + labels2 + + return inputs_mix, labels_mix From 981d47acf0550fd93af3c15cde0ec31ba49e339b Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 22 Nov 2022 23:00:47 +0800 Subject: [PATCH 119/225] update network unet2d_urpc is not needed as deep supervision is supported by unet2d and unet3d now --- pymic/net/net2d/unet2d_urpc.py | 132 --------------------------------- pymic/net/net_dict_seg.py | 3 - 2 files changed, 135 deletions(-) delete mode 100644 pymic/net/net2d/unet2d_urpc.py diff --git a/pymic/net/net2d/unet2d_urpc.py b/pymic/net/net2d/unet2d_urpc.py deleted file mode 100644 index ee8ab7c..0000000 --- a/pymic/net/net2d/unet2d_urpc.py +++ /dev/null @@ -1,132 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division - -import torch -import torch.nn as nn -import numpy as np -from torch.distributions.uniform import Uniform -from pymic.net.net2d.unet2d import ConvBlock, DownBlock, UpBlock - -def FeatureDropout(x): - attention = torch.mean(x, dim=1, keepdim=True) - max_val, _ = torch.max(attention.view( - x.size(0), -1), dim=1, keepdim=True) - threshold = max_val * np.random.uniform(0.7, 0.9) - threshold = threshold.view(x.size(0), 1, 1, 1).expand_as(attention) - drop_mask = (attention < threshold).float() - x = x.mul(drop_mask) - return x - -class FeatureNoise(nn.Module): - def __init__(self, uniform_range=0.3): - super(FeatureNoise, self).__init__() - self.uni_dist = Uniform(-uniform_range, uniform_range) - - def feature_based_noise(self, x): - noise_vector = self.uni_dist.sample( - x.shape[1:]).to(x.device).unsqueeze(0) - x_noise = x.mul(noise_vector) + x - return x_noise - - def forward(self, x): - x = self.feature_based_noise(x) - return x - -class UNet2D_URPC(nn.Module): - """ - An modification the U-Net to obtain multi-scale prediction according to - the URPC paper. - - * Reference: Xiangde Luo, Guotai Wang*, Wenjun Liao, Jieneng Chen, Tao Song, Yinan Chen, - Shichuan Zhang, Dimitris N. Metaxas, Shaoting Zhang. - Semi-Supervised Medical Image Segmentation via Uncertainty Rectified Pyramid Consistency . - `Medical Image Analysis 2022. `_ - - Also see: https://github.com/HiLab-git/SSL4MIS/blob/master/code/networks/unet.py - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param bilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - """ - def __init__(self, params): - super(UNet2D_URPC, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.n_class = self.params['class_num'] - self.bilinear = self.params['bilinear'] - assert(len(self.ft_chns) == 5) - - self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) - self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) - self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) - self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) - self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], 0.0, self.bilinear) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], 0.0, self.bilinear) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], 0.0, self.bilinear) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], 0.0, self.bilinear) - - self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, - kernel_size = 3, padding = 1) - self.out_conv_dp4 = nn.Conv2d(self.ft_chns[4], self.n_class, - kernel_size=3, padding=1) - self.out_conv_dp3 = nn.Conv2d(self.ft_chns[3], self.n_class, - kernel_size=3, padding=1) - self.out_conv_dp2 = nn.Conv2d(self.ft_chns[2], self.n_class, - kernel_size=3, padding=1) - self.out_conv_dp1 = nn.Conv2d(self.ft_chns[1], self.n_class, - kernel_size=3, padding=1) - self.feature_noise = FeatureNoise() - - def forward(self, x): - x_shape = list(x.shape) - if(len(x_shape) == 5): - [N, C, D, H, W] = x_shape - new_shape = [N*D, C, H, W] - x = torch.transpose(x, 1, 2) - x = torch.reshape(x, new_shape) - x0 = self.in_conv(x) - x1 = self.down1(x0) - x2 = self.down2(x1) - x3 = self.down3(x2) - x4 = self.down4(x3) - - x = self.up1(x4, x3) - if self.training: - x = nn.functional.dropout(x, p=0.5) - dp3_out = self.out_conv_dp3(x) - - x = self.up2(x, x2) - if self.training: - x = FeatureDropout(x) - dp2_out = self.out_conv_dp2(x) - - x = self.up3(x, x1) - if self.training: - x = self.feature_noise(x) - dp1_out = self.out_conv_dp1(x) - - x = self.up4(x, x0) - dp0_out = self.out_conv(x) - - out_shape = list(dp0_out.shape)[2:] - dp3_out = nn.functional.interpolate(dp3_out, out_shape) - dp2_out = nn.functional.interpolate(dp2_out, out_shape) - dp1_out = nn.functional.interpolate(dp1_out, out_shape) - out = [dp0_out, dp1_out, dp2_out, dp3_out] - - if(len(x_shape) == 5): - new_shape = [N, D] + list(dp0_out.shape)[1:] - for i in range(len(out)): - out[i] = torch.transpose(torch.reshape(out[i], new_shape), 1, 2) - return out \ No newline at end of file diff --git a/pymic/net/net_dict_seg.py b/pymic/net/net_dict_seg.py index 5dd3bee..195896a 100644 --- a/pymic/net/net_dict_seg.py +++ b/pymic/net/net_dict_seg.py @@ -4,7 +4,6 @@ * UNet2D :mod:`pymic.net.net2d.unet2d.UNet2D` * UNet2D_DualBranch :mod:`pymic.net.net2d.unet2d_dual_branch.UNet2D_DualBranch` -* UNet2D_URPC :mod:`pymic.net.net2d.unet2d_urpc.UNet2D_URPC` * UNet2D_CCT :mod:`pymic.net.net2d.unet2d_cct.UNet2D_CCT` * UNet2D_ScSE :mod:`pymic.net.net2d.unet2d_scse.UNet2D_ScSE` * AttentionUNet2D :mod:`pymic.net.net2d.unet2d_attention.AttentionUNet2D` @@ -17,7 +16,6 @@ from __future__ import print_function, division from pymic.net.net2d.unet2d import UNet2D from pymic.net.net2d.unet2d_dual_branch import UNet2D_DualBranch -from pymic.net.net2d.unet2d_urpc import UNet2D_URPC from pymic.net.net2d.unet2d_cct import UNet2D_CCT from pymic.net.net2d.cople_net import COPLENet from pymic.net.net2d.unet2d_attention import AttentionUNet2D @@ -30,7 +28,6 @@ SegNetDict = { 'UNet2D': UNet2D, 'UNet2D_DualBranch': UNet2D_DualBranch, - 'UNet2D_URPC': UNet2D_URPC, 'UNet2D_CCT': UNet2D_CCT, 'COPLENet': COPLENet, 'AttentionUNet2D': AttentionUNet2D, From 647674dd30249162b0c20e6b71b969c7ee739d0d Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 6 Dec 2022 11:03:10 +0800 Subject: [PATCH 120/225] enable automatic installation of dependencies --- pymic/loss/seg/ssl.py | 5 +- pymic/util/evaluation_seg.py | 115 ++++++++++++++++++----------------- pymic/util/post_process.py | 1 + requirements.txt | 4 +- setup.py | 13 +++- 5 files changed, 77 insertions(+), 61 deletions(-) diff --git a/pymic/loss/seg/ssl.py b/pymic/loss/seg/ssl.py index f15fc60..0bf276f 100644 --- a/pymic/loss/seg/ssl.py +++ b/pymic/loss/seg/ssl.py @@ -6,8 +6,9 @@ import torch.nn as nn import numpy as np from pymic.loss.seg.util import reshape_tensor_to_2D +from pymic.loss.seg.abstract import AbstractSegLoss -class EntropyLoss(nn.Module): +class EntropyLoss(AbstractSegLoss): """ Entropy Minimization for segmentation tasks. The parameters should be written in the `params` dictionary, and it has the @@ -43,7 +44,7 @@ def forward(self, loss_input_dict): avg_ent = torch.mean(entropy) return avg_ent -class TotalVariationLoss(nn.Module): +class TotalVariationLoss(AbstractSegLoss): """ Total Variation Loss for segmentation tasks. The parameters should be written in the `params` dictionary, and it has the diff --git a/pymic/util/evaluation_seg.py b/pymic/util/evaluation_seg.py index ec02297..06fc402 100644 --- a/pymic/util/evaluation_seg.py +++ b/pymic/util/evaluation_seg.py @@ -260,14 +260,15 @@ def evaluation(config): Run evaluation of segmentation results based on a configuration dictionary `config`. The following fields should be provided in `config`: - :param metric: (str) The metric for evaluation. + :param metric_list: (list) The list of metrics for evaluation. The metric options are {`dice`, `iou`, `assd`, `hd95`, `rve`, `volume`}. :param label_list: (list) The list of labels for evaluation. :param label_fuse: (option, bool) If true, fuse the labels in the `label_list` as the foreground, and other labels as the background. Default is False. :param organ_name: (str) The name of the organ for segmentation. :param ground_truth_folder_root: (str) The root dir of ground truth images. - :param segmentation_folder_root: (str) The root dir of segmentation images. + :param segmentation_folder_root: (str or list) The root dir of segmentation images. + When a list is given, each list element should be the root dir of the results of one method. :param evaluation_image_pair: (str) The csv file that provide the segmentation images and the corresponding ground truth images. :param ground_truth_label_convert_source: (optional, list) The list of source @@ -280,7 +281,7 @@ def evaluation(config): labels for label conversion in the segmentation. """ - metric = config['metric'] + metric_list = config['metric_list'] label_list = config['label_list'] label_fuse = config.get('label_fuse', False) organ_name = config['organ_name'] @@ -295,60 +296,62 @@ def evaluation(config): segmentation_label_convert_target = config.get('segmentation_label_convert_target', None) image_items = pd.read_csv(image_pair_csv) - item_num = len(image_items) - for seg_root_n in seg_root: - score_all_data = [] - name_score_list= [] - for i in range(item_num): - gt_name = image_items.iloc[i, 0] - seg_name = image_items.iloc[i, 1] - # seg_name = seg_name.replace(".nii.gz", "_pred.nii.gz") - gt_full_name = gt_root + '/' + gt_name - seg_full_name = seg_root_n + '/' + seg_name - - s_dict = load_image_as_nd_array(seg_full_name) - g_dict = load_image_as_nd_array(gt_full_name) - s_volume = s_dict["data_array"]; s_spacing = s_dict["spacing"] - g_volume = g_dict["data_array"]; g_spacing = g_dict["spacing"] - # for dim in range(len(s_spacing)): - # assert(s_spacing[dim] == g_spacing[dim]) - if((ground_truth_label_convert_source is not None) and \ - ground_truth_label_convert_target is not None): - g_volume = convert_label(g_volume, ground_truth_label_convert_source, \ - ground_truth_label_convert_target) - - if((segmentation_label_convert_source is not None) and \ - segmentation_label_convert_target is not None): - s_volume = convert_label(s_volume, segmentation_label_convert_source, \ - segmentation_label_convert_target) - - score_vector = get_multi_class_evaluation_score(s_volume, g_volume, label_list, - label_fuse, s_spacing, metric ) - if(len(label_list) > 1): - score_vector.append(np.asarray(score_vector).mean()) - score_all_data.append(score_vector) - name_score_list.append([seg_name] + score_vector) - print(seg_name, score_vector) - score_all_data = np.asarray(score_all_data) - score_mean = score_all_data.mean(axis = 0) - score_std = score_all_data.std(axis = 0) - name_score_list.append(['mean'] + list(score_mean)) - name_score_list.append(['std'] + list(score_std)) + item_num = len(image_items) - # save the result as csv - score_csv = "{0:}/{1:}_{2:}_all.csv".format(seg_root_n, organ_name, metric) - with open(score_csv, mode='w') as csv_file: - csv_writer = csv.writer(csv_file, delimiter=',', - quotechar='"',quoting=csv.QUOTE_MINIMAL) - head = ['image'] + ["class_{0:}".format(i) for i in label_list] - if(len(label_list) > 1): - head = head + ["average"] - csv_writer.writerow(head) - for item in name_score_list: - csv_writer.writerow(item) - - print("{0:} mean ".format(metric), score_mean) - print("{0:} std ".format(metric), score_std) + for seg_root_n in seg_root: # for each segmentation method + for metric in metric_list: + score_all_data = [] + name_score_list= [] + for i in range(item_num): + gt_name = image_items.iloc[i, 0] + seg_name = image_items.iloc[i, 1] + # seg_name = seg_name.replace(".nii.gz", "_pred.nii.gz") + gt_full_name = gt_root + '/' + gt_name + seg_full_name = seg_root_n + '/' + seg_name + + s_dict = load_image_as_nd_array(seg_full_name) + g_dict = load_image_as_nd_array(gt_full_name) + s_volume = s_dict["data_array"]; s_spacing = s_dict["spacing"] + g_volume = g_dict["data_array"]; g_spacing = g_dict["spacing"] + # for dim in range(len(s_spacing)): + # assert(s_spacing[dim] == g_spacing[dim]) + if((ground_truth_label_convert_source is not None) and \ + ground_truth_label_convert_target is not None): + g_volume = convert_label(g_volume, ground_truth_label_convert_source, \ + ground_truth_label_convert_target) + + if((segmentation_label_convert_source is not None) and \ + segmentation_label_convert_target is not None): + s_volume = convert_label(s_volume, segmentation_label_convert_source, \ + segmentation_label_convert_target) + + score_vector = get_multi_class_evaluation_score(s_volume, g_volume, label_list, + label_fuse, s_spacing, metric ) + if(len(label_list) > 1): + score_vector.append(np.asarray(score_vector).mean()) + score_all_data.append(score_vector) + name_score_list.append([seg_name] + score_vector) + print(seg_name, score_vector) + score_all_data = np.asarray(score_all_data) + score_mean = score_all_data.mean(axis = 0) + score_std = score_all_data.std(axis = 0) + name_score_list.append(['mean'] + list(score_mean)) + name_score_list.append(['std'] + list(score_std)) + + # save the result as csv + score_csv = "{0:}/{1:}_{2:}_all.csv".format(seg_root_n, organ_name, metric) + with open(score_csv, mode='w') as csv_file: + csv_writer = csv.writer(csv_file, delimiter=',', + quotechar='"',quoting=csv.QUOTE_MINIMAL) + head = ['image'] + ["class_{0:}".format(i) for i in label_list] + if(len(label_list) > 1): + head = head + ["average"] + csv_writer.writerow(head) + for item in name_score_list: + csv_writer.writerow(item) + + print("{0:} mean ".format(metric), score_mean) + print("{0:} std ".format(metric), score_std) def main(): """ diff --git a/pymic/util/post_process.py b/pymic/util/post_process.py index a0a9dff..a889b23 100644 --- a/pymic/util/post_process.py +++ b/pymic/util/post_process.py @@ -43,6 +43,7 @@ def __call__(self, seg): seg_c = np.asarray(seg == c, np.uint8) seg_c = get_largest_k_components(seg_c) output = output + seg_c * c + seg = output return seg PostProcessDict = { diff --git a/requirements.txt b/requirements.txt index c8cd562..6dac753 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,5 +7,5 @@ scipy>=1.3.3 SimpleITK>=2.0.0 tensorboard>=2.1.0 tensorboardX>=1.9 -torch>=1.7.1 -torchvision>=0.8.2 +torch>=1.1.12 +torchvision>=0.13.0 diff --git a/setup.py b/setup.py index ce7271b..527bdcb 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setuptools.setup( name = 'PYMIC', - version = "0.3.0", + version = "0.3.1", author ='PyMIC Consortium', author_email = 'wguotai@gmail.com', description = description, @@ -20,6 +20,17 @@ url = 'https://github.com/HiLab-git/PyMIC', license = 'Apache 2.0', packages = setuptools.find_packages(), + install_requires=[ + "matplotlib>=3.1.2", + "numpy>=1.17.4", + "pandas>=0.25.3", + "scikit-image>=0.16.2", + "scikit-learn>=0.22", + "scipy>=1.3.3", + "SimpleITK>=2.0.0", + "tensorboard>=2.1.0", + "tensorboardX>=1.9", + ], classifiers=[ 'License :: OSI Approved :: Apache Software License', 'Programming Language :: Python', From 8928f6bdc278b31b080d0f443e414a8e0524ad03 Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 6 Dec 2022 13:11:12 +0800 Subject: [PATCH 121/225] update distance evaluation --- pymic/test/test_assd.py | 37 ++++++++++++++++++++++++++++++++++++ pymic/util/evaluation_seg.py | 26 ++++++------------------- requirements.txt | 4 ++-- setup.py | 6 +++--- 4 files changed, 48 insertions(+), 25 deletions(-) create mode 100644 pymic/test/test_assd.py diff --git a/pymic/test/test_assd.py b/pymic/test/test_assd.py new file mode 100644 index 0000000..35c1804 --- /dev/null +++ b/pymic/test/test_assd.py @@ -0,0 +1,37 @@ +from scipy import ndimage +from PIL import Image +import numpy as np +import SimpleITK as sitk +import matplotlib.pyplot as plt +from pymic.util.evaluation_seg import get_edge_points + +def test_assd_2d(): + img_name = "/home/x/projects/PyMIC_project/PyMIC_examples/PyMIC_data/JSRT/label/JPCLN001.png" + img = Image.open(img_name) + img_array = np.asarray(img) + img_edge = get_edge_points(img_array > 0) + s_dis = ndimage.distance_transform_edt(1-img_edge) + plt.subplot(1,2,1) + plt.imshow(img_edge) + plt.subplot(1,2,2) + plt.imshow(s_dis) + plt.show() + +def test_assd_3d(): + img_name = "/home/x/projects/PyMIC_project/PyMIC_examples/seg_ssl/ACDC/result/unet2d_baseline/patient001_frame01.nii.gz" + img_obj = sitk.ReadImage(img_name) + spacing = img_obj.GetSpacing() + spacing = spacing[::-1] + img_data = sitk.GetArrayFromImage(img_obj) + print(img_data.shape) + print(spacing) + img_edge = get_edge_points(img_data > 0) + s_dis = ndimage.distance_transform_edt(1-img_edge, sampling=spacing) + dis_obj = sitk.GetImageFromArray(s_dis) + dis_obj.CopyInformation(img_obj) + sitk.WriteImage(dis_obj, "test_dis.nii.gz") + + + +if __name__ == "__main__": + test_assd_3d() \ No newline at end of file diff --git a/pymic/util/evaluation_seg.py b/pymic/util/evaluation_seg.py index 06fc402..ba04a73 100644 --- a/pymic/util/evaluation_seg.py +++ b/pymic/util/evaluation_seg.py @@ -6,11 +6,7 @@ import csv import os import sys -import math import pandas as pd -import random -import GeodisTK -import configparser import numpy as np from scipy import ndimage from pymic.io.image_read_write import * @@ -90,7 +86,7 @@ def get_edge_points(img): strt = ndimage.generate_binary_structure(2,1) else: strt = ndimage.generate_binary_structure(3,1) - ero = ndimage.morphology.binary_erosion(img, strt) + ero = ndimage.binary_erosion(img, strt) edge = np.asarray(img, np.uint8) - np.asarray(ero, np.uint8) return edge @@ -114,14 +110,9 @@ def binary_hd95(s, g, spacing = None): spacing = [1.0] * image_dim else: assert(image_dim == len(spacing)) - img = np.zeros_like(s) - if(image_dim == 2): - s_dis = GeodisTK.geodesic2d_raster_scan(img, s_edge, 0.0, 2) - g_dis = GeodisTK.geodesic2d_raster_scan(img, g_edge, 0.0, 2) - elif(image_dim ==3): - s_dis = GeodisTK.geodesic3d_raster_scan(img, s_edge, spacing, 0.0, 2) - g_dis = GeodisTK.geodesic3d_raster_scan(img, g_edge, spacing, 0.0, 2) - + s_dis = ndimage.distance_transform_edt(1-s_edge, sampling = spacing) + g_dis = ndimage.distance_transform_edt(1-g_edge, sampling = spacing) + dist_list1 = s_dis[g_edge > 0] dist_list1 = sorted(dist_list1) dist1 = dist_list1[int(len(dist_list1)*0.95)] @@ -150,13 +141,8 @@ def binary_assd(s, g, spacing = None): spacing = [1.0] * image_dim else: assert(image_dim == len(spacing)) - img = np.zeros_like(s) - if(image_dim == 2): - s_dis = GeodisTK.geodesic2d_raster_scan(img, s_edge, 0.0, 2) - g_dis = GeodisTK.geodesic2d_raster_scan(img, g_edge, 0.0, 2) - elif(image_dim ==3): - s_dis = GeodisTK.geodesic3d_raster_scan(img, s_edge, spacing, 0.0, 2) - g_dis = GeodisTK.geodesic3d_raster_scan(img, g_edge, spacing, 0.0, 2) + s_dis = ndimage.distance_transform_edt(1-s_edge, sampling = spacing) + g_dis = ndimage.distance_transform_edt(1-g_edge, sampling = spacing) ns = s_edge.sum() ng = g_edge.sum() diff --git a/requirements.txt b/requirements.txt index 6dac753..49912a4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ scikit-image>=0.16.2 scikit-learn>=0.22 scipy>=1.3.3 SimpleITK>=2.0.0 -tensorboard>=2.1.0 -tensorboardX>=1.9 +tensorboard +tensorboardX torch>=1.1.12 torchvision>=0.13.0 diff --git a/setup.py b/setup.py index 527bdcb..cdf2295 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setuptools.setup( name = 'PYMIC', - version = "0.3.1", + version = "0.3.2.2", author ='PyMIC Consortium', author_email = 'wguotai@gmail.com', description = description, @@ -28,8 +28,8 @@ "scikit-learn>=0.22", "scipy>=1.3.3", "SimpleITK>=2.0.0", - "tensorboard>=2.1.0", - "tensorboardX>=1.9", + "tensorboard", + "tensorboardX", ], classifiers=[ 'License :: OSI Approved :: Apache Software License', From 1e7994755f5e63b41af3888ec8e280b6115f2828 Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 6 Dec 2022 16:56:26 +0800 Subject: [PATCH 122/225] update the version all automatic installation of dependencies --- README.md | 4 ++-- setup.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 1cf7b3c..d338581 100644 --- a/README.md +++ b/README.md @@ -47,10 +47,10 @@ Run the following command to install the latest released version of PyMIC: ```bash pip install PYMIC ``` -To install a specific version of PYMIC such as 0.3.0, run: +To install a specific version of PYMIC such as 0.3.1, run: ```bash -pip install PYMIC==0.3.0 +pip install PYMIC==0.3.1 ``` Alternatively, you can download the source code for the latest version. Run the following command to compile and install: diff --git a/setup.py b/setup.py index cdf2295..9a7f38b 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setuptools.setup( name = 'PYMIC', - version = "0.3.2.2", + version = "0.3.1", author ='PyMIC Consortium', author_email = 'wguotai@gmail.com', description = description, From 934816c3e49edd00383b3667154c23ba197629fb Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 6 Dec 2022 17:27:05 +0800 Subject: [PATCH 123/225] update typo for num_worker --- pymic/net_run/agent_abstract.py | 2 +- pymic/net_run_nll/nll_dast.py | 2 +- pymic/net_run_ssl/ssl_abstract.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pymic/net_run/agent_abstract.py b/pymic/net_run/agent_abstract.py index dc6b1ad..50e67e2 100644 --- a/pymic/net_run/agent_abstract.py +++ b/pymic/net_run/agent_abstract.py @@ -259,7 +259,7 @@ def worker_init_fn(worker_id): bn_train = self.config['dataset']['train_batch_size'] bn_valid = self.config['dataset'].get('valid_batch_size', 1) - num_worker = self.config['dataset'].get('num_workder', 16) + num_worker = self.config['dataset'].get('num_worker', 16) g_train, g_valid = torch.Generator(), torch.Generator() g_train.manual_seed(self.random_seed) g_valid.manual_seed(self.random_seed) diff --git a/pymic/net_run_nll/nll_dast.py b/pymic/net_run_nll/nll_dast.py index f057f64..aa4ec1e 100644 --- a/pymic/net_run_nll/nll_dast.py +++ b/pymic/net_run_nll/nll_dast.py @@ -156,7 +156,7 @@ def worker_init_fn(worker_id): worker_init = None bn_train_noise = self.config['dataset']['train_batch_size_noise'] - num_worker = self.config['dataset'].get('num_workder', 16) + num_worker = self.config['dataset'].get('num_worker', 16) self.train_loader_noise = torch.utils.data.DataLoader(self.train_set_noise, batch_size = bn_train_noise, shuffle=True, num_workers= num_worker, worker_init_fn=worker_init) diff --git a/pymic/net_run_ssl/ssl_abstract.py b/pymic/net_run_ssl/ssl_abstract.py index e847d6c..b3ab9cd 100644 --- a/pymic/net_run_ssl/ssl_abstract.py +++ b/pymic/net_run_ssl/ssl_abstract.py @@ -73,7 +73,7 @@ def worker_init_fn(worker_id): worker_init = None bn_train_unlab = self.config['dataset']['train_batch_size_unlab'] - num_worker = self.config['dataset'].get('num_workder', 16) + num_worker = self.config['dataset'].get('num_worker', 16) self.train_loader_unlab = torch.utils.data.DataLoader(self.train_set_unlab, batch_size = bn_train_unlab, shuffle=True, num_workers= num_worker, worker_init_fn=worker_init) From 3834d34d13688df094cb44bba7d476c3e607dd61 Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 21 Dec 2022 14:29:20 +0800 Subject: [PATCH 124/225] update learning rate scheduler to keep consistent with pytorch, one step in learning rate scheduler corresponds to an epoch on the training set --- pymic/net_run/agent_seg.py | 14 +++++--------- pymic/net_run/get_optimizer.py | 18 +++++++++++++----- pymic/net_run_nll/nll_co_teaching.py | 3 --- pymic/net_run_nll/nll_dast.py | 3 --- pymic/net_run_nll/nll_trinet.py | 4 ---- pymic/net_run_ssl/ssl_cct.py | 5 +---- pymic/net_run_ssl/ssl_cps.py | 6 +----- pymic/net_run_ssl/ssl_em.py | 4 ---- pymic/net_run_ssl/ssl_mt.py | 4 ---- pymic/net_run_ssl/ssl_uamt.py | 5 ----- pymic/net_run_ssl/ssl_urpc.py | 5 +---- pymic/net_run_wsl/wsl_dmpls.py | 4 ---- pymic/net_run_wsl/wsl_em.py | 4 ---- pymic/net_run_wsl/wsl_gatedcrf.py | 4 ---- pymic/net_run_wsl/wsl_mumford_shah.py | 4 ---- pymic/net_run_wsl/wsl_tv.py | 4 ---- pymic/net_run_wsl/wsl_ustm.py | 4 ---- 17 files changed, 21 insertions(+), 74 deletions(-) diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 620f29b..7ab0b6e 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -162,10 +162,6 @@ def training(self): loss = self.get_loss_value(data, outputs, labels_prob) loss.backward() self.optimizer.step() - if(self.scheduler is not None and \ - not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): - self.scheduler.step() - train_loss = train_loss + loss.item() # get dice evaluation for each class if(isinstance(outputs, tuple) or isinstance(outputs, list)): @@ -219,10 +215,6 @@ def validation(self): valid_avg_loss = np.asarray(valid_loss_list).mean() valid_cls_dice = np.asarray(valid_dice_list).mean(axis = 0) valid_avg_dice = valid_cls_dice.mean() - - if(isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): - self.scheduler.step(valid_avg_dice) - valid_scalers = {'loss': valid_avg_loss, 'avg_dice': valid_avg_dice,\ 'class_dice': valid_cls_dice} return valid_scalers @@ -300,9 +292,13 @@ def train_valid(self): t0 = time.time() train_scalars = self.training() t1 = time.time() - valid_scalars = self.validation() t2 = time.time() + if(isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): + self.scheduler.step(valid_scalars['avg_dice']) + else: + self.scheduler.step() + self.glob_it = it + iter_valid logging.info("\n{0:} it {1:}".format(str(datetime.now())[:-7], self.glob_it)) logging.info('learning rate {0:}'.format(lr_value)) diff --git a/pymic/net_run/get_optimizer.py b/pymic/net_run/get_optimizer.py index c4504de..5924ff9 100644 --- a/pymic/net_run/get_optimizer.py +++ b/pymic/net_run/get_optimizer.py @@ -38,20 +38,28 @@ def get_optimizer(name, net_params, optim_params): def get_lr_scheduler(optimizer, sched_params): name = sched_params["lr_scheduler"] + val_it = sched_params["iter_valid"] if(name is None): return None - lr_gamma = sched_params["lr_gamma"] if(keyword_match(name, "ReduceLROnPlateau")): patience_it = sched_params["ReduceLROnPlateau_patience".lower()] - val_it = sched_params["iter_valid"] - patience = patience_it / val_it + patience = patience_it / val_it + lr_gamma = sched_params["lr_gamma"] scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode = "max", factor=lr_gamma, patience = patience) elif(keyword_match(name, "MultiStepLR")): lr_milestones = sched_params["lr_milestones"] - last_iter = sched_params["last_iter"] + lr_milestones = [int(item / val_it) for item in lr_milestones] + epoch_last = sched_params["last_iter"] / val_it + lr_gamma = sched_params["lr_gamma"] scheduler = lr_scheduler.MultiStepLR(optimizer, - lr_milestones, lr_gamma, last_iter) + lr_milestones, lr_gamma, epoch_last) + elif(keyword_match(name, "CosineAnnealingLR")): + epoch_max = sched_params["iter_max"] / val_it + epoch_last = sched_params["last_iter"] / val_it + lr_min = sched_params.get("lr_min", 0) + scheduler = lr_scheduler.CosineAnnealingLR(optimizer, + epoch_max, lr_min, epoch_last) else: raise ValueError("unsupported lr scheduler {0:}".format(name)) return scheduler \ No newline at end of file diff --git a/pymic/net_run_nll/nll_co_teaching.py b/pymic/net_run_nll/nll_co_teaching.py index 06240d6..465a316 100644 --- a/pymic/net_run_nll/nll_co_teaching.py +++ b/pymic/net_run_nll/nll_co_teaching.py @@ -128,9 +128,6 @@ def training(self): loss.backward() self.optimizer.step() - if(self.scheduler is not None and \ - not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): - self.scheduler.step() train_loss_no_select1 = train_loss_no_select1 + loss1.mean().item() train_loss_no_select2 = train_loss_no_select2 + loss2.mean().item() diff --git a/pymic/net_run_nll/nll_dast.py b/pymic/net_run_nll/nll_dast.py index aa4ec1e..07ed312 100644 --- a/pymic/net_run_nll/nll_dast.py +++ b/pymic/net_run_nll/nll_dast.py @@ -239,9 +239,6 @@ def training(self): loss.backward() self.optimizer.step() - if(self.scheduler is not None and \ - not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): - self.scheduler.step() train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() diff --git a/pymic/net_run_nll/nll_trinet.py b/pymic/net_run_nll/nll_trinet.py index 7ac74af..2694013 100644 --- a/pymic/net_run_nll/nll_trinet.py +++ b/pymic/net_run_nll/nll_trinet.py @@ -8,7 +8,6 @@ import torch import torch.nn as nn import torch.optim as optim -from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice @@ -125,9 +124,6 @@ def training(self): loss.backward() self.optimizer.step() - if(self.scheduler is not None and \ - not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): - self.scheduler.step() train_loss_no_select1 = train_loss_no_select1 + loss1.mean().item() train_loss_no_select2 = train_loss_no_select2 + loss2.mean().item() diff --git a/pymic/net_run_ssl/ssl_cct.py b/pymic/net_run_ssl/ssl_cct.py index 0bc23a2..0cb16aa 100644 --- a/pymic/net_run_ssl/ssl_cct.py +++ b/pymic/net_run_ssl/ssl_cct.py @@ -5,7 +5,6 @@ import torch.nn as nn import torch.nn.functional as F import numpy as np -from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice @@ -139,9 +138,7 @@ def training(self): loss.backward() self.optimizer.step() - if(self.scheduler is not None and \ - not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): - self.scheduler.step() + train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() train_loss_reg = train_loss_reg + loss_reg.item() diff --git a/pymic/net_run_ssl/ssl_cps.py b/pymic/net_run_ssl/ssl_cps.py index 239443c..637fe2e 100644 --- a/pymic/net_run_ssl/ssl_cps.py +++ b/pymic/net_run_ssl/ssl_cps.py @@ -4,7 +4,6 @@ import numpy as np import torch import torch.nn as nn -from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice @@ -26,7 +25,7 @@ def forward(self, x): if(self.training): return out1, out2 else: - return (out1 + out2) / 3 + return (out1 + out2) / 2 class SSLCPS(SSLSegAgent): """ @@ -117,9 +116,6 @@ def training(self): loss.backward() self.optimizer.step() - if(self.scheduler is not None and \ - not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): - self.scheduler.step() train_loss = train_loss + loss.item() train_loss_sup1 = train_loss_sup1 + loss_sup1.item() diff --git a/pymic/net_run_ssl/ssl_em.py b/pymic/net_run_ssl/ssl_em.py index 745501f..8aebada 100644 --- a/pymic/net_run_ssl/ssl_em.py +++ b/pymic/net_run_ssl/ssl_em.py @@ -3,7 +3,6 @@ import logging import numpy as np import torch -from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice @@ -83,9 +82,6 @@ def training(self): # if (self.config['training']['use']) loss.backward() self.optimizer.step() - if(self.scheduler is not None and \ - not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): - self.scheduler.step() train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() diff --git a/pymic/net_run_ssl/ssl_mt.py b/pymic/net_run_ssl/ssl_mt.py index bf4dacd..6a5a95f 100644 --- a/pymic/net_run_ssl/ssl_mt.py +++ b/pymic/net_run_ssl/ssl_mt.py @@ -3,7 +3,6 @@ import logging import torch import numpy as np -from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice @@ -102,9 +101,6 @@ def training(self): loss.backward() self.optimizer.step() - if(self.scheduler is not None and \ - not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): - self.scheduler.step() # update EMA alpha = ssl_cfg.get('ema_decay', 0.99) diff --git a/pymic/net_run_ssl/ssl_uamt.py b/pymic/net_run_ssl/ssl_uamt.py index 40a742b..af6640f 100644 --- a/pymic/net_run_ssl/ssl_uamt.py +++ b/pymic/net_run_ssl/ssl_uamt.py @@ -3,7 +3,6 @@ import logging import torch import numpy as np -from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice @@ -104,10 +103,6 @@ def training(self): loss.backward() self.optimizer.step() - if(self.scheduler is not None and \ - not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): - self.scheduler.step() - # update EMA alpha = ssl_cfg.get('ema_decay', 0.99) diff --git a/pymic/net_run_ssl/ssl_urpc.py b/pymic/net_run_ssl/ssl_urpc.py index 336e525..f404e64 100644 --- a/pymic/net_run_ssl/ssl_urpc.py +++ b/pymic/net_run_ssl/ssl_urpc.py @@ -4,7 +4,6 @@ import torch import torch.nn as nn import numpy as np -from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice @@ -96,9 +95,7 @@ def training(self): loss.backward() self.optimizer.step() - if(self.scheduler is not None and \ - not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): - self.scheduler.step() + train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() train_loss_reg = train_loss_reg + loss_reg.item() diff --git a/pymic/net_run_wsl/wsl_dmpls.py b/pymic/net_run_wsl/wsl_dmpls.py index 15e74a0..3081f8e 100644 --- a/pymic/net_run_wsl/wsl_dmpls.py +++ b/pymic/net_run_wsl/wsl_dmpls.py @@ -4,7 +4,6 @@ import numpy as np import random import torch -from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice @@ -91,9 +90,6 @@ def training(self): loss.backward() self.optimizer.step() - if(self.scheduler is not None and \ - not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): - self.scheduler.step() train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() diff --git a/pymic/net_run_wsl/wsl_em.py b/pymic/net_run_wsl/wsl_em.py index 387bd93..6002547 100644 --- a/pymic/net_run_wsl/wsl_em.py +++ b/pymic/net_run_wsl/wsl_em.py @@ -3,7 +3,6 @@ import logging import numpy as np import torch -from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice @@ -72,9 +71,6 @@ def training(self): loss.backward() self.optimizer.step() - if(self.scheduler is not None and \ - not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): - self.scheduler.step() train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() diff --git a/pymic/net_run_wsl/wsl_gatedcrf.py b/pymic/net_run_wsl/wsl_gatedcrf.py index 85f601e..41605db 100644 --- a/pymic/net_run_wsl/wsl_gatedcrf.py +++ b/pymic/net_run_wsl/wsl_gatedcrf.py @@ -3,7 +3,6 @@ import logging import numpy as np import torch -from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice @@ -98,9 +97,6 @@ def training(self): loss.backward() self.optimizer.step() - if(self.scheduler is not None and \ - not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): - self.scheduler.step() train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() diff --git a/pymic/net_run_wsl/wsl_mumford_shah.py b/pymic/net_run_wsl/wsl_mumford_shah.py index 862d761..431ec7a 100644 --- a/pymic/net_run_wsl/wsl_mumford_shah.py +++ b/pymic/net_run_wsl/wsl_mumford_shah.py @@ -3,7 +3,6 @@ import logging import numpy as np import torch -from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice @@ -73,9 +72,6 @@ def training(self): # if (self.config['training']['use']) loss.backward() self.optimizer.step() - if(self.scheduler is not None and \ - not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): - self.scheduler.step() train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() diff --git a/pymic/net_run_wsl/wsl_tv.py b/pymic/net_run_wsl/wsl_tv.py index 3fd55a3..1612037 100644 --- a/pymic/net_run_wsl/wsl_tv.py +++ b/pymic/net_run_wsl/wsl_tv.py @@ -3,7 +3,6 @@ import logging import numpy as np import torch -from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice @@ -68,9 +67,6 @@ def training(self): # if (self.config['training']['use']) loss.backward() self.optimizer.step() - if(self.scheduler is not None and \ - not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): - self.scheduler.step() train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() diff --git a/pymic/net_run_wsl/wsl_ustm.py b/pymic/net_run_wsl/wsl_ustm.py index b033f9d..5f69b4a 100644 --- a/pymic/net_run_wsl/wsl_ustm.py +++ b/pymic/net_run_wsl/wsl_ustm.py @@ -5,7 +5,6 @@ import random import torch import torch.nn.functional as F -from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice @@ -121,9 +120,6 @@ def training(self): loss.backward() self.optimizer.step() - if(self.scheduler is not None and \ - not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): - self.scheduler.step() # update EMA alpha = wsl_cfg.get('ema_decay', 0.99) From 452ccef19b2e43811a4f2c86caefc533eecdd32f Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 22 Dec 2022 11:18:32 +0800 Subject: [PATCH 125/225] update rotate and rescale introduce a probability for rotate and rescale --- pymic/net_run/get_optimizer.py | 24 +++++++++++++----------- pymic/transform/rescale.py | 24 ++++++++++++++---------- pymic/transform/rotate.py | 10 ++++++++++ 3 files changed, 37 insertions(+), 21 deletions(-) diff --git a/pymic/net_run/get_optimizer.py b/pymic/net_run/get_optimizer.py index 5924ff9..b3e8f82 100644 --- a/pymic/net_run/get_optimizer.py +++ b/pymic/net_run/get_optimizer.py @@ -10,28 +10,30 @@ def get_optimizer(name, net_params, optim_params): lr = optim_params['learning_rate'] momentum = optim_params['momentum'] weight_decay = optim_params['weight_decay'] + # see https://www.codeleading.com/article/44815584159/ + param_group = [{'params': net_params, 'initial_lr': lr}] if(keyword_match(name, "SGD")): - return optim.SGD(net_params, lr, + return optim.SGD(param_group, lr, momentum = momentum, weight_decay = weight_decay) elif(keyword_match(name, "Adam")): - return optim.Adam(net_params, lr, weight_decay = weight_decay) + return optim.Adam(param_group, lr, weight_decay = weight_decay) elif(keyword_match(name, "SparseAdam")): - return optim.SparseAdam(net_params, lr) + return optim.SparseAdam(param_group, lr) elif(keyword_match(name, "Adadelta")): - return optim.Adadelta(net_params, lr, weight_decay = weight_decay) + return optim.Adadelta(param_group, lr, weight_decay = weight_decay) elif(keyword_match(name, "Adagrad")): - return optim.Adagrad(net_params, lr, weight_decay = weight_decay) + return optim.Adagrad(param_group, lr, weight_decay = weight_decay) elif(keyword_match(name, "Adamax")): - return optim.Adamax(net_params, lr, weight_decay = weight_decay) + return optim.Adamax(param_group, lr, weight_decay = weight_decay) elif(keyword_match(name, "ASGD")): - return optim.ASGD(net_params, lr, weight_decay = weight_decay) + return optim.ASGD(param_group, lr, weight_decay = weight_decay) elif(keyword_match(name, "LBFGS")): - return optim.LBFGS(net_params, lr) + return optim.LBFGS(param_group, lr) elif(keyword_match(name, "RMSprop")): - return optim.RMSprop(net_params, lr, momentum = momentum, + return optim.RMSprop(param_group, lr, momentum = momentum, weight_decay = weight_decay) elif(keyword_match(name, "Rprop")): - return optim.Rprop(net_params, lr) + return optim.Rprop(param_group, lr) else: raise ValueError("unsupported optimizer {0:}".format(name)) @@ -57,7 +59,7 @@ def get_lr_scheduler(optimizer, sched_params): elif(keyword_match(name, "CosineAnnealingLR")): epoch_max = sched_params["iter_max"] / val_it epoch_last = sched_params["last_iter"] / val_it - lr_min = sched_params.get("lr_min", 0) + lr_min = sched_params.get("lr_min", 0) scheduler = lr_scheduler.CosineAnnealingLR(optimizer, epoch_max, lr_min, epoch_last) else: diff --git a/pymic/transform/rescale.py b/pymic/transform/rescale.py index 04dd458..cc1577d 100644 --- a/pymic/transform/rescale.py +++ b/pymic/transform/rescale.py @@ -85,28 +85,30 @@ class RandomRescale(AbstractTransform): The arguments should be written in the `params` dictionary, and it has the following fields: - :param `RandomRescale_lower_bound`: (list/tuple or int) + :param `RandomRescale_lower_bound`: (list/tuple or float) Desired minimal rescale ratio. If tuple/list, the length should be 3 or 2. - :param `RandomRescale_upper_bound`: (list/tuple or int) + :param `RandomRescale_upper_bound`: (list/tuple or float) Desired maximal rescale ratio. If tuple/list, the length should be 3 or 2. + :param `RandomRescale_probability`: (optional, float) + The probability of applying RandomRescale. Default is 0.5. :param `RandomRescale_inverse`: (optional, bool) Is inverse transform needed for inference. Default is `True`. """ def __init__(self, params): - """ - ratio0 (tuple/list or int): Desired minimal rescale ratio. - If tuple/list, the length should be 3 or 2. - ratio1 (tuple/list or int): Desired maximal rescale ratio. - If tuple/list, the length should be 3 or 2. - """ super(RandomRescale, self).__init__(params) self.ratio0 = params["RandomRescale_lower_bound".lower()] self.ratio1 = params["RandomRescale_upper_bound".lower()] + self.prob = params.get('RandomRescale_probability'.lower(), 0.5) self.inverse = params.get("RandomRescale_inverse".lower(), True) assert isinstance(self.ratio0, (float, list, tuple)) assert isinstance(self.ratio1, (float, list, tuple)) def __call__(self, sample): + if(np.random.uniform() > self.prob): + sample['RandomRescale_triggered'] = False + return sample + else: + sample['RandomRescale_triggered'] = True image = sample['image'] input_shape = image.shape input_dim = len(input_shape) - 1 @@ -117,8 +119,8 @@ def __call__(self, sample): scale = [self.ratio0[i] + random.random()*(self.ratio1[i] - self.ratio0[i]) \ for i in range(len(self.ratio0))] else: - scale = [self.ratio0 + random.random()*(self.ratio1 - self.ratio0) \ - for i in range(input_dim)] + scale = self.ratio0 + random.random()*(self.ratio1 - self.ratio0) + scale = [scale] * input_dim scale = [1.0] + scale image_t = ndimage.interpolation.zoom(image, scale, order = 1) @@ -136,6 +138,8 @@ def __call__(self, sample): return sample def inverse_transform_for_prediction(self, sample): + if(not sample['RandomRescale_triggered']): + return sample if(isinstance(sample['RandomRescale_origin_shape'], list) or \ isinstance(sample['RandomRescale_origin_shape'], tuple)): origin_shape = json.loads(sample['RandomRescale_origin_shape'][0]) diff --git a/pymic/transform/rotate.py b/pymic/transform/rotate.py index ba2d655..708cf6a 100644 --- a/pymic/transform/rotate.py +++ b/pymic/transform/rotate.py @@ -27,6 +27,8 @@ class RandomRotate(AbstractTransform): :param `RandomRotate_angle_range_w`: (list/tuple or None) Rotation angle (degree) range along width axis (y-z plane), e.g., (-90, 90). If None, no rotation along this axis. Only used for 3D images. + :param `RandomRotate_probability`: (optional, float) + The probability of applying RandomRotate. Default is 0.5. :param `RandomRotate_inverse`: (optional, bool) Is inverse transform needed for inference. Default is `True`. """ @@ -35,6 +37,7 @@ def __init__(self, params): self.angle_range_d = params['RandomRotate_angle_range_d'.lower()] self.angle_range_h = params['RandomRotate_angle_range_h'.lower()] self.angle_range_w = params['RandomRotate_angle_range_w'.lower()] + self.prob = params.get('RandomRotate_probability'.lower(), 0.5) self.inverse = params.get('RandomRotate_inverse'.lower(), True) def __apply_transformation(self, image, transform_param_list, order = 1): @@ -50,6 +53,11 @@ def __apply_transformation(self, image, transform_param_list, order = 1): return image def __call__(self, sample): + if(np.random.uniform() > self.prob): + sample['RandomRotate_triggered'] = False + return sample + else: + sample['RandomRotate_triggered'] = True image = sample['image'] input_shape = image.shape input_dim = len(input_shape) - 1 @@ -79,6 +87,8 @@ def __call__(self, sample): return sample def inverse_transform_for_prediction(self, sample): + if(not sample['RandomRotate_triggered']): + return sample if(isinstance(sample['RandomRotate_Param'], list) or \ isinstance(sample['RandomRotate_Param'], tuple)): transform_param_list = json.loads(sample['RandomRotate_Param'][0]) From abf0dc62d9e5089b755f4f0994d21fec63a3bcf0 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 23 Dec 2022 12:32:42 +0800 Subject: [PATCH 126/225] add dual-branch 3D unet --- pymic/net/net3d/unet3d.py | 87 +++++++++++++++++++++++++++ pymic/net/net3d/unet3d_dual_branch.py | 46 ++++++++++++++ pymic/net/net_dict_seg.py | 5 +- pymic/transform/rescale.py | 24 ++++---- pymic/transform/rotate.py | 16 ++--- 5 files changed, 157 insertions(+), 21 deletions(-) create mode 100644 pymic/net/net3d/unet3d_dual_branch.py diff --git a/pymic/net/net3d/unet3d.py b/pymic/net/net3d/unet3d.py index af28d73..caf4e5b 100644 --- a/pymic/net/net3d/unet3d.py +++ b/pymic/net/net3d/unet3d.py @@ -77,6 +77,93 @@ def forward(self, x1, x2): x = torch.cat([x2, x1], dim=1) return self.conv(x) +class Encoder(nn.Module): + """ + Encoder of 3D UNet. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + """ + def __init__(self, params): + super(Encoder, self).__init__() + self.params = params + self.in_chns = self.params['in_chns'] + self.ft_chns = self.params['feature_chns'] + self.dropout = self.params['dropout'] + assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) + + self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) + self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) + self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) + self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) + if(len(self.ft_chns) == 5): + self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) + + def forward(self, x): + x0 = self.in_conv(x) + x1 = self.down1(x0) + x2 = self.down2(x1) + x3 = self.down3(x2) + output = [x0, x1, x2, x3] + if(len(self.ft_chns) == 5): + x4 = self.down4(x3) + output.append(x4) + return output + +class Decoder(nn.Module): + """ + Decoder of 3D UNet. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param trilinear: (bool) Using bilinear for up-sampling or not. + If False, deconvolution will be used for up-sampling. + """ + def __init__(self, params): + super(Decoder, self).__init__() + self.params = params + self.in_chns = self.params['in_chns'] + self.ft_chns = self.params['feature_chns'] + self.dropout = self.params['dropout'] + self.n_class = self.params['class_num'] + self.trilinear = self.params['trilinear'] + + assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) + + if(len(self.ft_chns) == 5): + self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.bilinear) + self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.bilinear) + self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.bilinear) + self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.bilinear) + self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 1) + + def forward(self, x): + if(len(self.ft_chns) == 5): + assert(len(x) == 5) + x0, x1, x2, x3, x4 = x + x_d3 = self.up1(x4, x3) + else: + assert(len(x) == 4) + x0, x1, x2, x3 = x + x_d3 = x3 + x_d2 = self.up2(x_d3, x2) + x_d1 = self.up3(x_d2, x1) + x_d0 = self.up4(x_d1, x0) + output = self.out_conv(x_d0) + return output class UNet3D(nn.Module): """ diff --git a/pymic/net/net3d/unet3d_dual_branch.py b/pymic/net/net3d/unet3d_dual_branch.py new file mode 100644 index 0000000..3bede4e --- /dev/null +++ b/pymic/net/net3d/unet3d_dual_branch.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import torch.nn as nn +from pymic.net.net3d.unet3d import * + +class UNet3D_DualBranch(nn.Module): + """ + A dual branch network using UNet3D as backbone. + + * Reference: Xiangde Luo, Minhao Hu, Wenjun Liao, Shuwei Zhai, Tao Song, Guotai Wang, + Shaoting Zhang. ScribblScribble-Supervised Medical Image Segmentation via + Dual-Branch Network and Dynamically Mixed Pseudo Labels Supervision. + `MICCAI 2022. `_ + + The parameters for the backbone should be given in the `params` dictionary. + See :mod:`pymic.net.net3d.unet3d.UNet3D` for details. + In addition, the following field should be included: + + :param output_mode: (str) How to obtain the result during the inference. + `average`: taking average of the two branches. + `first`: takeing the result in the first branch. + `second`: taking the result in the second branch. + """ + def __init__(self, params): + super(UNet3D_DualBranch, self).__init__() + self.output_mode = params.get("output_mode", "average") + self.encoder = Encoder(params) + self.decoder1 = Decoder(params) + self.decoder2 = Decoder(params) + + def forward(self, x): + f = self.encoder(x) + output1 = self.decoder1(f) + output2 = self.decoder2(f) + + if(self.training): + return output1, output2 + else: + if(self.output_mode == "average"): + return (output1 + output2)/2 + elif(self.output_mode == "first"): + return output1 + else: + return output2 diff --git a/pymic/net/net_dict_seg.py b/pymic/net/net_dict_seg.py index 195896a..fc7692f 100644 --- a/pymic/net/net_dict_seg.py +++ b/pymic/net/net_dict_seg.py @@ -24,6 +24,7 @@ from pymic.net.net3d.unet2d5 import UNet2D5 from pymic.net.net3d.unet3d import UNet3D from pymic.net.net3d.unet3d_scse import UNet3D_ScSE +from pymic.net.net3d.unet3d_dual_branch import UNet3D_DualBranch SegNetDict = { 'UNet2D': UNet2D, @@ -35,5 +36,7 @@ 'UNet2D_ScSE': UNet2D_ScSE, 'UNet2D5': UNet2D5, 'UNet3D': UNet3D, - 'UNet3D_ScSE': UNet3D_ScSE + 'UNet3D_ScSE': UNet3D_ScSE, + 'UNet3D_DualBranch': UNet3D_DualBranch + } diff --git a/pymic/transform/rescale.py b/pymic/transform/rescale.py index cc1577d..2a671fd 100644 --- a/pymic/transform/rescale.py +++ b/pymic/transform/rescale.py @@ -1,9 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division -import torch import json -import math import random import numpy as np from scipy import ndimage @@ -104,11 +102,13 @@ def __init__(self, params): assert isinstance(self.ratio1, (float, list, tuple)) def __call__(self, sample): - if(np.random.uniform() > self.prob): - sample['RandomRescale_triggered'] = False - return sample - else: - sample['RandomRescale_triggered'] = True + # if(random.random() > self.prob): + # print("rescale not started") + # sample['RandomRescale_triggered'] = False + # return sample + # else: + # print("rescale started") + # sample['RandomRescale_triggered'] = True image = sample['image'] input_shape = image.shape input_dim = len(input_shape) - 1 @@ -125,7 +125,7 @@ def __call__(self, sample): image_t = ndimage.interpolation.zoom(image, scale, order = 1) sample['image'] = image_t - sample['RandomRescale_origin_shape'] = json.dumps(input_shape) + sample['RandomRescale_Param'] = json.dumps(input_shape) if('label' in sample and self.task == 'segmentation'): label = sample['label'] label = ndimage.interpolation.zoom(label, scale, order = 0) @@ -140,11 +140,11 @@ def __call__(self, sample): def inverse_transform_for_prediction(self, sample): if(not sample['RandomRescale_triggered']): return sample - if(isinstance(sample['RandomRescale_origin_shape'], list) or \ - isinstance(sample['RandomRescale_origin_shape'], tuple)): - origin_shape = json.loads(sample['RandomRescale_origin_shape'][0]) + if(isinstance(sample['RandomRescale_Param'], list) or \ + isinstance(sample['RandomRescale_Param'], tuple)): + origin_shape = json.loads(sample['RandomRescale_Param'][0]) else: - origin_shape = json.loads(sample['RandomRescale_origin_shape']) + origin_shape = json.loads(sample['RandomRescale_Param']) origin_dim = len(origin_shape) - 1 predict = sample['predict'] input_shape = predict.shape diff --git a/pymic/transform/rotate.py b/pymic/transform/rotate.py index 708cf6a..2aa06d4 100644 --- a/pymic/transform/rotate.py +++ b/pymic/transform/rotate.py @@ -1,9 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division -import torch import json -import math import random import numpy as np from scipy import ndimage @@ -53,11 +51,11 @@ def __apply_transformation(self, image, transform_param_list, order = 1): return image def __call__(self, sample): - if(np.random.uniform() > self.prob): - sample['RandomRotate_triggered'] = False - return sample - else: - sample['RandomRotate_triggered'] = True + # if(random.random() > self.prob): + # sample['RandomRotate_triggered'] = False + # return sample + # else: + # sample['RandomRotate_triggered'] = True image = sample['image'] input_shape = image.shape input_dim = len(input_shape) - 1 @@ -74,7 +72,9 @@ def __call__(self, sample): angle_w = np.random.uniform(self.angle_range_w[0], self.angle_range_w[1]) transform_param_list.append([angle_w, (-2, -3)]) assert(len(transform_param_list) > 0) - + # select a random transform from the possible list rather than + # use a combination for higher efficiency + transform_param_list = [random.choice(transform_param_list)] sample['RandomRotate_Param'] = json.dumps(transform_param_list) image_t = self.__apply_transformation(image, transform_param_list, 1) sample['image'] = image_t From 972011a683379baa766a714152459fbe5e2b7f17 Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 5 Jan 2023 09:44:15 +0800 Subject: [PATCH 127/225] update transform and lr scheduler add PolynomialLR and Transpose --- pymic/loss/seg/deep_sup.py | 63 ++++++++++++++++--- pymic/net/net2d/unet2d.py | 6 +- pymic/net_run/agent_seg.py | 26 ++++---- pymic/net_run/get_optimizer.py | 12 ++-- pymic/net_run/infer_func.py | 19 +++++- pymic/net_run/net_run.py | 12 +++- pymic/test/test_nifty_dataset.py | 104 ++++++++++++++++++------------- pymic/transform/trans_dict.py | 2 + pymic/transform/transpose.py | 57 +++++++++++++++++ 9 files changed, 223 insertions(+), 78 deletions(-) create mode 100644 pymic/transform/transpose.py diff --git a/pymic/loss/seg/deep_sup.py b/pymic/loss/seg/deep_sup.py index 42ce172..da6d9ef 100644 --- a/pymic/loss/seg/deep_sup.py +++ b/pymic/loss/seg/deep_sup.py @@ -2,8 +2,43 @@ from __future__ import print_function, division import torch.nn as nn +from torch.nn.functional import interpolate from pymic.loss.seg.abstract import AbstractSegLoss +def match_prediction_and_gt_shape(pred, gt, mode = 0): + pred_shape = list(pred.shape) + gt_shape = list(gt.shape) + dim = len(pred_shape) - 2 + shape_match = False + if(dim == 2): + if(pred_shape[-1] == gt_shape[-1] and pred_shape[-2] == gt_shape[-2]): + shape_match = True + else: + if(pred_shape[-1] == gt_shape[-1] and pred_shape[-2] == gt_shape[-2] + and pred_shape[-3] == gt_shape[-3]): + shape_match = True + if(shape_match): + return pred, gt + + interp_mode = 'bilinear' if dim == 2 else 'trilinear' + if(mode == 0): + pred_new = interpolate(pred, gt_shape[2:], mode = interp_mode) + gt_new = gt + elif(mode == 1): + pred_new = pred + gt_new = interpolate(gt, pred_shape[2:], mode = interp_mode) + elif(mode == 2): + pred_new = pred + if(dim == 2): + avg_pool = nn.AdaptiveAvgPool2d(pred_shape[-2:]) + else: + avg_pool = nn.AdaptiveAvgPool3d(pred_shape[-3:]) + gt_new = avg_pool(gt) + else: + raise ValueError("mode shoud be 0, 1 or 2, but {0:} was given".format(mode)) + return pred_new, gt_new + + class DeepSuperviseLoss(AbstractSegLoss): ''' Combine deep supervision with a basic loss function. @@ -12,28 +47,36 @@ class DeepSuperviseLoss(AbstractSegLoss): :param `loss_softmax`: (optional, bool) Apply softmax to the prediction of network or not. Default is True. - :param `deep_suervise_weight`: (list) A list of weight for each deep supervision scale. \n :param `base_loss`: (nn.Module) The basic function used for each scale. + :param `deep_supervise_weight`: (list) A list of weight for each deep supervision scale. + :param `deep_supervise_model`: (int) Mode for deep supervision when the prediction + has a smaller shape than the ground truth. 0: upsample the prediction to the size + of the ground truth. 1: downsample the ground truth to the size of the prediction + via interpolation. 2: downsample the ground truth via adaptive average pooling. ''' def __init__(self, params): super(DeepSuperviseLoss, self).__init__(params) - self.deep_sup_weight = params.get('deep_suervise_weight', None) - self.base_loss = params['base_loss'] + self.base_loss = params['base_loss'] + self.deep_sup_weight = params.get('deep_supervise_weight', None) + self.deep_sup_mode = params.get('deep_supervise_mode', 0) def forward(self, loss_input_dict): - predict = loss_input_dict['prediction'] - if(not isinstance(predict, (list,tuple))): + pred = loss_input_dict['prediction'] + gt = loss_input_dict['ground_truth'] + if(not isinstance(pred, (list,tuple))): raise ValueError("""For deep supervision, the prediction should be a list or a tuple""") - predict_num = len(predict) + pred_num = len(pred) if(self.deep_sup_weight is None): - self.deep_sup_weight = [1.0] * predict_num + self.deep_sup_weight = [1.0] * pred_num else: - assert(predict_num == len(self.deep_sup_weight)) + assert(pred_num == len(self.deep_sup_weight)) loss_sum, weight_sum = 0.0, 0.0 - for i in range(predict_num): - loss_input_dict['prediction'] = predict[i] + for i in range(pred_num): + pred_i, gt_i = match_prediction_and_gt_shape(pred[i], gt, self.deep_sup_mode) + loss_input_dict['prediction'] = pred_i + loss_input_dict['ground_truth'] = gt_i temp_loss = self.base_loss(loss_input_dict) loss_sum += temp_loss * self.deep_sup_weight[i] weight_sum += self.deep_sup_weight[i] diff --git a/pymic/net/net2d/unet2d.py b/pymic/net/net2d/unet2d.py index 2708800..89b0b4f 100644 --- a/pymic/net/net2d/unet2d.py +++ b/pymic/net/net2d/unet2d.py @@ -240,18 +240,14 @@ def forward(self, x): x_d0 = self.up4(x_d1, x0) output = self.out_conv(x_d0) if(self.deep_sup): - out_shape = list(output.shape)[2:] output1 = self.out_conv1(x_d1) - output1 = interpolate(output1, out_shape, mode = 'bilinear') output2 = self.out_conv2(x_d2) - output2 = interpolate(output2, out_shape, mode = 'bilinear') output3 = self.out_conv3(x_d3) - output3 = interpolate(output3, out_shape, mode = 'bilinear') output = [output, output1, output2, output3] if(len(x_shape) == 5): - new_shape = [N, D] + list(output[0].shape)[1:] for i in range(len(output)): + new_shape = [N, D] + list(output[i].shape)[1:] output[i] = torch.transpose(torch.reshape(output[i], new_shape), 1, 2) elif(len(x_shape) == 5): new_shape = [N, D] + list(output.shape)[1:] diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 7ab0b6e..01a569a 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -173,9 +173,9 @@ def training(self): train_dice_list.append(dice_list.cpu().numpy()) train_avg_loss = train_loss / iter_valid train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) - train_avg_dice = train_cls_dice.mean() + train_avg_dice = train_cls_dice[1:].mean() - train_scalers = {'loss': train_avg_loss, 'avg_dice':train_avg_dice,\ + train_scalers = {'loss': train_avg_loss, 'avg_fg_dice':train_avg_dice,\ 'class_dice': train_cls_dice} return train_scalers @@ -214,14 +214,14 @@ def validation(self): valid_avg_loss = np.asarray(valid_loss_list).mean() valid_cls_dice = np.asarray(valid_dice_list).mean(axis = 0) - valid_avg_dice = valid_cls_dice.mean() - valid_scalers = {'loss': valid_avg_loss, 'avg_dice': valid_avg_dice,\ + valid_avg_dice = valid_cls_dice[1:].mean() + valid_scalers = {'loss': valid_avg_loss, 'avg_fg_dice': valid_avg_dice,\ 'class_dice': valid_cls_dice} return valid_scalers def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): loss_scalar ={'train':train_scalars['loss'], 'valid':valid_scalars['loss']} - dice_scalar ={'train':train_scalars['avg_dice'], 'valid':valid_scalars['avg_dice']} + dice_scalar ={'train':train_scalars['avg_fg_dice'], 'valid':valid_scalars['avg_fg_dice']} self.summ_writer.add_scalars('loss', loss_scalar, glob_it) self.summ_writer.add_scalars('dice', dice_scalar, glob_it) self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it) @@ -231,11 +231,11 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): 'valid':valid_scalars['class_dice'][c]} self.summ_writer.add_scalars('class_{0:}_dice'.format(c), cls_dice_scalar, glob_it) - logging.info('train loss {0:.4f}, avg dice {1:.4f} '.format( - train_scalars['loss'], train_scalars['avg_dice']) + "[" + \ + logging.info('train loss {0:.4f}, avg foreground dice {1:.4f} '.format( + train_scalars['loss'], train_scalars['avg_fg_dice']) + "[" + \ ' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]") - logging.info('valid loss {0:.4f}, avg dice {1:.4f} '.format( - valid_scalars['loss'], valid_scalars['avg_dice']) + "[" + \ + logging.info('valid loss {0:.4f}, avg foreground dice {1:.4f} '.format( + valid_scalars['loss'], valid_scalars['avg_fg_dice']) + "[" + \ ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") def train_valid(self): @@ -295,7 +295,7 @@ def train_valid(self): valid_scalars = self.validation() t2 = time.time() if(isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): - self.scheduler.step(valid_scalars['avg_dice']) + self.scheduler.step(valid_scalars['avg_fg_dice']) else: self.scheduler.step() @@ -304,8 +304,8 @@ def train_valid(self): logging.info('learning rate {0:}'.format(lr_value)) logging.info("training/validation time: {0:.2f}s/{1:.2f}s".format(t1-t0, t2-t1)) self.write_scalars(train_scalars, valid_scalars, lr_value, self.glob_it) - if(valid_scalars['avg_dice'] > self.max_val_dice): - self.max_val_dice = valid_scalars['avg_dice'] + if(valid_scalars['avg_fg_dice'] > self.max_val_dice): + self.max_val_dice = valid_scalars['avg_fg_dice'] self.max_val_it = self.glob_it if(len(device_ids) > 1): self.best_model_wts = copy.deepcopy(self.net.module.state_dict()) @@ -316,7 +316,7 @@ def train_valid(self): self.glob_it - self.max_val_it > early_stop_it) else False if ((self.glob_it in iter_save_list) or stop_now): save_dict = {'iteration': self.glob_it, - 'valid_pred': valid_scalars['avg_dice'], + 'valid_pred': valid_scalars['avg_fg_dice'], 'model_state_dict': self.net.module.state_dict() \ if len(device_ids) > 1 else self.net.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict()} diff --git a/pymic/net_run/get_optimizer.py b/pymic/net_run/get_optimizer.py index b3e8f82..0575e99 100644 --- a/pymic/net_run/get_optimizer.py +++ b/pymic/net_run/get_optimizer.py @@ -39,8 +39,9 @@ def get_optimizer(name, net_params, optim_params): def get_lr_scheduler(optimizer, sched_params): - name = sched_params["lr_scheduler"] - val_it = sched_params["iter_valid"] + name = sched_params["lr_scheduler"] + val_it = sched_params["iter_valid"] + epoch_last = sched_params["last_iter"] / val_it if(name is None): return None if(keyword_match(name, "ReduceLROnPlateau")): @@ -52,16 +53,19 @@ def get_lr_scheduler(optimizer, sched_params): elif(keyword_match(name, "MultiStepLR")): lr_milestones = sched_params["lr_milestones"] lr_milestones = [int(item / val_it) for item in lr_milestones] - epoch_last = sched_params["last_iter"] / val_it lr_gamma = sched_params["lr_gamma"] scheduler = lr_scheduler.MultiStepLR(optimizer, lr_milestones, lr_gamma, epoch_last) elif(keyword_match(name, "CosineAnnealingLR")): epoch_max = sched_params["iter_max"] / val_it - epoch_last = sched_params["last_iter"] / val_it lr_min = sched_params.get("lr_min", 0) scheduler = lr_scheduler.CosineAnnealingLR(optimizer, epoch_max, lr_min, epoch_last) + elif(keyword_match(name, "PolynomialLR")): + epoch_max = sched_params["iter_max"] / val_it + power = sched_params["lr_power"] + scheduler = lr_scheduler.PolynomialLR(optimizer, + epoch_max, power, epoch_last) else: raise ValueError("unsupported lr scheduler {0:}".format(name)) return scheduler \ No newline at end of file diff --git a/pymic/net_run/infer_func.py b/pymic/net_run/infer_func.py index 81d0b53..43162d0 100644 --- a/pymic/net_run/infer_func.py +++ b/pymic/net_run/infer_func.py @@ -15,7 +15,8 @@ class Inferer(object): :param `sliding_window_stride`: (optional, list) The sliding window stride. :param `tta_mode`: (optional, int) The test time augmentation mode. Default is 0 (no test time augmentation). The other option is 1 (augmentation - with horinzontal and vertical flipping). + with horinzontal and vertical flipping) and 2 (ensemble of inference + in axial, sagittal and coronal views for 2D networks applied to 3D volumes) """ def __init__(self, config): self.config = config @@ -170,6 +171,22 @@ def run(self, model, image): outputs3 = torch.flip(outputs3, [-1]) outputs4 = torch.flip(outputs4, [-2, -1]) outputs = (outputs1 + outputs2 + outputs3 + outputs4) / 4 + elif(tta_mode == 2): + outputs1 = self.__infer(image) + outputs2 = self.__infer(torch.transpose(image, -1, -3)) + outputs3 = self.__infer(torch.transpose(image, -2, -3)) + if(isinstance(outputs1, (tuple, list))): + outputs = [] + for i in range(len(outputs1)): + temp_out1 = outputs1[i] + temp_out2 = torch.transpose(outputs2[i], -1, -3) + temp_out3 = torch.transpose(outputs3[i], -2, -3) + temp_mean = (temp_out1 + temp_out2 + temp_out3) / 3 + outputs.append(temp_mean) + else: + outputs2 = torch.transpose(outputs2, -1, -3) + outputs3 = torch.transpose(outputs3, -2, -3) + outputs = (outputs1 + outputs2 + outputs3) / 3 else: raise ValueError("Undefined tta_mode {0:}".format(tta_mode)) return outputs diff --git a/pymic/net_run/net_run.py b/pymic/net_run/net_run.py index aec6fa0..c8d8dcc 100644 --- a/pymic/net_run/net_run.py +++ b/pymic/net_run/net_run.py @@ -3,6 +3,7 @@ import logging import os import sys +import shutil from pymic.util.parse_config import * from pymic.net_run.agent_cls import ClassificationAgent from pymic.net_run.agent_seg import SegmentationAgent @@ -22,8 +23,15 @@ def main(): log_dir = config['training']['ckpt_save_dir'] if(not os.path.exists(log_dir)): os.makedirs(log_dir, exist_ok=True) - logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, - format='%(message)s') + if(stage == "train"): + dst_cfg = cfg_file if "/" not in cfg_file else cfg_file.split("/")[-1] + shutil.copy(cfg_file, log_dir + "/" + dst_cfg) + if sys.version.startswith("3.9"): + logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, + format='%(message)s', force=True) # for python 3.9 + else: + logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, + format='%(message)s') # for python 3.6 logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging_config(config) task = config['dataset']['task_type'] diff --git a/pymic/test/test_nifty_dataset.py b/pymic/test/test_nifty_dataset.py index 581bc52..b5c4660 100644 --- a/pymic/test/test_nifty_dataset.py +++ b/pymic/test/test_nifty_dataset.py @@ -1,50 +1,68 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division +import sys +import numpy as np +from pymic.util.parse_config import * +from pymic.net_run.agent_cls import ClassificationAgent +from pymic.net_run.agent_seg import SegmentationAgent +import SimpleITK as sitk -import os -import torch -import pandas as pd -import numpy as np -from skimage import io, transform -from torch.utils.data import Dataset, DataLoader -from torchvision import transforms, utils -from pymic.io.image_read_write import * -from pymic.io.nifty_dataset import NiftyDataset -from pymic.io.transform3d import * +def save_array_as_nifty_volume(data, image_name, reference_name = None): + """ + Save a numpy array as nifty image -if __name__ == "__main__": - root_dir = '/home/guotai/data/brats/BraTS2018_Training' - csv_file = '/home/guotai/projects/torch_brats/brats/config/brats18_train_train.csv' - - crop1 = CropWithBoundingBox(start = None, output_size = [4, 144, 180, 144]) - norm = ChannelWiseNormalize(mean = None, std = None, zero_to_random = True) - labconv = LabelConvert([0, 1, 2, 4], [0, 1, 2, 3]) - crop2 = RandomCrop([128, 128, 128]) - rescale =Rescale([64, 64, 64]) - transform_list = [crop1, norm, labconv, crop2,rescale, ToTensor()] - transformed_dataset = NiftyDataset(root_dir=root_dir, - csv_file = csv_file, - modal_num = 4, - transform = transforms.Compose(transform_list)) - dataloader = DataLoader(transformed_dataset, batch_size=4, - shuffle=True, num_workers=4) - # Helper function to show a batch + :param data: (numpy.ndarray) A numpy array with shape [Depth, Height, Width]. + :param image_name: (str) The ouput file name. + :param reference_name: (str) File name of the reference image of which + meta information is used. + """ + img = sitk.GetImageFromArray(data) + if(reference_name is not None): + img_ref = sitk.ReadImage(reference_name) + #img.CopyInformation(img_ref) + img.SetSpacing(img_ref.GetSpacing()) + img.SetOrigin(img_ref.GetOrigin()) + img.SetDirection(img_ref.GetDirection()) + sitk.WriteImage(img, image_name) +def main(): + """ + The main function for running a network for training or inference. + """ + if(len(sys.argv) < 3): + print('Number of arguments should be 3. e.g.') + print('python test_nifty_dataset.py train config.cfg') + exit() + stage = str(sys.argv[1]) + cfg_file = str(sys.argv[2]) + config = parse_config(cfg_file) + config = synchronize_config(config) + # task = config['dataset']['task_type'] + # assert task in ['cls', 'cls_nexcl', 'seg'] + # if(task == 'cls' or task == 'cls_nexcl'): + # agent = ClassificationAgent(config, stage) + # else: + # agent = SegmentationAgent(config, stage) + agent = SegmentationAgent(config, stage) + agent.create_dataset() + data_loader = agent.train_loader if stage == "train" else agent.test_loader + it = 0 + for data in data_loader: + inputs = agent.convert_tensor_type(data['image']) + labels_prob = agent.convert_tensor_type(data['label_prob']) + for i in range(inputs.shape[0]): + image_i = inputs[i][0] + label_i = np.argmax(labels_prob[i], axis = 0) + print(image_i.shape, label_i.shape) + image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) + label_name = "temp/label_{0:}_{1:}.nii.gz".format(it, i) + save_array_as_nifty_volume(image_i, image_name, reference_name = None) + save_array_as_nifty_volume(label_i, label_name, reference_name = None) + it = it + 1 + if(it == 10): + break - for i_batch, sample_batched in enumerate(dataloader): - print(i_batch, sample_batched['image'].size(), - sample_batched['label'].size()) +if __name__ == "__main__": + main() + - # # observe 4th batch and stop. - modals = ['flair', 't1ce', 't1', 't2'] - if i_batch == 0: - image = sample_batched['image'].numpy() - label = sample_batched['label'].numpy() - for i in range(image.shape[0]): - for mod in range(4): - image_i = image[i][mod] - label_i = label[i][0] - image_name = "temp/image_{0:}_{1:}.nii.gz".format(i, modals[mod]) - label_name = "temp/label_{0:}.nii.gz".format(i) - save_array_as_nifty_volume(image_i, image_name, reference_name = None) - save_array_as_nifty_volume(label_i, label_name, reference_name = None) \ No newline at end of file diff --git a/pymic/transform/trans_dict.py b/pymic/transform/trans_dict.py index 8da3b38..864dae1 100644 --- a/pymic/transform/trans_dict.py +++ b/pymic/transform/trans_dict.py @@ -34,6 +34,7 @@ from pymic.transform.pad import Pad from pymic.transform.rotate import RandomRotate from pymic.transform.rescale import Rescale, RandomRescale +from pymic.transform.transpose import RandomTranspose from pymic.transform.threshold import * from pymic.transform.normalize import * from pymic.transform.crop import * @@ -57,6 +58,7 @@ 'RandomCrop': RandomCrop, 'RandomResizedCrop': RandomResizedCrop, 'RandomRescale': RandomRescale, + 'RandomTranspose': RandomTranspose, 'RandomFlip': RandomFlip, 'RandomRotate': RandomRotate, 'ReduceLabelDim': ReduceLabelDim, diff --git a/pymic/transform/transpose.py b/pymic/transform/transpose.py new file mode 100644 index 0000000..9c73bda --- /dev/null +++ b/pymic/transform/transpose.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import json +import random +import numpy as np +from pymic.transform.abstract_transform import AbstractTransform + + +class RandomTranspose(AbstractTransform): + """ + Random transpose for 3D volumes. Assume the input has a shape of [C, D, H, W], the + output shape will be of [C, D, H, W], [C, W, H, D] or [C, H, D, W] + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `RandomTranspose_inverse`: (optional, bool) + Is inverse transform needed for inference. Default is `True`. + """ + def __init__(self, params): + super(RandomTranspose, self).__init__(params) + self.inverse = params.get('RandomTranspose_inverse'.lower(), True) + + def __call__(self, sample): + image = sample['image'] + input_shape = image.shape + input_dim = len(input_shape) - 1 + assert(input_dim == 3) + + rand_num = random.random() + if(rand_num < 0.4): + transpose_axis = None + elif(rand_num < 0.7): + transpose_axis = [0, 3, 2, 1] + else: + transpose_axis = [0, 2, 1, 3] + sample['RandomTranspose_Param'] = json.dumps(transpose_axis) + if(transpose_axis is not None): + image_t = np.transpose(image, transpose_axis) + sample['image'] = image_t + if('label' in sample and self.task == 'segmentation'): + sample['label'] = np.transpose(sample['label'] , transpose_axis) + if('pixel_weight' in sample and self.task == 'segmentation'): + sample['pixel_weight'] = np.transpose(sample['pixel_weight'] , transpose_axis) + + return sample + + def inverse_transform_for_prediction(self, sample): + if(isinstance(sample['RandomTranspose_Param'], list) or \ + isinstance(sample['RandomTranspose_Param'], tuple)): + transpose_axis = json.loads(sample['RandomTranspose_Param'][0]) + else: + transpose_axis = json.loads(sample['RandomTranspose_Param']) + if(transpose_axis is not None): + sample['predict'] = np.transpose(sample['predict'] , transpose_axis) + return sample \ No newline at end of file From 603df8103d625efae05a48ad95e2e9c1c6c943fc Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 9 Feb 2023 13:30:17 +0800 Subject: [PATCH 128/225] Update README.md update citation info --- README.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index d338581..556e357 100644 --- a/README.md +++ b/README.md @@ -4,8 +4,8 @@ PyMIC is a pytorch-based toolkit for medical image computing with annotation-eff Currently PyMIC supports 2D/3D medical image classification and segmentation, and it is still under development. If you use this toolkit, please cite the following paper: -* G. Wang, X. Luo, R. Gu, S. Yang, Y. Qu, S. Zhai, Q. Zhao, K. Li, S. Zhang. (2022). -[PyMIC: A deep learning toolkit for annotation-efficient medical image segmentation.][arxiv2022] arXiv, 2208.09350. +* G. Wang, X. Luo, R. Gu, S. Yang, Y. Qu, S. Zhai, Q. Zhao, K. Li, S. Zhang. (2023). +[PyMIC: A deep learning toolkit for annotation-efficient medical image segmentation.][arxiv2022] Computer Methods and Programs in Biomedicine (CMPB). February 2023, 107398. [arxiv2022]:http://arxiv.org/abs/2208.09350 @@ -14,11 +14,11 @@ BibTeX entry: @article{Wang2022pymic, author = {Guotai Wang and Xiangde Luo and Ran Gu and Shuojue Yang and Yijie Qu and Shuwei Zhai and Qianfei Zhao and Kang Li and Shaoting Zhang}, title = {{PyMIC: A deep learning toolkit for annotation-efficient medical image segmentation}}, - year = {2022}, + year = {2023}, url = {http://arxiv.org/abs/2208.09350}, - journal = {arXiv}, - volume = {2208.09350}, - pages = {1-10}, + journal = {Computer Methods and Programs in Biomedicine}, + volume = {February}, + pages = {107398}, } # Features From ab07d4f2bbc7cf0d42a939d79cd7c1d4b0cdf19f Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 9 Feb 2023 23:00:27 +0800 Subject: [PATCH 129/225] update __init__ update __init__ --- pymic/__init__.py | 2 ++ pymic/io/__init__.py | 2 ++ pymic/layer/__init__.py | 2 ++ pymic/loss/__init__.py | 2 ++ pymic/loss/cls/__init__.py | 2 ++ pymic/loss/seg/__init__.py | 2 ++ pymic/net/cls/__init__.py | 2 ++ pymic/net/net2d/__init__.py | 2 ++ pymic/net_run/__init__.py | 2 ++ pymic/net_run_nll/__init__.py | 2 ++ pymic/net_run_ssl/__init__.py | 2 ++ pymic/net_run_wsl/__init__.py | 2 ++ pymic/transform/__init__.py | 2 ++ 13 files changed, 26 insertions(+) diff --git a/pymic/__init__.py b/pymic/__init__.py index e69de29..72b8078 100644 --- a/pymic/__init__.py +++ b/pymic/__init__.py @@ -0,0 +1,2 @@ +from __future__ import absolute_import +from . import * \ No newline at end of file diff --git a/pymic/io/__init__.py b/pymic/io/__init__.py index e69de29..72b8078 100644 --- a/pymic/io/__init__.py +++ b/pymic/io/__init__.py @@ -0,0 +1,2 @@ +from __future__ import absolute_import +from . import * \ No newline at end of file diff --git a/pymic/layer/__init__.py b/pymic/layer/__init__.py index e69de29..72b8078 100644 --- a/pymic/layer/__init__.py +++ b/pymic/layer/__init__.py @@ -0,0 +1,2 @@ +from __future__ import absolute_import +from . import * \ No newline at end of file diff --git a/pymic/loss/__init__.py b/pymic/loss/__init__.py index e69de29..72b8078 100644 --- a/pymic/loss/__init__.py +++ b/pymic/loss/__init__.py @@ -0,0 +1,2 @@ +from __future__ import absolute_import +from . import * \ No newline at end of file diff --git a/pymic/loss/cls/__init__.py b/pymic/loss/cls/__init__.py index e69de29..72b8078 100644 --- a/pymic/loss/cls/__init__.py +++ b/pymic/loss/cls/__init__.py @@ -0,0 +1,2 @@ +from __future__ import absolute_import +from . import * \ No newline at end of file diff --git a/pymic/loss/seg/__init__.py b/pymic/loss/seg/__init__.py index e69de29..72b8078 100644 --- a/pymic/loss/seg/__init__.py +++ b/pymic/loss/seg/__init__.py @@ -0,0 +1,2 @@ +from __future__ import absolute_import +from . import * \ No newline at end of file diff --git a/pymic/net/cls/__init__.py b/pymic/net/cls/__init__.py index e69de29..72b8078 100644 --- a/pymic/net/cls/__init__.py +++ b/pymic/net/cls/__init__.py @@ -0,0 +1,2 @@ +from __future__ import absolute_import +from . import * \ No newline at end of file diff --git a/pymic/net/net2d/__init__.py b/pymic/net/net2d/__init__.py index e69de29..72b8078 100644 --- a/pymic/net/net2d/__init__.py +++ b/pymic/net/net2d/__init__.py @@ -0,0 +1,2 @@ +from __future__ import absolute_import +from . import * \ No newline at end of file diff --git a/pymic/net_run/__init__.py b/pymic/net_run/__init__.py index e69de29..72b8078 100644 --- a/pymic/net_run/__init__.py +++ b/pymic/net_run/__init__.py @@ -0,0 +1,2 @@ +from __future__ import absolute_import +from . import * \ No newline at end of file diff --git a/pymic/net_run_nll/__init__.py b/pymic/net_run_nll/__init__.py index e69de29..72b8078 100644 --- a/pymic/net_run_nll/__init__.py +++ b/pymic/net_run_nll/__init__.py @@ -0,0 +1,2 @@ +from __future__ import absolute_import +from . import * \ No newline at end of file diff --git a/pymic/net_run_ssl/__init__.py b/pymic/net_run_ssl/__init__.py index e69de29..72b8078 100644 --- a/pymic/net_run_ssl/__init__.py +++ b/pymic/net_run_ssl/__init__.py @@ -0,0 +1,2 @@ +from __future__ import absolute_import +from . import * \ No newline at end of file diff --git a/pymic/net_run_wsl/__init__.py b/pymic/net_run_wsl/__init__.py index e69de29..72b8078 100644 --- a/pymic/net_run_wsl/__init__.py +++ b/pymic/net_run_wsl/__init__.py @@ -0,0 +1,2 @@ +from __future__ import absolute_import +from . import * \ No newline at end of file diff --git a/pymic/transform/__init__.py b/pymic/transform/__init__.py index e69de29..72b8078 100644 --- a/pymic/transform/__init__.py +++ b/pymic/transform/__init__.py @@ -0,0 +1,2 @@ +from __future__ import absolute_import +from . import * \ No newline at end of file From 99145ff0b5c5d4219b447ea3ba9e0c2de33e4f39 Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 15 Feb 2023 09:54:18 +0800 Subject: [PATCH 130/225] update lr_scheduler add StepLR --- pymic/net_run/get_optimizer.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pymic/net_run/get_optimizer.py b/pymic/net_run/get_optimizer.py index 0575e99..5d9c608 100644 --- a/pymic/net_run/get_optimizer.py +++ b/pymic/net_run/get_optimizer.py @@ -41,7 +41,9 @@ def get_optimizer(name, net_params, optim_params): def get_lr_scheduler(optimizer, sched_params): name = sched_params["lr_scheduler"] val_it = sched_params["iter_valid"] - epoch_last = sched_params["last_iter"] / val_it + epoch_last = sched_params["last_iter"] + if(epoch_last > 0): + epoch_last = int(epoch_last / val_it) if(name is None): return None if(keyword_match(name, "ReduceLROnPlateau")): @@ -56,6 +58,11 @@ def get_lr_scheduler(optimizer, sched_params): lr_gamma = sched_params["lr_gamma"] scheduler = lr_scheduler.MultiStepLR(optimizer, lr_milestones, lr_gamma, epoch_last) + elif(keyword_match(name, "StepLR")): + lr_step = sched_params["lr_step"] / val_it + lr_gamma = sched_params["lr_gamma"] + scheduler = lr_scheduler.StepLR(optimizer, + lr_step, lr_gamma, epoch_last) elif(keyword_match(name, "CosineAnnealingLR")): epoch_max = sched_params["iter_max"] / val_it lr_min = sched_params.get("lr_min", 0) From 49e1c64f1045e6255a5ac992f53b9383d952924e Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 16 Feb 2023 12:50:56 +0800 Subject: [PATCH 131/225] update main files lr_schduler update after each validation for classification task; update main files for logging bug update init fix bug for alpha in mean teacher --- README.md | 4 +- pymic/net_run/agent_cls.py | 33 ++-- pymic/net_run/agent_seg.py | 21 ++- pymic/net_run/get_optimizer.py | 2 +- pymic/net_run_ssl/__init__.py | 8 +- pymic/net_run_ssl/ssl_main.py | 8 +- pymic/net_run_ssl/ssl_mt.py | 2 +- pymic/net_run_ssl/ssl_uamt.py | 2 +- pymic/net_run_wsl/wsl_main.py | 8 +- pymic/transform/__init__.py | 13 +- pymic/transform/intensity.py | 265 ++++++++++++++++++++++++++++++++- pymic/transform/trans_dict.py | 17 ++- requirements.txt | 1 + setup.py | 1 + 14 files changed, 347 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index 556e357..3a552b4 100644 --- a/README.md +++ b/README.md @@ -15,9 +15,9 @@ BibTeX entry: author = {Guotai Wang and Xiangde Luo and Ran Gu and Shuojue Yang and Yijie Qu and Shuwei Zhai and Qianfei Zhao and Kang Li and Shaoting Zhang}, title = {{PyMIC: A deep learning toolkit for annotation-efficient medical image segmentation}}, year = {2023}, - url = {http://arxiv.org/abs/2208.09350}, + url = {https://doi.org/10.1016/j.cmpb.2023.107398}, journal = {Computer Methods and Programs in Biomedicine}, - volume = {February}, + volume = {231}, pages = {107398}, } diff --git a/pymic/net_run/agent_cls.py b/pymic/net_run/agent_cls.py index f9f7781..181aa53 100644 --- a/pymic/net_run/agent_cls.py +++ b/pymic/net_run/agent_cls.py @@ -157,9 +157,6 @@ def training(self): loss = self.get_loss_value(data, outputs, labels) loss.backward() self.optimizer.step() - if(self.scheduler is not None and \ - not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): - self.scheduler.step() # statistics sample_num += labels.size(0) @@ -183,7 +180,7 @@ def validation(self): inputs = self.convert_tensor_type(data['image']) labels = self.convert_tensor_type(data['label_prob']) inputs, labels = inputs.to(self.device), labels.to(self.device) - self.optimizer.zero_grad() + # self.optimizer.zero_grad() # forward + backward + optimize outputs = self.net(inputs) loss = self.get_loss_value(data, outputs, labels) @@ -196,20 +193,17 @@ def validation(self): avg_loss = running_loss / sample_num avg_score= running_score.double() / sample_num metrics = self.config['training'].get("evaluation_metric", "accuracy") - if(isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): - self.scheduler.step(avg_score) valid_scalers = {'loss': avg_loss, metrics: avg_score} return valid_scalers def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): - metrics =self.config['training'].get("evaluation_metric", "accuracy") + metrics = self.config['training'].get("evaluation_metric", "accuracy") loss_scalar ={'train':train_scalars['loss'], 'valid':valid_scalars['loss']} acc_scalar ={'train':train_scalars[metrics],'valid':valid_scalars[metrics]} self.summ_writer.add_scalars('loss', loss_scalar, glob_it) self.summ_writer.add_scalars(metrics, acc_scalar, glob_it) self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it) - logging.info("{0:} it {1:}".format(str(datetime.now())[:-7], glob_it)) logging.info('train loss {0:.4f}, avg {1:} {2:.4f}'.format( train_scalars['loss'], metrics, train_scalars[metrics])) logging.info('valid loss {0:.4f}, avg {1:} {2:.4f}'.format( @@ -251,7 +245,10 @@ def train_valid(self): checkpoint_file = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, iter_start) self.checkpoint = torch.load(checkpoint_file, map_location = self.device) assert(self.checkpoint['iteration'] == iter_start) - self.net.load_state_dict(self.checkpoint['model_state_dict']) + if(len(device_ids) > 1): + self.net.module.load_state_dict(self.checkpoint['model_state_dict']) + else: + self.net.load_state_dict(self.checkpoint['model_state_dict']) self.max_val_score = self.checkpoint.get('valid_pred', 0) self.max_val_it = self.checkpoint['iteration'] self.best_model_wts = self.checkpoint['model_state_dict'] @@ -266,15 +263,28 @@ def train_valid(self): self.glob_it = iter_start for it in range(iter_start, iter_max, iter_valid): lr_value = self.optimizer.param_groups[0]['lr'] + t0 = time.time() train_scalars = self.training() + t1 = time.time() valid_scalars = self.validation() + t2 = time.time() + if(isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): + self.scheduler.step(valid_scalars[metrics]) + else: + self.scheduler.step() + self.glob_it = it + iter_valid + logging.info("\n{0:} it {1:}".format(str(datetime.now())[:-7], self.glob_it)) + logging.info('learning rate {0:}'.format(lr_value)) + logging.info("training/validation time: {0:.2f}s/{1:.2f}s".format(t1-t0, t2-t1)) self.write_scalars(train_scalars, valid_scalars, lr_value, self.glob_it) - if(valid_scalars[metrics] > self.max_val_score): self.max_val_score = valid_scalars[metrics] self.max_val_it = self.glob_it - self.best_model_wts = copy.deepcopy(self.net.state_dict()) + if(len(device_ids) > 1): + self.best_model_wts = copy.deepcopy(self.net.module.state_dict()) + else: + self.best_model_wts = copy.deepcopy(self.net.state_dict()) stop_now = True if(early_stop_it is not None and \ self.glob_it - self.max_val_it > early_stop_it) else False @@ -306,7 +316,6 @@ def train_valid(self): self.max_val_it, metrics, self.max_val_score)) self.summ_writer.close() - def infer(self): device_ids = self.config['testing']['gpus'] device = torch.device("cuda:{0:}".format(device_ids[0])) diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 01a569a..83975f0 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -96,9 +96,12 @@ def create_loss_calculator(self): raise ValueError("Undefined loss function {0:}".format(loss_name)) else: base_loss = self.loss_dict[loss_name](self.config['training']) - if(self.config['network'].get('deep_supervise', False)): - weight = self.config['network'].get('deep_supervise_weight', None) - params = {'deep_supervise_weight': weight, 'base_loss':base_loss} + if(self.config['training'].get('deep_supervise', False)): + weight = self.config['training'].get('deep_supervise_weight', None) + mode = self.config['training'].get('deep_supervise_mode', 2) + params = {'deep_supervise_weight': weight, + 'deep_supervise_mode': mode, + 'base_loss':base_loss} self.loss_calculator = DeepSuperviseLoss(params) else: self.loss_calculator = base_loss @@ -106,7 +109,10 @@ def create_loss_calculator(self): def get_loss_value(self, data, pred, gt, param = None): loss_input_dict = {'prediction':pred, 'ground_truth': gt} if data.get('pixel_weight', None) is not None: - loss_input_dict['pixel_weight'] = data['pixel_weight'].to(pred.device) + if(isinstance(pred, tuple) or isinstance(pred, list)): + loss_input_dict['pixel_weight'] = data['pixel_weight'].to(pred[0].device) + else: + loss_input_dict['pixel_weight'] = data['pixel_weight'].to(pred.device) loss_value = self.loss_calculator(loss_input_dict) return loss_value @@ -122,7 +128,7 @@ def set_postprocessor(self, postprocessor): def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] - mixup_prob = self.config['training'].get('mixup_probability', 0.5) + mixup_prob = self.config['training'].get('mixup_probability', 0.0) train_loss = 0 train_dice_list = [] self.net.train() @@ -135,7 +141,7 @@ def training(self): # get the inputs inputs = self.convert_tensor_type(data['image']) labels_prob = self.convert_tensor_type(data['label_prob']) - if(random() < mixup_prob): + if(mixup_prob > 0 and random() < mixup_prob): inputs, labels_prob = mixup(inputs, labels_prob) # # for debug @@ -246,7 +252,10 @@ def train_valid(self): else: self.device = torch.device("cuda:{0:}".format(device_ids[0])) self.net.to(self.device) + ckpt_dir = self.config['training']['ckpt_save_dir'] + if(ckpt_dir[-1] == "/"): + ckpt_dir = ckpt_dir[:-1] ckpt_prefix = self.config['training'].get('ckpt_prefix', None) if(ckpt_prefix is None): ckpt_prefix = ckpt_dir.split('/')[-1] diff --git a/pymic/net_run/get_optimizer.py b/pymic/net_run/get_optimizer.py index 5d9c608..ad8fda0 100644 --- a/pymic/net_run/get_optimizer.py +++ b/pymic/net_run/get_optimizer.py @@ -59,7 +59,7 @@ def get_lr_scheduler(optimizer, sched_params): scheduler = lr_scheduler.MultiStepLR(optimizer, lr_milestones, lr_gamma, epoch_last) elif(keyword_match(name, "StepLR")): - lr_step = sched_params["lr_step"] / val_it + lr_step = sched_params["lr_step"] / val_it lr_gamma = sched_params["lr_gamma"] scheduler = lr_scheduler.StepLR(optimizer, lr_step, lr_gamma, epoch_last) diff --git a/pymic/net_run_ssl/__init__.py b/pymic/net_run_ssl/__init__.py index 72b8078..4309107 100644 --- a/pymic/net_run_ssl/__init__.py +++ b/pymic/net_run_ssl/__init__.py @@ -1,2 +1,8 @@ from __future__ import absolute_import -from . import * \ No newline at end of file +from pymic.net_run_ssl.ssl_abstract import * +from pymic.net_run_ssl.ssl_cct import * +from pymic.net_run_ssl.ssl_cps import * +from pymic.net_run_ssl.ssl_em import * +from pymic.net_run_ssl.ssl_mt import * +from pymic.net_run_ssl.ssl_uamt import * +from pymic.net_run_ssl.ssl_urpc import * \ No newline at end of file diff --git a/pymic/net_run_ssl/ssl_main.py b/pymic/net_run_ssl/ssl_main.py index 7c4f2b9..996f14a 100644 --- a/pymic/net_run_ssl/ssl_main.py +++ b/pymic/net_run_ssl/ssl_main.py @@ -35,8 +35,12 @@ def main(): log_dir = config['training']['ckpt_save_dir'] if(not os.path.exists(log_dir)): os.mkdir(log_dir) - logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, - format='%(message)s') + if sys.version.startswith("3.9"): + logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, + format='%(message)s', force=True) # for python 3.9 + else: + logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, + format='%(message)s') # for python 3.6 logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging_config(config) ssl_method = config['semi_supervised_learning']['ssl_method'] diff --git a/pymic/net_run_ssl/ssl_mt.py b/pymic/net_run_ssl/ssl_mt.py index 6a5a95f..405ce41 100644 --- a/pymic/net_run_ssl/ssl_mt.py +++ b/pymic/net_run_ssl/ssl_mt.py @@ -104,7 +104,7 @@ def training(self): # update EMA alpha = ssl_cfg.get('ema_decay', 0.99) - alpha = min(1 - 1 / (iter_max + 1), alpha) + alpha = min(1 - 1 / (self.glob_it / iter_valid + 1), alpha) for ema_param, param in zip(self.net_ema.parameters(), self.net.parameters()): ema_param.data.mul_(alpha).add_(1 - alpha, param.data) diff --git a/pymic/net_run_ssl/ssl_uamt.py b/pymic/net_run_ssl/ssl_uamt.py index af6640f..18ddd7b 100644 --- a/pymic/net_run_ssl/ssl_uamt.py +++ b/pymic/net_run_ssl/ssl_uamt.py @@ -106,7 +106,7 @@ def training(self): # update EMA alpha = ssl_cfg.get('ema_decay', 0.99) - alpha = min(1 - 1 / (iter_max + 1), alpha) + alpha = min(1 - 1 / (self.glob_it / iter_valid + 1), alpha) for ema_param, param in zip(self.net_ema.parameters(), self.net.parameters()): ema_param.data.mul_(alpha).add_(1 - alpha, param.data) diff --git a/pymic/net_run_wsl/wsl_main.py b/pymic/net_run_wsl/wsl_main.py index d8e791d..540c535 100644 --- a/pymic/net_run_wsl/wsl_main.py +++ b/pymic/net_run_wsl/wsl_main.py @@ -34,8 +34,12 @@ def main(): log_dir = config['training']['ckpt_save_dir'] if(not os.path.exists(log_dir)): os.mkdir(log_dir) - logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, - format='%(message)s') + if sys.version.startswith("3.9"): + logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, + format='%(message)s', force=True) # for python 3.9 + else: + logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, + format='%(message)s') # for python 3.6 logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging_config(config) wsl_method = config['weakly_supervised_learning']['wsl_method'] diff --git a/pymic/transform/__init__.py b/pymic/transform/__init__.py index 72b8078..b47e09d 100644 --- a/pymic/transform/__init__.py +++ b/pymic/transform/__init__.py @@ -1,2 +1,13 @@ +# -*- coding: utf-8 -*- from __future__ import absolute_import -from . import * \ No newline at end of file +from pymic.transform.intensity import * +from pymic.transform.flip import * +from pymic.transform.pad import * +from pymic.transform.rotate import * +from pymic.transform.rescale import * +from pymic.transform.transpose import * +from pymic.transform.threshold import * +from pymic.transform.normalize import * +from pymic.transform.crop import * +from pymic.transform.label_convert import * +from pymic.transform.trans_dict import TransformDict \ No newline at end of file diff --git a/pymic/transform/intensity.py b/pymic/transform/intensity.py index 171604b..3b5ee9d 100644 --- a/pymic/transform/intensity.py +++ b/pymic/transform/intensity.py @@ -1,15 +1,48 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division - -import torch +import copy import json import math import random import numpy as np -from scipy import ndimage from pymic.transform.abstract_transform import AbstractTransform from pymic.util.image_process import * +try: # SciPy >= 0.19 + from scipy.special import comb +except ImportError: + from scipy.misc import comb + +def bernstein_poly(i, n, t): + """ + The Bernstein polynomial of n, i as a function of t + """ + + return comb(n, i) * ( t**(n-i) ) * (1 - t)**i + +def bezier_curve(points, nTimes=1000): + """ + Given a set of control points, return the + bezier curve defined by the control points. + Control points should be a list of lists, or list of tuples + such as [ [1,1], + [2,3], + [4,5], ..[Xn, Yn] ] + nTimes is the number of time steps, defaults to 1000 + See http://processingjs.nihongoresources.com/bezierinfo/ + """ + nPoints = len(points) + xPoints = np.array([p[0] for p in points]) + yPoints = np.array([p[1] for p in points]) + + t = np.linspace(0.0, 1.0, nTimes) + + polynomial_array = np.array([ bernstein_poly(i, nPoints-1, t) for i in range(0, nPoints) ]) + + xvals = np.dot(xPoints, polynomial_array) + yvals = np.dot(yPoints, polynomial_array) + + return xvals, yvals class GammaCorrection(AbstractTransform): """ @@ -98,4 +131,230 @@ def __call__(self, sample): assert(image.shape[0] == 1 or image.shape[0] == 3) if(image.shape[0] == 1): sample['image'] = np.concatenate([image, image, image]) + return sample + +class NonLinearTransform(AbstractTransform): + def __init__(self, params): + super(NonLinearTransform, self).__init__(params) + self.inverse = params.get('NonLinearTransform_inverse'.lower(), False) + self.prob = params.get('NonLinearTransform_probability'.lower(), 0.5) + + def __call__(self, sample): + if(random.random() > self.prob): + return sample + + image= sample['image'] + points = [[0, 0], [random.random(), random.random()], [random.random(), random.random()], [1, 1]] + xvals, yvals = bezier_curve(points, nTimes=100000) + if random.random() < 0.5: # Half change to get flip + xvals = np.sort(xvals) + else: + xvals, yvals = np.sort(xvals), np.sort(yvals) + image = np.interp(image, xvals, yvals) + sample['image'] = image + return sample + +class LocalShuffling(AbstractTransform): + """ + local pixel shuffling of an input image, used for self-supervised learning + """ + def __init__(self, params): + super(LocalShuffling, self).__init__(params) + self.inverse = params.get('LocalShuffling_inverse'.lower(), False) + self.prob = params.get('LocalShuffling_probability'.lower(), 0.5) + self.block_range = params.get('LocalShuffling_block_range'.lower(), (5000, 10000)) + self.block_size_min = params.get('LocalShuffling_block_size_min'.lower(), None) + self.block_size_max = params.get('LocalShuffling_block_size_max'.lower(), None) + + def __call__(self, sample): + if(random.random() > self.prob): + return sample + + image= sample['image'] + img_shape = image.shape + img_dim = len(img_shape) - 1 + assert(img_dim == 2 or img_dim == 3) + img_out = copy.deepcopy(image) + if(self.block_size_min is None): + block_size_min = [2] * img_dim + elif(isinstance(self.block_size_min, int)): + block_size_min = [self.block_size_min] * img_dim + else: + assert(len(self.block_size_min) == img_dim) + block_size_min = self.block_size_min + + if(self.block_size_max is None): + block_size_max = [img_shape[1+i]//10 for i in range(img_dim)] + elif(isinstance(self.block_size_min, int)): + block_size_max = [self.block_size_max] * img_dim + else: + assert(len(self.block_size_max) == img_dim) + block_size_max = self.block_size_max + block_num = random.randint(self.block_range[0], self.block_range[1]) + + for n in range(block_num): + block_size = [random.randint(block_size_min[i], block_size_max[i]) \ + for i in range(img_dim)] + coord_min = [random.randint(0, img_shape[1+i] - block_size[i]) \ + for i in range(img_dim)] + if(img_dim == 2): + window = image[:, coord_min[0]:coord_min[0] + block_size[0], + coord_min[1]:coord_min[1] + block_size[1]] + n_pixels = block_size[0] * block_size[1] + else: + window = image[:, coord_min[0]:coord_min[0] + block_size[0], + coord_min[1]:coord_min[1] + block_size[1], + coord_min[2]:coord_min[2] + block_size[2]] + n_pixels = block_size[0] * block_size[1] * block_size[2] + window = np.reshape(window, [-1, n_pixels]) + np.random.shuffle(np.transpose(window)) + window = np.transpose(window) + if(img_dim == 2): + window = np.reshape(window, [-1, block_size[0], block_size[1]]) + img_out[:, coord_min[0]:coord_min[0] + block_size[0], + coord_min[1]:coord_min[1] + block_size[1]] = window + else: + window = np.reshape(window, [-1, block_size[0], block_size[1], block_size[2]]) + img_out[:, coord_min[0]:coord_min[0] + block_size[0], + coord_min[1]:coord_min[1] + block_size[1], + coord_min[2]:coord_min[2] + block_size[2]] = window + sample['image'] = img_out + return sample + +class InPainting(AbstractTransform): + """ + In-painting of an input image, used for self-supervised learning + """ + def __init__(self, params): + super(InPainting, self).__init__(params) + self.inverse = params.get('InPainting_inverse'.lower(), False) + self.prob = params.get('InPainting_probability'.lower(), 0.5) + self.block_range = params.get('InPainting_block_range'.lower(), (1, 6)) + self.block_size_min = params.get('InPainting_block_size_min'.lower(), None) + self.block_size_max = params.get('InPainting_block_size_max'.lower(), None) + + def __call__(self, sample): + if(random.random() > self.prob): + return sample + + image= sample['image'] + img_shape = image.shape + img_dim = len(img_shape) - 1 + assert(img_dim == 2 or img_dim == 3) + + if(self.block_size_min is None): + block_size_min = [img_shape[1+i]//6 for i in range(img_dim)] + elif(isinstance(self.block_size_min, int)): + block_size_min = [self.block_size_min] * img_dim + else: + assert(len(self.block_size_min) == img_dim) + block_size_min = self.block_size_min + + if(self.block_size_max is None): + block_size_max = [img_shape[1+i]//3 for i in range(img_dim)] + elif(isinstance(self.block_size_min, int)): + block_size_max = [self.block_size_max] * img_dim + else: + assert(len(self.block_size_max) == img_dim) + block_size_max = self.block_size_max + block_num = random.randint(self.block_range[0], self.block_range[1]) + + for n in range(block_num): + block_size = [random.randint(block_size_min[i], block_size_max[i]) \ + for i in range(img_dim)] + coord_min = [random.randint(3, img_shape[1+i] - block_size[i] - 3) \ + for i in range(img_dim)] + if(img_dim == 2): + random_block = np.random.rand(img_shape[0], block_size[0], block_size[1]) + image[:, coord_min[0]:coord_min[0] + block_size[0], + coord_min[1]:coord_min[1] + block_size[1]] = random_block + else: + random_block = np.random.rand(img_shape[0], block_size[0], + block_size[1], block_size[2]) + image[:, coord_min[0]:coord_min[0] + block_size[0], + coord_min[1]:coord_min[1] + block_size[1], + coord_min[2]:coord_min[2] + block_size[2]] = random_block + sample['image'] = image + return sample + +class OutPainting(AbstractTransform): + """ + Out-painting of an input image, used for self-supervised learning + """ + def __init__(self, params): + super(OutPainting, self).__init__(params) + self.inverse = params.get('OutPainting_inverse'.lower(), False) + self.prob = params.get('OutPainting_probability'.lower(), 0.5) + self.block_range = params.get('OutPainting_block_range'.lower(), (1, 6)) + self.block_size_min = params.get('OutPainting_block_size_min'.lower(), None) + self.block_size_max = params.get('OutPainting_block_size_max'.lower(), None) + + def __call__(self, sample): + if(random.random() > self.prob): + return sample + + image= sample['image'] + img_shape = image.shape + img_dim = len(img_shape) - 1 + assert(img_dim == 2 or img_dim == 3) + img_out = np.random.rand(*img_shape) + + if(self.block_size_min is None): + block_size_min = [img_shape[1+i] - 4 * img_shape[1+i]//7 for i in range(img_dim)] + elif(isinstance(self.block_size_min, int)): + block_size_min = [self.block_size_min] * img_dim + else: + assert(len(self.block_size_min) == img_dim) + block_size_min = self.block_size_min + + if(self.block_size_max is None): + block_size_max = [img_shape[1+i] - 3 * img_shape[1+i]//7 for i in range(img_dim)] + elif(isinstance(self.block_size_min, int)): + block_size_max = [self.block_size_max] * img_dim + else: + assert(len(self.block_size_max) == img_dim) + block_size_max = self.block_size_max + block_num = random.randint(self.block_range[0], self.block_range[1]) + + for n in range(block_num): + block_size = [random.randint(block_size_min[i], block_size_max[i]) \ + for i in range(img_dim)] + coord_min = [random.randint(3, img_shape[1+i] - block_size[i] - 3) \ + for i in range(img_dim)] + if(img_dim == 2): + img_out[:, coord_min[0]:coord_min[0] + block_size[0], + coord_min[1]:coord_min[1] + block_size[1]] = \ + image[:, coord_min[0]:coord_min[0] + block_size[0], + coord_min[1]:coord_min[1] + block_size[1]] + else: + img_out[:, coord_min[0]:coord_min[0] + block_size[0], + coord_min[1]:coord_min[1] + block_size[1], + coord_min[2]:coord_min[2] + block_size[2]] = \ + image[:, coord_min[0]:coord_min[0] + block_size[0], + coord_min[1]:coord_min[1] + block_size[1], + coord_min[2]:coord_min[2] + block_size[2]] + sample['image'] = img_out + return sample + +class InOutPainting(AbstractTransform): + """ + Apply in-painting or out-patining randomly. They are mutually exclusive. + """ + def __init__(self, params): + super(InOutPainting, self).__init__(params) + self.inverse = params.get('InOutPainting_inverse'.lower(), False) + self.prob = params.get('InOutPainting_probability'.lower(), 0.5) + self.in_prob = params.get('InPainting_probability'.lower(), 0.5) + params['InPainting_probability'] = 1.0 + params['outPainting_probability'] = 1.0 + self.inpaint = InPainting(params) + self.outpaint = OutPainting(params) + + def __call__(self, sample): + if(random.random() > self.prob): + return sample + if(random.random() < self.in_prob): + sample = self.inpaint(sample) + else: + sample = self.outpaint(sample) return sample \ No newline at end of file diff --git a/pymic/transform/trans_dict.py b/pymic/transform/trans_dict.py index 864dae1..bc72c93 100644 --- a/pymic/transform/trans_dict.py +++ b/pymic/transform/trans_dict.py @@ -30,15 +30,15 @@ """ from __future__ import print_function, division from pymic.transform.intensity import * -from pymic.transform.flip import RandomFlip -from pymic.transform.pad import Pad -from pymic.transform.rotate import RandomRotate -from pymic.transform.rescale import Rescale, RandomRescale -from pymic.transform.transpose import RandomTranspose +from pymic.transform.flip import * +from pymic.transform.pad import * +from pymic.transform.rotate import * +from pymic.transform.rescale import * +from pymic.transform.transpose import * from pymic.transform.threshold import * from pymic.transform.normalize import * from pymic.transform.crop import * -from pymic.transform.label_convert import * +from pymic.transform.label_convert import * TransformDict = { 'ChannelWiseThreshold': ChannelWiseThreshold, @@ -48,9 +48,13 @@ 'GrayscaleToRGB': GrayscaleToRGB, 'GammaCorrection': GammaCorrection, 'GaussianNoise': GaussianNoise, + 'InPainting': InPainting, + 'InOutPainting': InOutPainting, 'LabelConvert': LabelConvert, 'LabelConvertNonzero': LabelConvertNonzero, 'LabelToProbability': LabelToProbability, + 'LocalShuffling': LocalShuffling, + 'NonLinearTransform': NonLinearTransform, 'NormalizeWithMeanStd': NormalizeWithMeanStd, 'NormalizeWithMinMax': NormalizeWithMinMax, 'NormalizeWithPercentiles': NormalizeWithPercentiles, @@ -63,5 +67,6 @@ 'RandomRotate': RandomRotate, 'ReduceLabelDim': ReduceLabelDim, 'Rescale': Rescale, + 'OutPainting': OutPainting, 'Pad': Pad, } diff --git a/requirements.txt b/requirements.txt index 49912a4..cac47f3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +h5py matplotlib>=3.1.2 numpy>=1.17.4 pandas>=0.25.3 diff --git a/setup.py b/setup.py index 9a7f38b..6c05058 100644 --- a/setup.py +++ b/setup.py @@ -21,6 +21,7 @@ license = 'Apache 2.0', packages = setuptools.find_packages(), install_requires=[ + "h5py", "matplotlib>=3.1.2", "numpy>=1.17.4", "pandas>=0.25.3", From 8c1579824dd5b27bf7dba6e4a11b1866e9ef7a3e Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 16 Feb 2023 14:01:05 +0800 Subject: [PATCH 132/225] update ssl and wsl fils update ssl and wsl, print average foreground dice for each epoch --- pymic/net/net2d/unet2d.py | 24 +++++++++++++----- pymic/net/net3d/__init__.py | 2 ++ pymic/net/net3d/unet3d.py | 36 ++++++++++++++++----------- pymic/net_run_ssl/ssl_abstract.py | 10 ++++---- pymic/net_run_ssl/ssl_cct.py | 4 +-- pymic/net_run_ssl/ssl_cps.py | 6 ++--- pymic/net_run_ssl/ssl_em.py | 4 +-- pymic/net_run_ssl/ssl_mt.py | 4 +-- pymic/net_run_ssl/ssl_uamt.py | 4 +-- pymic/net_run_ssl/ssl_urpc.py | 4 +-- pymic/net_run_wsl/wsl_abstract.py | 10 ++++---- pymic/net_run_wsl/wsl_dmpls.py | 8 +++--- pymic/net_run_wsl/wsl_em.py | 4 +-- pymic/net_run_wsl/wsl_gatedcrf.py | 4 +-- pymic/net_run_wsl/wsl_mumford_shah.py | 4 +-- pymic/net_run_wsl/wsl_tv.py | 4 +-- pymic/net_run_wsl/wsl_ustm.py | 4 +-- pymic/util/__init__.py | 2 ++ 18 files changed, 82 insertions(+), 56 deletions(-) diff --git a/pymic/net/net2d/unet2d.py b/pymic/net/net2d/unet2d.py index 89b0b4f..b9e25f5 100644 --- a/pymic/net/net2d/unet2d.py +++ b/pymic/net/net2d/unet2d.py @@ -131,6 +131,7 @@ class Decoder(nn.Module): :param class_num: (int) The class number for segmentation task. :param bilinear: (bool) Using bilinear for up-sampling or not. If False, deconvolution will be used for up-sampling. + :param multiscale_pred: (bool) Get multi-scale prediction. """ def __init__(self, params): super(Decoder, self).__init__() @@ -139,7 +140,8 @@ def __init__(self, params): self.ft_chns = self.params['feature_chns'] self.dropout = self.params['dropout'] self.n_class = self.params['class_num'] - self.bilinear = self.params['bilinear'] + self.bilinear = self.params.get('bilinear', True) + self.mul_pred = self.params.get('multiscale_pred', False) assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) @@ -149,6 +151,10 @@ def __init__(self, params): self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.bilinear) self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.bilinear) self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 1) + if(self.mul_pred): + self.out_conv1 = nn.Conv2d(self.ft_chns[1], self.n_class, kernel_size = 1) + self.out_conv2 = nn.Conv2d(self.ft_chns[2], self.n_class, kernel_size = 1) + self.out_conv3 = nn.Conv2d(self.ft_chns[3], self.n_class, kernel_size = 1) def forward(self, x): if(len(self.ft_chns) == 5): @@ -163,6 +169,11 @@ def forward(self, x): x_d1 = self.up3(x_d2, x1) x_d0 = self.up4(x_d1, x0) output = self.out_conv(x_d0) + if(self.mul_pred): + output1 = self.out_conv1(x_d1) + output2 = self.out_conv2(x_d2) + output3 = self.out_conv3(x_d3) + output = [output, output1, output2, output3] return output class UNet2D(nn.Module): @@ -187,7 +198,7 @@ class UNet2D(nn.Module): :param class_num: (int) The class number for segmentation task. :param bilinear: (bool) Using bilinear for up-sampling or not. If False, deconvolution will be used for up-sampling. - :param deep_supervise: (bool) Using deep supervision for training or not. + :param multiscale_pred: (bool) Get multiscale prediction. """ def __init__(self, params): super(UNet2D, self).__init__() @@ -197,7 +208,7 @@ def __init__(self, params): self.dropout = self.params['dropout'] self.n_class = self.params['class_num'] self.bilinear = self.params['bilinear'] - self.deep_sup = self.params['deep_supervise'] + self.mul_pred = self.params['multiscale_pred'] assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) @@ -213,7 +224,7 @@ def __init__(self, params): self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.bilinear) self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 1) - if(self.deep_sup): + if(self.mul_pred): self.out_conv1 = nn.Conv2d(self.ft_chns[1], self.n_class, kernel_size = 1) self.out_conv2 = nn.Conv2d(self.ft_chns[2], self.n_class, kernel_size = 1) self.out_conv3 = nn.Conv2d(self.ft_chns[3], self.n_class, kernel_size = 1) @@ -239,7 +250,7 @@ def forward(self, x): x_d1 = self.up3(x_d2, x1) x_d0 = self.up4(x_d1, x0) output = self.out_conv(x_d0) - if(self.deep_sup): + if(self.mul_pred): output1 = self.out_conv1(x_d1) output2 = self.out_conv2(x_d2) output3 = self.out_conv3(x_d3) @@ -261,7 +272,8 @@ def forward(self, x): 'feature_chns':[2, 8, 32, 48, 64], 'dropout': [0, 0, 0.3, 0.4, 0.5], 'class_num': 2, - 'bilinear': True} + 'bilinear': True, + 'multiscale_pred': False} Net = UNet2D(params) Net = Net.double() diff --git a/pymic/net/net3d/__init__.py b/pymic/net/net3d/__init__.py index e69de29..72b8078 100644 --- a/pymic/net/net3d/__init__.py +++ b/pymic/net/net3d/__init__.py @@ -0,0 +1,2 @@ +from __future__ import absolute_import +from . import * \ No newline at end of file diff --git a/pymic/net/net3d/unet3d.py b/pymic/net/net3d/unet3d.py index caf4e5b..a17bcb8 100644 --- a/pymic/net/net3d/unet3d.py +++ b/pymic/net/net3d/unet3d.py @@ -131,6 +131,7 @@ class Decoder(nn.Module): :param class_num: (int) The class number for segmentation task. :param trilinear: (bool) Using bilinear for up-sampling or not. If False, deconvolution will be used for up-sampling. + :param multiscale_pred: (bool) Get multi-scale prediction. """ def __init__(self, params): super(Decoder, self).__init__() @@ -139,16 +140,21 @@ def __init__(self, params): self.ft_chns = self.params['feature_chns'] self.dropout = self.params['dropout'] self.n_class = self.params['class_num'] - self.trilinear = self.params['trilinear'] + self.trilinear = self.params.get('trilinear', True) + self.mul_pred = self.params.get('multiscale_pred', False) assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) if(len(self.ft_chns) == 5): - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.bilinear) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.bilinear) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.bilinear) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.bilinear) + self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.trilinear) + self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.trilinear) + self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.trilinear) + self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.trilinear) self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 1) + if(self.mul_pred): + self.out_conv1 = nn.Conv3d(self.ft_chns[1], self.n_class, kernel_size = 1) + self.out_conv2 = nn.Conv3d(self.ft_chns[2], self.n_class, kernel_size = 1) + self.out_conv3 = nn.Conv3d(self.ft_chns[3], self.n_class, kernel_size = 1) def forward(self, x): if(len(self.ft_chns) == 5): @@ -163,6 +169,11 @@ def forward(self, x): x_d1 = self.up3(x_d2, x1) x_d0 = self.up4(x_d1, x0) output = self.out_conv(x_d0) + if(self.mul_pred): + output1 = self.out_conv1(x_d1) + output2 = self.out_conv2(x_d2) + output3 = self.out_conv3(x_d3) + output = [output, output1, output2, output3] return output class UNet3D(nn.Module): @@ -187,7 +198,7 @@ class UNet3D(nn.Module): :param class_num: (int) The class number for segmentation task. :param trilinear: (bool) Using trilinear for up-sampling or not. If False, deconvolution will be used for up-sampling. - :param deep_supervise: (bool) Using deep supervision for training or not. + :param multiscale_pred: (bool) Get multi-scale prediction. """ def __init__(self, params): super(UNet3D, self).__init__() @@ -197,7 +208,7 @@ def __init__(self, params): self.dropout = self.params['dropout'] self.n_class = self.params['class_num'] self.trilinear = self.params['trilinear'] - self.deep_sup = self.params['deep_supervise'] + self.mul_pred = self.params['multiscale_pred'] assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) @@ -216,7 +227,7 @@ def __init__(self, params): dropout_p = self.dropout[0], trilinear=self.trilinear) self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 1) - if(self.deep_sup): + if(self.mul_pred): self.out_conv1 = nn.Conv3d(self.ft_chns[1], self.n_class, kernel_size = 1) self.out_conv2 = nn.Conv3d(self.ft_chns[2], self.n_class, kernel_size = 1) self.out_conv3 = nn.Conv3d(self.ft_chns[3], self.n_class, kernel_size = 1) @@ -235,14 +246,10 @@ def forward(self, x): x_d1 = self.up3(x_d2, x1) x_d0 = self.up4(x_d1, x0) output = self.out_conv(x_d0) - if(self.deep_sup): - out_shape = list(output.shape)[2:] + if(self.mul_pred): output1 = self.out_conv1(x_d1) - output1 = interpolate(output1, out_shape, mode = 'trilinear') output2 = self.out_conv2(x_d2) - output2 = interpolate(output2, out_shape, mode = 'trilinear') output3 = self.out_conv3(x_d3) - output3 = interpolate(output3, out_shape, mode = 'trilinear') output = [output, output1, output2, output3] return output @@ -251,7 +258,8 @@ def forward(self, x): 'class_num': 2, 'feature_chns':[2, 8, 32, 64], 'dropout' : [0, 0, 0, 0.5], - 'trilinear': True} + 'trilinear': True, + 'multiscale_pred': False} Net = UNet3D(params) Net = Net.double() diff --git a/pymic/net_run_ssl/ssl_abstract.py b/pymic/net_run_ssl/ssl_abstract.py index b3ab9cd..5a46257 100644 --- a/pymic/net_run_ssl/ssl_abstract.py +++ b/pymic/net_run_ssl/ssl_abstract.py @@ -83,7 +83,7 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): 'valid':valid_scalars['loss']} loss_sup_scalar = {'train':train_scalars['loss_sup']} loss_upsup_scalar = {'train':train_scalars['loss_reg']} - dice_scalar ={'train':train_scalars['avg_dice'], 'valid':valid_scalars['avg_dice']} + dice_scalar ={'train':train_scalars['avg_fg_dice'], 'valid':valid_scalars['avg_fg_dice']} self.summ_writer.add_scalars('loss', loss_scalar, glob_it) self.summ_writer.add_scalars('loss_sup', loss_sup_scalar, glob_it) self.summ_writer.add_scalars('loss_reg', loss_upsup_scalar, glob_it) @@ -95,11 +95,11 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): cls_dice_scalar = {'train':train_scalars['class_dice'][c], \ 'valid':valid_scalars['class_dice'][c]} self.summ_writer.add_scalars('class_{0:}_dice'.format(c), cls_dice_scalar, glob_it) - logging.info('train loss {0:.4f}, avg dice {1:.4f} '.format( - train_scalars['loss'], train_scalars['avg_dice']) + "[" + \ + logging.info('train loss {0:.4f}, avg foreground dice {1:.4f} '.format( + train_scalars['loss'], train_scalars['avg_fg_dice']) + "[" + \ ' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]") - logging.info('valid loss {0:.4f}, avg dice {1:.4f} '.format( - valid_scalars['loss'], valid_scalars['avg_dice']) + "[" + \ + logging.info('valid loss {0:.4f}, avg foreground dice {1:.4f} '.format( + valid_scalars['loss'], valid_scalars['avg_fg_dice']) + "[" + \ ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") def train_valid(self): diff --git a/pymic/net_run_ssl/ssl_cct.py b/pymic/net_run_ssl/ssl_cct.py index 0cb16aa..2fbf982 100644 --- a/pymic/net_run_ssl/ssl_cct.py +++ b/pymic/net_run_ssl/ssl_cct.py @@ -154,9 +154,9 @@ def training(self): train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) - train_avg_dice = train_cls_dice.mean() + train_avg_dice = train_cls_dice[1:].mean() train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} return train_scalers diff --git a/pymic/net_run_ssl/ssl_cps.py b/pymic/net_run_ssl/ssl_cps.py index 637fe2e..2a76a86 100644 --- a/pymic/net_run_ssl/ssl_cps.py +++ b/pymic/net_run_ssl/ssl_cps.py @@ -137,12 +137,12 @@ def training(self): train_avg_loss_pse_sup1 = train_loss_pseudo_sup1 / iter_valid train_avg_loss_pse_sup2 = train_loss_pseudo_sup2 / iter_valid train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) - train_avg_dice = train_cls_dice.mean() + train_avg_dice = train_cls_dice[1:].mean() train_scalers = {'loss': train_avg_loss, 'loss_sup1':train_avg_loss_sup1, 'loss_sup2': train_avg_loss_sup2, 'loss_pse_sup1':train_avg_loss_pse_sup1, 'loss_pse_sup2': train_avg_loss_pse_sup2, - 'regular_w':regular_w, 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'regular_w':regular_w, 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} return train_scalers def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): @@ -152,7 +152,7 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): 'net2':train_scalars['loss_sup2']} loss_pse_sup_scalar = {'net1':train_scalars['loss_pse_sup1'], 'net2':train_scalars['loss_pse_sup2']} - dice_scalar ={'train':train_scalars['avg_dice'], 'valid':valid_scalars['avg_dice']} + dice_scalar ={'train':train_scalars['avg_fg_dice'], 'valid':valid_scalars['avg_fg_dice']} self.summ_writer.add_scalars('loss', loss_scalar, glob_it) self.summ_writer.add_scalars('loss_sup', loss_sup_scalar, glob_it) self.summ_writer.add_scalars('loss_pseudo_sup', loss_pse_sup_scalar, glob_it) diff --git a/pymic/net_run_ssl/ssl_em.py b/pymic/net_run_ssl/ssl_em.py index 8aebada..47b9fd2 100644 --- a/pymic/net_run_ssl/ssl_em.py +++ b/pymic/net_run_ssl/ssl_em.py @@ -98,9 +98,9 @@ def training(self): train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) - train_avg_dice = train_cls_dice.mean() + train_avg_dice = train_cls_dice[1:].mean() train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} return train_scalers \ No newline at end of file diff --git a/pymic/net_run_ssl/ssl_mt.py b/pymic/net_run_ssl/ssl_mt.py index 405ce41..0f3146b 100644 --- a/pymic/net_run_ssl/ssl_mt.py +++ b/pymic/net_run_ssl/ssl_mt.py @@ -123,9 +123,9 @@ def training(self): train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) - train_avg_dice = train_cls_dice.mean() + train_avg_dice = train_cls_dice[1:].mean() train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} return train_scalers \ No newline at end of file diff --git a/pymic/net_run_ssl/ssl_uamt.py b/pymic/net_run_ssl/ssl_uamt.py index 18ddd7b..427bdff 100644 --- a/pymic/net_run_ssl/ssl_uamt.py +++ b/pymic/net_run_ssl/ssl_uamt.py @@ -125,9 +125,9 @@ def training(self): train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) - train_avg_dice = train_cls_dice.mean() + train_avg_dice = train_cls_dice[1:].mean() train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} return train_scalers \ No newline at end of file diff --git a/pymic/net_run_ssl/ssl_urpc.py b/pymic/net_run_ssl/ssl_urpc.py index f404e64..3c92735 100644 --- a/pymic/net_run_ssl/ssl_urpc.py +++ b/pymic/net_run_ssl/ssl_urpc.py @@ -111,9 +111,9 @@ def training(self): train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) - train_avg_dice = train_cls_dice.mean() + train_avg_dice = train_cls_dice[1:].mean() train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} return train_scalers diff --git a/pymic/net_run_wsl/wsl_abstract.py b/pymic/net_run_wsl/wsl_abstract.py index f317dd0..f290465 100644 --- a/pymic/net_run_wsl/wsl_abstract.py +++ b/pymic/net_run_wsl/wsl_abstract.py @@ -24,7 +24,7 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): 'valid':valid_scalars['loss']} loss_sup_scalar = {'train':train_scalars['loss_sup']} loss_upsup_scalar = {'train':train_scalars['loss_reg']} - dice_scalar ={'train':train_scalars['avg_dice'], 'valid':valid_scalars['avg_dice']} + dice_scalar ={'train':train_scalars['avg_fg_dice'], 'valid':valid_scalars['avg_fg_dice']} self.summ_writer.add_scalars('loss', loss_scalar, glob_it) self.summ_writer.add_scalars('loss_sup', loss_sup_scalar, glob_it) self.summ_writer.add_scalars('loss_reg', loss_upsup_scalar, glob_it) @@ -36,9 +36,9 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): cls_dice_scalar = {'train':train_scalars['class_dice'][c], \ 'valid':valid_scalars['class_dice'][c]} self.summ_writer.add_scalars('class_{0:}_dice'.format(c), cls_dice_scalar, glob_it) - logging.info('train loss {0:.4f}, avg dice {1:.4f} '.format( - train_scalars['loss'], train_scalars['avg_dice']) + "[" + \ + logging.info('train loss {0:.4f}, avg foreground dice {1:.4f} '.format( + train_scalars['loss'], train_scalars['avg_fg_dice']) + "[" + \ ' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]") - logging.info('valid loss {0:.4f}, avg dice {1:.4f} '.format( - valid_scalars['loss'], valid_scalars['avg_dice']) + "[" + \ + logging.info('valid loss {0:.4f}, avg foreground dice {1:.4f} '.format( + valid_scalars['loss'], valid_scalars['avg_fg_dice']) + "[" + \ ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") diff --git a/pymic/net_run_wsl/wsl_dmpls.py b/pymic/net_run_wsl/wsl_dmpls.py index 3081f8e..a62d61e 100644 --- a/pymic/net_run_wsl/wsl_dmpls.py +++ b/pymic/net_run_wsl/wsl_dmpls.py @@ -4,6 +4,7 @@ import numpy as np import random import torch +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice @@ -90,7 +91,7 @@ def training(self): loss.backward() self.optimizer.step() - + train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() train_loss_reg = train_loss_reg + loss_reg.item() @@ -106,10 +107,11 @@ def training(self): train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) - train_avg_dice = train_cls_dice.mean() + train_avg_dice = train_cls_dice[1:].mean() train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + return train_scalers \ No newline at end of file diff --git a/pymic/net_run_wsl/wsl_em.py b/pymic/net_run_wsl/wsl_em.py index 6002547..f23fc65 100644 --- a/pymic/net_run_wsl/wsl_em.py +++ b/pymic/net_run_wsl/wsl_em.py @@ -87,9 +87,9 @@ def training(self): train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) - train_avg_dice = train_cls_dice.mean() + train_avg_dice = train_cls_dice[1:].mean() train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} return train_scalers \ No newline at end of file diff --git a/pymic/net_run_wsl/wsl_gatedcrf.py b/pymic/net_run_wsl/wsl_gatedcrf.py index 41605db..576703a 100644 --- a/pymic/net_run_wsl/wsl_gatedcrf.py +++ b/pymic/net_run_wsl/wsl_gatedcrf.py @@ -113,10 +113,10 @@ def training(self): train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) - train_avg_dice = train_cls_dice.mean() + train_avg_dice = train_cls_dice[1:].mean() train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} return train_scalers \ No newline at end of file diff --git a/pymic/net_run_wsl/wsl_mumford_shah.py b/pymic/net_run_wsl/wsl_mumford_shah.py index 431ec7a..e278d3c 100644 --- a/pymic/net_run_wsl/wsl_mumford_shah.py +++ b/pymic/net_run_wsl/wsl_mumford_shah.py @@ -88,10 +88,10 @@ def training(self): train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) - train_avg_dice = train_cls_dice.mean() + train_avg_dice = train_cls_dice[1:].mean() train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} return train_scalers \ No newline at end of file diff --git a/pymic/net_run_wsl/wsl_tv.py b/pymic/net_run_wsl/wsl_tv.py index 1612037..b5b2334 100644 --- a/pymic/net_run_wsl/wsl_tv.py +++ b/pymic/net_run_wsl/wsl_tv.py @@ -83,10 +83,10 @@ def training(self): train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) - train_avg_dice = train_cls_dice.mean() + train_avg_dice = train_cls_dice[1:].mean() train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} return train_scalers \ No newline at end of file diff --git a/pymic/net_run_wsl/wsl_ustm.py b/pymic/net_run_wsl/wsl_ustm.py index 5f69b4a..e17d9b9 100644 --- a/pymic/net_run_wsl/wsl_ustm.py +++ b/pymic/net_run_wsl/wsl_ustm.py @@ -142,9 +142,9 @@ def training(self): train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) - train_avg_dice = train_cls_dice.mean() + train_avg_dice = train_cls_dice[1:].mean() train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} return train_scalers \ No newline at end of file diff --git a/pymic/util/__init__.py b/pymic/util/__init__.py index e69de29..72b8078 100644 --- a/pymic/util/__init__.py +++ b/pymic/util/__init__.py @@ -0,0 +1,2 @@ +from __future__ import absolute_import +from . import * \ No newline at end of file From fee0a47695fde6f325bfc67c6af58e0c9e14f2c7 Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 21 Feb 2023 12:55:56 +0800 Subject: [PATCH 133/225] update ssl and wsl calculate mean dice for foreground classes allow loading pre-trained models --- pymic/__init__.py | 2 +- pymic/io/__init__.py | 4 ++- pymic/net/net2d/unet2d.py | 13 +------- pymic/net_run/agent_abstract.py | 8 ++--- pymic/net_run/agent_seg.py | 44 ++++++++++++++++++---------- pymic/net_run_nll/nll_co_teaching.py | 16 +++++----- pymic/net_run_nll/nll_dast.py | 5 ++-- pymic/net_run_nll/nll_main.py | 8 +++-- pymic/net_run_nll/nll_trinet.py | 14 ++++----- pymic/net_run_ssl/ssl_cps.py | 4 +-- pymic/net_run_ssl/ssl_urpc.py | 5 +++- pymic/net_run_wsl/wsl_ustm.py | 2 +- pymic/util/general.py | 19 ++++++++++++ 13 files changed, 86 insertions(+), 58 deletions(-) diff --git a/pymic/__init__.py b/pymic/__init__.py index 72b8078..d41dbb5 100644 --- a/pymic/__init__.py +++ b/pymic/__init__.py @@ -1,2 +1,2 @@ from __future__ import absolute_import -from . import * \ No newline at end of file +__version__ = "0.3.1" \ No newline at end of file diff --git a/pymic/io/__init__.py b/pymic/io/__init__.py index 72b8078..7f94f61 100644 --- a/pymic/io/__init__.py +++ b/pymic/io/__init__.py @@ -1,2 +1,4 @@ from __future__ import absolute_import -from . import * \ No newline at end of file +from pymic.io.image_read_write import * +from pymic.io.nifty_dataset import * +from pymic.io.h5_dataset import * \ No newline at end of file diff --git a/pymic/net/net2d/unet2d.py b/pymic/net/net2d/unet2d.py index b9e25f5..9acc0ad 100644 --- a/pymic/net/net2d/unet2d.py +++ b/pymic/net/net2d/unet2d.py @@ -131,7 +131,6 @@ class Decoder(nn.Module): :param class_num: (int) The class number for segmentation task. :param bilinear: (bool) Using bilinear for up-sampling or not. If False, deconvolution will be used for up-sampling. - :param multiscale_pred: (bool) Get multi-scale prediction. """ def __init__(self, params): super(Decoder, self).__init__() @@ -140,8 +139,7 @@ def __init__(self, params): self.ft_chns = self.params['feature_chns'] self.dropout = self.params['dropout'] self.n_class = self.params['class_num'] - self.bilinear = self.params.get('bilinear', True) - self.mul_pred = self.params.get('multiscale_pred', False) + self.bilinear = self.params['bilinear'] assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) @@ -151,10 +149,6 @@ def __init__(self, params): self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.bilinear) self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.bilinear) self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 1) - if(self.mul_pred): - self.out_conv1 = nn.Conv2d(self.ft_chns[1], self.n_class, kernel_size = 1) - self.out_conv2 = nn.Conv2d(self.ft_chns[2], self.n_class, kernel_size = 1) - self.out_conv3 = nn.Conv2d(self.ft_chns[3], self.n_class, kernel_size = 1) def forward(self, x): if(len(self.ft_chns) == 5): @@ -169,11 +163,6 @@ def forward(self, x): x_d1 = self.up3(x_d2, x1) x_d0 = self.up4(x_d1, x0) output = self.out_conv(x_d0) - if(self.mul_pred): - output1 = self.out_conv1(x_d1) - output2 = self.out_conv2(x_d2) - output3 = self.out_conv3(x_d3) - output = [output, output1, output2, output3] return output class UNet2D(nn.Module): diff --git a/pymic/net_run/agent_abstract.py b/pymic/net_run/agent_abstract.py index 50e67e2..9131ba0 100644 --- a/pymic/net_run/agent_abstract.py +++ b/pymic/net_run/agent_abstract.py @@ -276,7 +276,7 @@ def worker_init_fn(worker_id): self.test_loader = torch.utils.data.DataLoader(self.test_set, batch_size = bn_test, shuffle=False, num_workers= bn_test) - def create_optimizer(self, params): + def create_optimizer(self, params, checkpoint = None): """ Create optimizer based on configuration. @@ -288,9 +288,9 @@ def create_optimizer(self, params): self.optimizer = get_optimizer(opt_params['optimizer'], params, opt_params) last_iter = -1 - if(self.checkpoint is not None): - self.optimizer.load_state_dict(self.checkpoint['optimizer_state_dict']) - last_iter = self.checkpoint['iteration'] - 1 + if(checkpoint is not None): + self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + last_iter = checkpoint['iteration'] - 1 if(self.scheduler is None): opt_params["last_iter"] = last_iter self.scheduler = get_lr_scheduler(self.optimizer, opt_params) diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 83975f0..887a516 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -29,7 +29,7 @@ from pymic.transform.trans_dict import TransformDict from pymic.util.post_process import PostProcessDict from pymic.util.image_process import convert_label -from pymic.util.general import mixup +from pymic.util.general import mixup, tensor_shape_match class SegmentationAgent(NetRunAgent): def __init__(self, config, stage = 'train'): @@ -259,7 +259,8 @@ def train_valid(self): ckpt_prefix = self.config['training'].get('ckpt_prefix', None) if(ckpt_prefix is None): ckpt_prefix = ckpt_dir.split('/')[-1] - iter_start = self.config['training']['iter_start'] + # iter_start = self.config['training']['iter_start'] + iter_start = 0 iter_max = self.config['training']['iter_max'] iter_valid = self.config['training']['iter_valid'] iter_save = self.config['training'].get('iter_save', None) @@ -274,21 +275,32 @@ def train_valid(self): self.max_val_dice = 0.0 self.max_val_it = 0 self.best_model_wts = None - self.checkpoint = None - if(iter_start > 0): - checkpoint_file = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, iter_start) - self.checkpoint = torch.load(checkpoint_file, map_location = self.device) - # assert(self.checkpoint['iteration'] == iter_start) - if(len(device_ids) > 1): - self.net.module.load_state_dict(self.checkpoint['model_state_dict']) + checkpoint = None + # initialize the network with pre-trained weights + ckpt_init_name = self.config['training'].get('ckpt_init_name', None) + ckpt_init_mode = self.config['training'].get('ckpt_init_mode', 0) + ckpt_for_optm = None + if(ckpt_init_name is not None): + checkpoint = torch.load(ckpt_dir + "/" + ckpt_init_name, map_location = self.device) + pretrained_dict = checkpoint['model_state_dict'] + model_dict = self.net.module.state_dict() if (len(device_ids) > 1) else self.net.state_dict() + pretrained_dict = {k: v for k, v in pretrained_dict.items() if \ + k in model_dict and tensor_shape_match(pretrained_dict[k], model_dict[k])} + logging.info("Initializing the following parameters with pre-trained model") + for k in pretrained_dict: + logging.info(k) + if (len(device_ids) > 1): + self.net.module.load_state_dict(pretrained_dict, strict = False) else: - self.net.load_state_dict(self.checkpoint['model_state_dict']) - self.max_val_dice = self.checkpoint.get('valid_pred', 0) - # self.max_val_it = self.checkpoint['iteration'] - self.max_val_it = iter_start - self.best_model_wts = self.checkpoint['model_state_dict'] - - self.create_optimizer(self.get_parameters_to_update()) + self.net.load_state_dict(pretrained_dict, strict = False) + + if(ckpt_init_mode > 0): # Load other information + self.max_val_dice = checkpoint.get('valid_pred', 0) + iter_start = checkpoint['iteration'] - 1 + self.max_val_it = iter_start + self.best_model_wts = checkpoint['model_state_dict'] + ckpt_for_optm = checkpoint + self.create_optimizer(self.get_parameters_to_update(), ckpt_for_optm) self.create_loss_calculator() self.trainIter = iter(self.train_loader) diff --git a/pymic/net_run_nll/nll_co_teaching.py b/pymic/net_run_nll/nll_co_teaching.py index 465a316..ec8e230 100644 --- a/pymic/net_run_nll/nll_co_teaching.py +++ b/pymic/net_run_nll/nll_co_teaching.py @@ -32,7 +32,7 @@ def forward(self, x): if(self.training): return out1, out2 else: - return (out1 + out2) / 3 + return (out1 + out2) / 2 class NLLCoTeaching(SegmentationAgent): """ @@ -144,13 +144,13 @@ def training(self): train_avg_loss1 = train_loss1 / iter_valid train_avg_loss2 = train_loss2 / iter_valid train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) - train_avg_dice = train_cls_dice.mean() + train_avg_dice = train_cls_dice[1:].mean() train_scalers = {'loss': (train_avg_loss1 + train_avg_loss2) / 2, 'loss1':train_avg_loss1, 'loss2': train_avg_loss2, 'loss_no_select1':train_avg_loss_no_select1, 'loss_no_select2':train_avg_loss_no_select2, - 'select_ratio':remb_ratio, 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'select_ratio':remb_ratio, 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} return train_scalers def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): @@ -159,7 +159,7 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): loss_no_select_scalar = {'net1':train_scalars['loss_no_select1'], 'net2':train_scalars['loss_no_select2']} - dice_scalar ={'train':train_scalars['avg_dice'], 'valid':valid_scalars['avg_dice']} + dice_scalar ={'train':train_scalars['avg_fg_dice'], 'valid':valid_scalars['avg_fg_dice']} self.summ_writer.add_scalars('loss', loss_scalar, glob_it) self.summ_writer.add_scalars('loss_no_select', loss_no_select_scalar, glob_it) self.summ_writer.add_scalars('select_ratio', {'select_ratio':train_scalars['select_ratio']}, glob_it) @@ -171,9 +171,9 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): 'valid':valid_scalars['class_dice'][c]} self.summ_writer.add_scalars('class_{0:}_dice'.format(c), cls_dice_scalar, glob_it) - logging.info('train loss {0:.4f}, avg dice {1:.4f} '.format( - train_scalars['loss'], train_scalars['avg_dice']) + "[" + \ + logging.info('train loss {0:.4f}, avg foreground dice {1:.4f} '.format( + train_scalars['loss'], train_scalars['avg_fg_dice']) + "[" + \ ' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]") - logging.info('valid loss {0:.4f}, avg dice {1:.4f} '.format( - valid_scalars['loss'], valid_scalars['avg_dice']) + "[" + \ + logging.info('valid loss {0:.4f}, avg foreground dice {1:.4f} '.format( + valid_scalars['loss'], valid_scalars['avg_fg_dice']) + "[" + \ ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") diff --git a/pymic/net_run_nll/nll_dast.py b/pymic/net_run_nll/nll_dast.py index 07ed312..1921e9c 100644 --- a/pymic/net_run_nll/nll_dast.py +++ b/pymic/net_run_nll/nll_dast.py @@ -5,7 +5,6 @@ import numpy as np import torch.nn as nn import torchvision.transforms as transforms -from torch.optim import lr_scheduler from pymic.io.nifty_dataset import NiftyDataset from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth @@ -257,11 +256,11 @@ def training(self): train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) - train_avg_dice = train_cls_dice.mean() + train_avg_dice = train_cls_dice[1:].mean() train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':w_dbc, - 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} return train_scalers def train_valid(self): diff --git a/pymic/net_run_nll/nll_main.py b/pymic/net_run_nll/nll_main.py index a855f2b..a33a81b 100644 --- a/pymic/net_run_nll/nll_main.py +++ b/pymic/net_run_nll/nll_main.py @@ -28,8 +28,12 @@ def main(): log_dir = config['training']['ckpt_save_dir'] if(not os.path.exists(log_dir)): os.mkdir(log_dir) - logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, - format='%(message)s') + if sys.version.startswith("3.9"): + logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, + format='%(message)s', force=True) # for python 3.9 + else: + logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, + format='%(message)s') # for python 3.6 logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging_config(config) nll_method = config['noisy_label_learning']['nll_method'] diff --git a/pymic/net_run_nll/nll_trinet.py b/pymic/net_run_nll/nll_trinet.py index 2694013..25c90cf 100644 --- a/pymic/net_run_nll/nll_trinet.py +++ b/pymic/net_run_nll/nll_trinet.py @@ -140,13 +140,13 @@ def training(self): train_avg_loss1 = train_loss1 / iter_valid train_avg_loss2 = train_loss2 / iter_valid train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) - train_avg_dice = train_cls_dice.mean() + train_avg_dice = train_cls_dice[1:].mean() train_scalers = {'loss': (train_avg_loss1 + train_avg_loss2) / 2, 'loss1':train_avg_loss1, 'loss2': train_avg_loss2, 'loss_no_select1':train_avg_loss_no_select1, 'loss_no_select2':train_avg_loss_no_select2, - 'select_ratio':remb_ratio, 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'select_ratio':remb_ratio, 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} return train_scalers def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): @@ -155,7 +155,7 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): loss_no_select_scalar = {'net1':train_scalars['loss_no_select1'], 'net2':train_scalars['loss_no_select2']} - dice_scalar ={'train':train_scalars['avg_dice'], 'valid':valid_scalars['avg_dice']} + dice_scalar ={'train':train_scalars['avg_fg_dice'], 'valid':valid_scalars['avg_fg_dice']} self.summ_writer.add_scalars('loss', loss_scalar, glob_it) self.summ_writer.add_scalars('loss_no_select', loss_no_select_scalar, glob_it) self.summ_writer.add_scalars('select_ratio', {'select_ratio':train_scalars['select_ratio']}, glob_it) @@ -167,9 +167,9 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): 'valid':valid_scalars['class_dice'][c]} self.summ_writer.add_scalars('class_{0:}_dice'.format(c), cls_dice_scalar, glob_it) - logging.info('train loss {0:.4f}, avg dice {1:.4f} '.format( - train_scalars['loss'], train_scalars['avg_dice']) + "[" + \ + logging.info('train loss {0:.4f}, avg foregournd dice {1:.4f} '.format( + train_scalars['loss'], train_scalars['avg_fg_dice']) + "[" + \ ' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]") - logging.info('valid loss {0:.4f}, avg dice {1:.4f} '.format( - valid_scalars['loss'], valid_scalars['avg_dice']) + "[" + \ + logging.info('valid loss {0:.4f}, avg foreground dice {1:.4f} '.format( + valid_scalars['loss'], valid_scalars['avg_fg_dice']) + "[" + \ ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") diff --git a/pymic/net_run_ssl/ssl_cps.py b/pymic/net_run_ssl/ssl_cps.py index 2a76a86..a254b12 100644 --- a/pymic/net_run_ssl/ssl_cps.py +++ b/pymic/net_run_ssl/ssl_cps.py @@ -166,8 +166,8 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): self.summ_writer.add_scalars('class_{0:}_dice'.format(c), cls_dice_scalar, glob_it) logging.info('train loss {0:.4f}, avg dice {1:.4f} '.format( - train_scalars['loss'], train_scalars['avg_dice']) + "[" + \ + train_scalars['loss'], train_scalars['avg_fg_dice']) + "[" + \ ' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]") logging.info('valid loss {0:.4f}, avg dice {1:.4f} '.format( - valid_scalars['loss'], valid_scalars['avg_dice']) + "[" + \ + valid_scalars['loss'], valid_scalars['avg_fg_dice']) + "[" + \ ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") \ No newline at end of file diff --git a/pymic/net_run_ssl/ssl_urpc.py b/pymic/net_run_ssl/ssl_urpc.py index 3c92735..41a5bb2 100644 --- a/pymic/net_run_ssl/ssl_urpc.py +++ b/pymic/net_run_ssl/ssl_urpc.py @@ -71,8 +71,11 @@ def training(self): p0 = [output_i[:n0] for output_i in outputs_list] loss_sup = self.get_loss_value(data_lab, p0, y0) - # get average probability across scales + # resize to the same shape, and get average probability across scales outputs_soft_list = [torch.softmax(item, dim=1) for item in outputs_list] + for i in range(1, len(outputs_soft_list)): + outputs_soft_list[i] = nn.functional.interpolate(outputs_soft_list[i], + outputs_soft_list[0].shape[2:]) outputs_soft_avg = torch.mean(torch.stack(outputs_soft_list),dim = 0) p1_avg = outputs_soft_avg[n0:] * 0.99 + 0.005 # for unannotated images diff --git a/pymic/net_run_wsl/wsl_ustm.py b/pymic/net_run_wsl/wsl_ustm.py index e17d9b9..ea8d48c 100644 --- a/pymic/net_run_wsl/wsl_ustm.py +++ b/pymic/net_run_wsl/wsl_ustm.py @@ -123,7 +123,7 @@ def training(self): # update EMA alpha = wsl_cfg.get('ema_decay', 0.99) - alpha = min(1 - 1 / (iter_max + 1), alpha) + alpha = min(1 - 1 / (self.glob_it / iter_valid + 1), alpha) for ema_param, param in zip(self.net_ema.parameters(), self.net.parameters()): ema_param.data.mul_(alpha).add_(1 - alpha, param.data) diff --git a/pymic/util/general.py b/pymic/util/general.py index 99eb49f..075d2e1 100644 --- a/pymic/util/general.py +++ b/pymic/util/general.py @@ -9,6 +9,25 @@ def keyword_match(a,b): """ return a.lower() == b.lower() +def tensor_shape_match(a,b): + """ + Test if two tensors have the same shape""" + shape_a = list(a.shape) + shape_b = list(b.shape) + len_a = len(shape_a) + len_b = len(shape_b) + if(len_a != len_b): + return False + elif(len_a == 0): + return True + else: + for i in range(len_a): + if(shape_a[i] != shape_b[i]): + return False + return True + + + def get_one_hot_seg(label, class_num): """ Convert a segmentation label to one-hot. From 9ab3fca0a216a73a7249ae8bf8bfcb5339b01fbf Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 22 Feb 2023 15:02:06 +0800 Subject: [PATCH 134/225] Update nll_clslsr.py update nll_clsslsr --- pymic/net_run_nll/nll_clslsr.py | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/pymic/net_run_nll/nll_clslsr.py b/pymic/net_run_nll/nll_clslsr.py index 722b588..0148621 100644 --- a/pymic/net_run_nll/nll_clslsr.py +++ b/pymic/net_run_nll/nll_clslsr.py @@ -3,7 +3,6 @@ import logging import os import scipy -import sys import torch import numpy as np import pandas as pd @@ -27,7 +26,11 @@ def get_confident_map(gt, pred, CL_type = 'both'): :return: A tensor representing the noisiness of each pixel. """ - import cleanlab + try: + import cleanlab + assert(cleanlab.__version__ == '1.0.1') + except: + raise ValueError("Error: cleanlab 1.0.1 required. Please install it by `pip install cleanlab==1.0.1`") prob = scipy.special.softmax(pred, axis = 1) if CL_type in ['both', 'Qij']: noise = cleanlab.pruning.get_noise_indices(gt, prob, prune_method='both', n_jobs=1) @@ -146,15 +149,7 @@ def test_time_dropout(m): dst_path = os.path.join(save_dir, filename) conf_map.save(dst_path) -def get_confidence_map(): - """ - The main function to get the confidence map during inference. - """ - if(len(sys.argv) < 2): - print('Number of arguments should be 3. e.g.') - print(' python nll_clslsr.py config.cfg') - exit() - cfg_file = str(sys.argv[1]) +def get_confidence_map(cfg_file): config = parse_config(cfg_file) config = synchronize_config(config) @@ -173,7 +168,7 @@ def get_confidence_map(): one_transform = transform_dict[name](transform_param) transform_list.append(one_transform) data_transform = transforms.Compose(transform_list) - print('transform list', transform_list) + csv_file = config['dataset']['train_csv'] modal_num = config['dataset'].get('modal_num', 1) dataset = NiftyDataset(root_dir = config['dataset']['root_dir'], @@ -201,7 +196,4 @@ def get_confidence_map(): "label": df_train["label"]} train_cl_csv = csv_file.replace(".csv", "_clslsr.csv") df_cl = pd.DataFrame.from_dict(train_cl_dict) - df_cl.to_csv(train_cl_csv, index = False) - -if __name__ == "__main__": - get_confidence_map() \ No newline at end of file + df_cl.to_csv(train_cl_csv, index = False) \ No newline at end of file From d4d51dc5f5dfb0a83604be690ba7934bfabffe9d Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 22 Feb 2023 16:39:40 +0800 Subject: [PATCH 135/225] add self supervised learning add self supervised learning --- README.md | 2 +- pymic/__init__.py | 2 +- pymic/net_run_self_sl/self_sl_agent.py | 245 +++++++++++++++++++++++++ pymic/net_run_self_sl/self_sl_main.py | 87 +++++++++ setup.py | 5 +- 5 files changed, 337 insertions(+), 4 deletions(-) create mode 100644 pymic/net_run_self_sl/self_sl_agent.py create mode 100644 pymic/net_run_self_sl/self_sl_main.py diff --git a/README.md b/README.md index 3a552b4..dcceba8 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ BibTeX entry: # Features PyMIC provides flixible modules for medical image computing tasks including classification and segmentation. It currently provides the following functions: -* Support for annotation-efficient image segmentation, especially for semi-supervised, weakly-supervised and noisy-label learning. +* Support for annotation-efficient image segmentation, especially for semi-supervised, self-supervised, weakly-supervised and noisy-label learning. * User friendly: For beginners, you only need to edit the configuration files for model training and inference, without writing code. For advanced users, you can customize different modules (networks, loss functions, training pipeline, etc) and easily integrate them into PyMIC. * Easy-to-use I/O interface to read and write different 2D and 3D images. * Various data pre-processing/transformation methods before sending a tensor into a network. diff --git a/pymic/__init__.py b/pymic/__init__.py index d41dbb5..789b690 100644 --- a/pymic/__init__.py +++ b/pymic/__init__.py @@ -1,2 +1,2 @@ from __future__ import absolute_import -__version__ = "0.3.1" \ No newline at end of file +__version__ = "0.3.1.1" \ No newline at end of file diff --git a/pymic/net_run_self_sl/self_sl_agent.py b/pymic/net_run_self_sl/self_sl_agent.py new file mode 100644 index 0000000..24a6e66 --- /dev/null +++ b/pymic/net_run_self_sl/self_sl_agent.py @@ -0,0 +1,245 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import copy +import logging +import time +import logging +import numpy as np +import random +import torch +import torch.nn as nn +import torchvision.transforms as transforms +from datetime import datetime +from random import random +from torch.optim import lr_scheduler +from tensorboardX import SummaryWriter +from pymic.io.nifty_dataset import NiftyDataset +from pymic.loss.seg.util import get_soft_label +from pymic.loss.seg.util import reshape_prediction_and_ground_truth +from pymic.loss.seg.util import get_classwise_dice +from pymic.net_run.infer_func import Inferer +from pymic.net_run.agent_seg import SegmentationAgent +from pymic.transform.trans_dict import TransformDict +from pymic.loss.seg.mse import MAELoss, MSELoss + +RegressionLossDict = { + 'MAELoss': MAELoss, + 'MSELoss': MSELoss + } + +class SelfSLSegAgent(SegmentationAgent): + """ + Abstract class for self-supervised segmentation. + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. + """ + def __init__(self, config, stage = 'train'): + super(SelfSLSegAgent, self).__init__(config, stage) + self.transform_dict = TransformDict + + def create_loss_calculator(self): + if(self.loss_dict is None): + self.loss_dict = RegressionLossDict + loss_name = self.config['training']['loss_type'] + if isinstance(loss_name, (list, tuple)): + raise ValueError("Undefined loss function {0:}".format(loss_name)) + elif (loss_name not in self.loss_dict): + raise ValueError("Undefined loss function {0:}".format(loss_name)) + else: + loss_param = self.config['training'] + loss_param['loss_softmax'] = False + base_loss = self.loss_dict[loss_name](self.config['training']) + if(self.config['training'].get('deep_supervise', False)): + raise ValueError("Deep supervised loss not implemented for self-supervised learning") + # weight = self.config['training'].get('deep_supervise_weight', None) + # mode = self.config['training'].get('deep_supervise_mode', 2) + # params = {'deep_supervise_weight': weight, + # 'deep_supervise_mode': mode, + # 'base_loss':base_loss} + # self.loss_calculator = DeepSuperviseLoss(params) + else: + self.loss_calculator = base_loss + + def training(self): + iter_valid = self.config['training']['iter_valid'] + train_loss = 0 + self.net.train() + for it in range(iter_valid): + try: + data = next(self.trainIter) + except StopIteration: + self.trainIter = iter(self.train_loader) + data = next(self.trainIter) + # get the inputs + inputs = self.convert_tensor_type(data['image']) + label = self.convert_tensor_type(data['label']) + + # for debug + # from pymic.io.image_read_write import save_nd_array_as_image + # for i in range(inputs.shape[0]): + # image_i = inputs[i][0] + # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) + # save_nd_array_as_image(image_i, image_name, reference_name = None) + # return + + inputs, label = inputs.to(self.device), label.to(self.device) + + # zero the parameter gradients + self.optimizer.zero_grad() + + # forward + backward + optimize + outputs = self.net(inputs) + outputs = nn.Sigmoid()(outputs) + loss = self.get_loss_value(data, outputs, label) + loss.backward() + self.optimizer.step() + train_loss = train_loss + loss.item() + # get dice evaluation for each class + if(isinstance(outputs, tuple) or isinstance(outputs, list)): + outputs = outputs[0] + + train_avg_loss = train_loss / iter_valid + train_scalers = {'loss': train_avg_loss} + return train_scalers + + def validation(self): + if(self.inferer is None): + infer_cfg = self.config['testing'] + self.inferer = Inferer(infer_cfg) + + valid_loss_list = [] + validIter = iter(self.valid_loader) + with torch.no_grad(): + self.net.eval() + for data in validIter: + inputs = self.convert_tensor_type(data['image']) + label = self.convert_tensor_type(data['label']) + inputs, label = inputs.to(self.device), label.to(self.device) + outputs = self.inferer.run(self.net, inputs) + outputs = nn.Sigmoid()(outputs) + # The tensors are on CPU when calculating loss for validation data + loss = self.get_loss_value(data, outputs, label) + valid_loss_list.append(loss.item()) + + valid_avg_loss = np.asarray(valid_loss_list).mean() + valid_scalers = {'loss': valid_avg_loss} + return valid_scalers + + def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): + loss_scalar ={'train':train_scalars['loss'], + 'valid':valid_scalars['loss']} + self.summ_writer.add_scalars('loss', loss_scalar, glob_it) + self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it) + logging.info('train loss {0:.4f}'.format(train_scalars['loss'])) + logging.info('valid loss {0:.4f}'.format(valid_scalars['loss'])) + + def train_valid(self): + device_ids = self.config['training']['gpus'] + if(len(device_ids) > 1): + self.device = torch.device("cuda:0") + self.net = nn.DataParallel(self.net, device_ids = device_ids) + else: + self.device = torch.device("cuda:{0:}".format(device_ids[0])) + self.net.to(self.device) + ckpt_dir = self.config['training']['ckpt_save_dir'] + ckpt_prefix = self.config['training'].get('ckpt_prefix', None) + if(ckpt_prefix is None): + ckpt_prefix = ckpt_dir.split('/')[-1] + iter_start = self.config['training']['iter_start'] + iter_max = self.config['training']['iter_max'] + iter_valid = self.config['training']['iter_valid'] + iter_save = self.config['training'].get('iter_save', None) + early_stop_it = self.config['training'].get('early_stop_patience', None) + if(iter_save is None): + iter_save_list = [iter_max] + elif(isinstance(iter_save, (tuple, list))): + iter_save_list = iter_save + else: + iter_save_list = range(0, iter_max + 1, iter_save) + + self.min_val_loss = 10000.0 + self.max_val_it = 0 + self.best_model_wts = None + self.checkpoint = None + if(iter_start > 0): + checkpoint_file = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, iter_start) + self.checkpoint = torch.load(checkpoint_file, map_location = self.device) + # assert(self.checkpoint['iteration'] == iter_start) + if(len(device_ids) > 1): + self.net.module.load_state_dict(self.checkpoint['model_state_dict']) + else: + self.net.load_state_dict(self.checkpoint['model_state_dict']) + self.min_val_loss = self.checkpoint.get('valid_loss', 10000) + # self.max_val_it = self.checkpoint['iteration'] + self.max_val_it = iter_start + self.best_model_wts = self.checkpoint['model_state_dict'] + + self.create_optimizer(self.get_parameters_to_update()) + self.create_loss_calculator() + + self.trainIter = iter(self.train_loader) + + logging.info("{0:} training start".format(str(datetime.now())[:-7])) + self.summ_writer = SummaryWriter(self.config['training']['ckpt_save_dir']) + self.glob_it = iter_start + for it in range(iter_start, iter_max, iter_valid): + lr_value = self.optimizer.param_groups[0]['lr'] + t0 = time.time() + train_scalars = self.training() + t1 = time.time() + valid_scalars = self.validation() + t2 = time.time() + if(isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): + self.scheduler.step(-valid_scalars['loss']) + else: + self.scheduler.step() + + self.glob_it = it + iter_valid + logging.info("\n{0:} it {1:}".format(str(datetime.now())[:-7], self.glob_it)) + logging.info('learning rate {0:}'.format(lr_value)) + logging.info("training/validation time: {0:.2f}s/{1:.2f}s".format(t1-t0, t2-t1)) + self.write_scalars(train_scalars, valid_scalars, lr_value, self.glob_it) + if(valid_scalars['loss'] < self.min_val_loss): + self.min_val_loss = valid_scalars['loss'] + self.max_val_it = self.glob_it + if(len(device_ids) > 1): + self.best_model_wts = copy.deepcopy(self.net.module.state_dict()) + else: + self.best_model_wts = copy.deepcopy(self.net.state_dict()) + + stop_now = True if(early_stop_it is not None and \ + self.glob_it - self.max_val_it > early_stop_it) else False + if ((self.glob_it in iter_save_list) or stop_now): + save_dict = {'iteration': self.glob_it, + 'valid_loss': valid_scalars['loss'], + 'model_state_dict': self.net.module.state_dict() \ + if len(device_ids) > 1 else self.net.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict()} + save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.glob_it) + torch.save(save_dict, save_name) + txt_file = open("{0:}/{1:}_latest.txt".format(ckpt_dir, ckpt_prefix), 'wt') + txt_file.write(str(self.glob_it)) + txt_file.close() + if(stop_now): + logging.info("The training is early stopped") + break + # save the best performing checkpoint + save_dict = {'iteration': self.max_val_it, + 'valid_loss': self.min_val_loss, + 'model_state_dict': self.best_model_wts, + 'optimizer_state_dict': self.optimizer.state_dict()} + save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.max_val_it) + torch.save(save_dict, save_name) + txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefix), 'wt') + txt_file.write(str(self.max_val_it)) + txt_file.close() + logging.info('The best performing iter is {0:}, valid loss {1:}'.format(\ + self.max_val_it, self.min_val_loss)) + self.summ_writer.close() \ No newline at end of file diff --git a/pymic/net_run_self_sl/self_sl_main.py b/pymic/net_run_self_sl/self_sl_main.py new file mode 100644 index 0000000..4a795d5 --- /dev/null +++ b/pymic/net_run_self_sl/self_sl_main.py @@ -0,0 +1,87 @@ + +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import logging +import os +import sys +import shutil +from pymic.util.parse_config import * +from pymic.net_run_self_sl.self_sl_agent import SelfSLSegAgent + +def model_genesis(stage, cfg_file): + config = parse_config(cfg_file) + transforms = ['RandomFlip', 'LocalShuffling', 'NonLinearTransform', 'InOutPainting'] + genesis_cfg = { + 'randomflip_flip_depth': True, + 'randomflip_flip_height': True, + 'randomflip_flip_width': True, + 'localshuffling_probability': 0.5, + 'nonLineartransform_probability': 0.9, + 'inoutpainting_probability': 0.9, + 'inpainting_probability': 0.2 + } + config['dataset']['train_transform'].extend(transforms) + config['dataset']['valid_transform'].extend(transforms) + config['dataset'].update(genesis_cfg) + + config = synchronize_config(config) + log_dir = config['training']['ckpt_save_dir'] + if(not os.path.exists(log_dir)): + os.mkdir(log_dir) + if(stage == "train"): + dst_cfg = cfg_file if "/" not in cfg_file else cfg_file.split("/")[-1] + shutil.copy(cfg_file, log_dir + "/" + dst_cfg) + if sys.version.startswith("3.9"): + logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, + format='%(message)s', force=True) # for python 3.9 + else: + logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, + format='%(message)s') # for python 3.6 + logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) + logging_config(config) + agent = SelfSLSegAgent(config, stage) + agent.run() + +def default_self_sl(stage, cfg_file): + config = parse_config(cfg_file) + config = synchronize_config(config) + log_dir = config['training']['ckpt_save_dir'] + if(not os.path.exists(log_dir)): + os.mkdir(log_dir) + if(stage == "train"): + dst_cfg = cfg_file if "/" not in cfg_file else cfg_file.split("/")[-1] + shutil.copy(cfg_file, log_dir + "/" + dst_cfg) + if sys.version.startswith("3.9"): + logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, + format='%(message)s', force=True) # for python 3.9 + else: + logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, + format='%(message)s') # for python 3.6 + logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) + logging_config(config) + agent = SelfSLSegAgent(config, stage) + agent.run() + + +if __name__ == "__main__": + if(len(sys.argv) < 3): + print('Number of arguments should be 3. e.g.') + print(' pymic_self_sl train config.cfg') + exit() + stage = str(sys.argv[1]) + cfg_file = str(sys.argv[2]) + config = parse_config(cfg_file) + method = "default" + if 'self_supervised_learning' in config: + method = config['self_supervised_learning'].get('self_sl_method', 'default') + print("the self supervised method is ", method) + if(method == "default"): + default_self_sl(stage, cfg_file) + elif(method == 'model_genesis'): + model_genesis(stage, cfg_file) + else: + raise ValueError("The specified method {0:} is not implemented. ".format(method) + \ + "Consider to set `self_sl_method = default` and use customized" + \ + " transforms for self-supervised learning.") + + \ No newline at end of file diff --git a/setup.py b/setup.py index 6c05058..e0fa448 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setuptools.setup( name = 'PYMIC', - version = "0.3.1", + version = "0.3.1.1", author ='PyMIC Consortium', author_email = 'wguotai@gmail.com', description = description, @@ -42,7 +42,8 @@ entry_points = { 'console_scripts': [ 'pymic_run = pymic.net_run.net_run:main', - 'pymic_ssl = pymic.net_run_ssl.ssl_main:main', + 'pymic_semi_sl = pymic.net_run_ssl.ssl_main:main', + 'pymic_self_sl = pymic.net_run_self_sl.self_sl_main:main', 'pymic_wsl = pymic.net_run_wsl.wsl_main:main', 'pymic_nll = pymic.net_run_nll.nll_main:main', 'pymic_eval_cls = pymic.util.evaluation_cls:main', From 3340cb0fc2008ea4f55bfc291af9e985b1ba5db6 Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 23 Feb 2023 11:23:00 +0800 Subject: [PATCH 136/225] reorganize code add train.py and predict.py move semi_sup, self_sup, noisy_label and weak_sup to net_run --- pymic/net_run/noisy_label/__init__.py | 8 ++ .../noisy_label}/nll_clslsr.py | 0 .../noisy_label}/nll_co_teaching.py | 0 .../noisy_label}/nll_dast.py | 0 .../noisy_label}/nll_trinet.py | 0 pymic/net_run/{net_run.py => predict.py} | 29 +++--- pymic/net_run/self_sup/__init__.py | 2 + .../self_sup}/self_sl_agent.py | 0 pymic/net_run/semi_sup/__init__.py | 16 +++ .../semi_sup}/ssl_abstract.py | 0 .../semi_sup}/ssl_cct.py | 2 +- .../semi_sup}/ssl_cps.py | 2 +- .../semi_sup}/ssl_em.py | 2 +- .../semi_sup}/ssl_mt.py | 2 +- .../semi_sup}/ssl_uamt.py | 2 +- .../semi_sup}/ssl_urpc.py | 2 +- pymic/net_run/train.py | 98 +++++++++++++++++++ pymic/net_run/weak_sup/__init__.py | 15 +++ .../weak_sup}/wsl_abstract.py | 0 .../weak_sup}/wsl_dmpls.py | 2 +- .../weak_sup}/wsl_em.py | 2 +- .../weak_sup}/wsl_gatedcrf.py | 2 +- .../weak_sup}/wsl_mumford_shah.py | 2 +- .../weak_sup}/wsl_tv.py | 2 +- .../weak_sup}/wsl_ustm.py | 2 +- pymic/net_run_nll/__init__.py | 2 - pymic/net_run_nll/nll_main.py | 46 --------- pymic/net_run_self_sl/self_sl_main.py | 87 ---------------- pymic/net_run_ssl/__init__.py | 8 -- pymic/net_run_ssl/ssl_main.py | 53 ---------- pymic/net_run_wsl/__init__.py | 2 - pymic/net_run_wsl/wsl_main.py | 52 ---------- pymic/util/parse_config.py | 9 +- setup.py | 7 +- 34 files changed, 172 insertions(+), 286 deletions(-) create mode 100644 pymic/net_run/noisy_label/__init__.py rename pymic/{net_run_nll => net_run/noisy_label}/nll_clslsr.py (100%) rename pymic/{net_run_nll => net_run/noisy_label}/nll_co_teaching.py (100%) rename pymic/{net_run_nll => net_run/noisy_label}/nll_dast.py (100%) rename pymic/{net_run_nll => net_run/noisy_label}/nll_trinet.py (100%) rename pymic/net_run/{net_run.py => predict.py} (50%) create mode 100644 pymic/net_run/self_sup/__init__.py rename pymic/{net_run_self_sl => net_run/self_sup}/self_sl_agent.py (100%) create mode 100644 pymic/net_run/semi_sup/__init__.py rename pymic/{net_run_ssl => net_run/semi_sup}/ssl_abstract.py (100%) rename pymic/{net_run_ssl => net_run/semi_sup}/ssl_cct.py (99%) rename pymic/{net_run_ssl => net_run/semi_sup}/ssl_cps.py (99%) rename pymic/{net_run_ssl => net_run/semi_sup}/ssl_em.py (98%) rename pymic/{net_run_ssl => net_run/semi_sup}/ssl_mt.py (99%) rename pymic/{net_run_ssl => net_run/semi_sup}/ssl_uamt.py (99%) rename pymic/{net_run_ssl => net_run/semi_sup}/ssl_urpc.py (99%) create mode 100644 pymic/net_run/train.py create mode 100644 pymic/net_run/weak_sup/__init__.py rename pymic/{net_run_wsl => net_run/weak_sup}/wsl_abstract.py (100%) rename pymic/{net_run_wsl => net_run/weak_sup}/wsl_dmpls.py (98%) rename pymic/{net_run_wsl => net_run/weak_sup}/wsl_em.py (98%) rename pymic/{net_run_wsl => net_run/weak_sup}/wsl_gatedcrf.py (99%) rename pymic/{net_run_wsl => net_run/weak_sup}/wsl_mumford_shah.py (98%) rename pymic/{net_run_wsl => net_run/weak_sup}/wsl_tv.py (98%) rename pymic/{net_run_wsl => net_run/weak_sup}/wsl_ustm.py (99%) delete mode 100644 pymic/net_run_nll/__init__.py delete mode 100644 pymic/net_run_nll/nll_main.py delete mode 100644 pymic/net_run_self_sl/self_sl_main.py delete mode 100644 pymic/net_run_ssl/__init__.py delete mode 100644 pymic/net_run_ssl/ssl_main.py delete mode 100644 pymic/net_run_wsl/__init__.py delete mode 100644 pymic/net_run_wsl/wsl_main.py diff --git a/pymic/net_run/noisy_label/__init__.py b/pymic/net_run/noisy_label/__init__.py new file mode 100644 index 0000000..448a5e4 --- /dev/null +++ b/pymic/net_run/noisy_label/__init__.py @@ -0,0 +1,8 @@ +from __future__ import absolute_import +from pymic.net_run.noisy_label.nll_co_teaching import NLLCoTeaching +from pymic.net_run.noisy_label.nll_trinet import NLLTriNet +from pymic.net_run.noisy_label.nll_dast import NLLDAST + +NLLMethodDict = {'CoTeaching': NLLCoTeaching, + "TriNet": NLLTriNet, + "DAST": NLLDAST} \ No newline at end of file diff --git a/pymic/net_run_nll/nll_clslsr.py b/pymic/net_run/noisy_label/nll_clslsr.py similarity index 100% rename from pymic/net_run_nll/nll_clslsr.py rename to pymic/net_run/noisy_label/nll_clslsr.py diff --git a/pymic/net_run_nll/nll_co_teaching.py b/pymic/net_run/noisy_label/nll_co_teaching.py similarity index 100% rename from pymic/net_run_nll/nll_co_teaching.py rename to pymic/net_run/noisy_label/nll_co_teaching.py diff --git a/pymic/net_run_nll/nll_dast.py b/pymic/net_run/noisy_label/nll_dast.py similarity index 100% rename from pymic/net_run_nll/nll_dast.py rename to pymic/net_run/noisy_label/nll_dast.py diff --git a/pymic/net_run_nll/nll_trinet.py b/pymic/net_run/noisy_label/nll_trinet.py similarity index 100% rename from pymic/net_run_nll/nll_trinet.py rename to pymic/net_run/noisy_label/nll_trinet.py diff --git a/pymic/net_run/net_run.py b/pymic/net_run/predict.py similarity index 50% rename from pymic/net_run/net_run.py rename to pymic/net_run/predict.py index c8d8dcc..8ac4af9 100644 --- a/pymic/net_run/net_run.py +++ b/pymic/net_run/predict.py @@ -3,7 +3,7 @@ import logging import os import sys -import shutil +from datetime import datetime from pymic.util.parse_config import * from pymic.net_run.agent_cls import ClassificationAgent from pymic.net_run.agent_seg import SegmentationAgent @@ -12,34 +12,31 @@ def main(): """ The main function for running a network for training or inference. """ - if(len(sys.argv) < 3): - print('Number of arguments should be 3. e.g.') - print(' pymic_run train config.cfg') + if(len(sys.argv) < 2): + print('Number of arguments should be 2. e.g.') + print(' pymic_test config.cfg') exit() - stage = str(sys.argv[1]) - cfg_file = str(sys.argv[2]) + cfg_file = str(sys.argv[1]) config = parse_config(cfg_file) config = synchronize_config(config) - log_dir = config['training']['ckpt_save_dir'] + log_dir = config['testing']['output_dir'] if(not os.path.exists(log_dir)): os.makedirs(log_dir, exist_ok=True) - if(stage == "train"): - dst_cfg = cfg_file if "/" not in cfg_file else cfg_file.split("/")[-1] - shutil.copy(cfg_file, log_dir + "/" + dst_cfg) + if sys.version.startswith("3.9"): - logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, - format='%(message)s', force=True) # for python 3.9 + logging.basicConfig(filename=log_dir+"/log_test.txt", + level=logging.INFO, format='%(message)s', force=True) # for python 3.9 else: - logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, - format='%(message)s') # for python 3.6 + logging.basicConfig(filename=log_dir+"/log_test.txt", + level=logging.INFO, format='%(message)s') # for python 3.6 logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging_config(config) task = config['dataset']['task_type'] assert task in ['cls', 'cls_nexcl', 'seg'] if(task == 'cls' or task == 'cls_nexcl'): - agent = ClassificationAgent(config, stage) + agent = ClassificationAgent(config, 'test') else: - agent = SegmentationAgent(config, stage) + agent = SegmentationAgent(config, 'test') agent.run() if __name__ == "__main__": diff --git a/pymic/net_run/self_sup/__init__.py b/pymic/net_run/self_sup/__init__.py new file mode 100644 index 0000000..55f26bf --- /dev/null +++ b/pymic/net_run/self_sup/__init__.py @@ -0,0 +1,2 @@ +from __future__ import absolute_import +from pymic.net_run.self_sup.self_sl_agent import SelfSLSegAgent \ No newline at end of file diff --git a/pymic/net_run_self_sl/self_sl_agent.py b/pymic/net_run/self_sup/self_sl_agent.py similarity index 100% rename from pymic/net_run_self_sl/self_sl_agent.py rename to pymic/net_run/self_sup/self_sl_agent.py diff --git a/pymic/net_run/semi_sup/__init__.py b/pymic/net_run/semi_sup/__init__.py new file mode 100644 index 0000000..be753c2 --- /dev/null +++ b/pymic/net_run/semi_sup/__init__.py @@ -0,0 +1,16 @@ +from __future__ import absolute_import +from pymic.net_run.semi_sup.ssl_abstract import SSLSegAgent +from pymic.net_run.semi_sup.ssl_em import SSLEntropyMinimization +from pymic.net_run.semi_sup.ssl_mt import SSLMeanTeacher +from pymic.net_run.semi_sup.ssl_uamt import SSLUncertaintyAwareMeanTeacher +from pymic.net_run.semi_sup.ssl_cct import SSLCCT +from pymic.net_run.semi_sup.ssl_cps import SSLCPS +from pymic.net_run.semi_sup.ssl_urpc import SSLURPC + + +SSLMethodDict = {'EntropyMinimization': SSLEntropyMinimization, + 'MeanTeacher': SSLMeanTeacher, + 'UAMT': SSLUncertaintyAwareMeanTeacher, + 'CCT': SSLCCT, + 'CPS': SSLCPS, + 'URPC': SSLURPC} \ No newline at end of file diff --git a/pymic/net_run_ssl/ssl_abstract.py b/pymic/net_run/semi_sup/ssl_abstract.py similarity index 100% rename from pymic/net_run_ssl/ssl_abstract.py rename to pymic/net_run/semi_sup/ssl_abstract.py diff --git a/pymic/net_run_ssl/ssl_cct.py b/pymic/net_run/semi_sup/ssl_cct.py similarity index 99% rename from pymic/net_run_ssl/ssl_cct.py rename to pymic/net_run/semi_sup/ssl_cct.py index 2fbf982..1943608 100644 --- a/pymic/net_run_ssl/ssl_cct.py +++ b/pymic/net_run/semi_sup/ssl_cct.py @@ -8,7 +8,7 @@ from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice -from pymic.net_run_ssl.ssl_abstract import SSLSegAgent +from pymic.net_run.semi_sup import SSLSegAgent from pymic.util.ramps import get_rampup_ratio def softmax_mse_loss(inputs, targets, conf_mask=False, threshold=None, use_softmax=False): diff --git a/pymic/net_run_ssl/ssl_cps.py b/pymic/net_run/semi_sup/ssl_cps.py similarity index 99% rename from pymic/net_run_ssl/ssl_cps.py rename to pymic/net_run/semi_sup/ssl_cps.py index a254b12..4a3be9c 100644 --- a/pymic/net_run_ssl/ssl_cps.py +++ b/pymic/net_run/semi_sup/ssl_cps.py @@ -7,7 +7,7 @@ from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice -from pymic.net_run_ssl.ssl_abstract import SSLSegAgent +from pymic.net_run.semi_sup import SSLSegAgent from pymic.net.net_dict_seg import SegNetDict from pymic.util.ramps import get_rampup_ratio diff --git a/pymic/net_run_ssl/ssl_em.py b/pymic/net_run/semi_sup/ssl_em.py similarity index 98% rename from pymic/net_run_ssl/ssl_em.py rename to pymic/net_run/semi_sup/ssl_em.py index 47b9fd2..fde941b 100644 --- a/pymic/net_run_ssl/ssl_em.py +++ b/pymic/net_run/semi_sup/ssl_em.py @@ -7,7 +7,7 @@ from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.ssl import EntropyLoss -from pymic.net_run_ssl.ssl_abstract import SSLSegAgent +from pymic.net_run.semi_sup import SSLSegAgent from pymic.transform.trans_dict import TransformDict from pymic.util.ramps import get_rampup_ratio diff --git a/pymic/net_run_ssl/ssl_mt.py b/pymic/net_run/semi_sup/ssl_mt.py similarity index 99% rename from pymic/net_run_ssl/ssl_mt.py rename to pymic/net_run/semi_sup/ssl_mt.py index 0f3146b..2a2abb8 100644 --- a/pymic/net_run_ssl/ssl_mt.py +++ b/pymic/net_run/semi_sup/ssl_mt.py @@ -6,7 +6,7 @@ from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice -from pymic.net_run_ssl.ssl_abstract import SSLSegAgent +from pymic.net_run.semi_sup import SSLSegAgent from pymic.net.net_dict_seg import SegNetDict from pymic.util.ramps import get_rampup_ratio diff --git a/pymic/net_run_ssl/ssl_uamt.py b/pymic/net_run/semi_sup/ssl_uamt.py similarity index 99% rename from pymic/net_run_ssl/ssl_uamt.py rename to pymic/net_run/semi_sup/ssl_uamt.py index 427bdff..6222fe3 100644 --- a/pymic/net_run_ssl/ssl_uamt.py +++ b/pymic/net_run/semi_sup/ssl_uamt.py @@ -6,7 +6,7 @@ from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice -from pymic.net_run_ssl.ssl_mt import SSLMeanTeacher +from pymic.net_run.semi_sup import SSLMeanTeacher from pymic.util.ramps import get_rampup_ratio class SSLUncertaintyAwareMeanTeacher(SSLMeanTeacher): diff --git a/pymic/net_run_ssl/ssl_urpc.py b/pymic/net_run/semi_sup/ssl_urpc.py similarity index 99% rename from pymic/net_run_ssl/ssl_urpc.py rename to pymic/net_run/semi_sup/ssl_urpc.py index 41a5bb2..56bb77e 100644 --- a/pymic/net_run_ssl/ssl_urpc.py +++ b/pymic/net_run/semi_sup/ssl_urpc.py @@ -7,7 +7,7 @@ from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice -from pymic.net_run_ssl.ssl_abstract import SSLSegAgent +from pymic.net_run.semi_sup import SSLSegAgent from pymic.util.ramps import get_rampup_ratio class SSLURPC(SSLSegAgent): diff --git a/pymic/net_run/train.py b/pymic/net_run/train.py new file mode 100644 index 0000000..4afae77 --- /dev/null +++ b/pymic/net_run/train.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import logging +import os +import sys +import shutil +from datetime import datetime +from pymic.util.parse_config import * +from pymic.net_run.agent_cls import ClassificationAgent +from pymic.net_run.agent_seg import SegmentationAgent +from pymic.net_run.semi_sup import SSLMethodDict +from pymic.net_run.weak_sup import WSLMethodDict +from pymic.net_run.noisy_label import NLLMethodDict +from pymic.net_run.self_sup import SelfSLSegAgent + +def get_segmentation_agent(config, sup_type): + assert(sup_type in ['fully_sup', 'semi_sup', 'self_sup', 'weak_sup', 'noisy_label']) + if(sup_type == 'fully_sup'): + logging.info("\n********** Fully Supervised Learning **********\n") + agent = SegmentationAgent(config, 'train') + elif(sup_type == 'semi_sup'): + logging.info("\n********** Semi Supervised Learning **********\n") + method = config['semi_supervised_learning']['method_name'] + agent = SSLMethodDict[method](config, 'train') + elif(sup_type == 'weak_sup'): + logging.info("\n********** Weakly Supervised Learning **********\n") + method = config['weakly_supervised_learning']['method_name'] + agent = WSLMethodDict[method](config, 'train') + elif(sup_type == 'noisy_label'): + logging.info("\n********** Noisy Label Learning **********\n") + method = config['noisy_label_learning']['method_name'] + agent = NLLMethodDict[method](config, 'train') + elif(sup_type == 'self_sup'): + logging.info("\n********** Self Supervised Learning **********\n") + method = config['self_supervised_learning']['method_name'] + if(method == "custom"): + pass + elif(method == "model_genesis"): + transforms = ['RandomFlip', 'LocalShuffling', 'NonLinearTransform', 'InOutPainting'] + genesis_cfg = { + 'randomflip_flip_depth': True, + 'randomflip_flip_height': True, + 'randomflip_flip_width': True, + 'localshuffling_probability': 0.5, + 'nonLineartransform_probability': 0.9, + 'inoutpainting_probability': 0.9, + 'inpainting_probability': 0.2 + } + config['dataset']['train_transform'].extend(transforms) + config['dataset']['valid_transform'].extend(transforms) + config['dataset'].update(genesis_cfg) + logging_config(config['dataset']) + else: + raise ValueError("The specified method {0:} is not implemented. ".format(method) + \ + "Consider to set `self_sl_method = custom` and use customized" + \ + " transforms for self-supervised learning.") + agent = SelfSLSegAgent(config, 'train') + else: + raise ValueError("undefined supervision type: {0:}".format(sup_type)) + return agent + +def main(): + """ + The main function for running a network for training. + """ + if(len(sys.argv) < 2): + print('Number of arguments should be 2. e.g.') + print(' pymic_train config.cfg') + exit() + cfg_file = str(sys.argv[1]) + config = parse_config(cfg_file) + config = synchronize_config(config) + log_dir = config['training']['ckpt_save_dir'] + if(not os.path.exists(log_dir)): + os.makedirs(log_dir, exist_ok=True) + dst_cfg = cfg_file if "/" not in cfg_file else cfg_file.split("/")[-1] + shutil.copy(cfg_file, log_dir + "/" + dst_cfg) + if sys.version.startswith("3.9"): + logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(str(datetime.now())[:-7]), + level=logging.INFO, format='%(message)s', force=True) # for python 3.9 + else: + logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(str(datetime.now())[:-7]), + level=logging.INFO, format='%(message)s') # for python 3.6 + logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) + logging_config(config) + task = config['dataset']['task_type'] + assert task in ['cls', 'cls_nexcl', 'seg'] + if(task == 'cls' or task == 'cls_nexcl'): + agent = ClassificationAgent(config, 'train') + else: + sup_type = config['dataset'].get('supervise_type', 'fully_sup') + agent = get_segmentation_agent(config, sup_type) + agent.run() + +if __name__ == "__main__": + main() + + diff --git a/pymic/net_run/weak_sup/__init__.py b/pymic/net_run/weak_sup/__init__.py new file mode 100644 index 0000000..b3c8332 --- /dev/null +++ b/pymic/net_run/weak_sup/__init__.py @@ -0,0 +1,15 @@ +from __future__ import absolute_import +from pymic.net_run.weak_sup.wsl_abstract import WSLSegAgent +from pymic.net_run.weak_sup.wsl_em import WSLEntropyMinimization +from pymic.net_run.weak_sup.wsl_gatedcrf import WSLGatedCRF +from pymic.net_run.weak_sup.wsl_mumford_shah import WSLMumfordShah +from pymic.net_run.weak_sup.wsl_tv import WSLTotalVariation +from pymic.net_run.weak_sup.wsl_ustm import WSLUSTM +from pymic.net_run.weak_sup.wsl_dmpls import WSLDMPLS + +WSLMethodDict = {'EntropyMinimization': WSLEntropyMinimization, + 'GatedCRF': WSLGatedCRF, + 'MumfordShah': WSLMumfordShah, + 'TotalVariation': WSLTotalVariation, + 'USTM': WSLUSTM, + 'DMPLS': WSLDMPLS} \ No newline at end of file diff --git a/pymic/net_run_wsl/wsl_abstract.py b/pymic/net_run/weak_sup/wsl_abstract.py similarity index 100% rename from pymic/net_run_wsl/wsl_abstract.py rename to pymic/net_run/weak_sup/wsl_abstract.py diff --git a/pymic/net_run_wsl/wsl_dmpls.py b/pymic/net_run/weak_sup/wsl_dmpls.py similarity index 98% rename from pymic/net_run_wsl/wsl_dmpls.py rename to pymic/net_run/weak_sup/wsl_dmpls.py index a62d61e..4212409 100644 --- a/pymic/net_run_wsl/wsl_dmpls.py +++ b/pymic/net_run/weak_sup/wsl_dmpls.py @@ -9,7 +9,7 @@ from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.dice import DiceLoss -from pymic.net_run_wsl.wsl_abstract import WSLSegAgent +from pymic.net_run.weak_sup import WSLSegAgent from pymic.util.ramps import get_rampup_ratio class WSLDMPLS(WSLSegAgent): diff --git a/pymic/net_run_wsl/wsl_em.py b/pymic/net_run/weak_sup/wsl_em.py similarity index 98% rename from pymic/net_run_wsl/wsl_em.py rename to pymic/net_run/weak_sup/wsl_em.py index f23fc65..adcd70c 100644 --- a/pymic/net_run_wsl/wsl_em.py +++ b/pymic/net_run/weak_sup/wsl_em.py @@ -8,7 +8,7 @@ from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.ssl import EntropyLoss from pymic.net_run.agent_seg import SegmentationAgent -from pymic.net_run_wsl.wsl_abstract import WSLSegAgent +from pymic.net_run.weak_sup import WSLSegAgent from pymic.util.ramps import get_rampup_ratio class WSLEntropyMinimization(WSLSegAgent): diff --git a/pymic/net_run_wsl/wsl_gatedcrf.py b/pymic/net_run/weak_sup/wsl_gatedcrf.py similarity index 99% rename from pymic/net_run_wsl/wsl_gatedcrf.py rename to pymic/net_run/weak_sup/wsl_gatedcrf.py index 576703a..2ce1f95 100644 --- a/pymic/net_run_wsl/wsl_gatedcrf.py +++ b/pymic/net_run/weak_sup/wsl_gatedcrf.py @@ -7,7 +7,7 @@ from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.gatedcrf import GatedCRFLoss -from pymic.net_run_wsl.wsl_abstract import WSLSegAgent +from pymic.net_run.weak_sup import WSLSegAgent from pymic.util.ramps import get_rampup_ratio class WSLGatedCRF(WSLSegAgent): diff --git a/pymic/net_run_wsl/wsl_mumford_shah.py b/pymic/net_run/weak_sup/wsl_mumford_shah.py similarity index 98% rename from pymic/net_run_wsl/wsl_mumford_shah.py rename to pymic/net_run/weak_sup/wsl_mumford_shah.py index e278d3c..2480fee 100644 --- a/pymic/net_run_wsl/wsl_mumford_shah.py +++ b/pymic/net_run/weak_sup/wsl_mumford_shah.py @@ -7,7 +7,7 @@ from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.mumford_shah import MumfordShahLoss -from pymic.net_run_wsl.wsl_abstract import WSLSegAgent +from pymic.net_run.weak_sup import WSLSegAgent from pymic.util.ramps import get_rampup_ratio class WSLMumfordShah(WSLSegAgent): diff --git a/pymic/net_run_wsl/wsl_tv.py b/pymic/net_run/weak_sup/wsl_tv.py similarity index 98% rename from pymic/net_run_wsl/wsl_tv.py rename to pymic/net_run/weak_sup/wsl_tv.py index b5b2334..9d13c5d 100644 --- a/pymic/net_run_wsl/wsl_tv.py +++ b/pymic/net_run/weak_sup/wsl_tv.py @@ -7,7 +7,7 @@ from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.ssl import TotalVariationLoss -from pymic.net_run_wsl.wsl_abstract import WSLSegAgent +from pymic.net_run.weak_sup import WSLSegAgent from pymic.util.ramps import get_rampup_ratio from pymic.util.general import keyword_match diff --git a/pymic/net_run_wsl/wsl_ustm.py b/pymic/net_run/weak_sup/wsl_ustm.py similarity index 99% rename from pymic/net_run_wsl/wsl_ustm.py rename to pymic/net_run/weak_sup/wsl_ustm.py index ea8d48c..0ea3fbc 100644 --- a/pymic/net_run_wsl/wsl_ustm.py +++ b/pymic/net_run/weak_sup/wsl_ustm.py @@ -9,7 +9,7 @@ from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.net.net_dict_seg import SegNetDict -from pymic.net_run_wsl.wsl_abstract import WSLSegAgent +from pymic.net_run.weak_sup import WSLSegAgent from pymic.util.ramps import get_rampup_ratio from pymic.util.general import keyword_match diff --git a/pymic/net_run_nll/__init__.py b/pymic/net_run_nll/__init__.py deleted file mode 100644 index 72b8078..0000000 --- a/pymic/net_run_nll/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from __future__ import absolute_import -from . import * \ No newline at end of file diff --git a/pymic/net_run_nll/nll_main.py b/pymic/net_run_nll/nll_main.py deleted file mode 100644 index a33a81b..0000000 --- a/pymic/net_run_nll/nll_main.py +++ /dev/null @@ -1,46 +0,0 @@ - -# -*- coding: utf-8 -*- -from __future__ import print_function, division -import logging -import os -import sys -from pymic.util.parse_config import * -from pymic.net_run_nll.nll_co_teaching import NLLCoTeaching -from pymic.net_run_nll.nll_trinet import NLLTriNet -from pymic.net_run_nll.nll_dast import NLLDAST - -NLLMethodDict = {'CoTeaching': NLLCoTeaching, - "TriNet": NLLTriNet, - "DAST": NLLDAST} - -def main(): - """ - The main function for noisy label learning methods. - """ - if(len(sys.argv) < 3): - print('Number of arguments should be 3. e.g.') - print(' pymic_nll train config.cfg') - exit() - stage = str(sys.argv[1]) - cfg_file = str(sys.argv[2]) - config = parse_config(cfg_file) - config = synchronize_config(config) - log_dir = config['training']['ckpt_save_dir'] - if(not os.path.exists(log_dir)): - os.mkdir(log_dir) - if sys.version.startswith("3.9"): - logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, - format='%(message)s', force=True) # for python 3.9 - else: - logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, - format='%(message)s') # for python 3.6 - logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) - logging_config(config) - nll_method = config['noisy_label_learning']['nll_method'] - agent = NLLMethodDict[nll_method](config, stage) - agent.run() - -if __name__ == "__main__": - main() - - \ No newline at end of file diff --git a/pymic/net_run_self_sl/self_sl_main.py b/pymic/net_run_self_sl/self_sl_main.py deleted file mode 100644 index 4a795d5..0000000 --- a/pymic/net_run_self_sl/self_sl_main.py +++ /dev/null @@ -1,87 +0,0 @@ - -# -*- coding: utf-8 -*- -from __future__ import print_function, division -import logging -import os -import sys -import shutil -from pymic.util.parse_config import * -from pymic.net_run_self_sl.self_sl_agent import SelfSLSegAgent - -def model_genesis(stage, cfg_file): - config = parse_config(cfg_file) - transforms = ['RandomFlip', 'LocalShuffling', 'NonLinearTransform', 'InOutPainting'] - genesis_cfg = { - 'randomflip_flip_depth': True, - 'randomflip_flip_height': True, - 'randomflip_flip_width': True, - 'localshuffling_probability': 0.5, - 'nonLineartransform_probability': 0.9, - 'inoutpainting_probability': 0.9, - 'inpainting_probability': 0.2 - } - config['dataset']['train_transform'].extend(transforms) - config['dataset']['valid_transform'].extend(transforms) - config['dataset'].update(genesis_cfg) - - config = synchronize_config(config) - log_dir = config['training']['ckpt_save_dir'] - if(not os.path.exists(log_dir)): - os.mkdir(log_dir) - if(stage == "train"): - dst_cfg = cfg_file if "/" not in cfg_file else cfg_file.split("/")[-1] - shutil.copy(cfg_file, log_dir + "/" + dst_cfg) - if sys.version.startswith("3.9"): - logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, - format='%(message)s', force=True) # for python 3.9 - else: - logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, - format='%(message)s') # for python 3.6 - logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) - logging_config(config) - agent = SelfSLSegAgent(config, stage) - agent.run() - -def default_self_sl(stage, cfg_file): - config = parse_config(cfg_file) - config = synchronize_config(config) - log_dir = config['training']['ckpt_save_dir'] - if(not os.path.exists(log_dir)): - os.mkdir(log_dir) - if(stage == "train"): - dst_cfg = cfg_file if "/" not in cfg_file else cfg_file.split("/")[-1] - shutil.copy(cfg_file, log_dir + "/" + dst_cfg) - if sys.version.startswith("3.9"): - logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, - format='%(message)s', force=True) # for python 3.9 - else: - logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, - format='%(message)s') # for python 3.6 - logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) - logging_config(config) - agent = SelfSLSegAgent(config, stage) - agent.run() - - -if __name__ == "__main__": - if(len(sys.argv) < 3): - print('Number of arguments should be 3. e.g.') - print(' pymic_self_sl train config.cfg') - exit() - stage = str(sys.argv[1]) - cfg_file = str(sys.argv[2]) - config = parse_config(cfg_file) - method = "default" - if 'self_supervised_learning' in config: - method = config['self_supervised_learning'].get('self_sl_method', 'default') - print("the self supervised method is ", method) - if(method == "default"): - default_self_sl(stage, cfg_file) - elif(method == 'model_genesis'): - model_genesis(stage, cfg_file) - else: - raise ValueError("The specified method {0:} is not implemented. ".format(method) + \ - "Consider to set `self_sl_method = default` and use customized" + \ - " transforms for self-supervised learning.") - - \ No newline at end of file diff --git a/pymic/net_run_ssl/__init__.py b/pymic/net_run_ssl/__init__.py deleted file mode 100644 index 4309107..0000000 --- a/pymic/net_run_ssl/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from __future__ import absolute_import -from pymic.net_run_ssl.ssl_abstract import * -from pymic.net_run_ssl.ssl_cct import * -from pymic.net_run_ssl.ssl_cps import * -from pymic.net_run_ssl.ssl_em import * -from pymic.net_run_ssl.ssl_mt import * -from pymic.net_run_ssl.ssl_uamt import * -from pymic.net_run_ssl.ssl_urpc import * \ No newline at end of file diff --git a/pymic/net_run_ssl/ssl_main.py b/pymic/net_run_ssl/ssl_main.py deleted file mode 100644 index 996f14a..0000000 --- a/pymic/net_run_ssl/ssl_main.py +++ /dev/null @@ -1,53 +0,0 @@ - -# -*- coding: utf-8 -*- -from __future__ import print_function, division -import logging -import os -import sys -from pymic.util.parse_config import * -from pymic.net_run_ssl.ssl_em import SSLEntropyMinimization -from pymic.net_run_ssl.ssl_mt import SSLMeanTeacher -from pymic.net_run_ssl.ssl_uamt import SSLUncertaintyAwareMeanTeacher -from pymic.net_run_ssl.ssl_cct import SSLCCT -from pymic.net_run_ssl.ssl_cps import SSLCPS -from pymic.net_run_ssl.ssl_urpc import SSLURPC - - -SSLMethodDict = {'EntropyMinimization': SSLEntropyMinimization, - 'MeanTeacher': SSLMeanTeacher, - 'UAMT': SSLUncertaintyAwareMeanTeacher, - 'CCT': SSLCCT, - 'CPS': SSLCPS, - 'URPC': SSLURPC} - -def main(): - """ - Main function for running a semi-supervised method. - """ - if(len(sys.argv) < 3): - print('Number of arguments should be 3. e.g.') - print(' pymic_ssl train config.cfg') - exit() - stage = str(sys.argv[1]) - cfg_file = str(sys.argv[2]) - config = parse_config(cfg_file) - config = synchronize_config(config) - log_dir = config['training']['ckpt_save_dir'] - if(not os.path.exists(log_dir)): - os.mkdir(log_dir) - if sys.version.startswith("3.9"): - logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, - format='%(message)s', force=True) # for python 3.9 - else: - logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, - format='%(message)s') # for python 3.6 - logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) - logging_config(config) - ssl_method = config['semi_supervised_learning']['ssl_method'] - agent = SSLMethodDict[ssl_method](config, stage) - agent.run() - -if __name__ == "__main__": - main() - - \ No newline at end of file diff --git a/pymic/net_run_wsl/__init__.py b/pymic/net_run_wsl/__init__.py deleted file mode 100644 index 72b8078..0000000 --- a/pymic/net_run_wsl/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from __future__ import absolute_import -from . import * \ No newline at end of file diff --git a/pymic/net_run_wsl/wsl_main.py b/pymic/net_run_wsl/wsl_main.py deleted file mode 100644 index 540c535..0000000 --- a/pymic/net_run_wsl/wsl_main.py +++ /dev/null @@ -1,52 +0,0 @@ - -# -*- coding: utf-8 -*- -from __future__ import print_function, division -import logging -import os -import sys -from pymic.util.parse_config import * -from pymic.net_run_wsl.wsl_em import WSLEntropyMinimization -from pymic.net_run_wsl.wsl_gatedcrf import WSLGatedCRF -from pymic.net_run_wsl.wsl_mumford_shah import WSLMumfordShah -from pymic.net_run_wsl.wsl_tv import WSLTotalVariation -from pymic.net_run_wsl.wsl_ustm import WSLUSTM -from pymic.net_run_wsl.wsl_dmpls import WSLDMPLS - -WSLMethodDict = {'EntropyMinimization': WSLEntropyMinimization, - 'GatedCRF': WSLGatedCRF, - 'MumfordShah': WSLMumfordShah, - 'TotalVariation': WSLTotalVariation, - 'USTM': WSLUSTM, - 'DMPLS': WSLDMPLS} - -def main(): - """ - The main function for training and inference of weakly supervised segmentation. - """ - if(len(sys.argv) < 3): - print('Number of arguments should be 3. e.g.') - print(' pymic_wsl train config.cfg') - exit() - stage = str(sys.argv[1]) - cfg_file = str(sys.argv[2]) - config = parse_config(cfg_file) - config = synchronize_config(config) - log_dir = config['training']['ckpt_save_dir'] - if(not os.path.exists(log_dir)): - os.mkdir(log_dir) - if sys.version.startswith("3.9"): - logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, - format='%(message)s', force=True) # for python 3.9 - else: - logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, - format='%(message)s') # for python 3.6 - logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) - logging_config(config) - wsl_method = config['weakly_supervised_learning']['wsl_method'] - agent = WSLMethodDict[wsl_method](config, stage) - agent.run() - -if __name__ == "__main__": - main() - - \ No newline at end of file diff --git a/pymic/util/parse_config.py b/pymic/util/parse_config.py index 5e81312..232762f 100644 --- a/pymic/util/parse_config.py +++ b/pymic/util/parse_config.py @@ -112,9 +112,12 @@ def synchronize_config(config): def logging_config(config): for section in config: - for key in config[section]: - value = config[section][key] - logging.info("{0:} {1:} = {2:}".format(section, key, value)) + if(isinstance(config[section], dict)): + for key in config[section]: + value = config[section][key] + logging.info("{0:} {1:} = {2:}".format(section, key, value)) + else: + logging.info("{0:} = {1:}".format(section, config[section])) if __name__ == "__main__": print(is_int('555')) diff --git a/setup.py b/setup.py index e0fa448..31f72e0 100644 --- a/setup.py +++ b/setup.py @@ -41,11 +41,8 @@ python_requires = '>=3.6', entry_points = { 'console_scripts': [ - 'pymic_run = pymic.net_run.net_run:main', - 'pymic_semi_sl = pymic.net_run_ssl.ssl_main:main', - 'pymic_self_sl = pymic.net_run_self_sl.self_sl_main:main', - 'pymic_wsl = pymic.net_run_wsl.wsl_main:main', - 'pymic_nll = pymic.net_run_nll.nll_main:main', + 'pymic_train = pymic.net_run.train:main', + 'pymic_test = pymic.net_run.predict:main', 'pymic_eval_cls = pymic.util.evaluation_cls:main', 'pymic_eval_seg = pymic.util.evaluation_seg:main' ], From cae464b5bfcddaf08a5480e29413b1de7948bc21 Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 23 Feb 2023 11:33:56 +0800 Subject: [PATCH 137/225] Delete pyproject.toml --- pyproject.toml | 8 -------- 1 file changed, 8 deletions(-) delete mode 100644 pyproject.toml diff --git a/pyproject.toml b/pyproject.toml deleted file mode 100644 index 4bbff29..0000000 --- a/pyproject.toml +++ /dev/null @@ -1,8 +0,0 @@ -[build-system] -requires = ["flit_core >=3.2,<4"] -build-backend = "flit_core.buildapi" - -[project] -name = "PyMIC" -authors = [{name = "Graziella", email = "graziella@lumache"}] -dynamic = ["version", "description"] From 8e387a7225bc86e2fc6c92f03572724182fee5fb Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 23 Feb 2023 13:33:23 +0800 Subject: [PATCH 138/225] update train.py and predict.py Confirm that the configuration file exists --- pymic/net_run/agent_cls.py | 4 ++-- pymic/net_run/predict.py | 2 ++ pymic/net_run/train.py | 2 ++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/pymic/net_run/agent_cls.py b/pymic/net_run/agent_cls.py index 181aa53..5610982 100644 --- a/pymic/net_run/agent_cls.py +++ b/pymic/net_run/agent_cls.py @@ -327,8 +327,8 @@ def infer(self): if(self.config['testing'].get('evaluation_mode', True)): self.net.eval() - - output_csv = self.config['testing']['output_csv'] + + output_csv = self.config['testing']['output_dir'] + '/' + self.config['testing']['output_csv'] class_num = self.config['network']['class_num'] save_probability = self.config['testing'].get('save_probability', False) diff --git a/pymic/net_run/predict.py b/pymic/net_run/predict.py index 8ac4af9..ca4ef25 100644 --- a/pymic/net_run/predict.py +++ b/pymic/net_run/predict.py @@ -17,6 +17,8 @@ def main(): print(' pymic_test config.cfg') exit() cfg_file = str(sys.argv[1]) + if(not os.path.isfile(cfg_file)): + raise ValueError("The config file does not exist: " + cfg_file) config = parse_config(cfg_file) config = synchronize_config(config) log_dir = config['testing']['output_dir'] diff --git a/pymic/net_run/train.py b/pymic/net_run/train.py index 4afae77..1478527 100644 --- a/pymic/net_run/train.py +++ b/pymic/net_run/train.py @@ -68,6 +68,8 @@ def main(): print(' pymic_train config.cfg') exit() cfg_file = str(sys.argv[1]) + if(not os.path.isfile(cfg_file)): + raise ValueError("The config file does not exist: " + cfg_file) config = parse_config(cfg_file) config = synchronize_config(config) log_dir = config['training']['ckpt_save_dir'] From ea8d4057baae4a88885c205326c910c968d987be Mon Sep 17 00:00:00 2001 From: taigw Date: Sun, 26 Feb 2023 16:06:03 +0800 Subject: [PATCH 139/225] release v0.4.0 release v0.4.0 --- README.md | 4 ++-- pymic/__init__.py | 2 +- setup.py | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index dcceba8..ac7f9b6 100644 --- a/README.md +++ b/README.md @@ -47,10 +47,10 @@ Run the following command to install the latest released version of PyMIC: ```bash pip install PYMIC ``` -To install a specific version of PYMIC such as 0.3.1, run: +To install a specific version of PYMIC such as 0.4.0, run: ```bash -pip install PYMIC==0.3.1 +pip install PYMIC==0.4.0 ``` Alternatively, you can download the source code for the latest version. Run the following command to compile and install: diff --git a/pymic/__init__.py b/pymic/__init__.py index 789b690..cb6356a 100644 --- a/pymic/__init__.py +++ b/pymic/__init__.py @@ -1,2 +1,2 @@ from __future__ import absolute_import -__version__ = "0.3.1.1" \ No newline at end of file +__version__ = "0.4.0" \ No newline at end of file diff --git a/setup.py b/setup.py index 31f72e0..36daf9a 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setuptools.setup( name = 'PYMIC', - version = "0.3.1.1", + version = "0.4.0", author ='PyMIC Consortium', author_email = 'wguotai@gmail.com', description = description, @@ -41,8 +41,8 @@ python_requires = '>=3.6', entry_points = { 'console_scripts': [ - 'pymic_train = pymic.net_run.train:main', - 'pymic_test = pymic.net_run.predict:main', + 'pymic_train = pymic.net_run.train:main', + 'pymic_test = pymic.net_run.predict:main', 'pymic_eval_cls = pymic.util.evaluation_cls:main', 'pymic_eval_seg = pymic.util.evaluation_seg:main' ], From 713a56e948c0362f56c73f850e79d6a37c7c0421 Mon Sep 17 00:00:00 2001 From: taigw Date: Sun, 26 Feb 2023 16:39:21 +0800 Subject: [PATCH 140/225] Update index.rst update reference --- docs/source/index.rst | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index f9b62ba..1724b3a 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -28,9 +28,10 @@ Citation If you use PyMIC for your research, please acknowledge it accordingly by citing our paper: -`G. Wang, X. Luo, R. Gu, S. Yang, Y. Qu, S. Zhai, Q. Zhao, K. Li, S. Zhang. (2022). +`G. Wang, X. Luo, R. Gu, S. Yang, Y. Qu, S. Zhai, Q. Zhao, K. Li, S. Zhang. PyMIC: A deep learning toolkit for annotation-efficient medical image segmentation. -arXiv, 2208.09350. `_ + Computer Methods and Programs in Biomedicine (CMPB). 231 (2023): 107398. + `_ BibTeX entry: @@ -41,8 +42,8 @@ BibTeX entry: author = {Guotai Wang and Xiangde Luo and Ran Gu and Shuojue Yang and Yijie Qu and Shuwei Zhai and Qianfei Zhao and Kang Li and Shaoting Zhang}, title = {{PyMIC: A deep learning toolkit for annotation-efficient medical image segmentation}}, year = {2022}, - url = {http://arxiv.org/abs/2208.09350}, - journal = {arXiv}, - volume = {2208.09350}, - pages = {1-10}, + url = {https://doi.org/10.1016/j.cmpb.2023.107398}, + journal = {Computer Methods and Programs in Biomedicine}, + volume = {231}, + pages = {107398}, } From 4ff4088401618a06eac3a09de205f6227514674c Mon Sep 17 00:00:00 2001 From: taigw Date: Sun, 26 Feb 2023 17:23:34 +0800 Subject: [PATCH 141/225] update structure of docs --- docs/source/conf.py | 4 +- docs/source/index.rst | 7 +-- docs/source/installation.rst | 3 +- docs/source/pymic.net_run.noisy_label.rst | 45 +++++++++++++ docs/source/pymic.net_run.rst | 14 ++++- docs/source/pymic.net_run.self_sup.rst | 21 +++++++ docs/source/pymic.net_run.semi_sup.rst | 69 ++++++++++++++++++++ docs/source/pymic.net_run.weak_sup.rst | 69 ++++++++++++++++++++ docs/source/pymic.net_run_nll.rst | 53 ---------------- docs/source/pymic.net_run_ssl.rst | 77 ----------------------- docs/source/pymic.net_run_wsl.rst | 77 ----------------------- docs/source/pymic.rst | 3 - docs/source/setup.rst | 7 +++ docs/source/usage.quickstart.rst | 36 +++++++---- 14 files changed, 254 insertions(+), 231 deletions(-) create mode 100644 docs/source/pymic.net_run.noisy_label.rst create mode 100644 docs/source/pymic.net_run.self_sup.rst create mode 100644 docs/source/pymic.net_run.semi_sup.rst create mode 100644 docs/source/pymic.net_run.weak_sup.rst delete mode 100644 docs/source/pymic.net_run_nll.rst delete mode 100644 docs/source/pymic.net_run_ssl.rst delete mode 100644 docs/source/pymic.net_run_wsl.rst create mode 100644 docs/source/setup.rst diff --git a/docs/source/conf.py b/docs/source/conf.py index cf1b568..09f2b50 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -9,8 +9,8 @@ copyright = '2021, HiLab' author = 'HiLab' -release = '0.1' -version = '0.1.0' +release = '0.4' +version = '0.4.0' # -- General configuration diff --git a/docs/source/index.rst b/docs/source/index.rst index 1724b3a..c1b6523 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -28,10 +28,9 @@ Citation If you use PyMIC for your research, please acknowledge it accordingly by citing our paper: -`G. Wang, X. Luo, R. Gu, S. Yang, Y. Qu, S. Zhai, Q. Zhao, K. Li, S. Zhang. -PyMIC: A deep learning toolkit for annotation-efficient medical image segmentation. - Computer Methods and Programs in Biomedicine (CMPB). 231 (2023): 107398. - `_ +`G. Wang, X. Luo, R. Gu, S. Yang, Y. Qu, S. Zhai, Q. Zhao, K. Li, S. Zhang. +PyMIC: A deep learning toolkit for annotation-efficient medical image segmentation. +Computer Methods and Programs in Biomedicine (CMPB). 231 (2023): 107398. `_ BibTeX entry: diff --git a/docs/source/installation.rst b/docs/source/installation.rst index ced640f..c1055e1 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -17,7 +17,8 @@ Alternatively, you can download or clone the code from `GitHub `_ - `h5py `_ diff --git a/docs/source/pymic.net_run.noisy_label.rst b/docs/source/pymic.net_run.noisy_label.rst new file mode 100644 index 0000000..04d38bb --- /dev/null +++ b/docs/source/pymic.net_run.noisy_label.rst @@ -0,0 +1,45 @@ +pymic.net\_run.noisy\_label package +=================================== + +Submodules +---------- + +pymic.net\_run.noisy\_label.nll\_clslsr module +---------------------------------------------- + +.. automodule:: pymic.net_run.noisy_label.nll_clslsr + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.noisy\_label.nll\_co\_teaching module +---------------------------------------------------- + +.. automodule:: pymic.net_run.noisy_label.nll_co_teaching + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.noisy\_label.nll\_dast module +-------------------------------------------- + +.. automodule:: pymic.net_run.noisy_label.nll_dast + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.noisy\_label.nll\_trinet module +---------------------------------------------- + +.. automodule:: pymic.net_run.noisy_label.nll_trinet + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: pymic.net_run.noisy_label + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/pymic.net_run.rst b/docs/source/pymic.net_run.rst index 74aab12..5f9ed95 100644 --- a/docs/source/pymic.net_run.rst +++ b/docs/source/pymic.net_run.rst @@ -1,6 +1,16 @@ pymic.net\_run package ====================== +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + pymic.net_run.semi_sup + pymic.net_run.weak_sup + pymic.net_run.noisy_label + Submodules ---------- @@ -44,10 +54,10 @@ pymic.net\_run.infer\_func module :undoc-members: :show-inheritance: -pymic.net\_run.net\_run module +pymic.net\_run.semi\_sup module ------------------------------ -.. automodule:: pymic.net_run.net_run +.. automodule:: pymic.net_run.semi_sup :members: :undoc-members: :show-inheritance: diff --git a/docs/source/pymic.net_run.self_sup.rst b/docs/source/pymic.net_run.self_sup.rst new file mode 100644 index 0000000..a4568d1 --- /dev/null +++ b/docs/source/pymic.net_run.self_sup.rst @@ -0,0 +1,21 @@ +pymic.net\_run.self\_sup package +================================ + +Submodules +---------- + +pymic.net\_run.self\_sup.self\_sl\_agent module +----------------------------------------------- + +.. automodule:: pymic.net_run.self_sup.self_sl_agent + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: pymic.net_run.self_sup + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/pymic.net_run.semi_sup.rst b/docs/source/pymic.net_run.semi_sup.rst new file mode 100644 index 0000000..6ed157d --- /dev/null +++ b/docs/source/pymic.net_run.semi_sup.rst @@ -0,0 +1,69 @@ +pymic.net\_run.semi\_sup package +================================ + +Submodules +---------- + +pymic.net\_run.semi\_sup.ssl\_abstract module +--------------------------------------------- + +.. automodule:: pymic.net_run.semi_sup.ssl_abstract + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.semi\_sup.ssl\_cct module +---------------------------------------- + +.. automodule:: pymic.net_run.semi_sup.ssl_cct + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.semi\_sup.ssl\_cps module +---------------------------------------- + +.. automodule:: pymic.net_run.semi_sup.ssl_cps + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.semi\_sup.ssl\_em module +--------------------------------------- + +.. automodule:: pymic.net_run.semi_sup.ssl_em + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.semi\_sup.ssl\_mt module +--------------------------------------- + +.. automodule:: pymic.net_run.semi_sup.ssl_mt + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.semi\_sup.ssl\_uamt module +----------------------------------------- + +.. automodule:: pymic.net_run.semi_sup.ssl_uamt + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.semi\_sup.ssl\_urpc module +----------------------------------------- + +.. automodule:: pymic.net_run.semi_sup.ssl_urpc + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: pymic.net_run.semi_sup + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/pymic.net_run.weak_sup.rst b/docs/source/pymic.net_run.weak_sup.rst new file mode 100644 index 0000000..b906f72 --- /dev/null +++ b/docs/source/pymic.net_run.weak_sup.rst @@ -0,0 +1,69 @@ +pymic.net\_run.weak\_sup package +================================ + +Submodules +---------- + +pymic.net\_run.weak\_sup.wsl\_abstract module +--------------------------------------------- + +.. automodule:: pymic.net_run.weak_sup.wsl_abstract + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.weak\_sup.wsl\_dmpls module +------------------------------------------ + +.. automodule:: pymic.net_run.weak_sup.wsl_dmpls + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.weak\_sup.wsl\_em module +--------------------------------------- + +.. automodule:: pymic.net_run.weak_sup.wsl_em + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.weak\_sup.wsl\_gatedcrf module +--------------------------------------------- + +.. automodule:: pymic.net_run.weak_sup.wsl_gatedcrf + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.weak\_sup.wsl\_mumford\_shah module +-------------------------------------------------- + +.. automodule:: pymic.net_run.weak_sup.wsl_mumford_shah + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.weak\_sup.wsl\_tv module +--------------------------------------- + +.. automodule:: pymic.net_run.weak_sup.wsl_tv + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.weak\_sup.wsl\_ustm module +----------------------------------------- + +.. automodule:: pymic.net_run.weak_sup.wsl_ustm + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: pymic.net_run.weak_sup + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/pymic.net_run_nll.rst b/docs/source/pymic.net_run_nll.rst deleted file mode 100644 index 40ee9ef..0000000 --- a/docs/source/pymic.net_run_nll.rst +++ /dev/null @@ -1,53 +0,0 @@ -pymic.net\_run\_nll package -=========================== - -Submodules ----------- - -pymic.net\_run\_nll.nll\_clslsr module --------------------------------------- - -.. automodule:: pymic.net_run_nll.nll_clslsr - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_nll.nll\_co\_teaching module --------------------------------------------- - -.. automodule:: pymic.net_run_nll.nll_co_teaching - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_nll.nll\_dast module ------------------------------------- - -.. automodule:: pymic.net_run_nll.nll_dast - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_nll.nll\_main module ------------------------------------- - -.. automodule:: pymic.net_run_nll.nll_main - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_nll.nll\_trinet module --------------------------------------- - -.. automodule:: pymic.net_run_nll.nll_trinet - :members: - :undoc-members: - :show-inheritance: - -Module contents ---------------- - -.. automodule:: pymic.net_run_nll - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/pymic.net_run_ssl.rst b/docs/source/pymic.net_run_ssl.rst deleted file mode 100644 index 236e2d6..0000000 --- a/docs/source/pymic.net_run_ssl.rst +++ /dev/null @@ -1,77 +0,0 @@ -pymic.net\_run\_ssl package -=========================== - -Submodules ----------- - -pymic.net\_run\_ssl.ssl\_abstract module ----------------------------------------- - -.. automodule:: pymic.net_run_ssl.ssl_abstract - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_ssl.ssl\_cct module ------------------------------------ - -.. automodule:: pymic.net_run_ssl.ssl_cct - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_ssl.ssl\_cps module ------------------------------------ - -.. automodule:: pymic.net_run_ssl.ssl_cps - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_ssl.ssl\_em module ----------------------------------- - -.. automodule:: pymic.net_run_ssl.ssl_em - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_ssl.ssl\_main module ------------------------------------- - -.. automodule:: pymic.net_run_ssl.ssl_main - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_ssl.ssl\_mt module ----------------------------------- - -.. automodule:: pymic.net_run_ssl.ssl_mt - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_ssl.ssl\_uamt module ------------------------------------- - -.. automodule:: pymic.net_run_ssl.ssl_uamt - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_ssl.ssl\_urpc module ------------------------------------- - -.. automodule:: pymic.net_run_ssl.ssl_urpc - :members: - :undoc-members: - :show-inheritance: - -Module contents ---------------- - -.. automodule:: pymic.net_run_ssl - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/pymic.net_run_wsl.rst b/docs/source/pymic.net_run_wsl.rst deleted file mode 100644 index 5eda921..0000000 --- a/docs/source/pymic.net_run_wsl.rst +++ /dev/null @@ -1,77 +0,0 @@ -pymic.net\_run\_wsl package -=========================== - -Submodules ----------- - -pymic.net\_run\_wsl.wsl\_abstract module ----------------------------------------- - -.. automodule:: pymic.net_run_wsl.wsl_abstract - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_wsl.wsl\_dmpls module -------------------------------------- - -.. automodule:: pymic.net_run_wsl.wsl_dmpls - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_wsl.wsl\_em module ----------------------------------- - -.. automodule:: pymic.net_run_wsl.wsl_em - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_wsl.wsl\_gatedcrf module ----------------------------------------- - -.. automodule:: pymic.net_run_wsl.wsl_gatedcrf - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_wsl.wsl\_main module ------------------------------------- - -.. automodule:: pymic.net_run_wsl.wsl_main - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_wsl.wsl\_mumford\_shah module ---------------------------------------------- - -.. automodule:: pymic.net_run_wsl.wsl_mumford_shah - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_wsl.wsl\_tv module ----------------------------------- - -.. automodule:: pymic.net_run_wsl.wsl_tv - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_wsl.wsl\_ustm module ------------------------------------- - -.. automodule:: pymic.net_run_wsl.wsl_ustm - :members: - :undoc-members: - :show-inheritance: - -Module contents ---------------- - -.. automodule:: pymic.net_run_wsl - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/pymic.rst b/docs/source/pymic.rst index 7545740..dd180f0 100644 --- a/docs/source/pymic.rst +++ b/docs/source/pymic.rst @@ -12,9 +12,6 @@ Subpackages pymic.loss pymic.net pymic.net_run - pymic.net_run_nll - pymic.net_run_ssl - pymic.net_run_wsl pymic.transform pymic.util diff --git a/docs/source/setup.rst b/docs/source/setup.rst new file mode 100644 index 0000000..552eb49 --- /dev/null +++ b/docs/source/setup.rst @@ -0,0 +1,7 @@ +setup module +============ + +.. automodule:: setup + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/usage.quickstart.rst b/docs/source/usage.quickstart.rst index 95cf20d..bced277 100644 --- a/docs/source/usage.quickstart.rst +++ b/docs/source/usage.quickstart.rst @@ -12,13 +12,13 @@ for segmentation with full supervision, run the fullowing command: .. code-block:: bash - pymic_run train myconfig.cfg + pymic_train myconfig.cfg After training, run the following command for testing: .. code-block:: bash - pymic_run test myconfig.cfg + pymic_test myconfig.cfg .. tip:: @@ -51,11 +51,13 @@ file used for segmentation of lung from radiograph, which can be find in [dataset] # tensor type (float or double) tensor_type = float + task_type = seg root_dir = ../../PyMIC_data/JSRT train_csv = config/jsrt_train.csv valid_csv = config/jsrt_valid.csv test_csv = config/jsrt_test.csv + train_batch_size = 4 # data transforms @@ -69,19 +71,26 @@ file used for segmentation of lung from radiograph, which can be find in LabelConvert_source_list = [0, 255] LabelConvert_target_list = [0, 1] + [network] + # this section gives parameters for network + # the keys may be different for different networks + + # type of network net_type = UNet2D - # Parameters for UNet2D + + # number of class, required for segmentation task class_num = 2 in_chns = 1 feature_chns = [16, 32, 64, 128, 256] dropout = [0, 0, 0.3, 0.4, 0.5] bilinear = False - deep_supervise= False + multiscale_pred = False [training] # list of gpus gpus = [0] + loss_type = DiceLoss # for optimizers @@ -95,8 +104,8 @@ file used for segmentation of lung from radiograph, which can be find in lr_gamma = 0.5 lr_milestones = [2000, 4000, 6000] - ckpt_save_dir = model/unet_dice_loss - ckpt_prefix = unet + ckpt_save_dir = model/unet + ckpt_prefix = unet # start iter iter_start = 0 @@ -107,9 +116,10 @@ file used for segmentation of lung from radiograph, which can be find in [testing] # list of gpus gpus = [0] + # checkpoint mode can be [0-latest, 1-best, 2-specified] - ckpt_mode = 0 - output_dir = result + ckpt_mode = 0 + output_dir = result/unet # convert the label of prediction output label_source = [0, 1] @@ -131,17 +141,18 @@ For example, for segmentation tasks, run: pymic_eval_seg evaluation.cfg -The configuration file is like (an example from ``PYMIC_examples/seg_ssl/ACDC``): +The configuration file is like (an example from +`PyMIC_examples/seg_ssl/ACDC `_): .. code-block:: none [evaluation] - metric = dice + metric_list = [dice, hd95] label_list = [1,2,3] organ_name = heart ground_truth_folder_root = ../../PyMIC_data/ACDC/preprocess - segmentation_folder_root = result/unet2d_em + segmentation_folder_root = result/unet2d_urpc evaluation_image_pair = config/data/image_test_gt_seg.csv See :mod:`pymic.util.evaluation_seg.evaluation` for details of the configuration required. @@ -152,7 +163,8 @@ For classification tasks, run: pymic_eval_cls evaluation.cfg -The configuration file is like (an example from ``PYMIC_examples/classification/CHNCXR``): +The configuration file is like (an example from +`PyMIC_examples/classification/CHNCXR `_): .. code-block:: none From 6afeb8697d4fa9fb5dd1fb5cea12ca82f37baf07 Mon Sep 17 00:00:00 2001 From: taigw Date: Sun, 26 Feb 2023 21:45:56 +0800 Subject: [PATCH 142/225] update docs for v0.4.0 update docs for v0.4.0 --- docs/source/pymic.net_run.rst | 8 ------ docs/source/usage.fsl.rst | 29 +++++++++---------- docs/source/usage.nll.rst | 48 ++++++++++++-------------------- docs/source/usage.ssl.rst | 49 ++++++++++++++------------------- docs/source/usage.wsl.rst | 43 ++++++++++------------------- pymic/net_run/agent_abstract.py | 1 + pymic/net_run/get_optimizer.py | 16 +++++++++++ pymic/transform/intensity.py | 17 ++++++------ 8 files changed, 90 insertions(+), 121 deletions(-) diff --git a/docs/source/pymic.net_run.rst b/docs/source/pymic.net_run.rst index 5f9ed95..9f3bc26 100644 --- a/docs/source/pymic.net_run.rst +++ b/docs/source/pymic.net_run.rst @@ -54,14 +54,6 @@ pymic.net\_run.infer\_func module :undoc-members: :show-inheritance: -pymic.net\_run.semi\_sup module ------------------------------- - -.. automodule:: pymic.net_run.semi_sup - :members: - :undoc-members: - :show-inheritance: - Module contents --------------- diff --git a/docs/source/usage.fsl.rst b/docs/source/usage.fsl.rst index 053daf7..937bb01 100644 --- a/docs/source/usage.fsl.rst +++ b/docs/source/usage.fsl.rst @@ -28,8 +28,8 @@ configuration file for running. .. tip:: If you use the built-in modules such as ``UNet`` and ``Dice`` + ``CrossEntropy`` loss - for segmentation, you don't need to write the above code. Just just use the ``pymic_run`` - command. + for segmentation, you don't need to write the above code. Just just use the ``pymic_train`` + command. See examples in `PyMIC_examples/segmentation/ `_. Dataset ------- @@ -207,7 +207,7 @@ hyper-parameters. For example, the following is a configuration for using ``2DUN feature_chns = [16, 32, 64, 128, 256] dropout = [0, 0, 0.3, 0.4, 0.5] bilinear = False - deep_supervise= False + multiscale_pred = False The ``SegNetDict`` in :mod:`pymic.net.net_dict_seg` lists all the built-in network structures currently implemented in PyMIC. @@ -299,9 +299,6 @@ Itreations For training iterations, the following parameters need to be specified in the configuration file: -* ``iter_start``: the start iteration, by default is 0. None zero value means the - iteration where a pre-trained model stopped for continuing with the trainnig. - * ``iter_max``: the maximal allowed iteration for training. * ``iter_valid``: if the value is K, it means evaluating the performance on the @@ -321,9 +318,9 @@ Optimizer For optimizer, users need to set ``optimizer``, ``learning_rate``, ``momentum`` and ``weight_decay``. The built-in optimizers include ``SGD``, ``Adam``, ``SparseAdam``, ``Adadelta``, ``Adagrad``, ``Adamax``, ``ASGD``, -``LBFGS``, ``RMSprop`` and ``Rprop`` that are implemented in :mod:`torch.optim`. +``LBFGS``, ``RMSprop`` and ``Rprop`` that are implemented in `torch.optim`. -You can also use customized optimizers via :mod:`SegmentationAgent.set_optimizer()`. +You can also use customized optimizers via `SegmentationAgent.set_optimizer()`. Learning Rate Scheduler ^^^^^^^^^^^^^^^^^^^^^^^ @@ -335,7 +332,7 @@ the configuration file. Parameters related to ``ReduceLROnPlateau`` include ``lr_gamma``. Parameters related to ``MultiStepLR`` include ``lr_gamma`` and ``lr_milestones``. -You can also use customized lr schedulers via :mod:`SegmentationAgent.set_scheduler()`. +You can also use customized lr schedulers via `SegmentationAgent.set_scheduler()`. Other Options ^^^^^^^^^^^^^ @@ -373,8 +370,8 @@ test-time augmentation, etc. The following is a list of options availble for inf * ``ckpt_name`` (string, optinal): the full path to the checkpoint if ckpt_mode = 2. * ``post_process`` (string, default is None): the post process method after inference. - The current available post processing is :mod:`PostKeepLargestComponent`. Uses can also - specify customized post process methods via :mod:`SegmentationAgent.set_postprocessor()`. + The current available post processing is :mod:`pymic.util.post_process.PostKeepLargestComponent`. + Uses can also specify customized post process methods via `SegmentationAgent.set_postprocessor()`. * ``sliding_window_enable`` (bool, default is False): use sliding window for inference or not. @@ -390,14 +387,14 @@ test-time augmentation, etc. The following is a list of options availble for inf * ``ignore_dir`` (bool, default is True): if the input image name has a `/`, it will be replaced with `_` in the output file name. -* ``save_probability`` (boold, default is False): save the output probability for each class. +* ``save_probability`` (bool, default is False): save the output probability for each class. * ``label_source`` (list, default is None): a list of label to be converted after prediction. For example, - :mod:`label_source` = [0, 1] and :mod:`label_target` = [0, 255] will convert label value from 1 to 255. + `label_source` = [0, 1] and `label_target` = [0, 255] will convert label value from 1 to 255. -* ``label_target`` (list, default is None): a list of label after conversion. Use this with :mod:`label_source`. +* ``label_target`` (list, default is None): a list of label after conversion. Use this with `label_source`. * ``filename_replace_source`` (string, default is None): the substring in the filename will be replaced with - a new substring specified by :mod:`filename_replace_target`. + a new substring specified by `filename_replace_target`. -* ``filename_replace_target`` (string, default is None): work with :mod:`filename_replace_source`. \ No newline at end of file +* ``filename_replace_target`` (string, default is None): work with `filename_replace_source`. \ No newline at end of file diff --git a/docs/source/usage.nll.rst b/docs/source/usage.nll.rst index f6edd33..0a87be9 100644 --- a/docs/source/usage.nll.rst +++ b/docs/source/usage.nll.rst @@ -3,43 +3,28 @@ Noisy Label Learning ==================== -pymic_nll ---------- - -:mod:`pymic_nll` is the command for using built-in NLL methods for training. -Similarly to :mod:`pymic_run`, it should be followed by two parameters, specifying the -stage and configuration file, respectively. The training and testing commands are: - -.. code-block:: bash - - pymic_nll train myconfig_nll.cfg - pymic_nll test myconfig_nll.cfg - -.. tip:: - - If the NLL method only involves one network, either ``pymic_nll`` or ``pymic_run`` - can be used for inference. Their difference only exists in the training stage. - .. note:: Some NLL methods only use noise-robust loss functions without complex training process, and just combining the standard :mod:`SegmentationAgent` with such - loss function works for training. ``pymic_run`` instead of ``pymic_nll`` should - be used for these methods. + loss function works for training. NLL Configurations ------------------ -In the configuration file for ``pymic_nll``, in addition to those used in standard fully +In the configuration file for noisy label learning, in addition to those used in standard fully supervised learning, there is a ``noisy_label_learning`` section that is specifically designed -for NLL methods. In that section, users need to specify the ``nll_method`` and configurations -related to the NLL method. For example, the correspoinding configuration for CoTeaching is: +for NLL methods. In that section, users need to specify the ``method_name`` and configurations +related to the NLL method. ``supervise_type`` should be set as "`noisy_label`" in the ``dataset`` section. + For example, the correspoinding configuration for CoTeaching is: .. code-block:: none [dataset] ... + supervise_type = noisy_label + ... [network] ... @@ -48,7 +33,7 @@ related to the NLL method. For example, the correspoinding configuration for CoT ... [noisy_label_learning] - nll_method = CoTeaching + method_name = CoTeaching co_teaching_select_ratio = 0.8 rampup_start = 1000 rampup_end = 8000 @@ -60,13 +45,14 @@ related to the NLL method. For example, the correspoinding configuration for CoT The configuration items vary with different NLL methods. Please refer to the API of each built-in NLL method for details of the correspoinding configuration. + See examples in `PyMIC_examples/seg_nll/ `_. + Built-in NLL Methods -------------------- -Some NLL methods only use noise-robust loss functions. They are used with ``pymic_run`` -for training. Just set ``loss_type`` to one of them in the configuration file, similarly -to the fully supervised learning. +Some NLL methods only use noise-robust loss functions. They are used with a standard fully supervised training +paradigm. Just set ``supervise_type`` = `fully_sup`, and use ``loss_type`` to one of them in the configuration file: * ``GCELoss``: (`NeurIPS 2018 `_) Generalized cross entropy loss. @@ -78,7 +64,7 @@ to the fully supervised learning. Noise-robust Dice loss. The other NLL methods are implemented in child classes of -:mod:`pymic.net_run_nll.nll_abstract.NLLSegAgent`, and they are: +:mod:`pymic.net_run.agent_seg.SegmentationAgent`, and they are: * ``CLSLSR``: (`MICCAI 2020 `_) Confident learning with spatial label smoothing regularization. @@ -95,15 +81,15 @@ The other NLL methods are implemented in child classes of Customized NLL Methods ---------------------- -PyMIC alo supports customizing NLL methods by inheriting the :mod:`NLLSegAgent` class. -You may only need to rewrite the :mod:`training()` method and reuse most part of the +PyMIC alo supports customized NLL methods by inheriting the `SegmentationAgent` class. +You may only need to rewrite the `training()` method and reuse most part of the existing pipeline, such as data loading, validation and inference methods. For example: .. code-block:: none - from pymic.net_run_nll.nll_abstract import NLLSegAgent + from pymic.net_run.agent_seg import SegmentationAgent - class MyNLLMethod(NLLSegAgent): + class MyNLLMethod(SegmentationAgent): def __init__(self, config, stage = 'train'): super(MyNLLMethod, self).__init__(config, stage) ... diff --git a/docs/source/usage.ssl.rst b/docs/source/usage.ssl.rst index 143d3f8..0fd8dc6 100644 --- a/docs/source/usage.ssl.rst +++ b/docs/source/usage.ssl.rst @@ -3,37 +3,22 @@ Semi-Supervised Learning ========================= -pymic_ssl ---------- - -:mod:`pymic_ssl` is the command for using built-in semi-supervised methods for training. -Similarly to :mod:`pymic_run`, it should be followed by two parameters, specifying the -stage and configuration file, respectively. The training and testing commands are: - -.. code-block:: bash - - pymic_ssl train myconfig_ssl.cfg - pymic_ssl test myconfig_ssl.cfg - -.. tip:: - - If the SSL method only involves one network, either ``pymic_ssl`` or ``pymic_run`` - can be used for inference. Their difference only exists in the training stage. - SSL Configurations ------------------ -In the configuration file for ``pymic_ssl``, in addition to those used in fully +In the configuration file for semi-supervised segmentation, in addition to those used in fully supervised learning, there are some items specificalized for semi-supervised learning. Users should provide values for the following items in ``dataset`` section of the configuration file: +* ``supervise_type`` (string): The value should be "`semi_sup`". + * ``train_csv_unlab`` (string): the csv file for unlabeled dataset. Note that ``train_csv`` is only used for labeled dataset. * ``train_batch_size_unlab`` (int): the batch size for unlabeled dataset. - Note that ``train_batch_size`` means the batch size for the labeled dataset. + Note that `train_batch_size` means the batch size for the labeled dataset. * ``train_transform_unlab`` (list): a list of transforms used for unlabeled data. @@ -43,7 +28,12 @@ The following is an example of the ``dataset`` section for semi-supervised learn .. code-block:: none ... - root_dir =../../PyMIC_data/ACDC/preprocess/ + + tensor_type = float + task_type = seg + supervise_type = semi_sup + + root_dir = ../../PyMIC_data/ACDC/preprocess/ train_csv = config/data/image_train_r10_lab.csv train_csv_unlab = config/data/image_train_r10_unlab.csv valid_csv = config/data/image_valid.csv @@ -60,14 +50,14 @@ The following is an example of the ``dataset`` section for semi-supervised learn ... In addition, there is a ``semi_supervised_learning`` section that is specifically designed -for SSL methods. In that section, users need to specify the ``ssl_method`` and configurations +for SSL methods. In that section, users need to specify the ``method_name`` and configurations related to the SSL method. For example, the correspoinding configuration for CPS is: .. code-block:: none ... [semi_supervised_learning] - ssl_method = CPS + method_name = CPS regularize_w = 0.1 rampup_start = 1000 rampup_end = 20000 @@ -76,14 +66,15 @@ related to the SSL method. For example, the correspoinding configuration for CPS .. note:: The configuration items vary with different SSL methods. Please refer to the API - of each built-in SSL method for details of the correspoinding configuration. + of each built-in SSL method for details of the correspoinding configuration. + See examples in `PyMIC_examples/seg_ssl/ `_. Built-in SSL Methods -------------------- -:mod:`pymic.net_run_ssl.ssl_abstract.SSLSegAgent` is the abstract class used for -semi-supervised learning. The built-in SSL methods are child classes of :mod:`SSLSegAgent`. -The available SSL methods implemnted in PyMIC are listed in :mod:`pymic.net_run_ssl.ssl_main.SSLMethodDict`, +:mod:`pymic.net_run.semi_sup.ssl_abstract.SSLSegAgent` is the abstract class used for +semi-supervised learning. The built-in SSL methods are child classes of `SSLSegAgent`. +The available SSL methods implemnted in PyMIC are listed in `pymic.net_run.semi_sup.SSLMethodDict`, and they are: * ``EntropyMinimization``: (`NeurIPS 2005 `_) @@ -103,13 +94,13 @@ and they are: Customized SSL Methods ---------------------- -PyMIC alo supports customizing SSL methods by inheriting the :mod:`SSLSegAgent` class. -You may only need to rewrite the :mod:`training()` method and reuse most part of the +PyMIC alo supports customizing SSL methods by inheriting the `SSLSegAgent` class. +You may only need to rewrite the `training()` method and reuse most part of the existing pipeline, such as data loading, validation and inference methods. For example: .. code-block:: none - from pymic.net_run_ssl.ssl_abstract import SSLSegAgent + from pymic.net_run.semi_sup import SSLSegAgent class MySSLMethod(SSLSegAgent): def __init__(self, config, stage = 'train'): diff --git a/docs/source/usage.wsl.rst b/docs/source/usage.wsl.rst index 00471f6..10d6da2 100644 --- a/docs/source/usage.wsl.rst +++ b/docs/source/usage.wsl.rst @@ -3,23 +3,6 @@ Weakly-Supervised Learning ========================== -pymic_wsl ---------- - -:mod:`pymic_wsl` is the command for using built-in weakly-supervised methods for training. -Similarly to :mod:`pymic_run`, it should be followed by two parameters, specifying the -stage and configuration file, respectively. The training and testing commands are: - -.. code-block:: bash - - pymic_wsl train myconfig_wsl.cfg - pymic_wsl test myconfig_wsl.cfg - -.. tip:: - - If the WSL method only involves one network, either ``pymic_wsl`` or ``pymic_run`` - can be used for inference. Their difference only exists in the training stage. - .. note:: Currently, the weakly supervised methods supported by PyMIC are only for learning @@ -31,17 +14,19 @@ stage and configuration file, respectively. The training and testing commands ar WSL Configurations ------------------ -In the configuration file for ``pymic_wsl``, in addition to those used in fully +In the configuration file for weakly supervised learning, in addition to those used in fully supervised learning, there are some items specificalized for weakly-supervised learning. -First, in the :mod:`train_transform` list, a special transform named :mod:`PartialLabelToProbability` +First, ``supervise_type`` should be set as "`weak_sup`" in the ``dataset`` section. + +Second, in the ``train_transform`` list, a special transform named `PartialLabelToProbability` should be used to transform patial labels into a one-hot probability map and a weighting map of pixels (i.e., the weight of a pixel is 1 if labeled and 0 otherwise). The patial cross entropy loss on labeled pixels is actually implemented by a weighted cross entropy loss. The loss setting is `loss_type = CrossEntropyLoss`. -Second, there is a ``weakly_supervised_learning`` section that is specifically designed -for WSL methods. In that section, users need to specify the ``wsl_method`` and configurations +Thirdly, there is a ``weakly_supervised_learning`` section that is specifically designed +for WSL methods. In that section, users need to specify the ``method_name`` and configurations related to the WSL method. For example, the correspoinding configuration for GatedCRF is: @@ -50,6 +35,7 @@ related to the WSL method. For example, the correspoinding configuration for Gat [dataset] ... + supervise_type = weak_sup root_dir = ../../PyMIC_data/ACDC/preprocess train_csv = config/data/image_train.csv valid_csv = config/data/image_valid.csv @@ -72,7 +58,7 @@ related to the WSL method. For example, the correspoinding configuration for Gat ... [weakly_supervised_learning] - wsl_method = GatedCRF + method_name = GatedCRF regularize_w = 0.1 rampup_start = 2000 rampup_end = 15000 @@ -90,13 +76,14 @@ related to the WSL method. For example, the correspoinding configuration for Gat The configuration items vary with different WSL methods. Please refer to the API of each built-in WSL method for details of the correspoinding configuration. + See examples in `PyMIC_examples/seg_wsl/ `_. Built-in WSL Methods -------------------- -:mod:`pymic.net_run_wsl.wsl_abstract.WSLSegAgent` is the abstract class used for -weakly-supervised learning. The built-in WSL methods are child classes of :mod:`WSLSegAgent`. -The available WSL methods implemnted in PyMIC are listed in :mod:`pymic.net_run_wsl.wsl_main.WSLMethodDict`, +:mod:`pymic.net_run.weak_sup.wsl_abstract.WSLSegAgent` is the abstract class used for +weakly-supervised learning. The built-in WSL methods are child classes of `WSLSegAgent`. +The available WSL methods implemnted in PyMIC are listed in `pymic.net_run.weak_sup.WSLMethodDict`, and they are: * ``EntropyMinimization``: (`NeurIPS 2005 `_) @@ -120,13 +107,13 @@ and they are: Customized WSL Methods ---------------------- -PyMIC alo supports customizing WSL methods by inheriting the :mod:`WSLSegAgent` class. -You may only need to rewrite the :mod:`training()` method and reuse most part of the +PyMIC alo supports customizing WSL methods by inheriting the `WSLSegAgent` class. +You may only need to rewrite the `training()` method and reuse most part of the existing pipeline, such as data loading, validation and inference methods. For example: .. code-block:: none - from pymic.net_run_wsl.wsl_abstract import WSLSegAgent + from pymic.net_run.weak_sup import WSLSegAgent class MyWSLMethod(WSLSegAgent): def __init__(self, config, stage = 'train'): diff --git a/pymic/net_run/agent_abstract.py b/pymic/net_run/agent_abstract.py index 9131ba0..7a509d1 100644 --- a/pymic/net_run/agent_abstract.py +++ b/pymic/net_run/agent_abstract.py @@ -282,6 +282,7 @@ def create_optimizer(self, params, checkpoint = None): :param params: network parameters for optimization. Usually it is obtained by `self.get_parameters_to_update()`. + :param checkpoint: A previous checkpoint to load. Default is `None`. """ opt_params = self.config['training'] if(self.optimizer is None): diff --git a/pymic/net_run/get_optimizer.py b/pymic/net_run/get_optimizer.py index ad8fda0..755ebcf 100644 --- a/pymic/net_run/get_optimizer.py +++ b/pymic/net_run/get_optimizer.py @@ -7,6 +7,15 @@ from pymic.util.general import keyword_match def get_optimizer(name, net_params, optim_params): + """ + Create an optimizer for learnable parameters. + + :param name: (string) Name of the optimizer. Should be one of {`SGD`, `Adam`, + `SparseAdam`, `Adadelta`, `Adagrad`, `Adamax`, `ASGD`, `LBFGS`, `RMSprop`, `Rprop`}. + :param net_params: Learnable parameters that need to be set for an optimizer. + :param optim_params: (dict) The parameters required for the target optimizer. + :return: An instance of the target optimizer. + """ lr = optim_params['learning_rate'] momentum = optim_params['momentum'] weight_decay = optim_params['weight_decay'] @@ -39,6 +48,13 @@ def get_optimizer(name, net_params, optim_params): def get_lr_scheduler(optimizer, sched_params): + """ + Create learning rate scheduler for an optimizer + + :param optimizer: An optimizer instance. + :param sched_params: (dict) The parameters required for the scheduler. + :return: An instance of the target learning rate scheduler. + """ name = sched_params["lr_scheduler"] val_it = sched_params["iter_valid"] epoch_last = sched_params["last_iter"] diff --git a/pymic/transform/intensity.py b/pymic/transform/intensity.py index 3b5ee9d..96458fa 100644 --- a/pymic/transform/intensity.py +++ b/pymic/transform/intensity.py @@ -14,21 +14,20 @@ def bernstein_poly(i, n, t): """ - The Bernstein polynomial of n, i as a function of t + The Bernstein polynomial of n, i as a function of t. """ return comb(n, i) * ( t**(n-i) ) * (1 - t)**i def bezier_curve(points, nTimes=1000): """ - Given a set of control points, return the - bezier curve defined by the control points. - Control points should be a list of lists, or list of tuples - such as [ [1,1], - [2,3], - [4,5], ..[Xn, Yn] ] - nTimes is the number of time steps, defaults to 1000 - See http://processingjs.nihongoresources.com/bezierinfo/ + Given a set of control points, return the + bezier curve defined by the control points. + Control points should be a list of lists, or list of tuples + such as [ [1,1], [2,3], [4,5], ..[Xn, Yn] ]. + + `nTimes` is the number of time steps, defaults to 1000. + See http://processingjs.nihongoresources.com/bezierinfo/ """ nPoints = len(points) From e17a2fa6699710c074026f880979ce203bc8483f Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 15 Mar 2023 16:29:58 +0800 Subject: [PATCH 143/225] update task type: reconstruction update task type: reconstruction --- pymic/net_run/__init__.py | 2 - pymic/net_run/agent_abstract.py | 3 +- pymic/net_run/agent_cls.py | 2 +- pymic/net_run/agent_rec.py | 274 ++++++++++++++++++++++++ pymic/net_run/agent_seg.py | 2 +- pymic/net_run/get_optimizer.py | 16 -- pymic/net_run/predict.py | 7 +- pymic/net_run/self_sup/self_sl_agent.py | 227 +------------------- pymic/net_run/train.py | 7 +- pymic/transform/crop.py | 8 +- pymic/transform/flip.py | 4 +- pymic/transform/intensity.py | 17 +- pymic/transform/label_convert.py | 50 +++++ pymic/transform/pad.py | 4 +- pymic/transform/rescale.py | 8 +- pymic/transform/rotate.py | 4 +- pymic/transform/trans_dict.py | 2 + pymic/transform/transpose.py | 5 +- 18 files changed, 367 insertions(+), 275 deletions(-) create mode 100644 pymic/net_run/agent_rec.py diff --git a/pymic/net_run/__init__.py b/pymic/net_run/__init__.py index 72b8078..e69de29 100644 --- a/pymic/net_run/__init__.py +++ b/pymic/net_run/__init__.py @@ -1,2 +0,0 @@ -from __future__ import absolute_import -from . import * \ No newline at end of file diff --git a/pymic/net_run/agent_abstract.py b/pymic/net_run/agent_abstract.py index 7a509d1..2e6c062 100644 --- a/pymic/net_run/agent_abstract.py +++ b/pymic/net_run/agent_abstract.py @@ -57,7 +57,7 @@ def __init__(self, config, stage = 'train'): self.transform_dict = None self.inferer = None self.tensor_type = config['dataset']['tensor_type'] - self.task_type = config['dataset']['task_type'] #cls, cls_mtbc, seg + self.task_type = config['dataset']['task_type'] #cls, cls_mtbc, seg, rec self.deterministic = config['training'].get('deterministic', True) self.random_seed = config['training'].get('random_seed', 1) if(self.deterministic): @@ -282,7 +282,6 @@ def create_optimizer(self, params, checkpoint = None): :param params: network parameters for optimization. Usually it is obtained by `self.get_parameters_to_update()`. - :param checkpoint: A previous checkpoint to load. Default is `None`. """ opt_params = self.config['training'] if(self.optimizer is None): diff --git a/pymic/net_run/agent_cls.py b/pymic/net_run/agent_cls.py index 5610982..862a703 100644 --- a/pymic/net_run/agent_cls.py +++ b/pymic/net_run/agent_cls.py @@ -58,7 +58,7 @@ def get_stage_dataset_from_config(self, stage): data_transform = None else: transform_param = self.config['dataset'] - transform_param['task'] = 'classification' + transform_param['task'] = self.task_type for name in transform_names: if(name not in self.transform_dict): raise(ValueError("Undefined transform {0:}".format(name))) diff --git a/pymic/net_run/agent_rec.py b/pymic/net_run/agent_rec.py new file mode 100644 index 0000000..607623d --- /dev/null +++ b/pymic/net_run/agent_rec.py @@ -0,0 +1,274 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import copy +import logging +import time +import logging +import numpy as np +import os +import scipy +import torch +import torch.nn as nn +from datetime import datetime +from torch.optim import lr_scheduler +from tensorboardX import SummaryWriter +from pymic.io.image_read_write import save_nd_array_as_image +from pymic.net_run.infer_func import Inferer +from pymic.net_run.agent_seg import SegmentationAgent +from pymic.loss.seg.mse import MAELoss, MSELoss + +ReconstructionLossDict = { + 'MAELoss': MAELoss, + 'MSELoss': MSELoss + } + +class ReconstructionAgent(SegmentationAgent): + """ + An agent for image reconstruction (pixel-level intensity prediction). + """ + def __init__(self, config, stage = 'train'): + super(ReconstructionAgent, self).__init__(config, stage) + output_act_name = config['network'].get('output_activation', 'sigmoid') + if(output_act_name == "sigmoid"): + self.out_act = nn.Sigmoid() + elif(output_act_name == "tanh"): + self.out_act = nn.Tanh() + else: + raise ValueError("For reconstruction task, only sigmoid and tanh are " + \ + "supported for output_activation.") + + def create_loss_calculator(self): + if(self.loss_dict is None): + self.loss_dict = ReconstructionLossDict + loss_name = self.config['training']['loss_type'] + if isinstance(loss_name, (list, tuple)): + raise ValueError("Undefined loss function {0:}".format(loss_name)) + elif (loss_name not in self.loss_dict): + raise ValueError("Undefined loss function {0:}".format(loss_name)) + else: + loss_param = self.config['training'] + loss_param['loss_softmax'] = False + base_loss = self.loss_dict[loss_name](self.config['training']) + if(self.config['training'].get('deep_supervise', False)): + raise ValueError("Deep supervised loss not implemented for reconstruction tasks") + # weight = self.config['training'].get('deep_supervise_weight', None) + # mode = self.config['training'].get('deep_supervise_mode', 2) + # params = {'deep_supervise_weight': weight, + # 'deep_supervise_mode': mode, + # 'base_loss':base_loss} + # self.loss_calculator = DeepSuperviseLoss(params) + else: + self.loss_calculator = base_loss + + def training(self): + iter_valid = self.config['training']['iter_valid'] + train_loss = 0 + self.net.train() + for it in range(iter_valid): + try: + data = next(self.trainIter) + except StopIteration: + self.trainIter = iter(self.train_loader) + data = next(self.trainIter) + # get the inputs + inputs = self.convert_tensor_type(data['image']) + label = self.convert_tensor_type(data['label']) + + # for debug + # from pymic.io.image_read_write import save_nd_array_as_image + # for i in range(inputs.shape[0]): + # image_i = inputs[i][0] + # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) + # save_nd_array_as_image(image_i, image_name, reference_name = None) + # return + + inputs, label = inputs.to(self.device), label.to(self.device) + + # zero the parameter gradients + self.optimizer.zero_grad() + + # forward + backward + optimize + outputs = self.net(inputs) + outputs = self.out_act(outputs) + loss = self.get_loss_value(data, outputs, label) + loss.backward() + self.optimizer.step() + train_loss = train_loss + loss.item() + # get dice evaluation for each class + if(isinstance(outputs, tuple) or isinstance(outputs, list)): + outputs = outputs[0] + + train_avg_loss = train_loss / iter_valid + train_scalers = {'loss': train_avg_loss} + return train_scalers + + def validation(self): + class_num = self.config['network']['class_num'] + if(self.inferer is None): + infer_cfg = self.config['testing'] + infer_cfg['class_num'] = class_num + self.inferer = Inferer(infer_cfg) + + valid_loss_list = [] + validIter = iter(self.valid_loader) + with torch.no_grad(): + self.net.eval() + for data in validIter: + inputs = self.convert_tensor_type(data['image']) + label = self.convert_tensor_type(data['label']) + inputs, label = inputs.to(self.device), label.to(self.device) + outputs = self.inferer.run(self.net, inputs) + outputs = self.out_act(outputs) + # The tensors are on CPU when calculating loss for validation data + loss = self.get_loss_value(data, outputs, label) + valid_loss_list.append(loss.item()) + + valid_avg_loss = np.asarray(valid_loss_list).mean() + valid_scalers = {'loss': valid_avg_loss} + return valid_scalers + + def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): + loss_scalar ={'train':train_scalars['loss'], + 'valid':valid_scalars['loss']} + self.summ_writer.add_scalars('loss', loss_scalar, glob_it) + self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it) + logging.info('train loss {0:.4f}'.format(train_scalars['loss'])) + logging.info('valid loss {0:.4f}'.format(valid_scalars['loss'])) + + def train_valid(self): + device_ids = self.config['training']['gpus'] + if(len(device_ids) > 1): + self.device = torch.device("cuda:0") + self.net = nn.DataParallel(self.net, device_ids = device_ids) + else: + self.device = torch.device("cuda:{0:}".format(device_ids[0])) + self.net.to(self.device) + ckpt_dir = self.config['training']['ckpt_save_dir'] + ckpt_prefix = self.config['training'].get('ckpt_prefix', None) + if(ckpt_prefix is None): + ckpt_prefix = ckpt_dir.split('/')[-1] + iter_start = self.config['training']['iter_start'] + iter_max = self.config['training']['iter_max'] + iter_valid = self.config['training']['iter_valid'] + iter_save = self.config['training'].get('iter_save', None) + early_stop_it = self.config['training'].get('early_stop_patience', None) + if(iter_save is None): + iter_save_list = [iter_max] + elif(isinstance(iter_save, (tuple, list))): + iter_save_list = iter_save + else: + iter_save_list = range(0, iter_max + 1, iter_save) + + self.min_val_loss = 10000.0 + self.max_val_it = 0 + self.best_model_wts = None + self.checkpoint = None + if(iter_start > 0): + checkpoint_file = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, iter_start) + self.checkpoint = torch.load(checkpoint_file, map_location = self.device) + # assert(self.checkpoint['iteration'] == iter_start) + if(len(device_ids) > 1): + self.net.module.load_state_dict(self.checkpoint['model_state_dict']) + else: + self.net.load_state_dict(self.checkpoint['model_state_dict']) + self.min_val_loss = self.checkpoint.get('valid_loss', 10000) + # self.max_val_it = self.checkpoint['iteration'] + self.max_val_it = iter_start + self.best_model_wts = self.checkpoint['model_state_dict'] + + self.create_optimizer(self.get_parameters_to_update()) + self.create_loss_calculator() + + self.trainIter = iter(self.train_loader) + + logging.info("{0:} training start".format(str(datetime.now())[:-7])) + self.summ_writer = SummaryWriter(self.config['training']['ckpt_save_dir']) + self.glob_it = iter_start + for it in range(iter_start, iter_max, iter_valid): + lr_value = self.optimizer.param_groups[0]['lr'] + t0 = time.time() + train_scalars = self.training() + t1 = time.time() + valid_scalars = self.validation() + t2 = time.time() + if(isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): + self.scheduler.step(-valid_scalars['loss']) + else: + self.scheduler.step() + + self.glob_it = it + iter_valid + logging.info("\n{0:} it {1:}".format(str(datetime.now())[:-7], self.glob_it)) + logging.info('learning rate {0:}'.format(lr_value)) + logging.info("training/validation time: {0:.2f}s/{1:.2f}s".format(t1-t0, t2-t1)) + self.write_scalars(train_scalars, valid_scalars, lr_value, self.glob_it) + if(valid_scalars['loss'] < self.min_val_loss): + self.min_val_loss = valid_scalars['loss'] + self.max_val_it = self.glob_it + if(len(device_ids) > 1): + self.best_model_wts = copy.deepcopy(self.net.module.state_dict()) + else: + self.best_model_wts = copy.deepcopy(self.net.state_dict()) + + stop_now = True if(early_stop_it is not None and \ + self.glob_it - self.max_val_it > early_stop_it) else False + if ((self.glob_it in iter_save_list) or stop_now): + save_dict = {'iteration': self.glob_it, + 'valid_loss': valid_scalars['loss'], + 'model_state_dict': self.net.module.state_dict() \ + if len(device_ids) > 1 else self.net.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict()} + save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.glob_it) + torch.save(save_dict, save_name) + txt_file = open("{0:}/{1:}_latest.txt".format(ckpt_dir, ckpt_prefix), 'wt') + txt_file.write(str(self.glob_it)) + txt_file.close() + if(stop_now): + logging.info("The training is early stopped") + break + # save the best performing checkpoint + save_dict = {'iteration': self.max_val_it, + 'valid_loss': self.min_val_loss, + 'model_state_dict': self.best_model_wts, + 'optimizer_state_dict': self.optimizer.state_dict()} + save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.max_val_it) + torch.save(save_dict, save_name) + txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefix), 'wt') + txt_file.write(str(self.max_val_it)) + txt_file.close() + logging.info('The best performing iter is {0:}, valid loss {1:}'.format(\ + self.max_val_it, self.min_val_loss)) + self.summ_writer.close() + + def save_outputs(self, data): + """ + Save prediction output. + + :param data: (dictionary) A data dictionary with prediciton result and other + information such as input image name. + """ + output_dir = self.config['testing']['output_dir'] + ignore_dir = self.config['testing'].get('filename_ignore_dir', True) + filename_replace_source = self.config['testing'].get('filename_replace_source', None) + filename_replace_target = self.config['testing'].get('filename_replace_target', None) + if(not os.path.exists(output_dir)): + os.makedirs(output_dir, exist_ok=True) + + names, pred = data['names'], data['predict'] + if(isinstance(pred, (list, tuple))): + pred = pred[0] + if(isinstance(self.out_act, nn.Sigmoid)): + pred = scipy.special.expit(pred) + else: + pred = np.tanh(pred) + # save the output predictions + root_dir = self.config['dataset']['root_dir'] + for i in range(len(names)): + save_name = names[i].split('/')[-1] if ignore_dir else \ + names[i].replace('/', '_') + if((filename_replace_source is not None) and (filename_replace_target is not None)): + save_name = save_name.replace(filename_replace_source, filename_replace_target) + print(save_name) + save_name = "{0:}/{1:}".format(output_dir, save_name) + save_nd_array_as_image(pred[i][i], save_name, root_dir + '/' + names[i]) + + \ No newline at end of file diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 887a516..b85c716 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -54,7 +54,7 @@ def get_stage_dataset_from_config(self, stage): data_transform = None else: transform_param = self.config['dataset'] - transform_param['task'] = 'segmentation' + transform_param['task'] = self.task_type for name in transform_names: if(name not in self.transform_dict): raise(ValueError("Undefined transform {0:}".format(name))) diff --git a/pymic/net_run/get_optimizer.py b/pymic/net_run/get_optimizer.py index 755ebcf..ad8fda0 100644 --- a/pymic/net_run/get_optimizer.py +++ b/pymic/net_run/get_optimizer.py @@ -7,15 +7,6 @@ from pymic.util.general import keyword_match def get_optimizer(name, net_params, optim_params): - """ - Create an optimizer for learnable parameters. - - :param name: (string) Name of the optimizer. Should be one of {`SGD`, `Adam`, - `SparseAdam`, `Adadelta`, `Adagrad`, `Adamax`, `ASGD`, `LBFGS`, `RMSprop`, `Rprop`}. - :param net_params: Learnable parameters that need to be set for an optimizer. - :param optim_params: (dict) The parameters required for the target optimizer. - :return: An instance of the target optimizer. - """ lr = optim_params['learning_rate'] momentum = optim_params['momentum'] weight_decay = optim_params['weight_decay'] @@ -48,13 +39,6 @@ def get_optimizer(name, net_params, optim_params): def get_lr_scheduler(optimizer, sched_params): - """ - Create learning rate scheduler for an optimizer - - :param optimizer: An optimizer instance. - :param sched_params: (dict) The parameters required for the scheduler. - :return: An instance of the target learning rate scheduler. - """ name = sched_params["lr_scheduler"] val_it = sched_params["iter_valid"] epoch_last = sched_params["last_iter"] diff --git a/pymic/net_run/predict.py b/pymic/net_run/predict.py index ca4ef25..31fff86 100644 --- a/pymic/net_run/predict.py +++ b/pymic/net_run/predict.py @@ -7,6 +7,7 @@ from pymic.util.parse_config import * from pymic.net_run.agent_cls import ClassificationAgent from pymic.net_run.agent_seg import SegmentationAgent +from pymic.net_run.agent_rec import ReconstructionAgent def main(): """ @@ -34,11 +35,13 @@ def main(): logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging_config(config) task = config['dataset']['task_type'] - assert task in ['cls', 'cls_nexcl', 'seg'] + assert task in ['cls', 'cls_nexcl', 'seg', 'rec'] if(task == 'cls' or task == 'cls_nexcl'): agent = ClassificationAgent(config, 'test') - else: + elif(task == 'seg'): agent = SegmentationAgent(config, 'test') + else: + agent = ReconstructionAgent(config, 'test') agent.run() if __name__ == "__main__": diff --git a/pymic/net_run/self_sup/self_sl_agent.py b/pymic/net_run/self_sup/self_sl_agent.py index 24a6e66..c352adf 100644 --- a/pymic/net_run/self_sup/self_sl_agent.py +++ b/pymic/net_run/self_sup/self_sl_agent.py @@ -3,31 +3,10 @@ import copy import logging import time -import logging -import numpy as np -import random -import torch -import torch.nn as nn -import torchvision.transforms as transforms -from datetime import datetime -from random import random -from torch.optim import lr_scheduler -from tensorboardX import SummaryWriter -from pymic.io.nifty_dataset import NiftyDataset -from pymic.loss.seg.util import get_soft_label -from pymic.loss.seg.util import reshape_prediction_and_ground_truth -from pymic.loss.seg.util import get_classwise_dice -from pymic.net_run.infer_func import Inferer -from pymic.net_run.agent_seg import SegmentationAgent -from pymic.transform.trans_dict import TransformDict -from pymic.loss.seg.mse import MAELoss, MSELoss +from pymic.net_run.agent_rec import ReconstructionAgent -RegressionLossDict = { - 'MAELoss': MAELoss, - 'MSELoss': MSELoss - } -class SelfSLSegAgent(SegmentationAgent): +class SelfSLSegAgent(ReconstructionAgent): """ Abstract class for self-supervised segmentation. @@ -42,204 +21,4 @@ class SelfSLSegAgent(SegmentationAgent): """ def __init__(self, config, stage = 'train'): super(SelfSLSegAgent, self).__init__(config, stage) - self.transform_dict = TransformDict - - def create_loss_calculator(self): - if(self.loss_dict is None): - self.loss_dict = RegressionLossDict - loss_name = self.config['training']['loss_type'] - if isinstance(loss_name, (list, tuple)): - raise ValueError("Undefined loss function {0:}".format(loss_name)) - elif (loss_name not in self.loss_dict): - raise ValueError("Undefined loss function {0:}".format(loss_name)) - else: - loss_param = self.config['training'] - loss_param['loss_softmax'] = False - base_loss = self.loss_dict[loss_name](self.config['training']) - if(self.config['training'].get('deep_supervise', False)): - raise ValueError("Deep supervised loss not implemented for self-supervised learning") - # weight = self.config['training'].get('deep_supervise_weight', None) - # mode = self.config['training'].get('deep_supervise_mode', 2) - # params = {'deep_supervise_weight': weight, - # 'deep_supervise_mode': mode, - # 'base_loss':base_loss} - # self.loss_calculator = DeepSuperviseLoss(params) - else: - self.loss_calculator = base_loss - - def training(self): - iter_valid = self.config['training']['iter_valid'] - train_loss = 0 - self.net.train() - for it in range(iter_valid): - try: - data = next(self.trainIter) - except StopIteration: - self.trainIter = iter(self.train_loader) - data = next(self.trainIter) - # get the inputs - inputs = self.convert_tensor_type(data['image']) - label = self.convert_tensor_type(data['label']) - - # for debug - # from pymic.io.image_read_write import save_nd_array_as_image - # for i in range(inputs.shape[0]): - # image_i = inputs[i][0] - # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) - # save_nd_array_as_image(image_i, image_name, reference_name = None) - # return - - inputs, label = inputs.to(self.device), label.to(self.device) - - # zero the parameter gradients - self.optimizer.zero_grad() - - # forward + backward + optimize - outputs = self.net(inputs) - outputs = nn.Sigmoid()(outputs) - loss = self.get_loss_value(data, outputs, label) - loss.backward() - self.optimizer.step() - train_loss = train_loss + loss.item() - # get dice evaluation for each class - if(isinstance(outputs, tuple) or isinstance(outputs, list)): - outputs = outputs[0] - - train_avg_loss = train_loss / iter_valid - train_scalers = {'loss': train_avg_loss} - return train_scalers - - def validation(self): - if(self.inferer is None): - infer_cfg = self.config['testing'] - self.inferer = Inferer(infer_cfg) - - valid_loss_list = [] - validIter = iter(self.valid_loader) - with torch.no_grad(): - self.net.eval() - for data in validIter: - inputs = self.convert_tensor_type(data['image']) - label = self.convert_tensor_type(data['label']) - inputs, label = inputs.to(self.device), label.to(self.device) - outputs = self.inferer.run(self.net, inputs) - outputs = nn.Sigmoid()(outputs) - # The tensors are on CPU when calculating loss for validation data - loss = self.get_loss_value(data, outputs, label) - valid_loss_list.append(loss.item()) - - valid_avg_loss = np.asarray(valid_loss_list).mean() - valid_scalers = {'loss': valid_avg_loss} - return valid_scalers - - def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): - loss_scalar ={'train':train_scalars['loss'], - 'valid':valid_scalars['loss']} - self.summ_writer.add_scalars('loss', loss_scalar, glob_it) - self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it) - logging.info('train loss {0:.4f}'.format(train_scalars['loss'])) - logging.info('valid loss {0:.4f}'.format(valid_scalars['loss'])) - - def train_valid(self): - device_ids = self.config['training']['gpus'] - if(len(device_ids) > 1): - self.device = torch.device("cuda:0") - self.net = nn.DataParallel(self.net, device_ids = device_ids) - else: - self.device = torch.device("cuda:{0:}".format(device_ids[0])) - self.net.to(self.device) - ckpt_dir = self.config['training']['ckpt_save_dir'] - ckpt_prefix = self.config['training'].get('ckpt_prefix', None) - if(ckpt_prefix is None): - ckpt_prefix = ckpt_dir.split('/')[-1] - iter_start = self.config['training']['iter_start'] - iter_max = self.config['training']['iter_max'] - iter_valid = self.config['training']['iter_valid'] - iter_save = self.config['training'].get('iter_save', None) - early_stop_it = self.config['training'].get('early_stop_patience', None) - if(iter_save is None): - iter_save_list = [iter_max] - elif(isinstance(iter_save, (tuple, list))): - iter_save_list = iter_save - else: - iter_save_list = range(0, iter_max + 1, iter_save) - - self.min_val_loss = 10000.0 - self.max_val_it = 0 - self.best_model_wts = None - self.checkpoint = None - if(iter_start > 0): - checkpoint_file = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, iter_start) - self.checkpoint = torch.load(checkpoint_file, map_location = self.device) - # assert(self.checkpoint['iteration'] == iter_start) - if(len(device_ids) > 1): - self.net.module.load_state_dict(self.checkpoint['model_state_dict']) - else: - self.net.load_state_dict(self.checkpoint['model_state_dict']) - self.min_val_loss = self.checkpoint.get('valid_loss', 10000) - # self.max_val_it = self.checkpoint['iteration'] - self.max_val_it = iter_start - self.best_model_wts = self.checkpoint['model_state_dict'] - - self.create_optimizer(self.get_parameters_to_update()) - self.create_loss_calculator() - - self.trainIter = iter(self.train_loader) - - logging.info("{0:} training start".format(str(datetime.now())[:-7])) - self.summ_writer = SummaryWriter(self.config['training']['ckpt_save_dir']) - self.glob_it = iter_start - for it in range(iter_start, iter_max, iter_valid): - lr_value = self.optimizer.param_groups[0]['lr'] - t0 = time.time() - train_scalars = self.training() - t1 = time.time() - valid_scalars = self.validation() - t2 = time.time() - if(isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): - self.scheduler.step(-valid_scalars['loss']) - else: - self.scheduler.step() - - self.glob_it = it + iter_valid - logging.info("\n{0:} it {1:}".format(str(datetime.now())[:-7], self.glob_it)) - logging.info('learning rate {0:}'.format(lr_value)) - logging.info("training/validation time: {0:.2f}s/{1:.2f}s".format(t1-t0, t2-t1)) - self.write_scalars(train_scalars, valid_scalars, lr_value, self.glob_it) - if(valid_scalars['loss'] < self.min_val_loss): - self.min_val_loss = valid_scalars['loss'] - self.max_val_it = self.glob_it - if(len(device_ids) > 1): - self.best_model_wts = copy.deepcopy(self.net.module.state_dict()) - else: - self.best_model_wts = copy.deepcopy(self.net.state_dict()) - - stop_now = True if(early_stop_it is not None and \ - self.glob_it - self.max_val_it > early_stop_it) else False - if ((self.glob_it in iter_save_list) or stop_now): - save_dict = {'iteration': self.glob_it, - 'valid_loss': valid_scalars['loss'], - 'model_state_dict': self.net.module.state_dict() \ - if len(device_ids) > 1 else self.net.state_dict(), - 'optimizer_state_dict': self.optimizer.state_dict()} - save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.glob_it) - torch.save(save_dict, save_name) - txt_file = open("{0:}/{1:}_latest.txt".format(ckpt_dir, ckpt_prefix), 'wt') - txt_file.write(str(self.glob_it)) - txt_file.close() - if(stop_now): - logging.info("The training is early stopped") - break - # save the best performing checkpoint - save_dict = {'iteration': self.max_val_it, - 'valid_loss': self.min_val_loss, - 'model_state_dict': self.best_model_wts, - 'optimizer_state_dict': self.optimizer.state_dict()} - save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.max_val_it) - torch.save(save_dict, save_name) - txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefix), 'wt') - txt_file.write(str(self.max_val_it)) - txt_file.close() - logging.info('The best performing iter is {0:}, valid loss {1:}'.format(\ - self.max_val_it, self.min_val_loss)) - self.summ_writer.close() \ No newline at end of file + \ No newline at end of file diff --git a/pymic/net_run/train.py b/pymic/net_run/train.py index 1478527..107a519 100644 --- a/pymic/net_run/train.py +++ b/pymic/net_run/train.py @@ -13,7 +13,7 @@ from pymic.net_run.noisy_label import NLLMethodDict from pymic.net_run.self_sup import SelfSLSegAgent -def get_segmentation_agent(config, sup_type): +def get_seg_rec_agent(config, sup_type): assert(sup_type in ['fully_sup', 'semi_sup', 'self_sup', 'weak_sup', 'noisy_label']) if(sup_type == 'fully_sup'): logging.info("\n********** Fully Supervised Learning **********\n") @@ -86,12 +86,13 @@ def main(): logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging_config(config) task = config['dataset']['task_type'] - assert task in ['cls', 'cls_nexcl', 'seg'] + assert task in ['cls', 'cls_nexcl', 'seg', 'rec'] if(task == 'cls' or task == 'cls_nexcl'): agent = ClassificationAgent(config, 'train') else: sup_type = config['dataset'].get('supervise_type', 'fully_sup') - agent = get_segmentation_agent(config, sup_type) + agent = get_seg_rec_agent(config, sup_type) + agent.run() if __name__ == "__main__": diff --git a/pymic/transform/crop.py b/pymic/transform/crop.py index a27288d..60f5c9b 100644 --- a/pymic/transform/crop.py +++ b/pymic/transform/crop.py @@ -55,12 +55,12 @@ def __call__(self, sample): image_t = crop_ND_volume_with_bounding_box(image, crop_min, crop_max) sample['image'] = image_t - if('label' in sample and self.task == 'segmentation'): + if('label' in sample and self.task in ['seg', 'rec']): label = sample['label'] crop_max[0] = label.shape[0] label = crop_ND_volume_with_bounding_box(label, crop_min, crop_max) sample['label'] = label - if('pixel_weight' in sample and self.task == 'segmentation'): + if('pixel_weight' in sample and self.task in ['seg', 'rec']): weight = sample['pixel_weight'] crop_max[0] = weight.shape[0] weight = crop_ND_volume_with_bounding_box(weight, crop_min, crop_max) @@ -300,13 +300,13 @@ def __call__(self, sample): image_t = ndimage.interpolation.zoom(image_t, scale, order = 1) sample['image'] = image_t - if('label' in sample and self.task == 'segmentation'): + if('label' in sample and self.task in ['seg', 'rec']): label = sample['label'] crop_max[0] = label.shape[0] label = crop_ND_volume_with_bounding_box(label, crop_min, crop_max) label = ndimage.interpolation.zoom(label, scale, order = 0) sample['label'] = label - if('pixel_weight' in sample and self.task == 'segmentation'): + if('pixel_weight' in sample and self.task ['seg', 'rec']): weight = sample['pixel_weight'] crop_max[0] = weight.shape[0] weight = crop_ND_volume_with_bounding_box(weight, crop_min, crop_max) diff --git a/pymic/transform/flip.py b/pymic/transform/flip.py index ca0915e..0462935 100644 --- a/pymic/transform/flip.py +++ b/pymic/transform/flip.py @@ -52,9 +52,9 @@ def __call__(self, sample): # current pytorch does not support negative strides image_t = np.flip(image, flip_axis).copy() sample['image'] = image_t - if('label' in sample and self.task == 'segmentation'): + if('label' in sample and self.task in ['seg', 'rec']): sample['label'] = np.flip(sample['label'] , flip_axis).copy() - if('pixel_weight' in sample and self.task == 'segmentation'): + if('pixel_weight' in sample and self.task in ['seg', 'rec']): sample['pixel_weight'] = np.flip(sample['pixel_weight'] , flip_axis).copy() return sample diff --git a/pymic/transform/intensity.py b/pymic/transform/intensity.py index 96458fa..3b5ee9d 100644 --- a/pymic/transform/intensity.py +++ b/pymic/transform/intensity.py @@ -14,20 +14,21 @@ def bernstein_poly(i, n, t): """ - The Bernstein polynomial of n, i as a function of t. + The Bernstein polynomial of n, i as a function of t """ return comb(n, i) * ( t**(n-i) ) * (1 - t)**i def bezier_curve(points, nTimes=1000): """ - Given a set of control points, return the - bezier curve defined by the control points. - Control points should be a list of lists, or list of tuples - such as [ [1,1], [2,3], [4,5], ..[Xn, Yn] ]. - - `nTimes` is the number of time steps, defaults to 1000. - See http://processingjs.nihongoresources.com/bezierinfo/ + Given a set of control points, return the + bezier curve defined by the control points. + Control points should be a list of lists, or list of tuples + such as [ [1,1], + [2,3], + [4,5], ..[Xn, Yn] ] + nTimes is the number of time steps, defaults to 1000 + See http://processingjs.nihongoresources.com/bezierinfo/ """ nPoints = len(points) diff --git a/pymic/transform/label_convert.py b/pymic/transform/label_convert.py index 0dcae37..1efaab6 100644 --- a/pymic/transform/label_convert.py +++ b/pymic/transform/label_convert.py @@ -93,6 +93,27 @@ def __call__(self, sample): sample['label_prob'] = label_prob return sample +class LabelSmooth(AbstractTransform): + """ + Apply label smoothing to one-hot labels. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `LabelSmooth_alpha`: (float) Alpha value for label smoothing. + :param `LabelSmooth_inverse`: (optional, bool) + Is inverse transform needed for inference. Default is `False`. + """ + def __init__(self, params): + super(LabelSmooth, self).__init__(params) + self.alpha = params['LabelSmooth_alpha'.lower()] + self.inverse = params.get('LabelSmooth_inverse'.lower(), False) + + def __call__(self, sample): + label_prob = sample['label_prob'] + K = list(label_prob.shape)[1] + sample['label_prob'] = label_prob * (1.0 - self.alpha) + self.alpha / K + return sample class PartialLabelToProbability(AbstractTransform): """ @@ -130,5 +151,34 @@ def __call__(self, sample): return sample +class SelfSuperviseLabel(AbstractTransform): + """ + Convert one-channel partial label map to one-hot multi-channel probability map. + This is used for segmentation tasks only. In the input label map, 0 represents the + background class, 1 to C-1 represent the foreground classes, and C represents + unlabeled pixels. In the output dictionary, `label_prob` is the one-hot probability + map, and `pixel_weight` represents a weighting map, where the weight for a pixel + is 0 if the label is unkown. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `PartialLabelToProbability_class_num`: (int) The class number for the + segmentation task. + :param `PartialLabelToProbability_inverse`: (optional, bool) + Is inverse transform needed for inference. Default is `False`. + """ + def __init__(self, params): + """ + class_num (int): the class number in the label map + """ + super(SelfSuperviseLabel, self).__init__(params) + self.inverse = params.get('SelfSuperviseLabel_inverse'.lower(), False) + + def __call__(self, sample): + image = sample['image'] + label = image * 1.0 + sample['label'] = label + return sample diff --git a/pymic/transform/pad.py b/pymic/transform/pad.py index 0ec196c..4d292be 100644 --- a/pymic/transform/pad.py +++ b/pymic/transform/pad.py @@ -59,11 +59,11 @@ def __call__(self, sample): sample['image'] = image_t - if('label' in sample and self.task == 'segmentation'): + if('label' in sample and self.task in ['seg', 'rec']): label = sample['label'] label = np.pad(label, pad, 'reflect') if(max(margin) > 0) else label sample['label'] = label - if('pixel_weight' in sample and self.task == 'segmentation'): + if('pixel_weight' in sample and self.task in ['seg', 'rec']): weight = sample['pixel_weight'] weight = np.pad(weight, pad, 'reflect') if(max(margin) > 0) else weight sample['pixel_weight'] = weight diff --git a/pymic/transform/rescale.py b/pymic/transform/rescale.py index 2a671fd..660b156 100644 --- a/pymic/transform/rescale.py +++ b/pymic/transform/rescale.py @@ -48,11 +48,11 @@ def __call__(self, sample): sample['image'] = image_t sample['Rescale_origin_shape'] = json.dumps(input_shape) - if('label' in sample and self.task == 'segmentation'): + if('label' in sample and self.task in ['seg', 'rec']): label = sample['label'] label = ndimage.interpolation.zoom(label, scale, order = 0) sample['label'] = label - if('pixel_weight' in sample and self.task == 'segmentation'): + if('pixel_weight' in sample and self.task in ['seg', 'rec']): weight = sample['pixel_weight'] weight = ndimage.interpolation.zoom(weight, scale, order = 1) sample['pixel_weight'] = weight @@ -126,11 +126,11 @@ def __call__(self, sample): sample['image'] = image_t sample['RandomRescale_Param'] = json.dumps(input_shape) - if('label' in sample and self.task == 'segmentation'): + if('label' in sample and self.task in ['seg', 'rec']): label = sample['label'] label = ndimage.interpolation.zoom(label, scale, order = 0) sample['label'] = label - if('pixel_weight' in sample and self.task == 'segmentation'): + if('pixel_weight' in sample and self.task in ['seg', 'rec']): weight = sample['pixel_weight'] weight = ndimage.interpolation.zoom(weight, scale, order = 1) sample['pixel_weight'] = weight diff --git a/pymic/transform/rotate.py b/pymic/transform/rotate.py index 2aa06d4..bd68e6a 100644 --- a/pymic/transform/rotate.py +++ b/pymic/transform/rotate.py @@ -78,10 +78,10 @@ def __call__(self, sample): sample['RandomRotate_Param'] = json.dumps(transform_param_list) image_t = self.__apply_transformation(image, transform_param_list, 1) sample['image'] = image_t - if('label' in sample and self.task == 'segmentation'): + if('label' in sample and self.task in ['seg', 'rec']): sample['label'] = self.__apply_transformation(sample['label'] , transform_param_list, 0) - if('pixel_weight' in sample and self.task == 'segmentation'): + if('pixel_weight' in sample and self.task in ['seg', 'rec']): sample['pixel_weight'] = self.__apply_transformation(sample['pixel_weight'] , transform_param_list, 1) return sample diff --git a/pymic/transform/trans_dict.py b/pymic/transform/trans_dict.py index bc72c93..a7d96fd 100644 --- a/pymic/transform/trans_dict.py +++ b/pymic/transform/trans_dict.py @@ -25,6 +25,7 @@ 'RandomRotate': RandomRotate, 'ReduceLabelDim': ReduceLabelDim, 'Rescale': Rescale, + 'SelfSuperviseLabel': SelfSuperviseLabel, 'Pad': Pad. """ @@ -67,6 +68,7 @@ 'RandomRotate': RandomRotate, 'ReduceLabelDim': ReduceLabelDim, 'Rescale': Rescale, + 'SelfSuperviseLabel': SelfSuperviseLabel, 'OutPainting': OutPainting, 'Pad': Pad, } diff --git a/pymic/transform/transpose.py b/pymic/transform/transpose.py index 9c73bda..67e8611 100644 --- a/pymic/transform/transpose.py +++ b/pymic/transform/transpose.py @@ -39,10 +39,11 @@ def __call__(self, sample): if(transpose_axis is not None): image_t = np.transpose(image, transpose_axis) sample['image'] = image_t - if('label' in sample and self.task == 'segmentation'): + if('label' in sample and self.task in ['seg', 'rec']): sample['label'] = np.transpose(sample['label'] , transpose_axis) - if('pixel_weight' in sample and self.task == 'segmentation'): + if('pixel_weight' in sample and self.task in ['seg', 'rec']): sample['pixel_weight'] = np.transpose(sample['pixel_weight'] , transpose_axis) + return sample From 0015adefba9895c9591d7952511f444c088d2611 Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 20 Mar 2023 10:54:48 +0800 Subject: [PATCH 144/225] update task type update task type --- pymic/__init__.py | 19 ++++++++++++++++++- pymic/io/nifty_dataset.py | 22 ++++++++++++++++++---- pymic/net_run/agent_abstract.py | 2 +- pymic/net_run/agent_cls.py | 12 ++++++------ pymic/net_run/agent_rec.py | 21 +++++++++++++++++++++ pymic/net_run/agent_seg.py | 3 ++- pymic/net_run/predict.py | 12 +++++++----- pymic/net_run/semi_sup/ssl_abstract.py | 2 +- pymic/net_run/train.py | 6 +++--- pymic/transform/crop.py | 13 +++++++++---- pymic/transform/flip.py | 7 +++++-- pymic/transform/label_convert.py | 5 +++-- pymic/transform/pad.py | 7 +++++-- pymic/transform/rescale.py | 13 +++++++++---- pymic/transform/rotate.py | 7 +++++-- pymic/transform/transpose.py | 11 ++++++----- pymic/util/parse_config.py | 2 ++ pymic/util/preprocess.py | 1 + 18 files changed, 122 insertions(+), 43 deletions(-) diff --git a/pymic/__init__.py b/pymic/__init__.py index cb6356a..1520d82 100644 --- a/pymic/__init__.py +++ b/pymic/__init__.py @@ -1,2 +1,19 @@ from __future__ import absolute_import -__version__ = "0.4.0" \ No newline at end of file +from enum import Enum + +__version__ = "0.4.0" + +class TaskType(Enum): + CLASSIFICATION_ONE_HOT = 1 + CLASSIFICATION_COEXIST = 2 + REGRESSION = 3 + SEGMENTATION = 4 + RECONSTRUCTION = 5 + +TaskDict = { + 'cls': TaskType.CLASSIFICATION_ONE_HOT, + 'cls_coexist': TaskType.CLASSIFICATION_COEXIST, + 'regress': TaskType.REGRESSION, + 'seg': TaskType.SEGMENTATION, + 'rec': TaskType.RECONSTRUCTION +} \ No newline at end of file diff --git a/pymic/io/nifty_dataset.py b/pymic/io/nifty_dataset.py index bb1ff23..9812d13 100644 --- a/pymic/io/nifty_dataset.py +++ b/pymic/io/nifty_dataset.py @@ -1,12 +1,14 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division +import logging import os import torch import pandas as pd import numpy as np from torch.utils.data import Dataset, DataLoader from torchvision import transforms, utils +from pymic import TaskType from pymic.io.image_read_write import load_image_as_nd_array class NiftyDataset(Dataset): @@ -23,14 +25,21 @@ class NiftyDataset(Dataset): The built-in transforms can listed in :mod:`pymic.transform.trans_dict`. """ def __init__(self, root_dir, csv_file, modal_num = 1, - with_label = False, transform=None): + with_label = False, transform=None, task = TaskType.SEGMENTATION): self.root_dir = root_dir self.csv_items = pd.read_csv(csv_file) self.modal_num = modal_num self.with_label = with_label self.transform = transform + self.task = task + assert self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION] csv_keys = list(self.csv_items.keys()) + if('label' not in csv_keys): + logging.warning("`label` section is not found in the csv file {0:}".format( + csv_file) + "\n -- This is only allowed for self-supervised learning" + + "\n -- when `SelfSuperviseLabel` is used in the transform.") + self.with_label = False self.image_weight_idx = None self.pixel_weight_idx = None if('image_weight' in csv_keys): @@ -42,12 +51,15 @@ def __len__(self): return len(self.csv_items) def __getlabel__(self, idx): - csv_keys = list(self.csv_items.keys()) + csv_keys = list(self.csv_items.keys()) label_idx = csv_keys.index('label') label_name = "{0:}/{1:}".format(self.root_dir, self.csv_items.iloc[idx, label_idx]) label = load_image_as_nd_array(label_name)['data_array'] - label = np.asarray(label, np.int32) + if(self.task == TaskType.SEGMENTATION): + label = np.asarray(label, np.int32) + elif(self.task == TaskType.RECONSTRUCTION): + label = np.asarray(label, np.float32) return label def __get_pixel_weight__(self, idx): @@ -101,10 +113,12 @@ class ClassificationDataset(NiftyDataset): The built-in transforms can listed in :mod:`pymic.transform.trans_dict`. """ def __init__(self, root_dir, csv_file, modal_num = 1, class_num = 2, - with_label = False, transform=None): + with_label = False, transform=None, task = TaskType.CLASSIFICATION_ONE_HOT): super(ClassificationDataset, self).__init__(root_dir, csv_file, modal_num, with_label, transform) self.class_num = class_num + self.task = task + assert self.task in [TaskType.CLASSIFICATION_ONE_HOT, TaskType.CLASSIFICATION_COEXIST] def __getlabel__(self, idx): csv_keys = list(self.csv_items.keys()) diff --git a/pymic/net_run/agent_abstract.py b/pymic/net_run/agent_abstract.py index 2e6c062..109eabc 100644 --- a/pymic/net_run/agent_abstract.py +++ b/pymic/net_run/agent_abstract.py @@ -57,7 +57,7 @@ def __init__(self, config, stage = 'train'): self.transform_dict = None self.inferer = None self.tensor_type = config['dataset']['tensor_type'] - self.task_type = config['dataset']['task_type'] #cls, cls_mtbc, seg, rec + self.task_type = config['dataset']['task_type'] self.deterministic = config['training'].get('deterministic', True) self.random_seed = config['training'].get('random_seed', 1) if(self.deterministic): diff --git a/pymic/net_run/agent_cls.py b/pymic/net_run/agent_cls.py index 862a703..1ded3b1 100644 --- a/pymic/net_run/agent_cls.py +++ b/pymic/net_run/agent_cls.py @@ -13,6 +13,7 @@ from torch.optim import lr_scheduler from torchvision import transforms from tensorboardX import SummaryWriter +from pymic import TaskType from pymic.io.nifty_dataset import ClassificationDataset from pymic.loss.loss_dict_cls import PyMICClsLossDict from pymic.net.net_dict_cls import TorchClsNetDict @@ -38,7 +39,6 @@ class ClassificationAgent(NetRunAgent): def __init__(self, config, stage = 'train'): super(ClassificationAgent, self).__init__(config, stage) self.transform_dict = TransformDict - assert(self.task_type in ["cls", "cls_nexcl"]) def get_stage_dataset_from_config(self, stage): assert(stage in ['train', 'valid', 'test']) @@ -119,12 +119,12 @@ def get_evaluation_score(self, outputs, labels): metrics = self.config['training'].get("evaluation_metric", "accuracy") if(metrics != "accuracy"): # default classification accuracy raise ValueError("Not implemeted for metric {0:}".format(metrics)) - if(self.task_type == "cls"): + if(self.task_type == TaskType.CLASSIFICATION_ONE_HOT): out_argmax = torch.argmax(outputs, 1) lab_argmax = torch.argmax(labels, 1) consis = self.convert_tensor_type(out_argmax == lab_argmax) score = torch.mean(consis) - elif(self.task_type == "cls_nexcl"): #nonexclusive classification + elif(self.task_type == TaskType.CLASSIFICATION_COEXIST): preds = self.convert_tensor_type(outputs > 0.5) consis= self.convert_tensor_type(preds == labels.data) score = torch.mean(consis) @@ -346,15 +346,15 @@ def infer(self): infer_time = time.time() - start_time infer_time_list.append(infer_time) - if (self.task_type == "cls"): + if (self.task_type == TaskType.CLASSIFICATION_ONE_HOT): out_prob = nn.Softmax(dim = 1)(out_digit).detach().cpu().numpy() out_lab = np.argmax(out_prob, axis=1) - else: #self.task_type == "cls_nexcl" + else: #self.task_type == TaskType.CLASSIFICATION_COEXIST out_prob = nn.Sigmoid()(out_digit).detach().cpu().numpy() out_lab = np.asarray(out_prob > 0.5, np.uint8) for i in range(len(names)): print(names[i], out_lab[i]) - if(self.task_type == "cls"): + if(self.task_type == TaskType.CLASSIFICATION_ONE_HOT): out_lab_list.append([names[i]] + [out_lab[i]]) else: out_lab_list.append([names[i]] + out_lab[i].tolist()) diff --git a/pymic/net_run/agent_rec.py b/pymic/net_run/agent_rec.py index 607623d..188a5fc 100644 --- a/pymic/net_run/agent_rec.py +++ b/pymic/net_run/agent_rec.py @@ -113,6 +113,9 @@ def validation(self): validIter = iter(self.valid_loader) with torch.no_grad(): self.net.eval() + + # for debug + # save_num = 0 for data in validIter: inputs = self.convert_tensor_type(data['image']) label = self.convert_tensor_type(data['label']) @@ -123,6 +126,24 @@ def validation(self): loss = self.get_loss_value(data, outputs, label) valid_loss_list.append(loss.item()) + # for debug + # print(inputs.shape, label.shape, outputs.shape) + # inputs = inputs.cpu().numpy() + # label = label.cpu().numpy() + # outputs = outputs.cpu().numpy() + # for i in range(inputs.shape[0]): + # image_i = inputs[i][0] + # label_i = label[i][0] + # output_i = outputs[i][0] + # image_name = "temp/case{0:}_image.nii.gz".format(save_num + i) + # label_name = "temp/case{0:}_label.nii.gz".format(save_num + i) + # output_name= "temp/case{0:}_output.nii.gz".format(save_num + i) + # save_nd_array_as_image(image_i, image_name, reference_name = None) + # save_nd_array_as_image(label_i, label_name, reference_name = None) + # save_nd_array_as_image(output_i, output_name, reference_name = None) + # save_num += inputs.shape[0] + # if(save_num > 20): + # break valid_avg_loss = np.asarray(valid_loss_list).mean() valid_scalers = {'loss': valid_avg_loss} return valid_scalers diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index b85c716..fb99075 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -67,7 +67,8 @@ def get_stage_dataset_from_config(self, stage): csv_file = csv_file, modal_num = modal_num, with_label= not (stage == 'test'), - transform = data_transform ) + transform = data_transform, + task = self.task_type) return dataset def create_network(self): diff --git a/pymic/net_run/predict.py b/pymic/net_run/predict.py index 31fff86..80134d8 100644 --- a/pymic/net_run/predict.py +++ b/pymic/net_run/predict.py @@ -4,6 +4,7 @@ import os import sys from datetime import datetime +from pymic import TaskType from pymic.util.parse_config import * from pymic.net_run.agent_cls import ClassificationAgent from pymic.net_run.agent_seg import SegmentationAgent @@ -34,14 +35,15 @@ def main(): level=logging.INFO, format='%(message)s') # for python 3.6 logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging_config(config) - task = config['dataset']['task_type'] - assert task in ['cls', 'cls_nexcl', 'seg', 'rec'] - if(task == 'cls' or task == 'cls_nexcl'): + task = config['dataset']['task_type'] + if(task == TaskType.CLASSIFICATION_ONE_HOT or task == TaskType.CLASSIFICATION_COEXIST): agent = ClassificationAgent(config, 'test') - elif(task == 'seg'): + elif(task == TaskType.SEGMENTATION): agent = SegmentationAgent(config, 'test') - else: + elif(task == TaskType.RECONSTRUCTION): agent = ReconstructionAgent(config, 'test') + else: + raise ValueError("Undefined task for inference: {0:}".format(task)) agent.run() if __name__ == "__main__": diff --git a/pymic/net_run/semi_sup/ssl_abstract.py b/pymic/net_run/semi_sup/ssl_abstract.py index 5a46257..b27edc9 100644 --- a/pymic/net_run/semi_sup/ssl_abstract.py +++ b/pymic/net_run/semi_sup/ssl_abstract.py @@ -44,7 +44,7 @@ def get_unlabeled_dataset_from_config(self): data_transform = None else: transform_param = self.config['dataset'] - transform_param['task'] = 'segmentation' + transform_param['task'] = self.task_type for name in transform_names: if(name not in self.transform_dict): raise(ValueError("Undefined transform {0:}".format(name))) diff --git a/pymic/net_run/train.py b/pymic/net_run/train.py index 107a519..0167f2f 100644 --- a/pymic/net_run/train.py +++ b/pymic/net_run/train.py @@ -5,6 +5,7 @@ import sys import shutil from datetime import datetime +from pymic import TaskType from pymic.util.parse_config import * from pymic.net_run.agent_cls import ClassificationAgent from pymic.net_run.agent_seg import SegmentationAgent @@ -47,7 +48,7 @@ def get_seg_rec_agent(config, sup_type): 'inpainting_probability': 0.2 } config['dataset']['train_transform'].extend(transforms) - config['dataset']['valid_transform'].extend(transforms) + # config['dataset']['valid_transform'].extend(transforms) config['dataset'].update(genesis_cfg) logging_config(config['dataset']) else: @@ -86,8 +87,7 @@ def main(): logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging_config(config) task = config['dataset']['task_type'] - assert task in ['cls', 'cls_nexcl', 'seg', 'rec'] - if(task == 'cls' or task == 'cls_nexcl'): + if(task == TaskType.CLASSIFICATION_ONE_HOT or task == TaskType.CLASSIFICATION_COEXIST): agent = ClassificationAgent(config, 'train') else: sup_type = config['dataset'].get('supervise_type', 'fully_sup') diff --git a/pymic/transform/crop.py b/pymic/transform/crop.py index 60f5c9b..acadc49 100644 --- a/pymic/transform/crop.py +++ b/pymic/transform/crop.py @@ -7,6 +7,7 @@ import random import numpy as np from scipy import ndimage +from pymic import TaskType from pymic.transform.abstract_transform import AbstractTransform from pymic.util.image_process import * @@ -55,12 +56,14 @@ def __call__(self, sample): image_t = crop_ND_volume_with_bounding_box(image, crop_min, crop_max) sample['image'] = image_t - if('label' in sample and self.task in ['seg', 'rec']): + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): label = sample['label'] crop_max[0] = label.shape[0] label = crop_ND_volume_with_bounding_box(label, crop_min, crop_max) sample['label'] = label - if('pixel_weight' in sample and self.task in ['seg', 'rec']): + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): weight = sample['pixel_weight'] crop_max[0] = weight.shape[0] weight = crop_ND_volume_with_bounding_box(weight, crop_min, crop_max) @@ -300,13 +303,15 @@ def __call__(self, sample): image_t = ndimage.interpolation.zoom(image_t, scale, order = 1) sample['image'] = image_t - if('label' in sample and self.task in ['seg', 'rec']): + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): label = sample['label'] crop_max[0] = label.shape[0] label = crop_ND_volume_with_bounding_box(label, crop_min, crop_max) label = ndimage.interpolation.zoom(label, scale, order = 0) sample['label'] = label - if('pixel_weight' in sample and self.task ['seg', 'rec']): + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): weight = sample['pixel_weight'] crop_max[0] = weight.shape[0] weight = crop_ND_volume_with_bounding_box(weight, crop_min, crop_max) diff --git a/pymic/transform/flip.py b/pymic/transform/flip.py index 0462935..24cafb4 100644 --- a/pymic/transform/flip.py +++ b/pymic/transform/flip.py @@ -7,6 +7,7 @@ import random import numpy as np from scipy import ndimage +from pymic import TaskType from pymic.transform.abstract_transform import AbstractTransform from pymic.util.image_process import * @@ -52,9 +53,11 @@ def __call__(self, sample): # current pytorch does not support negative strides image_t = np.flip(image, flip_axis).copy() sample['image'] = image_t - if('label' in sample and self.task in ['seg', 'rec']): + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): sample['label'] = np.flip(sample['label'] , flip_axis).copy() - if('pixel_weight' in sample and self.task in ['seg', 'rec']): + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): sample['pixel_weight'] = np.flip(sample['pixel_weight'] , flip_axis).copy() return sample diff --git a/pymic/transform/label_convert.py b/pymic/transform/label_convert.py index 1efaab6..e3accdf 100644 --- a/pymic/transform/label_convert.py +++ b/pymic/transform/label_convert.py @@ -7,6 +7,7 @@ import random import numpy as np from scipy import ndimage +from pymic import TaskType from pymic.transform.abstract_transform import AbstractTransform from pymic.util.image_process import * @@ -80,13 +81,13 @@ def __init__(self, params): self.inverse = params.get('LabelToProbability_inverse'.lower(), False) def __call__(self, sample): - if(self.task == 'segmentation'): + if(self.task == TaskType.SEGMENTATION): label = sample['label'][0] # sample['label'] is (1, h, w) label_prob = np.zeros((self.class_num, *label.shape), dtype = np.float32) for i in range(self.class_num): label_prob[i] = label == i*np.ones_like(label) sample['label_prob'] = label_prob - elif(self.task == 'classification'): + elif(self.task == TaskType.CLASSIFICATION_ONE_HOT): label_idx = sample['label'] label_prob = np.zeros((self.class_num,), np.float32) label_prob[label_idx] = 1.0 diff --git a/pymic/transform/pad.py b/pymic/transform/pad.py index 4d292be..91cf6da 100644 --- a/pymic/transform/pad.py +++ b/pymic/transform/pad.py @@ -7,6 +7,7 @@ import random import numpy as np from scipy import ndimage +from pymic import TaskType from pymic.transform.abstract_transform import AbstractTransform from pymic.util.image_process import * @@ -59,11 +60,13 @@ def __call__(self, sample): sample['image'] = image_t - if('label' in sample and self.task in ['seg', 'rec']): + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): label = sample['label'] label = np.pad(label, pad, 'reflect') if(max(margin) > 0) else label sample['label'] = label - if('pixel_weight' in sample and self.task in ['seg', 'rec']): + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): weight = sample['pixel_weight'] weight = np.pad(weight, pad, 'reflect') if(max(margin) > 0) else weight sample['pixel_weight'] = weight diff --git a/pymic/transform/rescale.py b/pymic/transform/rescale.py index 660b156..355712e 100644 --- a/pymic/transform/rescale.py +++ b/pymic/transform/rescale.py @@ -5,6 +5,7 @@ import random import numpy as np from scipy import ndimage +from pymic import TaskType from pymic.transform.abstract_transform import AbstractTransform from pymic.util.image_process import * @@ -48,11 +49,13 @@ def __call__(self, sample): sample['image'] = image_t sample['Rescale_origin_shape'] = json.dumps(input_shape) - if('label' in sample and self.task in ['seg', 'rec']): + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): label = sample['label'] label = ndimage.interpolation.zoom(label, scale, order = 0) sample['label'] = label - if('pixel_weight' in sample and self.task in ['seg', 'rec']): + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): weight = sample['pixel_weight'] weight = ndimage.interpolation.zoom(weight, scale, order = 1) sample['pixel_weight'] = weight @@ -126,11 +129,13 @@ def __call__(self, sample): sample['image'] = image_t sample['RandomRescale_Param'] = json.dumps(input_shape) - if('label' in sample and self.task in ['seg', 'rec']): + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): label = sample['label'] label = ndimage.interpolation.zoom(label, scale, order = 0) sample['label'] = label - if('pixel_weight' in sample and self.task in ['seg', 'rec']): + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): weight = sample['pixel_weight'] weight = ndimage.interpolation.zoom(weight, scale, order = 1) sample['pixel_weight'] = weight diff --git a/pymic/transform/rotate.py b/pymic/transform/rotate.py index bd68e6a..65e5328 100644 --- a/pymic/transform/rotate.py +++ b/pymic/transform/rotate.py @@ -5,6 +5,7 @@ import random import numpy as np from scipy import ndimage +from pymic import TaskType from pymic.transform.abstract_transform import AbstractTransform from pymic.util.image_process import * @@ -78,10 +79,12 @@ def __call__(self, sample): sample['RandomRotate_Param'] = json.dumps(transform_param_list) image_t = self.__apply_transformation(image, transform_param_list, 1) sample['image'] = image_t - if('label' in sample and self.task in ['seg', 'rec']): + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): sample['label'] = self.__apply_transformation(sample['label'] , transform_param_list, 0) - if('pixel_weight' in sample and self.task in ['seg', 'rec']): + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): sample['pixel_weight'] = self.__apply_transformation(sample['pixel_weight'] , transform_param_list, 1) return sample diff --git a/pymic/transform/transpose.py b/pymic/transform/transpose.py index 67e8611..6f5d5fa 100644 --- a/pymic/transform/transpose.py +++ b/pymic/transform/transpose.py @@ -4,6 +4,7 @@ import json import random import numpy as np +from pymic import TaskType from pymic.transform.abstract_transform import AbstractTransform @@ -39,12 +40,12 @@ def __call__(self, sample): if(transpose_axis is not None): image_t = np.transpose(image, transpose_axis) sample['image'] = image_t - if('label' in sample and self.task in ['seg', 'rec']): + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): sample['label'] = np.transpose(sample['label'] , transpose_axis) - if('pixel_weight' in sample and self.task in ['seg', 'rec']): - sample['pixel_weight'] = np.transpose(sample['pixel_weight'] , transpose_axis) - - + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + sample['pixel_weight'] = np.transpose(sample['pixel_weight'] , transpose_axis) return sample def inverse_transform_for_prediction(self, sample): diff --git a/pymic/util/parse_config.py b/pymic/util/parse_config.py index 232762f..18afd08 100644 --- a/pymic/util/parse_config.py +++ b/pymic/util/parse_config.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, print_function +from pymic import TaskDict import configparser import logging @@ -103,6 +104,7 @@ def synchronize_config(config): data_cfg = config['dataset'] net_cfg = config['network'] # data_cfg["modal_num"] = net_cfg["in_chns"] + data_cfg["task_type"] = TaskDict[data_cfg["task_type"]] data_cfg["LabelToProbability_class_num".lower()] = net_cfg["class_num"] if "PartialLabelToProbability" in data_cfg['train_transform']: data_cfg["PartialLabelToProbability_class_num".lower()] = net_cfg["class_num"] diff --git a/pymic/util/preprocess.py b/pymic/util/preprocess.py index 5f20372..c0dc9a1 100644 --- a/pymic/util/preprocess.py +++ b/pymic/util/preprocess.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- import os import numpy as np import SimpleITK as sitk From 462e8a5daaca6f36ef0f8cd2ca49304649f641a8 Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 12 Apr 2023 14:14:10 +0800 Subject: [PATCH 145/225] update transform add IntensityClip and edit Normalization --- pymic/net_run/agent_seg.py | 12 +- pymic/transform/crop.py | 64 +++++------ pymic/transform/intensity.py | 39 ++++++- pymic/transform/normalize.py | 7 +- pymic/transform/trans_dict.py | 2 + pymic/util/evaluation_seg.py | 199 +++++++++++++++++++--------------- pymic/util/general.py | 10 +- 7 files changed, 199 insertions(+), 134 deletions(-) diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index fb99075..1bdad06 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -149,15 +149,15 @@ def training(self): # for i in range(inputs.shape[0]): # image_i = inputs[i][0] # label_i = labels_prob[i][1] - # pixw_i = pix_w[i][0] - # print(image_i.shape, label_i.shape, pixw_i.shape) + # # pixw_i = pix_w[i][0] + # print(image_i.shape, label_i.shape) # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) # label_name = "temp/label_{0:}_{1:}.nii.gz".format(it, i) - # weight_name= "temp/weight_{0:}_{1:}.nii.gz".format(it, i) + # # weight_name= "temp/weight_{0:}_{1:}.nii.gz".format(it, i) # save_nd_array_as_image(image_i, image_name, reference_name = None) # save_nd_array_as_image(label_i, label_name, reference_name = None) - # save_nd_array_as_image(pixw_i, weight_name, reference_name = None) - # continue + # # save_nd_array_as_image(pixw_i, weight_name, reference_name = None) + # # continue inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device) @@ -297,7 +297,7 @@ def train_valid(self): if(ckpt_init_mode > 0): # Load other information self.max_val_dice = checkpoint.get('valid_pred', 0) - iter_start = checkpoint['iteration'] - 1 + iter_start = checkpoint['iteration'] self.max_val_it = iter_start self.best_model_wts = checkpoint['model_state_dict'] ckpt_for_optm = checkpoint diff --git a/pymic/transform/crop.py b/pymic/transform/crop.py index acadc49..1012712 100644 --- a/pymic/transform/crop.py +++ b/pymic/transform/crop.py @@ -243,62 +243,51 @@ def _get_param_for_inverse_transform(self, sample): class RandomResizedCrop(CenterCrop): """ - Randomly crop the input image (shape [C, H, W]). Only 2D images are supported. - + Randomly resize and crop the input image (shape [C, D, H, W]). The arguments should be written in the `params` dictionary, and it has the following fields: - :param `RandomResizedCrop_output_size`: (list/tuple) Desired output size [H, W]. + :param `RandomResizedCrop_output_size`: (list/tuple) Desired output size [D, H, W]. The output channel is the same as the input channel. - :param `RandomResizedCrop_scale`: (list/tuple) Range of scale, e.g. (0.08, 1.0). - :param `RandomResizedCrop_ratio`: (list/tuple) Range of aspect ratio, e.g. (0.75, 1.33). + :param `RandomResizedCrop_scale_range`: (list/tuple) Range of scale, e.g. (0.08, 1.0). :param `RandomResizedCrop_inverse`: (optional, bool) Is inverse transform needed for inference. Default is `False`. Currently, the inverse transform is not supported, and this transform is assumed to be used only during training stage. """ def __init__(self, params): self.output_size = params['RandomResizedCrop_output_size'.lower()] - self.scale = params['RandomResizedCrop_scale'.lower()] - self.ratio = params['RandomResizedCrop_ratio'.lower()] + self.scale = params['RandomResizedCrop_scale_range'.lower()] self.inverse = params.get('RandomResizedCrop_inverse'.lower(), False) self.task = params['Task'.lower()] assert isinstance(self.output_size, (list, tuple)) assert isinstance(self.scale, (list, tuple)) - assert isinstance(self.ratio, (list, tuple)) - def _get_crop_param(self, sample): + def __call__(self, sample): image = sample['image'] - input_shape = image.shape - input_dim = len(input_shape) - 1 - assert(input_dim == 2) + channel, input_size = image.shape[0], image.shape[1:] + input_dim = len(input_size) assert(input_dim == len(self.output_size)) - scale = self.scale[0] + random.random()*(self.scale[1] - self.scale[0]) - ratio = self.ratio[0] + random.random()*(self.ratio[1] - self.ratio[0]) - crop_w = input_shape[-1] * scale - crop_h = crop_w * ratio - crop_h = min(crop_h, input_shape[-2]) - output_shape = [int(crop_h), int(crop_w)] - - crop_margin = [input_shape[i + 1] - output_shape[i]\ - for i in range(input_dim)] + crop_size = [int(self.output_size[i] * scale) for i in range(input_dim)] + crop_margin = [input_size[i] - crop_size[i] for i in range(input_dim)] + pad_image = False + if(min(crop_margin) < 0): + pad_image = True + pad_size = [max(0, -crop_margin[i]) for i in range(input_dim)] + pad_lower = [int(pad_size[i] / 2) for i in range(input_dim)] + pad_upper = [pad_size[i] - pad_lower[i] for i in range(input_dim)] + pad = [(pad_lower[i], pad_upper[i]) for i in range(input_dim)] + pad = tuple([(0, 0)] + pad) + image = np.pad(image, pad, 'reflect') + crop_margin = [max(0, crop_margin[i]) for i in range(input_dim)] + crop_min = [random.randint(0, item) for item in crop_margin] - crop_max = [crop_min[i] + output_shape[i] \ - for i in range(input_dim)] + crop_max = [crop_min[i] + crop_size[i] for i in range(input_dim)] crop_min = [0] + crop_min - crop_max = list(input_shape[0:1]) + crop_max - sample['RandomResizedCrop_Param'] = json.dumps((input_shape, crop_min, crop_max)) - return sample, crop_min, crop_max - - def __call__(self, sample): - image = sample['image'] - input_shape = image.shape - input_dim = len(input_shape) - 1 - sample, crop_min, crop_max = self._get_crop_param(sample) + crop_max = [channel] + crop_max image_t = crop_ND_volume_with_bounding_box(image, crop_min, crop_max) - crp_shape = image_t.shape - scale = [(self.output_size[i] + 0.0)/crp_shape[1:][i] for i in range(input_dim)] + scale = [(self.output_size[i] + 0.0)/crop_size[i] for i in range(input_dim)] scale = [1.0] + scale image_t = ndimage.interpolation.zoom(image_t, scale, order = 1) sample['image'] = image_t @@ -306,13 +295,18 @@ def __call__(self, sample): if('label' in sample and \ self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): label = sample['label'] + if(pad_image): + label = np.pad(label, pad, 'reflect') crop_max[0] = label.shape[0] label = crop_ND_volume_with_bounding_box(label, crop_min, crop_max) - label = ndimage.interpolation.zoom(label, scale, order = 0) + order = 0 if(self.task == TaskType.SEGMENTATION) else 1 + label = ndimage.interpolation.zoom(label, scale, order = order) sample['label'] = label if('pixel_weight' in sample and \ self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): weight = sample['pixel_weight'] + if(pad_image): + weight = np.pad(weight, pad, 'reflect') crop_max[0] = weight.shape[0] weight = crop_ND_volume_with_bounding_box(weight, crop_min, crop_max) weight = ndimage.interpolation.zoom(weight, scale, order = 1) diff --git a/pymic/transform/intensity.py b/pymic/transform/intensity.py index 3b5ee9d..1a13190 100644 --- a/pymic/transform/intensity.py +++ b/pymic/transform/intensity.py @@ -44,6 +44,40 @@ def bezier_curve(points, nTimes=1000): return xvals, yvals +class IntensityClip(AbstractTransform): + """ + Clip the intensity for input image + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `IntensityClip_channels`: (list) A list of int for specifying the channels. + :param `IntensityClip_lower`: (list) The lower bound for clip in each channel. + :param `IntensityClip_upper`: (list) The upper bound for clip in each channel. + :param `IntensityClip_inverse`: (optional, bool) + Is inverse transform needed for inference. Default is `False`. + """ + def __init__(self, params): + super(IntensityClip, self).__init__(params) + self.channels = params['IntensityClip_channels'.lower()] + self.lower = params.get('IntensityClip_lower'.lower(), None) + self.upper = params.get('IntensityClip_upper'.lower(), None) + self.inverse = params.get('IntensityClip_inverse'.lower(), False) + + def __call__(self, sample): + image = sample['image'] + lower = self.lower if self.lower is not None else [None] * len(self.channels) + upper = self.upper if self.upper is not None else [None] * len(self.channels) + for chn in self.channels: + lower_c, upper_c = lower[chn], upper[chn] + if(lower_c is None): + lower_c = np.percentile(image[chn], 0.05) + if(upper_c is None): + upper_c = np.percentile(image[chn, 99.95]) + image[chn] = np.clip(image[chn], lower_c, upper_c) + sample['image'] = image + return sample + class GammaCorrection(AbstractTransform): """ Apply random gamma correction to given channels. @@ -76,8 +110,9 @@ def __call__(self, sample): img_c = image[chn] v_min = img_c.min() v_max = img_c.max() - img_c = (img_c - v_min)/(v_max - v_min) - img_c = np.power(img_c, gamma_c)*(v_max - v_min) + v_min + if(v_min < v_max): + img_c = (img_c - v_min)/(v_max - v_min) + img_c = np.power(img_c, gamma_c)*(v_max - v_min) + v_min image[chn] = img_c sample['image'] = image diff --git a/pymic/transform/normalize.py b/pymic/transform/normalize.py index 4e493dd..77852d2 100644 --- a/pymic/transform/normalize.py +++ b/pymic/transform/normalize.py @@ -53,9 +53,12 @@ def __call__(self, sample): if(chn_mean is None): if(self.ingore_np): pixels = image[chn][image[chn] > 0] - chn_mean, chn_std = pixels.mean(), pixels.std() + if(len(pixels) > 0): + chn_mean, chn_std = pixels.mean(), pixels.std() + 1e-5 + else: + chn_mean, chn_std = 0.0, 1.0 else: - chn_mean, chn_std = image[chn].mean(), image[chn].std() + chn_mean, chn_std = image[chn].mean(), image[chn].std() + 1e-5 chn_norm = (image[chn] - chn_mean)/chn_std diff --git a/pymic/transform/trans_dict.py b/pymic/transform/trans_dict.py index a7d96fd..e8f1c37 100644 --- a/pymic/transform/trans_dict.py +++ b/pymic/transform/trans_dict.py @@ -14,6 +14,7 @@ 'LabelConvert': LabelConvert, 'LabelConvertNonzero': LabelConvertNonzero, 'LabelToProbability': LabelToProbability, + 'IntensityClip': IntensityClip, 'NormalizeWithMeanStd': NormalizeWithMeanStd, 'NormalizeWithMinMax': NormalizeWithMinMax, 'NormalizeWithPercentiles': NormalizeWithPercentiles, @@ -55,6 +56,7 @@ 'LabelConvertNonzero': LabelConvertNonzero, 'LabelToProbability': LabelToProbability, 'LocalShuffling': LocalShuffling, + 'IntensityClip': IntensityClip, 'NonLinearTransform': NonLinearTransform, 'NormalizeWithMeanStd': NormalizeWithMeanStd, 'NormalizeWithMinMax': NormalizeWithMinMax, diff --git a/pymic/util/evaluation_seg.py b/pymic/util/evaluation_seg.py index ba04a73..836401d 100644 --- a/pymic/util/evaluation_seg.py +++ b/pymic/util/evaluation_seg.py @@ -3,15 +3,19 @@ Evaluation module for segmenation tasks. """ from __future__ import absolute_import, print_function +import argparse import csv import os import sys import pandas as pd import numpy as np +from os.path import join from scipy import ndimage from pymic.io.image_read_write import * from pymic.util.image_process import * -from pymic.util.parse_config import parse_config +from pymic.util.general import is_image_name +from pymic.util.parse_config import parse_config, parse_value_from_string + def binary_dice(s, g, resize = False): @@ -257,108 +261,127 @@ def evaluation(config): When a list is given, each list element should be the root dir of the results of one method. :param evaluation_image_pair: (str) The csv file that provide the segmentation images and the corresponding ground truth images. - :param ground_truth_label_convert_source: (optional, list) The list of source - labels for label conversion in the ground truth. - :param ground_truth_label_convert_target: (optional, list) The list of target - labels for label conversion in the ground truth. - :param segmentation_label_convert_source: (optional, list) The list of source - labels for label conversion in the segmentation. - :param segmentation_label_convert_target: (optional, list) The list of target - labels for label conversion in the segmentation. """ metric_list = config['metric_list'] - label_list = config['label_list'] - label_fuse = config.get('label_fuse', False) - organ_name = config['organ_name'] - gt_root = config['ground_truth_folder_root'] - seg_root = config['segmentation_folder_root'] - if(not(isinstance(seg_root, tuple) or isinstance(seg_root, list))): - seg_root = [seg_root] - image_pair_csv = config['evaluation_image_pair'] - ground_truth_label_convert_source = config.get('ground_truth_label_convert_source', None) - ground_truth_label_convert_target = config.get('ground_truth_label_convert_target', None) - segmentation_label_convert_source = config.get('segmentation_label_convert_source', None) - segmentation_label_convert_target = config.get('segmentation_label_convert_target', None) - - image_items = pd.read_csv(image_pair_csv) - item_num = len(image_items) - - for seg_root_n in seg_root: # for each segmentation method - for metric in metric_list: - score_all_data = [] - name_score_list= [] - for i in range(item_num): - gt_name = image_items.iloc[i, 0] - seg_name = image_items.iloc[i, 1] - # seg_name = seg_name.replace(".nii.gz", "_pred.nii.gz") - gt_full_name = gt_root + '/' + gt_name - seg_full_name = seg_root_n + '/' + seg_name - - s_dict = load_image_as_nd_array(seg_full_name) - g_dict = load_image_as_nd_array(gt_full_name) - s_volume = s_dict["data_array"]; s_spacing = s_dict["spacing"] - g_volume = g_dict["data_array"]; g_spacing = g_dict["spacing"] - # for dim in range(len(s_spacing)): - # assert(s_spacing[dim] == g_spacing[dim]) - if((ground_truth_label_convert_source is not None) and \ - ground_truth_label_convert_target is not None): - g_volume = convert_label(g_volume, ground_truth_label_convert_source, \ - ground_truth_label_convert_target) - - if((segmentation_label_convert_source is not None) and \ - segmentation_label_convert_target is not None): - s_volume = convert_label(s_volume, segmentation_label_convert_source, \ - segmentation_label_convert_target) - - score_vector = get_multi_class_evaluation_score(s_volume, g_volume, label_list, - label_fuse, s_spacing, metric ) - if(len(label_list) > 1): - score_vector.append(np.asarray(score_vector).mean()) - score_all_data.append(score_vector) - name_score_list.append([seg_name] + score_vector) - print(seg_name, score_vector) - score_all_data = np.asarray(score_all_data) - score_mean = score_all_data.mean(axis = 0) - score_std = score_all_data.std(axis = 0) - name_score_list.append(['mean'] + list(score_mean)) - name_score_list.append(['std'] + list(score_std)) + if(not isinstance(metric_list, list)): + metric_list = [metric_list] + label_list = config.get('label_list', None) + if(label_list is None): + label_list = range(1, config["class_number"]) + elif(not isinstance(label_list, list)): + label_list = [label_list] + label_fuse = config.get('label_fuse', False) + output_name = config.get('output_name', None) + gt_root = config['ground_truth_folder_root'] + seg_root = config['segmentation_folder_root'] + image_pair_csv = config.get('evaluation_image_pair', None) + + if(image_pair_csv is not None): + image_pair = pd.read_csv(image_pair_csv) + gt_names, seg_names = image_pair.iloc[:, 0], image_pair.iloc[:, 1] + else: + seg_names = sorted(os.listdir(seg_root)) + seg_names = [item for item in seg_names if is_image_name(item)] + gt_names = seg_names + - # save the result as csv - score_csv = "{0:}/{1:}_{2:}_all.csv".format(seg_root_n, organ_name, metric) - with open(score_csv, mode='w') as csv_file: - csv_writer = csv.writer(csv_file, delimiter=',', - quotechar='"',quoting=csv.QUOTE_MINIMAL) - head = ['image'] + ["class_{0:}".format(i) for i in label_list] - if(len(label_list) > 1): - head = head + ["average"] - csv_writer.writerow(head) - for item in name_score_list: - csv_writer.writerow(item) - - print("{0:} mean ".format(metric), score_mean) - print("{0:} std ".format(metric), score_std) + for metric in metric_list: + print(metric) + score_all_data = [] + name_score_list= [] + for i in range(len(gt_names)): + gt_full_name = join(gt_root, gt_names[i]) + seg_full_name = join(seg_root, seg_names[i]) + s_dict = load_image_as_nd_array(seg_full_name) + g_dict = load_image_as_nd_array(gt_full_name) + s_volume = s_dict["data_array"]; s_spacing = s_dict["spacing"] + g_volume = g_dict["data_array"]; g_spacing = g_dict["spacing"] + # for dim in range(len(s_spacing)): + # assert(s_spacing[dim] == g_spacing[dim]) + + score_vector = get_multi_class_evaluation_score(s_volume, g_volume, label_list, + label_fuse, s_spacing, metric ) + if(len(label_list) > 1): + score_vector.append(np.asarray(score_vector).mean()) + score_all_data.append(score_vector) + name_score_list.append([seg_names[i]] + score_vector) + print(seg_names[i], score_vector) + score_all_data = np.asarray(score_all_data) + score_mean = score_all_data.mean(axis = 0) + score_std = score_all_data.std(axis = 0) + name_score_list.append(['mean'] + list(score_mean)) + name_score_list.append(['std'] + list(score_std)) + + # save the result as csv + if(output_name is None): + output_name = "{0:}/eval_{1:}.csv".format(seg_root, metric) + with open(output_name, mode='w') as csv_file: + csv_writer = csv.writer(csv_file, delimiter=',', + quotechar='"',quoting=csv.QUOTE_MINIMAL) + head = ['image'] + ["class_{0:}".format(i) for i in label_list] + if(len(label_list) > 1): + head = head + ["average"] + csv_writer.writerow(head) + for item in name_score_list: + csv_writer.writerow(item) + + print("{0:} mean ".format(metric), score_mean) + print("{0:} std ".format(metric), score_std) def main(): """ Main function for evaluation of segmentation results. - A configuration file is needed for runing. e.g., + You can use a configuration file for runing. e.g., .. code-block:: none - pymic_evaluate_cls config.cfg + pymic_evaluate_seg -cfg config.cfg The configuration file should have an `evaluation` section. See :mod:`pymic.util.evaluation_seg.evaluation` for details of the configuration required. + + In addition, you can also provide a list of args in the command if -cfg is not used. For example: + + .. code-block:: none + + pymic_evaluate_seg -metric dice -cls_index 255 -gt_dir ground_truth_dir -seg_dir segmentation_dir + """ - if(len(sys.argv) < 2): - print('Number of arguments should be 2. e.g.') - print(' pymic_evaluate_seg config.cfg') - exit() - config_file = str(sys.argv[1]) - assert(os.path.isfile(config_file)) - config = parse_config(config_file)['evaluation'] + parser = argparse.ArgumentParser() + parser.add_argument("-cfg", help="configuration file for evaluation", + required=False, default=None) + parser.add_argument("-metric", help="evaluation metrics, e.g., dice, or [dice, assd]", + required=False, default=None) + parser.add_argument("-cls_num", help="number of classes", + required=False, default=None) + parser.add_argument("-cls_index", help="The class index for evaluation, e.g., 255, [1, 2]", + required=False, default=None) + parser.add_argument("-gt_dir", help="path of folder for ground truth", + required=False, default=None) + parser.add_argument("-seg_dir", help="path of folder for segmentation", + required=False, default=None) + parser.add_argument("-name_pair", help="the .csv file for name mapping in case" + " the names of one case are different in the gt_dir " + " and seg_dir", + required=False, default=None) + parser.add_argument("-out", help="the output .csv file name", + required=False, default=None) + args = parser.parse_args() + print(args) + if(args.cfg is not None): + config = parse_config(args.cfg)['evaluation'] + else: + config = {} + config['metric_list'] = parse_value_from_string(args.metric) + config['label_list'] = None if args.cls_index is None else parse_value_from_string(args.cls_index) + config['class_number']= None if args.cls_num is None else parse_value_from_string(args.cls_num) + config['ground_truth_folder_root'] = args.gt_dir + config['segmentation_folder_root'] = args.seg_dir + config['evaluation_image_pair'] = args.name_pair + config['output_name'] = args.out + print(config) evaluation(config) - + if __name__ == '__main__': main() diff --git a/pymic/util/general.py b/pymic/util/general.py index 075d2e1..cf52746 100644 --- a/pymic/util/general.py +++ b/pymic/util/general.py @@ -26,7 +26,15 @@ def tensor_shape_match(a,b): return False return True - +def is_image_name(x): + valid_names = ["jpg", "jpeg", "png", "bmp", "nii.gz", + "tif", "nii", "nii.gz", "mha"] + valid = False + for item in valid_names: + if(x.endswith(item)): + valid = True + break + return valid def get_one_hot_seg(label, class_num): """ From 64cbca070df47dcbd8458ff4217af63270185692 Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 4 May 2023 21:59:08 +0800 Subject: [PATCH 146/225] update train and inference mode 1, allow unlabeled validation dataset (for self-supervised training) 2, add gaussian weight map for inference --- pymic/net_run/agent_seg.py | 15 +++++++++++-- pymic/net_run/infer_func.py | 43 ++++++++++++++++++++++++------------- 2 files changed, 41 insertions(+), 17 deletions(-) diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 1bdad06..2e703d4 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -63,10 +63,14 @@ def get_stage_dataset_from_config(self, stage): data_transform = transforms.Compose(self.transform_list) csv_file = self.config['dataset'].get(stage + '_csv', None) + if(stage == 'test'): + with_label = False + else: + with_label = self.config['dataset'].get(stage + '_label', True) dataset = NiftyDataset(root_dir = root_dir, csv_file = csv_file, modal_num = modal_num, - with_label= not (stage == 'test'), + with_label= with_label, transform = data_transform, task = self.task_type) return dataset @@ -189,8 +193,15 @@ def training(self): def validation(self): class_num = self.config['network']['class_num'] if(self.inferer is None): - infer_cfg = self.config['testing'] + infer_cfg = {} infer_cfg['class_num'] = class_num + infer_cfg['sliding_window_enable'] = self.config['testing'].get('sliding_window_enable', False) + if(infer_cfg['sliding_window_enable']): + patch_size = self.config['dataset'].get('patch_size', None) + if(patch_size is None): + patch_size = self.config['testing']['sliding_window_size'] + infer_cfg['sliding_window_size'] = patch_size + infer_cfg['sliding_window_stride'] = [i//2 for i in patch_size] self.inferer = Inferer(infer_cfg) valid_loss_list = [] diff --git a/pymic/net_run/infer_func.py b/pymic/net_run/infer_func.py index 43162d0..f6ed515 100644 --- a/pymic/net_run/infer_func.py +++ b/pymic/net_run/infer_func.py @@ -2,6 +2,8 @@ from __future__ import print_function, division import torch +import numpy as np +from scipy.ndimage.filters import gaussian_filter from torch.nn.functional import interpolate class Inferer(object): @@ -48,9 +50,19 @@ def __get_prediction_number_and_scales(self, tempx): output_num, scales = 1, None return output_num, scales + def __get_gaussian_weight_map(self, window_size, sigma_scale = 1.0/8): + w = np.zeros(window_size) + center = [i//2 for i in window_size] + sigmas = [i*sigma_scale for i in window_size] + w[tuple(center)] = 1.0 + w = gaussian_filter(w, sigmas, 0, mode='constant', cval=0) + return w + def __infer_with_sliding_window(self, image): """ - Use sliding window to predict segmentation for large images. + Use sliding window to predict segmentation for large images. The outupt of each + sliding window is weighted by a Gaussian map that hihglights contributions of windows + with a centroid closer to a given pixel. Note that the network may output a list of tensors with difference sizes. """ window_size = [x for x in self.config['sliding_window_size']] @@ -86,9 +98,10 @@ def __infer_with_sliding_window(self, image): crop_start_list.append([d_min, h_min, w_min]) output_shape = [batch_size, class_num] + img_shape - mask_shape = [batch_size, class_num] + window_size - counter = torch.zeros(output_shape).to(image.device) - temp_mask = torch.ones(mask_shape).to(image.device) + weight = torch.zeros(output_shape).to(image.device) + temp_w = self.__get_gaussian_weight_map(window_size) + temp_w = np.broadcast_to(temp_w, [batch_size, class_num] + window_size) + temp_w = torch.from_numpy(temp_w).to(image.device) temp_in_shape = img_full_shape[:2] + window_size tempx = torch.ones(temp_in_shape).to(image.device) out_num, scale_list = self.__get_prediction_number_and_scales(tempx) @@ -104,12 +117,12 @@ def __infer_with_sliding_window(self, image): if(isinstance(patch_out, (tuple, list))): patch_out = patch_out[0] if(img_dim == 2): - output[:, :, c0[0]:c1[0], c0[1]:c1[1]] += patch_out - counter[:, :, c0[0]:c1[0], c0[1]:c1[1]] += temp_mask + output[:, :, c0[0]:c1[0], c0[1]:c1[1]] += patch_out * temp_w + weight[:, :, c0[0]:c1[0], c0[1]:c1[1]] += temp_w else: - output[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] += patch_out - counter[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] += temp_mask - return output/counter + output[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] += patch_out * temp_w + weight[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] += temp_w + return output/weight else: # for multiple prediction output_list= [] for i in range(out_num): @@ -129,14 +142,14 @@ def __infer_with_sliding_window(self, image): c0_i = [int(c0[d] * scale_list[i][d]) for d in range(img_dim)] c1_i = [int(c1[d] * scale_list[i][d]) for d in range(img_dim)] if(img_dim == 2): - output_list[i][:, :, c0_i[0]:c1_i[0], c0_i[1]:c1_i[1]] += patch_out[i] - counter[:, :, c0[0]:c1[0], c0[1]:c1[1]] += temp_mask + output_list[i][:, :, c0_i[0]:c1_i[0], c0_i[1]:c1_i[1]] += patch_out[i] * temp_w + weight[:, :, c0[0]:c1[0], c0[1]:c1[1]] += temp_w else: - output_list[i][:, :, c0_i[0]:c1_i[0], c0_i[1]:c1_i[1], c0_i[2]:c1_i[2]] += patch_out[i] - counter[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] += temp_mask + output_list[i][:, :, c0_i[0]:c1_i[0], c0_i[1]:c1_i[1], c0_i[2]:c1_i[2]] += patch_out[i] * temp_w + weight[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] += temp_w for i in range(out_num): - counter_i = interpolate(counter, scale_factor = scale_list[i]) - output_list[i] = output_list[i] / counter_i + weight_i = interpolate(weight, scale_factor = scale_list[i]) + output_list[i] = output_list[i] / weight_i return output_list def run(self, model, image): From c29631ce7f3aa51b7fced92c54b34fec69417d80 Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 4 May 2023 22:05:03 +0800 Subject: [PATCH 147/225] update transform and image process update transform and image process --- pymic/transform/crop.py | 39 +++++--------- pymic/transform/pad.py | 2 +- pymic/util/image_process.py | 101 ++++++++++++++++++++++++++++++++++-- pymic/util/parse_config.py | 10 ++++ 4 files changed, 121 insertions(+), 31 deletions(-) diff --git a/pymic/transform/crop.py b/pymic/transform/crop.py index 1012712..b9345c8 100644 --- a/pymic/transform/crop.py +++ b/pymic/transform/crop.py @@ -198,39 +198,28 @@ def __init__(self, params): def _get_crop_param(self, sample): image = sample['image'] - input_shape = image.shape - input_dim = len(input_shape) - 1 + chns = image.shape[0] + input_shape = image.shape[1:] + input_dim = len(input_shape) assert(input_dim == len(self.output_size)) - temp_output_size = self.output_size - if(input_dim == 3 and self.output_size[0] is None): - # note that output size is [D, H, W] and input is [C, D, H, W] - temp_output_size = [input_shape[1]] + self.output_size[1:] - crop_margin = [input_shape[i + 1] - temp_output_size[i]\ - for i in range(input_dim)] + crop_margin = [input_shape[i] - self.output_size[i] for i in range(input_dim)] crop_min = [0 if item == 0 else random.randint(0, item) for item in crop_margin] + crop_max = [crop_min[i] + self.output_size[i] for i in range(input_dim)] if(self.fg_focus and random.random() < self.fg_ratio): - label = sample['label'] + label = sample['label'][0] mask = np.zeros_like(label) for temp_lab in self.mask_label: mask = np.maximum(mask, label == temp_lab) - if(mask.sum() == 0): - bb_min = [0] * (input_dim + 1) - bb_max = mask.shape - else: - bb_min, bb_max = get_ND_bounding_box(mask) - bb_min, bb_max = bb_min[1:], bb_max[1:] - crop_min = [random.randint(bb_min[i], bb_max[i]) - int(temp_output_size[i]/2) \ - for i in range(input_dim)] - crop_min = [max(0, item) for item in crop_min] - crop_min = [min(crop_min[i], input_shape[i+1] - temp_output_size[i]) \ - for i in range(input_dim)] - - crop_max = [crop_min[i] + temp_output_size[i] \ - for i in range(input_dim)] + if(mask.max() > 0): + crop_min, crop_max = get_random_box_from_mask(mask, self.output_size) + # to avoid Typeerror: object of type int64 is not json serializable + crop_min = [int(i) for i in crop_min] + crop_max = [int(i) for i in crop_max] crop_min = [0] + crop_min - crop_max = list(input_shape[0:1]) + crop_max - sample['RandomCrop_Param'] = json.dumps((input_shape, crop_min, crop_max)) + crop_max = [chns] + crop_max + + sample['RandomCrop_Param'] = json.dumps((image.shape, crop_min, crop_max)) return sample, crop_min, crop_max def _get_param_for_inverse_transform(self, sample): diff --git a/pymic/transform/pad.py b/pymic/transform/pad.py index 91cf6da..c9b75fe 100644 --- a/pymic/transform/pad.py +++ b/pymic/transform/pad.py @@ -22,7 +22,7 @@ class Pad(AbstractTransform): following fields: :param `Pad_output_size`: (list/tuple) The output size along each spatial axis. - :param `Pad_ceil_mode`: (optional, bool) If true (by default), the real output size will + :param `Pad_ceil_mode`: (optional, bool) If true, the real output size will be the minimal integer multiples of output_size higher than the input size. For example, the input image has a shape of [3, 100, 100], `Pad_output_size` = [32, 32], and the real output size will be [3, 128, 128] if `Pad_ceil_mode` = True. diff --git a/pymic/util/image_process.py b/pymic/util/image_process.py index 9a484e8..d5d7a7e 100644 --- a/pymic/util/image_process.py +++ b/pymic/util/image_process.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, print_function +import random import numpy as np import SimpleITK as sitk from scipy import ndimage @@ -96,7 +97,7 @@ def set_ND_volume_roi_with_bounding_box_range(volume, bb_min, bb_max, sub_volume raise ValueError("array dimension should be 2 to 5") return out -def crop_and_pad_ND_array_to_desired_shape(image, out_shape, pad_mod): +def crop_and_pad_ND_array_to_desired_shape(image, out_shape, pad_mod='reflect'): """ Crop and pad an image to a given shape. @@ -136,6 +137,96 @@ def crop_and_pad_ND_array_to_desired_shape(image, out_shape, pad_mod): return image_pad +def random_crop_ND_volume(volume, out_shape): + """ + randomly crop a volume with to a given shape. + + :param volume: The input ND array. + :param out_shape: (list) The desired output shape. + """ + in_shape = volume.shape + dim = len(in_shape) + + # pad the image first if the input size is smaller than the output size + pad_shape = [max(out_shape[i], in_shape[i]) for i in range(dim)] + mgnp = [pad_shape[i] - in_shape[i] for i in range(dim)] + if(max(mgnp) == 0): + image_pad = volume + else: + ml = [int(mgnp[i]/2) for i in range(dim)] + mr = [mgnp[i] - ml[i] for i in range(dim)] + pad = [(ml[i], mr[i]) for i in range(dim)] + pad = tuple(pad) + image_pad = np.pad(volume, pad, 'reflect') + + bb_min = [random.randint(0, pad_shape[i] - out_shape[i]) for i in range(dim)] + bb_max = [bb_min[i] + out_shape[i] for i in range(dim)] + crop_volume = crop_ND_volume_with_bounding_box(image_pad, bb_min, bb_max) + return crop_volume + +def get_random_box_from_mask(mask, out_shape): + mask_shape = mask.shape + dim = len(out_shape) + left_margin = [int(out_shape[i]/2) for i in range(dim)] + right_margin= [mask_shape[i] - (out_shape[i] - left_margin[i]) + 1 for i in range(dim)] + + valid_center_shape = [right_margin[i] - left_margin[i] for i in range(dim)] + valid_mask = np.zeros(mask_shape) + valid_mask = set_ND_volume_roi_with_bounding_box_range(valid_mask, + left_margin, right_margin, np.ones(valid_center_shape)) + valid_mask = valid_mask * mask + + indexes = np.where(valid_mask) + voxel_num = len(indexes[0]) + j = random.randint(0, voxel_num - 1) + bb_c = [indexes[i][j] for i in range(dim)] + bb_min = [bb_c[i] - left_margin[i] for i in range(dim)] + bb_max = [bb_min[i] + out_shape[i] for i in range(dim)] + return bb_min, bb_max + +def random_crop_ND_volume_with_mask(volume, out_shape, mask): + """ + randomly crop a volume with to a given shape. + + :param volume: The input ND array. + :param out_shape: (list) The desired output shape. + :param mask: A binary ND array. Default is None. If not None, + the center of the cropped region should be limited to the mask region. + """ + in_shape = volume.shape + dim = len(in_shape) + # pad the image first if the input size is smaller than the output size + pad_shape = [max(out_shape[i], in_shape[i]) for i in range(dim)] + mgnp = [pad_shape[i] - in_shape[i] for i in range(dim)] + if(max(mgnp) == 0): + image_pad, mask_pad = volume, mask + else: + ml = [int(mgnp[i]/2) for i in range(dim)] + mr = [mgnp[i] - ml[i] for i in range(dim)] + pad = [(ml[i], mr[i]) for i in range(dim)] + pad = tuple(pad) + image_pad = np.pad(volume, pad, 'reflect') + mask_pad = np.pad(mask, pad, 'reflect') + + bb_min, bb_max = get_random_box_from_mask(mask_pad, out_shape) + # left_margin = [int(out_shape[i]/2) for i in range(dim)] + # right_margin= [pad_shape[i] - (out_shape[i] - left_margin[i]) + 1 for i in range(dim)] + + # valid_center_shape = [right_margin[i] - left_margin[i] for i in range(dim)] + # valid_mask = np.zeros(pad_shape) + # valid_mask = set_ND_volume_roi_with_bounding_box_range(valid_mask, + # left_margin, right_margin, np.ones(valid_center_shape)) + # valid_mask = valid_mask * mask_pad + + # indexes = np.where(valid_mask) + # voxel_num = len(indexes[0]) + # j = random.randint(0, voxel_num) + # bb_c = [indexes[i][j] for i in range(dim)] + # bb_min = [bb_c[i] - left_margin[i] for i in range(dim)] + # bb_max = [bb_min[i] + out_shape[i] for i in range(dim)] + crop_volume = crop_ND_volume_with_bounding_box(image_pad, bb_min, bb_max) + return crop_volume + def get_largest_k_components(image, k = 1): """ Get the largest K components from 2D or 3D binary image. @@ -200,11 +291,11 @@ def convert_label(label, source_list, target_list): :param target_list: A list of target labels, e.g. [0, 1, 2, 3] """ assert(len(source_list) == len(target_list)) - label_converted = np.zeros_like(label) + label_converted = label * 1 for i in range(len(source_list)): - label_temp = np.asarray(label == source_list[i], label.dtype) - label_temp = label_temp * target_list[i] - label_converted = label_converted + label_temp + label_s = np.asarray(label == source_list[i], label.dtype) + label_t = label_s * target_list[i] + label_converted[label_s > 0] = label_t[label_s > 0] return label_converted def resample_sitk_image_to_given_spacing(image, spacing, order): diff --git a/pymic/util/parse_config.py b/pymic/util/parse_config.py index 18afd08..a12cc76 100644 --- a/pymic/util/parse_config.py +++ b/pymic/util/parse_config.py @@ -108,6 +108,16 @@ def synchronize_config(config): data_cfg["LabelToProbability_class_num".lower()] = net_cfg["class_num"] if "PartialLabelToProbability" in data_cfg['train_transform']: data_cfg["PartialLabelToProbability_class_num".lower()] = net_cfg["class_num"] + patch_size = data_cfg.get('patch_size', None) + if(patch_size is not None): + if('Pad' in data_cfg['train_transform']): + data_cfg['Pad_output_size'.lower()] = patch_size + if('CenterCrop' in data_cfg['train_transform']): + data_cfg['CenterCrop_output_size'.lower()] = patch_size + if('RandomCrop' in data_cfg['train_transform']): + data_cfg['RandomCrop_output_size'.lower()] = patch_size + if('RandomResizedCrop' in data_cfg['train_transform']): + data_cfg['RandomResizedCrop_output_size'.lower()] = patch_size config['dataset'] = data_cfg config['network'] = net_cfg return config From d0e800aec25123f7856b1b21f518c24d900202c2 Mon Sep 17 00:00:00 2001 From: taigw Date: Sun, 7 May 2023 21:26:10 +0800 Subject: [PATCH 148/225] update transform and inference mode 1, add CropwithForeground 2, all large batch size for sliding window inference 3, add BinaryDice loss and GroupDiceLoss --- pymic/loss/loss_dict_seg.py | 5 +- pymic/loss/seg/dice.py | 53 +++++++++++++++++++++ pymic/net_run/infer_func.py | 87 ++++++++++++++++++++++------------- pymic/transform/crop.py | 40 ++++++++++++++++ pymic/transform/trans_dict.py | 1 + 5 files changed, 154 insertions(+), 32 deletions(-) diff --git a/pymic/loss/loss_dict_seg.py b/pymic/loss/loss_dict_seg.py index 97c537e..fd72ce4 100644 --- a/pymic/loss/loss_dict_seg.py +++ b/pymic/loss/loss_dict_seg.py @@ -23,7 +23,8 @@ from __future__ import print_function, division import torch.nn as nn from pymic.loss.seg.ce import CrossEntropyLoss, GeneralizedCELoss -from pymic.loss.seg.dice import DiceLoss, FocalDiceLoss, NoiseRobustDiceLoss +from pymic.loss.seg.dice import DiceLoss, FocalDiceLoss, \ + NoiseRobustDiceLoss, BinaryDiceLoss, GroupDiceLoss from pymic.loss.seg.exp_log import ExpLogLoss from pymic.loss.seg.mse import MSELoss, MAELoss from pymic.loss.seg.slsr import SLSRLoss @@ -32,8 +33,10 @@ 'CrossEntropyLoss': CrossEntropyLoss, 'GeneralizedCELoss': GeneralizedCELoss, 'DiceLoss': DiceLoss, + 'BinaryDiceLoss': BinaryDiceLoss, 'FocalDiceLoss': FocalDiceLoss, 'NoiseRobustDiceLoss': NoiseRobustDiceLoss, + 'GroupDiceLoss': GroupDiceLoss, 'ExpLogLoss': ExpLogLoss, 'MAELoss': MAELoss, 'MSELoss': MSELoss, diff --git a/pymic/loss/seg/dice.py b/pymic/loss/seg/dice.py index c3a1134..350e0c4 100644 --- a/pymic/loss/seg/dice.py +++ b/pymic/loss/seg/dice.py @@ -31,6 +31,59 @@ def forward(self, loss_input_dict): dice_loss = 1.0 - dice_score.mean() return dice_loss +class BinaryDiceLoss(AbstractSegLoss): + ''' + Fuse all the foreground classes together and calculate the Dice value. + ''' + def __init__(self, params = None): + super(BinaryDiceLoss, self).__init__(params) + + def forward(self, loss_input_dict): + predict = loss_input_dict['prediction'] + soft_y = loss_input_dict['ground_truth'] + + if(isinstance(predict, (list, tuple))): + predict = predict[0] + if(self.softmax): + predict = nn.Softmax(dim = 1)(predict) + predict = 1.0 - predict[:, :1, :, :, :] + soft_y = 1.0 - soft_y[:, :1, :, :, :] + predict = reshape_tensor_to_2D(predict) + soft_y = reshape_tensor_to_2D(soft_y) + dice_score = get_classwise_dice(predict, soft_y) + dice_loss = 1.0 - dice_score.mean() + return dice_loss + +class GroupDiceLoss(AbstractSegLoss): + ''' + Fuse all the foreground classes together and calculate the Dice value. + ''' + def __init__(self, params = None): + super(GroupDiceLoss, self).__init__(params) + self.group = 2 + + def forward(self, loss_input_dict): + predict = loss_input_dict['prediction'] + soft_y = loss_input_dict['ground_truth'] + + if(isinstance(predict, (list, tuple))): + predict = predict[0] + if(self.softmax): + predict = nn.Softmax(dim = 1)(predict) + predict = reshape_tensor_to_2D(predict) + soft_y = reshape_tensor_to_2D(soft_y) + num_class = list(predict.size())[1] + cls_per_group = (num_class - 1) // self.group + loss_all = 0.0 + for g in range(self.group): + c0 = 1 + g*cls_per_group + c1 = min(num_class, c0 + cls_per_group) + pred_g = torch.sum(predict[:, c0:c1], dim = 1, keepdim = True) + y_g = torch.sum( soft_y[:, c0:c1], dim = 1, keepdim = True) + loss_all += 1.0 - get_classwise_dice(pred_g, y_g)[0] + avg_loss = loss_all / self.group + return avg_loss + class FocalDiceLoss(AbstractSegLoss): """ Focal Dice according to the following paper: diff --git a/pymic/net_run/infer_func.py b/pymic/net_run/infer_func.py index f6ed515..d2b9af2 100644 --- a/pymic/net_run/infer_func.py +++ b/pymic/net_run/infer_func.py @@ -63,13 +63,16 @@ def __infer_with_sliding_window(self, image): Use sliding window to predict segmentation for large images. The outupt of each sliding window is weighted by a Gaussian map that hihglights contributions of windows with a centroid closer to a given pixel. - Note that the network may output a list of tensors with difference sizes. + Note that the network may output a list of tensors with difference sizes for multi-scale prediction. """ window_size = [x for x in self.config['sliding_window_size']] window_stride = [x for x in self.config['sliding_window_stride']] + window_batch = self.config.get('sliding_window_batch', 1) class_num = self.config['class_num'] img_full_shape = list(image.shape) batch_size = img_full_shape[0] + assert(batch_size == 1 or window_batch == 1) + img_chns = img_full_shape[1] img_shape = img_full_shape[2:] img_dim = len(img_shape) if(img_dim != 2 and img_dim !=3): @@ -105,23 +108,37 @@ def __infer_with_sliding_window(self, image): temp_in_shape = img_full_shape[:2] + window_size tempx = torch.ones(temp_in_shape).to(image.device) out_num, scale_list = self.__get_prediction_number_and_scales(tempx) + + window_num = len(crop_start_list) + assert(window_num >= window_batch) + patches_shape = [window_batch, img_chns] + window_size + patches_in = torch.ones(patches_shape).to(image.device) if(out_num == 1): # for a single prediction output = torch.zeros(output_shape).to(image.device) - for c0 in crop_start_list: - c1 = [c0[d] + window_size[d] for d in range(img_dim)] - if(img_dim == 2): - patch_in = image[:, :, c0[0]:c1[0], c0[1]:c1[1]] - else: - patch_in = image[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] - patch_out = self.model(patch_in) - if(isinstance(patch_out, (tuple, list))): - patch_out = patch_out[0] - if(img_dim == 2): - output[:, :, c0[0]:c1[0], c0[1]:c1[1]] += patch_out * temp_w - weight[:, :, c0[0]:c1[0], c0[1]:c1[1]] += temp_w - else: - output[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] += patch_out * temp_w - weight[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] += temp_w + for w_i in range(0, window_num, window_batch): + for k in range(window_batch): + if(w_i + k >= window_num): + break + c0 = crop_start_list[w_i + k] + c1 = [c0[d] + window_size[d] for d in range(img_dim)] + if(img_dim == 2): + patches_in[k] = image[:, :, c0[0]:c1[0], c0[1]:c1[1]] + else: + patches_in[k] = image[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] + patches_out = self.model(patches_in) + if(isinstance(patches_out, (tuple, list))): + patches_out = patches_out[0] + for k in range(window_batch): + if(w_i + k >= window_num): + break + c0 = crop_start_list[w_i + k] + c1 = [c0[d] + window_size[d] for d in range(img_dim)] + if(img_dim == 2): + output[:, :, c0[0]:c1[0], c0[1]:c1[1]] += patches_out[k] * temp_w + weight[:, :, c0[0]:c1[0], c0[1]:c1[1]] += temp_w + else: + output[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] += patches_out[k] * temp_w + weight[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] += temp_w return output/weight else: # for multiple prediction output_list= [] @@ -130,23 +147,31 @@ def __infer_with_sliding_window(self, image): [int(img_shape[d] * scale_list[i][d]) for d in range(img_dim)] output_list.append(torch.zeros(output_shape_i).to(image.device)) - for c0 in crop_start_list: - c1 = [c0[d] + window_size[d] for d in range(img_dim)] - if(img_dim == 2): - patch_in = image[:, :, c0[0]:c1[0], c0[1]:c1[1]] - else: - patch_in = image[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] - patch_out = self.model(patch_in) - - for i in range(out_num): - c0_i = [int(c0[d] * scale_list[i][d]) for d in range(img_dim)] - c1_i = [int(c1[d] * scale_list[i][d]) for d in range(img_dim)] + for w_i in range(0, window_num, window_batch): + for k in range(window_batch): + if(w_i + k >= window_num): + break + c0 = crop_start_list[w_i + k] + c1 = [c0[d] + window_size[d] for d in range(img_dim)] if(img_dim == 2): - output_list[i][:, :, c0_i[0]:c1_i[0], c0_i[1]:c1_i[1]] += patch_out[i] * temp_w - weight[:, :, c0[0]:c1[0], c0[1]:c1[1]] += temp_w + patches_in[k] = image[:, :, c0[0]:c1[0], c0[1]:c1[1]] else: - output_list[i][:, :, c0_i[0]:c1_i[0], c0_i[1]:c1_i[1], c0_i[2]:c1_i[2]] += patch_out[i] * temp_w - weight[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] += temp_w + patches_in[k] = image[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] + patches_out = self.model(patches_in) + + for i in range(out_num): + for k in range(window_batch): + if(w_i + k >= window_num): + break + c0 = crop_start_list[w_i + k] + c0_i = [int(c0[d] * scale_list[i][d]) for d in range(img_dim)] + c1_i = [int(c1[d] * scale_list[i][d]) for d in range(img_dim)] + if(img_dim == 2): + output_list[i][:, :, c0_i[0]:c1_i[0], c0_i[1]:c1_i[1]] += patches_out[i][k] * temp_w + weight[:, :, c0[0]:c1[0], c0[1]:c1[1]] += temp_w + else: + output_list[i][:, :, c0_i[0]:c1_i[0], c0_i[1]:c1_i[1], c0_i[2]:c1_i[2]] += patches_out[i][k] * temp_w + weight[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] += temp_w for i in range(out_num): weight_i = interpolate(weight, scale_factor = scale_list[i]) output_list[i] = output_list[i] / weight_i diff --git a/pymic/transform/crop.py b/pymic/transform/crop.py index b9345c8..95e3489 100644 --- a/pymic/transform/crop.py +++ b/pymic/transform/crop.py @@ -164,7 +164,47 @@ def _get_param_for_inverse_transform(self, sample): params = json.loads(sample['CropWithBoundingBox_Param']) return params +class CropWithForeground(CenterCrop): + """ + Crop the image (shape [C, D, H, W] or [C, H, W]) based on a bounding box. + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `CropWithBoundingBox_start`: (None, or list/tuple) The start index + along each spatial axis. If None, calculate the start index automatically + so that the cropped region is centered at the non-zero region. + :param `CropWithBoundingBox_output_size`: (None or tuple/list): + Desired spatial output size. + If None, set it as the size of bounding box of non-zero region. + :param `CropWithBoundingBox_inverse`: (optional, bool) Is inverse transform needed for inference. + Default is `True`. + """ + def __init__(self, params): + self.labels = params.get('CropWithForeground_labels'.lower(), None) + self.margin = params.get('CropWithForeground_margin'.lower(), [5, 10, 10]) + self.inverse = params.get('CropWithForeground_inverse'.lower(), True) + self.task = params['task'] + + def _get_crop_param(self, sample): + image = sample['image'] + label = sample['label'] + input_shape = sample['image'].shape + bb_min, bb_max = get_ND_bounding_box(label, margin=[0] + self.margin) + bb_max[0] = input_shape[0] + + sample['CropWithForeground_Param'] = json.dumps((input_shape, bb_min, bb_max)) + + return sample, bb_min, bb_max + + def _get_param_for_inverse_transform(self, sample): + if(isinstance(sample['CropWithForeground_Param'], list) or \ + isinstance(sample['CropWithForeground_Param'], tuple)): + params = json.loads(sample['CropWithForeground_Param'][0]) + else: + params = json.loads(sample['CropWithForeground_Param']) + return params + class RandomCrop(CenterCrop): """Randomly crop the input image (shape [C, D, H, W] or [C, H, W]). diff --git a/pymic/transform/trans_dict.py b/pymic/transform/trans_dict.py index e8f1c37..c1ecdc2 100644 --- a/pymic/transform/trans_dict.py +++ b/pymic/transform/trans_dict.py @@ -46,6 +46,7 @@ 'ChannelWiseThreshold': ChannelWiseThreshold, 'ChannelWiseThresholdWithNormalize': ChannelWiseThresholdWithNormalize, 'CropWithBoundingBox': CropWithBoundingBox, + 'CropWithForeground': CropWithForeground, 'CenterCrop': CenterCrop, 'GrayscaleToRGB': GrayscaleToRGB, 'GammaCorrection': GammaCorrection, From f6be3eafa7b79e8f40cc18708610ab33c7adc69e Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 16 May 2023 09:38:14 +0800 Subject: [PATCH 149/225] update checkpoint save code Save the best checkpoint immediately --- pymic/net/net_dict_seg.py | 31 +++++++++++++++++++++++++++++-- pymic/net_run/agent_abstract.py | 2 ++ pymic/net_run/agent_cls.py | 19 +++++++++---------- pymic/net_run/agent_seg.py | 21 +++++++++++---------- 4 files changed, 51 insertions(+), 22 deletions(-) diff --git a/pymic/net/net_dict_seg.py b/pymic/net/net_dict_seg.py index fc7692f..e6c10bd 100644 --- a/pymic/net/net_dict_seg.py +++ b/pymic/net/net_dict_seg.py @@ -21,10 +21,24 @@ from pymic.net.net2d.unet2d_attention import AttentionUNet2D from pymic.net.net2d.unet2d_nest import NestedUNet2D from pymic.net.net2d.unet2d_scse import UNet2D_ScSE +from pymic.net.net2d.trans2d.transunet import TransUNet +from pymic.net.net2d.trans2d.swinunet import SwinUNet from pymic.net.net3d.unet2d5 import UNet2D5 from pymic.net.net3d.unet3d import UNet3D from pymic.net.net3d.unet3d_scse import UNet3D_ScSE from pymic.net.net3d.unet3d_dual_branch import UNet3D_DualBranch +from pymic.net.net3d.trans3d.nnFormer_wrap import nnFormer_wrap +from pymic.net.net3d.trans3d.unetr import UNETR +from pymic.net.net3d.trans3d.unetr_pp import UNETR_PP +from pymic.net.net3d.trans3d.MedFormer_v1 import MedFormerV1 +from pymic.net.net3d.trans3d.MedFormer_v2 import MedFormerV2 +from pymic.net.net3d.trans3d.MedFormer_v3 import MedFormerV3 +from pymic.net.net3d.trans3d.MedFormer_va1 import MedFormerVA1 +from pymic.net.net3d.trans3d.HiFormer_v1 import HiFormer_v1 +from pymic.net.net3d.trans3d.HiFormer_v2 import HiFormer_v2 +from pymic.net.net3d.trans3d.HiFormer_v3 import HiFormer_v3 +from pymic.net.net3d.trans3d.HiFormer_v4 import HiFormer_v4 +from pymic.net.net3d.trans3d.HiFormer_v5 import HiFormer_v5 SegNetDict = { 'UNet2D': UNet2D, @@ -34,9 +48,22 @@ 'AttentionUNet2D': AttentionUNet2D, 'NestedUNet2D': NestedUNet2D, 'UNet2D_ScSE': UNet2D_ScSE, + 'TransUNet': TransUNet, + 'SwinUNet': SwinUNet, 'UNet2D5': UNet2D5, 'UNet3D': UNet3D, 'UNet3D_ScSE': UNet3D_ScSE, - 'UNet3D_DualBranch': UNet3D_DualBranch - + 'UNet3D_DualBranch': UNet3D_DualBranch, + 'nnFormer': nnFormer_wrap, + 'UNETR': UNETR, + 'UNETR_PP': UNETR_PP, + 'MedFormerV1': MedFormerV1, + 'MedFormerV2': MedFormerV2, + 'MedFormerV3': MedFormerV3, + 'MedFormerVA1':MedFormerVA1, + 'HiFormer_v1': HiFormer_v1, + 'HiFormer_v2': HiFormer_v2, + 'HiFormer_v3': HiFormer_v3, + 'HiFormer_v4': HiFormer_v4, + 'HiFormer_v5': HiFormer_v5 } diff --git a/pymic/net_run/agent_abstract.py b/pymic/net_run/agent_abstract.py index 109eabc..7a49a2b 100644 --- a/pymic/net_run/agent_abstract.py +++ b/pymic/net_run/agent_abstract.py @@ -148,6 +148,8 @@ def get_checkpoint_name(self): with open(txt_name, 'r') as txt_file: it_num = txt_file.read().replace('\n', '') ckpt_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, it_num) + if(ckpt_mode == 1 and not os.path.isfile(ckpt_name)): + ckpt_name = "{0:}/{1:}_best.pt".format(ckpt_dir, ckpt_prefix) else: ckpt_name = self.config['testing']['ckpt_name'] return ckpt_name diff --git a/pymic/net_run/agent_cls.py b/pymic/net_run/agent_cls.py index 1ded3b1..ee4e25b 100644 --- a/pymic/net_run/agent_cls.py +++ b/pymic/net_run/agent_cls.py @@ -285,6 +285,15 @@ def train_valid(self): self.best_model_wts = copy.deepcopy(self.net.module.state_dict()) else: self.best_model_wts = copy.deepcopy(self.net.state_dict()) + save_dict = {'iteration': self.max_val_it, + 'valid_pred': self.max_val_score, + 'model_state_dict': self.best_model_wts, + 'optimizer_state_dict': self.optimizer.state_dict()} + save_name = "{0:}/{1:}_best.pt".format(ckpt_dir, ckpt_prefix) + torch.save(save_dict, save_name) + txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefix), 'wt') + txt_file.write(str(self.max_val_it)) + txt_file.close() stop_now = True if(early_stop_it is not None and \ self.glob_it - self.max_val_it > early_stop_it) else False @@ -302,16 +311,6 @@ def train_valid(self): if(stop_now): logging.info("The training is early stopped") break - # save the best performing checkpoint - save_dict = {'iteration': self.max_val_it, - 'valid_pred': self.max_val_score, - 'model_state_dict': self.best_model_wts, - 'optimizer_state_dict': self.optimizer.state_dict()} - save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.max_val_it) - torch.save(save_dict, save_name) - txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefix), 'wt') - txt_file.write(str(self.max_val_it)) - txt_file.close() logging.info('The best perfroming iter is {0:}, valid {1:} {2:}'.format(\ self.max_val_it, metrics, self.max_val_score)) self.summ_writer.close() diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 2e703d4..b007574 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -337,6 +337,7 @@ def train_valid(self): logging.info('learning rate {0:}'.format(lr_value)) logging.info("training/validation time: {0:.2f}s/{1:.2f}s".format(t1-t0, t2-t1)) self.write_scalars(train_scalars, valid_scalars, lr_value, self.glob_it) + if(valid_scalars['avg_fg_dice'] > self.max_val_dice): self.max_val_dice = valid_scalars['avg_fg_dice'] self.max_val_it = self.glob_it @@ -344,8 +345,17 @@ def train_valid(self): self.best_model_wts = copy.deepcopy(self.net.module.state_dict()) else: self.best_model_wts = copy.deepcopy(self.net.state_dict()) + save_dict = {'iteration': self.max_val_it, + 'valid_pred': self.max_val_dice, + 'model_state_dict': self.best_model_wts, + 'optimizer_state_dict': self.optimizer.state_dict()} + save_name = "{0:}/{1:}_best.pt".format(ckpt_dir, ckpt_prefix) + torch.save(save_dict, save_name) + txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefix), 'wt') + txt_file.write(str(self.max_val_it)) + txt_file.close() - stop_now = True if(early_stop_it is not None and \ + stop_now = True if (early_stop_it is not None and \ self.glob_it - self.max_val_it > early_stop_it) else False if ((self.glob_it in iter_save_list) or stop_now): save_dict = {'iteration': self.glob_it, @@ -362,15 +372,6 @@ def train_valid(self): logging.info("The training is early stopped") break # save the best performing checkpoint - save_dict = {'iteration': self.max_val_it, - 'valid_pred': self.max_val_dice, - 'model_state_dict': self.best_model_wts, - 'optimizer_state_dict': self.optimizer.state_dict()} - save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.max_val_it) - torch.save(save_dict, save_name) - txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefix), 'wt') - txt_file.write(str(self.max_val_it)) - txt_file.close() logging.info('The best performing iter is {0:}, valid dice {1:}'.format(\ self.max_val_it, self.max_val_dice)) self.summ_writer.close() From a205f31e5d4d44d562d83afc9617337820b0e841 Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 18 May 2023 12:28:16 +0800 Subject: [PATCH 150/225] update code for MIM add transform MaskedImageModelingLabel --- pymic/loss/seg/mse.py | 6 +- pymic/net/net3d/trans3d/HiFormer_v1.py | 1010 +++++++++++++++++ pymic/net/net3d/trans3d/HiFormer_v2.py | 381 +++++++ pymic/net/net3d/trans3d/HiFormer_v3.py | 455 ++++++++ pymic/net/net3d/trans3d/HiFormer_v4.py | 455 ++++++++ pymic/net/net3d/trans3d/HiFormer_v5.py | 308 +++++ pymic/net/net3d/trans3d/MedFormer_v1.py | 173 +++ pymic/net/net3d/trans3d/MedFormer_v2.py | 464 ++++++++ pymic/net/net3d/trans3d/MedFormer_v3.py | 255 +++++ pymic/net/net3d/trans3d/MedFormer_va1.py | 105 ++ pymic/net/net3d/trans3d/__init__.py | 0 pymic/net/net3d/trans3d/nnFormer_wrap.py | 43 + pymic/net/net3d/trans3d/unetr.py | 227 ++++ pymic/net/net3d/trans3d/unetr_pp.py | 460 ++++++++ pymic/net/net3d/trans3d/unetr_pp_block.py | 278 +++++ pymic/net_run/agent_rec.py | 6 +- pymic/net_run/self_sup/__init__.py | 3 +- .../net_run/self_sup/self_patch_mix_agent.py | 144 +++ pymic/net_run/self_sup/util.py | 167 +++ pymic/transform/label_convert.py | 59 +- pymic/transform/mix.py | 66 ++ pymic/transform/trans_dict.py | 3 +- 22 files changed, 5046 insertions(+), 22 deletions(-) create mode 100644 pymic/net/net3d/trans3d/HiFormer_v1.py create mode 100644 pymic/net/net3d/trans3d/HiFormer_v2.py create mode 100644 pymic/net/net3d/trans3d/HiFormer_v3.py create mode 100644 pymic/net/net3d/trans3d/HiFormer_v4.py create mode 100644 pymic/net/net3d/trans3d/HiFormer_v5.py create mode 100644 pymic/net/net3d/trans3d/MedFormer_v1.py create mode 100644 pymic/net/net3d/trans3d/MedFormer_v2.py create mode 100644 pymic/net/net3d/trans3d/MedFormer_v3.py create mode 100644 pymic/net/net3d/trans3d/MedFormer_va1.py create mode 100644 pymic/net/net3d/trans3d/__init__.py create mode 100644 pymic/net/net3d/trans3d/nnFormer_wrap.py create mode 100644 pymic/net/net3d/trans3d/unetr.py create mode 100644 pymic/net/net3d/trans3d/unetr_pp.py create mode 100644 pymic/net/net3d/trans3d/unetr_pp_block.py create mode 100644 pymic/net_run/self_sup/self_patch_mix_agent.py create mode 100644 pymic/net_run/self_sup/util.py create mode 100644 pymic/transform/mix.py diff --git a/pymic/loss/seg/mse.py b/pymic/loss/seg/mse.py index ad83899..5b657c5 100644 --- a/pymic/loss/seg/mse.py +++ b/pymic/loss/seg/mse.py @@ -40,11 +40,15 @@ def __init__(self, params = None): def forward(self, loss_input_dict): predict = loss_input_dict['prediction'] soft_y = loss_input_dict['ground_truth'] + weight = loss_input_dict.get('pixel_weight', None) if(isinstance(predict, (list, tuple))): predict = predict[0] if(self.softmax): predict = nn.Softmax(dim = 1)(predict) mae = torch.abs(predict - soft_y) - mae = torch.mean(mae) + if(weight is None): + mae = torch.mean(mae) + else: + mae = torch.sum(mae * weight) / weight.sum() return mae diff --git a/pymic/net/net3d/trans3d/HiFormer_v1.py b/pymic/net/net3d/trans3d/HiFormer_v1.py new file mode 100644 index 0000000..af73683 --- /dev/null +++ b/pymic/net/net3d/trans3d/HiFormer_v1.py @@ -0,0 +1,1010 @@ +from einops import rearrange +from copy import deepcopy +from nnformer.utilities.nd_softmax import softmax_helper +from torch import nn +import torch +import numpy as np +import torch.nn.functional +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_3tuple, trunc_normal_ +from pymic.net.net3d.unet3d import ConvBlock, DownBlock +# from nnFormer +class ContiguousGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return x + @staticmethod + def backward(ctx, grad_out): + return grad_out.contiguous() + +# from nnFormer +class Mlp(nn.Module): + """ Multilayer perceptron.""" + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + +# from nnFormer +def window_partition(x, window_size): + + B, S, H, W, C = x.shape + x = x.view(B, S // window_size, window_size, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size, window_size, window_size, C) + return windows + +# from nnFormer +def window_reverse(windows, window_size, S, H, W): + + B = int(windows.shape[0] / (S * H * W / window_size / window_size / window_size)) + x = windows.view(B, S // window_size, H // window_size, W // window_size, window_size, window_size, window_size, -1) + x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, S, H, W, -1) + return x + + +# from nnFormer +class SwinTransformerBlock_kv(nn.Module): + + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention_kv( + dim, window_size=to_3tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + #self.window_size=to_3tuple(self.window_size) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, mask_matrix,skip=None,x_up=None): + + B, L, C = x.shape + S, H, W = self.input_resolution + + assert L == S * H * W, "input feature has wrong size" + + shortcut = x + skip = self.norm1(skip) + x_up = self.norm1(x_up) + + skip = skip.view(B, S, H, W, C) + x_up = x_up.view(B, S, H, W, C) + x = x.view(B, S, H, W, C) + # pad feature maps to multiples of window size + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + pad_g = (self.window_size - S % self.window_size) % self.window_size + + skip = F.pad(skip, (0, 0, 0, pad_r, 0, pad_b, 0, pad_g)) + x_up = F.pad(x_up, (0, 0, 0, pad_r, 0, pad_b, 0, pad_g)) + _, Sp, Hp, Wp, _ = skip.shape + + + + # cyclic shift + if self.shift_size > 0: + skip = torch.roll(skip, shifts=(-self.shift_size, -self.shift_size,-self.shift_size), dims=(1, 2,3)) + x_up = torch.roll(x_up, shifts=(-self.shift_size, -self.shift_size,-self.shift_size), dims=(1, 2,3)) + attn_mask = mask_matrix + else: + skip = skip + x_up=x_up + attn_mask = None + # partition windows + skip = window_partition(skip, self.window_size) + skip = skip.view(-1, self.window_size * self.window_size * self.window_size, + C) + x_up = window_partition(x_up, self.window_size) + x_up = x_up.view(-1, self.window_size * self.window_size * self.window_size, + C) + attn_windows=self.attn(skip,x_up,mask=attn_mask,pos_embed=None) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Sp, Hp, Wp) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size, self.shift_size), dims=(1, 2, 3)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0 or pad_g > 0: + x = x[:, :S, :H, :W, :].contiguous() + + x = x.view(B, S * H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + +# from nnFormer +class WindowAttention_kv(nn.Module): + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), + num_heads)) + + # get pair-wise relative position index for each token inside the window + coords_s = torch.arange(self.window_size[0]) + coords_h = torch.arange(self.window_size[1]) + coords_w = torch.arange(self.window_size[2]) + coords = torch.stack(torch.meshgrid([coords_s, coords_h, coords_w])) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 2] += self.window_size[2] - 1 + + relative_coords[:, :, 0] *= 3 * self.window_size[1] - 1 + relative_coords[:, :, 1] *= 2 * self.window_size[1] - 1 + + relative_position_index = relative_coords.sum(-1) + self.register_buffer("relative_position_index", relative_position_index) + + self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + trunc_normal_(self.relative_position_bias_table, std=.02) + + + def forward(self, skip,x_up,pos_embed=None, mask=None): + + B_, N, C = skip.shape + + kv = self.kv(skip) + q = x_up + + kv=kv.reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q = q.reshape(B_,N,self.num_heads,C//self.num_heads).permute(0,2,1,3).contiguous() + k,v = kv[0], kv[1] + q = q * self.scale + attn = (q @ k.transpose(-2, -1).contiguous()) + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] * self.window_size[2], + self.window_size[0] * self.window_size[1] * self.window_size[2], -1) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C).contiguous() + if pos_embed is not None: + x = x + pos_embed + x = self.proj(x) + x = self.proj_drop(x) + return x + +# from nnFormer +class WindowAttention(nn.Module): + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), + num_heads)) + + # get pair-wise relative position index for each token inside the window + coords_s = torch.arange(self.window_size[0]) + coords_h = torch.arange(self.window_size[1]) + coords_w = torch.arange(self.window_size[2]) + coords = torch.stack(torch.meshgrid([coords_s, coords_h, coords_w])) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 2] += self.window_size[2] - 1 + + relative_coords[:, :, 0] *= 3 * self.window_size[1] - 1 + relative_coords[:, :, 1] *= 2 * self.window_size[1] - 1 + + relative_position_index = relative_coords.sum(-1) + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + trunc_normal_(self.relative_position_bias_table, std=.02) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None,pos_embed=None): + + B_, N, C = x.shape + + qkv = self.qkv(x) + + qkv=qkv.reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1).contiguous()) + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] * self.window_size[2], + self.window_size[0] * self.window_size[1] * self.window_size[2], -1) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C).contiguous() + if pos_embed is not None: + x = x+pos_embed + x = self.proj(x) + x = self.proj_drop(x) + return x + +# from nnFormer +class SwinTransformerBlock(nn.Module): + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + + self.attn = WindowAttention( + dim, window_size=to_3tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + + def forward(self, x, mask_matrix): + + B, L, C = x.shape + S, H, W = self.input_resolution + + assert L == S * H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, S, H, W, C) + + # pad feature maps to multiples of window size + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + pad_g = (self.window_size - S % self.window_size) % self.window_size + + x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b, 0, pad_g)) + _, Sp, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size,-self.shift_size), dims=(1, 2,3)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size * self.window_size, + C) + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=attn_mask,pos_embed=None) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Sp, Hp, Wp) + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size, self.shift_size), dims=(1, 2, 3)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0 or pad_g > 0: + x = x[:, :S, :H, :W, :].contiguous() + + x = x.view(B, S * H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + +# from nnFormer +class PatchMerging(nn.Module): + + + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Conv3d(dim,dim*2,kernel_size=3,stride=2,padding=1) + + self.norm = norm_layer(dim) + + def forward(self, x, S, H, W): + + B, L, C = x.shape + assert L == H * W * S, "input feature has wrong size" + x = x.view(B, S, H, W, C) + + x = F.gelu(x) + x = self.norm(x) + x=x.permute(0,4,1,2,3).contiguous() + x=self.reduction(x) + x=x.permute(0,2,3,4,1).contiguous().view(B,-1,2*C) + + return x + +# from nnFormer +class Patch_Expanding(nn.Module): + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + + self.norm = norm_layer(dim) + self.up=nn.ConvTranspose3d(dim,dim//2,2,2) + def forward(self, x, S, H, W): + + + B, L, C = x.shape + assert L == H * W * S, "input feature has wrong size" + + x = x.view(B, S, H, W, C) + + + + x = self.norm(x) + x=x.permute(0,4,1,2,3).contiguous() + x = self.up(x) + x = ContiguousGrad.apply(x) + x=x.permute(0,2,3,4,1).contiguous().view(B,-1,C//2) + + return x + +# from nnFormer +class BasicLayer(nn.Module): + + def __init__(self, + dim, + input_resolution, + depth, + num_heads, + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=True + ): + super().__init__() + self.window_size = window_size + self.shift_size = window_size // 2 + self.depth = depth + # build blocks + + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, S, H, W): + + + # calculate attention mask for SW-MSA + Sp = int(np.ceil(S / self.window_size)) * self.window_size + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Sp, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + s_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for s in s_slices: + for h in h_slices: + for w in w_slices: + img_mask[:, s, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) + mask_windows = mask_windows.view(-1, + self.window_size * self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + for blk in self.blocks: + + x = blk(x, attn_mask) + if self.downsample is not None: + x_down = self.downsample(x, S, H, W) + Ws, Wh, Ww = (S + 1) // 2, (H + 1) // 2, (W + 1) // 2 + return x, S, H, W, x_down, Ws, Wh, Ww + else: + return x, S, H, W, x, S, H, W + +# from nnFormer +class BasicLayer_up(nn.Module): + + def __init__(self, + dim, + input_resolution, + depth, + num_heads, + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + upsample=True + ): + super().__init__() + self.window_size = window_size + self.shift_size = window_size // 2 + self.depth = depth + + + # build blocks + self.blocks = nn.ModuleList() + self.blocks.append( + SwinTransformerBlock_kv( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=0 , + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[0] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) + ) + for i in range(depth-1): + self.blocks.append( + SwinTransformerBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=window_size // 2 , + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i+1] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) + ) + + + + self.Upsample = upsample(dim=2*dim, norm_layer=norm_layer) + def forward(self, x,skip, S, H, W): + + + x_up = self.Upsample(x, S, H, W) + + x = x_up + skip + S, H, W = S * 2, H * 2, W * 2 + # calculate attention mask for SW-MSA + Sp = int(np.ceil(S / self.window_size)) * self.window_size + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Sp, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + s_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for s in s_slices: + for h in h_slices: + for w in w_slices: + img_mask[:, s, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, + self.window_size * self.window_size * self.window_size) # 3d��3��winds�˻�����Ŀ�Ǻܴ�ģ�����winds����̫�� + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + x = self.blocks[0](x, attn_mask,skip=skip,x_up=x_up) + for i in range(self.depth-1): + x = self.blocks[i+1](x,attn_mask) + + return x, S, H, W + + +# from nnFormer +class project(nn.Module): + def __init__(self,in_dim,out_dim,stride,padding,activate,norm,last=False): + super().__init__() + self.out_dim=out_dim + self.conv1=nn.Conv3d(in_dim,out_dim,kernel_size=3,stride=stride,padding=padding) + self.conv2=nn.Conv3d(out_dim,out_dim,kernel_size=3,stride=1,padding=1) + self.activate=activate() + self.norm1=norm(out_dim) + self.last=last + if not last: + self.norm2=norm(out_dim) + + def forward(self,x): + x=self.conv1(x) + x=self.activate(x) + #norm1 + Ws, Wh, Ww = x.size(2), x.size(3), x.size(4) + x = x.flatten(2).transpose(1, 2).contiguous() + x = self.norm1(x) + x = x.transpose(1, 2).contiguous().view(-1, self.out_dim, Ws, Wh, Ww) + + + x=self.conv2(x) + if not self.last: + x=self.activate(x) + #norm2 + Ws, Wh, Ww = x.size(2), x.size(3), x.size(4) + x = x.flatten(2).transpose(1, 2).contiguous() + x = self.norm2(x) + x = x.transpose(1, 2).contiguous().view(-1, self.out_dim, Ws, Wh, Ww) + return x + + +# from nnFormer +class PatchEmbed_backup(nn.Module): + def __init__(self, patch_size=4, in_chans=4, embed_dim=96, norm_layer=None): + super().__init__() + patch_size = to_3tuple(patch_size) + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + stride1=[patch_size[0]//2,patch_size[1]//2,patch_size[2]//2] + stride2=[patch_size[0]//2,patch_size[1]//2,patch_size[2]//2] + self.proj1 = project(in_chans,embed_dim//2,stride1,1,nn.GELU,nn.LayerNorm,False) + self.proj2 = project(embed_dim//2,embed_dim,stride2,1,nn.GELU,nn.LayerNorm,True) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, S, H, W = x.size() + if W % self.patch_size[2] != 0: + x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2])) + if H % self.patch_size[1] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1])) + if S % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - S % self.patch_size[0])) + x = self.proj1(x) # B C Ws Wh Ww + x = self.proj2(x) # B C Ws Wh Ww + if self.norm is not None: + Ws, Wh, Ww = x.size(2), x.size(3), x.size(4) + x = x.flatten(2).transpose(1, 2).contiguous() + x = self.norm(x) + x = x.transpose(1, 2).contiguous().view(-1, self.embed_dim, Ws, Wh, Ww) + + return x + + +class PatchEmbed(nn.Module): + """ + replace patch embed with conv layers""" + def __init__(self, in_chns=1, ft_chns = [32, 64, 128], dropout = [0, 0, 0.2]): + super().__init__() + self.in_conv= ConvBlock(in_chns, ft_chns[0], dropout[0]) + self.down1 = DownBlock(ft_chns[0], ft_chns[1], dropout[1]) + self.down2 = DownBlock(ft_chns[1], ft_chns[2], dropout[2]) + + + def forward(self, x): + """Forward function.""" + x0 = self.in_conv(x) + x1 = self.down1(x0) + x2 = self.down2(x1) + return x2 + +# from nnFormer +class Encoder(nn.Module): + + def __init__(self, + pretrain_img_size=224, + patch_size=4, + in_chans=1 , + embed_dim=96, + depths=[2, 2, 2, 2], + num_heads=[4, 8, 16, 32], + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + patch_norm=True, + out_indices=(0, 1, 2, 3) + ): + super().__init__() + + self.pretrain_img_size = pretrain_img_size + + self.num_layers = len(depths) + print("number of layers in encoder", self.num_layers, depths) + self.embed_dim = embed_dim + self.patch_norm = patch_norm + self.out_indices = out_indices + + # split image into non-overlapping patches + # self.patch_embed = PatchEmbed( + # patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + # norm_layer=norm_layer if self.patch_norm else None) + self.patch_embed = PatchEmbed(in_chans, ft_chns=[embed_dim // 4, embed_dim //2, embed_dim]) + + + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2 ** i_layer), + input_resolution=( + pretrain_img_size[0] // patch_size[0] // 2 ** i_layer, pretrain_img_size[1] // patch_size[1] // 2 ** i_layer, + pretrain_img_size[2] // patch_size[2] // 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size[i_layer], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum( + depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging + if (i_layer < self.num_layers - 1) else None + ) + self.layers.append(layer) + + num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + + def forward(self, x): + """Forward function.""" + + x = self.patch_embed(x) + down=[] + Ws, Wh, Ww = x.size(2), x.size(3), x.size(4) + + x = x.flatten(2).transpose(1, 2).contiguous() + x = self.pos_drop(x) + + + for i in range(self.num_layers): + layer = self.layers[i] + x_out, S, H, W, x, Ws, Wh, Ww = layer(x, Ws, Wh, Ww) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + x_out = norm_layer(x_out) + + out = x_out.view(-1, S, H, W, self.num_features[i]).permute(0, 4, 1, 2, 3).contiguous() + + down.append(out) + return down + + +# from nnFormer +class Decoder(nn.Module): + def __init__(self, + pretrain_img_size, + embed_dim, + patch_size=4, + depths=[2,2,2], + num_heads=[24,12,6], + window_size=4, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm + ): + super().__init__() + + + self.num_layers = len(depths) + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers)[::-1]: + + layer = BasicLayer_up( + dim=int(embed_dim * 2 ** (len(depths)-i_layer-1)), + input_resolution=( + pretrain_img_size[0] // patch_size[0] // 2 ** (len(depths)-i_layer-1), pretrain_img_size[1] // patch_size[1] // 2 ** (len(depths)-i_layer-1), + pretrain_img_size[2] // patch_size[2] // 2 ** (len(depths)-i_layer-1)), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size[i_layer], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum( + depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + upsample=Patch_Expanding + ) + self.layers.append(layer) + self.num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + def forward(self,x,skips): + + outs=[] + S, H, W = x.size(2), x.size(3), x.size(4) + x = x.flatten(2).transpose(1, 2).contiguous() + for index,i in enumerate(skips): + i = i.flatten(2).transpose(1, 2).contiguous() + skips[index]=i + x = self.pos_drop(x) + + for i in range(self.num_layers)[::-1]: + + layer = self.layers[i] + + x, S, H, W, = layer(x,skips[i], S, H, W) + out = x.view(-1, S, H, W, self.num_features[i]) + outs.append(out) + return outs + + +class final_patch_expanding(nn.Module): + def __init__(self,dim,num_class,patch_size): + super().__init__() + self.up=nn.ConvTranspose3d(dim,num_class,patch_size,patch_size) + + def forward(self,x): + x=x.permute(0,4,1,2,3).contiguous() + x=self.up(x) + + + return x + + + + + +class HiFormer_v1(nn.Module): + def __init__(self, params): + """ + replace the embedding layer with convolutional blocks + """ + super(HiFormer_v1, self).__init__() + # crop_size=[96,96,96], + # embedding_dim=192, + # input_channels=1, + # num_classes=9, + # conv_op=nn.Conv3d, + # depths=[2,2,2,2], + # num_heads=[6, 12, 24, 48], + # patch_size=[4,4,4], + # window_size=[4,4,8,4], + # deep_supervision=False): + + crop_size = params["input_size"] + embed_dim = params.get("embedding_dim", 192) + input_channels = params["in_chns"] + num_classes = params["class_num"] + self.conv_op = nn.Conv3d + depths = params.get("depths", [2, 2, 2, 2]) + num_heads = params.get("num_heads", [6, 12, 24, 48]) + patch_size = params.get("patch_size", [4, 4, 4]) # for patch embedding + window_size = params.get("window_size", [4, 4, 8, 4]) # for swin transformer window + self._deep_supervision = params.get("deep_supervision", False) + self.do_ds = params.get("deep_supervision", False) + + + self.num_classes = num_classes + self.upscale_logits_ops = [] + self.upscale_logits_ops.append(lambda x: x) + + self.model_down=Encoder(pretrain_img_size=crop_size,window_size=window_size,embed_dim=embed_dim, + patch_size=patch_size,depths=depths,num_heads=num_heads,in_chans=input_channels, out_indices=range(len(depths))) + self.decoder=Decoder(pretrain_img_size=crop_size,embed_dim=embed_dim,window_size=window_size[::-1][1:],patch_size=patch_size,num_heads=num_heads[::-1][:-1],depths=depths[::-1][1:]) + + self.final=[] + if self.do_ds: + + for i in range(len(depths)-1): + self.final.append(final_patch_expanding(embed_dim*2**i,num_classes,patch_size=patch_size)) + + else: + self.final.append(final_patch_expanding(embed_dim,num_classes,patch_size=patch_size)) + + self.final=nn.ModuleList(self.final) + + + def forward(self, x): + + + seg_outputs=[] + skips = self.model_down(x) + neck=skips[-1] + + out=self.decoder(neck,skips) + + + + if self.do_ds: + for i in range(len(out)): + seg_outputs.append(self.final[-(i+1)](out[i])) + + + return seg_outputs[::-1] + else: + seg_outputs.append(self.final[0](out[-1])) + return seg_outputs[-1] + + +if __name__ == "__main__": + # params = {"input_size": [96, 96, 96], + # "in_chns": 1, + # "depth": [2, 2, 2, 2], + # "num_heads": [6, 12, 24, 48], + # "window_size": [6, 6, 6, 3], + # "class_num": 5} + params = {"input_size": [96, 96, 96], + "in_chns": 1, + "depths": [2, 2, 2], + "num_heads": [6, 12, 24], + "window_size": [6, 6, 6], + "class_num": 5} + Net = HiFormer_v1(params) + Net = Net.double() + + x = np.random.rand(1, 1, 96, 96, 96) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = Net(xt) + print(y.shape) + + + diff --git a/pymic/net/net3d/trans3d/HiFormer_v2.py b/pymic/net/net3d/trans3d/HiFormer_v2.py new file mode 100644 index 0000000..7d4c440 --- /dev/null +++ b/pymic/net/net3d/trans3d/HiFormer_v2.py @@ -0,0 +1,381 @@ + +import torch +import numpy as np +import torch.utils.checkpoint as checkpoint +from einops import rearrange +from copy import deepcopy +from torch import nn +from pymic.net.net3d.trans3d.HiFormer_v1 import BasicLayer + +class ConvBlock(nn.Module): + """ + 2D or 3D convolutional block + + :param in_channels: (int) Input channel number. + :param out_channels: (int) Output channel number. + :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. + :param dropout_p: (int) Dropout probability. + """ + def __init__(self, in_channels, out_channels, dim = 2, dropout_p = 0.0): + super(ConvBlock, self).__init__() + assert(dim == 2 or dim == 3) + if(dim == 2): + kernel_size = [1, 3, 3] + padding = [0, 1, 1] + else: + kernel_size = 3 + padding = 1 + + self.conv_conv = nn.Sequential( + nn.BatchNorm3d(in_channels), + nn.PReLU(), + nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=padding), + nn.BatchNorm3d(out_channels), + nn.PReLU(), + nn.Dropout(dropout_p), + nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, padding=padding), + ) + + def forward(self, x): + return self.conv_conv(x) + + +class DownSample(nn.Module): + def __init__(self, in_channels, out_channels, dim = 2, first_layer = False): + super(DownSample, self).__init__() + assert(dim == 2 or dim == 3) + if(dim == 2): + kernel_size = [1, 3, 3] + stride = [1, 2, 2] + padding = [0, 1, 1] + else: + kernel_size = 3 + stride = 2 + padding = 1 + + if(first_layer): + self.down = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, + padding=padding, stride = stride) + else: + self.down = nn.Sequential( + nn.BatchNorm3d(in_channels), + nn.PReLU(), + nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, + padding=padding, stride = stride), + ) + + def forward(self, x): + return self.down(x) + + + +class ConvTransBlock(nn.Module): + def __init__(self, + input_resolution= [32, 32, 32], + chns=96, + depth=2, + num_head=4, + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + patch_norm=True, + ): + super().__init__() + self.conv = ConvBlock(chns, chns, dim = 3, dropout_p = drop_rate) + self.trans = BasicLayer( + dim= chns, + input_resolution= input_resolution, + depth=depth, + num_heads=num_head, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + norm_layer=norm_layer, + downsample= None + ) + self.norm_layer = nn.LayerNorm(chns) + self.pos_drop = nn.Dropout(p=drop_rate) + + def forward(self, x): + """Forward function.""" + x1 = self.conv(x) + C, Ws, Wh, Ww = x.size(1), x.size(2), x.size(3), x.size(4) + x = x.flatten(2).transpose(1, 2).contiguous() + x = self.pos_drop(x) + x2, S, H, W, x, Ws, Wh, Ww = self.trans(x, Ws, Wh, Ww) + # x2 = self.norm_layer(x2) + x2 = x2.view(-1, S, H, W, C).permute(0, 4, 1, 2, 3).contiguous() + return x1 + x2 + + +class UpCatBlock(nn.Module): + """ + 3D upsampling followed by ConvBlock + + :param in_channels1: (int) Channel number of high-level features. + :param in_channels2: (int) Channel number of low-level features. + :param out_channels: (int) Output channel number. + :param dropout_p: (int) Dropout probability. + :param trilinear: (bool) Use trilinear for up-sampling (by default). + If False, deconvolution is used for up-sampling. + """ + def __init__(self, chns_l, chns_h, up_dim = 3, conv_dim = 3): + super(UpCatBlock, self).__init__() + assert(up_dim == 2 or up_dim == 3) + if(up_dim == 2): + kernel_size, stride = [1, 2, 2], [1, 2, 2] + else: + kernel_size, stride = 2, 2 + self.up = nn.ConvTranspose3d(chns_h, chns_l, + kernel_size = kernel_size, stride=stride) + + if(conv_dim == 2): + kernel_size, padding = [1, 3, 3], [0, 1, 1] + else: + kernel_size, padding = 3, 1 + self.conv = nn.Sequential( + nn.BatchNorm3d(chns_l*2), + nn.PReLU(), + nn.Conv3d(chns_l*2, chns_l, kernel_size=kernel_size, padding=padding), + ) + + def forward(self, x_l, x_h): + # print("input shapes", x1.shape, x2.shape) + # print("after upsample", x1.shape) + y = torch.cat([x_l, self.up(x_h)], dim=1) + return self.conv(y) + +class Encoder(nn.Module): + def __init__(self, + in_chns = 1 , + ft_chns = [48, 192, 384, 768], + input_size= [32, 128, 128], + down_dims = [2, 2, 3, 3], + conv_dims = [2, 3, 3, 3], + dropout = [0, 0.2, 0.2, 0.2], + depths = [2, 2, 2], + num_heads = [4, 8, 16], + window_sizes = [6, 6, 6], + ): + super().__init__() + + self.down1 = DownSample(in_chns, ft_chns[0], down_dims[0], first_layer=True) + self.down2 = DownSample(ft_chns[0], ft_chns[1], down_dims[1]) + self.down3 = DownSample(ft_chns[1], ft_chns[2], down_dims[2]) + self.down4 = DownSample(ft_chns[2], ft_chns[3], down_dims[3]) + + self.conv1 = ConvBlock(ft_chns[0], ft_chns[0], conv_dims[0], dropout[0]) + self.conv2 = ConvBlock(ft_chns[1], ft_chns[1], conv_dims[1], dropout[1]) + + down_scales = [] + for i in range(4): + down_scale = [1, 2, 2] if down_dims[i] == 2 else [2, 2, 2] + down_scales.append(down_scale) + + r_t2 = [input_size[i] // down_scales[0][i] // down_scales[1][i] for i in range(3)] + r_t3 = [r_t2[i] // down_scales[2][i] for i in range(3)] + r_t4 = [r_t3[i] // down_scales[3][i] for i in range(3)] + + self.conv_t2 = ConvTransBlock(chns = ft_chns[1], + input_resolution = r_t2, + window_size = window_sizes[0], + depth = depths[0], + num_head = num_heads[0], + drop_rate = dropout[1], + attn_drop_rate=dropout[1] + ) + self.conv_t3 = ConvTransBlock(chns = ft_chns[2], + input_resolution = r_t3, + window_size = window_sizes[1], + depth = depths[1], + num_head = num_heads[1], + drop_rate = dropout[2], + attn_drop_rate=dropout[2] + ) + self.conv_t4 = ConvTransBlock(chns = ft_chns[3], + input_resolution = r_t4, + window_size = window_sizes[2], + depth = depths[2], + num_head = num_heads[2], + drop_rate = dropout[3], + attn_drop_rate=dropout[3] + ) + + + + def forward(self, x): + """Forward function.""" + x1 = self.conv1(self.down1(x)) + x2 = self.conv2(self.down2(x1)) + x2 = self.conv_t2(x2) + x3 = self.conv_t3(self.down3(x2)) + x4 = self.conv_t4(self.down4(x3)) + + return x1, x2, x3, x4 + +class Decoder(nn.Module): + """ + Decoder of 3D UNet. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param trilinear: (bool) Using bilinear for up-sampling or not. + If False, deconvolution will be used for up-sampling. + :param multiscale_pred: (bool) Get multi-scale prediction. + """ + def __init__(self, + ft_chns = [48, 192, 384, 768], + input_size = [32, 128, 128], + down_dims = [2, 2, 3, 3], + conv_dims = [2, 3, 3, 3], + dropout = [0, 0, 0.2, 0.2], + depths = [2, 2, 2], + num_heads = [4, 8, 16], + window_sizes = [6, 6, 6], + class_num = 2, + multiscale_pred = False + ): + super(Decoder, self).__init__() + + self.up1 = UpCatBlock(ft_chns[0], ft_chns[1], down_dims[1], conv_dims[0]) + self.up2 = UpCatBlock(ft_chns[1], ft_chns[2], down_dims[2], conv_dims[1]) + self.up3 = UpCatBlock(ft_chns[2], ft_chns[3], down_dims[3], conv_dims[2]) + + down_scales = [] + for i in range(4): + down_scale = [1, 2, 2] if down_dims[i] == 2 else [2, 2, 2] + down_scales.append(down_scale) + + r_t2 = [input_size[i] // down_scales[0][i] // down_scales[1][i] for i in range(3)] + r_t3 = [r_t2[i] // down_scales[2][i] for i in range(3)] + + self.conv1 = ConvBlock(ft_chns[0], ft_chns[0], conv_dims[0], dropout[0]) + self.conv2 = ConvTransBlock(chns = ft_chns[1], + input_resolution = r_t2, + window_size = window_sizes[0], + depth = depths[0], + num_head = num_heads[0], + drop_rate = dropout[1], + attn_drop_rate=dropout[1] + ) + self.conv3 = ConvTransBlock(chns = ft_chns[2], + input_resolution = r_t3, + window_size = window_sizes[1], + depth = depths[1], + num_head = num_heads[1], + drop_rate = dropout[2], + attn_drop_rate=dropout[2] + ) + + kernel_size, stride = 2, 2 + if down_dims[0] == 2: + kernel_size, stride = [1, 2, 2], [1, 2, 2] + self.out_conv0 = nn.ConvTranspose3d(ft_chns[0], class_num, + kernel_size = kernel_size, stride= stride) + + self.mul_pred = multiscale_pred + if(self.mul_pred): + self.out_conv1 = nn.Conv3d(ft_chns[0], class_num, kernel_size = 1) + self.out_conv2 = nn.Conv3d(ft_chns[1], class_num, kernel_size = 1) + self.out_conv3 = nn.Conv3d(ft_chns[2], class_num, kernel_size = 1) + + def forward(self, x): + x1, x2, x3, x4 = x + x_d3 = self.conv3(self.up3(x3, x4)) + x_d2 = self.conv2(self.up2(x2, x_d3)) + x_d1 = self.conv1(self.up1(x1, x_d2)) + + output = self.out_conv0(x_d1) + if(self.mul_pred): + output1 = self.out_conv1(x_d1) + output2 = self.out_conv2(x_d2) + output3 = self.out_conv3(x_d3) + output = [output, output1, output2, output3] + return output + +class HiFormer_v2(nn.Module): + def __init__(self, params): + """ + replace the embedding layer with convolutional blocks + """ + super(HiFormer_v2, self).__init__() + in_chns = params["in_chns"] + class_num = params["class_num"] + input_size = params["input_size"] + ft_chns = params.get("feature_chns", [48, 192, 384, 764]) + down_dims = params.get("down_dims", [2, 2, 3, 3]) + conv_dims = params.get("conv_dims", [2, 3, 3, 3]) + dropout = params.get('dropout', [0, 0.2, 0.2, 0.2]) + depths = params.get("depths", [2, 2, 2]) + num_heads = params.get("num_heads", [4, 8, 16]) + window_sizes= params.get("window_sizes", [6, 6, 6]) + multiscale_pred = params.get("multiscale_pred", False) + + self.encoder = Encoder(in_chns, + ft_chns = ft_chns, + input_size = input_size, + down_dims = down_dims, + conv_dims = conv_dims, + dropout = dropout, + depths = depths, + num_heads = num_heads, + window_sizes= window_sizes) + + self.decoder = Decoder(ft_chns = ft_chns, + input_size = input_size, + down_dims = down_dims, + conv_dims = conv_dims, + dropout = dropout, + depths = depths, + num_heads = num_heads, + window_sizes= window_sizes, + class_num = class_num, + multiscale_pred = multiscale_pred + ) + + def forward(self, x): + x = self.encoder(x) + x = self.decoder(x) + return x + + +if __name__ == "__main__": + params = {"input_size": [32, 128, 128], + "in_chns": 1, + "down_dims": [2, 2, 3, 3], + "conv_dims": [2, 3, 3, 3], + "feature_chns": [96, 192, 384, 768], + "class_num": 5, + "multiscale_pred": True} + Net = HiFormer_v2(params) + Net = Net.double() + + x = np.random.rand(1, 1, 32, 128, 128) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = Net(xt) + if(params['multiscale_pred']): + for yi in y: + print(yi.shape) + else: + print(y.shape) + + + diff --git a/pymic/net/net3d/trans3d/HiFormer_v3.py b/pymic/net/net3d/trans3d/HiFormer_v3.py new file mode 100644 index 0000000..2f8c831 --- /dev/null +++ b/pymic/net/net3d/trans3d/HiFormer_v3.py @@ -0,0 +1,455 @@ + +import torch +import numpy as np +import torch.utils.checkpoint as checkpoint +from einops import rearrange +from copy import deepcopy +from torch import nn +from pymic.net.net3d.trans3d.HiFormer_v1 import BasicLayer + +class ConvBlock(nn.Module): + """ + 2D or 3D convolutional block + + :param in_channels: (int) Input channel number. + :param out_channels: (int) Output channel number. + :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. + :param dropout_p: (int) Dropout probability. + """ + def __init__(self, in_channels, out_channels, dim = 2, dropout_p = 0.0): + super(ConvBlock, self).__init__() + assert(dim == 2 or dim == 3) + if(dim == 2): + kernel_size = [1, 3, 3] + padding = [0, 1, 1] + else: + kernel_size = 3 + padding = 1 + + self.conv_conv = nn.Sequential( + nn.BatchNorm3d(in_channels), + nn.PReLU(), + nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=padding), + nn.BatchNorm3d(out_channels), + nn.PReLU(), + nn.Dropout(dropout_p), + nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, padding=padding), + ) + + def forward(self, x): + return self.conv_conv(x) + + +class DownSample(nn.Module): + def __init__(self, in_channels, out_channels, dim = 2, first_layer = False): + super(DownSample, self).__init__() + assert(dim == 2 or dim == 3) + if(dim == 2): + kernel_size = [1, 3, 3] + stride = [1, 2, 2] + padding = [0, 1, 1] + else: + kernel_size = 3 + stride = 2 + padding = 1 + + if(first_layer): + self.down = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, + padding=padding, stride = stride) + else: + self.down = nn.Sequential( + nn.BatchNorm3d(in_channels), + nn.PReLU(), + nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, + padding=padding, stride = stride), + ) + + def forward(self, x): + return self.down(x) + + + +class ConvTransBlock_backup(nn.Module): + def __init__(self, + input_resolution= [32, 32, 32], + chns=96, + depth=2, + num_head=4, + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + patch_norm=True, + ): + super().__init__() + self.conv = ConvBlock(chns, chns, dim = 3, dropout_p = drop_rate) + self.trans = BasicLayer( + dim= chns, + input_resolution= input_resolution, + depth=depth, + num_heads=num_head, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + norm_layer=norm_layer, + downsample= None + ) + self.norm_layer = nn.LayerNorm(chns) + self.pos_drop = nn.Dropout(p=drop_rate) + + def forward(self, x): + """Forward function.""" + x1 = self.conv(x) + C, Ws, Wh, Ww = x.size(1), x.size(2), x.size(3), x.size(4) + x = x.flatten(2).transpose(1, 2).contiguous() + x = self.pos_drop(x) + x2, S, H, W, x, Ws, Wh, Ww = self.trans(x, Ws, Wh, Ww) + # x2 = self.norm_layer(x2) + x2 = x2.view(-1, S, H, W, C).permute(0, 4, 1, 2, 3).contiguous() + return x1 + x2 + +# only using the conv block +class ConvTransBlock(nn.Module): + def __init__(self, + input_resolution= [32, 32, 32], + chns=96, + depth=2, + num_head=4, + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + patch_norm=True, + ): + super().__init__() + self.conv = ConvBlock(chns, chns, dim = 3, dropout_p = drop_rate) + # self.trans = BasicLayer( + # dim= chns, + # input_resolution= input_resolution, + # depth=depth, + # num_heads=num_head, + # window_size=window_size, + # mlp_ratio=mlp_ratio, + # qkv_bias=qkv_bias, + # qk_scale=qk_scale, + # drop=drop_rate, + # attn_drop=attn_drop_rate, + # drop_path=drop_path_rate, + # norm_layer=norm_layer, + # downsample= None + # ) + # self.norm_layer = nn.LayerNorm(chns) + # self.pos_drop = nn.Dropout(p=drop_rate) + + def forward(self, x): + """Forward function.""" + x1 = self.conv(x) + return x1 + # C, Ws, Wh, Ww = x.size(1), x.size(2), x.size(3), x.size(4) + # x = x.flatten(2).transpose(1, 2).contiguous() + # x = self.pos_drop(x) + # x2, S, H, W, x, Ws, Wh, Ww = self.trans(x, Ws, Wh, Ww) + # # x2 = self.norm_layer(x2) + # x2 = x2.view(-1, S, H, W, C).permute(0, 4, 1, 2, 3).contiguous() + # return x1 + x2 + +class UpCatBlock(nn.Module): + """ + 3D upsampling followed by ConvBlock + + :param in_channels1: (int) Channel number of high-level features. + :param in_channels2: (int) Channel number of low-level features. + :param out_channels: (int) Output channel number. + :param dropout_p: (int) Dropout probability. + :param trilinear: (bool) Use trilinear for up-sampling (by default). + If False, deconvolution is used for up-sampling. + """ + def __init__(self, chns_l, chns_h, up_dim = 3, conv_dim = 3): + super(UpCatBlock, self).__init__() + assert(up_dim == 2 or up_dim == 3) + if(up_dim == 2): + kernel_size, stride = [1, 2, 2], [1, 2, 2] + else: + kernel_size, stride = 2, 2 + self.up = nn.ConvTranspose3d(chns_h, chns_l, + kernel_size = kernel_size, stride=stride) + + if(conv_dim == 2): + kernel_size, padding = [1, 3, 3], [0, 1, 1] + else: + kernel_size, padding = 3, 1 + self.conv = nn.Sequential( + nn.BatchNorm3d(chns_l*2), + nn.PReLU(), + nn.Conv3d(chns_l*2, chns_l, kernel_size=kernel_size, padding=padding), + ) + + def forward(self, x_l, x_h): + # print("input shapes", x1.shape, x2.shape) + # print("after upsample", x1.shape) + y = torch.cat([x_l, self.up(x_h)], dim=1) + return self.conv(y) + +class Encoder(nn.Module): + def __init__(self, + in_chns = 1 , + ft_chns = [48, 192, 384, 768], + input_size= [32, 128, 128], + down_dims = [2, 2, 3, 3], + conv_dims = [2, 3, 3, 3], + dropout = [0, 0.2, 0.2, 0.2], + depths = [2, 2, 2], + num_heads = [4, 8, 16], + window_sizes = [6, 6, 6], + high_res = False, + ): + super().__init__() + self.high_res = high_res + + self.down1 = DownSample(in_chns, ft_chns[0], down_dims[0], first_layer=True) + self.down2 = DownSample(ft_chns[0], ft_chns[1], down_dims[1]) + self.down3 = DownSample(ft_chns[1], ft_chns[2], down_dims[2]) + self.down4 = DownSample(ft_chns[2], ft_chns[3], down_dims[3]) + + if(high_res): + self.conv0 = ConvBlock(in_chns, ft_chns[0] // 2, 3, 0) + self.conv1 = ConvBlock(ft_chns[0], ft_chns[0], conv_dims[0], dropout[0]) + self.conv2 = ConvBlock(ft_chns[1], ft_chns[1], conv_dims[1], dropout[1]) + + down_scales = [] + for i in range(4): + down_scale = [1, 2, 2] if down_dims[i] == 2 else [2, 2, 2] + down_scales.append(down_scale) + + r_t2 = [input_size[i] // down_scales[0][i] // down_scales[1][i] for i in range(3)] + r_t3 = [r_t2[i] // down_scales[2][i] for i in range(3)] + r_t4 = [r_t3[i] // down_scales[3][i] for i in range(3)] + + self.conv_t2 = ConvTransBlock(chns = ft_chns[1], + input_resolution = r_t2, + window_size = window_sizes[0], + depth = depths[0], + num_head = num_heads[0], + drop_rate = dropout[1], + attn_drop_rate=dropout[1] + ) + self.conv_t3 = ConvTransBlock(chns = ft_chns[2], + input_resolution = r_t3, + window_size = window_sizes[1], + depth = depths[1], + num_head = num_heads[1], + drop_rate = dropout[2], + attn_drop_rate=dropout[2] + ) + self.conv_t4 = ConvTransBlock(chns = ft_chns[3], + input_resolution = r_t4, + window_size = window_sizes[2], + depth = depths[2], + num_head = num_heads[2], + drop_rate = dropout[3], + attn_drop_rate=dropout[3] + ) + + + + def forward(self, x): + """Forward function.""" + if(self.high_res): + x0 = self.conv0(x) + x1 = self.conv1(self.down1(x)) + x2 = self.conv2(self.down2(x1)) + x2 = self.conv_t2(x2) + x3 = self.conv_t3(self.down3(x2)) + x4 = self.conv_t4(self.down4(x3)) + if(self.high_res): + return x0, x1, x2, x3, x4 + else: + return x1, x2, x3, x4 + +class Decoder(nn.Module): + """ + Decoder of 3D UNet. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param trilinear: (bool) Using bilinear for up-sampling or not. + If False, deconvolution will be used for up-sampling. + :param multiscale_pred: (bool) Get multi-scale prediction. + """ + def __init__(self, + ft_chns = [48, 192, 384, 768], + input_size = [32, 128, 128], + down_dims = [2, 2, 3, 3], + conv_dims = [2, 3, 3, 3], + dropout = [0, 0, 0.2, 0.2], + depths = [2, 2, 2], + num_heads = [4, 8, 16], + window_sizes = [6, 6, 6], + high_res = False, + class_num = 2, + multiscale_pred = False + ): + super(Decoder, self).__init__() + self.high_res = high_res + if(self.high_res): + self.up0 = UpCatBlock(ft_chns[0] // 2, ft_chns[0], down_dims[0], 3) + self.conv0 = ConvBlock(ft_chns[0] // 2, ft_chns[0] // 2, 3, 0) + self.up1 = UpCatBlock(ft_chns[0], ft_chns[1], down_dims[1], conv_dims[0]) + self.up2 = UpCatBlock(ft_chns[1], ft_chns[2], down_dims[2], conv_dims[1]) + self.up3 = UpCatBlock(ft_chns[2], ft_chns[3], down_dims[3], conv_dims[2]) + + down_scales = [] + for i in range(4): + down_scale = [1, 2, 2] if down_dims[i] == 2 else [2, 2, 2] + down_scales.append(down_scale) + + r_t2 = [input_size[i] // down_scales[0][i] // down_scales[1][i] for i in range(3)] + r_t3 = [r_t2[i] // down_scales[2][i] for i in range(3)] + + self.conv1 = ConvBlock(ft_chns[0], ft_chns[0], conv_dims[0], dropout[0]) + self.conv2 = ConvTransBlock(chns = ft_chns[1], + input_resolution = r_t2, + window_size = window_sizes[0], + depth = depths[0], + num_head = num_heads[0], + drop_rate = dropout[1], + attn_drop_rate=dropout[1] + ) + self.conv3 = ConvTransBlock(chns = ft_chns[2], + input_resolution = r_t3, + window_size = window_sizes[1], + depth = depths[1], + num_head = num_heads[1], + drop_rate = dropout[2], + attn_drop_rate=dropout[2] + ) + + kernel_size, stride = 2, 2 + if down_dims[0] == 2: + kernel_size, stride = [1, 2, 2], [1, 2, 2] + if(self.high_res): + self.out_conv0 = nn.Conv3d(ft_chns[0] // 2, class_num, + kernel_size = [1, 3, 3], padding = [0, 1, 1]) + else: + self.out_conv0 = nn.ConvTranspose3d(ft_chns[0], class_num, + kernel_size = kernel_size, stride= stride) + + self.mul_pred = multiscale_pred + if(self.mul_pred): + self.out_conv1 = nn.Conv3d(ft_chns[0], class_num, kernel_size = 1) + self.out_conv2 = nn.Conv3d(ft_chns[1], class_num, kernel_size = 1) + self.out_conv3 = nn.Conv3d(ft_chns[2], class_num, kernel_size = 1) + + def forward(self, x): + if(self.high_res): + x0, x1, x2, x3, x4 = x + else: + x1, x2, x3, x4 = x + x_d3 = self.conv3(self.up3(x3, x4)) + x_d2 = self.conv2(self.up2(x2, x_d3)) + x_d1 = self.conv1(self.up1(x1, x_d2)) + if(self.high_res): + x_d0 = self.conv0(self.up0(x0, x_d1)) + output = self.out_conv0(x_d0) + else: + output = self.out_conv0(x_d1) + if(self.mul_pred): + output1 = self.out_conv1(x_d1) + output2 = self.out_conv2(x_d2) + output3 = self.out_conv3(x_d3) + output = [output, output1, output2, output3] + return output + +class HiFormer_v3(nn.Module): + def __init__(self, params): + """ + replace the embedding layer with convolutional blocks + """ + super(HiFormer_v3, self).__init__() + in_chns = params["in_chns"] + class_num = params["class_num"] + input_size = params["input_size"] + ft_chns = params.get("feature_chns", [48, 192, 384, 764]) + down_dims = params.get("down_dims", [2, 2, 3, 3]) + conv_dims = params.get("conv_dims", [2, 3, 3, 3]) + dropout = params.get('dropout', [0, 0.2, 0.2, 0.2]) + high_res = params.get("high_res", False) + depths = params.get("depths", [2, 2, 2]) + num_heads = params.get("num_heads", [4, 8, 16]) + window_sizes= params.get("window_sizes", [6, 6, 6]) + multiscale_pred = params.get("multiscale_pred", False) + + self.encoder = Encoder(in_chns, + ft_chns = ft_chns, + input_size = input_size, + down_dims = down_dims, + conv_dims = conv_dims, + dropout = dropout, + depths = depths, + num_heads = num_heads, + window_sizes= window_sizes, + high_res = high_res) + + self.decoder = Decoder(ft_chns = ft_chns, + input_size = input_size, + down_dims = down_dims, + conv_dims = conv_dims, + dropout = dropout, + depths = depths, + num_heads = num_heads, + window_sizes= window_sizes, + high_res = high_res, + class_num = class_num, + multiscale_pred = multiscale_pred + ) + + def forward(self, x): + x = self.encoder(x) + x = self.decoder(x) + return x + + +if __name__ == "__main__": + params = {"input_size": [64, 96, 96], + "in_chns": 1, + "down_dims": [3, 3, 3, 3], + "conv_dims": [3, 3, 3, 3], + "feature_chns": [96, 192, 384, 768], + "high_res": True, + "class_num": 5, + "multiscale_pred": True} + Net = HiFormer_v3(params) + Net = Net.double() + + x = np.random.rand(1, 1, 64, 96, 96) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = Net(xt) + if(params['multiscale_pred']): + for yi in y: + print(yi.shape) + else: + print(y.shape) + + + diff --git a/pymic/net/net3d/trans3d/HiFormer_v4.py b/pymic/net/net3d/trans3d/HiFormer_v4.py new file mode 100644 index 0000000..f0c6087 --- /dev/null +++ b/pymic/net/net3d/trans3d/HiFormer_v4.py @@ -0,0 +1,455 @@ + +import torch +import numpy as np +import torch.utils.checkpoint as checkpoint +from einops import rearrange +from copy import deepcopy +from torch import nn +from pymic.net.net3d.trans3d.HiFormer_v1 import BasicLayer + +class ConvBlock(nn.Module): + """ + 2D or 3D convolutional block + + :param in_channels: (int) Input channel number. + :param out_channels: (int) Output channel number. + :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. + :param dropout_p: (int) Dropout probability. + """ + def __init__(self, in_channels, out_channels, dim = 2, dropout_p = 0.0): + super(ConvBlock, self).__init__() + assert(dim == 2 or dim == 3) + if(dim == 2): + kernel_size = [1, 3, 3] + padding = [0, 1, 1] + else: + kernel_size = 3 + padding = 1 + + self.conv_conv = nn.Sequential( + nn.BatchNorm3d(in_channels), + nn.PReLU(), + nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=padding), + nn.BatchNorm3d(out_channels), + nn.PReLU(), + nn.Dropout(dropout_p), + nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, padding=padding), + ) + + def forward(self, x): + return self.conv_conv(x) + + +class DownSample(nn.Module): + def __init__(self, in_channels, out_channels, down_dim = 3, conv_dim = 3): + super(DownSample, self).__init__() + assert(down_dim == 2 or down_dim == 3) + assert(conv_dim == 2 or conv_dim == 3) + + kernel_size = [1, 2, 2] if(down_dim == 2) else 2 + self.pool = nn.MaxPool3d(kernel_size) + + if(conv_dim == 2): + kernel_size = [1, 3, 3] + padding = [0, 1, 1] + else: + kernel_size = 3 + padding = 1 + + self.conv = nn.Sequential( + nn.BatchNorm3d(in_channels), + nn.PReLU(), + nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=padding), + ) + + def forward(self, x): + return self.conv(self.pool(x)) + + + +# class ConvTransBlock(nn.Module): +# def __init__(self, +# input_resolution= [32, 32, 32], +# chns=96, +# depth=2, +# num_head=4, +# window_size=7, +# mlp_ratio=4., +# qkv_bias=True, +# qk_scale=None, +# drop_rate=0., +# attn_drop_rate=0., +# drop_path_rate=0.2, +# norm_layer=nn.LayerNorm, +# patch_norm=True, +# ): +# super().__init__() +# self.conv = ConvBlock(chns, chns, dim = 3, dropout_p = drop_rate) +# self.trans = BasicLayer( +# dim= chns, +# input_resolution= input_resolution, +# depth=depth, +# num_heads=num_head, +# window_size=window_size, +# mlp_ratio=mlp_ratio, +# qkv_bias=qkv_bias, +# qk_scale=qk_scale, +# drop=drop_rate, +# attn_drop=attn_drop_rate, +# drop_path=drop_path_rate, +# norm_layer=norm_layer, +# downsample= None +# ) +# self.norm_layer = nn.LayerNorm(chns) +# self.pos_drop = nn.Dropout(p=drop_rate) + +# def forward(self, x): +# """Forward function.""" +# x1 = self.conv(x) +# C, Ws, Wh, Ww = x.size(1), x.size(2), x.size(3), x.size(4) +# x = x.flatten(2).transpose(1, 2).contiguous() +# x = self.pos_drop(x) +# x2, S, H, W, x, Ws, Wh, Ww = self.trans(x, Ws, Wh, Ww) +# # x2 = self.norm_layer(x2) +# x2 = x2.view(-1, S, H, W, C).permute(0, 4, 1, 2, 3).contiguous() +# return x1 + x2 + +# only using the conv block +class ConvTransBlock(nn.Module): + def __init__(self, + input_resolution= [32, 32, 32], + chns=96, + depth=2, + num_head=4, + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + patch_norm=True, + ): + super().__init__() + self.conv = ConvBlock(chns, chns, dim = 3, dropout_p = drop_rate) + # self.trans = BasicLayer( + # dim= chns, + # input_resolution= input_resolution, + # depth=depth, + # num_heads=num_head, + # window_size=window_size, + # mlp_ratio=mlp_ratio, + # qkv_bias=qkv_bias, + # qk_scale=qk_scale, + # drop=drop_rate, + # attn_drop=attn_drop_rate, + # drop_path=drop_path_rate, + # norm_layer=norm_layer, + # downsample= None + # ) + # self.norm_layer = nn.LayerNorm(chns) + # self.pos_drop = nn.Dropout(p=drop_rate) + + def forward(self, x): + """Forward function.""" + x1 = self.conv(x) + return x1 + # C, Ws, Wh, Ww = x.size(1), x.size(2), x.size(3), x.size(4) + # x = x.flatten(2).transpose(1, 2).contiguous() + # x = self.pos_drop(x) + # x2, S, H, W, x, Ws, Wh, Ww = self.trans(x, Ws, Wh, Ww) + # # x2 = self.norm_layer(x2) + # x2 = x2.view(-1, S, H, W, C).permute(0, 4, 1, 2, 3).contiguous() + # return x1 + x2 + +class ConvLayer(nn.Module): + """ + 2D or 3D convolutional block + + :param in_channels: (int) Input channel number. + :param out_channels: (int) Output channel number. + :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. + :param dropout_p: (int) Dropout probability. + """ + def __init__(self, in_channels, out_channels, kernel = 1, padding = 0): + super(ConvLayer, self).__init__() + + self.conv = nn.Sequential( + nn.BatchNorm3d(in_channels), + nn.PReLU(), + nn.Conv3d(in_channels, out_channels, kernel_size=kernel, padding=padding), + ) + + def forward(self, x): + return self.conv(x) + +class UpCatBlock(nn.Module): + """ + 3D upsampling followed by ConvBlock + + :param in_channels1: (int) Channel number of high-level features. + :param in_channels2: (int) Channel number of low-level features. + :param out_channels: (int) Output channel number. + :param dropout_p: (int) Dropout probability. + :param trilinear: (bool) Use trilinear for up-sampling (by default). + If False, deconvolution is used for up-sampling. + """ + def __init__(self, chns_l, chns_h, up_dim = 3, conv_dim = 3): + super(UpCatBlock, self).__init__() + assert(up_dim == 2 or up_dim == 3) + if(up_dim == 2): + kernel_size, stride = [1, 2, 2], [1, 2, 2] + else: + kernel_size, stride = 2, 2 + + self.up = nn.Sequential( + nn.BatchNorm3d(chns_h), + nn.PReLU(), + nn.ConvTranspose3d(chns_h, chns_l, kernel_size = kernel_size, stride=stride) + ) + + if(conv_dim == 2): + kernel_size, padding = [1, 3, 3], [0, 1, 1] + else: + kernel_size, padding = 3, 1 + self.conv = nn.Sequential( + nn.BatchNorm3d(chns_l*2), + nn.PReLU(), + nn.Conv3d(chns_l*2, chns_l, kernel_size=kernel_size, padding=padding), + ) + + def forward(self, x_l, x_h): + # print("input shapes", x1.shape, x2.shape) + # print("after upsample", x1.shape) + y = torch.cat([x_l, self.up(x_h)], dim=1) + return self.conv(y) + +class Encoder(nn.Module): + def __init__(self, + in_chns = 1 , + ft_chns = [24, 48, 192, 384, 768], + input_size= [32, 128, 128], + down_dims = [3, 3, 3, 3, 3], + conv_dims = [3, 3, 3, 3, 3], + dropout = [0, 0, 0.2, 0.2, 0.2], + depths = [2, 2, 2], + num_heads = [4, 8, 16], + window_sizes = [6, 6, 6], + ): + super().__init__() + self.proj = nn.Conv3d(in_chns, ft_chns[0], kernel_size=3, padding=1) + self.conv0 = ConvBlock(ft_chns[0], ft_chns[0], conv_dims[0], dropout[0]) + self.conv1 = ConvBlock(ft_chns[1], ft_chns[1], conv_dims[1], dropout[1]) + self.conv2 = ConvBlock(ft_chns[2], ft_chns[2], conv_dims[2], dropout[2]) + + self.down1 = DownSample(ft_chns[0], ft_chns[1], down_dims[0], conv_dims[1]) + self.down2 = DownSample(ft_chns[1], ft_chns[2], down_dims[1], conv_dims[2]) + self.down3 = DownSample(ft_chns[2], ft_chns[3], down_dims[2], conv_dims[3]) + self.down4 = DownSample(ft_chns[3], ft_chns[4], down_dims[3], conv_dims[4]) + + down_scales = [] + for i in range(4): + down_scale = [1, 2, 2] if down_dims[i] == 2 else [2, 2, 2] + down_scales.append(down_scale) + + r_t2 = [input_size[i] // down_scales[0][i] // down_scales[1][i] for i in range(3)] + r_t3 = [r_t2[i] // down_scales[2][i] for i in range(3)] + r_t4 = [r_t3[i] // down_scales[3][i] for i in range(3)] + + self.conv_t2 = ConvTransBlock(chns = ft_chns[2], + input_resolution = r_t2, + window_size = window_sizes[0], + depth = depths[0], + num_head = num_heads[0], + drop_rate = dropout[2], + attn_drop_rate=dropout[2] + ) + self.conv_t3 = ConvTransBlock(chns = ft_chns[3], + input_resolution = r_t3, + window_size = window_sizes[1], + depth = depths[1], + num_head = num_heads[1], + drop_rate = dropout[3], + attn_drop_rate=dropout[3] + ) + self.conv_t4 = ConvTransBlock(chns = ft_chns[4], + input_resolution = r_t4, + window_size = window_sizes[2], + depth = depths[2], + num_head = num_heads[2], + drop_rate = dropout[4], + attn_drop_rate=dropout[4] + ) + + + + def forward(self, x): + """Forward function.""" + x0 = self.conv0(self.proj(x)) + x1 = self.conv1(self.down1(x0)) + x2 = self.conv2(self.down2(x1)) + x2 = self.conv_t2(x2) + x3 = self.conv_t3(self.down3(x2)) + x4 = self.conv_t4(self.down4(x3)) + return x0, x1, x2, x3, x4 + + +class Decoder(nn.Module): + """ + Decoder of 3D UNet. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param trilinear: (bool) Using bilinear for up-sampling or not. + If False, deconvolution will be used for up-sampling. + :param multiscale_pred: (bool) Get multi-scale prediction. + """ + def __init__(self, + ft_chns = [24, 48, 192, 384, 768], + input_size= [32, 128, 128], + down_dims = [3, 3, 3, 3, 3], + conv_dims = [3, 3, 3, 3, 3], + dropout = [0, 0, 0.2, 0.2, 0.2], + depths = [2, 2, 2], + num_heads = [4, 8, 16], + window_sizes = [6, 6, 6], + class_num = 2, + multiscale_pred = False + ): + super(Decoder, self).__init__() + # self.up0 = UpCatBlock(ft_chns[0] // 2, ft_chns[0], down_dims[0], 3) + # self.conv0 = ConvBlock(ft_chns[0] // 2, ft_chns[0] // 2, 3, 0) + self.up1 = UpCatBlock(ft_chns[0], ft_chns[1], down_dims[0], conv_dims[0]) + self.up2 = UpCatBlock(ft_chns[1], ft_chns[2], down_dims[1], conv_dims[1]) + self.up3 = UpCatBlock(ft_chns[2], ft_chns[3], down_dims[2], conv_dims[2]) + self.up4 = UpCatBlock(ft_chns[3], ft_chns[4], down_dims[3], conv_dims[3]) + + down_scales = [] + for i in range(4): + down_scale = [1, 2, 2] if down_dims[i] == 2 else [2, 2, 2] + down_scales.append(down_scale) + + r_t2 = [input_size[i] // down_scales[0][i] // down_scales[1][i] for i in range(3)] + r_t3 = [r_t2[i] // down_scales[2][i] for i in range(3)] + + self.conv0 = ConvBlock(ft_chns[0], ft_chns[0], conv_dims[0], dropout[0]) + self.conv1 = ConvBlock(ft_chns[1], ft_chns[1], conv_dims[1], dropout[1]) + self.conv2 = ConvTransBlock(chns = ft_chns[2], + input_resolution = r_t2, + window_size = window_sizes[0], + depth = depths[0], + num_head = num_heads[0], + drop_rate = dropout[2], + attn_drop_rate=dropout[2] + ) + self.conv3 = ConvTransBlock(chns = ft_chns[3], + input_resolution = r_t3, + window_size = window_sizes[1], + depth = depths[1], + num_head = num_heads[1], + drop_rate = dropout[3], + attn_drop_rate=dropout[3] + ) + + self.out_conv0 = ConvLayer(ft_chns[0], class_num) + + self.mul_pred = multiscale_pred + if(self.mul_pred): + self.out_conv1 = ConvLayer(ft_chns[1], class_num) + self.out_conv2 = ConvLayer(ft_chns[2], class_num) + self.out_conv3 = ConvLayer(ft_chns[3], class_num) + + def forward(self, x): + x0, x1, x2, x3, x4 = x + + x_d3 = self.conv3(self.up4(x3, x4)) + x_d2 = self.conv2(self.up3(x2, x_d3)) + x_d1 = self.conv1(self.up2(x1, x_d2)) + x_d0 = self.conv0(self.up1(x0, x_d1)) + output = self.out_conv0(x_d0) + + if(self.mul_pred): + output1 = self.out_conv1(x_d1) + output2 = self.out_conv2(x_d2) + output3 = self.out_conv3(x_d3) + output = [output, output1, output2, output3] + return output + +class HiFormer_v4(nn.Module): + def __init__(self, params): + """ + replace the embedding layer with convolutional blocks + """ + super(HiFormer_v4, self).__init__() + in_chns = params["in_chns"] + class_num = params["class_num"] + input_size = params["input_size"] + ft_chns = params.get("feature_chns", [32, 64, 128, 256, 512]) + down_dims = params.get("down_dims", [3, 3, 3, 3, 3]) + conv_dims = params.get("conv_dims", [3, 3, 3, 3, 3]) + dropout = params.get('dropout', [0, 0, 0.2, 0.2, 0.2]) + depths = params.get("depths", [2, 2, 2]) + num_heads = params.get("num_heads", [4, 8, 16]) + window_sizes= params.get("window_sizes", [6, 6, 6]) + multiscale_pred = params.get("multiscale_pred", False) + + self.encoder = Encoder(in_chns, + ft_chns = ft_chns, + input_size = input_size, + down_dims = down_dims, + conv_dims = conv_dims, + dropout = dropout, + depths = depths, + num_heads = num_heads, + window_sizes= window_sizes) + + self.decoder = Decoder(ft_chns = ft_chns, + input_size = input_size, + down_dims = down_dims, + conv_dims = conv_dims, + dropout = dropout, + depths = depths, + num_heads = num_heads, + window_sizes= window_sizes, + class_num = class_num, + multiscale_pred = multiscale_pred + ) + + def forward(self, x): + x = self.encoder(x) + x = self.decoder(x) + return x + + +if __name__ == "__main__": + params = {"input_size": [64, 96, 96], + "in_chns": 1, + "down_dims": [3, 3, 3, 3, 3], + "conv_dims": [3, 3, 3, 3, 3], + "feature_chns": [32, 64, 128, 256, 512], + "class_num": 5, + "multiscale_pred": True} + Net = HiFormer_v4(params) + Net = Net.double() + + x = np.random.rand(1, 1, 64, 96, 96) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = Net(xt) + if(params['multiscale_pred']): + for yi in y: + print(yi.shape) + else: + print(y.shape) + + + diff --git a/pymic/net/net3d/trans3d/HiFormer_v5.py b/pymic/net/net3d/trans3d/HiFormer_v5.py new file mode 100644 index 0000000..5fcef5a --- /dev/null +++ b/pymic/net/net3d/trans3d/HiFormer_v5.py @@ -0,0 +1,308 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import torch.nn as nn +import numpy as np +from torch.nn.functional import interpolate + + +class ConvBlock(nn.Module): + """ + 2D or 3D convolutional block + + :param in_channels: (int) Input channel number. + :param out_channels: (int) Output channel number. + :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. + :param dropout_p: (int) Dropout probability. + """ + def __init__(self, in_channels, out_channels, dropout_p = 0.0, dim = 3): + super(ConvBlock, self).__init__() + assert(dim == 2 or dim == 3) + if(dim == 2): + kernel_size = [1, 3, 3] + padding = [0, 1, 1] + else: + kernel_size = 3 + padding = 1 + + self.conv_conv = nn.Sequential( + nn.BatchNorm3d(in_channels), + nn.LeakyReLU(), + nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=padding), + nn.BatchNorm3d(out_channels), + nn.LeakyReLU(), + nn.Dropout(dropout_p), + nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, padding=padding), + ) + + def forward(self, x): + return self.conv_conv(x) + +class ConvLayer(nn.Module): + """ + 2D or 3D convolutional block + + :param in_channels: (int) Input channel number. + :param out_channels: (int) Output channel number. + :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. + :param dropout_p: (int) Dropout probability. + """ + def __init__(self, in_channels, out_channels, kernel = 1, padding = 0): + super(ConvLayer, self).__init__() + + self.conv = nn.Sequential( + nn.BatchNorm3d(in_channels), + nn.LeakyReLU(), + nn.Conv3d(in_channels, out_channels, kernel_size=kernel, padding=padding), + ) + + def forward(self, x): + return self.conv(x) + +class DownBlock(nn.Module): + """ + 3D downsampling followed by ConvBlock + + :param in_channels: (int) Input channel number. + :param out_channels: (int) Output channel number. + :param dropout_p: (int) Dropout probability. + """ + def __init__(self, in_channels, out_channels, dropout_p): + super(DownBlock, self).__init__() + self.maxpool_conv = nn.Sequential( + nn.MaxPool3d(2), + ConvBlock(in_channels, out_channels, dropout_p) + ) + + def forward(self, x): + return self.maxpool_conv(x) + +class UpBlock(nn.Module): + """ + 3D upsampling followed by ConvBlock + + :param in_channels1: (int) Channel number of high-level features. + :param in_channels2: (int) Channel number of low-level features. + :param out_channels: (int) Output channel number. + :param dropout_p: (int) Dropout probability. + :param trilinear: (bool) Use trilinear for up-sampling (by default). + If False, deconvolution is used for up-sampling. + """ + def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, + trilinear=True): + super(UpBlock, self).__init__() + self.trilinear = trilinear + if trilinear: + self.conv1x1 = nn.Conv3d(in_channels1, in_channels2, kernel_size = 1) + self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True) + else: + self.up = nn.Sequential( + nn.BatchNorm3d(in_channels1), + nn.LeakyReLU(), + nn.ConvTranspose3d(in_channels1, in_channels2, kernel_size=2, stride=2) + ) + self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p) + + def forward(self, x1, x2): + if self.trilinear: + x1 = self.conv1x1(x1) + x1 = self.up(x1) + x = torch.cat([x2, x1], dim=1) + return self.conv(x) + +class Encoder(nn.Module): + """ + Encoder of 3D UNet. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + """ + def __init__(self, params): + super(Encoder, self).__init__() + self.params = params + self.in_chns = self.params['in_chns'] + self.ft_chns = self.params['feature_chns'] + self.dropout = self.params['dropout'] + assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) + + self.proj = nn.Conv3d(self.in_chns, self.ft_chns[0], kernel_size=3, padding=1) + self.in_conv= ConvBlock(self.ft_chns[0], self.ft_chns[0], self.dropout[0]) + self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) + self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) + self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) + if(len(self.ft_chns) == 5): + self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) + + def forward(self, x): + x0 = self.in_conv(self.proj(x)) + x1 = self.down1(x0) + x2 = self.down2(x1) + x3 = self.down3(x2) + output = [x0, x1, x2, x3] + if(len(self.ft_chns) == 5): + x4 = self.down4(x3) + output.append(x4) + return output + +class Decoder(nn.Module): + """ + Decoder of 3D UNet. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param trilinear: (bool) Using bilinear for up-sampling or not. + If False, deconvolution will be used for up-sampling. + :param multiscale_pred: (bool) Get multi-scale prediction. + """ + def __init__(self, params): + super(Decoder, self).__init__() + self.params = params + self.in_chns = self.params['in_chns'] + self.ft_chns = self.params['feature_chns'] + self.dropout = self.params['dropout'] + self.n_class = self.params['class_num'] + self.trilinear = self.params.get('trilinear', True) + self.mul_pred = self.params.get('multiscale_pred', False) + + assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) + + if(len(self.ft_chns) == 5): + self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.trilinear) + self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.trilinear) + self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.trilinear) + self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.trilinear) + self.out_conv = ConvLayer(self.ft_chns[0], self.n_class) + if(self.mul_pred): + self.out_conv1 = ConvLayer(self.ft_chns[1], self.n_class) + self.out_conv2 = ConvLayer(self.ft_chns[2], self.n_class) + self.out_conv3 = ConvLayer(self.ft_chns[3], self.n_class) + + def forward(self, x): + if(len(self.ft_chns) == 5): + assert(len(x) == 5) + x0, x1, x2, x3, x4 = x + x_d3 = self.up1(x4, x3) + else: + assert(len(x) == 4) + x0, x1, x2, x3 = x + x_d3 = x3 + x_d2 = self.up2(x_d3, x2) + x_d1 = self.up3(x_d2, x1) + x_d0 = self.up4(x_d1, x0) + output = self.out_conv(x_d0) + if(self.mul_pred): + output1 = self.out_conv1(x_d1) + output2 = self.out_conv2(x_d2) + output3 = self.out_conv3(x_d3) + output = [output, output1, output2, output3] + return output + +class HiFormer_v5(nn.Module): + """ + An implementation of the U-Net. + + * Reference: Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, Olaf Ronneberger: + 3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation. + `MICCAI (2) 2016: 424-432. `_ + + Note that there are some modifications from the original paper, such as + the use of batch normalization, dropout, leaky relu and deep supervision. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param trilinear: (bool) Using trilinear for up-sampling or not. + If False, deconvolution will be used for up-sampling. + :param multiscale_pred: (bool) Get multi-scale prediction. + """ + def __init__(self, params): + super(HiFormer_v5, self).__init__() + self.params = params + self.in_chns = self.params['in_chns'] + self.ft_chns = self.params['feature_chns'] + self.dropout = self.params['dropout'] + self.n_class = self.params['class_num'] + self.trilinear = self.params['trilinear'] + self.mul_pred = self.params['multiscale_pred'] + assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) + + self.proj = nn.Conv3d(self.in_chns, self.ft_chns[0], kernel_size=3, padding=1) + self.in_conv= ConvBlock(self.ft_chns[0], self.ft_chns[0], self.dropout[0]) + self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) + self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) + self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) + if(len(self.ft_chns) == 5): + self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) + self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], + dropout_p = self.dropout[3], trilinear=self.trilinear) + self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], + dropout_p = self.dropout[2], trilinear=self.trilinear) + self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], + dropout_p = self.dropout[1], trilinear=self.trilinear) + self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], + dropout_p = self.dropout[0], trilinear=self.trilinear) + + self.out_conv = ConvLayer(self.ft_chns[0], self.n_class) + if(self.mul_pred): + self.out_conv1 = ConvLayer(self.ft_chns[1], self.n_class) + self.out_conv2 = ConvLayer(self.ft_chns[2], self.n_class) + self.out_conv3 = ConvLayer(self.ft_chns[3], self.n_class) + + def forward(self, x): + x0 = self.in_conv(self.proj(x)) + x1 = self.down1(x0) + x2 = self.down2(x1) + x3 = self.down3(x2) + if(len(self.ft_chns) == 5): + x4 = self.down4(x3) + x_d3 = self.up1(x4, x3) + else: + x_d3 = x3 + x_d2 = self.up2(x_d3, x2) + x_d1 = self.up3(x_d2, x1) + x_d0 = self.up4(x_d1, x0) + output = self.out_conv(x_d0) + if(self.mul_pred): + output1 = self.out_conv1(x_d1) + output2 = self.out_conv2(x_d2) + output3 = self.out_conv3(x_d3) + output = [output, output1, output2, output3] + return output + +if __name__ == "__main__": + params = {'in_chns':4, + 'class_num': 2, + 'feature_chns':[32, 64, 128, 256, 512], + 'dropout' : [0, 0, 0, 0, 0.5], + 'trilinear': False, + 'multiscale_pred': False} + Net = HiFormer_v5(params) + Net = Net.double() + + x = np.random.rand(4, 4, 96, 96, 96) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = Net(xt) + y = y.detach().numpy() + print(y.shape) diff --git a/pymic/net/net3d/trans3d/MedFormer_v1.py b/pymic/net/net3d/trans3d/MedFormer_v1.py new file mode 100644 index 0000000..1f2ed54 --- /dev/null +++ b/pymic/net/net3d/trans3d/MedFormer_v1.py @@ -0,0 +1,173 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import math +import torch +import torch.nn as nn +import numpy as np +from torch.nn import Dropout, Softmax, Linear, Conv2d, LayerNorm +from pymic.net.net3d.unet3d import Encoder, Decoder + +class Attention(nn.Module): + def __init__(self, params): + super(Attention, self).__init__() + hidden_size = params["attention_hidden_size"] + self.num_attention_heads = params["attention_num_heads"] + self.attention_head_size = int(hidden_size / self.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = Linear(hidden_size, self.all_head_size) + self.key = Linear(hidden_size, self.all_head_size) + self.value = Linear(hidden_size, self.all_head_size) + + self.out = Linear(hidden_size, hidden_size) + self.attn_dropout = Dropout(params["attention_dropout_rate"]) + self.proj_dropout = Dropout(params["attention_dropout_rate"]) + + self.softmax = Softmax(dim=-1) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + attention_probs = self.softmax(attention_scores) + # weights = attention_probs if self.vis else None + attention_probs = self.attn_dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + attention_output = self.out(context_layer) + attention_output = self.proj_dropout(attention_output) + return attention_output + +class MLP(nn.Module): + def __init__(self, params): + super(MLP, self).__init__() + hidden_size = params["attention_hidden_size"] + mlp_dim = params["attention_mlp_dim"] + self.fc1 = Linear(hidden_size, mlp_dim) + self.fc2 = Linear(mlp_dim, hidden_size) + self.act_fn = torch.nn.functional.gelu + self.dropout = Dropout(params["attention_dropout_rate"]) + + self._init_weights() + + def _init_weights(self): + nn.init.xavier_uniform_(self.fc1.weight) + nn.init.xavier_uniform_(self.fc2.weight) + nn.init.normal_(self.fc1.bias, std=1e-6) + nn.init.normal_(self.fc2.bias, std=1e-6) + + def forward(self, x): + x = self.fc1(x) + x = self.act_fn(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + +class Block(nn.Module): + def __init__(self, params): + super(Block, self).__init__() + hidden_size = params["attention_hidden_size"] + self.attention_norm = LayerNorm(hidden_size, eps=1e-6) + self.ffn_norm = LayerNorm(hidden_size, eps=1e-6) + self.ffn = MLP(params) + self.attn = Attention(params) + + def forward(self, x): + # convert the tensor shape from [B, C, D, H, W] to [B, DHW, C] + [B, C, D, H, W] = list(x.shape) + new_shape = [B, C, D*H*W] + x = torch.reshape(x, new_shape) + x = torch.transpose(x, 1, 2) + + h = x + x = self.attention_norm(x) + x = self.attn(x) + x = x + h + + h = x + x = self.ffn_norm(x) + x = self.ffn(x) + x = x + h + + # convert the result back to [B, C, D, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, [B, C, D, H, W]) + return x + +class MedFormerV1(nn.Module): + """ + An implementation of the U-Net. + + * Reference: Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, Olaf Ronneberger: + 3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation. + `MICCAI (2) 2016: 424-432. `_ + + Note that there are some modifications from the original paper, such as + the use of batch normalization, dropout, leaky relu and deep supervision. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param trilinear: (bool) Using trilinear for up-sampling or not. + If False, deconvolution will be used for up-sampling. + :param deep_supervise: (bool) Using deep supervision for training or not. + """ + def __init__(self, params): + super(MedFormerV1, self).__init__() + self.params = params + self.encoder = Encoder(params) + self.decoder = Decoder(params) + self.attn = Block(params) + + def forward(self, x): + f = self.encoder(x) + f[-1] = self.attn(f[-1]) + output = self.decoder(f) + return output + +if __name__ == "__main__": + params = {'in_chns':4, + 'class_num': 2, + 'feature_chns':[16, 32, 64, 128], + 'dropout' : [0, 0, 0, 0.5], + 'trilinear': True, + 'deep_supervise': True, + 'attention_hidden_size': 128, + 'attention_num_heads': 4, + 'attention_mlp_dim': 256, + 'attention_dropout_rate': 0.2} + Net = MedFormerV1(params) + Net = Net.double() + + x = np.random.rand(1, 4, 96, 96, 96) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = Net(xt) + print("output length", len(y)) + for yi in y: + yi = yi.detach().numpy() + print(yi.shape) diff --git a/pymic/net/net3d/trans3d/MedFormer_v2.py b/pymic/net/net3d/trans3d/MedFormer_v2.py new file mode 100644 index 0000000..00cb295 --- /dev/null +++ b/pymic/net/net3d/trans3d/MedFormer_v2.py @@ -0,0 +1,464 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import math +import copy +import torch +import torch.nn as nn +import numpy as np +import torch.nn.functional as F +from pymic.net.net3d.unet3d import ConvBlock, Encoder, Decoder +from pymic.net.net3d.trans3d.MedFormer_v1 import Block +from timm.models.layers import DropPath, to_3tuple, trunc_normal_ + + +# code from nnFormer +class Mlp(nn.Module): + """ Multilayer perceptron.""" + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + + B, S, H, W, C = x.shape + x = x.view(B, S // window_size, window_size, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, S, H, W): + + B = int(windows.shape[0] / (S * H * W / window_size / window_size / window_size)) + x = windows.view(B, S // window_size, H // window_size, W // window_size, window_size, window_size, window_size, -1) + x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, S, H, W, -1) + return x + + +class WindowAttention(nn.Module): + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), + num_heads)) + + # get pair-wise relative position index for each token inside the window + coords_s = torch.arange(self.window_size[0]) + coords_h = torch.arange(self.window_size[1]) + coords_w = torch.arange(self.window_size[2]) + coords = torch.stack(torch.meshgrid([coords_s, coords_h, coords_w])) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 2] += self.window_size[2] - 1 + + relative_coords[:, :, 0] *= 3 * self.window_size[1] - 1 + relative_coords[:, :, 1] *= 2 * self.window_size[1] - 1 + + relative_position_index = relative_coords.sum(-1) + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + trunc_normal_(self.relative_position_bias_table, std=.02) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None,pos_embed=None): + + B_, N, C = x.shape + + qkv = self.qkv(x) + + qkv=qkv.reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1).contiguous()) + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] * self.window_size[2], + self.window_size[0] * self.window_size[1] * self.window_size[2], -1) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C).contiguous() + if pos_embed is not None: + x = x+pos_embed + x = self.proj(x) + x = self.proj_drop(x) + return x + +class SwinTransformerBlock(nn.Module): + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + + self.attn = WindowAttention( + dim, window_size=to_3tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + + def forward(self, x, mask_matrix): + + B, L, C = x.shape + S, H, W = self.input_resolution + + assert L == S * H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, S, H, W, C) + + # pad feature maps to multiples of window size + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + pad_g = (self.window_size - S % self.window_size) % self.window_size + + x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b, 0, pad_g)) + _, Sp, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size,-self.shift_size), dims=(1, 2,3)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size * self.window_size, + C) + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=attn_mask,pos_embed=None) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Sp, Hp, Wp) + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size, self.shift_size), dims=(1, 2, 3)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0 or pad_g > 0: + x = x[:, :S, :H, :W, :].contiguous() + + x = x.view(B, S * H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + +class BasicLayer(nn.Module): + + def __init__(self, + dim, + input_resolution, + depth, + num_heads, + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=True + ): + super().__init__() + self.window_size = window_size + self.shift_size = window_size // 2 + self.depth = depth + # build blocks + + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, S, H, W): + + + # calculate attention mask for SW-MSA + Sp = int(np.ceil(S / self.window_size)) * self.window_size + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Sp, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + s_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for s in s_slices: + for h in h_slices: + for w in w_slices: + img_mask[:, s, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) + mask_windows = mask_windows.view(-1, + self.window_size * self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + for blk in self.blocks: + + x = blk(x, attn_mask) + if self.downsample is not None: + x_down = self.downsample(x, S, H, W) + Ws, Wh, Ww = (S + 1) // 2, (H + 1) // 2, (W + 1) // 2 + return x, S, H, W, x_down, Ws, Wh, Ww + else: + return x, S, H, W, x, S, H, W + + +class AttUpBlock(nn.Module): + """ + 3D upsampling followed by ConvBlock + + :param in_channels1: (int) Channel number of high-level features. + :param in_channels2: (int) Channel number of low-level features. + :param out_channels: (int) Output channel number. + :param dropout_p: (int) Dropout probability. + :param trilinear: (bool) Use trilinear for up-sampling (by default). + If False, deconvolution is used for up-sampling. + """ + def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, + trilinear=True, with_att = False, att_params = None): + super(AttUpBlock, self).__init__() + self.trilinear = trilinear + self.with_att = with_att + if trilinear: + self.conv1x1 = nn.Conv3d(in_channels1, in_channels2, kernel_size = 1) + self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True) + else: + self.up = nn.ConvTranspose3d(in_channels1, in_channels2, kernel_size=2, stride=2) + self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p) + if(self.with_att): + input_resolution = att_params['input_resolution'] + depth = att_params['depth'] + num_heads = att_params['num_heads'] + self.attn = BasicLayer(out_channels, input_resolution, depth, num_heads, downsample=None) + + def forward(self, x1, x2): + if self.trilinear: + x1 = self.conv1x1(x1) + x1 = self.up(x1) + x = torch.cat([x2, x1], dim=1) + x = self.conv(x) + if(self.with_att): + [B, C, D, H, W] = list(x.shape) + x = x.flatten(2).transpose(1, 2).contiguous() + x = self.attn(x, D, H, W)[0] + x = x.view(-1, D, H, W, C).permute(0, 4, 1, 2, 3).contiguous() + return x + +class AttDecoder(nn.Module): + """ + Decoder of 3D UNet. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param trilinear: (bool) Using bilinear for up-sampling or not. + If False, deconvolution will be used for up-sampling. + """ + def __init__(self, params): + super(AttDecoder, self).__init__() + self.params = params + self.in_chns = self.params['in_chns'] + self.ft_chns = self.params['feature_chns'] + self.dropout = self.params['dropout'] + self.n_class = self.params['class_num'] + self.trilinear = self.params.get('trilinear', True) + self.mul_pred = self.params['multiscale_pred'] + + assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) + + if(len(self.ft_chns) == 5): + self.up1 = AttUpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.trilinear) + att_params = {"input_resolution": [24, 24, 24], "depth": 2, "num_heads": 4} + self.up2 = AttUpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.trilinear, True, att_params) + att_params = {"input_resolution": [48, 48, 48], "depth": 2, "num_heads": 4} + self.up3 = AttUpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.trilinear, True, att_params) + self.up4 = AttUpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.trilinear) + self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 1) + if(self.mul_pred): + self.out_conv1 = nn.Conv3d(self.ft_chns[1], self.n_class, kernel_size = 1) + self.out_conv2 = nn.Conv3d(self.ft_chns[2], self.n_class, kernel_size = 1) + self.out_conv3 = nn.Conv3d(self.ft_chns[3], self.n_class, kernel_size = 1) + + def forward(self, x): + if(len(self.ft_chns) == 5): + assert(len(x) == 5) + x0, x1, x2, x3, x4 = x + x_d3 = self.up1(x4, x3) + else: + assert(len(x) == 4) + x0, x1, x2, x3 = x + x_d3 = x3 + x_d2 = self.up2(x_d3, x2) + x_d1 = self.up3(x_d2, x1) + x_d0 = self.up4(x_d1, x0) + output = self.out_conv(x_d0) + if(self.mul_pred): + output1 = self.out_conv1(x_d1) + output2 = self.out_conv2(x_d2) + output3 = self.out_conv3(x_d3) + output = [output, output1, output2, output3] + return output + +class MedFormerV2(nn.Module): + """ + An implementation of the U-Net. + + * Reference: Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, Olaf Ronneberger: + 3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation. + `MICCAI (2) 2016: 424-432. `_ + + Note that there are some modifications from the original paper, such as + the use of batch normalization, dropout, leaky relu and deep supervision. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param trilinear: (bool) Using trilinear for up-sampling or not. + If False, deconvolution will be used for up-sampling. + """ + def __init__(self, params): + super(MedFormerV2, self).__init__() + self.params = params + self.encoder = Encoder(params) + self.decoder = AttDecoder(params) + self.attn = Block(params) + + def forward(self, x): + f = self.encoder(x) + f[-1] = self.attn(f[-1]) + output = self.decoder(f) + return output + +if __name__ == "__main__": + params = {'in_chns':4, + 'class_num': 2, + 'feature_chns':[16, 32, 64, 128], + 'dropout' : [0, 0, 0, 0.5], + 'trilinear': True, + 'multiscale_pred': True, + 'attention_hidden_size': 128, + 'attention_num_heads': 4, + 'attention_mlp_dim': 256, + 'attention_dropout_rate': 0.2} + + Net = MedFormerV2(params) + Net = Net.double() + + x = np.random.rand(1, 4, 96, 96, 96) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = Net(xt) + print("output length", len(y)) + for yi in y: + yi = yi.detach().numpy() + print(yi.shape) diff --git a/pymic/net/net3d/trans3d/MedFormer_v3.py b/pymic/net/net3d/trans3d/MedFormer_v3.py new file mode 100644 index 0000000..f119a9c --- /dev/null +++ b/pymic/net/net3d/trans3d/MedFormer_v3.py @@ -0,0 +1,255 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import numpy as np +import torch +import torch.nn as nn +from torch.nn.functional import interpolate +from pymic.net.net3d.unet3d import ConvBlock, Encoder +from pymic.net.net3d.trans3d.MedFormer_v1 import Block +from pymic.net.net3d.trans3d.MedFormer_v2 import SwinTransformerBlock, window_partition + +class GLAttLayer(nn.Module): + def __init__(self, + dim, + input_resolution, + num_heads, + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm): + super().__init__() + self.window_size = window_size + self.shift_size = window_size // 2 + # build blocks + + self.lcl_att = SwinTransformerBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=0, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path) + self.adpool = nn.AdaptiveAvgPool3d([12, 12, 12]) + + params = {'attention_hidden_size': dim, + 'attention_num_heads': 4, + 'attention_mlp_dim': dim, + 'attention_dropout_rate': 0.2} + self.glb_att = Block(params) + self.conv1x1 = nn.Sequential( + nn.Conv3d(2*dim, dim, kernel_size=1), + nn.BatchNorm3d(dim), + nn.LeakyReLU()) + + def forward(self, x): + [B, C, S, H, W] = list(x.shape) + # calculate attention mask for SW-MSA + Sp = int(np.ceil(S / self.window_size)) * self.window_size + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Sp, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + s_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for s in s_slices: + for h in h_slices: + for w in w_slices: + img_mask[:, s, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) + mask_windows = mask_windows.view(-1, + self.window_size * self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + # for local attention + xl = x.flatten(2).transpose(1, 2).contiguous() + xl = self.lcl_att(xl, attn_mask) + xl = xl.view(-1, S, H, W, C).permute(0, 4, 1, 2, 3).contiguous() + + # for global attention + xg = self.adpool(x) + xg = self.glb_att(xg) + xg = interpolate(xg, [S, H, W], mode = 'trilinear') + out = torch.cat([xl, xg], dim=1) + out = self.conv1x1(out) + return out + +class AttUpBlock(nn.Module): + """ + 3D upsampling followed by ConvBlock + + :param in_channels1: (int) Channel number of high-level features. + :param in_channels2: (int) Channel number of low-level features. + :param out_channels: (int) Output channel number. + :param dropout_p: (int) Dropout probability. + :param trilinear: (bool) Use trilinear for up-sampling (by default). + If False, deconvolution is used for up-sampling. + """ + def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, + trilinear=True, with_att = False, att_params = None): + super(AttUpBlock, self).__init__() + self.trilinear = trilinear + self.with_att = with_att + if trilinear: + self.conv1x1 = nn.Conv3d(in_channels1, in_channels2, kernel_size = 1) + self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True) + else: + self.up = nn.ConvTranspose3d(in_channels1, in_channels2, kernel_size=2, stride=2) + self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p) + if(self.with_att): + input_resolution = att_params['input_resolution'] + num_heads = att_params['num_heads'] + window_size = att_params['window_size'] + self.attn = GLAttLayer(out_channels, input_resolution, num_heads, window_size, 2.0) + + def forward(self, x1, x2): + if self.trilinear: + x1 = self.conv1x1(x1) + x1 = self.up(x1) + x = torch.cat([x2, x1], dim=1) + x = self.conv(x) + if(self.with_att): + x = self.attn(x) + return x + + +class AttDecoder(nn.Module): + """ + Decoder of 3D UNet. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param trilinear: (bool) Using bilinear for up-sampling or not. + If False, deconvolution will be used for up-sampling. + """ + def __init__(self, params): + super(AttDecoder, self).__init__() + self.params = params + self.in_chns = self.params['in_chns'] + self.ft_chns = self.params['feature_chns'] + self.dropout = self.params['dropout'] + self.n_class = self.params['class_num'] + self.trilinear = self.params.get('trilinear', True) + self.mul_pred = self.params['multiscale_pred'] + + assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) + + if(len(self.ft_chns) == 5): + self.up1 = AttUpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.trilinear) + att_params = {"input_resolution": [24, 24, 24], "num_heads": 4, "window_size": 7} + self.up2 = AttUpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.trilinear, True, att_params) + att_params = {"input_resolution": [48, 48, 48], "num_heads": 4, "window_size": 7} + self.up3 = AttUpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.trilinear, True, att_params) + self.up4 = AttUpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.trilinear) + self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 1) + if(self.mul_pred): + self.out_conv1 = nn.Conv3d(self.ft_chns[1], self.n_class, kernel_size = 1) + self.out_conv2 = nn.Conv3d(self.ft_chns[2], self.n_class, kernel_size = 1) + self.out_conv3 = nn.Conv3d(self.ft_chns[3], self.n_class, kernel_size = 1) + + def forward(self, x): + if(len(self.ft_chns) == 5): + assert(len(x) == 5) + x0, x1, x2, x3, x4 = x + x_d3 = self.up1(x4, x3) + else: + assert(len(x) == 4) + x0, x1, x2, x3 = x + x_d3 = x3 + x_d2 = self.up2(x_d3, x2) + x_d1 = self.up3(x_d2, x1) + x_d0 = self.up4(x_d1, x0) + output = self.out_conv(x_d0) + if(self.mul_pred): + output1 = self.out_conv1(x_d1) + output2 = self.out_conv2(x_d2) + output3 = self.out_conv3(x_d3) + output = [output, output1, output2, output3] + return output + +class MedFormerV3(nn.Module): + """ + An implementation of the U-Net. + + * Reference: Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, Olaf Ronneberger: + 3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation. + `MICCAI (2) 2016: 424-432. `_ + + Note that there are some modifications from the original paper, such as + the use of batch normalization, dropout, leaky relu and deep supervision. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param trilinear: (bool) Using trilinear for up-sampling or not. + If False, deconvolution will be used for up-sampling. + """ + def __init__(self, params): + super(MedFormerV3, self).__init__() + self.params = params + self.encoder = Encoder(params) + self.decoder = AttDecoder(params) + params["attention_hidden_size"] = params['feature_chns'][-1] + params["attention_mlp_dim"] = params['feature_chns'][-1] + self.attn = Block(params) + + def forward(self, x): + f = self.encoder(x) + f[-1] = self.attn(f[-1]) + output = self.decoder(f) + return output + +if __name__ == "__main__": + params = {'in_chns':4, + 'class_num': 2, + 'feature_chns':[16, 32, 64, 128], + 'dropout' : [0, 0, 0, 0.5], + 'trilinear': True, + 'multiscale_pred': True, + 'attention_num_heads': 4, + 'attention_dropout_rate': 0.2} + + Net = MedFormerV3(params) + Net = Net.double() + + x = np.random.rand(2, 4, 96, 96, 96) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = Net(xt) + print("output length", len(y)) + for yi in y: + yi = yi.detach().numpy() + print(yi.shape) diff --git a/pymic/net/net3d/trans3d/MedFormer_va1.py b/pymic/net/net3d/trans3d/MedFormer_va1.py new file mode 100644 index 0000000..27dfa3e --- /dev/null +++ b/pymic/net/net3d/trans3d/MedFormer_va1.py @@ -0,0 +1,105 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import math +import torch +import torch.nn as nn +import numpy as np +from torch.nn import Dropout, Softmax, Linear, Conv2d, LayerNorm +from pymic.net.net3d.unet3d import Decoder + +class EmbeddingBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, padding, stride): + super(EmbeddingBlock, self).__init__() + self.out_channels = out_channels + self.conv1 = nn.Conv3d(in_channels, out_channels//2, kernel_size=kernel_size, padding=padding, stride = stride) + self.conv2 = nn.Conv3d(out_channels//2, out_channels, kernel_size=1) + self.act = nn.GELU() + self.norm1 = nn.LayerNorm(out_channels//2) + self.norm2 = nn.LayerNorm(out_channels) + + + def forward(self, x): + x = self.act(self.conv1(x)) + # norm 1 + Ws, Wh, Ww = x.size(2), x.size(3), x.size(4) + x = x.flatten(2).transpose(1, 2).contiguous() + x = self.norm1(x) + x = x.transpose(1, 2).contiguous().view(-1, self.out_channels // 2, Ws, Wh, Ww) + + x = self.act(self.conv2(x)) + x = x.flatten(2).transpose(1, 2).contiguous() + x = self.norm2(x) + x = x.transpose(1, 2).contiguous().view(-1, self.out_channels, Ws, Wh, Ww) + + return x + +class Encoder(nn.Module): + """ + Encoder of 3D UNet. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + """ + def __init__(self, params): + super(Encoder, self).__init__() + self.params = params + self.in_chns = self.params['in_chns'] + self.ft_chns = self.params['feature_chns'] + assert(len(self.ft_chns) == 4) + + self.down0 = EmbeddingBlock(self.in_chns, self.ft_chns[0], 3, 1, 1) + self.down1 = EmbeddingBlock(self.in_chns, self.ft_chns[1], 2, 0, 2) + self.down2 = EmbeddingBlock(self.in_chns, self.ft_chns[2], 4, 0, 4) + self.down3 = EmbeddingBlock(self.in_chns, self.ft_chns[3], 8, 0, 8) + + def forward(self, x): + x0 = self.down0(x) + x1 = self.down1(x) + x2 = self.down2(x) + x3 = self.down3(x) + output = [x0, x1, x2, x3] + return output + +class MedFormerVA1(nn.Module): + def __init__(self, params): + super(MedFormerVA1, self).__init__() + self.params = params + self.encoder = Encoder(params) + self.decoder = Decoder(params) + + def forward(self, x): + f = self.encoder(x) + output = self.decoder(f) + return output + + +if __name__ == "__main__": + params = {'in_chns':1, + 'class_num': 8, + 'feature_chns':[16, 32, 64, 128], + 'dropout' : [0, 0, 0, 0.5], + 'trilinear': True, + 'deep_supervise': True, + 'attention_hidden_size': 128, + 'attention_num_heads': 4, + 'attention_mlp_dim': 256, + 'attention_dropout_rate': 0.2} + Net = MedFormerVA1(params) + Net = Net.double() + + x = np.random.rand(1, 1, 96, 96, 96) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = Net(xt) + print("output length", len(y)) + for yi in y: + yi = yi.detach().numpy() + print(yi.shape) \ No newline at end of file diff --git a/pymic/net/net3d/trans3d/__init__.py b/pymic/net/net3d/trans3d/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pymic/net/net3d/trans3d/nnFormer_wrap.py b/pymic/net/net3d/trans3d/nnFormer_wrap.py new file mode 100644 index 0000000..35593a4 --- /dev/null +++ b/pymic/net/net3d/trans3d/nnFormer_wrap.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import math +import torch +import torch.nn as nn +import numpy as np +from nnformer.network_architecture.nnFormer_tumor import nnFormer + +class nnFormer_wrap(nn.Module): + def __init__(self, params): + super(nnFormer_wrap, self).__init__() + patch_size = params["patch_size"] # 96x96x96 + n_class = params['class_num'] + in_chns = params['in_chns'] + # https://github.com/282857341/nnFormer/blob/main/nnformer/network_architecture/nnFormer_tumor.py + self.nnformer = nnFormer(crop_size = patch_size, + embedding_dim=192, + input_channels = in_chns, + num_classes = n_class, + conv_op=nn.Conv3d, + depths =[2,2,2,2], + num_heads = [6, 12, 24, 48], + patch_size = [4,4,4], + window_size= [4,4,8,4], + deep_supervision=False) + + def forward(self, x): + return self.nnformer(x) + +if __name__ == "__main__": + params = {"patch_size": [96, 96, 96], + "in_chns": 1, + "class_num": 5} + Net = nnFormer_wrap(params) + Net = Net.double() + + x = np.random.rand(1, 1, 96, 96, 96) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = Net(xt) + print(y.shape) diff --git a/pymic/net/net3d/trans3d/unetr.py b/pymic/net/net3d/trans3d/unetr.py new file mode 100644 index 0000000..ea90b2f --- /dev/null +++ b/pymic/net/net3d/trans3d/unetr.py @@ -0,0 +1,227 @@ +from __future__ import print_function, division + +import torch +import torch.nn as nn + +from monai.networks.blocks import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock +from monai.networks.blocks.dynunet_block import UnetOutBlock +from monai.networks.nets import ViT + + +class UNETR(nn.Module): + """ + UNETR based on: "Hatamizadeh et al., + UNETR: Transformers for 3D Medical Image Segmentation " + """ + + def __init__(self, params): + # in_channels: int, + # out_channels: int, + # img_size: Tuple[int, int, int], + # feature_size: int = 16, + # hidden_size: int = 768, + # mlp_dim: int = 3072, + # num_heads: int = 12, + # pos_embed: str = "perceptron", + # norm_name: Union[Tuple, str] = "instance", + # conv_block: bool = False, + # res_block: bool = True, + # dropout_rate: float = 0.0, + # ) -> None: + """ + Args: + in_channels: dimension of input channels. + out_channels: dimension of output channels. + img_size: dimension of input image. + feature_size: dimension of network feature size. + hidden_size: dimension of hidden layer. + mlp_dim: dimension of feedforward layer. + num_heads: number of attention heads. + pos_embed: position embedding layer type. + norm_name: feature normalization type and arguments. + conv_block: bool argument to determine if convolutional block is used. + res_block: bool argument to determine if residual block is used. + dropout_rate: faction of the input units to drop. + Examples:: + # for single channel input 4-channel output with patch size of (96,96,96), feature size of 32 and batch norm + >>> net = UNETR(in_channels=1, out_channels=4, img_size=(96,96,96), feature_size=32, norm_name='batch') + # for 4-channel input 3-channel output with patch size of (128,128,128), conv position embedding and instance norm + >>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), pos_embed='conv', norm_name='instance') + """ + + super().__init__() + in_channels = params['in_chns'] + out_channels = params['class_num'] + img_size = params['img_size'] + feature_size = 16 + hidden_size = 768 + mlp_dim = 3072 + num_heads = 12 + pos_embed = "perceptron" + norm_name = "instance" + conv_block = False + res_block = True + dropout_rate = 0.0 + + if not (0 <= dropout_rate <= 1): + raise AssertionError("dropout_rate should be between 0 and 1.") + + if hidden_size % num_heads != 0: + raise AssertionError("hidden size should be divisible by num_heads.") + + if pos_embed not in ["conv", "perceptron"]: + raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.") + + self.num_layers = 12 + self.patch_size = (16, 16, 16) + self.feat_size = ( + img_size[0] // self.patch_size[0], + img_size[1] // self.patch_size[1], + img_size[2] // self.patch_size[2], + ) + self.hidden_size = hidden_size + self.classification = False + self.vit = ViT( + in_channels=in_channels, + img_size=img_size, + patch_size=self.patch_size, + hidden_size=hidden_size, + mlp_dim=mlp_dim, + num_layers=self.num_layers, + num_heads=num_heads, + pos_embed=pos_embed, + classification=self.classification, + dropout_rate=dropout_rate, + ) + self.encoder1 = UnetrBasicBlock( + spatial_dims=3, + in_channels=in_channels, + out_channels=feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=res_block, + ) + self.encoder2 = UnetrPrUpBlock( + spatial_dims=3, + in_channels=hidden_size, + out_channels=feature_size * 2, + num_layer=2, + kernel_size=3, + stride=1, + upsample_kernel_size=2, + norm_name=norm_name, + conv_block=conv_block, + res_block=res_block, + ) + self.encoder3 = UnetrPrUpBlock( + spatial_dims=3, + in_channels=hidden_size, + out_channels=feature_size * 4, + num_layer=1, + kernel_size=3, + stride=1, + upsample_kernel_size=2, + norm_name=norm_name, + conv_block=conv_block, + res_block=res_block, + ) + self.encoder4 = UnetrPrUpBlock( + spatial_dims=3, + in_channels=hidden_size, + out_channels=feature_size * 8, + num_layer=0, + kernel_size=3, + stride=1, + upsample_kernel_size=2, + norm_name=norm_name, + conv_block=conv_block, + res_block=res_block, + ) + self.decoder5 = UnetrUpBlock( + spatial_dims=3, + in_channels=hidden_size, + out_channels=feature_size * 8, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=res_block, + ) + self.decoder4 = UnetrUpBlock( + spatial_dims=3, + in_channels=feature_size * 8, + out_channels=feature_size * 4, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=res_block, + ) + self.decoder3 = UnetrUpBlock( + spatial_dims=3, + in_channels=feature_size * 4, + out_channels=feature_size * 2, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=res_block, + ) + self.decoder2 = UnetrUpBlock( + spatial_dims=3, + in_channels=feature_size * 2, + out_channels=feature_size, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=res_block, + ) + self.out = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels) # type: ignore + + def proj_feat(self, x, hidden_size, feat_size): + x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) + x = x.permute(0, 4, 1, 2, 3).contiguous() + return x + + def load_from(self, weights): + with torch.no_grad(): + res_weight = weights + # copy weights from patch embedding + for i in weights["state_dict"]: + print(i) + self.vit.patch_embedding.position_embeddings.copy_( + weights["state_dict"]["module.transformer.patch_embedding.position_embeddings_3d"] + ) + self.vit.patch_embedding.cls_token.copy_( + weights["state_dict"]["module.transformer.patch_embedding.cls_token"] + ) + self.vit.patch_embedding.patch_embeddings[1].weight.copy_( + weights["state_dict"]["module.transformer.patch_embedding.patch_embeddings.1.weight"] + ) + self.vit.patch_embedding.patch_embeddings[1].bias.copy_( + weights["state_dict"]["module.transformer.patch_embedding.patch_embeddings.1.bias"] + ) + + # copy weights from encoding blocks (default: num of blocks: 12) + for bname, block in self.vit.blocks.named_children(): + print(block) + block.loadFrom(weights, n_block=bname) + # last norm layer of transformer + self.vit.norm.weight.copy_(weights["state_dict"]["module.transformer.norm.weight"]) + self.vit.norm.bias.copy_(weights["state_dict"]["module.transformer.norm.bias"]) + + def forward(self, x_in): + x, hidden_states_out = self.vit(x_in) + enc1 = self.encoder1(x_in) + x2 = hidden_states_out[3] + enc2 = self.encoder2(self.proj_feat(x2, self.hidden_size, self.feat_size)) + x3 = hidden_states_out[6] + enc3 = self.encoder3(self.proj_feat(x3, self.hidden_size, self.feat_size)) + x4 = hidden_states_out[9] + enc4 = self.encoder4(self.proj_feat(x4, self.hidden_size, self.feat_size)) + dec4 = self.proj_feat(x, self.hidden_size, self.feat_size) + dec3 = self.decoder5(dec4, enc4) + dec2 = self.decoder4(dec3, enc3) + dec1 = self.decoder3(dec2, enc2) + out = self.decoder2(dec1, enc1) + logits = self.out(out) + return logits + diff --git a/pymic/net/net3d/trans3d/unetr_pp.py b/pymic/net/net3d/trans3d/unetr_pp.py new file mode 100644 index 0000000..3ef6736 --- /dev/null +++ b/pymic/net/net3d/trans3d/unetr_pp.py @@ -0,0 +1,460 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional, Sequence, Tuple, Union +from pymic.net.net3d.trans3d.unetr_pp_block import UnetOutBlock, UnetResBlock, get_conv_layer +from timm.models.layers import trunc_normal_ +from monai.utils import optional_import +from monai.networks.blocks.convolutions import Convolution +from monai.networks.layers.factories import Act, Norm +from monai.networks.layers.utils import get_act_layer, get_norm_layer + +einops, _ = optional_import("einops") + +class LayerNorm(nn.Module): + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.normalized_shape = (normalized_shape,) + + def forward(self, x): + if self.data_format == "channels_last": + return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + elif self.data_format == "channels_first": + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + +class EPA(nn.Module): + """ + Efficient Paired Attention Block, based on: "Shaker et al., + UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" + """ + def __init__(self, input_size, hidden_size, proj_size, num_heads=4, qkv_bias=False, + channel_attn_drop=0.1, spatial_attn_drop=0.1): + super().__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + self.temperature2 = nn.Parameter(torch.ones(num_heads, 1, 1)) + + # qkvv are 4 linear layers (query_shared, key_shared, value_spatial, value_channel) + self.qkvv = nn.Linear(hidden_size, hidden_size * 4, bias=qkv_bias) + + # E and F are projection matrices with shared weights used in spatial attention module to project + # keys and values from HWD-dimension to P-dimension + self.E = self.F = nn.Linear(input_size, proj_size) + + self.attn_drop = nn.Dropout(channel_attn_drop) + self.attn_drop_2 = nn.Dropout(spatial_attn_drop) + + self.out_proj = nn.Linear(hidden_size, int(hidden_size // 2)) + self.out_proj2 = nn.Linear(hidden_size, int(hidden_size // 2)) + + def forward(self, x): + B, N, C = x.shape + + qkvv = self.qkvv(x).reshape(B, N, 4, self.num_heads, C // self.num_heads) + + qkvv = qkvv.permute(2, 0, 3, 1, 4) + + q_shared, k_shared, v_CA, v_SA = qkvv[0], qkvv[1], qkvv[2], qkvv[3] + + q_shared = q_shared.transpose(-2, -1) + k_shared = k_shared.transpose(-2, -1) + v_CA = v_CA.transpose(-2, -1) + v_SA = v_SA.transpose(-2, -1) + + k_shared_projected = self.E(k_shared) + + v_SA_projected = self.F(v_SA) + + q_shared = torch.nn.functional.normalize(q_shared, dim=-1) + k_shared = torch.nn.functional.normalize(k_shared, dim=-1) + + attn_CA = (q_shared @ k_shared.transpose(-2, -1)) * self.temperature + + attn_CA = attn_CA.softmax(dim=-1) + attn_CA = self.attn_drop(attn_CA) + + x_CA = (attn_CA @ v_CA).permute(0, 3, 1, 2).reshape(B, N, C) + + attn_SA = (q_shared.permute(0, 1, 3, 2) @ k_shared_projected) * self.temperature2 + + attn_SA = attn_SA.softmax(dim=-1) + attn_SA = self.attn_drop_2(attn_SA) + + x_SA = (attn_SA @ v_SA_projected.transpose(-2, -1)).permute(0, 3, 1, 2).reshape(B, N, C) + + # Concat fusion + x_SA = self.out_proj(x_SA) + x_CA = self.out_proj2(x_CA) + x = torch.cat((x_SA, x_CA), dim=-1) + return x + + @torch.jit.ignore + def no_weight_decay(self): + return {'temperature', 'temperature2'} + + +class TransformerBlock(nn.Module): + """ + A transformer block, based on: "Shaker et al., + UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + proj_size: int, + num_heads: int, + dropout_rate: float = 0.0, + pos_embed=False, + ) -> None: + """ + Args: + input_size: the size of the input for each stage. + hidden_size: dimension of hidden layer. + proj_size: projection size for keys and values in the spatial attention module. + num_heads: number of attention heads. + dropout_rate: faction of the input units to drop. + pos_embed: bool argument to determine if positional embedding is used. + + """ + + super().__init__() + + if not (0 <= dropout_rate <= 1): + raise ValueError("dropout_rate should be between 0 and 1.") + + if hidden_size % num_heads != 0: + print("Hidden size is ", hidden_size) + print("Num heads is ", num_heads) + raise ValueError("hidden_size should be divisible by num_heads.") + + self.norm = nn.LayerNorm(hidden_size) + self.gamma = nn.Parameter(1e-6 * torch.ones(hidden_size), requires_grad=True) + self.epa_block = EPA(input_size=input_size, hidden_size=hidden_size, proj_size=proj_size, num_heads=num_heads, channel_attn_drop=dropout_rate,spatial_attn_drop=dropout_rate) + self.conv51 = UnetResBlock(3, hidden_size, hidden_size, kernel_size=3, stride=1, norm_name="batch") + self.conv8 = nn.Sequential(nn.Dropout3d(0.1, False), nn.Conv3d(hidden_size, hidden_size, 1)) + + self.pos_embed = None + if pos_embed: + self.pos_embed = nn.Parameter(torch.zeros(1, input_size, hidden_size)) + + def forward(self, x): + B, C, H, W, D = x.shape + + x = x.reshape(B, C, H * W * D).permute(0, 2, 1) + + if self.pos_embed is not None: + x = x + self.pos_embed + attn = x + self.gamma * self.epa_block(self.norm(x)) + + attn_skip = attn.reshape(B, H, W, D, C).permute(0, 4, 1, 2, 3) # (B, C, H, W, D) + attn = self.conv51(attn_skip) + x = attn_skip + self.conv8(attn) + + return x + +class UnetrPPEncoder(nn.Module): + def __init__(self, input_size=[32 * 32 * 32, 16 * 16 * 16, 8 * 8 * 8, 4 * 4 * 4],dims=[32, 64, 128, 256], + proj_size =[64,64,64,32], depths=[3, 3, 3, 3], num_heads=4, spatial_dims=3, in_channels=1, dropout=0.0, transformer_dropout_rate=0.15 ,**kwargs): + super().__init__() + + self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers + stem_layer = nn.Sequential( + get_conv_layer(spatial_dims, in_channels, dims[0], kernel_size=(2, 4, 4), stride=(2, 4, 4), + dropout=dropout, conv_only=True, ), + get_norm_layer(name=("group", {"num_groups": in_channels}), channels=dims[0]), + ) + self.downsample_layers.append(stem_layer) + for i in range(3): + downsample_layer = nn.Sequential( + get_conv_layer(spatial_dims, dims[i], dims[i + 1], kernel_size=(2, 2, 2), stride=(2, 2, 2), + dropout=dropout, conv_only=True, ), + get_norm_layer(name=("group", {"num_groups": dims[i]}), channels=dims[i + 1]), + ) + self.downsample_layers.append(downsample_layer) + + self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple Transformer blocks + for i in range(4): + stage_blocks = [] + for j in range(depths[i]): + stage_blocks.append(TransformerBlock(input_size=input_size[i], hidden_size=dims[i], proj_size=proj_size[i], num_heads=num_heads, + dropout_rate=transformer_dropout_rate, pos_embed=True)) + self.stages.append(nn.Sequential(*stage_blocks)) + self.hidden_states = [] + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, (LayerNorm, nn.LayerNorm)): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward_features(self, x): + hidden_states = [] + + x = self.downsample_layers[0](x) + x = self.stages[0](x) + + hidden_states.append(x) + + for i in range(1, 4): + x = self.downsample_layers[i](x) + x = self.stages[i](x) + if i == 3: # Reshape the output of the last stage + x = einops.rearrange(x, "b c h w d -> b (h w d) c") + hidden_states.append(x) + return x, hidden_states + + def forward(self, x): + x, hidden_states = self.forward_features(x) + return x, hidden_states + + +class UnetrUpBlock(nn.Module): + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int], + upsample_kernel_size: Union[Sequence[int], int], + norm_name: Union[Tuple, str], + proj_size: int = 64, + num_heads: int = 4, + out_size: int = 0, + depth: int = 3, + conv_decoder: bool = False, + ) -> None: + """ + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + kernel_size: convolution kernel size. + upsample_kernel_size: convolution kernel size for transposed convolution layers. + norm_name: feature normalization type and arguments. + proj_size: projection size for keys and values in the spatial attention module. + num_heads: number of heads inside each EPA module. + out_size: spatial size for each decoder. + depth: number of blocks for the current decoder stage. + """ + + super().__init__() + upsample_stride = upsample_kernel_size + self.transp_conv = get_conv_layer( + spatial_dims, + in_channels, + out_channels, + kernel_size=upsample_kernel_size, + stride=upsample_stride, + conv_only=True, + is_transposed=True, + ) + + # 4 feature resolution stages, each consisting of multiple residual blocks + self.decoder_block = nn.ModuleList() + + # If this is the last decoder, use ConvBlock(UnetResBlock) instead of EPA_Block (see suppl. material in the paper) + if conv_decoder == True: + self.decoder_block.append( + UnetResBlock(spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1, + norm_name=norm_name, )) + else: + stage_blocks = [] + for j in range(depth): + stage_blocks.append(TransformerBlock(input_size=out_size, hidden_size= out_channels, proj_size=proj_size, num_heads=num_heads, + dropout_rate=0.15, pos_embed=True)) + self.decoder_block.append(nn.Sequential(*stage_blocks)) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, (nn.LayerNorm)): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, inp, skip): + + out = self.transp_conv(inp) + out = out + skip + out = self.decoder_block[0](out) + + return out + + +class UNETR_PP(nn.Module): + """ + UNETR++ based on: "Shaker et al., + UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" + """ + + def __init__(self, params): + """ + Args: + in_channels: dimension of input channels. + out_channels: dimension of output channels. + img_size: dimension of input image. + feature_size: dimension of network feature size. + hidden_size: dimension of the last encoder. + num_heads: number of attention heads. + pos_embed: position embedding layer type. + norm_name: feature normalization type and arguments. + dropout_rate: faction of the input units to drop. + depths: number of blocks for each stage. + dims: number of channel maps for the stages. + conv_op: type of convolution operation. + do_ds: use deep supervision to compute the loss. + + """ + super().__init__() + in_channels = params['in_chns'] + out_channels = params['class_num'] + img_size = params['img_size'] + feature_size = params.get('feature_size', 16) + hidden_size = params.get('hidden_size', 256) + num_heads = params.get('num_heads', 4) + pos_embed = params.get('pos_embed', "perceptron") + norm_name = params.get('norm_name', "instance") + dropout_rate = params.get('dropout_rate', 0.0) + depths = params.get('depths', [3, 3, 3, 3]) + dims = params.get('dims', [32, 64, 128, 256]) + conv_op = nn.Conv3d + do_ds = params.get('deep_supervise', True) + + self.do_ds = do_ds + self.conv_op = conv_op + self.num_classes = out_channels + if not (0 <= dropout_rate <= 1): + raise AssertionError("dropout_rate should be between 0 and 1.") + + if pos_embed not in ["conv", "perceptron"]: + raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.") + + self.patch_size = (2, 4, 4) + self.feat_size = ( + img_size[0] // self.patch_size[0] // 8, # 8 is the downsampling happened through the four encoders stages + img_size[1] // self.patch_size[1] // 8, # 8 is the downsampling happened through the four encoders stages + img_size[2] // self.patch_size[2] // 8, # 8 is the downsampling happened through the four encoders stages + ) + self.hidden_size = hidden_size + + self.unetr_pp_encoder = UnetrPPEncoder(dims=dims, depths=depths, num_heads=num_heads) + + self.encoder1 = UnetResBlock( + spatial_dims=3, + in_channels=in_channels, + out_channels=feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + ) + self.decoder5 = UnetrUpBlock( + spatial_dims=3, + in_channels=feature_size * 16, + out_channels=feature_size * 8, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + out_size=8 * 8 * 8, + ) + self.decoder4 = UnetrUpBlock( + spatial_dims=3, + in_channels=feature_size * 8, + out_channels=feature_size * 4, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + out_size=16 * 16 * 16, + ) + self.decoder3 = UnetrUpBlock( + spatial_dims=3, + in_channels=feature_size * 4, + out_channels=feature_size * 2, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + out_size=32 * 32 * 32, + ) + self.decoder2 = UnetrUpBlock( + spatial_dims=3, + in_channels=feature_size * 2, + out_channels=feature_size, + kernel_size=3, + upsample_kernel_size=(2, 4, 4), + norm_name=norm_name, + out_size=64 * 128 * 128, + conv_decoder=True, + ) + self.out1 = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels) + if self.do_ds: + self.out2 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 2, out_channels=out_channels) + self.out3 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 4, out_channels=out_channels) + + def proj_feat(self, x, hidden_size, feat_size): + x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) + x = x.permute(0, 4, 1, 2, 3).contiguous() + return x + + def forward(self, x_in): + x_output, hidden_states = self.unetr_pp_encoder(x_in) + + convBlock = self.encoder1(x_in) + + # Four encoders + enc1 = hidden_states[0] + enc2 = hidden_states[1] + enc3 = hidden_states[2] + enc4 = hidden_states[3] + + # Four decoders + dec4 = self.proj_feat(enc4, self.hidden_size, self.feat_size) + dec3 = self.decoder5(dec4, enc3) + dec2 = self.decoder4(dec3, enc2) + dec1 = self.decoder3(dec2, enc1) + + out = self.decoder2(dec1, convBlock) + if self.do_ds: + logits = [self.out1(out), self.out2(dec1), self.out3(dec2)] + else: + logits = self.out1(out) + + return logits + + +if __name__ == "__main__": + params = {'in_chns': 1, + 'class_num': 2, + 'img_size': [64, 128, 128] + } + net = UNETR_PP(params) + net.double() + + x = np.random.rand(2, 1, 64, 128, 128) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = net(xt) + print(len(y)) + for yi in y: + yi = yi.detach().numpy() + print(yi.shape) \ No newline at end of file diff --git a/pymic/net/net3d/trans3d/unetr_pp_block.py b/pymic/net/net3d/trans3d/unetr_pp_block.py new file mode 100644 index 0000000..89a8769 --- /dev/null +++ b/pymic/net/net3d/trans3d/unetr_pp_block.py @@ -0,0 +1,278 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import numpy as np +import torch +import torch.nn as nn +from typing import Optional, Sequence, Tuple, Union +from monai.networks.blocks.convolutions import Convolution +from monai.networks.layers.factories import Act, Norm +from monai.networks.layers.utils import get_act_layer, get_norm_layer + + +class UnetResBlock(nn.Module): + """ + A skip-connection based module that can be used for DynUNet, based on: + `Automated Design of Deep Learning Methods for Biomedical Image Segmentation `_. + `nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation `_. + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + kernel_size: convolution kernel size. + stride: convolution stride. + norm_name: feature normalization type and arguments. + act_name: activation layer type and arguments. + dropout: dropout probability. + + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int], + stride: Union[Sequence[int], int], + norm_name: Union[Tuple, str], + act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), + dropout: Optional[Union[Tuple, str, float]] = None, + ): + super().__init__() + self.conv1 = get_conv_layer( + spatial_dims, + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + dropout=dropout, + conv_only=True, + ) + self.conv2 = get_conv_layer( + spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1, dropout=dropout, conv_only=True + ) + self.lrelu = get_act_layer(name=act_name) + self.norm1 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) + self.norm2 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) + self.downsample = in_channels != out_channels + stride_np = np.atleast_1d(stride) + if not np.all(stride_np == 1): + self.downsample = True + if self.downsample: + self.conv3 = get_conv_layer( + spatial_dims, in_channels, out_channels, kernel_size=1, stride=stride, dropout=dropout, conv_only=True + ) + self.norm3 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) + + def forward(self, inp): + residual = inp + out = self.conv1(inp) + out = self.norm1(out) + out = self.lrelu(out) + out = self.conv2(out) + out = self.norm2(out) + if hasattr(self, "conv3"): + residual = self.conv3(residual) + if hasattr(self, "norm3"): + residual = self.norm3(residual) + out += residual + out = self.lrelu(out) + return out + + +class UnetBasicBlock(nn.Module): + """ + A CNN module module that can be used for DynUNet, based on: + `Automated Design of Deep Learning Methods for Biomedical Image Segmentation `_. + `nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation `_. + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + kernel_size: convolution kernel size. + stride: convolution stride. + norm_name: feature normalization type and arguments. + act_name: activation layer type and arguments. + dropout: dropout probability. + + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int], + stride: Union[Sequence[int], int], + norm_name: Union[Tuple, str], + act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), + dropout: Optional[Union[Tuple, str, float]] = None, + ): + super().__init__() + self.conv1 = get_conv_layer( + spatial_dims, + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + dropout=dropout, + conv_only=True, + ) + self.conv2 = get_conv_layer( + spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1, dropout=dropout, conv_only=True + ) + self.lrelu = get_act_layer(name=act_name) + self.norm1 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) + self.norm2 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) + + def forward(self, inp): + out = self.conv1(inp) + out = self.norm1(out) + out = self.lrelu(out) + out = self.conv2(out) + out = self.norm2(out) + out = self.lrelu(out) + return out + + +class UnetUpBlock(nn.Module): + """ + An upsampling module that can be used for DynUNet, based on: + `Automated Design of Deep Learning Methods for Biomedical Image Segmentation `_. + `nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation `_. + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + kernel_size: convolution kernel size. + stride: convolution stride. + upsample_kernel_size: convolution kernel size for transposed convolution layers. + norm_name: feature normalization type and arguments. + act_name: activation layer type and arguments. + dropout: dropout probability. + trans_bias: transposed convolution bias. + + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int], + stride: Union[Sequence[int], int], + upsample_kernel_size: Union[Sequence[int], int], + norm_name: Union[Tuple, str], + act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), + dropout: Optional[Union[Tuple, str, float]] = None, + trans_bias: bool = False, + ): + super().__init__() + upsample_stride = upsample_kernel_size + self.transp_conv = get_conv_layer( + spatial_dims, + in_channels, + out_channels, + kernel_size=upsample_kernel_size, + stride=upsample_stride, + dropout=dropout, + bias=trans_bias, + conv_only=True, + is_transposed=True, + ) + self.conv_block = UnetBasicBlock( + spatial_dims, + out_channels + out_channels, + out_channels, + kernel_size=kernel_size, + stride=1, + dropout=dropout, + norm_name=norm_name, + act_name=act_name, + ) + + def forward(self, inp, skip): + # number of channels for skip should equals to out_channels + out = self.transp_conv(inp) + out = torch.cat((out, skip), dim=1) + out = self.conv_block(out) + return out + + +class UnetOutBlock(nn.Module): + def __init__( + self, spatial_dims: int, in_channels: int, out_channels: int, dropout: Optional[Union[Tuple, str, float]] = None + ): + super().__init__() + self.conv = get_conv_layer( + spatial_dims, in_channels, out_channels, kernel_size=1, stride=1, dropout=dropout, bias=True, conv_only=True + ) + + def forward(self, inp): + return self.conv(inp) + + +def get_conv_layer( + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int] = 3, + stride: Union[Sequence[int], int] = 1, + act: Optional[Union[Tuple, str]] = Act.PRELU, + norm: Union[Tuple, str] = Norm.INSTANCE, + dropout: Optional[Union[Tuple, str, float]] = None, + bias: bool = False, + conv_only: bool = True, + is_transposed: bool = False, +): + padding = get_padding(kernel_size, stride) + output_padding = None + if is_transposed: + output_padding = get_output_padding(kernel_size, stride, padding) + return Convolution( + spatial_dims, + in_channels, + out_channels, + strides=stride, + kernel_size=kernel_size, + act=act, + norm=norm, + dropout=dropout, + bias=bias, + conv_only=conv_only, + is_transposed=is_transposed, + padding=padding, + output_padding=output_padding, + ) + + +def get_padding( + kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int] +) -> Union[Tuple[int, ...], int]: + + kernel_size_np = np.atleast_1d(kernel_size) + stride_np = np.atleast_1d(stride) + padding_np = (kernel_size_np - stride_np + 1) / 2 + if np.min(padding_np) < 0: + raise AssertionError("padding value should not be negative, please change the kernel size and/or stride.") + padding = tuple(int(p) for p in padding_np) + + return padding if len(padding) > 1 else padding[0] + + +def get_output_padding( + kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], padding: Union[Sequence[int], int] +) -> Union[Tuple[int, ...], int]: + kernel_size_np = np.atleast_1d(kernel_size) + stride_np = np.atleast_1d(stride) + padding_np = np.atleast_1d(padding) + + out_padding_np = 2 * padding_np + stride_np - kernel_size_np + if np.min(out_padding_np) < 0: + raise AssertionError("out_padding value should not be negative, please change the kernel size and/or stride.") + out_padding = tuple(int(p) for p in out_padding_np) + + return out_padding if len(out_padding) > 1 else out_padding[0] diff --git a/pymic/net_run/agent_rec.py b/pymic/net_run/agent_rec.py index 188a5fc..4860144 100644 --- a/pymic/net_run/agent_rec.py +++ b/pymic/net_run/agent_rec.py @@ -76,6 +76,7 @@ def training(self): # for debug # from pymic.io.image_read_write import save_nd_array_as_image + # print(inputs.shape) # for i in range(inputs.shape[0]): # image_i = inputs[i][0] # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) @@ -168,7 +169,8 @@ def train_valid(self): ckpt_prefix = self.config['training'].get('ckpt_prefix', None) if(ckpt_prefix is None): ckpt_prefix = ckpt_dir.split('/')[-1] - iter_start = self.config['training']['iter_start'] + # iter_start = self.config['training']['iter_start'] + iter_start = 0 iter_max = self.config['training']['iter_max'] iter_valid = self.config['training']['iter_valid'] iter_save = self.config['training'].get('iter_save', None) @@ -193,7 +195,7 @@ def train_valid(self): else: self.net.load_state_dict(self.checkpoint['model_state_dict']) self.min_val_loss = self.checkpoint.get('valid_loss', 10000) - # self.max_val_it = self.checkpoint['iteration'] + iter_start = self.checkpoint['iteration'] self.max_val_it = iter_start self.best_model_wts = self.checkpoint['model_state_dict'] diff --git a/pymic/net_run/self_sup/__init__.py b/pymic/net_run/self_sup/__init__.py index 55f26bf..73308e6 100644 --- a/pymic/net_run/self_sup/__init__.py +++ b/pymic/net_run/self_sup/__init__.py @@ -1,2 +1,3 @@ from __future__ import absolute_import -from pymic.net_run.self_sup.self_sl_agent import SelfSLSegAgent \ No newline at end of file +from pymic.net_run.self_sup.self_sl_agent import SelfSLSegAgent +from pymic.net_run.self_sup.self_patch_mix_agent import SelfSLPatchMixAgent \ No newline at end of file diff --git a/pymic/net_run/self_sup/self_patch_mix_agent.py b/pymic/net_run/self_sup/self_patch_mix_agent.py new file mode 100644 index 0000000..e30a131 --- /dev/null +++ b/pymic/net_run/self_sup/self_patch_mix_agent.py @@ -0,0 +1,144 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import copy +import os +import sys +import shutil +import time +import logging +import scipy +import torch +import torchvision.transforms as transforms +import numpy as np +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +from datetime import datetime +from random import random +from torch.optim import lr_scheduler +from tensorboardX import SummaryWriter +from pymic.io.image_read_write import save_nd_array_as_image +from pymic.io.nifty_dataset import NiftyDataset +from pymic.net.net_dict_seg import SegNetDict +from pymic.net_run.agent_abstract import NetRunAgent +from pymic.net_run.infer_func import Inferer +from pymic.loss.loss_dict_seg import SegLossDict +from pymic.loss.seg.combined import CombinedLoss +from pymic.loss.seg.deep_sup import DeepSuperviseLoss +from pymic.loss.seg.util import get_soft_label +from pymic.loss.seg.util import reshape_prediction_and_ground_truth +from pymic.loss.seg.util import get_classwise_dice +from pymic.transform.trans_dict import TransformDict +from pymic.util.post_process import PostProcessDict +from pymic.util.image_process import convert_label +from pymic.util.parse_config import * +from pymic.io.image_read_write import save_nd_array_as_image +from pymic.net_run.self_sup.util import patch_mix +from pymic.net_run.agent_seg import SegmentationAgent + +class SelfSLPatchMixAgent(SegmentationAgent): + """ + Abstract class for self-supervised segmentation. + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. + """ + def __init__(self, config, stage = 'train'): + super(SelfSLPatchMixAgent, self).__init__(config, stage) + + def training(self): + class_num = self.config['network']['class_num'] + iter_valid = self.config['training']['iter_valid'] + fg_num = self.config['network']['class_num'] - 1 + patch_num = self.config['patch_mix']['patch_num_range'] + size_d = self.config['patch_mix']['patch_depth_range'] + size_h = self.config['patch_mix']['patch_height_range'] + size_w = self.config['patch_mix']['patch_width_range'] + + train_loss = 0 + train_dice_list = [] + self.net.train() + for it in range(iter_valid): + try: + data = next(self.trainIter) + except StopIteration: + self.trainIter = iter(self.train_loader) + data = next(self.trainIter) + # get the inputs + inputs = self.convert_tensor_type(data['image']) + inputs, labels_prob = patch_mix(inputs, fg_num, patch_num, size_d, size_h, size_w) + + # # for debug + # if(it==10): + # break + # for i in range(inputs.shape[0]): + # image_i = inputs[i][0] + # label_i = np.argmax(labels_prob[i], axis = 0) + # # pixw_i = pix_w[i][0] + # print(image_i.shape, label_i.shape) + # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) + # label_name = "temp/label_{0:}_{1:}.nii.gz".format(it, i) + # # weight_name= "temp/weight_{0:}_{1:}.nii.gz".format(it, i) + # save_nd_array_as_image(image_i, image_name, reference_name = None) + # save_nd_array_as_image(label_i, label_name, reference_name = None) + # # save_nd_array_as_image(pixw_i, weight_name, reference_name = None) + # continue + + inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device) + + # zero the parameter gradients + self.optimizer.zero_grad() + + # forward + backward + optimize + outputs = self.net(inputs) + loss = self.get_loss_value(data, outputs, labels_prob) + loss.backward() + self.optimizer.step() + train_loss = train_loss + loss.item() + # get dice evaluation for each class + if(isinstance(outputs, tuple) or isinstance(outputs, list)): + outputs = outputs[0] + outputs_argmax = torch.argmax(outputs, dim = 1, keepdim = True) + soft_out = get_soft_label(outputs_argmax, class_num, self.tensor_type) + soft_out, labels_prob = reshape_prediction_and_ground_truth(soft_out, labels_prob) + dice_list = get_classwise_dice(soft_out, labels_prob) + train_dice_list.append(dice_list.cpu().numpy()) + train_avg_loss = train_loss / iter_valid + train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) + train_avg_dice = train_cls_dice[1:].mean() + + train_scalers = {'loss': train_avg_loss, 'avg_fg_dice':train_avg_dice,\ + 'class_dice': train_cls_dice} + return train_scalers + +def main(): + cfg_file = str(sys.argv[1]) + if(not os.path.isfile(cfg_file)): + raise ValueError("The config file does not exist: " + cfg_file) + config = parse_config(cfg_file) + config = synchronize_config(config) + log_dir = config['training']['ckpt_save_dir'] + if(not os.path.exists(log_dir)): + os.makedirs(log_dir, exist_ok=True) + dst_cfg = cfg_file if "/" not in cfg_file else cfg_file.split("/")[-1] + shutil.copy(cfg_file, log_dir + "/" + dst_cfg) + if sys.version.startswith("3.9"): + logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(str(datetime.now())[:-7]), + level=logging.INFO, format='%(message)s', force=True) # for python 3.9 + else: + logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(str(datetime.now())[:-7]), + level=logging.INFO, format='%(message)s') # for python 3.6 + logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) + logging_config(config) + agent = SelfSLPatchMixAgent(config) + agent.run() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/pymic/net_run/self_sup/util.py b/pymic/net_run/self_sup/util.py new file mode 100644 index 0000000..e131941 --- /dev/null +++ b/pymic/net_run/self_sup/util.py @@ -0,0 +1,167 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import os +import torch +import random +import numpy as np +from scipy import ndimage +from pymic.io.image_read_write import * +from pymic.util.image_process import * +from pymic.util.general import get_one_hot_seg + +def get_human_region_mask(img): + """ + Get the mask of human region in CT volumes + """ + dim = len(img.shape) + if( dim == 4): + img = img[0] + mask = np.asarray(img > -600) + se = np.ones([3,3,3]) + mask = ndimage.binary_opening(mask, se, iterations = 2) + mask = get_largest_k_components(mask, 1) + mask_close = ndimage.binary_closing(mask, se, iterations = 3) + + D, H, W = mask.shape + for d in [1, 2, D-3, D-2]: + mask_close[d] = mask[d] + for d in [0, -1, int(D/2)]: + mask_close[d, 1:-1, 1:-1] = np.ones((H-2, W-2)) + + bg = get_largest_k_components(1- mask_close, 1) + fg = 1 - bg + se = np.ones([3,3,3]) + fg = ndimage.binary_opening(fg, se, iterations = 1) + fg = get_largest_k_components(fg, 1) + if(dim == 4): + fg = np.expand_dims(fg, 0) + fg = np.asarray(fg, np.uint8) + return fg + +def crop_ct_scan(input_img, output_img, input_lab = None, output_lab = None): + """ + Crop a CT scan based on the bounding box of the human region. + """ + img_obj = sitk.ReadImage(input_img) + img = sitk.GetArrayFromImage(img_obj) + mask = np.asarray(img > -600) + se = np.ones([3,3,3]) + mask = ndimage.binary_opening(mask, se, iterations = 2) + mask = get_largest_k_components(mask, 1) + bbmin, bbmax = get_ND_bounding_box(mask, margin = [5, 10, 10]) + img_sub = crop_ND_volume_with_bounding_box(img, bbmin, bbmax) + img_sub_obj = sitk.GetImageFromArray(img_sub) + img_sub_obj.SetSpacing(img_obj.GetSpacing()) + sitk.WriteImage(img_sub_obj, output_img) + if(input_lab is not None): + lab_obj = sitk.ReadImage(input_lab) + lab = sitk.GetArrayFromImage(lab_obj) + lab_sub = crop_ND_volume_with_bounding_box(lab, bbmin, bbmax) + lab_sub_obj = sitk.GetImageFromArray(lab_sub) + lab_sub_obj.SetSpacing(img_obj.GetSpacing()) + sitk.WriteImage(lab_sub_obj, output_lab) + + +def patch_mix(x, fg_num, patch_num, size_d, size_h, size_w): + """ + Copy a sub region of an impage and paste to another one to generate + images and labels for self-supervised segmentation. + """ + N, C, D, H, W = list(x.shape) + fg_mask = torch.zeros_like(x) + # generate mask + for n in range(N): + p_num = random.randint(patch_num[0], patch_num[1]) + for i in range(p_num): + d = random.randint(size_d[0], size_d[1]) + h = random.randint(size_h[0], size_h[1]) + w = random.randint(size_w[0], size_w[1]) + d_c = random.randint(0, D) + h_c = random.randint(0, H) + w_c = random.randint(0, W) + d0, d1 = max(0, d_c - d), min(D, d_c + d) + h0, h1 = max(0, h_c - h), min(H, h_c + h) + w0, w1 = max(0, w_c - w), min(W, w_c + w) + temp_m = torch.ones([C, d1-d0, h1-h0, w1-w0]) * random.randint(1, fg_num) + fg_mask[n, :, d0:d1, h0:h1, w0:w1] = temp_m + fg_w = fg_mask * 1.0 / fg_num + x_roll = torch.roll(x, 1, 0) + x_fuse = fg_w*x_roll + (1.0 - fg_w)*x + y_prob = get_one_hot_seg(fg_mask.to(torch.int32), fg_num + 1) + return x_fuse, y_prob + +def create_mixed_dataset(input_dir, output_dir, fg_num = 1, crop_num = 1, + mask = 'default', data_format = "nii.gz"): + """ + Create dataset based on patch mix. + + :param input_dir: (str) The path of folder for input images + :param output_dir: (str) The path of folder for output images + :param fg_num: (int) The number of foreground classes + :param crop_num: (int) The number of patches to crop for each input image + :param mask: ND array to specify a mask, or 'default' or None. If default, + a mask for body region is automatically generated (just for CT). + :param data_format: (str) The format of images. + """ + img_names = os.listdir(input_dir) + img_names = [item for item in img_names if item.endswith(data_format)] + img_names = sorted(img_names) + out_img_dir = output_dir + "/image" + out_lab_dir = output_dir + "/label" + if(not os.path.exists(out_img_dir)): + os.mkdir(out_img_dir) + if(not os.path.exists(out_lab_dir)): + os.mkdir(out_lab_dir) + + img_num = len(img_names) + print("image number", img_num) + i_range = range(img_num) + j_range = list(i_range) + random.shuffle(j_range) + for i in i_range: + print(i, img_names[i]) + j = j_range[i] + if(i == j): + j = i + 1 if i < img_num - 1 else 0 + img_i = load_image_as_nd_array(input_dir + "/" + img_names[i])['data_array'] + img_j = load_image_as_nd_array(input_dir + "/" + img_names[j])['data_array'] + + chns = img_i.shape[0] + # random crop to patch size + if(mask == 'default'): + mask_i = get_human_region_mask(img_i) + mask_j = get_human_region_mask(img_j) + for k in range(crop_num): + if(mask is None): + img_ik = random_crop_ND_volume(img_i, [chns, 96, 96, 96]) + img_jk = random_crop_ND_volume(img_j, [chns, 96, 96, 96]) + else: + img_ik = random_crop_ND_volume_with_mask(img_i, [chns, 96, 96, 96], mask_i) + img_jk = random_crop_ND_volume_with_mask(img_j, [chns, 96, 96, 96], mask_j) + C, D, H, W = img_ik.shape + # generate mask + fg_mask = np.zeros_like(img_ik, np.uint8) + patch_num = random.randint(4, 40) + for patch in range(patch_num): + d = random.randint(4, 20) # half of window size + h = random.randint(4, 40) + w = random.randint(4, 40) + d_c = random.randint(0, D) + h_c = random.randint(0, H) + w_c = random.randint(0, W) + d0, d1 = max(0, d_c - d), min(D, d_c + d) + h0, h1 = max(0, h_c - h), min(H, h_c + h) + w0, w1 = max(0, w_c - w), min(W, w_c + w) + temp_m = np.ones([C, d1-d0, h1-h0, w1-w0]) * random.randint(1, fg_num) + fg_mask[:, d0:d1, h0:h1, w0:w1] = temp_m + fg_w = fg_mask * 1.0 / fg_num + x_fuse = fg_w*img_jk + (1.0 - fg_w)*img_ik + + out_name = img_names[i] + if crop_num > 1: + out_name = out_name.replace(".nii.gz", "_{0:}.nii.gz".format(k)) + save_nd_array_as_image(x_fuse[0], out_img_dir + "/" + out_name, + reference_name = input_dir + "/" + img_names[i]) + save_nd_array_as_image(fg_mask[0], out_lab_dir + "/" + out_name, + reference_name = input_dir + "/" + img_names[i]) + diff --git a/pymic/transform/label_convert.py b/pymic/transform/label_convert.py index e3accdf..00c505e 100644 --- a/pymic/transform/label_convert.py +++ b/pymic/transform/label_convert.py @@ -152,29 +152,16 @@ def __call__(self, sample): return sample -class SelfSuperviseLabel(AbstractTransform): +class SelfReconstructionLabel(AbstractTransform): """ - Convert one-channel partial label map to one-hot multi-channel probability map. - This is used for segmentation tasks only. In the input label map, 0 represents the - background class, 1 to C-1 represent the foreground classes, and C represents - unlabeled pixels. In the output dictionary, `label_prob` is the one-hot probability - map, and `pixel_weight` represents a weighting map, where the weight for a pixel - is 0 if the label is unkown. - - The arguments should be written in the `params` dictionary, and it has the - following fields: - - :param `PartialLabelToProbability_class_num`: (int) The class number for the - segmentation task. - :param `PartialLabelToProbability_inverse`: (optional, bool) - Is inverse transform needed for inference. Default is `False`. + Used for self-supervised learning with image reconstruction tasks. """ def __init__(self, params): """ class_num (int): the class number in the label map """ - super(SelfSuperviseLabel, self).__init__(params) - self.inverse = params.get('SelfSuperviseLabel_inverse'.lower(), False) + super(SelfReconstructionLabel, self).__init__(params) + self.inverse = params.get('SelfReconstructionLabel_inverse'.lower(), False) def __call__(self, sample): image = sample['image'] @@ -183,3 +170,41 @@ def __call__(self, sample): return sample +class MaskedImageModelingLabel(AbstractTransform): + """ + Used for self-supervised learning with image reconstruction tasks. + Only reconstruct the masked region in the input. + The input images is masked in local patches. + """ + def __init__(self, params): + """ + class_num (int): the class number in the label map + """ + super(MaskedImageModelingLabel, self).__init__(params) + self.patch_size = params.get('MaskedImageModelingLabel_patch_size'.lower(), [16, 16, 16]) + self.masking_ratio = params.get('MaskedImageModelingLabel_ratio'.lower(), 0.15) + self.inverse = params.get('MaskedImageModelingLabel_inverse'.lower(), False) + + def __call__(self, sample): + image = sample['image'] + C, D, H, W = image.shape + patch_size = self.patch_size + mask = np.ones([D, H, W], np.float32) + grid_size = [math.ceil((image.shape[i+1] + 0.0) / patch_size[i]) for i in range(3)] + for d in range(grid_size[0]): + d0 = d*patch_size[0] + for h in range(grid_size[1]): + h0 = h*patch_size[1] + for w in range(grid_size[2]): + w0 = w*patch_size[2] + if(random.random() > self.masking_ratio): + continue + d1 = min(d0 + patch_size[0], D) + h1 = min(h0 + patch_size[1], H) + w1 = min(w0 + patch_size[2], W) + mask[d0:d1, h0:h1, w0:w1] = np.zeros([d1 - d0, h1 - h0, w1 - w0]) + sample['pixel_weight'] = 1 - mask + sample['image'] = image * mask + sample['label'] = image + return sample + diff --git a/pymic/transform/mix.py b/pymic/transform/mix.py new file mode 100644 index 0000000..fe2f315 --- /dev/null +++ b/pymic/transform/mix.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import copy +import json +import math +import random +import numpy as np +from pymic.transform.abstract_transform import AbstractTransform +from pymic.util.image_process import * +try: # SciPy >= 0.19 + from scipy.special import comb +except ImportError: + from scipy.misc import comb + + +class CopyPaste(AbstractTransform): + """ + In-painting of an input image, used for self-supervised learning + """ + def __init__(self, params): + super(CopyPaste, self).__init__(params) + self.inverse = params.get('CopyPaste_inverse'.lower(), False) + self.block_range = params.get('CopyPaste_block_range'.lower(), (1, 6)) + self.block_size_min = params.get('CopyPaste_block_size_min'.lower(), None) + self.block_size_max = params.get('CopyPaste_block_size_max'.lower(), None) + + def __call__(self, sample): + image= sample['image'] + img_shape = image.shape + img_dim = len(img_shape) - 1 + assert(img_dim == 2 or img_dim == 3) + + if(self.block_size_min is None): + block_size_min = [img_shape[1+i]//6 for i in range(img_dim)] + elif(isinstance(self.block_size_min, int)): + block_size_min = [self.block_size_min] * img_dim + else: + assert(len(self.block_size_min) == img_dim) + block_size_min = self.block_size_min + + if(self.block_size_max is None): + block_size_max = [img_shape[1+i]//3 for i in range(img_dim)] + elif(isinstance(self.block_size_min, int)): + block_size_max = [self.block_size_max] * img_dim + else: + assert(len(self.block_size_max) == img_dim) + block_size_max = self.block_size_max + block_num = random.randint(self.block_range[0], self.block_range[1]) + + for n in range(block_num): + block_size = [random.randint(block_size_min[i], block_size_max[i]) \ + for i in range(img_dim)] + coord_min = [random.randint(3, img_shape[1+i] - block_size[i] - 3) \ + for i in range(img_dim)] + if(img_dim == 2): + random_block = np.random.rand(img_shape[0], block_size[0], block_size[1]) + image[:, coord_min[0]:coord_min[0] + block_size[0], + coord_min[1]:coord_min[1] + block_size[1]] = random_block + else: + random_block = np.random.rand(img_shape[0], block_size[0], + block_size[1], block_size[2]) + image[:, coord_min[0]:coord_min[0] + block_size[0], + coord_min[1]:coord_min[1] + block_size[1], + coord_min[2]:coord_min[2] + block_size[2]] = random_block + sample['image'] = image + return sample \ No newline at end of file diff --git a/pymic/transform/trans_dict.py b/pymic/transform/trans_dict.py index c1ecdc2..5dac73b 100644 --- a/pymic/transform/trans_dict.py +++ b/pymic/transform/trans_dict.py @@ -71,7 +71,8 @@ 'RandomRotate': RandomRotate, 'ReduceLabelDim': ReduceLabelDim, 'Rescale': Rescale, - 'SelfSuperviseLabel': SelfSuperviseLabel, + 'SelfReconstructionLabel': SelfReconstructionLabel, + 'MaskedImageModelingLabel': MaskedImageModelingLabel, 'OutPainting': OutPainting, 'Pad': Pad, } From bebf534ef9673961806bea4703cb6946560b8834 Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 22 May 2023 12:40:04 +0800 Subject: [PATCH 151/225] update random crop allow default setting of foreground labels --- pymic/transform/crop.py | 21 +++++++++++---------- pymic/util/image_process.py | 36 +++++++++++++++++------------------- 2 files changed, 28 insertions(+), 29 deletions(-) diff --git a/pymic/transform/crop.py b/pymic/transform/crop.py index 95e3489..b8130f9 100644 --- a/pymic/transform/crop.py +++ b/pymic/transform/crop.py @@ -221,7 +221,8 @@ class RandomCrop(CenterCrop): `RandomCrop_foreground_focus` is True. :param `RandomCrop_mask_label`: (optional, None, or list/tuple) Specifying the foreground labels for foreground focus cropping when - `RandomCrop_foreground_focus` is True. + `RandomCrop_foreground_focus` is True. If it is None (by default), + the mask label will be the list of all the foreground classes. :param `RandomCrop_inverse`: (optional, bool) Is inverse transform needed for inference. Default is `True`. """ @@ -229,7 +230,7 @@ def __init__(self, params): self.output_size = params['RandomCrop_output_size'.lower()] self.fg_focus = params.get('RandomCrop_foreground_focus'.lower(), False) self.fg_ratio = params.get('RandomCrop_foreground_ratio'.lower(), 0.5) - self.mask_label = params.get('RandomCrop_mask_label'.lower(), [1]) + self.mask_label = params.get('RandomCrop_mask_label'.lower(), None) self.inverse = params.get('RandomCrop_inverse'.lower(), True) self.task = params['Task'.lower()] assert isinstance(self.output_size, (list, tuple)) @@ -246,16 +247,16 @@ def _get_crop_param(self, sample): crop_margin = [input_shape[i] - self.output_size[i] for i in range(input_dim)] crop_min = [0 if item == 0 else random.randint(0, item) for item in crop_margin] crop_max = [crop_min[i] + self.output_size[i] for i in range(input_dim)] + if(self.fg_focus and random.random() < self.fg_ratio): label = sample['label'][0] - mask = np.zeros_like(label) - for temp_lab in self.mask_label: - mask = np.maximum(mask, label == temp_lab) - if(mask.max() > 0): - crop_min, crop_max = get_random_box_from_mask(mask, self.output_size) - # to avoid Typeerror: object of type int64 is not json serializable - crop_min = [int(i) for i in crop_min] - crop_max = [int(i) for i in crop_max] + if(self.mask_label is None): + mask_label = np.unique(label)[1:] + else: + mask_label = self.mask_label + random_label = random.choice(mask_label) + crop_min, crop_max = get_random_box_from_mask(label == random_label, self.output_size) + crop_min = [0] + crop_min crop_max = [chns] + crop_max diff --git a/pymic/util/image_process.py b/pymic/util/image_process.py index d5d7a7e..98409d7 100644 --- a/pymic/util/image_process.py +++ b/pymic/util/image_process.py @@ -165,23 +165,19 @@ def random_crop_ND_volume(volume, out_shape): return crop_volume def get_random_box_from_mask(mask, out_shape): - mask_shape = mask.shape - dim = len(out_shape) - left_margin = [int(out_shape[i]/2) for i in range(dim)] - right_margin= [mask_shape[i] - (out_shape[i] - left_margin[i]) + 1 for i in range(dim)] - - valid_center_shape = [right_margin[i] - left_margin[i] for i in range(dim)] - valid_mask = np.zeros(mask_shape) - valid_mask = set_ND_volume_roi_with_bounding_box_range(valid_mask, - left_margin, right_margin, np.ones(valid_center_shape)) - valid_mask = valid_mask * mask - - indexes = np.where(valid_mask) + indexes = np.where(mask) voxel_num = len(indexes[0]) - j = random.randint(0, voxel_num - 1) - bb_c = [indexes[i][j] for i in range(dim)] - bb_min = [bb_c[i] - left_margin[i] for i in range(dim)] + dim = len(out_shape) + left_bound = [int(out_shape[i]/2) for i in range(dim)] + right_bound = [mask.shape[i] - (out_shape[i] - left_bound[i]) for i in range(dim)] + + j = random.randint(0, voxel_num - 1) + bb_c = [int(indexes[i][j]) for i in range(dim)] + bb_c = [max(left_bound[i], bb_c[i]) for i in range(dim)] + bb_c = [min(right_bound[i], bb_c[i]) for i in range(dim)] + bb_min = [bb_c[i] - left_bound[i] for i in range(dim)] bb_max = [bb_min[i] + out_shape[i] for i in range(dim)] + return bb_min, bb_max def random_crop_ND_volume_with_mask(volume, out_shape, mask): @@ -234,7 +230,8 @@ def get_largest_k_components(image, k = 1): :param image: The input ND array for binary segmentation. :param k: (int) The value of k. - :return: An output array with only the largest K components of the input. + :return: An output array (k == 1) or a list of ND array (k>1) + with only the largest K components of the input. """ dim = len(image.shape) if(image.sum() == 0 ): @@ -247,11 +244,12 @@ def get_largest_k_components(image, k = 1): sizes = ndimage.sum(image, labeled_array, range(1, numpatches + 1)) sizes_sort = sorted(sizes, reverse = True) kmin = min(k, numpatches) - output = np.zeros_like(image) + output = [] for i in range(kmin): labeli = np.where(sizes == sizes_sort[i])[0] + 1 - output = output + np.asarray(labeled_array == labeli, np.uint8) - return output + output_i = np.asarray(labeled_array == labeli, np.uint8) + output.append(output_i) + return output[0] if k == 1 else output def get_euclidean_distance(image, dim = 3, spacing = [1.0, 1.0, 1.0]): """ From 11ffcd01ce337d51d75130c297e8047ae6e002eb Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 2 Jun 2023 22:05:24 +0800 Subject: [PATCH 152/225] update infer_func update multi-scale prediction with gaussian weight --- pymic/net_run/agent_rec.py | 56 +++++++++++++++++----------- pymic/net_run/agent_seg.py | 8 +++- pymic/net_run/infer_func.py | 14 +++---- pymic/net_run/self_sup/util.py | 38 ++++++++++++------- pymic/util/image_process.py | 68 ++++++++++++++++++++++++++++++++++ 5 files changed, 142 insertions(+), 42 deletions(-) diff --git a/pymic/net_run/agent_rec.py b/pymic/net_run/agent_rec.py index 4860144..1e58bc6 100644 --- a/pymic/net_run/agent_rec.py +++ b/pymic/net_run/agent_rec.py @@ -16,6 +16,7 @@ from pymic.net_run.infer_func import Inferer from pymic.net_run.agent_seg import SegmentationAgent from pymic.loss.seg.mse import MAELoss, MSELoss +from pymic.util.general import mixup, tensor_shape_match ReconstructionLossDict = { 'MAELoss': MAELoss, @@ -165,6 +166,7 @@ def train_valid(self): else: self.device = torch.device("cuda:{0:}".format(device_ids[0])) self.net.to(self.device) + ckpt_dir = self.config['training']['ckpt_save_dir'] ckpt_prefix = self.config['training'].get('ckpt_prefix', None) if(ckpt_prefix is None): @@ -186,20 +188,31 @@ def train_valid(self): self.max_val_it = 0 self.best_model_wts = None self.checkpoint = None - if(iter_start > 0): - checkpoint_file = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, iter_start) - self.checkpoint = torch.load(checkpoint_file, map_location = self.device) - # assert(self.checkpoint['iteration'] == iter_start) - if(len(device_ids) > 1): - self.net.module.load_state_dict(self.checkpoint['model_state_dict']) + # initialize the network with pre-trained weights + ckpt_init_name = self.config['training'].get('ckpt_init_name', None) + ckpt_init_mode = self.config['training'].get('ckpt_init_mode', 0) + ckpt_for_optm = None + if(ckpt_init_name is not None): + checkpoint = torch.load(ckpt_dir + "/" + ckpt_init_name, map_location = self.device) + pretrained_dict = checkpoint['model_state_dict'] + model_dict = self.net.module.state_dict() if (len(device_ids) > 1) else self.net.state_dict() + pretrained_dict = {k: v for k, v in pretrained_dict.items() if \ + k in model_dict and tensor_shape_match(pretrained_dict[k], model_dict[k])} + logging.info("Initializing the following parameters with pre-trained model") + for k in pretrained_dict: + logging.info(k) + if (len(device_ids) > 1): + self.net.module.load_state_dict(pretrained_dict, strict = False) else: - self.net.load_state_dict(self.checkpoint['model_state_dict']) - self.min_val_loss = self.checkpoint.get('valid_loss', 10000) - iter_start = self.checkpoint['iteration'] - self.max_val_it = iter_start - self.best_model_wts = self.checkpoint['model_state_dict'] + self.net.load_state_dict(pretrained_dict, strict = False) + if(ckpt_init_mode > 0): # Load other information + self.min_val_loss = self.checkpoint.get('valid_loss', 10000) + iter_start = checkpoint['iteration'] + self.max_val_it = iter_start + self.best_model_wts = checkpoint['model_state_dict'] + ckpt_for_optm = checkpoint - self.create_optimizer(self.get_parameters_to_update()) + self.create_optimizer(self.get_parameters_to_update(), ckpt_for_optm) self.create_loss_calculator() self.trainIter = iter(self.train_loader) @@ -231,6 +244,16 @@ def train_valid(self): self.best_model_wts = copy.deepcopy(self.net.module.state_dict()) else: self.best_model_wts = copy.deepcopy(self.net.state_dict()) + + save_dict = {'iteration': self.max_val_it, + 'valid_loss': self.min_val_loss, + 'model_state_dict': self.best_model_wts, + 'optimizer_state_dict': self.optimizer.state_dict()} + save_name = "{0:}/{1:}_best.pt".format(ckpt_dir, ckpt_prefix) + torch.save(save_dict, save_name) + txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefix), 'wt') + txt_file.write(str(self.max_val_it)) + txt_file.close() stop_now = True if(early_stop_it is not None and \ self.glob_it - self.max_val_it > early_stop_it) else False @@ -249,15 +272,6 @@ def train_valid(self): logging.info("The training is early stopped") break # save the best performing checkpoint - save_dict = {'iteration': self.max_val_it, - 'valid_loss': self.min_val_loss, - 'model_state_dict': self.best_model_wts, - 'optimizer_state_dict': self.optimizer.state_dict()} - save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.max_val_it) - torch.save(save_dict, save_name) - txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefix), 'wt') - txt_file.write(str(self.max_val_it)) - txt_file.close() logging.info('The best performing iter is {0:}, valid loss {1:}'.format(\ self.max_val_it, self.min_val_loss)) self.summ_writer.close() diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index b007574..3d2fb21 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -89,7 +89,12 @@ def create_network(self): logging.info('parameter number {0:}'.format(param_number)) def get_parameters_to_update(self): - return self.net.parameters() + if hasattr(self.net, "get_parameters_to_update"): + params = self.net.get_parameters_to_update() + else: + params = self.net.parameters() + return params + def create_loss_calculator(self): if(self.loss_dict is None): @@ -401,6 +406,7 @@ def test_time_dropout(m): raise ValueError("ckpt_mode should be 3 if ckpt_name is a list") # load network parameters and set the network as evaluation mode + print("ckpt name", ckpt_name) checkpoint = torch.load(ckpt_name, map_location = device) self.net.load_state_dict(checkpoint['model_state_dict']) diff --git a/pymic/net_run/infer_func.py b/pymic/net_run/infer_func.py index d2b9af2..b0190ad 100644 --- a/pymic/net_run/infer_func.py +++ b/pymic/net_run/infer_func.py @@ -146,7 +146,8 @@ def __infer_with_sliding_window(self, image): output_shape_i = [batch_size, class_num] + \ [int(img_shape[d] * scale_list[i][d]) for d in range(img_dim)] output_list.append(torch.zeros(output_shape_i).to(image.device)) - + temp_ws = [interpolate(temp_w, scale_factor = scale_list[i]) for i in range(out_num)] + weights = [interpolate(weight, scale_factor = scale_list[i]) for i in range(out_num)] for w_i in range(0, window_num, window_batch): for k in range(window_batch): if(w_i + k >= window_num): @@ -167,14 +168,13 @@ def __infer_with_sliding_window(self, image): c0_i = [int(c0[d] * scale_list[i][d]) for d in range(img_dim)] c1_i = [int(c1[d] * scale_list[i][d]) for d in range(img_dim)] if(img_dim == 2): - output_list[i][:, :, c0_i[0]:c1_i[0], c0_i[1]:c1_i[1]] += patches_out[i][k] * temp_w - weight[:, :, c0[0]:c1[0], c0[1]:c1[1]] += temp_w + output_list[i][:, :, c0_i[0]:c1_i[0], c0_i[1]:c1_i[1]] += patches_out[i][k] * temp_ws[i] + weights[i][:, :, c0_i[0]:c1_i[0], c0_i[1]:c1_i[1]] += temp_ws[i] else: - output_list[i][:, :, c0_i[0]:c1_i[0], c0_i[1]:c1_i[1], c0_i[2]:c1_i[2]] += patches_out[i][k] * temp_w - weight[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] += temp_w + output_list[i][:, :, c0_i[0]:c1_i[0], c0_i[1]:c1_i[1], c0_i[2]:c1_i[2]] += patches_out[i][k] * temp_ws[i] + weights[i][:, :, c0_i[0]:c1_i[0], c0_i[1]:c1_i[1], c0_i[2]:c1_i[2]] += temp_ws[i] for i in range(out_num): - weight_i = interpolate(weight, scale_factor = scale_list[i]) - output_list[i] = output_list[i] / weight_i + output_list[i] = output_list[i] / weights[i] return output_list def run(self, model, image): diff --git a/pymic/net_run/self_sup/util.py b/pymic/net_run/self_sup/util.py index e131941..9cffaa7 100644 --- a/pymic/net_run/self_sup/util.py +++ b/pymic/net_run/self_sup/util.py @@ -20,17 +20,26 @@ def get_human_region_mask(img): se = np.ones([3,3,3]) mask = ndimage.binary_opening(mask, se, iterations = 2) mask = get_largest_k_components(mask, 1) - mask_close = ndimage.binary_closing(mask, se, iterations = 3) + mask_close = ndimage.binary_closing(mask, se, iterations = 2) D, H, W = mask.shape for d in [1, 2, D-3, D-2]: mask_close[d] = mask[d] - for d in [0, -1, int(D/2)]: - mask_close[d, 1:-1, 1:-1] = np.ones((H-2, W-2)) + for d in range(0, D, 2): + mask_close[d, 2:-2, 2:-2] = np.ones((H-4, W-4)) - bg = get_largest_k_components(1- mask_close, 1) + # get background component + bg = np.zeros_like(mask) + bgs = get_largest_k_components(1- mask_close, 10) + for bgi in bgs: + indices = np.where(bgi) + if(bgi.sum() < 1000): + break + if(indices[0].min() == 0 or indices[1].min() == 0 or indices[2].min() ==0 or \ + indices[0].max() == D-1 or indices[1].max() == H-1 or indices[2].max() ==W-1): + bg = bg + bgi fg = 1 - bg - se = np.ones([3,3,3]) + fg = ndimage.binary_opening(fg, se, iterations = 1) fg = get_largest_k_components(fg, 1) if(dim == 4): @@ -91,7 +100,7 @@ def patch_mix(x, fg_num, patch_num, size_d, size_h, size_w): return x_fuse, y_prob def create_mixed_dataset(input_dir, output_dir, fg_num = 1, crop_num = 1, - mask = 'default', data_format = "nii.gz"): + mask_dir = None, data_format = "nii.gz"): """ Create dataset based on patch mix. @@ -128,16 +137,19 @@ def create_mixed_dataset(input_dir, output_dir, fg_num = 1, crop_num = 1, chns = img_i.shape[0] # random crop to patch size - if(mask == 'default'): + if(mask_dir is None): mask_i = get_human_region_mask(img_i) mask_j = get_human_region_mask(img_j) + else: + mask_i = load_image_as_nd_array(mask_dir + "/" + img_names[i])['data_array'] + mask_j = load_image_as_nd_array(mask_dir + "/" + img_names[j])['data_array'] for k in range(crop_num): - if(mask is None): - img_ik = random_crop_ND_volume(img_i, [chns, 96, 96, 96]) - img_jk = random_crop_ND_volume(img_j, [chns, 96, 96, 96]) - else: - img_ik = random_crop_ND_volume_with_mask(img_i, [chns, 96, 96, 96], mask_i) - img_jk = random_crop_ND_volume_with_mask(img_j, [chns, 96, 96, 96], mask_j) + # if(mask is None): + # img_ik = random_crop_ND_volume(img_i, [chns, 96, 96, 96]) + # img_jk = random_crop_ND_volume(img_j, [chns, 96, 96, 96]) + # else: + img_ik = random_crop_ND_volume_with_mask(img_i, [chns, 96, 96, 96], mask_i) + img_jk = random_crop_ND_volume_with_mask(img_j, [chns, 96, 96, 96], mask_j) C, D, H, W = img_ik.shape # generate mask fg_mask = np.zeros_like(img_ik, np.uint8) diff --git a/pymic/util/image_process.py b/pymic/util/image_process.py index 98409d7..8ae8e80 100644 --- a/pymic/util/image_process.py +++ b/pymic/util/image_process.py @@ -1,10 +1,13 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, print_function +import csv import random +import pandas as pd import numpy as np import SimpleITK as sitk from scipy import ndimage +from pymic.io.image_read_write import load_image_as_nd_array def get_ND_bounding_box(volume, margin = None): """ @@ -315,3 +318,68 @@ def resample_sitk_image_to_given_spacing(image, spacing, order): out_img.SetSpacing(spacing) out_img.SetDirection(image.GetDirection()) return out_img + +def get_image_info(img_names): + space0, space1, slices = [], [], [] + for img_name in img_names: + img_obj = sitk.ReadImage(img_name) + img_arr = sitk.GetArrayFromImage(img_obj) + spacing = img_obj.GetSpacing() + slices.append(img_arr.shape[0]) + space0.append(spacing[0]) + space1.append(spacing[2]) + print(img_name, spacing, img_arr.shape) + + space0 = np.asarray(space0) + space1 = np.asarray(space1) + slices = np.asarray(slices) + print("intra-slice spacing") + print(space0.min(), space0.max(), space0.mean()) + print("inter-slice spacing") + print(space1.min(), space1.max(), space1.mean()) + print("slice number") + print(slices.min(), slices.max(), slices.mean()) + +def get_average_mean_std(data_dir, data_csv): + df = pd.read_csv(data_csv) + mean_list, std_list = [], [] + for i in range(len(df)): + img_name = data_dir + "/" + df.iloc[i, 0] + lab_name = data_dir + "/" + df.iloc[i, 1] + img = load_image_as_nd_array(img_name)["data_array"][0] + lab = load_image_as_nd_array(lab_name)["data_array"][0] + voxels = img[lab>0] + mean = voxels.mean() + std = voxels.std() + mean_list.append(mean) + std_list.append(std) + print(img_name, mean, std) + mean = np.asarray(mean_list).mean() + std = np.asarray(std_list).mean() + print("mean and std value", mean, std) + +def get_label_info(data_dir, label_csv, class_num): + df = pd.read_csv(label_csv) + size_list = [] + # mean_list, std_list = [], [] + num_no_tumor = 0 + for i in range(len(df)): + lab_name = data_dir + "/" + df.iloc[i, 1] + lab = load_image_as_nd_array(lab_name)["data_array"][0] + size_per_class = [] + for c in range(1, class_num): + labc = lab == c + size_per_class.append(np.sum(labc)) + if(np.sum(labc) == 0): + num_no_tumor = num_no_tumor + 1 + size_list.append(size_per_class) + print(lab_name, size_per_class) + size = np.asarray(size_list) + size_min = size.min(axis = 0) + size_max = size.max(axis = 0) + size_mean = size.mean(axis = 0) + + print("size min", size_min) + print("size max", size_max) + print("size mean", size_mean) + print("case number without tumor", num_no_tumor) \ No newline at end of file From 258041aa2d398a134f243b599b4d4b753d6c947e Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 5 Jun 2023 13:09:46 +0800 Subject: [PATCH 153/225] update ce and dice loss set pixel weight and class weight --- pymic/loss/seg/ce.py | 25 +++++++++++++------------ pymic/loss/seg/dice.py | 14 +++++++++++--- pymic/net_run/agent_seg.py | 20 +++++++++++++++----- 3 files changed, 39 insertions(+), 20 deletions(-) diff --git a/pymic/loss/seg/ce.py b/pymic/loss/seg/ce.py index 529482b..9524d57 100644 --- a/pymic/loss/seg/ce.py +++ b/pymic/loss/seg/ce.py @@ -23,6 +23,7 @@ def forward(self, loss_input_dict): predict = loss_input_dict['prediction'] soft_y = loss_input_dict['ground_truth'] pix_w = loss_input_dict.get('pixel_weight', None) + cls_w = loss_input_dict.get('class_weight', None) if(isinstance(predict, (list, tuple))): predict = predict[0] @@ -34,7 +35,10 @@ def forward(self, loss_input_dict): # for numeric stability predict = predict * 0.999 + 5e-4 ce = - soft_y* torch.log(predict) - ce = torch.sum(ce, dim = 1) # shape is [N] + if(cls_w is not None): + ce = torch.sum(ce*cls_w, dim = 1) + else: + ce = torch.sum(ce, dim = 1) # shape is [N] if(pix_w is None): ce = torch.mean(ce) else: @@ -61,12 +65,12 @@ class GeneralizedCELoss(AbstractSegLoss): def __init__(self, params): super(GeneralizedCELoss, self).__init__(params) self.q = params.get('loss_gce_q', 0.5) - self.enable_pix_weight = params.get('loss_with_pixel_weight', False) - self.cls_weight = params.get('loss_class_weight', None) def forward(self, loss_input_dict): predict = loss_input_dict['prediction'] - soft_y = loss_input_dict['ground_truth'] + soft_y = loss_input_dict['ground_truth'] + pix_w = loss_input_dict.get('pixel_weight', None) + cls_w = loss_input_dict.get('class_weight', None) if(isinstance(predict, (list, tuple))): predict = predict[0] @@ -76,17 +80,14 @@ def forward(self, loss_input_dict): soft_y = reshape_tensor_to_2D(soft_y) gce = (1.0 - torch.pow(predict, self.q)) / self.q * soft_y - if(self.cls_weight is not None): - gce = torch.sum(gce * self.cls_w, dim = 1) + if(cls_w is not None): + gce = torch.sum(gce * cls_w, dim = 1) else: gce = torch.sum(gce, dim = 1) - if(self.enable_pix_weight): - pix_w = loss_input_dict.get('pixel_weight', None) - if(pix_w is None): - raise ValueError("Pixel weight is enabled but not defined") - pix_w = reshape_tensor_to_2D(pix_w) - gce = torch.sum(gce * pix_w) / torch.sum(pix_w) + if(pix_w is not None): + pix_w = torch.squeeze(reshape_tensor_to_2D(pix_w)) + gce = torch.sum(gce * pix_w) / torch.sum(pix_w) else: gce = torch.mean(gce) return gce diff --git a/pymic/loss/seg/dice.py b/pymic/loss/seg/dice.py index 350e0c4..2c2df32 100644 --- a/pymic/loss/seg/dice.py +++ b/pymic/loss/seg/dice.py @@ -20,6 +20,8 @@ def __init__(self, params = None): def forward(self, loss_input_dict): predict = loss_input_dict['prediction'] soft_y = loss_input_dict['ground_truth'] + pix_w = loss_input_dict.get('pixel_weight', None) + cls_w = loss_input_dict.get('class_weight', None) if(isinstance(predict, (list, tuple))): predict = predict[0] @@ -27,9 +29,15 @@ def forward(self, loss_input_dict): predict = nn.Softmax(dim = 1)(predict) predict = reshape_tensor_to_2D(predict) soft_y = reshape_tensor_to_2D(soft_y) - dice_score = get_classwise_dice(predict, soft_y) - dice_loss = 1.0 - dice_score.mean() - return dice_loss + if(pix_w is not None): + pix_w = reshape_tensor_to_2D(pix_w) + dice_loss = 1.0 - get_classwise_dice(predict, soft_y, pix_w) + if(cls_w is not None): + weighted_loss = dice_loss * cls_w + avg_loss = weighted_loss.sum() / cls_w.sum() + else: + avg_loss = dice_loss.mean() + return avg_loss class BinaryDiceLoss(AbstractSegLoss): ''' diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 3d2fb21..70f01b3 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -118,11 +118,21 @@ def create_loss_calculator(self): def get_loss_value(self, data, pred, gt, param = None): loss_input_dict = {'prediction':pred, 'ground_truth': gt} - if data.get('pixel_weight', None) is not None: - if(isinstance(pred, tuple) or isinstance(pred, list)): - loss_input_dict['pixel_weight'] = data['pixel_weight'].to(pred[0].device) - else: - loss_input_dict['pixel_weight'] = data['pixel_weight'].to(pred.device) + if(isinstance(pred, tuple) or isinstance(pred, list)): + device = pred[0].device + else: + device = pred.device + pixel_weight = data.get('pixel_weight', None) + if(pixel_weight is not None): + loss_input_dict['pixel_weight'] = pixel_weight.to(device) + + class_weight = self.config['training'].get('class_weight', None) + if(class_weight is not None): + class_num = self.config['network']['class_num'] + assert(len(class_weight) == class_num) + class_weight = torch.from_numpy(np.asarray(class_weight)) + class_weight = self.convert_tensor_type(class_weight) + loss_input_dict['class_weight'] = class_weight.to(device) loss_value = self.loss_calculator(loss_input_dict) return loss_value From d74ac93760a76a1cb32e947aaefe545627bfc9ce Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 27 Jun 2023 11:10:23 +0800 Subject: [PATCH 154/225] Update net_dict_seg.py --- pymic/net/net_dict_seg.py | 56 +++++++++++++++++++-------------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/pymic/net/net_dict_seg.py b/pymic/net/net_dict_seg.py index e6c10bd..ffaa023 100644 --- a/pymic/net/net_dict_seg.py +++ b/pymic/net/net_dict_seg.py @@ -21,24 +21,24 @@ from pymic.net.net2d.unet2d_attention import AttentionUNet2D from pymic.net.net2d.unet2d_nest import NestedUNet2D from pymic.net.net2d.unet2d_scse import UNet2D_ScSE -from pymic.net.net2d.trans2d.transunet import TransUNet -from pymic.net.net2d.trans2d.swinunet import SwinUNet +# from pymic.net.net2d.trans2d.transunet import TransUNet +# from pymic.net.net2d.trans2d.swinunet import SwinUNet from pymic.net.net3d.unet2d5 import UNet2D5 from pymic.net.net3d.unet3d import UNet3D from pymic.net.net3d.unet3d_scse import UNet3D_ScSE from pymic.net.net3d.unet3d_dual_branch import UNet3D_DualBranch -from pymic.net.net3d.trans3d.nnFormer_wrap import nnFormer_wrap -from pymic.net.net3d.trans3d.unetr import UNETR -from pymic.net.net3d.trans3d.unetr_pp import UNETR_PP -from pymic.net.net3d.trans3d.MedFormer_v1 import MedFormerV1 -from pymic.net.net3d.trans3d.MedFormer_v2 import MedFormerV2 -from pymic.net.net3d.trans3d.MedFormer_v3 import MedFormerV3 -from pymic.net.net3d.trans3d.MedFormer_va1 import MedFormerVA1 -from pymic.net.net3d.trans3d.HiFormer_v1 import HiFormer_v1 -from pymic.net.net3d.trans3d.HiFormer_v2 import HiFormer_v2 -from pymic.net.net3d.trans3d.HiFormer_v3 import HiFormer_v3 -from pymic.net.net3d.trans3d.HiFormer_v4 import HiFormer_v4 -from pymic.net.net3d.trans3d.HiFormer_v5 import HiFormer_v5 +# from pymic.net.net3d.trans3d.nnFormer_wrap import nnFormer_wrap +# from pymic.net.net3d.trans3d.unetr import UNETR +# from pymic.net.net3d.trans3d.unetr_pp import UNETR_PP +# from pymic.net.net3d.trans3d.MedFormer_v1 import MedFormerV1 +# from pymic.net.net3d.trans3d.MedFormer_v2 import MedFormerV2 +# from pymic.net.net3d.trans3d.MedFormer_v3 import MedFormerV3 +# from pymic.net.net3d.trans3d.MedFormer_va1 import MedFormerVA1 +# from pymic.net.net3d.trans3d.HiFormer_v1 import HiFormer_v1 +# from pymic.net.net3d.trans3d.HiFormer_v2 import HiFormer_v2 +# from pymic.net.net3d.trans3d.HiFormer_v3 import HiFormer_v3 +# from pymic.net.net3d.trans3d.HiFormer_v4 import HiFormer_v4 +# from pymic.net.net3d.trans3d.HiFormer_v5 import HiFormer_v5 SegNetDict = { 'UNet2D': UNet2D, @@ -48,22 +48,22 @@ 'AttentionUNet2D': AttentionUNet2D, 'NestedUNet2D': NestedUNet2D, 'UNet2D_ScSE': UNet2D_ScSE, - 'TransUNet': TransUNet, - 'SwinUNet': SwinUNet, + # 'TransUNet': TransUNet, + # 'SwinUNet': SwinUNet, 'UNet2D5': UNet2D5, 'UNet3D': UNet3D, 'UNet3D_ScSE': UNet3D_ScSE, 'UNet3D_DualBranch': UNet3D_DualBranch, - 'nnFormer': nnFormer_wrap, - 'UNETR': UNETR, - 'UNETR_PP': UNETR_PP, - 'MedFormerV1': MedFormerV1, - 'MedFormerV2': MedFormerV2, - 'MedFormerV3': MedFormerV3, - 'MedFormerVA1':MedFormerVA1, - 'HiFormer_v1': HiFormer_v1, - 'HiFormer_v2': HiFormer_v2, - 'HiFormer_v3': HiFormer_v3, - 'HiFormer_v4': HiFormer_v4, - 'HiFormer_v5': HiFormer_v5 + # 'nnFormer': nnFormer_wrap, + # 'UNETR': UNETR, + # 'UNETR_PP': UNETR_PP, + # 'MedFormerV1': MedFormerV1, + # 'MedFormerV2': MedFormerV2, + # 'MedFormerV3': MedFormerV3, + # 'MedFormerVA1':MedFormerVA1, + # 'HiFormer_v1': HiFormer_v1, + # 'HiFormer_v2': HiFormer_v2, + # 'HiFormer_v3': HiFormer_v3, + # 'HiFormer_v4': HiFormer_v4, + # 'HiFormer_v5': HiFormer_v5 } From f43f409d1f0023e5b07fcfb72ec53ed45fc1ee95 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 30 Jun 2023 18:37:29 +0800 Subject: [PATCH 155/225] Update train.py fix issues for datetime --- pymic/net_run/train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pymic/net_run/train.py b/pymic/net_run/train.py index 0167f2f..426b620 100644 --- a/pymic/net_run/train.py +++ b/pymic/net_run/train.py @@ -78,11 +78,12 @@ def main(): os.makedirs(log_dir, exist_ok=True) dst_cfg = cfg_file if "/" not in cfg_file else cfg_file.split("/")[-1] shutil.copy(cfg_file, log_dir + "/" + dst_cfg) + datetime_str = str(datetime.now())[:-7].replace(":", "_") if sys.version.startswith("3.9"): - logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(str(datetime.now())[:-7]), + logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(datetime_str), level=logging.INFO, format='%(message)s', force=True) # for python 3.9 else: - logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(str(datetime.now())[:-7]), + logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(datetime_str), level=logging.INFO, format='%(message)s') # for python 3.6 logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging_config(config) From a7d179fee45edfe006bf8b9c863da2a8190b4b3c Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 18 Jul 2023 09:56:17 +0800 Subject: [PATCH 156/225] Update image_process.py --- pymic/util/image_process.py | 50 +++++++++++++++++++++++++------------ 1 file changed, 34 insertions(+), 16 deletions(-) diff --git a/pymic/util/image_process.py b/pymic/util/image_process.py index 8ae8e80..4143f39 100644 --- a/pymic/util/image_process.py +++ b/pymic/util/image_process.py @@ -319,26 +319,44 @@ def resample_sitk_image_to_given_spacing(image, spacing, order): out_img.SetDirection(image.GetDirection()) return out_img -def get_image_info(img_names): - space0, space1, slices = [], [], [] +def get_image_info(img_names, output_csv = None): + spacing_list, shape_list = [], [] for img_name in img_names: img_obj = sitk.ReadImage(img_name) img_arr = sitk.GetArrayFromImage(img_obj) spacing = img_obj.GetSpacing() - slices.append(img_arr.shape[0]) - space0.append(spacing[0]) - space1.append(spacing[2]) - print(img_name, spacing, img_arr.shape) - - space0 = np.asarray(space0) - space1 = np.asarray(space1) - slices = np.asarray(slices) - print("intra-slice spacing") - print(space0.min(), space0.max(), space0.mean()) - print("inter-slice spacing") - print(space1.min(), space1.max(), space1.mean()) - print("slice number") - print(slices.min(), slices.max(), slices.mean()) + shape = img_arr.shape + spacing_list.append(spacing) + shape_list.append(shape) + print(img_name, spacing, shape) + spacings = np.asarray(spacing_list) + shapes = np.asarray(shape_list) + spacing_min = spacings.min(axis = 0) + spacing_max = spacings.max(axis = 0) + spacing_median = np.percentile(spacings, 50, axis = 0) + print("spacing min", spacing_min) + print("spacing max", spacing_max) + print("spacing median", spacing_median) + + shape_min = shapes.min(axis = 0) + shape_max = shapes.max(axis = 0) + shape_median = np.percentile(shapes, 50, axis = 0) + print("shape min", shape_min) + print("shape max", shape_max) + print("shape median", shape_median) + + if(output_csv is not None): + img_names_short = [item.split("/")[-1] for item in img_names] + img_names_short.extend(["spacing min", "spacing max", "spacing median", + "shape min", "shape max", "shape median"]) + spacing_list.extend([spacing_min, spacing_max, spacing_median, + shape_min, shape_max, shape_median]) + shape_list.extend(['']* 6) + out_dict = {"img_name": img_names_short, + "spacing": spacing_list, + "shape": shape_list} + df = pd.DataFrame.from_dict(out_dict) + df.to_csv(output_csv, index=False) def get_average_mean_std(data_dir, data_csv): df = pd.read_csv(data_csv) From 5db981ffbc35e5e7be69e1a15eae8b3d81419709 Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 18 Jul 2023 20:47:54 +0800 Subject: [PATCH 157/225] Update image_read_write.py --- pymic/io/image_read_write.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pymic/io/image_read_write.py b/pymic/io/image_read_write.py index cb65e19..5278959 100644 --- a/pymic/io/image_read_write.py +++ b/pymic/io/image_read_write.py @@ -23,11 +23,9 @@ def load_nifty_volume_as_4d_array(filename): spacing = img_obj.GetSpacing() direction = img_obj.GetDirection() shape = data_array.shape - if(len(shape) == 4): - assert(shape[3] == 1) - elif(len(shape) == 3): + if(len(shape) == 3): data_array = np.expand_dims(data_array, axis = 0) - else: + elif(len(shape) > 4 or len(shape) < 3): raise ValueError("unsupported image dim: {0:}".format(len(shape))) output = {} output['data_array'] = data_array From d69a415a57e762b6aef71aee1a3ddeac2e52b62b Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 19 Jul 2023 09:40:27 +0800 Subject: [PATCH 158/225] Update rescale.py --- pymic/transform/rescale.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/pymic/transform/rescale.py b/pymic/transform/rescale.py index 355712e..36f122b 100644 --- a/pymic/transform/rescale.py +++ b/pymic/transform/rescale.py @@ -100,27 +100,26 @@ def __init__(self, params): self.ratio0 = params["RandomRescale_lower_bound".lower()] self.ratio1 = params["RandomRescale_upper_bound".lower()] self.prob = params.get('RandomRescale_probability'.lower(), 0.5) - self.inverse = params.get("RandomRescale_inverse".lower(), True) + self.inverse = params.get("RandomRescale_inverse".lower(), False) assert isinstance(self.ratio0, (float, list, tuple)) assert isinstance(self.ratio1, (float, list, tuple)) def __call__(self, sample): - # if(random.random() > self.prob): - # print("rescale not started") - # sample['RandomRescale_triggered'] = False - # return sample - # else: - # print("rescale started") - # sample['RandomRescale_triggered'] = True + image = sample['image'] input_shape = image.shape input_dim = len(input_shape) - 1 - + assert(input_dim == len(self.ratio0) and input_dim == len(self.ratio1)) + if isinstance(self.ratio0, (list, tuple)): - for i in range(len(self.ratio0)): + for i in range(input_dim): + if(self.ratio0[i] is None): + self.ratio0[i] = 1.0 + if(self.ratio1[i] is None): + self.ratio1[i] = 1.0 assert(self.ratio0[i] <= self.ratio1[i]) scale = [self.ratio0[i] + random.random()*(self.ratio1[i] - self.ratio0[i]) \ - for i in range(len(self.ratio0))] + for i in range(input_dim)] else: scale = self.ratio0 + random.random()*(self.ratio1 - self.ratio0) scale = [scale] * input_dim @@ -130,12 +129,12 @@ def __call__(self, sample): sample['image'] = image_t sample['RandomRescale_Param'] = json.dumps(input_shape) if('label' in sample and \ - self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): label = sample['label'] label = ndimage.interpolation.zoom(label, scale, order = 0) sample['label'] = label if('pixel_weight' in sample and \ - self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): weight = sample['pixel_weight'] weight = ndimage.interpolation.zoom(weight, scale, order = 1) sample['pixel_weight'] = weight @@ -143,8 +142,6 @@ def __call__(self, sample): return sample def inverse_transform_for_prediction(self, sample): - if(not sample['RandomRescale_triggered']): - return sample if(isinstance(sample['RandomRescale_Param'], list) or \ isinstance(sample['RandomRescale_Param'], tuple)): origin_shape = json.loads(sample['RandomRescale_Param'][0]) From c09e00a45c8081e82126abea47e65bb7126f5349 Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 20 Jul 2023 08:37:53 +0800 Subject: [PATCH 159/225] Update image_read_write.py --- pymic/io/image_read_write.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pymic/io/image_read_write.py b/pymic/io/image_read_write.py index 5278959..6c8c6b0 100644 --- a/pymic/io/image_read_write.py +++ b/pymic/io/image_read_write.py @@ -97,7 +97,10 @@ def save_array_as_nifty_volume(data, image_name, reference_name = None): #img.CopyInformation(img_ref) img.SetSpacing(img_ref.GetSpacing()) img.SetOrigin(img_ref.GetOrigin()) - img.SetDirection(img_ref.GetDirection()) + direction0 = img_ref.GetDirection() + direction1 = img.GetDirection() + if(len(direction0) == len(direction1)): + img.SetDirection(direction0) sitk.WriteImage(img, image_name) def save_array_as_rgb_image(data, image_name): From ac96e04cfbba183903e623548440cd584076d4a6 Mon Sep 17 00:00:00 2001 From: taigw Date: Sun, 6 Aug 2023 10:26:03 +0800 Subject: [PATCH 160/225] update transforms --- pymic/transform/crop.py | 21 ++++++---- pymic/transform/intensity.py | 74 +++++++++++++++++++++++++++++++----- pymic/transform/rescale.py | 70 ++++++++++++++++++++++++++++++++++ pymic/transform/rotate.py | 4 +- 4 files changed, 151 insertions(+), 18 deletions(-) diff --git a/pymic/transform/crop.py b/pymic/transform/crop.py index b8130f9..36a0ca9 100644 --- a/pymic/transform/crop.py +++ b/pymic/transform/crop.py @@ -153,7 +153,6 @@ def _get_crop_param(self, sample): crop_min = [0] + crop_min crop_max = list(input_shape[0:1]) + crop_max sample['CropWithBoundingBox_Param'] = json.dumps((input_shape, crop_min, crop_max)) - print("for crop", crop_min, crop_max) return sample, crop_min, crop_max def _get_param_for_inverse_transform(self, sample): @@ -248,7 +247,8 @@ def _get_crop_param(self, sample): crop_min = [0 if item == 0 else random.randint(0, item) for item in crop_margin] crop_max = [crop_min[i] + self.output_size[i] for i in range(input_dim)] - if(self.fg_focus and random.random() < self.fg_ratio): + label_exist = False if ('label' not in sample or sample['label']) is None else True + if(label_exist and self.fg_focus and random.random() < self.fg_ratio): label = sample['label'][0] if(self.mask_label is None): mask_label = np.unique(label)[1:] @@ -279,26 +279,33 @@ class RandomResizedCrop(CenterCrop): :param `RandomResizedCrop_output_size`: (list/tuple) Desired output size [D, H, W]. The output channel is the same as the input channel. - :param `RandomResizedCrop_scale_range`: (list/tuple) Range of scale, e.g. (0.08, 1.0). + :param `RandomResizedCrop_scale_lower_bound`: (list/tuple) Lower bound of the range of scale + for each dimension. e.g. (1.0, 0.5, 0.5). + param `RandomResizedCrop_scale_upper_bound`: (list/tuple) Upper bound of the range of scale + for each dimension. e.g. (1.0, 2.0, 2.0). :param `RandomResizedCrop_inverse`: (optional, bool) Is inverse transform needed for inference. Default is `False`. Currently, the inverse transform is not supported, and this transform is assumed to be used only during training stage. """ def __init__(self, params): self.output_size = params['RandomResizedCrop_output_size'.lower()] - self.scale = params['RandomResizedCrop_scale_range'.lower()] + self.scale_lower = params['RandomResizedCrop_scale_lower_bound'.lower()] + self.scale_upper = params['RandomResizedCrop_scale_upper_bound'.lower()] self.inverse = params.get('RandomResizedCrop_inverse'.lower(), False) self.task = params['Task'.lower()] assert isinstance(self.output_size, (list, tuple)) - assert isinstance(self.scale, (list, tuple)) + assert isinstance(self.scale_lower, (list, tuple)) + assert isinstance(self.scale_upper, (list, tuple)) def __call__(self, sample): image = sample['image'] channel, input_size = image.shape[0], image.shape[1:] input_dim = len(input_size) assert(input_dim == len(self.output_size)) - scale = self.scale[0] + random.random()*(self.scale[1] - self.scale[0]) - crop_size = [int(self.output_size[i] * scale) for i in range(input_dim)] + scale = [self.scale_lower[i] + (self.scale_upper[i] - self.scale_lower[i]) * random.random() \ + for i in range(input_dim)] + + crop_size = [int(self.output_size[i] * scale[i]) for i in range(input_dim)] crop_margin = [input_size[i] - crop_size[i] for i in range(input_dim)] pad_image = False if(min(crop_margin) < 0): diff --git a/pymic/transform/intensity.py b/pymic/transform/intensity.py index 1a13190..9dbf6d9 100644 --- a/pymic/transform/intensity.py +++ b/pymic/transform/intensity.py @@ -171,21 +171,31 @@ def __call__(self, sample): class NonLinearTransform(AbstractTransform): def __init__(self, params): super(NonLinearTransform, self).__init__(params) - self.inverse = params.get('NonLinearTransform_inverse'.lower(), False) + self.channels = params['NonLinearTransform_channels'.lower()] self.prob = params.get('NonLinearTransform_probability'.lower(), 0.5) + self.inverse = params.get('NonLinearTransform_inverse'.lower(), False) + def __call__(self, sample): if(random.random() > self.prob): return sample - image= sample['image'] - points = [[0, 0], [random.random(), random.random()], [random.random(), random.random()], [1, 1]] - xvals, yvals = bezier_curve(points, nTimes=100000) - if random.random() < 0.5: # Half change to get flip - xvals = np.sort(xvals) - else: - xvals, yvals = np.sort(xvals), np.sort(yvals) - image = np.interp(image, xvals, yvals) + image = sample['image'] + for chn in self.channels: + points = [[0, 0], [random.random(), random.random()], [random.random(), random.random()], [1, 1]] + xvals, yvals = bezier_curve(points, nTimes=10000) + if random.random() < 0.5: # Half change to get flip + xvals = np.sort(xvals) + else: + xvals, yvals = np.sort(xvals), np.sort(yvals) + # normalize the image intensity to [0, 1] before the non-linear tranform + img_c = image[chn] + v_min = img_c.min() + v_max = img_c.max() + if(v_min < v_max): + img_c = (img_c - v_min)/(v_max - v_min) + img_c = np.interp(img_c, xvals, yvals) + image[chn] = img_c * (v_max - v_min) + v_min sample['image'] = image return sample @@ -392,4 +402,50 @@ def __call__(self, sample): sample = self.inpaint(sample) else: sample = self.outpaint(sample) + return sample + +class PatchSwaping(AbstractTransform): + """ + Apply patch swaping for context restoration in self-supervised learning. + Reference: Liang Chen et al., Self-supervised learning for medical image analysis + using image context restoration, Medical Image Analysis, 2019. + """ + def __init__(self, params): + super(PatchSwaping, self).__init__(params) + self.inverse = params.get('PatchSwaping_inverse'.lower(), False) + self.swap_t = params.get('PatchSwaping_swap_time'.lower(), (1, 6)) + self.patch_size_min = params.get('PatchSwaping_patch_size_min'.lower(), None) + self.patch_size_max = params.get('PatchSwaping_patch_size_max'.lower(), None) + + def __call__(self, sample): + + image= sample['image'] + img_shape = image.shape + img_dim = len(img_shape) - 1 + assert(img_dim == 2 or img_dim == 3) + img_out = image + + C, D, H, W = image.shape + patch_size = [random.randint(self.patch_size_min[i], self.patch_size_max[i]) for \ + i in range(img_dim)] + + coordinate_list = [] + for d in range(0, D-patch_size[0], patch_size[0]): + for h in range(0, H-patch_size[1], patch_size[1]): + for w in range(0, W-patch_size[2], patch_size[2]): + coordinate_list.append((d, h, w)) + random.shuffle(coordinate_list) + + for t in range(self.swap_t): + pos_a0 = coordinate_list[2*t] + pos_b0 = coordinate_list[2*t + 1] + pos_a1 = [pos_a0[i] + patch_size[i] for i in range(img_dim)] + pos_b1 = [pos_b0[i] + patch_size[i] for i in range(img_dim)] + img_out[:, pos_a0[0]:pos_a1[0], pos_a0[1]:pos_a1[1], pos_a0[2]:pos_a1[2]] = \ + image[:, pos_b0[0]:pos_b1[0], pos_b0[1]:pos_b1[1], pos_b0[2]:pos_b1[2]] + img_out[:, pos_b0[0]:pos_b1[0], pos_b0[1]:pos_b1[1], pos_b0[2]:pos_b1[2]] = \ + image[:, pos_a0[0]:pos_a1[0], pos_a0[1]:pos_a1[1], pos_a0[2]:pos_a1[2]] + + sample['image'] = img_out + sample['label'] = image return sample \ No newline at end of file diff --git a/pymic/transform/rescale.py b/pymic/transform/rescale.py index 36f122b..fa4f052 100644 --- a/pymic/transform/rescale.py +++ b/pymic/transform/rescale.py @@ -154,6 +154,76 @@ def inverse_transform_for_prediction(self, sample): i in range(origin_dim)] scale = [1.0, 1.0] + scale + output_predict = ndimage.interpolation.zoom(predict, scale, order = 1) + sample['predict'] = output_predict + return sample + + +class Resample(Rescale): + """Resample the image to a given spatial resolution. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `Rescale_output_size`: (list/tuple or int) The output size along each spatial axis, + such as [D, H, W] or [H, W]. If D is None, the input image is only reslcaled in 2D. + If int, the smallest axis is matched to output_size keeping aspect ratio the same + as the input. + :param `Rescale_inverse`: (optional, bool) + Is inverse transform needed for inference. Default is `True`. + """ + def __init__(self, params): + super(Rescale, self).__init__(params) + self.output_spacing = params["Resample_output_spacing".lower()] + self.ignore_zspacing= params.get("Resample_ignore_zspacing_range".lower(), None) + self.inverse = params.get("Resample_inverse".lower(), True) + # assert isinstance(self.output_size, (int, list, tuple)) + + def __call__(self, sample): + image = sample['image'] + input_shape = image.shape + input_dim = len(input_shape) - 1 + spacing = sample['spacing'] + out_spacing = [item for item in self.output_spacing] + for i in range(input_dim): + out_spacing[i] = spacing[i] if out_spacing[i] is None else out_spacing[i] + if(self.ignore_zspacing is not None): + if(spacing[0] > self.ignore_zspacing[0] and spacing[0] < self.ignore_zspacing[1]): + out_spacing[0] = spacing[0] + scale = [spacing[i] / out_spacing[i] for i in range(input_dim)] + scale = [1.0] + scale + + image_t = ndimage.interpolation.zoom(image, scale, order = 1) + + sample['image'] = image_t + sample['spacing'] = out_spacing + sample['Resample_origin_shape'] = json.dumps(input_shape) + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + label = sample['label'] + label = ndimage.interpolation.zoom(label, scale, order = 0) + sample['label'] = label + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + weight = sample['pixel_weight'] + weight = ndimage.interpolation.zoom(weight, scale, order = 1) + sample['pixel_weight'] = weight + + return sample + + def inverse_transform_for_prediction(self, sample): + if(isinstance(sample['Resample_origin_shape'], list) or \ + isinstance(sample['Resample_origin_shape'], tuple)): + origin_shape = json.loads(sample['Resample_origin_shape'][0]) + else: + origin_shape = json.loads(sample['Resample_origin_shape']) + origin_dim = len(origin_shape) - 1 + predict = sample['predict'] + input_shape = predict.shape + scale = [(origin_shape[1:][i] + 0.0)/input_shape[2:][i] for \ + i in range(origin_dim)] + scale = [1.0, 1.0] + scale + output_predict = ndimage.interpolation.zoom(predict, scale, order = 1) sample['predict'] = output_predict return sample \ No newline at end of file diff --git a/pymic/transform/rotate.py b/pymic/transform/rotate.py index 65e5328..5f85e28 100644 --- a/pymic/transform/rotate.py +++ b/pymic/transform/rotate.py @@ -34,8 +34,8 @@ class RandomRotate(AbstractTransform): def __init__(self, params): super(RandomRotate, self).__init__(params) self.angle_range_d = params['RandomRotate_angle_range_d'.lower()] - self.angle_range_h = params['RandomRotate_angle_range_h'.lower()] - self.angle_range_w = params['RandomRotate_angle_range_w'.lower()] + self.angle_range_h = params.get('RandomRotate_angle_range_h'.lower(), None) + self.angle_range_w = params.get('RandomRotate_angle_range_w'.lower(), None) self.prob = params.get('RandomRotate_probability'.lower(), 0.5) self.inverse = params.get('RandomRotate_inverse'.lower(), True) From 05d1748b10870760d336f36a53fcbc57798c5354 Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 10 Aug 2023 09:17:21 +0800 Subject: [PATCH 161/225] update transform --- pymic/transform/flip.py | 2 +- pymic/transform/intensity.py | 3 ++ pymic/transform/mix.py | 85 ++++++++++++++++++++++++++++++++++- pymic/transform/trans_dict.py | 4 ++ 4 files changed, 92 insertions(+), 2 deletions(-) diff --git a/pymic/transform/flip.py b/pymic/transform/flip.py index 24cafb4..486180c 100644 --- a/pymic/transform/flip.py +++ b/pymic/transform/flip.py @@ -54,7 +54,7 @@ def __call__(self, sample): image_t = np.flip(image, flip_axis).copy() sample['image'] = image_t if('label' in sample and \ - self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): sample['label'] = np.flip(sample['label'] , flip_axis).copy() if('pixel_weight' in sample and \ self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): diff --git a/pymic/transform/intensity.py b/pymic/transform/intensity.py index 9dbf6d9..e5e30f6 100644 --- a/pymic/transform/intensity.py +++ b/pymic/transform/intensity.py @@ -98,6 +98,7 @@ def __init__(self, params): self.channels = params['GammaCorrection_channels'.lower()] self.gamma_min = params['GammaCorrection_gamma_min'.lower()] self.gamma_max = params['GammaCorrection_gamma_max'.lower()] + self.flip_prob = params.get('GammaCorrection_intensity_flip_probability'.lower(), 0.2) self.prob = params.get('GammaCorrection_probability'.lower(), 0.5) self.inverse = params.get('GammaCorrection_inverse'.lower(), False) @@ -112,6 +113,8 @@ def __call__(self, sample): v_max = img_c.max() if(v_min < v_max): img_c = (img_c - v_min)/(v_max - v_min) + if(np.random.uniform() < self.flip_prob): + img_c = 1.0 - img_c img_c = np.power(img_c, gamma_c)*(v_max - v_min) + v_min image[chn] = img_c diff --git a/pymic/transform/mix.py b/pymic/transform/mix.py index fe2f315..6e2fb8e 100644 --- a/pymic/transform/mix.py +++ b/pymic/transform/mix.py @@ -63,4 +63,87 @@ def __call__(self, sample): coord_min[1]:coord_min[1] + block_size[1], coord_min[2]:coord_min[2] + block_size[2]] = random_block sample['image'] = image - return sample \ No newline at end of file + return sample + +class PatchMix(AbstractTransform): + """ + In-painting of an input image, used for self-supervised learning + """ + def __init__(self, params): + super(PatchMix, self).__init__(params) + self.inverse = params.get('PatchMix_inverse'.lower(), False) + self.crop_size = params.get('PatchMix_crop_size'.lower(), [64, 128, 128]) + self.fg_cls_num = params.get('PatchMix_cls_num'.lower(), [4, 40]) + self.patch_num_range= params.get('PatchMix_patch_range'.lower(), [4, 40]) + self.patch_size_min = params.get('PatchMix_patch_size_min'.lower(), [4, 4, 4]) + self.patch_size_max = params.get('PatchMix_patch_size_max'.lower(), [20, 40, 40]) + + def __call__(self, sample): + x0, x1 = self._random_crop_and_flip(sample) + C, D, H, W = x0.shape + # generate mask + fg_mask = np.zeros_like(x0, np.uint8) + patch_num = random.randint(self.patch_num_range[0], self.patch_num_range[1]) + for patch in range(patch_num): + d = random.randint(self.patch_size_min[0], self.patch_size_max[0]) + h = random.randint(self.patch_size_min[1], self.patch_size_max[1]) + w = random.randint(self.patch_size_min[2], self.patch_size_max[2]) + d_c = random.randint(0, D) + h_c = random.randint(0, H) + w_c = random.randint(0, W) + d0, d1 = max(0, d_c - d // 2), min(D, d_c + d // 2) + h0, h1 = max(0, h_c - h // 2), min(H, h_c + h // 2) + w0, w1 = max(0, w_c - w // 2), min(W, w_c + w // 2) + temp_m = np.ones([C, d1-d0, h1-h0, w1-w0]) * random.randint(1, self.fg_cls_num) + fg_mask[:, d0:d1, h0:h1, w0:w1] = temp_m + fg_w = fg_mask * 1.0 / self.fg_cls_num + x_fuse = fg_w*x0 + (1.0 - fg_w)*x1 # x1 is used as background + + sample['image'] = x_fuse + sample['label'] = fg_mask + return sample + + def _random_crop_and_flip(self, sample): + input_shape = sample['image'].shape + input_dim = len(input_shape) - 1 + assert(input_dim == 3) + + if('label' in sample): + # get the center for crop randomly + mask = sample['label'] > 0 + C, D, H, W = input_shape + size_h = [i// 2 for i in self.crop_size] + temp_mask = np.zeros_like(mask) + temp_mask[:,size_h[0]:D-size_h[0]+1,size_h[1]:H-size_h[1]+1,size_h[2]:W-size_h[2]+1] = \ + np.ones([C, D-self.crop_size[0]+1, H-self.crop_size[1]+1, W-self.crop_size[2]+1]) + mask = mask * temp_mask + indices = np.where(mask) + n0 = random.randint(0, len(indices[0])-1) + n1 = random.randint(0, len(indices[0])-1) + center0 = [indices[i][n0] for i in range(1, 4)] + center1 = [indices[i][n1] for i in range(1, 4)] + crop_min0 = [center0[i] - size_h[i] for i in range(3)] + crop_min1 = [center1[i] - size_h[i] for i in range(3)] + else: + crop_margin = [input_shape[1+i] - self.crop_size[i] for i in range(input_dim)] + crop_min0 = [0 if item == 0 else random.randint(0, item) for item in crop_margin] + crop_min1 = [0 if item == 0 else random.randint(0, item) for item in crop_margin] + + patches = [] + for crop_min in [crop_min0, crop_min1]: + crop_max = [crop_min[i] + self.crop_size[i] for i in range(input_dim)] + crop_min = [0] + crop_min + crop_max = [C] + crop_max + x = crop_ND_volume_with_bounding_box(sample['image'], crop_min, crop_max) + flip_axis = [] + if(random.random() > 0.5): + flip_axis.append(-1) + if(random.random() > 0.5): + flip_axis.append(-2) + if(random.random() > 0.5): + flip_axis.append(-3) + if(len(flip_axis) > 0): + x = np.flip(x, flip_axis).copy() + patches.append(x) + + return patches \ No newline at end of file diff --git a/pymic/transform/trans_dict.py b/pymic/transform/trans_dict.py index 5dac73b..f779d00 100644 --- a/pymic/transform/trans_dict.py +++ b/pymic/transform/trans_dict.py @@ -40,6 +40,7 @@ from pymic.transform.threshold import * from pymic.transform.normalize import * from pymic.transform.crop import * +from pymic.transform.mix import * from pymic.transform.label_convert import * TransformDict = { @@ -71,8 +72,11 @@ 'RandomRotate': RandomRotate, 'ReduceLabelDim': ReduceLabelDim, 'Rescale': Rescale, + 'Resample': Resample, 'SelfReconstructionLabel': SelfReconstructionLabel, 'MaskedImageModelingLabel': MaskedImageModelingLabel, 'OutPainting': OutPainting, 'Pad': Pad, + 'PatchSwaping':PatchSwaping, + 'PatchMix': PatchMix } From b526f936e78f54f7a68157fe0a93e871c52b94da Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 2 Sep 2023 15:03:15 +0800 Subject: [PATCH 162/225] update transform allow default channel setting --- pymic/transform/intensity.py | 29 ++++++++++++++++------------- pymic/transform/normalize.py | 13 +++++++------ 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/pymic/transform/intensity.py b/pymic/transform/intensity.py index e5e30f6..ffa5141 100644 --- a/pymic/transform/intensity.py +++ b/pymic/transform/intensity.py @@ -95,18 +95,20 @@ class GammaCorrection(AbstractTransform): """ def __init__(self, params): super(GammaCorrection, self).__init__(params) - self.channels = params['GammaCorrection_channels'.lower()] - self.gamma_min = params['GammaCorrection_gamma_min'.lower()] - self.gamma_max = params['GammaCorrection_gamma_max'.lower()] - self.flip_prob = params.get('GammaCorrection_intensity_flip_probability'.lower(), 0.2) + self.channels = params.get('GammaCorrection_channels'.lower(), None) + self.gamma_min = params.get('GammaCorrection_gamma_min'.lower(), 0.7) + self.gamma_max = params.get('GammaCorrection_gamma_max'.lower(), 1.5) + self.flip_prob = params.get('GammaCorrection_intensity_flip_probability'.lower(), 0.0) self.prob = params.get('GammaCorrection_probability'.lower(), 0.5) self.inverse = params.get('GammaCorrection_inverse'.lower(), False) def __call__(self, sample): - if(np.random.uniform() > self.prob): - return sample image= sample['image'] + if(self.channels is None): + self.channels = range(image.shape[0]) for chn in self.channels: + if(np.random.uniform() > self.prob): + continue gamma_c = random.random() * (self.gamma_max - self.gamma_min) + self.gamma_min img_c = image[chn] v_min = img_c.min() @@ -138,20 +140,21 @@ class GaussianNoise(AbstractTransform): """ def __init__(self, params): super(GaussianNoise, self).__init__(params) - self.channels = params['GaussianNoise_channels'.lower()] + self.channels = params.get('GaussianNoise_channels'.lower(), None) self.mean = params['GaussianNoise_mean'.lower()] self.std = params['GaussianNoise_std'.lower()] self.prob = params.get('GaussianNoise_probability'.lower(), 0.5) self.inverse = params.get('GaussianNoise_inverse'.lower(), False) def __call__(self, sample): - if(np.random.uniform() > self.prob): - return sample - image= sample['image'] + image = sample['image'] + if(self.channels is None): + self.channels = range(image.shape[0]) for chn in self.channels: - img_c = image[chn] - noise = np.random.normal(self.mean, self.std, img_c.shape) - image[chn] = img_c + noise + if(np.random.uniform() < self.prob): + img_c = image[chn] + noise = np.random.normal(self.mean, self.std, img_c.shape) + image[chn] = img_c + noise sample['image'] = image return sample diff --git a/pymic/transform/normalize.py b/pymic/transform/normalize.py index 77852d2..6531f17 100644 --- a/pymic/transform/normalize.py +++ b/pymic/transform/normalize.py @@ -34,7 +34,7 @@ class NormalizeWithMeanStd(AbstractTransform): """ def __init__(self, params): super(NormalizeWithMeanStd, self).__init__(params) - self.chns = params['NormalizeWithMeanStd_channels'.lower()] + self.chns = params.get('NormalizeWithMeanStd_channels'.lower(), None) self.mean = params.get('NormalizeWithMeanStd_mean'.lower(), None) self.std = params.get('NormalizeWithMeanStd_std'.lower(), None) self.ingore_np = params.get('NormalizeWithMeanStd_ignore_non_positive'.lower(), False) @@ -42,13 +42,14 @@ def __init__(self, params): def __call__(self, sample): image= sample['image'] - chns = self.chns if self.chns is not None else range(image.shape[0]) + if(self.chns is None): + self.chns = range(image.shape[0]) if(self.mean is None): - self.mean = [None] * len(chns) - self.std = [None] * len(chns) + self.mean = [None] * len(self.chns) + self.std = [None] * len(self.chns) - for i in range(len(chns)): - chn = chns[i] + for i in range(len(self.chns)): + chn = self.chns[i] chn_mean, chn_std = self.mean[i], self.std[i] if(chn_mean is None): if(self.ingore_np): From 73a13a39ab9b36d95cd75cbc93a1df38558d5215 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 2 Sep 2023 15:24:08 +0800 Subject: [PATCH 163/225] Update unetr_pp.py --- pymic/net/net3d/trans3d/unetr_pp.py | 63 ++++++++++++++++------------- 1 file changed, 36 insertions(+), 27 deletions(-) diff --git a/pymic/net/net3d/trans3d/unetr_pp.py b/pymic/net/net3d/trans3d/unetr_pp.py index 3ef6736..a4ab7e6 100644 --- a/pymic/net/net3d/trans3d/unetr_pp.py +++ b/pymic/net/net3d/trans3d/unetr_pp.py @@ -155,7 +155,6 @@ def __init__( def forward(self, x): B, C, H, W, D = x.shape - x = x.reshape(B, C, H * W * D).permute(0, 2, 1) if self.pos_embed is not None: @@ -170,12 +169,13 @@ def forward(self, x): class UnetrPPEncoder(nn.Module): def __init__(self, input_size=[32 * 32 * 32, 16 * 16 * 16, 8 * 8 * 8, 4 * 4 * 4],dims=[32, 64, 128, 256], - proj_size =[64,64,64,32], depths=[3, 3, 3, 3], num_heads=4, spatial_dims=3, in_channels=1, dropout=0.0, transformer_dropout_rate=0.15 ,**kwargs): + proj_size =[64,64,64,32], depths=[3, 3, 3, 3], num_heads=4, spatial_dims=3, + in_channels=1, dropout=0.0, transformer_dropout_rate=0.15, kernel_size=(2,4,4), **kwargs): super().__init__() self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers stem_layer = nn.Sequential( - get_conv_layer(spatial_dims, in_channels, dims[0], kernel_size=(2, 4, 4), stride=(2, 4, 4), + get_conv_layer(spatial_dims, in_channels, dims[0], kernel_size=kernel_size, stride=kernel_size, dropout=dropout, conv_only=True, ), get_norm_layer(name=("group", {"num_groups": in_channels}), channels=dims[0]), ) @@ -209,7 +209,6 @@ def _init_weights(self, m): def forward_features(self, x): hidden_states = [] - x = self.downsample_layers[0](x) x = self.stages[0](x) @@ -330,6 +329,7 @@ def __init__(self, params): in_channels = params['in_chns'] out_channels = params['class_num'] img_size = params['img_size'] + self.res_mode= params.get("resolution_mode", 1) feature_size = params.get('feature_size', 16) hidden_size = params.get('hidden_size', 256) num_heads = params.get('num_heads', 4) @@ -350,15 +350,20 @@ def __init__(self, params): if pos_embed not in ["conv", "perceptron"]: raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.") - self.patch_size = (2, 4, 4) + kernel_ds = [4, 2, 1] + kernel_d = kernel_ds[self.res_mode] + self.patch_size = (kernel_d, 4, 4) + self.feat_size = ( img_size[0] // self.patch_size[0] // 8, # 8 is the downsampling happened through the four encoders stages img_size[1] // self.patch_size[1] // 8, # 8 is the downsampling happened through the four encoders stages img_size[2] // self.patch_size[2] // 8, # 8 is the downsampling happened through the four encoders stages ) + self.hidden_size = hidden_size - self.unetr_pp_encoder = UnetrPPEncoder(dims=dims, depths=depths, num_heads=num_heads) + self.unetr_pp_encoder = UnetrPPEncoder(dims=dims, depths=depths, num_heads=num_heads, + in_channels=in_channels, kernel_size=self.patch_size) self.encoder1 = UnetResBlock( spatial_dims=3, @@ -395,20 +400,21 @@ def __init__(self, params): norm_name=norm_name, out_size=32 * 32 * 32, ) + self.decoder2 = UnetrUpBlock( spatial_dims=3, in_channels=feature_size * 2, out_channels=feature_size, kernel_size=3, - upsample_kernel_size=(2, 4, 4), + upsample_kernel_size= self.patch_size, norm_name=norm_name, - out_size=64 * 128 * 128, + out_size= kernel_d*32 * 128 * 128, conv_decoder=True, ) self.out1 = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels) - if self.do_ds: - self.out2 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 2, out_channels=out_channels) - self.out3 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 4, out_channels=out_channels) + # if self.do_ds: + self.out2 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 2, out_channels=out_channels) + self.out3 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 4, out_channels=out_channels) def proj_feat(self, x, hidden_size, feat_size): x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) @@ -442,19 +448,22 @@ def forward(self, x_in): if __name__ == "__main__": - params = {'in_chns': 1, - 'class_num': 2, - 'img_size': [64, 128, 128] - } - net = UNETR_PP(params) - net.double() - - x = np.random.rand(2, 1, 64, 128, 128) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = net(xt) - print(len(y)) - for yi in y: - yi = yi.detach().numpy() - print(yi.shape) \ No newline at end of file + depths = [128, 64, 32] + for i in range(3): + params = {'in_chns': 4, + 'class_num': 2, + 'img_size': [depths[i], 128, 128], + 'resolution_mode': i + } + net = UNETR_PP(params) + net.double() + + x = np.random.rand(2, 4, depths[i], 128, 128) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = net(xt) + print(len(y)) + for yi in y: + yi = yi.detach().numpy() + print(yi.shape) \ No newline at end of file From c57f50e442a338b8be0d3e10a7d2cc51b1f15981 Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 13 Sep 2023 10:50:22 +0800 Subject: [PATCH 164/225] update crop update functions for random crop --- pymic/transform/crop.py | 61 +++++++++++++++++++++++++------------ pymic/util/image_process.py | 50 +++++++++++++++++++++++------- 2 files changed, 81 insertions(+), 30 deletions(-) diff --git a/pymic/transform/crop.py b/pymic/transform/crop.py index 36a0ca9..b4d0b63 100644 --- a/pymic/transform/crop.py +++ b/pymic/transform/crop.py @@ -255,7 +255,7 @@ def _get_crop_param(self, sample): else: mask_label = self.mask_label random_label = random.choice(mask_label) - crop_min, crop_max = get_random_box_from_mask(label == random_label, self.output_size) + crop_min, crop_max = get_random_box_from_mask(label == random_label, self.output_size, mode = 1) crop_min = [0] + crop_min crop_max = [chns] + crop_max @@ -289,8 +289,11 @@ class RandomResizedCrop(CenterCrop): """ def __init__(self, params): self.output_size = params['RandomResizedCrop_output_size'.lower()] - self.scale_lower = params['RandomResizedCrop_scale_lower_bound'.lower()] - self.scale_upper = params['RandomResizedCrop_scale_upper_bound'.lower()] + self.scale_lower = params['RandomResizedCrop_resize_lower_bound'.lower()] + self.scale_upper = params['RandomResizedCrop_resize_upper_bound'.lower()] + self.prob = params.get('RandomResizedCrop_resize_prob'.lower(), 0.5) + self.fg_ratio = params.get('RandomResizedCrop_foreground_ratio'.lower(), 0.0) + self.mask_label = params.get('RandomResizedCrop_mask_label'.lower(), None) self.inverse = params.get('RandomResizedCrop_inverse'.lower(), False) self.task = params['Task'.lower()] assert isinstance(self.output_size, (list, tuple)) @@ -302,14 +305,19 @@ def __call__(self, sample): channel, input_size = image.shape[0], image.shape[1:] input_dim = len(input_size) assert(input_dim == len(self.output_size)) - scale = [self.scale_lower[i] + (self.scale_upper[i] - self.scale_lower[i]) * random.random() \ - for i in range(input_dim)] - - crop_size = [int(self.output_size[i] * scale[i]) for i in range(input_dim)] + + # get the resized crop size + resize = random.random() < self.prob + if(resize): + scale = [self.scale_lower[i] + (self.scale_upper[i] - self.scale_lower[i]) * random.random() \ + for i in range(input_dim)] + crop_size = [int(self.output_size[i] * scale[i]) for i in range(input_dim)] + else: + crop_size = self.output_size + crop_margin = [input_size[i] - crop_size[i] for i in range(input_dim)] - pad_image = False - if(min(crop_margin) < 0): - pad_image = True + pad_image = min(crop_margin) < 0 + if(pad_image): # pad the image if necessary pad_size = [max(0, -crop_margin[i]) for i in range(input_dim)] pad_lower = [int(pad_size[i] / 2) for i in range(input_dim)] pad_upper = [pad_size[i] - pad_lower[i] for i in range(input_dim)] @@ -317,16 +325,29 @@ def __call__(self, sample): pad = tuple([(0, 0)] + pad) image = np.pad(image, pad, 'reflect') crop_margin = [max(0, crop_margin[i]) for i in range(input_dim)] - - crop_min = [random.randint(0, item) for item in crop_margin] - crop_max = [crop_min[i] + crop_size[i] for i in range(input_dim)] + # ge the bounding box for crop + if(random.random() < self.fg_ratio): + label = sample['label'] + if(pad_image): + label = np.pad(label, pad, 'reflect') + label = label[0] + if(self.mask_label is None): + mask_label = np.unique(label)[1:] + else: + mask_label = self.mask_label + random_label = random.choice(mask_label) + crop_min, crop_max = get_random_box_from_mask(label == random_label, crop_size, mode = 1) + else: + crop_min = [random.randint(0, item) for item in crop_margin] + crop_max = [crop_min[i] + crop_size[i] for i in range(input_dim)] crop_min = [0] + crop_min crop_max = [channel] + crop_max image_t = crop_ND_volume_with_bounding_box(image, crop_min, crop_max) - scale = [(self.output_size[i] + 0.0)/crop_size[i] for i in range(input_dim)] - scale = [1.0] + scale - image_t = ndimage.interpolation.zoom(image_t, scale, order = 1) + if(resize): + scale = [(self.output_size[i] + 0.0)/crop_size[i] for i in range(input_dim)] + scale = [1.0] + scale + image_t = ndimage.interpolation.zoom(image_t, scale, order = 1) sample['image'] = image_t if('label' in sample and \ @@ -336,8 +357,9 @@ def __call__(self, sample): label = np.pad(label, pad, 'reflect') crop_max[0] = label.shape[0] label = crop_ND_volume_with_bounding_box(label, crop_min, crop_max) - order = 0 if(self.task == TaskType.SEGMENTATION) else 1 - label = ndimage.interpolation.zoom(label, scale, order = order) + if(resize): + order = 0 if(self.task == TaskType.SEGMENTATION) else 1 + label = ndimage.interpolation.zoom(label, scale, order = order) sample['label'] = label if('pixel_weight' in sample and \ self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): @@ -346,6 +368,7 @@ def __call__(self, sample): weight = np.pad(weight, pad, 'reflect') crop_max[0] = weight.shape[0] weight = crop_ND_volume_with_bounding_box(weight, crop_min, crop_max) - weight = ndimage.interpolation.zoom(weight, scale, order = 1) + if(resize): + weight = ndimage.interpolation.zoom(weight, scale, order = 1) sample['pixel_weight'] = weight return sample \ No newline at end of file diff --git a/pymic/util/image_process.py b/pymic/util/image_process.py index 4143f39..d6a7220 100644 --- a/pymic/util/image_process.py +++ b/pymic/util/image_process.py @@ -167,18 +167,46 @@ def random_crop_ND_volume(volume, out_shape): crop_volume = crop_ND_volume_with_bounding_box(image_pad, bb_min, bb_max) return crop_volume -def get_random_box_from_mask(mask, out_shape): - indexes = np.where(mask) - voxel_num = len(indexes[0]) - dim = len(out_shape) - left_bound = [int(out_shape[i]/2) for i in range(dim)] - right_bound = [mask.shape[i] - (out_shape[i] - left_bound[i]) for i in range(dim)] +def get_random_box_from_mask(mask, out_shape, mode = 0): + """ + get a bounding box of a subvolume according to a mask + + mode == 0: The output bounding box should be a sub region of the mask region + mode == 1: The center point of the output bounding box can be ahy where of the mask region + """ + dim = len(out_shape) + left_margin = [int(out_shape[i]/2) for i in range(dim)] + right_margin = [out_shape[i] - left_margin[i] for i in range(dim)] + + if(mode == 0): + bb_mask_min, bb_mask_max = get_ND_bounding_box(mask) + bb_valid_min, bb_valid_max = [], [] + for i in range(dim): + mask_size = bb_mask_max[i] - bb_mask_min[i] + if(mask_size > out_shape[i]): + valid_left = bb_mask_min[i] + left_margin[i] + valid_right = bb_mask_max[i] - right_margin[i] + else: + valid_left = (bb_mask_max[i] - bb_mask_min[i]) // 2 + valid_right = valid_left + 1 + bb_valid_min.append(valid_left) + bb_valid_max.append(valid_right) + + valid_region_shape = [bb_valid_max[i] - bb_valid_min[i] for i in range(dim)] + valid_mask = np.zeros_like(mask) + valid_mask = set_ND_volume_roi_with_bounding_box_range(valid_mask, + bb_valid_min, bb_valid_max, np.ones(valid_region_shape, np.bool), addition = True) + valid_mask = valid_mask * mask + else: + valid_mask = mask + indices = np.where(valid_mask) + voxel_num = len(indices[0]) j = random.randint(0, voxel_num - 1) - bb_c = [int(indexes[i][j]) for i in range(dim)] - bb_c = [max(left_bound[i], bb_c[i]) for i in range(dim)] - bb_c = [min(right_bound[i], bb_c[i]) for i in range(dim)] - bb_min = [bb_c[i] - left_bound[i] for i in range(dim)] + bb_c = [int(indices[i][j]) for i in range(dim)] + bb_min = [max(0, bb_c[i] - left_margin[i]) for i in range(dim)] + mask_shape = np.shape(mask) + bb_min = [min(bb_min[i], mask_shape[i] - out_shape[i]) for i in range(dim)] bb_max = [bb_min[i] + out_shape[i] for i in range(dim)] return bb_min, bb_max @@ -205,7 +233,7 @@ def random_crop_ND_volume_with_mask(volume, out_shape, mask): pad = [(ml[i], mr[i]) for i in range(dim)] pad = tuple(pad) image_pad = np.pad(volume, pad, 'reflect') - mask_pad = np.pad(mask, pad, 'reflect') + mask_pad = np.pad(mask, pad, 'constant') bb_min, bb_max = get_random_box_from_mask(mask_pad, out_shape) # left_margin = [int(out_shape[i]/2) for i in range(dim)] From fbf1b26d6f26a4fb79be6741a2d4344e5b6dfb76 Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 30 Oct 2023 16:21:36 +0800 Subject: [PATCH 165/225] update 2D transformers update 2D transformers --- pymic/net/net2d/trans2d/__init__.py | 0 pymic/net/net2d/trans2d/swinunet.py | 124 ++++ pymic/net/net2d/trans2d/swinunet_sys.py | 749 ++++++++++++++++++++ pymic/net/net2d/trans2d/transunet.py | 491 +++++++++++++ pymic/net/net2d/trans2d/transunet_cfg.py | 135 ++++ pymic/net/net2d/trans2d/transunet_resnet.py | 164 +++++ 6 files changed, 1663 insertions(+) create mode 100644 pymic/net/net2d/trans2d/__init__.py create mode 100644 pymic/net/net2d/trans2d/swinunet.py create mode 100644 pymic/net/net2d/trans2d/swinunet_sys.py create mode 100644 pymic/net/net2d/trans2d/transunet.py create mode 100644 pymic/net/net2d/trans2d/transunet_cfg.py create mode 100644 pymic/net/net2d/trans2d/transunet_resnet.py diff --git a/pymic/net/net2d/trans2d/__init__.py b/pymic/net/net2d/trans2d/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pymic/net/net2d/trans2d/swinunet.py b/pymic/net/net2d/trans2d/swinunet.py new file mode 100644 index 0000000..f35539a --- /dev/null +++ b/pymic/net/net2d/trans2d/swinunet.py @@ -0,0 +1,124 @@ +# -*- coding: utf-8 -*- +""" +code adapted from: https://github.com/HuCaoFighting/Swin-Unet + +""" +from __future__ import print_function, division + +import copy +import numpy as np +import torch +import torch.nn as nn + +from pymic.net.net2d.trans2d.swinunet_sys import SwinTransformerSys + +class SwinUNet(nn.Module): + """ + Implementatin of Swin-UNet. + + * Reference: Hu Cao, Yueyue Wang et al: + Swin-Unet: Unet-Like Pure Transformer for Medical Image Segmentation. + `ECCV 2022 Workshops. `_ + + Note that the input channel can only be 1 or 3, and the input image size should be 224x224. + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param img_size: (tuple) The input image size, should be [224, 224]. + :param class_num: (int) The class number for segmentation task. + """ + def __init__(self, params): + super(SwinUNet, self).__init__() + img_size = params['img_size'] + if(isinstance(img_size, tuple) or isinstance(img_size, list)): + img_size = img_size[0] + self.num_classes = params['class_num'] + self.swin_unet = SwinTransformerSys(img_size = img_size, num_classes=self.num_classes) + # self.swin_unet = SwinTransformerSys(img_size=config.DATA.IMG_SIZE, + # patch_size=config.MODEL.SWIN.PATCH_SIZE, + # in_chans=config.MODEL.SWIN.IN_CHANS, + # num_classes=self.num_classes, + # embed_dim=config.MODEL.SWIN.EMBED_DIM, + # depths=config.MODEL.SWIN.DEPTHS, + # num_heads=config.MODEL.SWIN.NUM_HEADS, + # window_size=config.MODEL.SWIN.WINDOW_SIZE, + # mlp_ratio=config.MODEL.SWIN.MLP_RATIO, + # qkv_bias=config.MODEL.SWIN.QKV_BIAS, + # qk_scale=config.MODEL.SWIN.QK_SCALE, + # drop_rate=config.MODEL.DROP_RATE, + # drop_path_rate=config.MODEL.DROP_PATH_RATE, + # ape=config.MODEL.SWIN.APE, + # patch_norm=config.MODEL.SWIN.PATCH_NORM, + # use_checkpoint=config.TRAIN.USE_CHECKPOINT) + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + + if x.size()[1] == 1: + x = x.repeat(1,3,1,1) + logits = self.swin_unet(x) + + if(len(x_shape) == 5): + new_shape = [N, D] + list(logits.shape)[1:] + logits = torch.reshape(logits, new_shape) + logits = torch.transpose(logits, 1, 2) + + return logits + + def load_from(self, config): + pretrained_path = config.MODEL.PRETRAIN_CKPT + if pretrained_path is not None: + print("pretrained_path:{}".format(pretrained_path)) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + pretrained_dict = torch.load(pretrained_path, map_location=device) + if "model" not in pretrained_dict: + print("---start load pretrained modle by splitting---") + pretrained_dict = {k[17:]:v for k,v in pretrained_dict.items()} + for k in list(pretrained_dict.keys()): + if "output" in k: + print("delete key:{}".format(k)) + del pretrained_dict[k] + msg = self.swin_unet.load_state_dict(pretrained_dict,strict=False) + # print(msg) + return + pretrained_dict = pretrained_dict['model'] + print("---start load pretrained modle of swin encoder---") + + model_dict = self.swin_unet.state_dict() + full_dict = copy.deepcopy(pretrained_dict) + for k, v in pretrained_dict.items(): + if "layers." in k: + current_layer_num = 3-int(k[7:8]) + current_k = "layers_up." + str(current_layer_num) + k[8:] + full_dict.update({current_k:v}) + for k in list(full_dict.keys()): + if k in model_dict: + if full_dict[k].shape != model_dict[k].shape: + print("delete:{};shape pretrain:{};shape model:{}".format(k,v.shape,model_dict[k].shape)) + del full_dict[k] + + msg = self.swin_unet.load_state_dict(full_dict, strict=False) + # print(msg) + else: + print("none pretrain") + + +if __name__ == "__main__": + params = {'img_size': [224, 224], + 'class_num': 2} + net = SwinUNet(params) + net.double() + + x = np.random.rand(4, 3, 224, 224) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = net(xt) + print(len(y.size())) + y = y.detach().numpy() + print(y.shape) \ No newline at end of file diff --git a/pymic/net/net2d/trans2d/swinunet_sys.py b/pymic/net/net2d/trans2d/swinunet_sys.py new file mode 100644 index 0000000..a6e3552 --- /dev/null +++ b/pymic/net/net2d/trans2d/swinunet_sys.py @@ -0,0 +1,749 @@ +# -*- coding: utf-8 -*- +""" +code adapted from: https://github.com/HuCaoFighting/Swin-Unet + +""" +from __future__ import print_function, division + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint +from einops import rearrange +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + +class PatchExpand(nn.Module): + def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.expand = nn.Linear(dim, 2*dim, bias=False) if dim_scale==2 else nn.Identity() + self.norm = norm_layer(dim // dim_scale) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + x = self.expand(x) + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4) + x = x.view(B,-1,C//4) + x= self.norm(x) + + return x + +class FinalPatchExpand_X4(nn.Module): + def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.dim_scale = dim_scale + self.expand = nn.Linear(dim, 16*dim, bias=False) + self.output_dim = dim + self.norm = norm_layer(self.output_dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + x = self.expand(x) + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//(self.dim_scale**2)) + x = x.view(B,-1,self.output_dim) + x= self.norm(x) + + return x + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + +class BasicLayer_up(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if upsample is not None: + self.upsample = PatchExpand(input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer) + else: + self.upsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.upsample is not None: + x = self.upsample(x) + return x + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + + +class SwinTransformerSys(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, + embed_dim=96, depths=[2, 2, 2, 2], depths_decoder=[1, 2, 2, 2], num_heads=[3, 6, 12, 24], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, final_upsample="expand_first", **kwargs): + super().__init__() + + print("SwinTransformerSys expand initial----depths:{};depths_decoder:{};drop_path_rate:{};num_classes:{}".format(depths, + depths_decoder,drop_path_rate,num_classes)) + + self.num_classes = num_classes + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.num_features_up = int(embed_dim * 2) + self.mlp_ratio = mlp_ratio + self.final_upsample = final_upsample + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build encoder and bottleneck layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), + input_resolution=(patches_resolution[0] // (2 ** i_layer), + patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) + self.layers.append(layer) + + # build decoder layers + self.layers_up = nn.ModuleList() + self.concat_back_dim = nn.ModuleList() + for i_layer in range(self.num_layers): + concat_linear = nn.Linear(2*int(embed_dim*2**(self.num_layers-1-i_layer)), + int(embed_dim*2**(self.num_layers-1-i_layer))) if i_layer > 0 else nn.Identity() + if i_layer ==0 : + layer_up = PatchExpand(input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)), + patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), dim_scale=2, norm_layer=norm_layer) + else: + layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), + input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)), + patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), + depth=depths[(self.num_layers-1-i_layer)], + num_heads=num_heads[(self.num_layers-1-i_layer)], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:(self.num_layers-1-i_layer)]):sum(depths[:(self.num_layers-1-i_layer) + 1])], + norm_layer=norm_layer, + upsample=PatchExpand if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) + self.layers_up.append(layer_up) + self.concat_back_dim.append(concat_linear) + + self.norm = norm_layer(self.num_features) + self.norm_up= norm_layer(self.embed_dim) + + if self.final_upsample == "expand_first": + print("---final upsample expand_first---") + self.up = FinalPatchExpand_X4(input_resolution=(img_size//patch_size,img_size//patch_size),dim_scale=4,dim=embed_dim) + self.output = nn.Conv2d(in_channels=embed_dim,out_channels=self.num_classes,kernel_size=1,bias=False) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + #Encoder and Bottleneck + def forward_features(self, x): + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + x_downsample = [] + + for layer in self.layers: + x_downsample.append(x) + x = layer(x) + + x = self.norm(x) # B L C + + return x, x_downsample + + #Dencoder and Skip connection + def forward_up_features(self, x, x_downsample): + for inx, layer_up in enumerate(self.layers_up): + if inx == 0: + x = layer_up(x) + else: + x = torch.cat([x,x_downsample[3-inx]],-1) + x = self.concat_back_dim[inx](x) + x = layer_up(x) + + x = self.norm_up(x) # B L C + + return x + + def up_x4(self, x): + H, W = self.patches_resolution + B, L, C = x.shape + assert L == H*W, "input features has wrong size" + + if self.final_upsample=="expand_first": + x = self.up(x) + x = x.view(B,4*H,4*W,-1) + x = x.permute(0,3,1,2) #B,C,H,W + x = self.output(x) + + return x + + def forward(self, x): + x, x_downsample = self.forward_features(x) + x = self.forward_up_features(x,x_downsample) + x = self.up_x4(x) + + return x + + def flops(self): + flops = 0 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) + flops += self.num_features * self.num_classes + return \ No newline at end of file diff --git a/pymic/net/net2d/trans2d/transunet.py b/pymic/net/net2d/trans2d/transunet.py new file mode 100644 index 0000000..9db5d2d --- /dev/null +++ b/pymic/net/net2d/trans2d/transunet.py @@ -0,0 +1,491 @@ +# -*- coding: utf-8 -*- +""" +code adapted from: https://github.com/Beckschen/TransUNet +""" +from __future__ import print_function, division + +import copy +# import logging +import math +import torch +import torch.nn as nn +from torch.nn import Dropout, Softmax, Linear, Conv2d, LayerNorm +from torch.nn.modules.utils import _pair + +import numpy as np +from scipy import ndimage +from os.path import join as pjoin +import pymic.net.net2d.trans2d.transunet_cfg as configs +from pymic.net.net2d.trans2d.transunet_resnet import ResNetV2 + + +VIT_CONFIGS = { + 'ViT-B_16': configs.get_b16_config(), + 'ViT-B_32': configs.get_b32_config(), + 'ViT-L_16': configs.get_l16_config(), + 'ViT-L_32': configs.get_l32_config(), + 'ViT-H_14': configs.get_h14_config(), + 'R50-ViT-B_16': configs.get_r50_b16_config(), + 'R50-ViT-L_16': configs.get_r50_l16_config(), + 'testing': configs.get_testing(), +} + +ATTENTION_Q = "MultiHeadDotProductAttention_1/query" +ATTENTION_K = "MultiHeadDotProductAttention_1/key" +ATTENTION_V = "MultiHeadDotProductAttention_1/value" +ATTENTION_OUT = "MultiHeadDotProductAttention_1/out" +FC_0 = "MlpBlock_3/Dense_0" +FC_1 = "MlpBlock_3/Dense_1" +ATTENTION_NORM = "LayerNorm_0" +MLP_NORM = "LayerNorm_2" + + +def np2th(weights, conv=False): + """Possibly convert HWIO to OIHW.""" + if conv: + weights = weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(weights) + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} + +class Attention(nn.Module): + def __init__(self, config, vis): + super(Attention, self).__init__() + self.vis = vis + self.num_attention_heads = config.transformer["num_heads"] + self.attention_head_size = int(config.hidden_size / self.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = Linear(config.hidden_size, self.all_head_size) + self.key = Linear(config.hidden_size, self.all_head_size) + self.value = Linear(config.hidden_size, self.all_head_size) + + self.out = Linear(config.hidden_size, config.hidden_size) + self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"]) + self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"]) + + self.softmax = Softmax(dim=-1) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + attention_probs = self.softmax(attention_scores) + weights = attention_probs if self.vis else None + attention_probs = self.attn_dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + attention_output = self.out(context_layer) + attention_output = self.proj_dropout(attention_output) + return attention_output, weights + +class Mlp(nn.Module): + def __init__(self, config): + super(Mlp, self).__init__() + self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"]) + self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size) + self.act_fn = ACT2FN["gelu"] + self.dropout = Dropout(config.transformer["dropout_rate"]) + + self._init_weights() + + def _init_weights(self): + nn.init.xavier_uniform_(self.fc1.weight) + nn.init.xavier_uniform_(self.fc2.weight) + nn.init.normal_(self.fc1.bias, std=1e-6) + nn.init.normal_(self.fc2.bias, std=1e-6) + + def forward(self, x): + x = self.fc1(x) + x = self.act_fn(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + +class Embeddings(nn.Module): + """Construct the embeddings from patch, position embeddings. + """ + def __init__(self, config, img_size, in_channels=3): + super(Embeddings, self).__init__() + self.hybrid = None + self.config = config + img_size = _pair(img_size) + + if config.patches.get("grid") is not None: # ResNet + grid_size = config.patches["grid"] + patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1]) + patch_size_real = (patch_size[0] * 16, patch_size[1] * 16) + n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1]) + self.hybrid = True + else: + patch_size = _pair(config.patches["size"]) + n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) + self.hybrid = False + + if self.hybrid: + self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor) + in_channels = self.hybrid_model.width * 16 + self.patch_embeddings = Conv2d(in_channels=in_channels, + out_channels=config.hidden_size, + kernel_size=patch_size, + stride=patch_size) + self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size)) + + self.dropout = Dropout(config.transformer["dropout_rate"]) + + + def forward(self, x): + if self.hybrid: + x, features = self.hybrid_model(x) + else: + features = None + x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2)) + x = x.flatten(2) + x = x.transpose(-1, -2) # (B, n_patches, hidden) + + embeddings = x + self.position_embeddings + embeddings = self.dropout(embeddings) + return embeddings, features + +class Block(nn.Module): + def __init__(self, config, vis): + super(Block, self).__init__() + self.hidden_size = config.hidden_size + self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6) + self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6) + self.ffn = Mlp(config) + self.attn = Attention(config, vis) + + def forward(self, x): + h = x + x = self.attention_norm(x) + x, weights = self.attn(x) + x = x + h + + h = x + x = self.ffn_norm(x) + x = self.ffn(x) + x = x + h + return x, weights + + def load_from(self, weights, n_block): + ROOT = f"Transformer/encoderblock_{n_block}" + with torch.no_grad(): + query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t() + key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t() + value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t() + out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t() + + query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1) + key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1) + value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1) + out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1) + + self.attn.query.weight.copy_(query_weight) + self.attn.key.weight.copy_(key_weight) + self.attn.value.weight.copy_(value_weight) + self.attn.out.weight.copy_(out_weight) + self.attn.query.bias.copy_(query_bias) + self.attn.key.bias.copy_(key_bias) + self.attn.value.bias.copy_(value_bias) + self.attn.out.bias.copy_(out_bias) + + mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t() + mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t() + mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t() + mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t() + + self.ffn.fc1.weight.copy_(mlp_weight_0) + self.ffn.fc2.weight.copy_(mlp_weight_1) + self.ffn.fc1.bias.copy_(mlp_bias_0) + self.ffn.fc2.bias.copy_(mlp_bias_1) + + self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")])) + self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")])) + self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")])) + self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")])) + +class Encoder(nn.Module): + def __init__(self, config, vis): + super(Encoder, self).__init__() + self.vis = vis + self.layer = nn.ModuleList() + self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6) + for _ in range(config.transformer["num_layers"]): + layer = Block(config, vis) + self.layer.append(copy.deepcopy(layer)) + + def forward(self, hidden_states): + attn_weights = [] + for layer_block in self.layer: + hidden_states, weights = layer_block(hidden_states) + if self.vis: + attn_weights.append(weights) + encoded = self.encoder_norm(hidden_states) + return encoded, attn_weights + +class Transformer(nn.Module): + def __init__(self, config, img_size, vis): + super(Transformer, self).__init__() + self.embeddings = Embeddings(config, img_size=img_size) + self.encoder = Encoder(config, vis) + + def forward(self, input_ids): + embedding_output, features = self.embeddings(input_ids) + encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden) + return encoded, attn_weights, features + +class Conv2dReLU(nn.Sequential): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + padding=0, + stride=1, + use_batchnorm=True, + ): + conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=not (use_batchnorm), + ) + relu = nn.ReLU(inplace=True) + + bn = nn.BatchNorm2d(out_channels) + + super(Conv2dReLU, self).__init__(conv, bn, relu) + + +class DecoderBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + skip_channels=0, + use_batchnorm=True, + ): + super().__init__() + self.conv1 = Conv2dReLU( + in_channels + skip_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.conv2 = Conv2dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.up = nn.UpsamplingBilinear2d(scale_factor=2) + + def forward(self, x, skip=None): + x = self.up(x) + if skip is not None: + x = torch.cat([x, skip], dim=1) + x = self.conv1(x) + x = self.conv2(x) + return x + + +class SegmentationHead(nn.Sequential): + + def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1): + conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) + upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity() + super().__init__(conv2d, upsampling) + + +class DecoderCup(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + head_channels = 512 + self.conv_more = Conv2dReLU( + config.hidden_size, + head_channels, + kernel_size=3, + padding=1, + use_batchnorm=True, + ) + decoder_channels = config.decoder_channels + in_channels = [head_channels] + list(decoder_channels[:-1]) + out_channels = decoder_channels + + if self.config.n_skip != 0: + skip_channels = self.config.skip_channels + for i in range(4-self.config.n_skip): # re-select the skip channels according to n_skip + skip_channels[3-i]=0 + + else: + skip_channels=[0,0,0,0] + + blocks = [ + DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels) + ] + self.blocks = nn.ModuleList(blocks) + + def forward(self, hidden_states, features=None): + B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden) + h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch)) + x = hidden_states.permute(0, 2, 1) + x = x.contiguous().view(B, hidden, h, w) + x = self.conv_more(x) + for i, decoder_block in enumerate(self.blocks): + if features is not None: + skip = features[i] if (i < self.config.n_skip) else None + else: + skip = None + x = decoder_block(x, skip=skip) + return x + +class TransUNet(nn.Module): + """ + Implementatin of TransUNet. + + * Reference: Jieneng Chen, Yongyi Lu et al: + TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation. + `Arxiv 2021. `_ + + Note that the input channel can only be 1 or 3, and the input image size should be 256x256. + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param img_size: (tuple) The input image size, should be [256, 256]. + :param class_num: (int) The class number for segmentation task. + :param vit_name: (string) The name for vit backbone. It can be one of the following: 'ViT-B_16', + 'ViT-B_32','ViT-L_16', 'ViT-L_32', 'ViT-H_14'. 'R50-ViT-B_16', 'R50-ViT-L_16'. + By default, it is 'R50-ViT-B_16'. + """ + def __init__(self, params): + super(TransUNet, self).__init__() + vit_name = params.get("vit_name", 'R50-ViT-B_16') + img_size = params['img_size'] + vis = params.get("vis", False) + self.config = VIT_CONFIGS[vit_name] + self.num_classes = params['class_num'] + self.zero_head = params.get("zero_head", False) + + self.classifier = self.config.classifier + self.transformer = Transformer(self.config, img_size, vis) + self.decoder = DecoderCup(self.config) + self.segmentation_head = SegmentationHead( + in_channels=self.config['decoder_channels'][-1], + out_channels=self.num_classes, + kernel_size=3, + ) + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + + if x.size()[1] == 1: + x = x.repeat(1,3,1,1) + elif(x.size()[1] !=3): + raise ValueError("The input channel number should be 1 or 3 for TransUNet") + x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden) + x = self.decoder(x, features) + logits = self.segmentation_head(x) + + if(len(x_shape) == 5): + new_shape = [N, D] + list(logits.shape)[1:] + logits = torch.reshape(logits, new_shape) + logits = torch.transpose(logits, 1, 2) + + return logits + + def load_from(self, weights): + with torch.no_grad(): + + res_weight = weights + self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True)) + self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"])) + + self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"])) + self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"])) + + posemb = np2th(weights["Transformer/posembed_input/pos_embedding"]) + + posemb_new = self.transformer.embeddings.position_embeddings + if posemb.size() == posemb_new.size(): + self.transformer.embeddings.position_embeddings.copy_(posemb) + elif posemb.size()[1]-1 == posemb_new.size()[1]: + posemb = posemb[:, 1:] + self.transformer.embeddings.position_embeddings.copy_(posemb) + else: + # logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size())) + ntok_new = posemb_new.size(1) + if self.classifier == "seg": + _, posemb_grid = posemb[:, :1], posemb[0, 1:] + gs_old = int(np.sqrt(len(posemb_grid))) + gs_new = int(np.sqrt(ntok_new)) + print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new)) + posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) + zoom = (gs_new / gs_old, gs_new / gs_old, 1) + posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) # th2np + posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) + posemb = posemb_grid + self.transformer.embeddings.position_embeddings.copy_(np2th(posemb)) + + # Encoder whole + for bname, block in self.transformer.encoder.named_children(): + for uname, unit in block.named_children(): + unit.load_from(weights, n_block=uname) + + if self.transformer.embeddings.hybrid: + self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(res_weight["conv_root/kernel"], conv=True)) + gn_weight = np2th(res_weight["gn_root/scale"]).view(-1) + gn_bias = np2th(res_weight["gn_root/bias"]).view(-1) + self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight) + self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias) + + for bname, block in self.transformer.embeddings.hybrid_model.body.named_children(): + for uname, unit in block.named_children(): + unit.load_from(res_weight, n_block=bname, n_unit=uname) + +if __name__ == "__main__": + params = {'img_size': [256, 256], + 'class_num': 2} + net = TransUNet(params) + net.double() + + for c in [1,3]: + x = np.random.rand(4, c, 256, 256) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = net(xt) + print(len(y.size())) + y = y.detach().numpy() + print(y.shape) \ No newline at end of file diff --git a/pymic/net/net2d/trans2d/transunet_cfg.py b/pymic/net/net2d/trans2d/transunet_cfg.py new file mode 100644 index 0000000..aab62d4 --- /dev/null +++ b/pymic/net/net2d/trans2d/transunet_cfg.py @@ -0,0 +1,135 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +""" +code adapted from: https://github.com/Beckschen/TransUNet +""" +import ml_collections + +def get_b16_config(): + """Returns the ViT-B/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.hidden_size = 768 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 3072 + config.transformer.num_heads = 12 + config.transformer.num_layers = 12 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + + config.classifier = 'seg' + config.representation_size = None + config.resnet_pretrained_path = None + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_16.npz' + config.patch_size = 16 + + config.decoder_channels = (256, 128, 64, 16) + config.n_classes = 2 + config.activation = 'softmax' + return config + + +def get_testing(): + """Returns a minimal configuration for testing.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.hidden_size = 1 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 1 + config.transformer.num_heads = 1 + config.transformer.num_layers = 1 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.classifier = 'token' + config.representation_size = None + return config + +def get_r50_b16_config(): + """Returns the Resnet50 + ViT-B/16 configuration.""" + config = get_b16_config() + config.patches.grid = (16, 16) + config.resnet = ml_collections.ConfigDict() + config.resnet.num_layers = (3, 4, 9) + config.resnet.width_factor = 1 + + config.classifier = 'seg' + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz' + config.decoder_channels = (256, 128, 64, 16) + config.skip_channels = [512, 256, 64, 16] + config.n_classes = 2 + config.n_skip = 3 + config.activation = 'softmax' + + return config + + +def get_b32_config(): + """Returns the ViT-B/32 configuration.""" + config = get_b16_config() + config.patches.size = (32, 32) + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_32.npz' + return config + + +def get_l16_config(): + """Returns the ViT-L/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.hidden_size = 1024 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 4096 + config.transformer.num_heads = 16 + config.transformer.num_layers = 24 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.representation_size = None + + # custom + config.classifier = 'seg' + config.resnet_pretrained_path = None + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-L_16.npz' + config.decoder_channels = (256, 128, 64, 16) + config.n_classes = 2 + config.activation = 'softmax' + return config + + +def get_r50_l16_config(): + """Returns the Resnet50 + ViT-L/16 configuration. customized """ + config = get_l16_config() + config.patches.grid = (16, 16) + config.resnet = ml_collections.ConfigDict() + config.resnet.num_layers = (3, 4, 9) + config.resnet.width_factor = 1 + + config.classifier = 'seg' + config.resnet_pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz' + config.decoder_channels = (256, 128, 64, 16) + config.skip_channels = [512, 256, 64, 16] + config.n_classes = 2 + config.activation = 'softmax' + return config + + +def get_l32_config(): + """Returns the ViT-L/32 configuration.""" + config = get_l16_config() + config.patches.size = (32, 32) + return config + + +def get_h14_config(): + """Returns the ViT-L/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (14, 14)}) + config.hidden_size = 1280 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 5120 + config.transformer.num_heads = 16 + config.transformer.num_layers = 32 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.classifier = 'token' + config.representation_size = None + + return config \ No newline at end of file diff --git a/pymic/net/net2d/trans2d/transunet_resnet.py b/pymic/net/net2d/trans2d/transunet_resnet.py new file mode 100644 index 0000000..144a268 --- /dev/null +++ b/pymic/net/net2d/trans2d/transunet_resnet.py @@ -0,0 +1,164 @@ +# -*- coding: utf-8 -*- +""" +code adapted from: https://github.com/Beckschen/TransUNet +""" +from __future__ import print_function, division + +from os.path import join as pjoin +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def np2th(weights, conv=False): + """Possibly convert HWIO to OIHW.""" + if conv: + weights = weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(weights) + + +class StdConv2d(nn.Conv2d): + + def forward(self, x): + w = self.weight + v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False) + w = (w - m) / torch.sqrt(v + 1e-5) + return F.conv2d(x, w, self.bias, self.stride, self.padding, + self.dilation, self.groups) + + +def conv3x3(cin, cout, stride=1, groups=1, bias=False): + return StdConv2d(cin, cout, kernel_size=3, stride=stride, + padding=1, bias=bias, groups=groups) + + +def conv1x1(cin, cout, stride=1, bias=False): + return StdConv2d(cin, cout, kernel_size=1, stride=stride, + padding=0, bias=bias) + + +class PreActBottleneck(nn.Module): + """Pre-activation (v2) bottleneck block. + """ + + def __init__(self, cin, cout=None, cmid=None, stride=1): + super().__init__() + cout = cout or cin + cmid = cmid or cout//4 + + self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6) + self.conv1 = conv1x1(cin, cmid, bias=False) + self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6) + self.conv2 = conv3x3(cmid, cmid, stride, bias=False) # Original code has it on conv1!! + self.gn3 = nn.GroupNorm(32, cout, eps=1e-6) + self.conv3 = conv1x1(cmid, cout, bias=False) + self.relu = nn.ReLU(inplace=True) + + if (stride != 1 or cin != cout): + # Projection also with pre-activation according to paper. + self.downsample = conv1x1(cin, cout, stride, bias=False) + self.gn_proj = nn.GroupNorm(cout, cout) + + def forward(self, x): + + # Residual branch + residual = x + if hasattr(self, 'downsample'): + residual = self.downsample(x) + residual = self.gn_proj(residual) + + # Unit's branch + y = self.relu(self.gn1(self.conv1(x))) + y = self.relu(self.gn2(self.conv2(y))) + y = self.gn3(self.conv3(y)) + + y = self.relu(residual + y) + return y + + def load_from(self, weights, n_block, n_unit): + conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True) + conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True) + conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True) + + gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")]) + gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")]) + + gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")]) + gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")]) + + gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")]) + gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")]) + + self.conv1.weight.copy_(conv1_weight) + self.conv2.weight.copy_(conv2_weight) + self.conv3.weight.copy_(conv3_weight) + + self.gn1.weight.copy_(gn1_weight.view(-1)) + self.gn1.bias.copy_(gn1_bias.view(-1)) + + self.gn2.weight.copy_(gn2_weight.view(-1)) + self.gn2.bias.copy_(gn2_bias.view(-1)) + + self.gn3.weight.copy_(gn3_weight.view(-1)) + self.gn3.bias.copy_(gn3_bias.view(-1)) + + if hasattr(self, 'downsample'): + proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True) + proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")]) + proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")]) + + self.downsample.weight.copy_(proj_conv_weight) + self.gn_proj.weight.copy_(proj_gn_weight.view(-1)) + self.gn_proj.bias.copy_(proj_gn_bias.view(-1)) + +class ResNetV2(nn.Module): + """Implementation of Pre-activation (v2) ResNet mode.""" + + def __init__(self, block_units, width_factor): + super().__init__() + width = int(64 * width_factor) + self.width = width + + self.root = nn.Sequential(OrderedDict([ + ('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)), + ('gn', nn.GroupNorm(32, width, eps=1e-6)), + ('relu', nn.ReLU(inplace=True)), + # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)) + ])) + + self.body = nn.Sequential(OrderedDict([ + ('block1', nn.Sequential(OrderedDict( + [('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] + + [(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)], + ))), + ('block2', nn.Sequential(OrderedDict( + [('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] + + [(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)], + ))), + ('block3', nn.Sequential(OrderedDict( + [('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] + + [(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)], + ))), + ])) + + def forward(self, x): + features = [] + b, c, in_size, _ = x.size() + x = self.root(x) + features.append(x) + x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x) + for i in range(len(self.body)-1): + x = self.body[i](x) + right_size = int(in_size / 4 / (i+1)) + if x.size()[2] != right_size: + pad = right_size - x.size()[2] + assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), right_size) + feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device) + feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:] + else: + feat = x + features.append(feat) + x = self.body[-1](x) + return x, features[::-1] \ No newline at end of file From 317da9febf192b3cbac047d5b883e80c8855a4c3 Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 31 Oct 2023 16:58:43 +0800 Subject: [PATCH 166/225] add mcnet add mcnet for semi-supervised segmentation --- pymic/net/net2d/unet2d.py | 145 +++++++++++++------------- pymic/net/net2d/unet2d_dual_branch.py | 17 ++- pymic/net/net2d/unet2d_mcnet.py | 69 ++++++++++++ pymic/net/net2d/unet2d_urpc.py | 132 +++++++++++++++++++++++ pymic/net/net_dict_seg.py | 17 +-- pymic/net_run/semi_sup/ssl_mcnet.py | 129 +++++++++++++++++++++++ 6 files changed, 432 insertions(+), 77 deletions(-) create mode 100644 pymic/net/net2d/unet2d_mcnet.py create mode 100644 pymic/net/net2d/unet2d_urpc.py create mode 100644 pymic/net_run/semi_sup/ssl_mcnet.py diff --git a/pymic/net/net2d/unet2d.py b/pymic/net/net2d/unet2d.py index 9acc0ad..758bfe5 100644 --- a/pymic/net/net2d/unet2d.py +++ b/pymic/net/net2d/unet2d.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division +import logging import torch import torch.nn as nn import numpy as np @@ -56,22 +57,33 @@ class UpBlock(nn.Module): :param in_channels2: (int) Channel number of low-level features. :param out_channels: (int) Output channel number. :param dropout_p: (int) Dropout probability. - :param bilinear: (bool) Use bilinear for up-sampling (by default). - If False, deconvolution is used for up-sampling. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Bilinear`), 3 (`Bicubic`). The default value + is 2 (`Bilinear`). """ def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, - bilinear=True): + up_mode = 2): super(UpBlock, self).__init__() - self.bilinear = bilinear - if bilinear: - self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size = 1) - self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + if(isinstance(up_mode, int)): + up_mode_values = ["transconv", "nearest", "bilinear", "bicubic"] + if(up_mode > 3): + raise ValueError("The upsample mode should be 0-3, but {0:} is given.".format(up_mode)) + self.up_mode = up_mode_values[up_mode] else: + self.up_mode = up_mode.lower() + + if (self.up_mode == "transconv"): self.up = nn.ConvTranspose2d(in_channels1, in_channels2, kernel_size=2, stride=2) + else: + self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size = 1) + if(self.up_mode == "nearest"): + self.up = nn.Upsample(scale_factor=2, mode=self.up_mode) + else: + self.up = nn.Upsample(scale_factor=2, mode=self.up_mode, align_corners=True) self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p) def forward(self, x1, x2): - if self.bilinear: + if self.up_mode != "transconv": x1 = self.conv1x1(x1) x1 = self.up(x1) x = torch.cat([x2, x1], dim=1) @@ -129,8 +141,10 @@ class Decoder(nn.Module): :param dropout: (list) The dropout ratio for each resolution level. The length should be the same as that of `feature_chns`. :param class_num: (int) The class number for segmentation task. - :param bilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (or `Nearest`), 2 (or `Bilinear`), 3 (or `Bicubic`). + The default value is 2 (or `Bilinear`). + :param multiscale_pred: (bool) Get multiscale prediction. """ def __init__(self, params): super(Decoder, self).__init__() @@ -139,17 +153,23 @@ def __init__(self, params): self.ft_chns = self.params['feature_chns'] self.dropout = self.params['dropout'] self.n_class = self.params['class_num'] - self.bilinear = self.params['bilinear'] + self.up_mode = self.params['up_mode'] + self.mul_pred = self.params['multiscale_pred'] assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) if(len(self.ft_chns) == 5): - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.bilinear) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.bilinear) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.bilinear) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.bilinear) + self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.up_mode) + self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.up_mode) + self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.up_mode) + self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.up_mode) self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 1) + if(self.mul_pred): + self.out_conv1 = nn.Conv2d(self.ft_chns[1], self.n_class, kernel_size = 1) + self.out_conv2 = nn.Conv2d(self.ft_chns[2], self.n_class, kernel_size = 1) + self.out_conv3 = nn.Conv2d(self.ft_chns[3], self.n_class, kernel_size = 1) + def forward(self, x): if(len(self.ft_chns) == 5): assert(len(x) == 5) @@ -163,6 +183,11 @@ def forward(self, x): x_d1 = self.up3(x_d2, x1) x_d0 = self.up4(x_d1, x0) output = self.out_conv(x_d0) + if(self.mul_pred): + output1 = self.out_conv1(x_d1) + output2 = self.out_conv1(x_d2) + output3 = self.out_conv1(x_d3) + output = [output, output1, output2, output3] return output class UNet2D(nn.Module): @@ -180,43 +205,39 @@ class UNet2D(nn.Module): following fields: :param in_chns: (int) Input channel number. + :param class_num: (int) The class number for segmentation task. + + Optional parameters: + :param feature_chns: (list) Feature channel for each resolution level. The length should be 4 or 5, such as [16, 32, 64, 128, 256]. :param dropout: (list) The dropout ratio for each resolution level. The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param bilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (or `Nearest`), 2 (or `Bilinear`), 3 (or `Bicubic`). + The default value is 2 (or `Bilinear`). :param multiscale_pred: (bool) Get multiscale prediction. """ def __init__(self, params): super(UNet2D, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.n_class = self.params['class_num'] - self.bilinear = self.params['bilinear'] - self.mul_pred = self.params['multiscale_pred'] - - assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) - - self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) - self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) - self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) - self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) - if(len(self.ft_chns) == 5): - self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.bilinear) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.bilinear) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.bilinear) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.bilinear) - - self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 1) - if(self.mul_pred): - self.out_conv1 = nn.Conv2d(self.ft_chns[1], self.n_class, kernel_size = 1) - self.out_conv2 = nn.Conv2d(self.ft_chns[2], self.n_class, kernel_size = 1) - self.out_conv3 = nn.Conv2d(self.ft_chns[3], self.n_class, kernel_size = 1) + params = self.get_default_parameters(params) + for p in params: + print(p, params[p]) + self.encoder = Encoder(params) + self.decoder = Decoder(params) + + def get_default_parameters(self, params): + default_param = { + 'feature_chns': [32, 64, 128, 256, 512], + 'dropout': [0.0, 0.0, 0.2, 0.3, 0.4], + 'up_mode': 2, + 'multiscale_pred': False + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params def forward(self, x): x_shape = list(x.shape) @@ -226,42 +247,26 @@ def forward(self, x): x = torch.transpose(x, 1, 2) x = torch.reshape(x, new_shape) - x0 = self.in_conv(x) - x1 = self.down1(x0) - x2 = self.down2(x1) - x3 = self.down3(x2) - if(len(self.ft_chns) == 5): - x4 = self.down4(x3) - x_d3 = self.up1(x4, x3) - else: - x_d3 = x3 - x_d2 = self.up2(x_d3, x2) - x_d1 = self.up3(x_d2, x1) - x_d0 = self.up4(x_d1, x0) - output = self.out_conv(x_d0) - if(self.mul_pred): - output1 = self.out_conv1(x_d1) - output2 = self.out_conv2(x_d2) - output3 = self.out_conv3(x_d3) - output = [output, output1, output2, output3] - - if(len(x_shape) == 5): + f = self.encoder(x) + output = self.decoder(f) + if(len(x_shape) == 5): + if(isinstance(output, (list,tuple))): for i in range(len(output)): new_shape = [N, D] + list(output[i].shape)[1:] output[i] = torch.transpose(torch.reshape(output[i], new_shape), 1, 2) - elif(len(x_shape) == 5): - new_shape = [N, D] + list(output.shape)[1:] - output = torch.reshape(output, new_shape) - output = torch.transpose(output, 1, 2) - + else: + new_shape = [N, D] + list(output.shape)[1:] + output = torch.transpose(torch.reshape(output, new_shape), 1, 2) + return output + if __name__ == "__main__": params = {'in_chns':4, 'feature_chns':[2, 8, 32, 48, 64], 'dropout': [0, 0, 0.3, 0.4, 0.5], 'class_num': 2, - 'bilinear': True, + 'up_mode': 0, 'multiscale_pred': False} Net = UNet2D(params) Net = Net.double() diff --git a/pymic/net/net2d/unet2d_dual_branch.py b/pymic/net/net2d/unet2d_dual_branch.py index 828bdfe..19a0788 100644 --- a/pymic/net/net2d/unet2d_dual_branch.py +++ b/pymic/net/net2d/unet2d_dual_branch.py @@ -25,11 +25,26 @@ class UNet2D_DualBranch(nn.Module): """ def __init__(self, params): super(UNet2D_DualBranch, self).__init__() - self.output_mode = params.get("output_mode", "average") + params = self.get_default_parameters(params) + self.output_mode = params["output_mode"] self.encoder = Encoder(params) self.decoder1 = Decoder(params) self.decoder2 = Decoder(params) + def get_default_parameters(self, params): + default_param = { + 'feature_chns': [32, 64, 128, 256, 512], + 'dropout': [0.0, 0.0, 0.2, 0.3, 0.4], + 'up_mode': 2, + 'multiscale_pred': False, + 'output_mode': "average" + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params + def forward(self, x): x_shape = list(x.shape) if(len(x_shape) == 5): diff --git a/pymic/net/net2d/unet2d_mcnet.py b/pymic/net/net2d/unet2d_mcnet.py new file mode 100644 index 0000000..c3e8a5f --- /dev/null +++ b/pymic/net/net2d/unet2d_mcnet.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch.nn as nn +from pymic.net.net2d.unet2d import * + + +class MCNet2D(nn.Module): + """ + A tri-branch network using UNet2D as backbone. + + * Reference: Yicheng Wu, Zongyuan Ge et al. Mutual consistency learning for + semi-supervised medical image segmentation. + `MIA 2022. `_ + + The original code is at: https://github.com/ycwu1997/MC-Net + + The parameters for the backbone should be given in the `params` dictionary. + See :mod:`pymic.net.net2d.unet2d.UNet2D` for details. + """ + +class MCNet2D(nn.Module): + def __init__(self, params): + super(MCNet2D, self).__init__() + in_chns = params['in_chns'] + class_num = params['class_num'] + params1 = {'in_chns': in_chns, + 'feature_chns': [16, 32, 64, 128, 256], + 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], + 'class_num': class_num, + 'up_mode': 0, + 'multiscale_pred': False } + params2 = {'in_chns': in_chns, + 'feature_chns': [16, 32, 64, 128, 256], + 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], + 'class_num': class_num, + 'up_mode': 1, + 'multiscale_pred': False} + params3 = {'in_chns': in_chns, + 'feature_chns': [16, 32, 64, 128, 256], + 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], + 'class_num': class_num, + 'up_mode': 2, + 'multiscale_pred': False} + self.encoder = Encoder(params1) + self.decoder1 = Decoder(params1) + self.decoder2 = Decoder(params2) + self.decoder3 = Decoder(params3) + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + + feature = self.encoder(x) + output1 = self.decoder1(feature) + new_shape = [N, D] + list(output1.shape)[1:] + output1 = torch.transpose(torch.reshape(output1, new_shape), 1, 2) + if(not self.training): + return output1 + output2 = self.decoder2(feature) + output3 = self.decoder3(feature) + if(len(x_shape) == 5): + output2 = torch.transpose(torch.reshape(output2, new_shape), 1, 2) + output3 = torch.transpose(torch.reshape(output3, new_shape), 1, 2) + return output1, output2, output3 diff --git a/pymic/net/net2d/unet2d_urpc.py b/pymic/net/net2d/unet2d_urpc.py new file mode 100644 index 0000000..ee8ab7c --- /dev/null +++ b/pymic/net/net2d/unet2d_urpc.py @@ -0,0 +1,132 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import torch.nn as nn +import numpy as np +from torch.distributions.uniform import Uniform +from pymic.net.net2d.unet2d import ConvBlock, DownBlock, UpBlock + +def FeatureDropout(x): + attention = torch.mean(x, dim=1, keepdim=True) + max_val, _ = torch.max(attention.view( + x.size(0), -1), dim=1, keepdim=True) + threshold = max_val * np.random.uniform(0.7, 0.9) + threshold = threshold.view(x.size(0), 1, 1, 1).expand_as(attention) + drop_mask = (attention < threshold).float() + x = x.mul(drop_mask) + return x + +class FeatureNoise(nn.Module): + def __init__(self, uniform_range=0.3): + super(FeatureNoise, self).__init__() + self.uni_dist = Uniform(-uniform_range, uniform_range) + + def feature_based_noise(self, x): + noise_vector = self.uni_dist.sample( + x.shape[1:]).to(x.device).unsqueeze(0) + x_noise = x.mul(noise_vector) + x + return x_noise + + def forward(self, x): + x = self.feature_based_noise(x) + return x + +class UNet2D_URPC(nn.Module): + """ + An modification the U-Net to obtain multi-scale prediction according to + the URPC paper. + + * Reference: Xiangde Luo, Guotai Wang*, Wenjun Liao, Jieneng Chen, Tao Song, Yinan Chen, + Shichuan Zhang, Dimitris N. Metaxas, Shaoting Zhang. + Semi-Supervised Medical Image Segmentation via Uncertainty Rectified Pyramid Consistency . + `Medical Image Analysis 2022. `_ + + Also see: https://github.com/HiLab-git/SSL4MIS/blob/master/code/networks/unet.py + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param bilinear: (bool) Using bilinear for up-sampling or not. + If False, deconvolution will be used for up-sampling. + """ + def __init__(self, params): + super(UNet2D_URPC, self).__init__() + self.params = params + self.in_chns = self.params['in_chns'] + self.ft_chns = self.params['feature_chns'] + self.dropout = self.params['dropout'] + self.n_class = self.params['class_num'] + self.bilinear = self.params['bilinear'] + assert(len(self.ft_chns) == 5) + + self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) + self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) + self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) + self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) + self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) + self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], 0.0, self.bilinear) + self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], 0.0, self.bilinear) + self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], 0.0, self.bilinear) + self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], 0.0, self.bilinear) + + self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, + kernel_size = 3, padding = 1) + self.out_conv_dp4 = nn.Conv2d(self.ft_chns[4], self.n_class, + kernel_size=3, padding=1) + self.out_conv_dp3 = nn.Conv2d(self.ft_chns[3], self.n_class, + kernel_size=3, padding=1) + self.out_conv_dp2 = nn.Conv2d(self.ft_chns[2], self.n_class, + kernel_size=3, padding=1) + self.out_conv_dp1 = nn.Conv2d(self.ft_chns[1], self.n_class, + kernel_size=3, padding=1) + self.feature_noise = FeatureNoise() + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + x0 = self.in_conv(x) + x1 = self.down1(x0) + x2 = self.down2(x1) + x3 = self.down3(x2) + x4 = self.down4(x3) + + x = self.up1(x4, x3) + if self.training: + x = nn.functional.dropout(x, p=0.5) + dp3_out = self.out_conv_dp3(x) + + x = self.up2(x, x2) + if self.training: + x = FeatureDropout(x) + dp2_out = self.out_conv_dp2(x) + + x = self.up3(x, x1) + if self.training: + x = self.feature_noise(x) + dp1_out = self.out_conv_dp1(x) + + x = self.up4(x, x0) + dp0_out = self.out_conv(x) + + out_shape = list(dp0_out.shape)[2:] + dp3_out = nn.functional.interpolate(dp3_out, out_shape) + dp2_out = nn.functional.interpolate(dp2_out, out_shape) + dp1_out = nn.functional.interpolate(dp1_out, out_shape) + out = [dp0_out, dp1_out, dp2_out, dp3_out] + + if(len(x_shape) == 5): + new_shape = [N, D] + list(dp0_out.shape)[1:] + for i in range(len(out)): + out[i] = torch.transpose(torch.reshape(out[i], new_shape), 1, 2) + return out \ No newline at end of file diff --git a/pymic/net/net_dict_seg.py b/pymic/net/net_dict_seg.py index ffaa023..e1e250c 100644 --- a/pymic/net/net_dict_seg.py +++ b/pymic/net/net_dict_seg.py @@ -7,6 +7,7 @@ * UNet2D_CCT :mod:`pymic.net.net2d.unet2d_cct.UNet2D_CCT` * UNet2D_ScSE :mod:`pymic.net.net2d.unet2d_scse.UNet2D_ScSE` * AttentionUNet2D :mod:`pymic.net.net2d.unet2d_attention.AttentionUNet2D` +* MCNet2D :mod:`pymic.net.net2d.unet2d_mcnet.MCNet2D` * NestedUNet2D :mod:`pymic.net.net2d.unet2d_nest.NestedUNet2D` * COPLENet :mod:`pymic.net.net2d.cople_net.COPLENet` * UNet2D5 :mod:`pymic.net.net3d.unet2d5.UNet2D5` @@ -17,12 +18,13 @@ from pymic.net.net2d.unet2d import UNet2D from pymic.net.net2d.unet2d_dual_branch import UNet2D_DualBranch from pymic.net.net2d.unet2d_cct import UNet2D_CCT +from pymic.net.net2d.unet2d_mcnet import MCNet2D from pymic.net.net2d.cople_net import COPLENet from pymic.net.net2d.unet2d_attention import AttentionUNet2D from pymic.net.net2d.unet2d_nest import NestedUNet2D from pymic.net.net2d.unet2d_scse import UNet2D_ScSE -# from pymic.net.net2d.trans2d.transunet import TransUNet -# from pymic.net.net2d.trans2d.swinunet import SwinUNet +from pymic.net.net2d.trans2d.transunet import TransUNet +from pymic.net.net2d.trans2d.swinunet import SwinUNet from pymic.net.net3d.unet2d5 import UNet2D5 from pymic.net.net3d.unet3d import UNet3D from pymic.net.net3d.unet3d_scse import UNet3D_ScSE @@ -39,24 +41,26 @@ # from pymic.net.net3d.trans3d.HiFormer_v3 import HiFormer_v3 # from pymic.net.net3d.trans3d.HiFormer_v4 import HiFormer_v4 # from pymic.net.net3d.trans3d.HiFormer_v5 import HiFormer_v5 +# from pymic.net.net3d.trans3d.SwitchNet import SwitchNet SegNetDict = { 'UNet2D': UNet2D, 'UNet2D_DualBranch': UNet2D_DualBranch, 'UNet2D_CCT': UNet2D_CCT, + 'MCNet2D': MCNet2D, 'COPLENet': COPLENet, 'AttentionUNet2D': AttentionUNet2D, 'NestedUNet2D': NestedUNet2D, 'UNet2D_ScSE': UNet2D_ScSE, - # 'TransUNet': TransUNet, - # 'SwinUNet': SwinUNet, + 'TransUNet': TransUNet, + 'SwinUNet': SwinUNet, 'UNet2D5': UNet2D5, 'UNet3D': UNet3D, 'UNet3D_ScSE': UNet3D_ScSE, 'UNet3D_DualBranch': UNet3D_DualBranch, # 'nnFormer': nnFormer_wrap, - # 'UNETR': UNETR, - # 'UNETR_PP': UNETR_PP, + 'UNETR': UNETR, + 'UNETR_PP': UNETR_PP, # 'MedFormerV1': MedFormerV1, # 'MedFormerV2': MedFormerV2, # 'MedFormerV3': MedFormerV3, @@ -66,4 +70,5 @@ # 'HiFormer_v3': HiFormer_v3, # 'HiFormer_v4': HiFormer_v4, # 'HiFormer_v5': HiFormer_v5 + # 'SwitchNet': SwitchNet } diff --git a/pymic/net_run/semi_sup/ssl_mcnet.py b/pymic/net_run/semi_sup/ssl_mcnet.py new file mode 100644 index 0000000..66e1034 --- /dev/null +++ b/pymic/net_run/semi_sup/ssl_mcnet.py @@ -0,0 +1,129 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import logging +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from pymic.loss.seg.util import get_soft_label +from pymic.loss.seg.util import reshape_prediction_and_ground_truth +from pymic.loss.seg.util import get_classwise_dice +from pymic.net_run.semi_sup import SSLSegAgent +from pymic.util.ramps import get_rampup_ratio + +def sharpening(P, T = 0.1): + T = 1.0/T + P_sharpen = P**T / (P**T + (1-P)**T) + return P_sharpen + +class SSLMCNet(SSLSegAgent): + """ + Mutual Consistency Learning for semi-supervised segmentation. It requires a network + with multiple decoders for learning, such as `pymic.net.net2d.unet2d_mcnet.MCNet2D`. + + * Reference: Yicheng Wu, Zongyuan Ge et al. Mutual consistency learning for + semi-supervised medical image segmentation. + `MIA 2022. `_ + + The original code is at: https://github.com/ycwu1997/MC-Net + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + .. note:: + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. + """ + def training(self): + class_num = self.config['network']['class_num'] + iter_valid = self.config['training']['iter_valid'] + ssl_cfg = self.config['semi_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = ssl_cfg.get('rampup_start', 0) + rampup_end = ssl_cfg.get('rampup_end', iter_max) + temperature = ssl_cfg.get('temperature', 0.1) + unsup_loss_name = ssl_cfg.get('unsupervised_loss', "MSE") + train_loss = 0 + train_loss_sup = 0 + train_loss_reg = 0 + train_dice_list = [] + self.net.train() + + for it in range(iter_valid): + try: + data_lab = next(self.trainIter) + except StopIteration: + self.trainIter = iter(self.train_loader) + data_lab = next(self.trainIter) + try: + data_unlab = next(self.trainIter_unlab) + except StopIteration: + self.trainIter_unlab = iter(self.train_loader_unlab) + data_unlab = next(self.trainIter_unlab) + + # get the inputs + x0 = self.convert_tensor_type(data_lab['image']) + y0 = self.convert_tensor_type(data_lab['label_prob']) + x1 = self.convert_tensor_type(data_unlab['image']) + inputs = torch.cat([x0, x1], dim = 0) + inputs, y0 = inputs.to(self.device), y0.to(self.device) + + # zero the parameter gradients + self.optimizer.zero_grad() + + # forward pass to obtain multiple predictions + outputs = self.net(inputs) + num_outputs = len(outputs) + n0 = list(x0.shape)[0] + p0 = F.softmax(outputs[0], dim=1)[:n0] + # for probability prediction and pseudo respectively + p_ori = torch.zeros((num_outputs,) + outputs[0].shape) + y_psu = torch.zeros((num_outputs,) + outputs[0].shape) + + # get supervised loss + loss_sup = 0 + for idx in range(num_outputs): + p0i = outputs[idx][:n0] + loss_sup += self.get_loss_value(data_lab, p0i, y0) + + # get pseudo labels + p_i = F.softmax(outputs[idx], dim=1) + p_ori[idx] = p_i + y_psu[idx] = sharpening(p_i, temperature) + + # get regularization loss + loss_reg = 0.0 + for i in range(num_outputs): + for j in range(num_outputs): + if (i!=j): + loss_reg += F.mse_loss(p_ori[i], y_psu[j], reduction='mean') + + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + regular_w = ssl_cfg.get('regularize_w', 0.1) * rampup_ratio + loss = loss_sup + regular_w*loss_reg + + loss.backward() + self.optimizer.step() + + train_loss = train_loss + loss.item() + train_loss_sup = train_loss_sup + loss_sup.item() + train_loss_reg = train_loss_reg + loss_reg.item() + # get dice evaluation for each class in annotated images + if(isinstance(p0, tuple) or isinstance(p0, list)): + p0 = p0[0] + p0_argmax = torch.argmax(p0, dim = 1, keepdim = True) + p0_soft = get_soft_label(p0_argmax, class_num, self.tensor_type) + p0_soft, y0 = reshape_prediction_and_ground_truth(p0_soft, y0) + dice_list = get_classwise_dice(p0_soft, y0) + train_dice_list.append(dice_list.cpu().numpy()) + train_avg_loss = train_loss / iter_valid + train_avg_loss_sup = train_loss_sup / iter_valid + train_avg_loss_reg = train_loss_reg / iter_valid + train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) + train_avg_dice = train_cls_dice[1:].mean() + + train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, + 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + return train_scalers From 649e62d59a0582c1d06fc0d7f4ffdc3fb9b1cd3f Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 31 Oct 2023 21:23:43 +0800 Subject: [PATCH 167/225] update docs update docs for mecnet --- docs/source/api.rst | 3 --- docs/source/pymic.net.net2d.rst | 8 ++++++++ docs/source/pymic.net_run.semi_sup.rst | 8 ++++++++ pymic/net/net2d/unet2d_mcnet.py | 7 ++----- pymic/net/net_dict_seg.py | 4 ++-- pymic/net_run/semi_sup/__init__.py | 13 +++++++------ pymic/net_run/semi_sup/ssl_mcnet.py | 9 +++++++-- 7 files changed, 34 insertions(+), 18 deletions(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index 206000d..d09809c 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -9,8 +9,5 @@ API pymic.loss pymic.net pymic.net_run - pymic.net_run_nll - pymic.net_run_ssl - pymic.net_run_wsl pymic.transform pymic.util \ No newline at end of file diff --git a/docs/source/pymic.net.net2d.rst b/docs/source/pymic.net.net2d.rst index d978dfe..bd54bc6 100644 --- a/docs/source/pymic.net.net2d.rst +++ b/docs/source/pymic.net.net2d.rst @@ -52,6 +52,14 @@ pymic.net.net2d.unet2d\_dual\_branch module :undoc-members: :show-inheritance: +pymic.net.net2d.unet2d\_mcnet module +------------------------------------------- + +.. automodule:: pymic.net.net2d.unet2d_mcnet + :members: + :undoc-members: + :show-inheritance: + pymic.net.net2d.unet2d\_nest module ----------------------------------- diff --git a/docs/source/pymic.net_run.semi_sup.rst b/docs/source/pymic.net_run.semi_sup.rst index 6ed157d..15692b2 100644 --- a/docs/source/pymic.net_run.semi_sup.rst +++ b/docs/source/pymic.net_run.semi_sup.rst @@ -28,6 +28,14 @@ pymic.net\_run.semi\_sup.ssl\_cps module :undoc-members: :show-inheritance: +pymic.net\_run.semi\_sup.ssl\_mcnet module +---------------------------------------- + +.. automodule:: pymic.net_run.semi_sup.ssl_mcnet + :members: + :undoc-members: + :show-inheritance: + pymic.net\_run.semi\_sup.ssl\_em module --------------------------------------- diff --git a/pymic/net/net2d/unet2d_mcnet.py b/pymic/net/net2d/unet2d_mcnet.py index c3e8a5f..76ee5de 100644 --- a/pymic/net/net2d/unet2d_mcnet.py +++ b/pymic/net/net2d/unet2d_mcnet.py @@ -4,22 +4,19 @@ import torch.nn as nn from pymic.net.net2d.unet2d import * - class MCNet2D(nn.Module): """ A tri-branch network using UNet2D as backbone. * Reference: Yicheng Wu, Zongyuan Ge et al. Mutual consistency learning for - semi-supervised medical image segmentation. - `MIA 2022. `_ + semi-supervised medical image segmentation. + `Medical Image Analysis 2022. `_ The original code is at: https://github.com/ycwu1997/MC-Net The parameters for the backbone should be given in the `params` dictionary. See :mod:`pymic.net.net2d.unet2d.UNet2D` for details. """ - -class MCNet2D(nn.Module): def __init__(self, params): super(MCNet2D, self).__init__() in_chns = params['in_chns'] diff --git a/pymic/net/net_dict_seg.py b/pymic/net/net_dict_seg.py index e1e250c..8862fd1 100644 --- a/pymic/net/net_dict_seg.py +++ b/pymic/net/net_dict_seg.py @@ -59,8 +59,8 @@ 'UNet3D_ScSE': UNet3D_ScSE, 'UNet3D_DualBranch': UNet3D_DualBranch, # 'nnFormer': nnFormer_wrap, - 'UNETR': UNETR, - 'UNETR_PP': UNETR_PP, + # 'UNETR': UNETR, + # 'UNETR_PP': UNETR_PP, # 'MedFormerV1': MedFormerV1, # 'MedFormerV2': MedFormerV2, # 'MedFormerV3': MedFormerV3, diff --git a/pymic/net_run/semi_sup/__init__.py b/pymic/net_run/semi_sup/__init__.py index be753c2..39ca5dc 100644 --- a/pymic/net_run/semi_sup/__init__.py +++ b/pymic/net_run/semi_sup/__init__.py @@ -1,4 +1,5 @@ from __future__ import absolute_import +# from . import * from pymic.net_run.semi_sup.ssl_abstract import SSLSegAgent from pymic.net_run.semi_sup.ssl_em import SSLEntropyMinimization from pymic.net_run.semi_sup.ssl_mt import SSLMeanTeacher @@ -8,9 +9,9 @@ from pymic.net_run.semi_sup.ssl_urpc import SSLURPC -SSLMethodDict = {'EntropyMinimization': SSLEntropyMinimization, - 'MeanTeacher': SSLMeanTeacher, - 'UAMT': SSLUncertaintyAwareMeanTeacher, - 'CCT': SSLCCT, - 'CPS': SSLCPS, - 'URPC': SSLURPC} \ No newline at end of file +# SSLMethodDict = {'EntropyMinimization': SSLEntropyMinimization, +# 'MeanTeacher': SSLMeanTeacher, +# 'UAMT': SSLUncertaintyAwareMeanTeacher, +# 'CCT': SSLCCT, +# 'CPS': SSLCPS, +# 'URPC': SSLURPC} \ No newline at end of file diff --git a/pymic/net_run/semi_sup/ssl_mcnet.py b/pymic/net_run/semi_sup/ssl_mcnet.py index 66e1034..955357d 100644 --- a/pymic/net_run/semi_sup/ssl_mcnet.py +++ b/pymic/net_run/semi_sup/ssl_mcnet.py @@ -22,8 +22,8 @@ class SSLMCNet(SSLSegAgent): with multiple decoders for learning, such as `pymic.net.net2d.unet2d_mcnet.MCNet2D`. * Reference: Yicheng Wu, Zongyuan Ge et al. Mutual consistency learning for - semi-supervised medical image segmentation. - `MIA 2022. `_ + semi-supervised medical image segmentation. + `Medical Image Analysis 2022. `_ The original code is at: https://github.com/ycwu1997/MC-Net @@ -34,6 +34,11 @@ class SSLMCNet(SSLSegAgent): In the configuration dictionary, in addition to the four sections (`dataset`, `network`, `training` and `inference`) used in fully supervised learning, an extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. + + Special parameters required for MCNet in `semi_supervised_learning` section: + + :param temperature: (float) temperature for label sharpening. The default value is 0.1. + """ def training(self): class_num = self.config['network']['class_num'] From ad5957e083b26e82fad5e27d90c09bbf1ba5b573 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 4 Nov 2023 14:10:44 +0800 Subject: [PATCH 168/225] add multi-net --- pymic/net/multi_net.py | 32 ++++++++++++++++++++ pymic/net_run/agent_seg.py | 49 +++++++++++++++++++++---------- pymic/net_run/semi_sup/ssl_cps.py | 28 +----------------- 3 files changed, 66 insertions(+), 43 deletions(-) create mode 100644 pymic/net/multi_net.py diff --git a/pymic/net/multi_net.py b/pymic/net/multi_net.py new file mode 100644 index 0000000..8807b0c --- /dev/null +++ b/pymic/net/multi_net.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import torch.nn as nn + +class MultiNet(nn.Module): + ''' + A combination of multiple networks. + Parameters should be saved in the `params` dictionary. + + :param `net_names`: (list) A list of network class name. + :param `infer_mode`: (int) Mode for inference. 0: only use the first network. + 1: taking an average of all the networks. + ''' + def __init__(self, net_dict, params): + super(MultiNet, self).__init__() + net_names = params['net_names'] # should be a list of network class name + self.output_mode = params.get('infer_mode', 0) + self.networks = nn.ModuleList([net_dict[item](params) for item in net_names]) + + def forward(self, x): + if(self.training): + output = [net(x) for net in self.networks] + else: + output = self.networks[0](x) + if(self.output_mode == 1): + for i in range(1, len(self.networks)): + output += self.networks[i](x) + output = output / len(self.networks) + return output + \ No newline at end of file diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 70f01b3..65c2b68 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -18,6 +18,7 @@ from pymic.io.image_read_write import save_nd_array_as_image from pymic.io.nifty_dataset import NiftyDataset from pymic.net.net_dict_seg import SegNetDict +from pymic.net.multi_net import MultiNet from pymic.net_run.agent_abstract import NetRunAgent from pymic.net_run.infer_func import Inferer from pymic.loss.loss_dict_seg import SegLossDict @@ -78,9 +79,12 @@ def get_stage_dataset_from_config(self, stage): def create_network(self): if(self.net is None): net_name = self.config['network']['net_type'] - if(net_name not in self.net_dict): - raise ValueError("Undefined network {0:}".format(net_name)) - self.net = self.net_dict[net_name](self.config['network']) + if(isinstance(net_name, (tuple, list))): + self.net = MultiNet(self.net_dict, self.config['network']) + else: + if(net_name not in self.net_dict): + raise ValueError("Undefined network {0:}".format(net_name)) + self.net = self.net_dict[net_name](self.config['network']) if(self.tensor_type == 'float'): self.net.float() else: @@ -164,10 +168,11 @@ def training(self): if(mixup_prob > 0 and random() < mixup_prob): inputs, labels_prob = mixup(inputs, labels_prob) - # # for debug + # for debug # for i in range(inputs.shape[0]): # image_i = inputs[i][0] - # label_i = labels_prob[i][1] + # # label_i = labels_prob[i][1] + # label_i = np.argmax(labels_prob[i], axis = 0) # # pixw_i = pix_w[i][0] # print(image_i.shape, label_i.shape) # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) @@ -176,7 +181,7 @@ def training(self): # save_nd_array_as_image(image_i, image_name, reference_name = None) # save_nd_array_as_image(label_i, label_name, reference_name = None) # # save_nd_array_as_image(pixw_i, weight_name, reference_name = None) - # # continue + # continue inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device) @@ -271,6 +276,27 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): valid_scalars['loss'], valid_scalars['avg_fg_dice']) + "[" + \ ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") + def load_pretrained_weights(self, network, pretrained_dict, device_ids): + if(len(device_ids) > 1): + if(hasattr(network.module, "get_parameters_to_load")): + model_dict = network.module.get_parameters_to_load() + else: + model_dict = network.module.state_dict() + else: + if(hasattr(network, "get_parameters_to_load")): + model_dict = network.get_parameters_to_load() + else: + model_dict = network.state_dict() + pretrained_dict = {k: v for k, v in pretrained_dict.items() if \ + k in model_dict and tensor_shape_match(pretrained_dict[k], model_dict[k])} + logging.info("Initializing the following parameters with pre-trained model") + for k in pretrained_dict: + logging.info(k) + if (len(device_ids) > 1): + network.module.load_state_dict(pretrained_dict, strict = False) + else: + network.load_state_dict(pretrained_dict, strict = False) + def train_valid(self): device_ids = self.config['training']['gpus'] if(len(device_ids) > 1): @@ -310,16 +336,7 @@ def train_valid(self): if(ckpt_init_name is not None): checkpoint = torch.load(ckpt_dir + "/" + ckpt_init_name, map_location = self.device) pretrained_dict = checkpoint['model_state_dict'] - model_dict = self.net.module.state_dict() if (len(device_ids) > 1) else self.net.state_dict() - pretrained_dict = {k: v for k, v in pretrained_dict.items() if \ - k in model_dict and tensor_shape_match(pretrained_dict[k], model_dict[k])} - logging.info("Initializing the following parameters with pre-trained model") - for k in pretrained_dict: - logging.info(k) - if (len(device_ids) > 1): - self.net.module.load_state_dict(pretrained_dict, strict = False) - else: - self.net.load_state_dict(pretrained_dict, strict = False) + self.load_pretrained_weights(self.net, pretrained_dict, device_ids) if(ckpt_init_mode > 0): # Load other information self.max_val_dice = checkpoint.get('valid_pred', 0) diff --git a/pymic/net_run/semi_sup/ssl_cps.py b/pymic/net_run/semi_sup/ssl_cps.py index 4a3be9c..db1fb28 100644 --- a/pymic/net_run/semi_sup/ssl_cps.py +++ b/pymic/net_run/semi_sup/ssl_cps.py @@ -3,30 +3,12 @@ import logging import numpy as np import torch -import torch.nn as nn from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.net_run.semi_sup import SSLSegAgent -from pymic.net.net_dict_seg import SegNetDict from pymic.util.ramps import get_rampup_ratio -class BiNet(nn.Module): - def __init__(self, params): - super(BiNet, self).__init__() - net_name = params['net_type'] - self.net1 = SegNetDict[net_name](params) - self.net2 = SegNetDict[net_name](params) - - def forward(self, x): - out1 = self.net1(x) - out2 = self.net2(x) - - if(self.training): - return out1, out2 - else: - return (out1 + out2) / 2 - class SSLCPS(SSLSegAgent): """ Using cross pseudo supervision for semi-supervised segmentation. @@ -47,14 +29,6 @@ class SSLCPS(SSLSegAgent): def __init__(self, config, stage = 'train'): super(SSLCPS, self).__init__(config, stage) - def create_network(self): - if(self.net is None): - self.net = BiNet(self.config['network']) - if(self.tensor_type == 'float'): - self.net.float() - else: - self.net.double() - def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] @@ -89,7 +63,7 @@ def training(self): # zero the parameter gradients self.optimizer.zero_grad() - outputs1, outputs2 = self.net(inputs) + outputs1, outputs2 = self.net(inputs) outputs_soft1 = torch.softmax(outputs1, dim=1) outputs_soft2 = torch.softmax(outputs2, dim=1) From 6017047877c3860326b283cd20ee9c30347917e9 Mon Sep 17 00:00:00 2001 From: taigw Date: Sun, 5 Nov 2023 16:26:13 +0800 Subject: [PATCH 169/225] add CANet --- pymic/net/multi_net.py | 2 +- pymic/net/net2d/canet_module.py | 578 +++++++++++++++++++++ pymic/net/net2d/unet2d_attention.py | 15 +- pymic/net/net2d/unet2d_canet.py | 756 ++++++++++++++++++++++++++++ pymic/net/net2d/unet2d_mcnet.py | 2 +- pymic/util/evaluation_seg.py | 55 +- 6 files changed, 1379 insertions(+), 29 deletions(-) create mode 100644 pymic/net/net2d/canet_module.py create mode 100644 pymic/net/net2d/unet2d_canet.py diff --git a/pymic/net/multi_net.py b/pymic/net/multi_net.py index 8807b0c..78209b1 100644 --- a/pymic/net/multi_net.py +++ b/pymic/net/multi_net.py @@ -15,7 +15,7 @@ class MultiNet(nn.Module): ''' def __init__(self, net_dict, params): super(MultiNet, self).__init__() - net_names = params['net_names'] # should be a list of network class name + net_names = params['net_type'] # should be a list of network class name self.output_mode = params.get('infer_mode', 0) self.networks = nn.ModuleList([net_dict[item](params) for item in net_names]) diff --git a/pymic/net/net2d/canet_module.py b/pymic/net/net2d/canet_module.py new file mode 100644 index 0000000..097a4f1 --- /dev/null +++ b/pymic/net/net2d/canet_module.py @@ -0,0 +1,578 @@ +# -*- coding: utf-8 -*- +""" +Building blcoks for CA-Net. + +Oringinal file is on `Github. +`_ +""" + +from __future__ import print_function, division +import torch +import torch.nn as nn +import functools +from torch.nn import functional as F + + +class conv_block(nn.Module): + def __init__(self, ch_in, ch_out, drop_out=False): + super(conv_block, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), + nn.BatchNorm2d(ch_out), + nn.ReLU(inplace=True), + nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True), + nn.BatchNorm2d(ch_out), + nn.ReLU(inplace=True), + ) + self.dropout = drop_out + + def forward(self, x): + x = self.conv(x) + if self.dropout: + x = nn.Dropout2d(0.5)(x) + return x + + +# # UpCat(nn.Module) for U-net UP convolution +class UpCat(nn.Module): + def __init__(self, in_feat, out_feat, is_deconv=True): + super(UpCat, self).__init__() + if is_deconv: + self.up = nn.ConvTranspose2d(in_feat, out_feat, kernel_size=2, stride=2) + else: + self.up = nn.Upsample(scale_factor=2, mode='bilinear') + + def forward(self, inputs, down_outputs): + # TODO: Upsampling required after deconv? + outputs = self.up(down_outputs) + offset = inputs.size()[3] - outputs.size()[3] + if offset == 1: + addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2]), out=None).unsqueeze( + 3).cuda() + outputs = torch.cat([outputs, addition], dim=3) + elif offset > 1: + addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2], offset), out=None).cuda() + outputs = torch.cat([outputs, addition], dim=3) + out = torch.cat([inputs, outputs], dim=1) + + return out + + +# # UpCatconv(nn.Module) for up convolution +class UpCatconv(nn.Module): + def __init__(self, in_feat, out_feat, is_deconv=True, drop_out=False): + super(UpCatconv, self).__init__() + + if is_deconv: + self.conv = conv_block(in_feat, out_feat, drop_out=drop_out) + self.up = nn.ConvTranspose2d(in_feat, out_feat, kernel_size=2, stride=2) + else: + self.conv = conv_block(in_feat + out_feat, out_feat, drop_out=drop_out) + self.up = nn.Upsample(scale_factor=2, mode='bilinear') + + def forward(self, inputs, down_outputs): + # TODO: Upsampling required after deconv + outputs = self.up(down_outputs) + offset = inputs.size()[3] - outputs.size()[3] + if offset == 1: + addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2]), out=None).unsqueeze( + 3).cuda() + outputs = torch.cat([outputs, addition], dim=3) + elif offset > 1: + addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2], offset), out=None).cuda() + outputs = torch.cat([outputs, addition], dim=3) + out = self.conv(torch.cat([inputs, outputs], dim=1)) + + return out + + +class UnetDsv3(nn.Module): + def __init__(self, in_size, out_size, scale_factor): + super(UnetDsv3, self).__init__() + self.dsv = nn.Sequential(nn.Conv2d(in_size, out_size, kernel_size=1, stride=1, padding=0), + nn.Upsample(size=scale_factor, mode='bilinear'), ) + + def forward(self, input): + return self.dsv(input) + + +###### Intial weights ##### +def weights_init_normal(m): + classname = m.__class__.__name__ + #print(classname) + if classname.find('Conv') != -1: + nn.init.normal(m.weight.data, 0.0, 0.02) + elif classname.find('Linear') != -1: + nn.init.normal(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm') != -1: + nn.init.normal(m.weight.data, 1.0, 0.02) + nn.init.constant(m.bias.data, 0.0) + + +def weights_init_xavier(m): + classname = m.__class__.__name__ + #print(classname) + if classname.find('Conv') != -1: + nn.init.xavier_normal(m.weight.data, gain=1) + elif classname.find('Linear') != -1: + nn.init.xavier_normal(m.weight.data, gain=1) + elif classname.find('BatchNorm') != -1: + nn.init.normal(m.weight.data, 1.0, 0.02) + nn.init.constant(m.bias.data, 0.0) + + +def weights_init_kaiming(m): + classname = m.__class__.__name__ + #print(classname) + if classname.find('Conv') != -1: + nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in') + elif classname.find('Linear') != -1: + nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in') + elif classname.find('BatchNorm') != -1: + nn.init.normal(m.weight.data, 1.0, 0.02) + nn.init.constant(m.bias.data, 0.0) + + +def weights_init_orthogonal(m): + classname = m.__class__.__name__ + #print(classname) + if classname.find('Conv') != -1: + nn.init.orthogonal(m.weight.data, gain=1) + elif classname.find('Linear') != -1: + nn.init.orthogonal(m.weight.data, gain=1) + elif classname.find('BatchNorm') != -1: + nn.init.normal(m.weight.data, 1.0, 0.02) + nn.init.constant(m.bias.data, 0.0) + + +def init_weights(net, init_type='normal'): + #print('initialization method [%s]' % init_type) + if init_type == 'normal': + net.apply(weights_init_normal) + elif init_type == 'xavier': + net.apply(weights_init_xavier) + elif init_type == 'kaiming': + net.apply(weights_init_kaiming) + elif init_type == 'orthogonal': + net.apply(weights_init_orthogonal) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + + +def get_norm_layer(norm_type='instance'): + if norm_type == 'batch': + norm_layer = functools.partial(nn.BatchNorm2d, affine=True) + elif norm_type == 'instance': + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) + elif norm_type == 'none': + norm_layer = None + else: + raise NotImplementedError('normalization layer [%s] is not found' % norm_type) + return norm_layer + + +###### For attention ###### +class _GridAttentionBlockND(nn.Module): + def __init__(self, in_channels, gating_channels, inter_channels=None, dimension=3, mode='concatenation', + sub_sample_factor=(2,2,2)): + super(_GridAttentionBlockND, self).__init__() + + assert dimension in [2, 3] + assert mode in ['concatenation', 'concatenation_debug', 'concatenation_residual'] + + # Downsampling rate for the input featuremap + if isinstance(sub_sample_factor, tuple): self.sub_sample_factor = sub_sample_factor + elif isinstance(sub_sample_factor, list): self.sub_sample_factor = tuple(sub_sample_factor) + else: self.sub_sample_factor = tuple([sub_sample_factor]) * dimension + + # Default parameter set + self.mode = mode + self.dimension = dimension + self.sub_sample_kernel_size = self.sub_sample_factor + + # Number of channels (pixel dimensions) + self.in_channels = in_channels + self.gating_channels = gating_channels + self.inter_channels = inter_channels + + if self.inter_channels is None: + self.inter_channels = in_channels // 2 + if self.inter_channels == 0: + self.inter_channels = 1 + + if dimension == 3: + conv_nd = nn.Conv3d + bn = nn.BatchNorm3d + self.upsample_mode = 'trilinear' + elif dimension == 2: + conv_nd = nn.Conv2d + bn = nn.BatchNorm2d + self.upsample_mode = 'bilinear' + else: + raise NotImplemented + + # Output transform + self.W = nn.Sequential( + conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), + bn(self.in_channels), + ) + + # Theta^T * x_ij + Phi^T * gating_signal + bias + self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, + kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=True) + self.phi = conv_nd(in_channels=self.gating_channels, out_channels=self.inter_channels, + kernel_size=(1, 1), stride=1, padding=0, bias=True) + self.psi = conv_nd(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True) + + # Initialise weights + for m in self.children(): + init_weights(m, init_type='kaiming') + + # Define the operation + if mode == 'concatenation': + self.operation_function = self._concatenation + elif mode == 'concatenation_debug': + self.operation_function = self._concatenation_debug + elif mode == 'concatenation_residual': + self.operation_function = self._concatenation_residual + else: + raise NotImplementedError('Unknown operation function.') + + + def forward(self, x, g): + ''' + :param x: (b, c, t, h, w) + :param g: (b, g_d) + :return: + ''' + + output = self.operation_function(x, g) + return output + + def _concatenation(self, x, g): + input_size = x.size() + batch_size = input_size[0] + assert batch_size == g.size(0) + + # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) + # phi => (b, g_d) -> (b, i_c) + theta_x = self.theta(x) + theta_x_size = theta_x.size() + + # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') + # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) + phi_g = F.upsample(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) + f = F.relu(theta_x + phi_g, inplace=True) + + # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) + sigm_psi_f = F.sigmoid(self.psi(f)) + + # upsample the attentions and multiply + sigm_psi_f = F.upsample(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) + y = sigm_psi_f.expand_as(x) * x + W_y = self.W(y) + + return W_y, sigm_psi_f + + def _concatenation_debug(self, x, g): + input_size = x.size() + batch_size = input_size[0] + assert batch_size == g.size(0) + + # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) + # phi => (b, g_d) -> (b, i_c) + theta_x = self.theta(x) + theta_x_size = theta_x.size() + + # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') + # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) + phi_g = F.upsample(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) + f = F.softplus(theta_x + phi_g) + + # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) + sigm_psi_f = F.sigmoid(self.psi(f)) + + # upsample the attentions and multiply + sigm_psi_f = F.upsample(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) + y = sigm_psi_f.expand_as(x) * x + W_y = self.W(y) + + return W_y, sigm_psi_f + + + def _concatenation_residual(self, x, g): + input_size = x.size() + batch_size = input_size[0] + assert batch_size == g.size(0) + + # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) + # phi => (b, g_d) -> (b, i_c) + theta_x = self.theta(x) + theta_x_size = theta_x.size() + + # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') + # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) + phi_g = F.upsample(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) + f = F.relu(theta_x + phi_g, inplace=True) + + # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) + f = self.psi(f).view(batch_size, 1, -1) + sigm_psi_f = F.softmax(f, dim=2).view(batch_size, 1, *theta_x.size()[2:]) + + # upsample the attentions and multiply + sigm_psi_f = F.upsample(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) + y = sigm_psi_f.expand_as(x) * x + W_y = self.W(y) + + return W_y, sigm_psi_f + + +class GridAttentionBlock2D(_GridAttentionBlockND): + def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation', + sub_sample_factor=(2, 2)): + super(GridAttentionBlock2D, self).__init__(in_channels, + inter_channels=inter_channels, + gating_channels=gating_channels, + dimension=2, mode=mode, + sub_sample_factor=sub_sample_factor, + ) + + +class GridAttentionBlock3D(_GridAttentionBlockND): + def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation', + sub_sample_factor=(2,2,2)): + super(GridAttentionBlock3D, self).__init__(in_channels, + inter_channels=inter_channels, + gating_channels=gating_channels, + dimension=3, mode=mode, + sub_sample_factor=sub_sample_factor, + ) + +class _GridAttentionBlockND_TORR(nn.Module): + def __init__(self, in_channels, gating_channels, inter_channels=None, dimension=3, mode='concatenation', + sub_sample_factor=(1,1,1), bn_layer=True, use_W=True, use_phi=True, use_theta=True, use_psi=True, nonlinearity1='relu'): + super(_GridAttentionBlockND_TORR, self).__init__() + + assert dimension in [2, 3] + assert mode in ['concatenation', 'concatenation_softmax', + 'concatenation_sigmoid', 'concatenation_mean', + 'concatenation_range_normalise', 'concatenation_mean_flow'] + + # Default parameter set + self.mode = mode + self.dimension = dimension + self.sub_sample_factor = sub_sample_factor if isinstance(sub_sample_factor, tuple) else tuple([sub_sample_factor])*dimension + self.sub_sample_kernel_size = self.sub_sample_factor + + # Number of channels (pixel dimensions) + self.in_channels = in_channels + self.gating_channels = gating_channels + self.inter_channels = inter_channels + + if self.inter_channels is None: + self.inter_channels = in_channels // 2 + if self.inter_channels == 0: + self.inter_channels = 1 + + if dimension == 3: + conv_nd = nn.Conv3d + bn = nn.BatchNorm3d + self.upsample_mode = 'trilinear' + elif dimension == 2: + conv_nd = nn.Conv2d + bn = nn.BatchNorm2d + self.upsample_mode = 'bilinear' + else: + raise NotImplemented + + # initialise id functions + # Theta^T * x_ij + Phi^T * gating_signal + bias + self.W = lambda x: x + self.theta = lambda x: x + self.psi = lambda x: x + self.phi = lambda x: x + self.nl1 = lambda x: x + + if use_W: + if bn_layer: + self.W = nn.Sequential( + conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), + bn(self.in_channels), + ) + else: + self.W = conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0) + + if use_theta: + self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, + kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=False) + + + if use_phi: + self.phi = conv_nd(in_channels=self.gating_channels, out_channels=self.inter_channels, + kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=False) + + + if use_psi: + self.psi = conv_nd(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True) + + + if nonlinearity1: + if nonlinearity1 == 'relu': + self.nl1 = lambda x: F.relu(x, inplace=True) + + if 'concatenation' in mode: + self.operation_function = self._concatenation + else: + raise NotImplementedError('Unknown operation function.') + + # Initialise weights + for m in self.children(): + init_weights(m, init_type='kaiming') + + + if use_psi and self.mode == 'concatenation_sigmoid': + nn.init.constant(self.psi.bias.data, 3.0) + + if use_psi and self.mode == 'concatenation_softmax': + nn.init.constant(self.psi.bias.data, 10.0) + + # if use_psi and self.mode == 'concatenation_mean': + # nn.init.constant(self.psi.bias.data, 3.0) + + # if use_psi and self.mode == 'concatenation_range_normalise': + # nn.init.constant(self.psi.bias.data, 3.0) + + parallel = False + if parallel: + if use_W: self.W = nn.DataParallel(self.W) + if use_phi: self.phi = nn.DataParallel(self.phi) + if use_psi: self.psi = nn.DataParallel(self.psi) + if use_theta: self.theta = nn.DataParallel(self.theta) + + def forward(self, x, g): + ''' + :param x: (b, c, t, h, w) + :param g: (b, g_d) + :return: + ''' + + output = self.operation_function(x, g) + return output + + def _concatenation(self, x, g): + input_size = x.size() + batch_size = input_size[0] + assert batch_size == g.size(0) + + ############################# + # compute compatibility score + + # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) + # phi => (b, c, t, h, w) -> (b, i_c, t, h, w) + theta_x = self.theta(x) + theta_x_size = theta_x.size() + + # nl(theta.x + phi.g + bias) -> f = (b, i_c, t/s1, h/s2, w/s3) + phi_g = F.upsample(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) + + f = theta_x + phi_g + f = self.nl1(f) + + psi_f = self.psi(f) + + ############################################ + # normalisation -- scale compatibility score + # psi^T . f -> (b, 1, t/s1, h/s2, w/s3) + if self.mode == 'concatenation_softmax': + sigm_psi_f = F.softmax(psi_f.view(batch_size, 1, -1), dim=2) + sigm_psi_f = sigm_psi_f.view(batch_size, 1, *theta_x_size[2:]) + elif self.mode == 'concatenation_mean': + psi_f_flat = psi_f.view(batch_size, 1, -1) + psi_f_sum = torch.sum(psi_f_flat, dim=2)#clamp(1e-6) + psi_f_sum = psi_f_sum[:,:,None].expand_as(psi_f_flat) + + sigm_psi_f = psi_f_flat / psi_f_sum + sigm_psi_f = sigm_psi_f.view(batch_size, 1, *theta_x_size[2:]) + elif self.mode == 'concatenation_mean_flow': + psi_f_flat = psi_f.view(batch_size, 1, -1) + ss = psi_f_flat.shape + psi_f_min = psi_f_flat.min(dim=2)[0].view(ss[0],ss[1],1) + psi_f_flat = psi_f_flat - psi_f_min + psi_f_sum = torch.sum(psi_f_flat, dim=2).view(ss[0],ss[1],1).expand_as(psi_f_flat) + + sigm_psi_f = psi_f_flat / psi_f_sum + sigm_psi_f = sigm_psi_f.view(batch_size, 1, *theta_x_size[2:]) + elif self.mode == 'concatenation_range_normalise': + psi_f_flat = psi_f.view(batch_size, 1, -1) + ss = psi_f_flat.shape + psi_f_max = torch.max(psi_f_flat, dim=2)[0].view(ss[0], ss[1], 1) + psi_f_min = torch.min(psi_f_flat, dim=2)[0].view(ss[0], ss[1], 1) + + sigm_psi_f = (psi_f_flat - psi_f_min) / (psi_f_max - psi_f_min).expand_as(psi_f_flat) + sigm_psi_f = sigm_psi_f.view(batch_size, 1, *theta_x_size[2:]) + + elif self.mode == 'concatenation_sigmoid': + sigm_psi_f = F.sigmoid(psi_f) + else: + raise NotImplementedError + + # sigm_psi_f is attention map! upsample the attentions and multiply + sigm_psi_f = F.upsample(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) + y = sigm_psi_f.expand_as(x) * x + W_y = self.W(y) + + return W_y, sigm_psi_f + + +class GridAttentionBlock2D_TORR(_GridAttentionBlockND_TORR): + def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation', + sub_sample_factor=(1,1), bn_layer=True, + use_W=True, use_phi=True, use_theta=True, use_psi=True, + nonlinearity1='relu'): + super(GridAttentionBlock2D_TORR, self).__init__(in_channels, + inter_channels=inter_channels, + gating_channels=gating_channels, + dimension=2, mode=mode, + sub_sample_factor=sub_sample_factor, + bn_layer=bn_layer, + use_W=use_W, + use_phi=use_phi, + use_theta=use_theta, + use_psi=use_psi, + nonlinearity1=nonlinearity1) + + +class GridAttentionBlock3D_TORR(_GridAttentionBlockND_TORR): + def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation', + sub_sample_factor=(1,1,1), bn_layer=True): + super(GridAttentionBlock3D_TORR, self).__init__(in_channels, + inter_channels=inter_channels, + gating_channels=gating_channels, + dimension=3, mode=mode, + sub_sample_factor=sub_sample_factor, + bn_layer=bn_layer) + + +class MultiAttentionBlock(nn.Module): + def __init__(self, in_size, gate_size, inter_size, nonlocal_mode, sub_sample_factor): + super(MultiAttentionBlock, self).__init__() + self.gate_block_1 = GridAttentionBlock2D(in_channels=in_size, gating_channels=gate_size, + inter_channels=inter_size, mode=nonlocal_mode, + sub_sample_factor=sub_sample_factor) + self.gate_block_2 = GridAttentionBlock2D(in_channels=in_size, gating_channels=gate_size, + inter_channels=inter_size, mode=nonlocal_mode, + sub_sample_factor=sub_sample_factor) + self.combine_gates = nn.Sequential(nn.Conv2d(in_size*2, in_size, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(in_size), + nn.ReLU(inplace=True)) + + # initialise the blocks + for m in self.children(): + if m.__class__.__name__.find('GridAttentionBlock2D') != -1: continue + init_weights(m, init_type='kaiming') + + def forward(self, input, gating_signal): + gate_1, attention_1 = self.gate_block_1(input, gating_signal) + gate_2, attention_2 = self.gate_block_2(input, gating_signal) + + return self.combine_gates(torch.cat([gate_1, gate_2], 1)), torch.cat([attention_1, attention_2], 1) \ No newline at end of file diff --git a/pymic/net/net2d/unet2d_attention.py b/pymic/net/net2d/unet2d_attention.py index 6afdfdc..36faec8 100644 --- a/pymic/net/net2d/unet2d_attention.py +++ b/pymic/net/net2d/unet2d_attention.py @@ -4,14 +4,7 @@ import torch import torch.nn as nn from pymic.net.net2d.unet2d import * -""" -A Reimplementation of the attention U-Net paper: - Ozan Oktay, Jo Schlemper et al.: - Attentin U-Net: Looking Where to Look for the Pancreas. MIDL, 2018. -Note that there are some modifications from the original paper, such as -the use of batch normalization, dropout, and leaky relu here. -""" class AttentionGateBlock(nn.Module): def __init__(self, chns_l, chns_h): """ @@ -80,6 +73,14 @@ def forward(self, x1, x2): return self.conv(x) class AttentionUNet2D(UNet2D): + """ + A Reimplementation of the attention U-Net paper: + Ozan Oktay, Jo Schlemper et al.: + Attentin U-Net: Looking Where to Look for the Pancreas. MIDL, 2018. + + Note that there are some modifications from the original paper, such as + the use of batch normalization, dropout, and leaky relu here. + """ def __init__(self, params): super(AttentionUNet2D, self).__init__(params) self.up1 = UpBlockWithAttention(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p = 0.0) diff --git a/pymic/net/net2d/unet2d_canet.py b/pymic/net/net2d/unet2d_canet.py new file mode 100644 index 0000000..7aeba84 --- /dev/null +++ b/pymic/net/net2d/unet2d_canet.py @@ -0,0 +1,756 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F +from pymic.net.net2d.canet_module import * + + +def conv3x3(in_planes, out_planes, stride=1, bias=False, group=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,padding=1, groups=group, bias=bias) + + +class SE_Conv_Block(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, drop_out=False): + super(SE_Conv_Block, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes * 2) + self.bn2 = nn.BatchNorm2d(planes * 2) + self.conv3 = conv3x3(planes * 2, planes) + self.bn3 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + self.dropout = drop_out + + self.fc1 = nn.Linear(in_features=planes * 2, out_features=round(planes / 2)) + self.fc2 = nn.Linear(in_features=round(planes / 2), out_features=planes * 2) + self.sigmoid = nn.Sigmoid() + + self.downchannel = None + if inplanes != planes: + self.downchannel = nn.Sequential(nn.Conv2d(inplanes, planes * 2, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * 2),) + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downchannel is not None: + residual = self.downchannel(x) + + original_out = out + out1 = out + # For global average pool + out = F.adaptive_avg_pool2d(out, (1,1)) + out = out.view(out.size(0), -1) + out = self.fc1(out) + out = self.relu(out) + out = self.fc2(out) + out = self.sigmoid(out) + out = out.view(out.size(0), out.size(1), 1, 1) + avg_att = out + out = out * original_out + # For global maximum pool + out1 = F.adaptive_max_pool2d(out1, (1,1)) + out1 = out1.view(out1.size(0), -1) + out1 = self.fc1(out1) + out1 = self.relu(out1) + out1 = self.fc2(out1) + out1 = self.sigmoid(out1) + out1 = out1.view(out1.size(0), out1.size(1), 1, 1) + max_att = out1 + out1 = out1 * original_out + + att_weight = avg_att + max_att + out += out1 + out += residual + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + out = self.relu(out) + if self.dropout: + out = nn.Dropout2d(0.5)(out) + + return out, att_weight + +# # CBAM Convolutional block attention module +class BasicConv(nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, + relu=True, bn=True, bias=False): + super(BasicConv, self).__init__() + self.out_channels = out_planes + self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, + dilation=dilation, groups=groups, bias=bias) + self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None + self.relu = nn.ReLU() if relu else None + + def forward(self, x): + x = self.conv(x) + if self.bn is not None: + x = self.bn(x) + if self.relu is not None: + x = self.relu(x) + return x + + +class Flatten(nn.Module): + def forward(self, x): + return x.view(x.size(0), -1) + + +class ChannelGate(nn.Module): + def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): + super(ChannelGate, self).__init__() + self.gate_channels = gate_channels + self.mlp = nn.Sequential( + Flatten(), + nn.Linear(gate_channels, gate_channels // reduction_ratio), + nn.ReLU(), + nn.Linear(gate_channels // reduction_ratio, gate_channels) + ) + self.pool_types = pool_types + + def forward(self, x): + channel_att_sum = None + for pool_type in self.pool_types: + if pool_type == 'avg': + avg_pool = F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) + channel_att_raw = self.mlp(avg_pool) + elif pool_type == 'max': + max_pool = F.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) + channel_att_raw = self.mlp(max_pool) + elif pool_type == 'lp': + lp_pool = F.lp_pool2d(x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) + channel_att_raw = self.mlp(lp_pool) + elif pool_type == 'lse': + # LSE pool only + lse_pool = logsumexp_2d(x) + channel_att_raw = self.mlp(lse_pool) + + if channel_att_sum is None: + channel_att_sum = channel_att_raw + else: + channel_att_sum = channel_att_sum + channel_att_raw + + # scalecoe = F.sigmoid(channel_att_sum) + # print("channel att_sum", channel_att_sum.shape) + # channel_att_sum = channel_att_sum.reshape(channel_att_sum.shape[0], 4, 4) + # avg_weight = torch.mean(channel_att_sum, dim=2).unsqueeze(2) + # avg_weight = avg_weight.expand(channel_att_sum.shape[0], 4, 4).reshape(channel_att_sum.shape[0], 16) + # scale = F.sigmoid(avg_weight).unsqueeze(2).unsqueeze(3).expand_as(x) + scale = F.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x) + return x * scale, scale + + +def logsumexp_2d(tensor): + tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) + s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) + outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() + return outputs + + +class ChannelPool(nn.Module): + def forward(self, x): + return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1) + + +class SpatialGate(nn.Module): + def __init__(self): + super(SpatialGate, self).__init__() + kernel_size = 7 + self.compress = ChannelPool() + self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False) + + def forward(self, x): + x_compress = self.compress(x) + x_out = self.spatial(x_compress) + scale = F.sigmoid(x_out) # broadcasting + return x * scale, scale + +class SpatialAtten(nn.Module): + def __init__(self, in_size, out_size, kernel_size=3, stride=1): + super(SpatialAtten, self).__init__() + self.conv1 = BasicConv(in_size, out_size, kernel_size, stride=stride, + padding=(kernel_size-1) // 2, relu=True) + self.conv2 = BasicConv(out_size, in_size, kernel_size=1, stride=stride, + padding=0, relu=True, bn=False) + + def forward(self, x): + residual = x + x_out = self.conv1(x) + x_out = self.conv2(x_out) + spatial_att = F.sigmoid(x_out) + # .unsqueeze(4).permute(0, 1, 4, 2, 3) + # spatial_att = spatial_att.expand(spatial_att.shape[0], 4, 4, spatial_att.shape[3], spatial_att.shape[4]).reshape( + # spatial_att.shape[0], 16, spatial_att.shape[3], spatial_att.shape[4]) + x_out = residual * spatial_att + + x_out += residual + + return x_out, spatial_att + +class Scale_atten_block(nn.Module): + def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False): + super(Scale_atten_block, self).__init__() + self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) + self.no_spatial = no_spatial + if not no_spatial: + self.SpatialGate = SpatialAtten(gate_channels, gate_channels //reduction_ratio) + + def forward(self, x): + x_out, ca_atten = self.ChannelGate(x) + if not self.no_spatial: + x_out, sa_atten = self.SpatialGate(x_out) + + return x_out, ca_atten, sa_atten + + +class scale_atten_convblock(nn.Module): + def __init__(self, in_size, out_size, stride=1, downsample=None, use_cbam=True, no_spatial=False, drop_out=False): + super(scale_atten_convblock, self).__init__() + self.downsample = downsample + self.stride = stride + self.no_spatial = no_spatial + self.dropout = drop_out + + self.relu = nn.ReLU(inplace=True) + self.conv3 = conv3x3(in_size, out_size) + self.bn3 = nn.BatchNorm2d(out_size) + + if use_cbam: + self.cbam = Scale_atten_block(in_size, reduction_ratio=4, no_spatial=self.no_spatial) # out_size + else: + self.cbam = None + + def forward(self, x): + residual = x + + if self.downsample is not None: + residual = self.downsample(x) + + if not self.cbam is None: + out, scale_c_atten, scale_s_atten = self.cbam(x) + + out += residual + out = self.relu(out) + out = self.conv3(out) + out = self.bn3(out) + out = self.relu(out) + + if self.dropout: + out = nn.Dropout2d(0.5)(out) + + return out + +class _NonLocalBlockND(nn.Module): + def __init__(self, in_channels, inter_channels=None, dimension=3, mode='embedded_gaussian', + sub_sample_factor=4, bn_layer=True): + super(_NonLocalBlockND, self).__init__() + + assert dimension in [1, 2, 3] + assert mode in ['embedded_gaussian', 'gaussian', 'dot_product', 'concatenation', 'concat_proper', 'concat_proper_down'] + + # print('Dimension: %d, mode: %s' % (dimension, mode)) + + self.mode = mode + self.dimension = dimension + self.sub_sample_factor = sub_sample_factor if isinstance(sub_sample_factor, list) else [sub_sample_factor] + + self.in_channels = in_channels + self.inter_channels = inter_channels + + if self.inter_channels is None: + self.inter_channels = in_channels // 2 + if self.inter_channels == 0: + self.inter_channels = 1 + + if dimension == 3: + conv_nd = nn.Conv3d + max_pool = nn.MaxPool3d + bn = nn.BatchNorm3d + elif dimension == 2: + conv_nd = nn.Conv2d + max_pool = nn.MaxPool2d + bn = nn.BatchNorm2d + else: + conv_nd = nn.Conv1d + max_pool = nn.MaxPool1d + bn = nn.BatchNorm1d + + self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, + kernel_size=1, stride=1, padding=0) + + if bn_layer: + self.W = nn.Sequential( + conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, + kernel_size=1, stride=1, padding=0), + bn(self.in_channels) + ) + nn.init.constant(self.W[1].weight, 0) + nn.init.constant(self.W[1].bias, 0) + else: + self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, + kernel_size=1, stride=1, padding=0) + nn.init.constant(self.W.weight, 0) + nn.init.constant(self.W.bias, 0) + + self.theta = None + self.phi = None + + if mode in ['embedded_gaussian', 'dot_product', 'concatenation', 'concat_proper', 'concat_proper_down']: + self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, + kernel_size=1, stride=1, padding=0) + self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, + kernel_size=1, stride=1, padding=0) + + if mode in ['concatenation']: + self.wf_phi = nn.Linear(self.inter_channels, 1, bias=False) + self.wf_theta = nn.Linear(self.inter_channels, 1, bias=False) + elif mode in ['concat_proper', 'concat_proper_down']: + self.psi = nn.Conv2d(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, + padding=0, bias=True) + + if mode == 'embedded_gaussian': + self.operation_function = self._embedded_gaussian + elif mode == 'dot_product': + self.operation_function = self._dot_product + elif mode == 'gaussian': + self.operation_function = self._gaussian + elif mode == 'concatenation': + self.operation_function = self._concatenation + elif mode == 'concat_proper': + self.operation_function = self._concatenation_proper + elif mode == 'concat_proper_down': + self.operation_function = self._concatenation_proper_down + else: + raise NotImplementedError('Unknown operation function.') + + if any(ss > 1 for ss in self.sub_sample_factor): + self.g = nn.Sequential(self.g, max_pool(kernel_size=sub_sample_factor)) + if self.phi is None: + self.phi = max_pool(kernel_size=sub_sample_factor) + else: + self.phi = nn.Sequential(self.phi, max_pool(kernel_size=sub_sample_factor)) + if mode == 'concat_proper_down': + self.theta = nn.Sequential(self.theta, max_pool(kernel_size=sub_sample_factor)) + + # Initialise weights + for m in self.children(): + init_weights(m, init_type='kaiming') + + def forward(self, x): + ''' + :param x: (b, c, t, h, w) + :return: + ''' + + output = self.operation_function(x) + return output + + def _embedded_gaussian(self, x): + batch_size = x.size(0) + + # g=>(b, c, t, h, w)->(b, 0.5c, t, h, w)->(b, thw, 0.5c) + g_x = self.g(x).view(batch_size, self.inter_channels, -1) + g_x = g_x.permute(0, 2, 1) + + # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw, 0.5c) + # phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw) + # f=>(b, thw, 0.5c)dot(b, 0.5c, twh) = (b, thw, thw) + theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) + theta_x = theta_x.permute(0, 2, 1) + phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) + f = torch.matmul(theta_x, phi_x) + f_div_C = F.softmax(f, dim=-1) + + # (b, thw, thw)dot(b, thw, 0.5c) = (b, thw, 0.5c)->(b, 0.5c, t, h, w)->(b, c, t, h, w) + y = torch.matmul(f_div_C, g_x) + y = y.permute(0, 2, 1).contiguous() + y = y.view(batch_size, self.inter_channels, *x.size()[2:]) + W_y = self.W(y) + z = W_y + x + + return z + + def _gaussian(self, x): + batch_size = x.size(0) + g_x = self.g(x).view(batch_size, self.inter_channels, -1) + g_x = g_x.permute(0, 2, 1) + + theta_x = x.view(batch_size, self.in_channels, -1) + theta_x = theta_x.permute(0, 2, 1) + + if self.sub_sample_factor > 1: + phi_x = self.phi(x).view(batch_size, self.in_channels, -1) + else: + phi_x = x.view(batch_size, self.in_channels, -1) + + f = torch.matmul(theta_x, phi_x) + f_div_C = F.softmax(f, dim=-1) + + y = torch.matmul(f_div_C, g_x) + y = y.permute(0, 2, 1).contiguous() + y = y.view(batch_size, self.inter_channels, *x.size()[2:]) + W_y = self.W(y) + z = W_y + x + + return z + + def _dot_product(self, x): + batch_size = x.size(0) + + g_x = self.g(x).view(batch_size, self.inter_channels, -1) + g_x = g_x.permute(0, 2, 1) + + theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) + theta_x = theta_x.permute(0, 2, 1) + phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) + f = torch.matmul(theta_x, phi_x) + N = f.size(-1) + f_div_C = f / N + + y = torch.matmul(f_div_C, g_x) + y = y.permute(0, 2, 1).contiguous() + y = y.view(batch_size, self.inter_channels, *x.size()[2:]) + W_y = self.W(y) + z = W_y + x + + return z + + def _concatenation(self, x): + batch_size = x.size(0) + + # g=>(b, c, t, h, w)->(b, 0.5c, thw/s**2) + g_x = self.g(x).view(batch_size, self.inter_channels, -1) + + # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw, 0.5c) + # phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw/s**2, 0.5c) + theta_x = self.theta(x).view(batch_size, self.inter_channels, -1).permute(0, 2, 1) + phi_x = self.phi(x).view(batch_size, self.inter_channels, -1).permute(0, 2, 1) + + # theta => (b, thw, 0.5c) -> (b, thw, 1) -> (b, 1, thw) -> (expand) (b, thw/s**2, thw) + # phi => (b, thw/s**2, 0.5c) -> (b, thw/s**2, 1) -> (expand) (b, thw/s**2, thw) + # f=> RELU[(b, thw/s**2, thw) + (b, thw/s**2, thw)] = (b, thw/s**2, thw) + f = self.wf_theta(theta_x).permute(0, 2, 1).repeat(1, phi_x.size(1), 1) + \ + self.wf_phi(phi_x).repeat(1, 1, theta_x.size(1)) + f = F.relu(f, inplace=True) + + # Normalise the relations + N = f.size(-1) + f_div_c = f / N + + # g(x_j) * f(x_j, x_i) + # (b, 0.5c, thw/s**2) * (b, thw/s**2, thw) -> (b, 0.5c, thw) + y = torch.matmul(g_x, f_div_c) + y = y.contiguous().view(batch_size, self.inter_channels, *x.size()[2:]) + W_y = self.W(y) + z = W_y + x + + return z + + def _concatenation_proper(self, x): + batch_size = x.size(0) + + # g=>(b, c, t, h, w)->(b, 0.5c, thw/s**2) + g_x = self.g(x).view(batch_size, self.inter_channels, -1) + + # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw) + # phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw/s**2) + theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) + phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) + + # theta => (b, 0.5c, thw) -> (expand) (b, 0.5c, thw/s**2, thw) + # phi => (b, 0.5c, thw/s**2) -> (expand) (b, 0.5c, thw/s**2, thw) + # f=> RELU[(b, 0.5c, thw/s**2, thw) + (b, 0.5c, thw/s**2, thw)] = (b, 0.5c, thw/s**2, thw) + f = theta_x.unsqueeze(dim=2).repeat(1,1,phi_x.size(2),1) + \ + phi_x.unsqueeze(dim=3).repeat(1,1,1,theta_x.size(2)) + f = F.relu(f, inplace=True) + + # psi -> W_psi^t * f -> (b, 1, thw/s**2, thw) -> (b, thw/s**2, thw) + f = torch.squeeze(self.psi(f), dim=1) + + # Normalise the relations + f_div_c = F.softmax(f, dim=1) + + # g(x_j) * f(x_j, x_i) + # (b, 0.5c, thw/s**2) * (b, thw/s**2, thw) -> (b, 0.5c, thw) + y = torch.matmul(g_x, f_div_c) + y = y.contiguous().view(batch_size, self.inter_channels, *x.size()[2:]) + W_y = self.W(y) + z = W_y + x + + return z + + def _concatenation_proper_down(self, x): + batch_size = x.size(0) + + # g=>(b, c, t, h, w)->(b, 0.5c, thw/s**2) + g_x = self.g(x).view(batch_size, self.inter_channels, -1) + + # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw) + # phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw/s**2) + theta_x = self.theta(x) + downsampled_size = theta_x.size() + theta_x = theta_x.view(batch_size, self.inter_channels, -1) + phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) + + # theta => (b, 0.5c, thw) -> (expand) (b, 0.5c, thw/s**2, thw) + # phi => (b, 0.5, thw/s**2) -> (expand) (b, 0.5c, thw/s**2, thw) + # f=> RELU[(b, 0.5c, thw/s**2, thw) + (b, 0.5c, thw/s**2, thw)] = (b, 0.5c, thw/s**2, thw) + f = theta_x.unsqueeze(dim=2).repeat(1,1,phi_x.size(2),1) + \ + phi_x.unsqueeze(dim=3).repeat(1,1,1,theta_x.size(2)) + f = F.relu(f, inplace=True) + + # psi -> W_psi^t * f -> (b, 0.5c, thw/s**2, thw) -> (b, 1, thw/s**2, thw) -> (b, thw/s**2, thw) + f = torch.squeeze(self.psi(f), dim=1) + + # Normalise the relations + f_div_c = F.softmax(f, dim=1) + + # g(x_j) * f(x_j, x_i) + # (b, 0.5c, thw/s**2) * (b, thw/s**2, thw) -> (b, 0.5c, thw) + y = torch.matmul(g_x, f_div_c) + y = y.contiguous().view(batch_size, self.inter_channels, *downsampled_size[2:]) + + # upsample the final featuremaps # (b,0.5c,t/s1,h/s2,w/s3) + y = F.upsample(y, size=x.size()[2:], mode='trilinear') + + # attention block output + W_y = self.W(y) + z = W_y + x + + return z + + +class NONLocalBlock2D(_NonLocalBlockND): + def __init__(self, in_channels, inter_channels=None, mode='embedded_gaussian', sub_sample_factor=2, bn_layer=True): + super(NONLocalBlock2D, self).__init__(in_channels, + inter_channels=inter_channels, + dimension=2, mode=mode, + sub_sample_factor=sub_sample_factor, + bn_layer=bn_layer) + + +class CANet(nn.Module): + """ + Implementation of CANet (Comprehensive Attention Network) for image segmentation. + + * Reference: R. Gu et al. `CA-Net: Comprehensive Attention Convolutional Neural Networks + for Explainable Medical Image Segmentation `_. + IEEE Transactions on Medical Imaging, 40(2),2021:699-711. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param bilinear: (bool) Using bilinear for up-sampling or not. + If False, deconvolution will be used for up-sampling. + """ + def __init__(self, params): #args, in_ch=3, n_classes=2, feature_scale=4, is_deconv=True, is_batchnorm=True, + # nonlocal_mode='concatenation', attention_dsample=(1, 1)): + super(CANet, self).__init__() + self.in_channels = params['in_chns'] + self.num_classes = params['class_num'] + self.is_deconv = params.get('is_deconv', True) + self.is_batchnorm = params.get('is_batchnorm', True) + self.feature_scale = params.get('feature_scale', 4) + nonlocal_mode = 'concatenation' + attention_dsample = (1, 1) + + filters = [64, 128, 256, 512, 1024] + filters = [int(x / self.feature_scale) for x in filters] + + # downsampling + self.conv1 = conv_block(self.in_channels, filters[0]) + self.maxpool1 = nn.MaxPool2d(kernel_size=(2, 2)) + + self.conv2 = conv_block(filters[0], filters[1]) + self.maxpool2 = nn.MaxPool2d(kernel_size=(2, 2)) + + self.conv3 = conv_block(filters[1], filters[2]) + self.maxpool3 = nn.MaxPool2d(kernel_size=(2, 2)) + + self.conv4 = conv_block(filters[2], filters[3], drop_out=True) + self.maxpool4 = nn.MaxPool2d(kernel_size=(2, 2)) + + self.center = conv_block(filters[3], filters[4], drop_out=True) + + # attention blocks + # self.attentionblock1 = GridAttentionBlock2D(in_channels=filters[0], gating_channels=filters[1], + # inter_channels=filters[0]) + self.attentionblock2 = MultiAttentionBlock(in_size=filters[1], gate_size=filters[2], inter_size=filters[1], + nonlocal_mode=nonlocal_mode, sub_sample_factor=attention_dsample) + self.attentionblock3 = MultiAttentionBlock(in_size=filters[2], gate_size=filters[3], inter_size=filters[2], + nonlocal_mode=nonlocal_mode, sub_sample_factor=attention_dsample) + self.nonlocal4_2 = NONLocalBlock2D(in_channels=filters[4], inter_channels=filters[4] // 4) + + # upsampling + self.up_concat4 = UpCat(filters[4], filters[3], self.is_deconv) + self.up_concat3 = UpCat(filters[3], filters[2], self.is_deconv) + self.up_concat2 = UpCat(filters[2], filters[1], self.is_deconv) + self.up_concat1 = UpCat(filters[1], filters[0], self.is_deconv) + self.up4 = SE_Conv_Block(filters[4], filters[3], drop_out=True) + self.up3 = SE_Conv_Block(filters[3], filters[2]) + self.up2 = SE_Conv_Block(filters[2], filters[1]) + self.up1 = SE_Conv_Block(filters[1], filters[0]) + + # For deep supervision, project the multi-scale feature maps to the same number of channels + self.dsv1 = nn.Conv2d(in_channels=filters[0], out_channels=filters[0]//2, kernel_size=1) + self.dsv2 = nn.Conv2d(in_channels=filters[1], out_channels=filters[0]//2, kernel_size=1) + self.dsv3 = nn.Conv2d(in_channels=filters[2], out_channels=filters[0]//2, kernel_size=1) + self.dsv4 = nn.Conv2d(in_channels=filters[3], out_channels=filters[0]//2, kernel_size=1) + + self.scale_att = scale_atten_convblock(in_size=filters[0]//2 * 4, out_size=filters[0]) + self.final = nn.Conv2d(filters[0], self.num_classes, kernel_size=1) + + def forward(self, inputs): + # Feature Extraction + conv1 = self.conv1(inputs) + maxpool1 = self.maxpool1(conv1) + + conv2 = self.conv2(maxpool1) + maxpool2 = self.maxpool2(conv2) + + conv3 = self.conv3(maxpool2) + maxpool3 = self.maxpool3(conv3) + + conv4 = self.conv4(maxpool3) + maxpool4 = self.maxpool4(conv4) + + # Gating Signal Generation + center = self.center(maxpool4) + + # Attention Mechanism + # Upscaling Part (Decoder) + up4 = self.up_concat4(conv4, center) + g_conv4 = self.nonlocal4_2(up4) + + up4, att_weight4 = self.up4(g_conv4) + g_conv3, att3 = self.attentionblock3(conv3, up4) + + # atten3_map = att3.cpu().detach().numpy().astype(np.float) + # atten3_map = ndimage.interpolation.zoom(atten3_map, [1.0, 1.0, 224 / atten3_map.shape[2], + # 300 / atten3_map.shape[3]], order=0) + + up3 = self.up_concat3(g_conv3, up4) + up3, att_weight3 = self.up3(up3) + g_conv2, att2 = self.attentionblock2(conv2, up3) + + up2 = self.up_concat2(g_conv2, up3) + up2, att_weight2 = self.up2(up2) + + up1 = self.up_concat1(conv1, up2) + up1, att_weight1 = self.up1(up1) + + # Deep Supervision + dsv1 = self.dsv1(up1) + dsv2 = F.interpolate(self.dsv2(up2), dsv1.shape[2:], mode = 'bilinear') + dsv3 = F.interpolate(self.dsv3(up3), dsv1.shape[2:], mode = 'bilinear') + dsv4 = F.interpolate(self.dsv4(up4), dsv1.shape[2:], mode = 'bilinear') + + dsv_cat = torch.cat([dsv1, dsv2, dsv3, dsv4], dim=1) + out = self.scale_att(dsv_cat) + + out = self.final(out) + + return out + +if __name__ == "__main__": + params = {'in_chns':3, + 'class_num':2} + Net = CANet(params) + Net = Net.double() + + x = np.random.rand(4, 3, 224, 224) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = Net(xt) + print(len(y.size())) + y = y.detach().numpy() + print(y.shape) +axpool3(conv3) + + conv4 = self.conv4(maxpool3) + maxpool4 = self.maxpool4(conv4) + + # Gating Signal Generation + center = self.center(maxpool4) + + # Attention Mechanism + # Upscaling Part (Decoder) + up4 = self.up_concat4(conv4, center) + g_conv4 = self.nonlocal4_2(up4) + + up4, att_weight4 = self.up4(g_conv4) + g_conv3, att3 = self.attentionblock3(conv3, up4) + + # atten3_map = att3.cpu().detach().numpy().astype(np.float) + # atten3_map = ndimage.interpolation.zoom(atten3_map, [1.0, 1.0, 224 / atten3_map.shape[2], + # 300 / atten3_map.shape[3]], order=0) + + up3 = self.up_concat3(g_conv3, up4) + up3, att_weight3 = self.up3(up3) + g_conv2, att2 = self.attentionblock2(conv2, up3) + + # atten2_map = att2.cpu().detach().numpy().astype(np.float) + # atten2_map = ndimage.interpolation.zoom(atten2_map, [1.0, 1.0, 224 / atten2_map.shape[2], + # 300 / atten2_map.shape[3]], order=0) + + up2 = self.up_concat2(g_conv2, up3) + up2, att_weight2 = self.up2(up2) + # g_conv1, att1 = self.attentionblock1(conv1, up2) + + # atten1_map = att1.cpu().detach().numpy().astype(np.float) + # atten1_map = ndimage.interpolation.zoom(atten1_map, [1.0, 1.0, 224 / atten1_map.shape[2], + # 300 / atten1_map.shape[3]], order=0) + up1 = self.up_concat1(conv1, up2) + up1, att_weight1 = self.up1(up1) + + # Deep Supervision + dsv4 = self.dsv4(up4) + dsv3 = self.dsv3(up3) + dsv2 = self.dsv2(up2) + dsv1 = self.dsv1(up1) + dsv_cat = torch.cat([dsv1, dsv2, dsv3, dsv4], dim=1) + out = self.scale_att(dsv_cat) + + out = self.final(out) + + return out + + + +if __name__ == "__main__": + params = {'in_chns':3, + 'class_num':2} + Net = CANet(params) + Net = Net.double() + + x = np.random.rand(2, 3, 256, 256) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = Net(xt) + print(len(y.size())) + y = y.detach().numpy() + print(y.shape) diff --git a/pymic/net/net2d/unet2d_mcnet.py b/pymic/net/net2d/unet2d_mcnet.py index 76ee5de..be5b16b 100644 --- a/pymic/net/net2d/unet2d_mcnet.py +++ b/pymic/net/net2d/unet2d_mcnet.py @@ -39,7 +39,7 @@ def __init__(self, params): 'class_num': class_num, 'up_mode': 2, 'multiscale_pred': False} - self.encoder = Encoder(params1) + self.encoder = Encoder(params1) self.decoder1 = Decoder(params1) self.decoder2 = Decoder(params2) self.decoder3 = Decoder(params3) diff --git a/pymic/util/evaluation_seg.py b/pymic/util/evaluation_seg.py index 836401d..a9b114b 100644 --- a/pymic/util/evaluation_seg.py +++ b/pymic/util/evaluation_seg.py @@ -108,22 +108,30 @@ def binary_hd95(s, g, spacing = None): """ s_edge = get_edge_points(s) g_edge = get_edge_points(g) - image_dim = len(s.shape) - assert(image_dim == len(g.shape)) - if(spacing == None): - spacing = [1.0] * image_dim + ns = s_edge.sum() + ng = g_edge.sum() + if(ns + ng == 0): + hd95 = 0.0 + elif(ns * ng == 0): + hd95 = 100.0 else: - assert(image_dim == len(spacing)) - s_dis = ndimage.distance_transform_edt(1-s_edge, sampling = spacing) - g_dis = ndimage.distance_transform_edt(1-g_edge, sampling = spacing) - - dist_list1 = s_dis[g_edge > 0] - dist_list1 = sorted(dist_list1) - dist1 = dist_list1[int(len(dist_list1)*0.95)] - dist_list2 = g_dis[s_edge > 0] - dist_list2 = sorted(dist_list2) - dist2 = dist_list2[int(len(dist_list2)*0.95)] - return max(dist1, dist2) + image_dim = len(s.shape) + assert(image_dim == len(g.shape)) + if(spacing == None): + spacing = [1.0] * image_dim + else: + assert(image_dim == len(spacing)) + s_dis = ndimage.distance_transform_edt(1-s_edge, sampling = spacing) + g_dis = ndimage.distance_transform_edt(1-g_edge, sampling = spacing) + + dist_list1 = s_dis[g_edge > 0] + dist_list1 = sorted(dist_list1) + dist1 = dist_list1[int(len(dist_list1)*0.95)] + dist_list2 = g_dis[s_edge > 0] + dist_list2 = sorted(dist_list2) + dist2 = dist_list2[int(len(dist_list2)*0.95)] + hd95 = max(dist1, dist2) + return hd95 def binary_assd(s, g, spacing = None): @@ -150,9 +158,14 @@ def binary_assd(s, g, spacing = None): ns = s_edge.sum() ng = g_edge.sum() - s_dis_g_edge = s_dis * g_edge - g_dis_s_edge = g_dis * s_edge - assd = (s_dis_g_edge.sum() + g_dis_s_edge.sum()) / (ns + ng) + if(ns + ng == 0): + assd = 0.0 + elif(ns*ng == 0): + assd = 20.0 + else: + s_dis_g_edge = s_dis * g_edge + g_dis_s_edge = g_dis * s_edge + assd = (s_dis_g_edge.sum() + g_dis_s_edge.sum()) / (ns + ng) return assd # relative volume error evaluation @@ -315,8 +328,10 @@ def evaluation(config): # save the result as csv if(output_name is None): - output_name = "{0:}/eval_{1:}.csv".format(seg_root, metric) - with open(output_name, mode='w') as csv_file: + metric_output_name = "{0:}/eval_{1:}.csv".format(seg_root, metric) + else: + metric_output_name = output_name + with open(metric_output_name, mode='w') as csv_file: csv_writer = csv.writer(csv_file, delimiter=',', quotechar='"',quoting=csv.QUOTE_MINIMAL) head = ['image'] + ["class_{0:}".format(i) for i in label_list] From 28000e230842e4a69155c5d3ed15ac146aecb5e0 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 10 Nov 2023 11:44:06 +0800 Subject: [PATCH 170/225] update unet2d only give one prediction for inference --- pymic/net/net2d/unet2d.py | 28 ++++--- pymic/net/net2d/unet2d_urpc.py | 132 --------------------------------- 2 files changed, 16 insertions(+), 144 deletions(-) delete mode 100644 pymic/net/net2d/unet2d_urpc.py diff --git a/pymic/net/net2d/unet2d.py b/pymic/net/net2d/unet2d.py index 758bfe5..7d14a2e 100644 --- a/pymic/net/net2d/unet2d.py +++ b/pymic/net/net2d/unet2d.py @@ -153,8 +153,8 @@ def __init__(self, params): self.ft_chns = self.params['feature_chns'] self.dropout = self.params['dropout'] self.n_class = self.params['class_num'] - self.up_mode = self.params['up_mode'] - self.mul_pred = self.params['multiscale_pred'] + self.up_mode = self.params.get('up_mode', 2) + self.mul_pred = self.params.get('multiscale_pred', False) assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) @@ -183,10 +183,10 @@ def forward(self, x): x_d1 = self.up3(x_d2, x1) x_d0 = self.up4(x_d1, x0) output = self.out_conv(x_d0) - if(self.mul_pred): + if(self.mul_pred and self.training): output1 = self.out_conv1(x_d1) - output2 = self.out_conv1(x_d2) - output3 = self.out_conv1(x_d3) + output2 = self.out_conv2(x_d2) + output3 = self.out_conv3(x_d3) output = [output, output1, output2, output3] return output @@ -263,19 +263,23 @@ def forward(self, x): if __name__ == "__main__": params = {'in_chns':4, - 'feature_chns':[2, 8, 32, 48, 64], + 'feature_chns':[16, 32, 64, 128, 256], 'dropout': [0, 0, 0.3, 0.4, 0.5], 'class_num': 2, 'up_mode': 0, - 'multiscale_pred': False} + 'multiscale_pred': True} Net = UNet2D(params) Net = Net.double() - x = np.random.rand(4, 4, 10, 96, 96) + x = np.random.rand(4, 4, 10, 256, 256) xt = torch.from_numpy(x) xt = torch.tensor(xt) - y = Net(xt) - print(len(y.size())) - y = y.detach().numpy() - print(y.shape) + out = Net(xt) + if params['multiscale_pred']: + for y in out: + print(len(y.size())) + y = y.detach().numpy() + print(y.shape) + else: + print(out.shape) diff --git a/pymic/net/net2d/unet2d_urpc.py b/pymic/net/net2d/unet2d_urpc.py deleted file mode 100644 index ee8ab7c..0000000 --- a/pymic/net/net2d/unet2d_urpc.py +++ /dev/null @@ -1,132 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division - -import torch -import torch.nn as nn -import numpy as np -from torch.distributions.uniform import Uniform -from pymic.net.net2d.unet2d import ConvBlock, DownBlock, UpBlock - -def FeatureDropout(x): - attention = torch.mean(x, dim=1, keepdim=True) - max_val, _ = torch.max(attention.view( - x.size(0), -1), dim=1, keepdim=True) - threshold = max_val * np.random.uniform(0.7, 0.9) - threshold = threshold.view(x.size(0), 1, 1, 1).expand_as(attention) - drop_mask = (attention < threshold).float() - x = x.mul(drop_mask) - return x - -class FeatureNoise(nn.Module): - def __init__(self, uniform_range=0.3): - super(FeatureNoise, self).__init__() - self.uni_dist = Uniform(-uniform_range, uniform_range) - - def feature_based_noise(self, x): - noise_vector = self.uni_dist.sample( - x.shape[1:]).to(x.device).unsqueeze(0) - x_noise = x.mul(noise_vector) + x - return x_noise - - def forward(self, x): - x = self.feature_based_noise(x) - return x - -class UNet2D_URPC(nn.Module): - """ - An modification the U-Net to obtain multi-scale prediction according to - the URPC paper. - - * Reference: Xiangde Luo, Guotai Wang*, Wenjun Liao, Jieneng Chen, Tao Song, Yinan Chen, - Shichuan Zhang, Dimitris N. Metaxas, Shaoting Zhang. - Semi-Supervised Medical Image Segmentation via Uncertainty Rectified Pyramid Consistency . - `Medical Image Analysis 2022. `_ - - Also see: https://github.com/HiLab-git/SSL4MIS/blob/master/code/networks/unet.py - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param bilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - """ - def __init__(self, params): - super(UNet2D_URPC, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.n_class = self.params['class_num'] - self.bilinear = self.params['bilinear'] - assert(len(self.ft_chns) == 5) - - self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) - self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) - self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) - self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) - self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], 0.0, self.bilinear) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], 0.0, self.bilinear) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], 0.0, self.bilinear) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], 0.0, self.bilinear) - - self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, - kernel_size = 3, padding = 1) - self.out_conv_dp4 = nn.Conv2d(self.ft_chns[4], self.n_class, - kernel_size=3, padding=1) - self.out_conv_dp3 = nn.Conv2d(self.ft_chns[3], self.n_class, - kernel_size=3, padding=1) - self.out_conv_dp2 = nn.Conv2d(self.ft_chns[2], self.n_class, - kernel_size=3, padding=1) - self.out_conv_dp1 = nn.Conv2d(self.ft_chns[1], self.n_class, - kernel_size=3, padding=1) - self.feature_noise = FeatureNoise() - - def forward(self, x): - x_shape = list(x.shape) - if(len(x_shape) == 5): - [N, C, D, H, W] = x_shape - new_shape = [N*D, C, H, W] - x = torch.transpose(x, 1, 2) - x = torch.reshape(x, new_shape) - x0 = self.in_conv(x) - x1 = self.down1(x0) - x2 = self.down2(x1) - x3 = self.down3(x2) - x4 = self.down4(x3) - - x = self.up1(x4, x3) - if self.training: - x = nn.functional.dropout(x, p=0.5) - dp3_out = self.out_conv_dp3(x) - - x = self.up2(x, x2) - if self.training: - x = FeatureDropout(x) - dp2_out = self.out_conv_dp2(x) - - x = self.up3(x, x1) - if self.training: - x = self.feature_noise(x) - dp1_out = self.out_conv_dp1(x) - - x = self.up4(x, x0) - dp0_out = self.out_conv(x) - - out_shape = list(dp0_out.shape)[2:] - dp3_out = nn.functional.interpolate(dp3_out, out_shape) - dp2_out = nn.functional.interpolate(dp2_out, out_shape) - dp1_out = nn.functional.interpolate(dp1_out, out_shape) - out = [dp0_out, dp1_out, dp2_out, dp3_out] - - if(len(x_shape) == 5): - new_shape = [N, D] + list(dp0_out.shape)[1:] - for i in range(len(out)): - out[i] = torch.transpose(torch.reshape(out[i], new_shape), 1, 2) - return out \ No newline at end of file From 5948827dd670c53aa1b96fe6faab14af50514298 Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 29 Nov 2023 09:58:36 +0800 Subject: [PATCH 171/225] update unet2d_scse --- pymic/net/net2d/unet2d_scse.py | 143 +++++++++++---------------------- pymic/test/test_net2d.py | 57 +++++++++++++ pymic/test/test_net3d.py | 108 +++++++++++++++++++++++++ 3 files changed, 211 insertions(+), 97 deletions(-) create mode 100644 pymic/test/test_net2d.py create mode 100644 pymic/test/test_net3d.py diff --git a/pymic/net/net2d/unet2d_scse.py b/pymic/net/net2d/unet2d_scse.py index 125843e..54a5d2f 100644 --- a/pymic/net/net2d/unet2d_scse.py +++ b/pymic/net/net2d/unet2d_scse.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn import numpy as np +from pymic.net.net2d.unet2d import UpBlock, Encoder, Decoder, UNet2D from pymic.net.net2d.scse2d import * class ConvScSEBlock(nn.Module): @@ -50,116 +51,64 @@ def __init__(self, in_channels, out_channels, dropout_p): def forward(self, x): return self.maxpool_conv(x) -class UpBlock(nn.Module): +class UpBlockScSE(UpBlock): """Up-sampling followed by `ConvScSEBlock` in U-Net structure. - :param in_channels1: (int) Input channel number for low-resolution feature map. - :param in_channels2: (int) Input channel number for high-resolution feature map. - :param out_channels: (int) Output channel number. - :param dropout_p: (int) Dropout probability. - :param bilinear: (bool) Use bilinear for up-sampling or not. + The parameters for the backbone should be given in the `params` dictionary. + See :mod:`pymic.net.net2d.unet2d.UpBlock` for details. """ - def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, - bilinear=True): - super(UpBlock, self).__init__() - self.bilinear = bilinear - if bilinear: - self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size = 1) - self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) - else: - self.up = nn.ConvTranspose2d(in_channels1, in_channels2, kernel_size=2, stride=2) + def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, up_mode = 2): + super(UpBlockScSE, self).__init__(in_channels1, in_channels2, out_channels, dropout_p, up_mode) self.conv = ConvScSEBlock(in_channels2 * 2, out_channels, dropout_p) - def forward(self, x1, x2): - if self.bilinear: - x1 = self.conv1x1(x1) - x1 = self.up(x1) - x = torch.cat([x2, x1], dim=1) - return self.conv(x) - -class UNet2D_ScSE(nn.Module): +class EncoderScSE(Encoder): """ - Combining 2D U-Net with SCSE module. + Encoder of 2D UNet with ScSE. - * Reference: Abhijit Guha Roy, Nassir Navab, Christian Wachinger: - Recalibrating Fully Convolutional Networks With Spatial and Channel - "Squeeze and Excitation" Blocks. - `IEEE Trans. Med. Imaging 38(2): 540-549 (2019). `_ - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param bilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. + The parameters for the backbone should be given in the `params` dictionary. + See :mod:`pymic.net.net2d.unet2d.Encoder` for details. """ def __init__(self, params): - super(UNet2D_ScSE, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.n_class = self.params['class_num'] - self.bilinear = self.params['bilinear'] - assert(len(self.ft_chns) == 5) + super(EncoderScSE, self).__init__(params) self.in_conv= ConvScSEBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) - self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p = self.dropout[3]) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p = self.dropout[2]) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p = self.dropout[1]) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p = self.dropout[0]) - - self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, - kernel_size = 3, padding = 1) + if(len(self.ft_chns) == 5): + self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - def forward(self, x): - x_shape = list(x.shape) - if(len(x_shape) == 5): - [N, C, D, H, W] = x_shape - new_shape = [N*D, C, H, W] - x = torch.transpose(x, 1, 2) - x = torch.reshape(x, new_shape) - x0 = self.in_conv(x) - x1 = self.down1(x0) - x2 = self.down2(x1) - x3 = self.down3(x2) - x4 = self.down4(x3) +class DecoderScSE(Decoder): + """ + Decoder of 2D UNet with ScSE. + + The parameters for the backbone should be given in the `params` dictionary. + See :mod:`pymic.net.net2d.unet2d.Decoder` for details. + """ + def __init__(self, params): + super(DecoderScSE, self).__init__(params) - x = self.up1(x4, x3) - x = self.up2(x, x2) - x = self.up3(x, x1) - x = self.up4(x, x0) - output = self.out_conv(x) - - if(len(x_shape) == 5): - new_shape = [N, D] + list(output.shape)[1:] - output = torch.reshape(output, new_shape) - output = torch.transpose(output, 1, 2) - return output - -if __name__ == "__main__": - params = {'in_chns':4, - 'feature_chns':[2, 8, 32, 48, 64], - 'dropout': [0, 0, 0.3, 0.4, 0.5], - 'class_num': 2, - 'bilinear': True} - Net = UNet2D_ScSE(params) - Net = Net.double() - - x = np.random.rand(4, 4, 10, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - print(len(y.size())) - y = y.detach().numpy() - print(y.shape) \ No newline at end of file + + if(len(self.ft_chns) == 5): + self.up1 = UpBlockScSE(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.up_mode) + self.up2 = UpBlockScSE(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.up_mode) + self.up3 = UpBlockScSE(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.up_mode) + self.up4 = UpBlockScSE(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.up_mode) + + +class UNet2D_ScSE(UNet2D): + """ + Combining 2D U-Net with SCSE module. + + * Reference: Abhijit Guha Roy, Nassir Navab, Christian Wachinger: + Recalibrating Fully Convolutional Networks With Spatial and Channel + "Squeeze and Excitation" Blocks. + `IEEE Trans. Med. Imaging 38(2): 540-549 (2019). `_ + + The parameters for the backbone should be given in the `params` dictionary. + See :mod:`pymic.net.net2d.unet2d.unet2d` for details. + """ + def __init__(self, params): + super(UNet2D_ScSE, self).__init__(params) + self.encoder = Encoder(params) + self.decoder = Decoder(params) diff --git a/pymic/test/test_net2d.py b/pymic/test/test_net2d.py new file mode 100644 index 0000000..aafaf20 --- /dev/null +++ b/pymic/test/test_net2d.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import numpy as np +from pymic.net.net2d.unet2d import UNet2D +from pymic.net.net2d.unet2d_scse import UNet2D_ScSE + +def test_unet2d(): + params = {'in_chns':4, + 'feature_chns':[16, 32, 64, 128, 256], + 'dropout': [0, 0, 0.3, 0.4, 0.5], + 'class_num': 2, + 'up_mode': 0, + 'multiscale_pred': True} + Net = UNet2D(params) + Net = Net.double() + + x = np.random.rand(4, 4, 10, 256, 256) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + out = Net(xt) + if params['multiscale_pred']: + for y in out: + print(len(y.size())) + y = y.detach().numpy() + print(y.shape) + else: + print(out.shape) + +def test_unet2d_scse(): + params = {'in_chns':4, + 'feature_chns':[16, 32, 64, 128, 256], + 'dropout': [0, 0, 0.3, 0.4, 0.5], + 'class_num': 2, + 'up_mode': 0, + 'multiscale_pred': True} + Net = UNet2D_ScSE(params) + Net = Net.double() + + x = np.random.rand(4, 4, 10, 256, 256) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + out = Net(xt) + if params['multiscale_pred']: + for y in out: + print(len(y.size())) + y = y.detach().numpy() + print(y.shape) + else: + print(out.shape) + +if __name__ == "__main__": + # test_unet2d() + test_unet2d_scse() \ No newline at end of file diff --git a/pymic/test/test_net3d.py b/pymic/test/test_net3d.py new file mode 100644 index 0000000..180dcff --- /dev/null +++ b/pymic/test/test_net3d.py @@ -0,0 +1,108 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import numpy as np +from pymic.net.net3d.unet3d import UNet3D +from pymic.net.net3d.unet3d_scse import UNet3D_ScSE +from pymic.net.net3d.unet2d5 import UNet2D5 + +def test_unet3d(): + params = {'in_chns':4, + 'class_num': 2, + 'feature_chns':[2, 8, 32, 64], + 'dropout' : [0, 0, 0, 0.5], + 'up_mode': 2, + 'multiscale_pred': False} + Net = UNet3D(params) + Net = Net.double() + + x = np.random.rand(4, 4, 96, 96, 96) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = Net(xt) + y = y.detach().numpy() + print(y.shape) + + params = {'in_chns':4, + 'class_num': 2, + 'feature_chns':[2, 8, 32, 64, 128], + 'dropout' : [0, 0, 0, 0.4, 0.5], + 'up_mode': 3, + 'multiscale_pred': True} + Net = UNet3D(params) + Net = Net.double() + + x = np.random.rand(4, 4, 96, 128, 128) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + out = Net(xt) + for y in out: + y = y.detach().numpy() + print(y.shape) + +def test_unet3d_scse(): + params = {'in_chns':4, + 'feature_chns':[2, 8, 32, 48, 64], + 'dropout': [0, 0, 0.3, 0.4, 0.5], + 'class_num': 2, + 'up_mode': 2} + Net = UNet3D_ScSE(params) + Net = Net.double() + + x = np.random.rand(4, 4, 96, 96, 96) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = Net(xt) + print(len(y.size())) + y = y.detach().numpy() + print(y.shape) + +def test_unet2d5(): + params = {'in_chns':4, + 'feature_chns':[8, 16, 32, 64, 128], + 'conv_dims': [2, 2, 3, 3, 3], + 'dropout': [0, 0, 0.3, 0.4, 0.5], + 'class_num': 2, + 'up_mode': 2, + 'multiscale_pred': True} + Net = UNet2D5(params) + Net = Net.double() + + x = np.random.rand(4, 4, 32, 128, 128) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + out = Net(xt) + for y in out: + y = y.detach().numpy() + print(y.shape) + + params = {'in_chns':4, + 'feature_chns':[8, 16, 32, 64, 128], + 'conv_dims': [2, 3, 3, 3, 3], + 'dropout': [0, 0, 0.3, 0.4, 0.5], + 'class_num': 2, + 'up_mode': 0, + 'multiscale_pred': True} + Net = UNet2D5(params) + Net = Net.double() + + x = np.random.rand(4, 4, 64, 128, 128) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + out = Net(xt) + for y in out: + y = y.detach().numpy() + print(y.shape) + +if __name__ == "__main__": + # test_unet3d() + # test_unet3d_scse() + test_unet2d5() + + \ No newline at end of file From 0248db0b164e6994188fd6487d9cca2897486b7b Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 4 Dec 2023 17:14:20 +0800 Subject: [PATCH 172/225] add affine transform add affine transform --- pymic/io/nifty_dataset.py | 19 ++-- pymic/net_run/agent_preprocess.py | 104 +++++++++++++++++++ pymic/net_run/agent_seg.py | 2 + pymic/net_run/get_optimizer.py | 2 +- pymic/net_run/semi_sup/ssl_cps.py | 14 +++ pymic/net_run/semi_sup/ssl_mcnet.py | 9 +- pymic/net_run/train.py | 5 +- pymic/test/test_assd.py | 3 +- pymic/transform/affine.py | 156 ++++++++++++++++++++++++++++ pymic/transform/crop.py | 31 ++++-- pymic/transform/flip.py | 1 - pymic/transform/intensity.py | 7 +- pymic/transform/normalize.py | 24 +++-- pymic/transform/pad.py | 1 - pymic/transform/trans_dict.py | 2 + 15 files changed, 338 insertions(+), 42 deletions(-) create mode 100644 pymic/net_run/agent_preprocess.py create mode 100644 pymic/transform/affine.py diff --git a/pymic/io/nifty_dataset.py b/pymic/io/nifty_dataset.py index 9812d13..438a07b 100644 --- a/pymic/io/nifty_dataset.py +++ b/pymic/io/nifty_dataset.py @@ -38,7 +38,8 @@ def __init__(self, root_dir, csv_file, modal_num = 1, if('label' not in csv_keys): logging.warning("`label` section is not found in the csv file {0:}".format( csv_file) + "\n -- This is only allowed for self-supervised learning" + - "\n -- when `SelfSuperviseLabel` is used in the transform.") + "\n -- when `SelfSuperviseLabel` is used in the transform, or when" + + "\n -- loading the unlabeled data for preprocessing.") self.with_label = False self.image_weight_idx = None self.pixel_weight_idx = None @@ -52,15 +53,15 @@ def __len__(self): def __getlabel__(self, idx): csv_keys = list(self.csv_items.keys()) - label_idx = csv_keys.index('label') - label_name = "{0:}/{1:}".format(self.root_dir, - self.csv_items.iloc[idx, label_idx]) - label = load_image_as_nd_array(label_name)['data_array'] + label_idx = csv_keys.index('label') + label_name = self.csv_items.iloc[idx, label_idx] + label_name_full = "{0:}/{1:}".format(self.root_dir, label_name) + label = load_image_as_nd_array(label_name_full)['data_array'] if(self.task == TaskType.SEGMENTATION): label = np.asarray(label, np.int32) elif(self.task == TaskType.RECONSTRUCTION): label = np.asarray(label, np.float32) - return label + return label, label_name def __get_pixel_weight__(self, idx): weight_name = "{0:}/{1:}".format(self.root_dir, @@ -80,12 +81,14 @@ def __getitem__(self, idx): image_list.append(image_data) image = np.concatenate(image_list, axis = 0) image = np.asarray(image, np.float32) - sample = {'image': image, 'names' : names_list[0], + + sample = {'image': image, 'names' : names_list, 'origin':image_dict['origin'], 'spacing': image_dict['spacing'], 'direction':image_dict['direction']} if (self.with_label): - sample['label'] = self.__getlabel__(idx) + sample['label'], label_name = self.__getlabel__(idx) + sample['names'].append(label_name) assert(image.shape[1:] == sample['label'].shape[1:]) if (self.image_weight_idx is not None): sample['image_weight'] = self.csv_items.iloc[idx, self.image_weight_idx] diff --git a/pymic/net_run/agent_preprocess.py b/pymic/net_run/agent_preprocess.py new file mode 100644 index 0000000..c681de9 --- /dev/null +++ b/pymic/net_run/agent_preprocess.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import os +import sys +import torch +import torchvision.transforms as transforms +from pymic.util.parse_config import * +from pymic.io.image_read_write import save_nd_array_as_image +from pymic.io.nifty_dataset import NiftyDataset +from pymic.transform.trans_dict import TransformDict + + + +class PreprocessAgent(object): + def __init__(self, config): + super(PreprocessAgent, self).__init__() + self.config = config + self.transform_dict = TransformDict + self.task_type = config['dataset']['task_type'] + self.dataloader = None + self.dataloader_unlab= None + + def get_dataset_from_config(self): + root_dir = self.config['dataset']['root_dir'] + modal_num = self.config['dataset'].get('modal_num', 1) + transform_names = self.config['dataset']["transform"] + + self.transform_list = [] + if(transform_names is None or len(transform_names) == 0): + data_transform = None + else: + transform_param = self.config['dataset'] + transform_param['task'] = self.task_type + for name in transform_names: + if(name not in self.transform_dict): + raise(ValueError("Undefined transform {0:}".format(name))) + one_transform = self.transform_dict[name](transform_param) + self.transform_list.append(one_transform) + data_transform = transforms.Compose(self.transform_list) + + data_csv = self.config['dataset'].get('data_csv', None) + data_csv_unlab = self.config['dataset'].get('data_csv_unlab', None) + if(data_csv is not None): + dataset = NiftyDataset(root_dir = root_dir, + csv_file = data_csv, + modal_num = modal_num, + with_label= True, + transform = data_transform, + task = self.task_type) + self.dataloader = torch.utils.data.DataLoader(dataset, + batch_size = 1, shuffle=False, num_workers= 8, + worker_init_fn=None, generator = torch.Generator()) + if(data_csv_unlab is not None): + dataset_unlab = NiftyDataset(root_dir = root_dir, + csv_file = data_csv_unlab, + modal_num = modal_num, + with_label= False, + transform = data_transform, + task = self.task_type) + self.dataloader_unlab = torch.utils.data.DataLoader(dataset_unlab, + batch_size = 1, shuffle=False, num_workers= 8, + worker_init_fn=None, generator = torch.Generator()) + + def run(self): + """ + Do preprocessing for labeled and unlabeled data. + """ + self.get_dataset_from_config() + out_dir = self.config['dataset']['output_dir'] + for dataloader in [self.dataloader, self.dataloader_unlab]: + for item in dataloader: + img = item['image'][0] # the batch size is 1 + # save differnt modaliteis + img_names = item['names'] + spacing = [x.numpy()[0] for x in item['spacing']] + for i in range(img.shape[0]): + image_name = out_dir + "/" + img_names[i][0] + print(image_name) + save_nd_array_as_image(img[i], image_name, reference_name = None, spacing=spacing) + if('label' in item): + lab = item['label'][0] + label_name = out_dir + "/" + img_names[-1][0] + print(label_name) + save_nd_array_as_image(lab[0], label_name, reference_name = None, spacing=spacing) + +def main(): + """ + The main function for data preprocessing. + """ + if(len(sys.argv) < 2): + print('Number of arguments should be 2. e.g.') + print(' pymic_preprocess config.cfg') + exit() + cfg_file = str(sys.argv[1]) + if(not os.path.isfile(cfg_file)): + raise ValueError("The config file does not exist: " + cfg_file) + config = parse_config(cfg_file) + config = synchronize_config(config) + agent = PreprocessAgent(config) + agent.run() + +if __name__ == "__main__": + main() + diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 65c2b68..c376546 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -89,6 +89,8 @@ def create_network(self): self.net.float() else: self.net.double() + if(hasattr(self.net, "set_stage")): + self.net.set_stage(self.stage) param_number = sum(p.numel() for p in self.net.parameters() if p.requires_grad) logging.info('parameter number {0:}'.format(param_number)) diff --git a/pymic/net_run/get_optimizer.py b/pymic/net_run/get_optimizer.py index ad8fda0..53de5e3 100644 --- a/pymic/net_run/get_optimizer.py +++ b/pymic/net_run/get_optimizer.py @@ -14,7 +14,7 @@ def get_optimizer(name, net_params, optim_params): param_group = [{'params': net_params, 'initial_lr': lr}] if(keyword_match(name, "SGD")): return optim.SGD(param_group, lr, - momentum = momentum, weight_decay = weight_decay) + momentum = momentum, weight_decay = weight_decay, nesterov = True) elif(keyword_match(name, "Adam")): return optim.Adam(param_group, lr, weight_decay = weight_decay) elif(keyword_match(name, "SparseAdam")): diff --git a/pymic/net_run/semi_sup/ssl_cps.py b/pymic/net_run/semi_sup/ssl_cps.py index db1fb28..df4b5af 100644 --- a/pymic/net_run/semi_sup/ssl_cps.py +++ b/pymic/net_run/semi_sup/ssl_cps.py @@ -6,6 +6,7 @@ from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice +from pymic.io.image_read_write import save_nd_array_as_image from pymic.net_run.semi_sup import SSLSegAgent from pymic.util.ramps import get_rampup_ratio @@ -57,6 +58,19 @@ def training(self): x0 = self.convert_tensor_type(data_lab['image']) y0 = self.convert_tensor_type(data_lab['label_prob']) x1 = self.convert_tensor_type(data_unlab['image']) + + # for debug + # for i in range(x0.shape[0]): + # image_i = x0[i][0] + # label_i = np.argmax(y0[i], axis = 0) + # # pixw_i = pix_w[i][0] + # print(image_i.shape, label_i.shape) + # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) + # label_name = "temp/label_{0:}_{1:}.nii.gz".format(it, i) + # save_nd_array_as_image(image_i, image_name, reference_name = None) + # save_nd_array_as_image(label_i, label_name, reference_name = None) + # continue + inputs = torch.cat([x0, x1], dim = 0) inputs, y0 = inputs.to(self.device), y0.to(self.device) diff --git a/pymic/net_run/semi_sup/ssl_mcnet.py b/pymic/net_run/semi_sup/ssl_mcnet.py index 955357d..66e1034 100644 --- a/pymic/net_run/semi_sup/ssl_mcnet.py +++ b/pymic/net_run/semi_sup/ssl_mcnet.py @@ -22,8 +22,8 @@ class SSLMCNet(SSLSegAgent): with multiple decoders for learning, such as `pymic.net.net2d.unet2d_mcnet.MCNet2D`. * Reference: Yicheng Wu, Zongyuan Ge et al. Mutual consistency learning for - semi-supervised medical image segmentation. - `Medical Image Analysis 2022. `_ + semi-supervised medical image segmentation. + `MIA 2022. `_ The original code is at: https://github.com/ycwu1997/MC-Net @@ -34,11 +34,6 @@ class SSLMCNet(SSLSegAgent): In the configuration dictionary, in addition to the four sections (`dataset`, `network`, `training` and `inference`) used in fully supervised learning, an extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. - - Special parameters required for MCNet in `semi_supervised_learning` section: - - :param temperature: (float) temperature for label sharpening. The default value is 0.1. - """ def training(self): class_num = self.config['network']['class_num'] diff --git a/pymic/net_run/train.py b/pymic/net_run/train.py index 0167f2f..426b620 100644 --- a/pymic/net_run/train.py +++ b/pymic/net_run/train.py @@ -78,11 +78,12 @@ def main(): os.makedirs(log_dir, exist_ok=True) dst_cfg = cfg_file if "/" not in cfg_file else cfg_file.split("/")[-1] shutil.copy(cfg_file, log_dir + "/" + dst_cfg) + datetime_str = str(datetime.now())[:-7].replace(":", "_") if sys.version.startswith("3.9"): - logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(str(datetime.now())[:-7]), + logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(datetime_str), level=logging.INFO, format='%(message)s', force=True) # for python 3.9 else: - logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(str(datetime.now())[:-7]), + logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(datetime_str), level=logging.INFO, format='%(message)s') # for python 3.6 logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging_config(config) diff --git a/pymic/test/test_assd.py b/pymic/test/test_assd.py index 35c1804..6b2732a 100644 --- a/pymic/test/test_assd.py +++ b/pymic/test/test_assd.py @@ -18,7 +18,8 @@ def test_assd_2d(): plt.show() def test_assd_3d(): - img_name = "/home/x/projects/PyMIC_project/PyMIC_examples/seg_ssl/ACDC/result/unet2d_baseline/patient001_frame01.nii.gz" + # img_name = "/home/x/projects/PyMIC_project/PyMIC_examples/seg_ssl/ACDC/result/unet2d_baseline/patient001_frame01.nii.gz" + img_name = "/home/disk4t/data/heart/ACDC/preprocess/patient001_frame12_gt.nii.gz" img_obj = sitk.ReadImage(img_name) spacing = img_obj.GetSpacing() spacing = spacing[::-1] diff --git a/pymic/transform/affine.py b/pymic/transform/affine.py new file mode 100644 index 0000000..552516f --- /dev/null +++ b/pymic/transform/affine.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import json +import math +import random +import numpy as np +from skimage import transform +from scipy import ndimage +from pymic import TaskType +from pymic.transform.abstract_transform import AbstractTransform +from pymic.util.image_process import * + +class Affine(AbstractTransform): + """ + Apply Affine Transform to an ND volume in the x-y plane. + Input shape should be [C, D, H, W] or [C, H, W]. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `Affine_scale_range`: (list or tuple) The range for scaling, e.g., (0.5, 2.0) + :param `Affine_shear_range`: (list or tuple) The range for shearing angle, e.g., (0, 30) + :param `Affine_rotate_range`: (list or tuple) The range for rotation, e.g., (-45, 45) + :param `Affine_output_size`: (None, list or tuple of length 2) The output size after affine transformation. + For 3D volumes, as we only apply affine transformation in x-y plane, the output slice + number will be the same as the input slice number, so only the output height and width + need to be given here, e.g., (H, W). By default (`None`), the output size will be the + same as the input size. + """ + def __init__(self, params): + super(Affine, self).__init__(params) + self.scale_range = params['Affine_scale_range'.lower()] + self.shear_range = params['Affine_shear_range'.lower()] + self.rotat_range = params['Affine_rotate_range'.lower()] + self.output_shape= params.get('Affine_output_size'.lower(), None) + self.inverse = params.get('Affine_inverse'.lower(), True) + + def _get_affine_param(self, sample, output_shape): + """ + output_shape should only has two dimensions, e.g., (H, W) + """ + input_shape = sample['image'].shape + input_dim = len(input_shape) - 1 + assert(len(output_shape) >=2) + + in_y, in_x = input_shape[-2:] + out_y, out_x = output_shape[-2:] + points = [[0, out_y], + [0, 0], + [out_x, 0], + [out_x, out_y]] + + sx = random.random() * (self.scale_range[1] - self.scale_range[0]) + self.scale_range[0] + sy = random.random() * (self.scale_range[1] - self.scale_range[0]) + self.scale_range[0] + shx = (random.random() * (self.shear_range[1] - self.shear_range[0]) + self.shear_range[0]) * 3.14159/180 + shy = (random.random() * (self.shear_range[1] - self.shear_range[0]) + self.shear_range[0]) * 3.14159/180 + rot = (random.random() * (self.rotat_range[1] - self.rotat_range[0]) + self.rotat_range[0]) * 3.14159/180 + # get affine transform parameters + new_points = [] + for p in points: + x = sx * p[0] * (math.cos(rot) + math.tan(shy) * math.sin(rot)) - \ + sy * p[1] * (math.tan(shx) * math.cos(rot) + math.sin(rot)) + y = sx * p[0] * (math.sin(rot) - math.tan(shy) * math.cos(rot)) - \ + sy * p[1] * (math.tan(shx) * math.sin(rot) - math.cos(rot)) + new_points.append([x,y]) + bb_min = np.array(new_points).min(axis = 0) + bb_max = np.array(new_points).max(axis = 0) + bbx, bby = int(bb_max[0] - bb_min[0]), int(bb_max[1] - bb_min[1]) + # transform the points to the image coordinate + margin_x = in_x - bbx + margin_y = in_y - bby + p0x = random.random() * margin_x if margin_x > 0 else margin_x / 2 + p0y = random.random() * margin_y if margin_y > 0 else margin_y / 2 + dst = [[new_points[i][0] - bb_min[0] + p0x, new_points[i][1] - bb_min[1] + p0y] \ + for i in range(3)] + + tform = transform.AffineTransform() + tform.estimate(np.array(points[:3]), np.array(dst)) + # to do: need to find a solution to save the affine transform matrix + # Use the matplotlib.transforms.Affine2D function to generate transform matrices, + # and the scipy.ndimage.warp function to warp images using the transform matrices. + # The skimage AffineTransform shear functionality is weird, + # and the scipy affine_transform function for warping images swaps the X and Y axes. + # sample['Affine_Param'] = json.dumps((input_shape, tform["matrix"])) + return sample, tform + + def _apply_affine_to_ND_volume(self, image, output_shape, tform, order = 3): + """ + output_shape should only has two dimensions, e.g., (H, W) + """ + dim = len(image.shape) - 1 + if(dim == 2): + C, H, W = image.shape + output = np.zeros([C] + output_shape) + for c in range(C): + output[c] = ndimage.affine_transform(image[c], tform, + output_shape = output_shape, mode='mirror', order = order) + elif(dim == 3): + C, D, H, W = image.shape + output = np.zeros([C, D] + output_shape) + for c in range(C): + for d in range(D): + output[c,d] = ndimage.affine_transform(image[c,d], tform, + output_shape = output_shape, mode='mirror', order = order) + return output + + def __call__(self, sample): + image = sample['image'] + input_shape = sample['image'].shape + output_shape= input_shape if self.output_shape is None else self.output_shape + aff_out_shape = output_shape[-2:] + sample, tform = self._get_affine_param(sample, aff_out_shape) + image_t = self._apply_affine_to_ND_volume(image, aff_out_shape, tform) + sample['image'] = image_t + + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + label = sample['label'] + label = self._apply_affine_to_ND_volume(label, aff_out_shape, tform, order = 0) + sample['label'] = label + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + weight = sample['pixel_weight'] + weight = self._apply_affine_to_ND_volume(weight, aff_out_shape, tform) + sample['pixel_weight'] = weight + return sample + + def _get_param_for_inverse_transform(self, sample): + if(isinstance(sample['Affine_Param'], list) or \ + isinstance(sample['Affine_Param'], tuple)): + params = json.loads(sample['Affine_Param'][0]) + else: + params = json.loads(sample['Affine_Param']) + return params + + # def inverse_transform_for_prediction(self, sample): + # params = self._get_param_for_inverse_transform(sample) + # origin_shape = params[0] + # tform = params[1] + + # predict = sample['predict'] + # if(isinstance(predict, tuple) or isinstance(predict, list)): + # output_predict = [] + # for predict_i in predict: + # aff_out_shape = origin_shape[-2:] + # output_predict_i = self._apply_affine_to_ND_volume(predict_i, + # aff_out_shape, tform.inverse) + # output_predict.append(output_predict_i) + # else: + # aff_out_shape = origin_shape[-2:] + # output_predict = self._apply_affine_to_ND_volume(predict, aff_out_shape, tform.inverse) + + # sample['predict'] = output_predict + # return sample diff --git a/pymic/transform/crop.py b/pymic/transform/crop.py index b4d0b63..6977639 100644 --- a/pymic/transform/crop.py +++ b/pymic/transform/crop.py @@ -113,16 +113,20 @@ class CropWithBoundingBox(CenterCrop): :param `CropWithBoundingBox_start`: (None, or list/tuple) The start index along each spatial axis. If None, calculate the start index automatically - so that the cropped region is centered at the non-zero region. + so that the cropped region is centered at the mask region defined by the threshold. :param `CropWithBoundingBox_output_size`: (None or tuple/list): Desired spatial output size. - If None, set it as the size of bounding box of non-zero region. + If None, set it as the size of bounding box of the mask region defined by the threshold. + :param `CropWithBoundingBox_threshold`: (None or float): + Threshold for obtaining a mask. This is used only when + `CropWithBoundingBox_start` is None. Default is 1.0 :param `CropWithBoundingBox_inverse`: (optional, bool) Is inverse transform needed for inference. Default is `True`. """ def __init__(self, params): self.start = params['CropWithBoundingBox_start'.lower()] self.output_size = params['CropWithBoundingBox_output_size'.lower()] + self.threshold = params.get('CropWithBoundingBox_threshold'.lower(), 1.0) self.inverse = params.get('CropWithBoundingBox_inverse'.lower(), True) self.task = params['task'] @@ -130,8 +134,9 @@ def _get_crop_param(self, sample): image = sample['image'] input_shape = sample['image'].shape input_dim = len(input_shape) - 1 - bb_min, bb_max = get_ND_bounding_box(image) - bb_min, bb_max = bb_min[1:], bb_max[1:] + if(self.start is None or self.output_size is None): + bb_min, bb_max = get_ND_bounding_box(image > self.threshold) + bb_min, bb_max = bb_min[1:], bb_max[1:] if(self.start is None): if(self.output_size is None): crop_min, crop_max = bb_min, bb_max @@ -212,7 +217,9 @@ class RandomCrop(CenterCrop): :param `RandomCrop_output_size`: (list/tuple) Desired output size [D, H, W] or [H, W]. The output channel is the same as the input channel. - If D is None for 3D images, the z-axis is not cropped. + If `None` is set for a certain axis, that axis will not be cropped. For example, + for 3D vlumes, (None, H, W) means only crop in 2D, and (D, None, None) means only + crop along the z axis. :param `RandomCrop_foreground_focus`: (optional, bool) If true, allow crop around the foreground. Default is False. :param `RandomCrop_foreground_ratio`: (optional, float) @@ -242,10 +249,16 @@ def _get_crop_param(self, sample): input_shape = image.shape[1:] input_dim = len(input_shape) assert(input_dim == len(self.output_size)) - - crop_margin = [input_shape[i] - self.output_size[i] for i in range(input_dim)] + + output_size = [item for item in self.output_size] + # print("crop input and output size", input_shape, output_size) + for i in range(input_dim): + if(output_size[i] is None): + output_size[i] = input_shape[i] + # print(output_size) + crop_margin = [input_shape[i] - output_size[i] for i in range(input_dim)] crop_min = [0 if item == 0 else random.randint(0, item) for item in crop_margin] - crop_max = [crop_min[i] + self.output_size[i] for i in range(input_dim)] + crop_max = [crop_min[i] + output_size[i] for i in range(input_dim)] label_exist = False if ('label' not in sample or sample['label']) is None else True if(label_exist and self.fg_focus and random.random() < self.fg_ratio): @@ -255,7 +268,7 @@ def _get_crop_param(self, sample): else: mask_label = self.mask_label random_label = random.choice(mask_label) - crop_min, crop_max = get_random_box_from_mask(label == random_label, self.output_size, mode = 1) + crop_min, crop_max = get_random_box_from_mask(label == random_label, output_size, mode = 1) crop_min = [0] + crop_min crop_max = [chns] + crop_max diff --git a/pymic/transform/flip.py b/pymic/transform/flip.py index 486180c..6ea017c 100644 --- a/pymic/transform/flip.py +++ b/pymic/transform/flip.py @@ -6,7 +6,6 @@ import math import random import numpy as np -from scipy import ndimage from pymic import TaskType from pymic.transform.abstract_transform import AbstractTransform from pymic.util.image_process import * diff --git a/pymic/transform/intensity.py b/pymic/transform/intensity.py index ffa5141..d39eb2c 100644 --- a/pymic/transform/intensity.py +++ b/pymic/transform/intensity.py @@ -62,6 +62,7 @@ def __init__(self, params): self.channels = params['IntensityClip_channels'.lower()] self.lower = params.get('IntensityClip_lower'.lower(), None) self.upper = params.get('IntensityClip_upper'.lower(), None) + self.perct = params.get('IntensityClip_percentile_mode'.lower(), False) self.inverse = params.get('IntensityClip_inverse'.lower(), False) def __call__(self, sample): @@ -72,8 +73,12 @@ def __call__(self, sample): lower_c, upper_c = lower[chn], upper[chn] if(lower_c is None): lower_c = np.percentile(image[chn], 0.05) + elif(self.perct): + lower_c = np.percentile(image[chn], lower_c) if(upper_c is None): - upper_c = np.percentile(image[chn, 99.95]) + upper_c = np.percentile(image[chn], 99.95) + elif(self.perct): + upper_c = np.percentile(image[chn], upper_c) image[chn] = np.clip(image[chn], lower_c, upper_c) sample['image'] = image return sample diff --git a/pymic/transform/normalize.py b/pymic/transform/normalize.py index 6531f17..3d28c0d 100644 --- a/pymic/transform/normalize.py +++ b/pymic/transform/normalize.py @@ -24,11 +24,12 @@ class NormalizeWithMeanStd(AbstractTransform): :param `NormalizeWithMeanStd_std`: (list/tuple or None) The std values along each specified channel. If None, the std values are calculated automatically. - :param `NormalizeWithMeanStd_ignore_non_positive`: (optional, bool) - Only used when mean and std are not given. Default is False. - If True, calculate mean and std in the positive region for normalization, - and set non-positive region to random. If False, calculate - the mean and std values in the entire image region. + :param `NormalizeWithMeanStd_mask_threshold`: (optional, float) + Only used when mean and std are not given. Default is 1.0. + Calculate mean and std in the mask region where the intensity is higher than the mask. + :param `NormalizeWithMeanStd_set_background_to_random`: (optional, bool) + Set background region to random or not, and only applicable when + `NormalizeWithMeanStd_mask_threshold` is not None. Default is True. :param `NormalizeWithMeanStd_inverse`: (optional, bool) Is inverse transform needed for inference. Default is `False`. """ @@ -37,8 +38,9 @@ def __init__(self, params): self.chns = params.get('NormalizeWithMeanStd_channels'.lower(), None) self.mean = params.get('NormalizeWithMeanStd_mean'.lower(), None) self.std = params.get('NormalizeWithMeanStd_std'.lower(), None) - self.ingore_np = params.get('NormalizeWithMeanStd_ignore_non_positive'.lower(), False) - self.inverse = params.get('NormalizeWithMeanStd_inverse'.lower(), False) + self.mask_thrd = params.get('NormalizeWithMeanStd_mask_threshold'.lower(), 1.0) + self.bg_random = params.get('NormalizeWithMeanStd_set_background_to_random'.lower(), True) + self.inverse = params.get('NormalizeWithMeanStd_inverse'.lower(), False) def __call__(self, sample): image= sample['image'] @@ -52,8 +54,8 @@ def __call__(self, sample): chn = self.chns[i] chn_mean, chn_std = self.mean[i], self.std[i] if(chn_mean is None): - if(self.ingore_np): - pixels = image[chn][image[chn] > 0] + if(self.mask_thrd is not None): + pixels = image[chn][image[chn] > self.mask_thrd] if(len(pixels) > 0): chn_mean, chn_std = pixels.mean(), pixels.std() + 1e-5 else: @@ -63,9 +65,9 @@ def __call__(self, sample): chn_norm = (image[chn] - chn_mean)/chn_std - if(self.ingore_np): + if(self.mask_thrd is not None and self.bg_random): chn_random = np.random.normal(0, 1, size = chn_norm.shape) - chn_norm[image[chn] <= 0] = chn_random[image[chn] <= 0] + chn_norm[image[chn] <= self.mask_thrd] = chn_random[image[chn] <=self.mask_thrd] image[chn] = chn_norm sample['image'] = image return sample diff --git a/pymic/transform/pad.py b/pymic/transform/pad.py index c9b75fe..8624aa2 100644 --- a/pymic/transform/pad.py +++ b/pymic/transform/pad.py @@ -6,7 +6,6 @@ import math import random import numpy as np -from scipy import ndimage from pymic import TaskType from pymic.transform.abstract_transform import AbstractTransform from pymic.util.image_process import * diff --git a/pymic/transform/trans_dict.py b/pymic/transform/trans_dict.py index f779d00..a28e848 100644 --- a/pymic/transform/trans_dict.py +++ b/pymic/transform/trans_dict.py @@ -31,6 +31,7 @@ """ from __future__ import print_function, division +from pymic.transform.affine import * from pymic.transform.intensity import * from pymic.transform.flip import * from pymic.transform.pad import * @@ -44,6 +45,7 @@ from pymic.transform.label_convert import * TransformDict = { + 'Affine': Affine, 'ChannelWiseThreshold': ChannelWiseThreshold, 'ChannelWiseThresholdWithNormalize': ChannelWiseThresholdWithNormalize, 'CropWithBoundingBox': CropWithBoundingBox, From 0a65b4994987e4c69604f9d1d2fba44f8726bc0c Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 13 Dec 2023 10:46:34 +0800 Subject: [PATCH 173/225] update segmentation and reconstruction agent --- pymic/net_run/agent_rec.py | 44 +++++++++-------- pymic/net_run/agent_seg.py | 63 ++++++++++++++---------- pymic/net_run/self_sup/util.py | 88 ++++++++++++++++++++++++++++++---- pymic/net_run/train.py | 26 ++-------- pymic/transform/mix.py | 86 +++++++++++++++++---------------- 5 files changed, 188 insertions(+), 119 deletions(-) diff --git a/pymic/net_run/agent_rec.py b/pymic/net_run/agent_rec.py index 1e58bc6..634ea20 100644 --- a/pymic/net_run/agent_rec.py +++ b/pymic/net_run/agent_rec.py @@ -29,14 +29,6 @@ class ReconstructionAgent(SegmentationAgent): """ def __init__(self, config, stage = 'train'): super(ReconstructionAgent, self).__init__(config, stage) - output_act_name = config['network'].get('output_activation', 'sigmoid') - if(output_act_name == "sigmoid"): - self.out_act = nn.Sigmoid() - elif(output_act_name == "tanh"): - self.out_act = nn.Tanh() - else: - raise ValueError("For reconstruction task, only sigmoid and tanh are " + \ - "supported for output_activation.") def create_loss_calculator(self): if(self.loss_dict is None): @@ -48,7 +40,6 @@ def create_loss_calculator(self): raise ValueError("Undefined loss function {0:}".format(loss_name)) else: loss_param = self.config['training'] - loss_param['loss_softmax'] = False base_loss = self.loss_dict[loss_name](self.config['training']) if(self.config['training'].get('deep_supervise', False)): raise ValueError("Deep supervised loss not implemented for reconstruction tasks") @@ -80,8 +71,13 @@ def training(self): # print(inputs.shape) # for i in range(inputs.shape[0]): # image_i = inputs[i][0] + # label_i = label[i][0] # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) + # label_name = "temp/label_{0:}_{1:}.nii.gz".format(it, i) # save_nd_array_as_image(image_i, image_name, reference_name = None) + # save_nd_array_as_image(label_i, label_name, reference_name = None) + # if(it > 10): + # break # return inputs, label = inputs.to(self.device), label.to(self.device) @@ -91,7 +87,18 @@ def training(self): # forward + backward + optimize outputs = self.net(inputs) - outputs = self.out_act(outputs) + + # for debug + # if it < 5: + # outputs = nn.Tanh()(outputs) + # for i in range(inputs.shape[0]): + # out_name = "temp/output_{0:}_{1:}.nii.gz".format(it, i) + # output = outputs[i][0] + # output = output.cpu().detach().numpy() + # save_nd_array_as_image(output, out_name, reference_name = None) + # else: + # break + loss = self.get_loss_value(data, outputs, label) loss.backward() self.optimizer.step() @@ -123,7 +130,6 @@ def validation(self): label = self.convert_tensor_type(data['label']) inputs, label = inputs.to(self.device), label.to(self.device) outputs = self.inferer.run(self.net, inputs) - outputs = self.out_act(outputs) # The tensors are on CPU when calculating loss for validation data loss = self.get_loss_value(data, outputs, label) valid_loss_list.append(loss.item()) @@ -293,19 +299,19 @@ def save_outputs(self, data): names, pred = data['names'], data['predict'] if(isinstance(pred, (list, tuple))): pred = pred[0] - if(isinstance(self.out_act, nn.Sigmoid)): - pred = scipy.special.expit(pred) - else: - pred = np.tanh(pred) + pred = np.tanh(pred) # save the output predictions - root_dir = self.config['dataset']['root_dir'] + test_dir = self.config['dataset'].get('test_dir', None) + if(test_dir is None): + test_dir = self.config['dataset']['train_dir'] + for i in range(len(names)): - save_name = names[i].split('/')[-1] if ignore_dir else \ - names[i].replace('/', '_') + save_name = names[i][0].split('/')[-1] if ignore_dir else \ + names[i][0].replace('/', '_') if((filename_replace_source is not None) and (filename_replace_target is not None)): save_name = save_name.replace(filename_replace_source, filename_replace_target) print(save_name) save_name = "{0:}/{1:}".format(output_dir, save_name) - save_nd_array_as_image(pred[i][i], save_name, root_dir + '/' + names[i]) + save_nd_array_as_image(pred[i][i], save_name, test_dir + '/' + names[i][0]) \ No newline at end of file diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index c376546..14f7567 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -39,36 +39,44 @@ def __init__(self, config, stage = 'train'): self.net_dict = SegNetDict self.postprocess_dict = PostProcessDict self.postprocessor = None - - def get_stage_dataset_from_config(self, stage): - assert(stage in ['train', 'valid', 'test']) - root_dir = self.config['dataset']['root_dir'] - modal_num = self.config['dataset'].get('modal_num', 1) + def get_transform_names_and_parameters(self, stage): + """ + Get a list of transform objects for creating a dataset + """ + assert(stage in ['train', 'valid', 'test']) transform_key = stage + '_transform' if(stage == "valid" and transform_key not in self.config['dataset']): transform_key = "train_transform" - transform_names = self.config['dataset'][transform_key] - - self.transform_list = [] - if(transform_names is None or len(transform_names) == 0): - data_transform = None - else: - transform_param = self.config['dataset'] - transform_param['task'] = self.task_type - for name in transform_names: + trans_names = self.config['dataset'][transform_key] + trans_params = self.config['dataset'] + trans_params['task'] = self.task_type + return trans_names, trans_params + + def get_stage_dataset_from_config(self, stage): + trans_names, trans_params = self.get_transform_names_and_parameters(stage) + transform_list = [] + if(trans_names is not None and len(trans_names) > 0): + for name in trans_names: if(name not in self.transform_dict): raise(ValueError("Undefined transform {0:}".format(name))) - one_transform = self.transform_dict[name](transform_param) - self.transform_list.append(one_transform) - data_transform = transforms.Compose(self.transform_list) + one_transform = self.transform_dict[name](trans_params) + transform_list.append(one_transform) + data_transform = transforms.Compose(transform_list) - csv_file = self.config['dataset'].get(stage + '_csv', None) + csv_file = self.config['dataset'].get(stage + '_csv', None) if(stage == 'test'): with_label = False + self.test_transforms = transform_list else: with_label = self.config['dataset'].get(stage + '_label', True) - dataset = NiftyDataset(root_dir = root_dir, + modal_num = self.config['dataset'].get('modal_num', 1) + stage_dir = self.config['dataset'].get('train_dir', None) + if(stage == 'valid' and "valid_dir" in self.config['dataset']): + stage_dir = self.config['dataset']['valid_dir'] + if(stage == 'test' and "test_dir" in self.config['dataset']): + stage_dir = self.config['dataset']['test_dir'] + dataset = NiftyDataset(root_dir = stage_dir, csv_file = csv_file, modal_num = modal_num, with_label= with_label, @@ -471,7 +479,7 @@ def test_time_dropout(m): pred = pred.cpu().numpy() data['predict'] = pred # inverse transform - for transform in self.transform_list[::-1]: + for transform in self.test_transforms[::-1]: if (transform.inverse): data = transform.inverse_transform_for_prediction(data) @@ -525,7 +533,7 @@ def infer_with_multiple_checkpoints(self): pred = np.mean(predict_list, axis=0) data['predict'] = pred # inverse transform - for transform in self.transform_list[::-1]: + for transform in self.test_transforms[::-1]: if (transform.inverse): data = transform.inverse_transform_for_prediction(data) @@ -564,15 +572,18 @@ def save_outputs(self, data): for i in range(len(names)): output[i] = self.postprocessor(output[i]) # save the output and (optionally) probability predictions - root_dir = self.config['dataset']['root_dir'] + test_dir = self.config['dataset'].get('test_dir', None) + if(test_dir is None): + test_dir = self.config['dataset']['train_dir'] + for i in range(len(names)): - save_name = names[i].split('/')[-1] if ignore_dir else \ - names[i].replace('/', '_') + save_name = names[i][0].split('/')[-1] if ignore_dir else \ + names[i][0].replace('/', '_') if((filename_replace_source is not None) and (filename_replace_target is not None)): save_name = save_name.replace(filename_replace_source, filename_replace_target) print(save_name) save_name = "{0:}/{1:}".format(output_dir, save_name) - save_nd_array_as_image(output[i], save_name, root_dir + '/' + names[i]) + save_nd_array_as_image(output[i], save_name, test_dir + '/' + names[i][0]) save_name_split = save_name.split('.') if(not save_prob): @@ -590,4 +601,4 @@ def save_outputs(self, data): prob_save_name = "{0:}_prob_{1:}.{2:}".format(save_prefix, c, save_format) if(len(temp_prob.shape) == 2): temp_prob = np.asarray(temp_prob * 255, np.uint8) - save_nd_array_as_image(temp_prob, prob_save_name, root_dir + '/' + names[i]) + save_nd_array_as_image(temp_prob, prob_save_name, test_dir + '/' + names[i][0]) diff --git a/pymic/net_run/self_sup/util.py b/pymic/net_run/self_sup/util.py index 9cffaa7..0ab1670 100644 --- a/pymic/net_run/self_sup/util.py +++ b/pymic/net_run/self_sup/util.py @@ -19,6 +19,10 @@ def get_human_region_mask(img): mask = np.asarray(img > -600) se = np.ones([3,3,3]) mask = ndimage.binary_opening(mask, se, iterations = 2) + D, H, W = mask.shape + for h in range(H): + if(mask[:,h,:].sum() < 2000): + mask[:,h, :] = np.zeros((D, W)) mask = get_largest_k_components(mask, 1) mask_close = ndimage.binary_closing(mask, se, iterations = 2) @@ -47,20 +51,39 @@ def get_human_region_mask(img): fg = np.asarray(fg, np.uint8) return fg -def crop_ct_scan(input_img, output_img, input_lab = None, output_lab = None): +def get_human_region_mask_fast(img, itk_spacing): + # downsample + D, H, W = img.shape + # scale_down = [1, 1, 1] + if(itk_spacing[2] <= 1): + scale_down = [1/2, 1/2, 1/2] + else: + scale_down = [1, 1/2, 1/2] + img_sub = ndimage.interpolation.zoom(img, scale_down, order = 0) + mask = get_human_region_mask(img_sub) + D1, H1, W1 = mask.shape + scale_up = [D/D1, H/H1, W/W1] + mask = ndimage.interpolation.zoom(mask, scale_up, order = 0) + return mask + +def crop_ct_scan(input_img, output_img, input_lab = None, output_lab = None, z_axis_density = 0.5): """ Crop a CT scan based on the bounding box of the human region. """ img_obj = sitk.ReadImage(input_img) - img = sitk.GetArrayFromImage(img_obj) - mask = np.asarray(img > -600) - se = np.ones([3,3,3]) - mask = ndimage.binary_opening(mask, se, iterations = 2) - mask = get_largest_k_components(mask, 1) - bbmin, bbmax = get_ND_bounding_box(mask, margin = [5, 10, 10]) + img = sitk.GetArrayFromImage(img_obj) + mask = np.asarray(img > -600) + mask2d = np.mean(mask, axis = 0) > z_axis_density + se = np.ones([3,3]) + mask2d = ndimage.binary_opening(mask2d, se, iterations = 2) + mask2d = get_largest_k_components(mask2d, 1) + bbmin, bbmax = get_ND_bounding_box(mask2d, margin = [0, 0]) + bbmin = [0] + bbmin + bbmax = [img.shape[0]] + bbmax img_sub = crop_ND_volume_with_bounding_box(img, bbmin, bbmax) img_sub_obj = sitk.GetImageFromArray(img_sub) img_sub_obj.SetSpacing(img_obj.GetSpacing()) + img_sub_obj.SetDirection(img_obj.GetDirection()) sitk.WriteImage(img_sub_obj, output_img) if(input_lab is not None): lab_obj = sitk.ReadImage(input_lab) @@ -70,6 +93,49 @@ def crop_ct_scan(input_img, output_img, input_lab = None, output_lab = None): lab_sub_obj.SetSpacing(img_obj.GetSpacing()) sitk.WriteImage(lab_sub_obj, output_lab) +def get_human_body_mask_and_crop(input_dir, out_img_dir, out_mask_dir): + if(not os.path.exists(out_img_dir)): + os.mkdir(out_img_dir) + os.mkdir(out_mask_dir) + + img_names = [item for item in os.listdir(input_dir) if "nii.gz" in item] + img_names = sorted(img_names) + for img_name in img_names: + print(img_name) + input_name = input_dir + "/" + img_name + out_name = out_img_dir + "/" + img_name + mask_name = out_mask_dir + "/" + img_name + if(os.path.isfile(out_name)): + continue + img_obj = sitk.ReadImage(input_name) + img = sitk.GetArrayFromImage(img_obj) + spacing = img_obj.GetSpacing() + + # downsample + D, H, W = img.shape + spacing = img_obj.GetSpacing() + # scale_down = [1, 1, 1] + if(spacing[2] <= 1): + scale_down = [1/2, 1/2, 1/2] + else: + scale_down = [1, 1/2, 1/2] + img_sub = ndimage.interpolation.zoom(img, scale_down, order = 0) + mask = get_human_region_mask(img_sub) + D1, H1, W1 = mask.shape + scale_up = [D/D1, H/H1, W/W1] + mask = ndimage.interpolation.zoom(mask, scale_up, order = 0) + + bbmin, bbmax = get_ND_bounding_box(mask) + img_crop = crop_ND_volume_with_bounding_box(img, bbmin, bbmax) + mask_crop = crop_ND_volume_with_bounding_box(mask, bbmin, bbmax) + + out_img_obj = sitk.GetImageFromArray(img_crop) + out_img_obj.SetSpacing(spacing) + sitk.WriteImage(out_img_obj, out_name) + mask_obj = sitk.GetImageFromArray(mask_crop) + mask_obj.CopyInformation(out_img_obj) + sitk.WriteImage(mask_obj, mask_name) + def patch_mix(x, fg_num, patch_num, size_d, size_h, size_w): """ @@ -99,7 +165,7 @@ def patch_mix(x, fg_num, patch_num, size_d, size_h, size_w): y_prob = get_one_hot_seg(fg_mask.to(torch.int32), fg_num + 1) return x_fuse, y_prob -def create_mixed_dataset(input_dir, output_dir, fg_num = 1, crop_num = 1, +def create_mixed_dataset(input_dir, output_dir, fg_num = 1, crop_num = 1, patch_size=[128,128,128], mask_dir = None, data_format = "nii.gz"): """ Create dataset based on patch mix. @@ -136,6 +202,7 @@ def create_mixed_dataset(input_dir, output_dir, fg_num = 1, crop_num = 1, img_j = load_image_as_nd_array(input_dir + "/" + img_names[j])['data_array'] chns = img_i.shape[0] + crop_size = [chns] + patch_size # random crop to patch size if(mask_dir is None): mask_i = get_human_region_mask(img_i) @@ -148,8 +215,9 @@ def create_mixed_dataset(input_dir, output_dir, fg_num = 1, crop_num = 1, # img_ik = random_crop_ND_volume(img_i, [chns, 96, 96, 96]) # img_jk = random_crop_ND_volume(img_j, [chns, 96, 96, 96]) # else: - img_ik = random_crop_ND_volume_with_mask(img_i, [chns, 96, 96, 96], mask_i) - img_jk = random_crop_ND_volume_with_mask(img_j, [chns, 96, 96, 96], mask_j) + + img_ik = random_crop_ND_volume_with_mask(img_i, crop_size, mask_i) + img_jk = random_crop_ND_volume_with_mask(img_j, crop_size, mask_j) C, D, H, W = img_ik.shape # generate mask fg_mask = np.zeros_like(img_ik, np.uint8) diff --git a/pymic/net_run/train.py b/pymic/net_run/train.py index 426b620..f2bbe0f 100644 --- a/pymic/net_run/train.py +++ b/pymic/net_run/train.py @@ -11,8 +11,9 @@ from pymic.net_run.agent_seg import SegmentationAgent from pymic.net_run.semi_sup import SSLMethodDict from pymic.net_run.weak_sup import WSLMethodDict +from pymic.net_run.self_sup import SelfSupMethodDict from pymic.net_run.noisy_label import NLLMethodDict -from pymic.net_run.self_sup import SelfSLSegAgent +# from pymic.net_run.self_sup import SelfSLSegAgent def get_seg_rec_agent(config, sup_type): assert(sup_type in ['fully_sup', 'semi_sup', 'self_sup', 'weak_sup', 'noisy_label']) @@ -34,28 +35,7 @@ def get_seg_rec_agent(config, sup_type): elif(sup_type == 'self_sup'): logging.info("\n********** Self Supervised Learning **********\n") method = config['self_supervised_learning']['method_name'] - if(method == "custom"): - pass - elif(method == "model_genesis"): - transforms = ['RandomFlip', 'LocalShuffling', 'NonLinearTransform', 'InOutPainting'] - genesis_cfg = { - 'randomflip_flip_depth': True, - 'randomflip_flip_height': True, - 'randomflip_flip_width': True, - 'localshuffling_probability': 0.5, - 'nonLineartransform_probability': 0.9, - 'inoutpainting_probability': 0.9, - 'inpainting_probability': 0.2 - } - config['dataset']['train_transform'].extend(transforms) - # config['dataset']['valid_transform'].extend(transforms) - config['dataset'].update(genesis_cfg) - logging_config(config['dataset']) - else: - raise ValueError("The specified method {0:} is not implemented. ".format(method) + \ - "Consider to set `self_sl_method = custom` and use customized" + \ - " transforms for self-supervised learning.") - agent = SelfSLSegAgent(config, 'train') + agent = SelfSupMethodDict[method](config, 'train') else: raise ValueError("undefined supervision type: {0:}".format(sup_type)) return agent diff --git a/pymic/transform/mix.py b/pymic/transform/mix.py index 6e2fb8e..6efed6a 100644 --- a/pymic/transform/mix.py +++ b/pymic/transform/mix.py @@ -71,7 +71,8 @@ class PatchMix(AbstractTransform): """ def __init__(self, params): super(PatchMix, self).__init__(params) - self.inverse = params.get('PatchMix_inverse'.lower(), False) + self.inverse = params.get('PatchMix_inverse'.lower(), False) + self.threshold = params.get('PatchMix_threshold'.lower(), 0) self.crop_size = params.get('PatchMix_crop_size'.lower(), [64, 128, 128]) self.fg_cls_num = params.get('PatchMix_cls_num'.lower(), [4, 40]) self.patch_num_range= params.get('PatchMix_patch_range'.lower(), [4, 40]) @@ -79,7 +80,8 @@ def __init__(self, params): self.patch_size_max = params.get('PatchMix_patch_size_max'.lower(), [20, 40, 40]) def __call__(self, sample): - x0, x1 = self._random_crop_and_flip(sample) + x0 = self._random_crop_and_flip(sample) + x1 = self._random_crop_and_flip(sample) C, D, H, W = x0.shape # generate mask fg_mask = np.zeros_like(x0, np.uint8) @@ -104,46 +106,48 @@ def __call__(self, sample): return sample def _random_crop_and_flip(self, sample): - input_shape = sample['image'].shape - input_dim = len(input_shape) - 1 + image = sample['image'] + input_dim = len(image.shape) - 1 assert(input_dim == 3) - + C, D, H, W = image.shape + + half_size = [x // 2 for x in self.crop_size] + dc = random.randint(half_size[0], D - half_size[0]) + image2d = image[0, dc, :, :] + mask2d = np.zeros_like(image2d) + mask2d[half_size[1]:H+1-half_size[1], half_size[2]:W+1-half_size[2]] = \ + np.ones([H-self.crop_size[1]+1, W-self.crop_size[2]+1]) if('label' in sample): - # get the center for crop randomly - mask = sample['label'] > 0 - C, D, H, W = input_shape - size_h = [i// 2 for i in self.crop_size] - temp_mask = np.zeros_like(mask) - temp_mask[:,size_h[0]:D-size_h[0]+1,size_h[1]:H-size_h[1]+1,size_h[2]:W-size_h[2]+1] = \ - np.ones([C, D-self.crop_size[0]+1, H-self.crop_size[1]+1, W-self.crop_size[2]+1]) - mask = mask * temp_mask - indices = np.where(mask) - n0 = random.randint(0, len(indices[0])-1) - n1 = random.randint(0, len(indices[0])-1) - center0 = [indices[i][n0] for i in range(1, 4)] - center1 = [indices[i][n1] for i in range(1, 4)] - crop_min0 = [center0[i] - size_h[i] for i in range(3)] - crop_min1 = [center1[i] - size_h[i] for i in range(3)] - else: - crop_margin = [input_shape[1+i] - self.crop_size[i] for i in range(input_dim)] - crop_min0 = [0 if item == 0 else random.randint(0, item) for item in crop_margin] - crop_min1 = [0 if item == 0 else random.randint(0, item) for item in crop_margin] + temp_mask = sample['label'][0, dc, :, :] > 0 + mask2d = temp_mask * mask2d + elif(self.threshold is not None): + temp_mask = image2d > self.threshold + se = np.ones([3,3]) + temp_mask = ndimage.binary_opening(temp_mask, se, iterations = 2) + temp_mask = get_largest_k_components(temp_mask, 1) + mask2d = temp_mask * mask2d + + indices = np.where(mask2d) + n = random.randint(0, len(indices[0])-1) + center = [indices[i][n] for i in range(2)] + crop_min = [dc - half_size[0], center[0]-half_size[1], center[1] - half_size[2]] + crop_max = [crop_min[i] + self.crop_size[i] for i in range(input_dim)] + crop_min = [0] + crop_min + crop_max = [C] + crop_max + x = crop_ND_volume_with_bounding_box(image, crop_min, crop_max) - patches = [] - for crop_min in [crop_min0, crop_min1]: - crop_max = [crop_min[i] + self.crop_size[i] for i in range(input_dim)] - crop_min = [0] + crop_min - crop_max = [C] + crop_max - x = crop_ND_volume_with_bounding_box(sample['image'], crop_min, crop_max) - flip_axis = [] - if(random.random() > 0.5): - flip_axis.append(-1) - if(random.random() > 0.5): - flip_axis.append(-2) - if(random.random() > 0.5): - flip_axis.append(-3) - if(len(flip_axis) > 0): - x = np.flip(x, flip_axis).copy() - patches.append(x) + flip_axis = [] + if(random.random() > 0.5): + flip_axis.append(-1) + if(random.random() > 0.5): + flip_axis.append(-2) + if(random.random() > 0.5): + flip_axis.append(-3) + if(len(flip_axis) > 0): + x = np.flip(x, flip_axis).copy() - return patches \ No newline at end of file + if(x.shape[1] == 63): + print("crop shape == 63", x.shape) + print(sample['names']) + print(image.shape, crop_min, crop_max) + return x \ No newline at end of file From a906dfc1e23045c9aa0b235adf7ae550889dc9dc Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 3 Jan 2024 10:18:34 +0800 Subject: [PATCH 174/225] update self-supervised learning --- .gitignore | 1 + pymic/io/h5_dataset.py | 7 +- pymic/io/image_read_write.py | 13 +- pymic/io/nifty_dataset.py | 23 +- pymic/loss/seg/abstract.py | 15 +- pymic/loss/seg/ce.py | 15 +- pymic/loss/seg/dice.py | 20 +- pymic/loss/seg/exp_log.py | 4 +- pymic/loss/seg/mse.py | 8 +- pymic/net/net2d/unet2d.py | 44 ++-- pymic/net/net2d/unet2d_canet.py | 66 ------ pymic/net/net_dict_cls.py | 2 +- pymic/net/net_dict_seg.py | 2 + pymic/net_run/agent_abstract.py | 13 +- pymic/net_run/agent_rec.py | 5 +- pymic/net_run/agent_seg.py | 8 +- pymic/net_run/get_optimizer.py | 3 +- pymic/net_run/preprocess.py | 82 +++++++ pymic/net_run/self_sup/__init__.py | 11 +- pymic/net_run/self_sup/self_genesis.py | 51 +++++ pymic/net_run/self_sup/self_patch_swapping.py | 44 ++++ pymic/net_run/self_sup/self_sl_agent.py | 3 +- ...tch_mix_agent.py => self_volume_fusion.py} | 49 +--- pymic/net_run/self_sup/util.py | 121 ++-------- pymic/net_run/semi_sup/__init__.py | 15 +- pymic/net_run/semi_sup/ssl_abstract.py | 6 +- pymic/net_run/semi_sup/ssl_cps.py | 8 +- pymic/net_run/train.py | 6 +- pymic/transform/crop.py | 90 +++++++- pymic/transform/extract_channel.py | 39 ++++ pymic/transform/intensity.py | 210 ++++++++---------- pymic/transform/normalize.py | 8 +- pymic/transform/rescale.py | 1 + pymic/transform/trans_dict.py | 2 + pymic/util/image_process.py | 2 +- 35 files changed, 580 insertions(+), 417 deletions(-) create mode 100644 pymic/net_run/preprocess.py create mode 100644 pymic/net_run/self_sup/self_genesis.py create mode 100644 pymic/net_run/self_sup/self_patch_swapping.py rename pymic/net_run/self_sup/{self_patch_mix_agent.py => self_volume_fusion.py} (71%) create mode 100644 pymic/transform/extract_channel.py diff --git a/.gitignore b/.gitignore index f7da7ac..639f873 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ dist/* *egg*/* *stop* files.txt +pymic/test/runs/* # Created by https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks # Edit at https://www.toptal.com/developers/gitignore?templates=python,jupyternotebooks diff --git a/pymic/io/h5_dataset.py b/pymic/io/h5_dataset.py index 02f94f3..34fa1a4 100644 --- a/pymic/io/h5_dataset.py +++ b/pymic/io/h5_dataset.py @@ -8,8 +8,9 @@ import pandas as pd from torch.utils.data import Dataset from torch.utils.data.sampler import Sampler +from pymic import TaskType -class H5DataSet(Dataset): +class H5DataSet_backup(Dataset): """ Dataset for loading images stored in h5 format. It generates 4D tensors with dimention order [C, D, H, W] for 3D images, and @@ -39,7 +40,9 @@ def __getitem__(self, idx): if self.transform: sample = self.transform(sample) return sample - + + + class TwoStreamBatchSampler(Sampler): """Iterate two sets of indices diff --git a/pymic/io/image_read_write.py b/pymic/io/image_read_write.py index 6c8c6b0..3aa87bd 100644 --- a/pymic/io/image_read_write.py +++ b/pymic/io/image_read_write.py @@ -79,10 +79,10 @@ def load_image_as_nd_array(image_name): image_name.endswith(".tif") or image_name.endswith(".png")): image_dict = load_rgb_image_as_3d_array(image_name) else: - raise ValueError("unsupported image format") + raise ValueError("unsupported image format: {0:}".format(image_name)) return image_dict -def save_array_as_nifty_volume(data, image_name, reference_name = None): +def save_array_as_nifty_volume(data, image_name, reference_name = None, spacing = [1.0,1.0,1.0]): """ Save a numpy array as nifty image @@ -90,6 +90,7 @@ def save_array_as_nifty_volume(data, image_name, reference_name = None): :param image_name: (str) The ouput file name. :param reference_name: (str) File name of the reference image of which meta information is used. + :param spacing: (list or tuple) the spacing of a volume data when `reference_name` is not provided. """ img = sitk.GetImageFromArray(data) if(reference_name is not None): @@ -101,6 +102,9 @@ def save_array_as_nifty_volume(data, image_name, reference_name = None): direction1 = img.GetDirection() if(len(direction0) == len(direction1)): img.SetDirection(direction0) + else: + nifty_spacing = spacing[1:] + spacing[:1] + img.SetSpacing(nifty_spacing) sitk.WriteImage(img, image_name) def save_array_as_rgb_image(data, image_name): @@ -119,7 +123,7 @@ def save_array_as_rgb_image(data, image_name): img = Image.fromarray(data) img.save(image_name) -def save_nd_array_as_image(data, image_name, reference_name = None): +def save_nd_array_as_image(data, image_name, reference_name = None, spacing = [1.0,1.0,1.0]): """ Save a 3D or 2D numpy array as medical image or RGB image @@ -127,13 +131,14 @@ def save_nd_array_as_image(data, image_name, reference_name = None): [H, W, 3] or [H, W]. :param reference_name: (str) File name of the reference image of which meta information is used. + :param spacing: (list or tuple) the spacing of a volume data when `reference_name` is not provided. """ data_dim = len(data.shape) assert(data_dim == 2 or data_dim == 3) if (image_name.endswith(".nii.gz") or image_name.endswith(".nii") or image_name.endswith(".mha")): assert(data_dim == 3) - save_array_as_nifty_volume(data, image_name, reference_name) + save_array_as_nifty_volume(data, image_name, reference_name, spacing) elif(image_name.endswith(".jpg") or image_name.endswith(".jpeg") or image_name.endswith(".tif") or image_name.endswith(".png")): diff --git a/pymic/io/nifty_dataset.py b/pymic/io/nifty_dataset.py index 438a07b..aefe4da 100644 --- a/pymic/io/nifty_dataset.py +++ b/pymic/io/nifty_dataset.py @@ -3,11 +3,9 @@ import logging import os -import torch import pandas as pd import numpy as np -from torch.utils.data import Dataset, DataLoader -from torchvision import transforms, utils +from torch.utils.data import Dataset from pymic import TaskType from pymic.io.image_read_write import load_image_as_nd_array @@ -70,6 +68,25 @@ def __get_pixel_weight__(self, idx): weight = np.asarray(weight, np.float32) return weight + # def __getitem__(self, idx): + # sample_name = self.csv_items.iloc[idx, 0] + # h5f = h5py.File(self.root_dir + '/' + sample_name, 'r') + # image = np.asarray(h5f['image'][:], np.float32) + + # # this a temporaory process, will be delieted later + # if(len(image.shape) == 3 and image.shape[0] > 1): + # image = np.expand_dims(image, 0) + # sample = {'image': image, 'names':sample_name} + + # if('label' in h5f): + # label = np.asarray(h5f['label'][:], np.uint8) + # if(len(label.shape) == 3 and label.shape[0] > 1): + # label = np.expand_dims(label, 0) + # sample['label'] = label + # if self.transform: + # sample = self.transform(sample) + # return sample + def __getitem__(self, idx): names_list, image_list = [], [] for i in range (self.modal_num): diff --git a/pymic/loss/seg/abstract.py b/pymic/loss/seg/abstract.py index f42d816..68643e8 100644 --- a/pymic/loss/seg/abstract.py +++ b/pymic/loss/seg/abstract.py @@ -16,9 +16,20 @@ class AbstractSegLoss(nn.Module): def __init__(self, params = None): super(AbstractSegLoss, self).__init__() if(params is None): - self.softmax = True + self.acti_func = 'softmax' else: - self.softmax = params.get('loss_softmax', True) + self.acti_func = params.get('loss_acti_func', 'softmax') + + def get_activated_prediction(self, p, acti_func = 'softmax'): + if(acti_func == "softmax"): + p = nn.Softmax(dim = 1)(p) + elif(acti_func == "tanh"): + p = nn.Tanh()(p) + elif(acti_func == "sigmoid"): + p = nn.Sigmoid()(p) + else: + raise ValueError("activation for output is not supported: {0:}".format(acti_func)) + return p def forward(self, loss_input_dict): """ diff --git a/pymic/loss/seg/ce.py b/pymic/loss/seg/ce.py index 9524d57..4edbbc3 100644 --- a/pymic/loss/seg/ce.py +++ b/pymic/loss/seg/ce.py @@ -13,8 +13,10 @@ class CrossEntropyLoss(AbstractSegLoss): The parameters should be written in the `params` dictionary, and it has the following fields: - :param `loss_softmax`: (optional, bool) - Apply softmax to the prediction of network or not. Default is True. + :param `loss_acti_func`: (optional, string) + Apply an activation function to the prediction of network or not, for example, + 'softmax' for image segmentation tasks, 'tanh' for reconstruction tasks, and None + means no activation is used. """ def __init__(self, params = None): super(CrossEntropyLoss, self).__init__(params) @@ -27,8 +29,9 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) + predict = reshape_tensor_to_2D(predict) soft_y = reshape_tensor_to_2D(soft_y) @@ -74,8 +77,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) predict = reshape_tensor_to_2D(predict) soft_y = reshape_tensor_to_2D(soft_y) gce = (1.0 - torch.pow(predict, self.q)) / self.q * soft_y diff --git a/pymic/loss/seg/dice.py b/pymic/loss/seg/dice.py index 2c2df32..c423c2c 100644 --- a/pymic/loss/seg/dice.py +++ b/pymic/loss/seg/dice.py @@ -25,8 +25,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) predict = reshape_tensor_to_2D(predict) soft_y = reshape_tensor_to_2D(soft_y) if(pix_w is not None): @@ -52,8 +52,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) predict = 1.0 - predict[:, :1, :, :, :] soft_y = 1.0 - soft_y[:, :1, :, :, :] predict = reshape_tensor_to_2D(predict) @@ -76,8 +76,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) predict = reshape_tensor_to_2D(predict) soft_y = reshape_tensor_to_2D(soft_y) num_class = list(predict.size())[1] @@ -115,8 +115,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) predict = reshape_tensor_to_2D(predict) soft_y = reshape_tensor_to_2D(soft_y) @@ -149,8 +149,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) predict = reshape_tensor_to_2D(predict) soft_y = reshape_tensor_to_2D(soft_y) diff --git a/pymic/loss/seg/exp_log.py b/pymic/loss/seg/exp_log.py index c1b3f00..8c0d494 100644 --- a/pymic/loss/seg/exp_log.py +++ b/pymic/loss/seg/exp_log.py @@ -32,8 +32,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) predict = reshape_tensor_to_2D(predict) soft_y = reshape_tensor_to_2D(soft_y) diff --git a/pymic/loss/seg/mse.py b/pymic/loss/seg/mse.py index 5b657c5..eb53af4 100644 --- a/pymic/loss/seg/mse.py +++ b/pymic/loss/seg/mse.py @@ -19,8 +19,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) mse = torch.square(predict - soft_y) mse = torch.mean(mse) return mse @@ -44,8 +44,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) mae = torch.abs(predict - soft_y) if(weight is None): mae = torch.mean(mae) diff --git a/pymic/net/net2d/unet2d.py b/pymic/net/net2d/unet2d.py index 7d14a2e..be69f0d 100644 --- a/pymic/net/net2d/unet2d.py +++ b/pymic/net/net2d/unet2d.py @@ -5,7 +5,6 @@ import torch import torch.nn as nn import numpy as np -from torch.nn.functional import interpolate class ConvBlock(nn.Module): """ @@ -61,8 +60,7 @@ class UpBlock(nn.Module): 0 (or `TransConv`), 1 (`Nearest`), 2 (`Bilinear`), 3 (`Bicubic`). The default value is 2 (`Bilinear`). """ - def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, - up_mode = 2): + def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, up_mode = 2): super(UpBlock, self).__init__() if(isinstance(up_mode, int)): up_mode_values = ["transconv", "nearest", "bilinear", "bicubic"] @@ -144,7 +142,7 @@ class Decoder(nn.Module): :param up_mode: (string or int) The mode for upsampling. The allowed values are: 0 (or `TransConv`), 1 (or `Nearest`), 2 (or `Bilinear`), 3 (or `Bicubic`). The default value is 2 (or `Bilinear`). - :param multiscale_pred: (bool) Get multiscale prediction. + :param multiscale_pred: (bool) Get multi-scale prediction. """ def __init__(self, params): super(Decoder, self).__init__() @@ -165,10 +163,14 @@ def __init__(self, params): self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.up_mode) self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 1) - if(self.mul_pred): + if(self.mul_pred and (self.training or self.mul_infer)): self.out_conv1 = nn.Conv2d(self.ft_chns[1], self.n_class, kernel_size = 1) self.out_conv2 = nn.Conv2d(self.ft_chns[2], self.n_class, kernel_size = 1) self.out_conv3 = nn.Conv2d(self.ft_chns[3], self.n_class, kernel_size = 1) + self.stage = 'train' + + def set_stage(self, stage): + self.stage = stage def forward(self, x): if(len(self.ft_chns) == 5): @@ -183,7 +185,7 @@ def forward(self, x): x_d1 = self.up3(x_d2, x1) x_d0 = self.up4(x_d1, x0) output = self.out_conv(x_d0) - if(self.mul_pred and self.training): + if(self.mul_pred and self.stage == 'train'): output1 = self.out_conv1(x_d1) output2 = self.out_conv2(x_d2) output3 = self.out_conv3(x_d3) @@ -239,6 +241,10 @@ def get_default_parameters(self, params): logging.info("{0:} = {1:}".format(key, params[key])) return params + def set_stage(self, stage): + self.stage = stage + self.decoder.set_stage(stage) + def forward(self, x): x_shape = list(x.shape) if(len(x_shape) == 5): @@ -258,28 +264,4 @@ def forward(self, x): new_shape = [N, D] + list(output.shape)[1:] output = torch.transpose(torch.reshape(output, new_shape), 1, 2) - return output - - -if __name__ == "__main__": - params = {'in_chns':4, - 'feature_chns':[16, 32, 64, 128, 256], - 'dropout': [0, 0, 0.3, 0.4, 0.5], - 'class_num': 2, - 'up_mode': 0, - 'multiscale_pred': True} - Net = UNet2D(params) - Net = Net.double() - - x = np.random.rand(4, 4, 10, 256, 256) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - out = Net(xt) - if params['multiscale_pred']: - for y in out: - print(len(y.size())) - y = y.detach().numpy() - print(y.shape) - else: - print(out.shape) + return output \ No newline at end of file diff --git a/pymic/net/net2d/unet2d_canet.py b/pymic/net/net2d/unet2d_canet.py index 7aeba84..defcb60 100644 --- a/pymic/net/net2d/unet2d_canet.py +++ b/pymic/net/net2d/unet2d_canet.py @@ -684,72 +684,6 @@ def forward(self, inputs): xt = torch.from_numpy(x) xt = torch.tensor(xt) - y = Net(xt) - print(len(y.size())) - y = y.detach().numpy() - print(y.shape) -axpool3(conv3) - - conv4 = self.conv4(maxpool3) - maxpool4 = self.maxpool4(conv4) - - # Gating Signal Generation - center = self.center(maxpool4) - - # Attention Mechanism - # Upscaling Part (Decoder) - up4 = self.up_concat4(conv4, center) - g_conv4 = self.nonlocal4_2(up4) - - up4, att_weight4 = self.up4(g_conv4) - g_conv3, att3 = self.attentionblock3(conv3, up4) - - # atten3_map = att3.cpu().detach().numpy().astype(np.float) - # atten3_map = ndimage.interpolation.zoom(atten3_map, [1.0, 1.0, 224 / atten3_map.shape[2], - # 300 / atten3_map.shape[3]], order=0) - - up3 = self.up_concat3(g_conv3, up4) - up3, att_weight3 = self.up3(up3) - g_conv2, att2 = self.attentionblock2(conv2, up3) - - # atten2_map = att2.cpu().detach().numpy().astype(np.float) - # atten2_map = ndimage.interpolation.zoom(atten2_map, [1.0, 1.0, 224 / atten2_map.shape[2], - # 300 / atten2_map.shape[3]], order=0) - - up2 = self.up_concat2(g_conv2, up3) - up2, att_weight2 = self.up2(up2) - # g_conv1, att1 = self.attentionblock1(conv1, up2) - - # atten1_map = att1.cpu().detach().numpy().astype(np.float) - # atten1_map = ndimage.interpolation.zoom(atten1_map, [1.0, 1.0, 224 / atten1_map.shape[2], - # 300 / atten1_map.shape[3]], order=0) - up1 = self.up_concat1(conv1, up2) - up1, att_weight1 = self.up1(up1) - - # Deep Supervision - dsv4 = self.dsv4(up4) - dsv3 = self.dsv3(up3) - dsv2 = self.dsv2(up2) - dsv1 = self.dsv1(up1) - dsv_cat = torch.cat([dsv1, dsv2, dsv3, dsv4], dim=1) - out = self.scale_att(dsv_cat) - - out = self.final(out) - - return out - - - -if __name__ == "__main__": - params = {'in_chns':3, - 'class_num':2} - Net = CANet(params) - Net = Net.double() - - x = np.random.rand(2, 3, 256, 256) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - y = Net(xt) print(len(y.size())) y = y.detach().numpy() diff --git a/pymic/net/net_dict_cls.py b/pymic/net/net_dict_cls.py index 7996e59..3a7808b 100644 --- a/pymic/net/net_dict_cls.py +++ b/pymic/net/net_dict_cls.py @@ -3,7 +3,7 @@ Built-in networks for classification. * resnet18 :mod:`pymic.net.cls.torch_pretrained_net.ResNet18` -* vgg16 :mod:`pymic.net.cls.torch_pretrained_net.VGG16` +* vgg16 :mod:`pymic.net.cls.torch_pretrained_net.VGG16` * mobilenetv2 :mod:`pymic.net.cls.torch_pretrained_net.MobileNetV2` """ diff --git a/pymic/net/net_dict_seg.py b/pymic/net/net_dict_seg.py index 8862fd1..e381421 100644 --- a/pymic/net/net_dict_seg.py +++ b/pymic/net/net_dict_seg.py @@ -17,6 +17,7 @@ from __future__ import print_function, division from pymic.net.net2d.unet2d import UNet2D from pymic.net.net2d.unet2d_dual_branch import UNet2D_DualBranch +from pymic.net.net2d.unet2d_canet import CANet from pymic.net.net2d.unet2d_cct import UNet2D_CCT from pymic.net.net2d.unet2d_mcnet import MCNet2D from pymic.net.net2d.cople_net import COPLENet @@ -48,6 +49,7 @@ 'UNet2D_DualBranch': UNet2D_DualBranch, 'UNet2D_CCT': UNet2D_CCT, 'MCNet2D': MCNet2D, + 'CANet': CANet, 'COPLENet': COPLENet, 'AttentionUNet2D': AttentionUNet2D, 'NestedUNet2D': NestedUNet2D, diff --git a/pymic/net_run/agent_abstract.py b/pymic/net_run/agent_abstract.py index 7a49a2b..f9575ab 100644 --- a/pymic/net_run/agent_abstract.py +++ b/pymic/net_run/agent_abstract.py @@ -154,6 +154,15 @@ def get_checkpoint_name(self): ckpt_name = self.config['testing']['ckpt_name'] return ckpt_name + @abstractmethod + def get_stage_transform_from_config(self, stage): + """ + Get the transform list required by dataset for training, validation or inference stage. + + :param stage: (str) `train`, `valid` or `test`. + """ + raise(ValueError("not implemented")) + @abstractmethod def get_stage_dataset_from_config(self, stage): """ @@ -261,13 +270,13 @@ def worker_init_fn(worker_id): bn_train = self.config['dataset']['train_batch_size'] bn_valid = self.config['dataset'].get('valid_batch_size', 1) - num_worker = self.config['dataset'].get('num_worker', 16) + num_worker = self.config['dataset'].get('num_worker', 8) g_train, g_valid = torch.Generator(), torch.Generator() g_train.manual_seed(self.random_seed) g_valid.manual_seed(self.random_seed) self.train_loader = torch.utils.data.DataLoader(self.train_set, batch_size = bn_train, shuffle=True, num_workers= num_worker, - worker_init_fn=worker_init, generator = g_train) + worker_init_fn=worker_init, generator = g_train, drop_last = True) self.valid_loader = torch.utils.data.DataLoader(self.valid_set, batch_size = bn_valid, shuffle=False, num_workers= num_worker, worker_init_fn=worker_init, generator = g_valid) diff --git a/pymic/net_run/agent_rec.py b/pymic/net_run/agent_rec.py index 634ea20..cd311ad 100644 --- a/pymic/net_run/agent_rec.py +++ b/pymic/net_run/agent_rec.py @@ -193,7 +193,7 @@ def train_valid(self): self.min_val_loss = 10000.0 self.max_val_it = 0 self.best_model_wts = None - self.checkpoint = None + checkpoint = None # initialize the network with pre-trained weights ckpt_init_name = self.config['training'].get('ckpt_init_name', None) ckpt_init_mode = self.config['training'].get('ckpt_init_mode', 0) @@ -212,7 +212,7 @@ def train_valid(self): else: self.net.load_state_dict(pretrained_dict, strict = False) if(ckpt_init_mode > 0): # Load other information - self.min_val_loss = self.checkpoint.get('valid_loss', 10000) + self.min_val_loss = checkpoint.get('valid_loss', 10000) iter_start = checkpoint['iteration'] self.max_val_it = iter_start self.best_model_wts = checkpoint['model_state_dict'] @@ -300,6 +300,7 @@ def save_outputs(self, data): if(isinstance(pred, (list, tuple))): pred = pred[0] pred = np.tanh(pred) + # pred = scipy.special.expit(pred) # save the output predictions test_dir = self.config['dataset'].get('test_dir', None) if(test_dir is None): diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 14f7567..2d6d489 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -46,8 +46,6 @@ def get_transform_names_and_parameters(self, stage): """ assert(stage in ['train', 'valid', 'test']) transform_key = stage + '_transform' - if(stage == "valid" and transform_key not in self.config['dataset']): - transform_key = "train_transform" trans_names = self.config['dataset'][transform_key] trans_params = self.config['dataset'] trans_params['task'] = self.task_type @@ -179,6 +177,8 @@ def training(self): inputs, labels_prob = mixup(inputs, labels_prob) # for debug + # if(it > 10): + # break # for i in range(inputs.shape[0]): # image_i = inputs[i][0] # # label_i = labels_prob[i][1] @@ -192,6 +192,7 @@ def training(self): # save_nd_array_as_image(label_i, label_name, reference_name = None) # # save_nd_array_as_image(pixw_i, weight_name, reference_name = None) # continue + inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device) @@ -241,6 +242,9 @@ def validation(self): self.net.eval() for data in validIter: inputs = self.convert_tensor_type(data['image']) + if('label_prob' not in data): + raise ValueError("label_prob is not found in validation data, make sure" + + "that LabelToProbability is used in valid_transform.") labels_prob = self.convert_tensor_type(data['label_prob']) inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device) batch_n = inputs.shape[0] diff --git a/pymic/net_run/get_optimizer.py b/pymic/net_run/get_optimizer.py index 53de5e3..771b448 100644 --- a/pymic/net_run/get_optimizer.py +++ b/pymic/net_run/get_optimizer.py @@ -13,8 +13,9 @@ def get_optimizer(name, net_params, optim_params): # see https://www.codeleading.com/article/44815584159/ param_group = [{'params': net_params, 'initial_lr': lr}] if(keyword_match(name, "SGD")): + nesterov = optim_params.get('nesterov', True) return optim.SGD(param_group, lr, - momentum = momentum, weight_decay = weight_decay, nesterov = True) + momentum = momentum, weight_decay = weight_decay, nesterov = nesterov) elif(keyword_match(name, "Adam")): return optim.Adam(param_group, lr, weight_decay = weight_decay) elif(keyword_match(name, "SparseAdam")): diff --git a/pymic/net_run/preprocess.py b/pymic/net_run/preprocess.py new file mode 100644 index 0000000..f2bbe0f --- /dev/null +++ b/pymic/net_run/preprocess.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import logging +import os +import sys +import shutil +from datetime import datetime +from pymic import TaskType +from pymic.util.parse_config import * +from pymic.net_run.agent_cls import ClassificationAgent +from pymic.net_run.agent_seg import SegmentationAgent +from pymic.net_run.semi_sup import SSLMethodDict +from pymic.net_run.weak_sup import WSLMethodDict +from pymic.net_run.self_sup import SelfSupMethodDict +from pymic.net_run.noisy_label import NLLMethodDict +# from pymic.net_run.self_sup import SelfSLSegAgent + +def get_seg_rec_agent(config, sup_type): + assert(sup_type in ['fully_sup', 'semi_sup', 'self_sup', 'weak_sup', 'noisy_label']) + if(sup_type == 'fully_sup'): + logging.info("\n********** Fully Supervised Learning **********\n") + agent = SegmentationAgent(config, 'train') + elif(sup_type == 'semi_sup'): + logging.info("\n********** Semi Supervised Learning **********\n") + method = config['semi_supervised_learning']['method_name'] + agent = SSLMethodDict[method](config, 'train') + elif(sup_type == 'weak_sup'): + logging.info("\n********** Weakly Supervised Learning **********\n") + method = config['weakly_supervised_learning']['method_name'] + agent = WSLMethodDict[method](config, 'train') + elif(sup_type == 'noisy_label'): + logging.info("\n********** Noisy Label Learning **********\n") + method = config['noisy_label_learning']['method_name'] + agent = NLLMethodDict[method](config, 'train') + elif(sup_type == 'self_sup'): + logging.info("\n********** Self Supervised Learning **********\n") + method = config['self_supervised_learning']['method_name'] + agent = SelfSupMethodDict[method](config, 'train') + else: + raise ValueError("undefined supervision type: {0:}".format(sup_type)) + return agent + +def main(): + """ + The main function for running a network for training. + """ + if(len(sys.argv) < 2): + print('Number of arguments should be 2. e.g.') + print(' pymic_train config.cfg') + exit() + cfg_file = str(sys.argv[1]) + if(not os.path.isfile(cfg_file)): + raise ValueError("The config file does not exist: " + cfg_file) + config = parse_config(cfg_file) + config = synchronize_config(config) + log_dir = config['training']['ckpt_save_dir'] + if(not os.path.exists(log_dir)): + os.makedirs(log_dir, exist_ok=True) + dst_cfg = cfg_file if "/" not in cfg_file else cfg_file.split("/")[-1] + shutil.copy(cfg_file, log_dir + "/" + dst_cfg) + datetime_str = str(datetime.now())[:-7].replace(":", "_") + if sys.version.startswith("3.9"): + logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(datetime_str), + level=logging.INFO, format='%(message)s', force=True) # for python 3.9 + else: + logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(datetime_str), + level=logging.INFO, format='%(message)s') # for python 3.6 + logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) + logging_config(config) + task = config['dataset']['task_type'] + if(task == TaskType.CLASSIFICATION_ONE_HOT or task == TaskType.CLASSIFICATION_COEXIST): + agent = ClassificationAgent(config, 'train') + else: + sup_type = config['dataset'].get('supervise_type', 'fully_sup') + agent = get_seg_rec_agent(config, sup_type) + + agent.run() + +if __name__ == "__main__": + main() + + diff --git a/pymic/net_run/self_sup/__init__.py b/pymic/net_run/self_sup/__init__.py index 73308e6..d73e42a 100644 --- a/pymic/net_run/self_sup/__init__.py +++ b/pymic/net_run/self_sup/__init__.py @@ -1,3 +1,10 @@ from __future__ import absolute_import -from pymic.net_run.self_sup.self_sl_agent import SelfSLSegAgent -from pymic.net_run.self_sup.self_patch_mix_agent import SelfSLPatchMixAgent \ No newline at end of file +from pymic.net_run.self_sup.self_genesis import SelfSupModelGenesis +from pymic.net_run.self_sup.self_patch_swapping import SelfSupPatchSwapping +from pymic.net_run.self_sup.self_volume_fusion import SelfSupVolumeFusion + +SelfSupMethodDict = { + 'ModelGenesis': SelfSupModelGenesis, + 'PatchSwapping': SelfSupPatchSwapping, + 'VolumeFusion': SelfSupVolumeFusion + } \ No newline at end of file diff --git a/pymic/net_run/self_sup/self_genesis.py b/pymic/net_run/self_sup/self_genesis.py new file mode 100644 index 0000000..85ee194 --- /dev/null +++ b/pymic/net_run/self_sup/self_genesis.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import copy +import logging +import time +from pymic.net_run.agent_rec import ReconstructionAgent + +class SelfSupModelGenesis(ReconstructionAgent): + """ + Patch swapping-based self-supervised learning. + + Reference: Liang Chen et al., Self-supervised learning for medical image analysis + using image context restoration, Medical Image Analysis, 2019. + + A PatchSwaping transform need to be used in the cnfiguration. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `self_supervised_learning` is needed. See :doc:`usage.selfsl` for details. + + In the configuration file, it should look like this: + ``` + [dataset] + task_type = rec + supervise_type = self_sup + train_transform = [..., ..., PatchSwaping] + valid_transform = [..., ..., PatchSwaping] + + [self_supervised_learning] + method_name = ModelGenesis + + """ + def __init__(self, config, stage = 'train'): + super(SelfSupModelGenesis, self).__init__(config, stage) + + def get_transform_names_and_parameters(self, stage): + trans_names, trans_params = super(SelfSupModelGenesis, self).get_transform_names_and_parameters(stage) + # if(stage == 'train'): + # print('training transforms:', trans_names) + # if("LocalShuffling" not in trans_names): + # raise ValueError("LocalShuffling is required for model genesis, \ + # but it is not given in training transform") + # if("NonLinearTransform" not in trans_names): + # raise ValueError("NonLinearTransform is required for model genesis, \ + # but it is not given in training transform") + # if("InOutPainting" not in trans_names): + # raise ValueError("InOutPainting is required for model genesis, \ + # but it is not given in training transform") + return trans_names, trans_params diff --git a/pymic/net_run/self_sup/self_patch_swapping.py b/pymic/net_run/self_sup/self_patch_swapping.py new file mode 100644 index 0000000..1692fa7 --- /dev/null +++ b/pymic/net_run/self_sup/self_patch_swapping.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import copy +import logging +import time +from pymic.net_run.agent_rec import ReconstructionAgent + +class SelfSupPatchSwapping(ReconstructionAgent): + """ + Patch swapping-based self-supervised learning. + + Reference: Liang Chen et al., Self-supervised learning for medical image analysis + using image context restoration, Medical Image Analysis, 2019. + + A PatchSwaping transform need to be used in the cnfiguration. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `self_supervised_learning` is needed. See :doc:`usage.selfsl` for details. + + In the configuration file, it should look like this: + ``` + [dataset] + task_type = rec + supervise_type = self_sup + train_transform = [..., ..., PatchSwaping] + valid_transform = [..., ..., PatchSwaping] + + [self_supervised_learning] + method_name = PatchSwapping + + """ + def __init__(self, config, stage = 'train'): + super(SelfSupPatchSwapping, self).__init__(config, stage) + + def get_transform_names_and_parameters(self, stage): + trans_names, trans_params = super(SelfSupPatchSwapping, self).get_transform_names_and_parameters(stage) + if(stage == 'train'): + print('training transforms:', trans_names) + assert("PatchSwaping" in trans_names) + return trans_names, trans_params + diff --git a/pymic/net_run/self_sup/self_sl_agent.py b/pymic/net_run/self_sup/self_sl_agent.py index c352adf..45bee26 100644 --- a/pymic/net_run/self_sup/self_sl_agent.py +++ b/pymic/net_run/self_sup/self_sl_agent.py @@ -6,6 +6,7 @@ from pymic.net_run.agent_rec import ReconstructionAgent + class SelfSLSegAgent(ReconstructionAgent): """ Abstract class for self-supervised segmentation. @@ -17,7 +18,7 @@ class SelfSLSegAgent(ReconstructionAgent): In the configuration dictionary, in addition to the four sections (`dataset`, `network`, `training` and `inference`) used in fully supervised learning, an - extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. + extra section `self_supervised_learning` is needed. See :doc:`usage.selfsl` for details. """ def __init__(self, config, stage = 'train'): super(SelfSLSegAgent, self).__init__(config, stage) diff --git a/pymic/net_run/self_sup/self_patch_mix_agent.py b/pymic/net_run/self_sup/self_volume_fusion.py similarity index 71% rename from pymic/net_run/self_sup/self_patch_mix_agent.py rename to pymic/net_run/self_sup/self_volume_fusion.py index e30a131..91fe088 100644 --- a/pymic/net_run/self_sup/self_patch_mix_agent.py +++ b/pymic/net_run/self_sup/self_volume_fusion.py @@ -32,11 +32,13 @@ from pymic.util.post_process import PostProcessDict from pymic.util.image_process import convert_label from pymic.util.parse_config import * +from pymic.util.general import get_one_hot_seg from pymic.io.image_read_write import save_nd_array_as_image -from pymic.net_run.self_sup.util import patch_mix +from pymic.net_run.self_sup.util import volume_fusion from pymic.net_run.agent_seg import SegmentationAgent -class SelfSLPatchMixAgent(SegmentationAgent): + +class SelfSupVolumeFusion(SegmentationAgent): """ Abstract class for self-supervised segmentation. @@ -50,16 +52,15 @@ class SelfSLPatchMixAgent(SegmentationAgent): extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. """ def __init__(self, config, stage = 'train'): - super(SelfSLPatchMixAgent, self).__init__(config, stage) + super(SelfSupVolumeFusion, self).__init__(config, stage) def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] - fg_num = self.config['network']['class_num'] - 1 - patch_num = self.config['patch_mix']['patch_num_range'] - size_d = self.config['patch_mix']['patch_depth_range'] - size_h = self.config['patch_mix']['patch_height_range'] - size_w = self.config['patch_mix']['patch_width_range'] + cls_num = self.config['network']['class_num'] + block_range = self.config['self_supervised_learning']['VolumeFusion_block_range'.lower()] + size_min = self.config['self_supervised_learning']['VolumeFusion_size_min'.lower()] + size_max = self.config['self_supervised_learning']['VolumeFusion_size_max'.lower()] train_loss = 0 train_dice_list = [] @@ -72,16 +73,16 @@ def training(self): data = next(self.trainIter) # get the inputs inputs = self.convert_tensor_type(data['image']) - inputs, labels_prob = patch_mix(inputs, fg_num, patch_num, size_d, size_h, size_w) + inputs, labels = volume_fusion(inputs, cls_num - 1, block_range, size_min, size_max) + labels_prob = get_one_hot_seg(labels, cls_num) - # # for debug + # for debug # if(it==10): # break # for i in range(inputs.shape[0]): # image_i = inputs[i][0] # label_i = np.argmax(labels_prob[i], axis = 0) # # pixw_i = pix_w[i][0] - # print(image_i.shape, label_i.shape) # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) # label_name = "temp/label_{0:}_{1:}.nii.gz".format(it, i) # # weight_name= "temp/weight_{0:}_{1:}.nii.gz".format(it, i) @@ -116,29 +117,3 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'avg_fg_dice':train_avg_dice,\ 'class_dice': train_cls_dice} return train_scalers - -def main(): - cfg_file = str(sys.argv[1]) - if(not os.path.isfile(cfg_file)): - raise ValueError("The config file does not exist: " + cfg_file) - config = parse_config(cfg_file) - config = synchronize_config(config) - log_dir = config['training']['ckpt_save_dir'] - if(not os.path.exists(log_dir)): - os.makedirs(log_dir, exist_ok=True) - dst_cfg = cfg_file if "/" not in cfg_file else cfg_file.split("/")[-1] - shutil.copy(cfg_file, log_dir + "/" + dst_cfg) - if sys.version.startswith("3.9"): - logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(str(datetime.now())[:-7]), - level=logging.INFO, format='%(message)s', force=True) # for python 3.9 - else: - logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(str(datetime.now())[:-7]), - level=logging.INFO, format='%(message)s') # for python 3.6 - logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) - logging_config(config) - agent = SelfSLPatchMixAgent(config) - agent.run() - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/pymic/net_run/self_sup/util.py b/pymic/net_run/self_sup/util.py index 0ab1670..db27702 100644 --- a/pymic/net_run/self_sup/util.py +++ b/pymic/net_run/self_sup/util.py @@ -7,7 +7,7 @@ from scipy import ndimage from pymic.io.image_read_write import * from pymic.util.image_process import * -from pymic.util.general import get_one_hot_seg + def get_human_region_mask(img): """ @@ -137,111 +137,36 @@ def get_human_body_mask_and_crop(input_dir, out_img_dir, out_mask_dir): sitk.WriteImage(mask_obj, mask_name) -def patch_mix(x, fg_num, patch_num, size_d, size_h, size_w): +def volume_fusion(x, fg_num, block_range, size_min, size_max): """ - Copy a sub region of an impage and paste to another one to generate + Fuse a subregion of an impage with another one to generate images and labels for self-supervised segmentation. + input x should be a batch of tensors """ + #n_min, n_max, N, C, D, H, W = list(x.shape) - fg_mask = torch.zeros_like(x) + fg_mask = torch.zeros_like(x).to(torch.int32) # generate mask for n in range(N): - p_num = random.randint(patch_num[0], patch_num[1]) + p_num = random.randint(block_range[0], block_range[1]) for i in range(p_num): - d = random.randint(size_d[0], size_d[1]) - h = random.randint(size_h[0], size_h[1]) - w = random.randint(size_w[0], size_w[1]) - d_c = random.randint(0, D) - h_c = random.randint(0, H) - w_c = random.randint(0, W) - d0, d1 = max(0, d_c - d), min(D, d_c + d) - h0, h1 = max(0, h_c - h), min(H, h_c + h) - w0, w1 = max(0, w_c - w), min(W, w_c + w) - temp_m = torch.ones([C, d1-d0, h1-h0, w1-w0]) * random.randint(1, fg_num) + d = random.randint(size_min[0], size_max[0]) + h = random.randint(size_min[1], size_max[1]) + w = random.randint(size_min[2], size_max[2]) + dc = random.randint(0, D - 1) + hc = random.randint(0, H - 1) + wc = random.randint(0, W - 1) + d0 = dc - d // 2 + h0 = hc - h // 2 + w0 = wc - w // 2 + d1 = min(D, d0 + d) + h1 = min(H, h0 + h) + w1 = min(W, w0 + w) + d0, h0, w0 = max(0, d0), max(0, h0), max(0, w0) + temp_m = torch.ones([C, d1 - d0, h1 - h0, w1 - w0]) * random.randint(1, fg_num) fg_mask[n, :, d0:d1, h0:h1, w0:w1] = temp_m fg_w = fg_mask * 1.0 / fg_num x_roll = torch.roll(x, 1, 0) x_fuse = fg_w*x_roll + (1.0 - fg_w)*x - y_prob = get_one_hot_seg(fg_mask.to(torch.int32), fg_num + 1) - return x_fuse, y_prob - -def create_mixed_dataset(input_dir, output_dir, fg_num = 1, crop_num = 1, patch_size=[128,128,128], - mask_dir = None, data_format = "nii.gz"): - """ - Create dataset based on patch mix. - - :param input_dir: (str) The path of folder for input images - :param output_dir: (str) The path of folder for output images - :param fg_num: (int) The number of foreground classes - :param crop_num: (int) The number of patches to crop for each input image - :param mask: ND array to specify a mask, or 'default' or None. If default, - a mask for body region is automatically generated (just for CT). - :param data_format: (str) The format of images. - """ - img_names = os.listdir(input_dir) - img_names = [item for item in img_names if item.endswith(data_format)] - img_names = sorted(img_names) - out_img_dir = output_dir + "/image" - out_lab_dir = output_dir + "/label" - if(not os.path.exists(out_img_dir)): - os.mkdir(out_img_dir) - if(not os.path.exists(out_lab_dir)): - os.mkdir(out_lab_dir) - - img_num = len(img_names) - print("image number", img_num) - i_range = range(img_num) - j_range = list(i_range) - random.shuffle(j_range) - for i in i_range: - print(i, img_names[i]) - j = j_range[i] - if(i == j): - j = i + 1 if i < img_num - 1 else 0 - img_i = load_image_as_nd_array(input_dir + "/" + img_names[i])['data_array'] - img_j = load_image_as_nd_array(input_dir + "/" + img_names[j])['data_array'] - - chns = img_i.shape[0] - crop_size = [chns] + patch_size - # random crop to patch size - if(mask_dir is None): - mask_i = get_human_region_mask(img_i) - mask_j = get_human_region_mask(img_j) - else: - mask_i = load_image_as_nd_array(mask_dir + "/" + img_names[i])['data_array'] - mask_j = load_image_as_nd_array(mask_dir + "/" + img_names[j])['data_array'] - for k in range(crop_num): - # if(mask is None): - # img_ik = random_crop_ND_volume(img_i, [chns, 96, 96, 96]) - # img_jk = random_crop_ND_volume(img_j, [chns, 96, 96, 96]) - # else: - - img_ik = random_crop_ND_volume_with_mask(img_i, crop_size, mask_i) - img_jk = random_crop_ND_volume_with_mask(img_j, crop_size, mask_j) - C, D, H, W = img_ik.shape - # generate mask - fg_mask = np.zeros_like(img_ik, np.uint8) - patch_num = random.randint(4, 40) - for patch in range(patch_num): - d = random.randint(4, 20) # half of window size - h = random.randint(4, 40) - w = random.randint(4, 40) - d_c = random.randint(0, D) - h_c = random.randint(0, H) - w_c = random.randint(0, W) - d0, d1 = max(0, d_c - d), min(D, d_c + d) - h0, h1 = max(0, h_c - h), min(H, h_c + h) - w0, w1 = max(0, w_c - w), min(W, w_c + w) - temp_m = np.ones([C, d1-d0, h1-h0, w1-w0]) * random.randint(1, fg_num) - fg_mask[:, d0:d1, h0:h1, w0:w1] = temp_m - fg_w = fg_mask * 1.0 / fg_num - x_fuse = fg_w*img_jk + (1.0 - fg_w)*img_ik - - out_name = img_names[i] - if crop_num > 1: - out_name = out_name.replace(".nii.gz", "_{0:}.nii.gz".format(k)) - save_nd_array_as_image(x_fuse[0], out_img_dir + "/" + out_name, - reference_name = input_dir + "/" + img_names[i]) - save_nd_array_as_image(fg_mask[0], out_lab_dir + "/" + out_name, - reference_name = input_dir + "/" + img_names[i]) - + # y_prob = get_one_hot_seg(fg_mask.to(torch.int32), fg_num + 1) + return x_fuse, fg_mask diff --git a/pymic/net_run/semi_sup/__init__.py b/pymic/net_run/semi_sup/__init__.py index 39ca5dc..d3095f6 100644 --- a/pymic/net_run/semi_sup/__init__.py +++ b/pymic/net_run/semi_sup/__init__.py @@ -1,17 +1,18 @@ from __future__ import absolute_import -# from . import * from pymic.net_run.semi_sup.ssl_abstract import SSLSegAgent from pymic.net_run.semi_sup.ssl_em import SSLEntropyMinimization from pymic.net_run.semi_sup.ssl_mt import SSLMeanTeacher +from pymic.net_run.semi_sup.ssl_mcnet import SSLMCNet from pymic.net_run.semi_sup.ssl_uamt import SSLUncertaintyAwareMeanTeacher from pymic.net_run.semi_sup.ssl_cct import SSLCCT from pymic.net_run.semi_sup.ssl_cps import SSLCPS from pymic.net_run.semi_sup.ssl_urpc import SSLURPC -# SSLMethodDict = {'EntropyMinimization': SSLEntropyMinimization, -# 'MeanTeacher': SSLMeanTeacher, -# 'UAMT': SSLUncertaintyAwareMeanTeacher, -# 'CCT': SSLCCT, -# 'CPS': SSLCPS, -# 'URPC': SSLURPC} \ No newline at end of file +SSLMethodDict = {'EntropyMinimization': SSLEntropyMinimization, + 'MeanTeacher': SSLMeanTeacher, + 'MCNet': SSLMCNet, + 'UAMT': SSLUncertaintyAwareMeanTeacher, + 'CCT': SSLCCT, + 'CPS': SSLCPS, + 'URPC': SSLURPC} \ No newline at end of file diff --git a/pymic/net_run/semi_sup/ssl_abstract.py b/pymic/net_run/semi_sup/ssl_abstract.py index b27edc9..0e05281 100644 --- a/pymic/net_run/semi_sup/ssl_abstract.py +++ b/pymic/net_run/semi_sup/ssl_abstract.py @@ -35,7 +35,7 @@ def get_unlabeled_dataset_from_config(self): """ Create a dataset for the unlabeled images based on configuration. """ - root_dir = self.config['dataset']['root_dir'] + train_dir = self.config['dataset']['train_dir'] modal_num = self.config['dataset'].get('modal_num', 1) transform_names = self.config['dataset']['train_transform_unlab'] @@ -53,7 +53,7 @@ def get_unlabeled_dataset_from_config(self): data_transform = transforms.Compose(self.transform_list) csv_file = self.config['dataset'].get('train_csv_unlab', None) - dataset = NiftyDataset(root_dir=root_dir, + dataset = NiftyDataset(root_dir = train_dir, csv_file = csv_file, modal_num = modal_num, with_label= False, @@ -76,7 +76,7 @@ def worker_init_fn(worker_id): num_worker = self.config['dataset'].get('num_worker', 16) self.train_loader_unlab = torch.utils.data.DataLoader(self.train_set_unlab, batch_size = bn_train_unlab, shuffle=True, num_workers= num_worker, - worker_init_fn=worker_init) + worker_init_fn=worker_init, drop_last = True) def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): loss_scalar ={'train':train_scalars['loss'], diff --git a/pymic/net_run/semi_sup/ssl_cps.py b/pymic/net_run/semi_sup/ssl_cps.py index df4b5af..7acfe17 100644 --- a/pymic/net_run/semi_sup/ssl_cps.py +++ b/pymic/net_run/semi_sup/ssl_cps.py @@ -3,12 +3,14 @@ import logging import numpy as np import torch +from random import random from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.io.image_read_write import save_nd_array_as_image from pymic.net_run.semi_sup import SSLSegAgent from pymic.util.ramps import get_rampup_ratio +from pymic.util.general import mixup, tensor_shape_match class SSLCPS(SSLSegAgent): """ @@ -34,7 +36,8 @@ def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] ssl_cfg = self.config['semi_supervised_learning'] - iter_max = self.config['training']['iter_max'] + iter_max = self.config['training']['iter_max'] + mixup_prob = self.config['training'].get('mixup_probability', 0.0) rampup_start = ssl_cfg.get('rampup_start', 0) rampup_end = ssl_cfg.get('rampup_end', iter_max) train_loss = 0 @@ -70,7 +73,8 @@ def training(self): # save_nd_array_as_image(image_i, image_name, reference_name = None) # save_nd_array_as_image(label_i, label_name, reference_name = None) # continue - + if(mixup_prob > 0 and random() < mixup_prob): + x0, y0 = mixup(x0, y0) inputs = torch.cat([x0, x1], dim = 0) inputs, y0 = inputs.to(self.device), y0.to(self.device) diff --git a/pymic/net_run/train.py b/pymic/net_run/train.py index f2bbe0f..50a5fb7 100644 --- a/pymic/net_run/train.py +++ b/pymic/net_run/train.py @@ -9,6 +9,7 @@ from pymic.util.parse_config import * from pymic.net_run.agent_cls import ClassificationAgent from pymic.net_run.agent_seg import SegmentationAgent +from pymic.net_run.agent_rec import ReconstructionAgent from pymic.net_run.semi_sup import SSLMethodDict from pymic.net_run.weak_sup import WSLMethodDict from pymic.net_run.self_sup import SelfSupMethodDict @@ -19,7 +20,10 @@ def get_seg_rec_agent(config, sup_type): assert(sup_type in ['fully_sup', 'semi_sup', 'self_sup', 'weak_sup', 'noisy_label']) if(sup_type == 'fully_sup'): logging.info("\n********** Fully Supervised Learning **********\n") - agent = SegmentationAgent(config, 'train') + if config['dataset']['task_type'] == TaskType.SEGMENTATION: + agent = SegmentationAgent(config, 'train') + else: + agent = ReconstructionAgent(config, 'train') elif(sup_type == 'semi_sup'): logging.info("\n********** Semi Supervised Learning **********\n") method = config['semi_supervised_learning']['method_name'] diff --git a/pymic/transform/crop.py b/pymic/transform/crop.py index 6977639..b821bb2 100644 --- a/pymic/transform/crop.py +++ b/pymic/transform/crop.py @@ -384,4 +384,92 @@ def __call__(self, sample): if(resize): weight = ndimage.interpolation.zoom(weight, scale, order = 1) sample['pixel_weight'] = weight - return sample \ No newline at end of file + return sample + +class RandomSlice(AbstractTransform): + """Randomly selecting N slices from a volume + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `RandomSlice_output_size`: (int) Desired number of slice for output. + :param `RandomSlice_inverse`: (optional, bool) Is inverse transform needed for inference. + Default is `True`. + """ + def __init__(self, params): + self.output_size = params['RandomSlice_output_size'.lower()] + self.shuffle = params.get('RandomSlice_shuffle'.lower(), False) + self.inverse = params.get('RandomSlice_inverse'.lower(), False) + self.task = params['Task'.lower()] + + def __call__(self, sample): + image = sample['image'] + D = image.shape[1] + assert( D >= self.output_size) + slice_idx = list(range(D)) + if(self.shuffle): + random.shuffle(slice_idx) + slice_idx = slice_idx[:self.output_size] + else: + d0 = random.randint(0, D - self.output_size) + d1 = d0 + self.output_size + slice_idx = slice_idx[d0:d1] + sample['image'] = image[:, slice_idx, :, :] + + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + label = sample['label'] + sample['label'] = label[:, slice_idx, :, :] + + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + weight = sample['pixel_weight'] + sample['pixel_weight'] = weight[:, slice_idx, :, :] + + return sample + +class CropHumanRegionFromCT(CenterCrop): + """ + Crop the human region from a CT volume. + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `CropWithBoundingBox_start`: (None, or list/tuple) The start index + along each spatial axis. If None, calculate the start index automatically + so that the cropped region is centered at the mask region defined by the threshold. + :param `CropWithBoundingBox_output_size`: (None or tuple/list): + Desired spatial output size. + If None, set it as the size of bounding box of the mask region defined by the threshold. + :param `CropWithBoundingBox_threshold`: (None or float): + Threshold for obtaining a mask. This is used only when + `CropWithBoundingBox_start` is None. Default is 1.0 + :param `CropWithBoundingBox_inverse`: (optional, bool) Is inverse transform needed for inference. + Default is `True`. + """ + def __init__(self, params): + self.threshold_i = params.get('CropHumanRegionFromCT_intensity_threshold'.lower(), -600) + self.threshold_z = params.get('CropHumanRegionFromCT_zaxis_threshold'.lower(), 0.5) + self.inverse = params.get('CropHumanRegionFromCT_inverse'.lower(), True) + self.task = params['task'] + + def _get_crop_param(self, sample): + image = sample['image'] + input_shape = image.shape + mask = np.asarray(image[0] > self.threshold_i) + mask2d = np.mean(mask, axis = 0) > self.threshold_z + se = np.ones([3,3]) + mask2d = ndimage.binary_opening(mask2d, se, iterations = 2) + mask2d = get_largest_k_components(mask2d, 1) + bbmin, bbmax = get_ND_bounding_box(mask2d, margin = [0, 0]) + crop_min = [0, 0] + bbmin + crop_max = list(input_shape[:2]) + bbmax + sample['CropHumanRegionFromCT_Param'] = json.dumps((input_shape, crop_min, crop_max)) + return sample, crop_min, crop_max + + def _get_param_for_inverse_transform(self, sample): + if(isinstance(sample['CropHumanRegionFromCT_Param'], list) or \ + isinstance(sample['CropHumanRegionFromCT_Param'], tuple)): + params = json.loads(sample['CropHumanRegionFromCT_Param'][0]) + else: + params = json.loads(sample['CropHumanRegionFromCT_Param']) + return params \ No newline at end of file diff --git a/pymic/transform/extract_channel.py b/pymic/transform/extract_channel.py new file mode 100644 index 0000000..c4974be --- /dev/null +++ b/pymic/transform/extract_channel.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import json +import math +import random +import numpy as np +from scipy import ndimage +from pymic import TaskType +from pymic.transform.abstract_transform import AbstractTransform +from pymic.util.image_process import * + + +class ExtractChannel(AbstractTransform): + """ Random flip the image. The shape is [C, D, H, W] or [C, H, W]. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `RandomFlip_flip_depth`: (bool) + Random flip along depth axis or not, only used for 3D images. + :param `RandomFlip_flip_height`: (bool) Random flip along height axis or not. + :param `RandomFlip_flip_width`: (bool) Random flip along width axis or not. + :param `RandomFlip_inverse`: (optional, bool) Is inverse transform needed for inference. + Default is `True`. + """ + def __init__(self, params): + super(ExtractChannel, self).__init__(params) + self.channels = params['ExtractChannel_channels'.lower()] + self.inverse = params.get('ExtractChannel_inverse'.lower(), False) + + def __call__(self, sample): + image = sample['image'] + image_extract = [] + for i in self.channels: + image_extract.append(image[i]) + sample['image'] = np.asarray(image_extract) + return sample diff --git a/pymic/transform/intensity.py b/pymic/transform/intensity.py index d39eb2c..2b19ebc 100644 --- a/pymic/transform/intensity.py +++ b/pymic/transform/intensity.py @@ -37,7 +37,7 @@ def bezier_curve(points, nTimes=1000): t = np.linspace(0.0, 1.0, nTimes) - polynomial_array = np.array([ bernstein_poly(i, nPoints-1, t) for i in range(0, nPoints) ]) + polynomial_array = np.array([ bernstein_poly(i, nPoints-1, t) for i in range(0, nPoints)]) xvals = np.dot(xPoints, polynomial_array) yvals = np.dot(yPoints, polynomial_array) @@ -182,30 +182,54 @@ def __call__(self, sample): class NonLinearTransform(AbstractTransform): def __init__(self, params): super(NonLinearTransform, self).__init__(params) - self.channels = params['NonLinearTransform_channels'.lower()] + self.channels = params.get('NonLinearTransform_channels'.lower(), None) self.prob = params.get('NonLinearTransform_probability'.lower(), 0.5) self.inverse = params.get('NonLinearTransform_inverse'.lower(), False) + self.block_range = params.get('NonLinearTransform_block_range'.lower(), None) + self.block_size = params.get('NonLinearTransform_block_size'.lower(), [8, 16, 16]) + def __apply_nonlinear_transform(self, img): + """ + the input img should be normlized to [0, 1]""" + points = [[0, 0], [random.random(), random.random()], [random.random(), random.random()], [1, 1]] + xvals, yvals = bezier_curve(points, nTimes=10000) + if random.random() < 0.5: # Half chance to get flip + xvals = np.sort(xvals) + else: + xvals, yvals = np.sort(xvals), np.sort(yvals) + + img = np.interp(img, xvals, yvals) + return img + def __call__(self, sample): if(random.random() > self.prob): return sample - image = sample['image'] - for chn in self.channels: - points = [[0, 0], [random.random(), random.random()], [random.random(), random.random()], [1, 1]] - xvals, yvals = bezier_curve(points, nTimes=10000) - if random.random() < 0.5: # Half change to get flip - xvals = np.sort(xvals) - else: - xvals, yvals = np.sort(xvals), np.sort(yvals) + image = sample['image'] + img_shape = image.shape + img_dim = len(img_shape) - 1 + channels = self.channels if self.channels is not None else range(image.shape[0]) + for chn in channels: # normalize the image intensity to [0, 1] before the non-linear tranform img_c = image[chn] - v_min = img_c.min() - v_max = img_c.max() + v_min, v_max = img_c.min(), img_c.max() if(v_min < v_max): img_c = (img_c - v_min)/(v_max - v_min) - img_c = np.interp(img_c, xvals, yvals) + if(self.block_range is None): # apply non-linear transform to the entire image + img_c = self.__apply_nonlinear_transform(img_c) + else: # non-linear transform to random blocks + img_c_sr = copy.deepcopy(img_c) + for n in range(self.block_range[0], self.block_range[1]): + coord_min = [random.randint(0, img_shape[1+i] - self.block_size[i]) \ + for i in range(img_dim)] + window = img_c_sr[coord_min[0]:coord_min[0] + self.block_size[0], + coord_min[1]:coord_min[1] + self.block_size[1], + coord_min[2]:coord_min[2] + self.block_size[2]] + img_c[coord_min[0]:coord_min[0] + self.block_size[0], + coord_min[1]:coord_min[1] + self.block_size[1], + coord_min[2]:coord_min[2] + self.block_size[2]] = \ + self.__apply_nonlinear_transform(window) image[chn] = img_c * (v_max - v_min) + v_min sample['image'] = image return sample @@ -218,9 +242,8 @@ def __init__(self, params): super(LocalShuffling, self).__init__(params) self.inverse = params.get('LocalShuffling_inverse'.lower(), False) self.prob = params.get('LocalShuffling_probability'.lower(), 0.5) - self.block_range = params.get('LocalShuffling_block_range'.lower(), (5000, 10000)) - self.block_size_min = params.get('LocalShuffling_block_size_min'.lower(), None) - self.block_size_max = params.get('LocalShuffling_block_size_max'.lower(), None) + self.block_range = params.get('LocalShuffling_block_range'.lower(), [40, 80]) + self.block_size = params.get('LocalShuffling_block_size'.lower(), [4, 8, 8]) def __call__(self, sample): if(random.random() > self.prob): @@ -231,49 +254,33 @@ def __call__(self, sample): img_dim = len(img_shape) - 1 assert(img_dim == 2 or img_dim == 3) img_out = copy.deepcopy(image) - if(self.block_size_min is None): - block_size_min = [2] * img_dim - elif(isinstance(self.block_size_min, int)): - block_size_min = [self.block_size_min] * img_dim - else: - assert(len(self.block_size_min) == img_dim) - block_size_min = self.block_size_min - - if(self.block_size_max is None): - block_size_max = [img_shape[1+i]//10 for i in range(img_dim)] - elif(isinstance(self.block_size_min, int)): - block_size_max = [self.block_size_max] * img_dim - else: - assert(len(self.block_size_max) == img_dim) - block_size_max = self.block_size_max + block_num = random.randint(self.block_range[0], self.block_range[1]) for n in range(block_num): - block_size = [random.randint(block_size_min[i], block_size_max[i]) \ - for i in range(img_dim)] - coord_min = [random.randint(0, img_shape[1+i] - block_size[i]) \ + coord_min = [random.randint(0, img_shape[1+i] - self.block_size[i]) \ for i in range(img_dim)] if(img_dim == 2): - window = image[:, coord_min[0]:coord_min[0] + block_size[0], - coord_min[1]:coord_min[1] + block_size[1]] - n_pixels = block_size[0] * block_size[1] + window = image[:, coord_min[0]:coord_min[0] + self.block_size[0], + coord_min[1]:coord_min[1] + self.block_size[1]] + n_pixels = self.block_size[0] * self.block_size[1] else: - window = image[:, coord_min[0]:coord_min[0] + block_size[0], - coord_min[1]:coord_min[1] + block_size[1], - coord_min[2]:coord_min[2] + block_size[2]] - n_pixels = block_size[0] * block_size[1] * block_size[2] + window = image[:, coord_min[0]:coord_min[0] + self.block_size[0], + coord_min[1]:coord_min[1] + self.block_size[1], + coord_min[2]:coord_min[2] + self.block_size[2]] + n_pixels = self.block_size[0] * self.block_size[1] * self.block_size[2] window = np.reshape(window, [-1, n_pixels]) np.random.shuffle(np.transpose(window)) window = np.transpose(window) if(img_dim == 2): - window = np.reshape(window, [-1, block_size[0], block_size[1]]) - img_out[:, coord_min[0]:coord_min[0] + block_size[0], - coord_min[1]:coord_min[1] + block_size[1]] = window + window = np.reshape(window, [-1, self.block_size[0], self.block_size[1]]) + img_out[:, coord_min[0]:coord_min[0] + self.block_size[0], + coord_min[1]:coord_min[1] + self.block_size[1]] = window else: - window = np.reshape(window, [-1, block_size[0], block_size[1], block_size[2]]) - img_out[:, coord_min[0]:coord_min[0] + block_size[0], - coord_min[1]:coord_min[1] + block_size[1], - coord_min[2]:coord_min[2] + block_size[2]] = window + window = np.reshape(window, [-1, self.block_size[0], self.block_size[1], self.block_size[2]]) + img_out[:, coord_min[0]:coord_min[0] + self.block_size[0], + coord_min[1]:coord_min[1] + self.block_size[1], + coord_min[2]:coord_min[2] + self.block_size[2]] = window sample['image'] = img_out return sample @@ -285,10 +292,9 @@ def __init__(self, params): super(InPainting, self).__init__(params) self.inverse = params.get('InPainting_inverse'.lower(), False) self.prob = params.get('InPainting_probability'.lower(), 0.5) - self.block_range = params.get('InPainting_block_range'.lower(), (1, 6)) - self.block_size_min = params.get('InPainting_block_size_min'.lower(), None) - self.block_size_max = params.get('InPainting_block_size_max'.lower(), None) - + self.block_range = params.get('InPainting_block_range'.lower(), (20, 40)) + self.block_size = params.get('InPainting_block_size'.lower(), [8, 16, 16]) + def __call__(self, sample): if(random.random() > self.prob): return sample @@ -298,38 +304,21 @@ def __call__(self, sample): img_dim = len(img_shape) - 1 assert(img_dim == 2 or img_dim == 3) - if(self.block_size_min is None): - block_size_min = [img_shape[1+i]//6 for i in range(img_dim)] - elif(isinstance(self.block_size_min, int)): - block_size_min = [self.block_size_min] * img_dim - else: - assert(len(self.block_size_min) == img_dim) - block_size_min = self.block_size_min - - if(self.block_size_max is None): - block_size_max = [img_shape[1+i]//3 for i in range(img_dim)] - elif(isinstance(self.block_size_min, int)): - block_size_max = [self.block_size_max] * img_dim - else: - assert(len(self.block_size_max) == img_dim) - block_size_max = self.block_size_max block_num = random.randint(self.block_range[0], self.block_range[1]) - for n in range(block_num): - block_size = [random.randint(block_size_min[i], block_size_max[i]) \ - for i in range(img_dim)] - coord_min = [random.randint(3, img_shape[1+i] - block_size[i] - 3) \ + for n in range(block_num): + coord_min = [random.randint(3, img_shape[1+i] - self.block_size[i] - 3) \ for i in range(img_dim)] if(img_dim == 2): - random_block = np.random.rand(img_shape[0], block_size[0], block_size[1]) - image[:, coord_min[0]:coord_min[0] + block_size[0], - coord_min[1]:coord_min[1] + block_size[1]] = random_block + random_block = np.random.rand(img_shape[0], self.block_size[0], self.block_size[1]) * 2 -1 + image[:, coord_min[0]:coord_min[0] + self.block_size[0], + coord_min[1]:coord_min[1] + self.block_size[1]] = random_block else: - random_block = np.random.rand(img_shape[0], block_size[0], - block_size[1], block_size[2]) - image[:, coord_min[0]:coord_min[0] + block_size[0], - coord_min[1]:coord_min[1] + block_size[1], - coord_min[2]:coord_min[2] + block_size[2]] = random_block + random_block = np.random.rand(img_shape[0], self.block_size[0], + self.block_size[1], self.block_size[2]) * 2 -1 + image[:, coord_min[0]:coord_min[0] + self.block_size[0], + coord_min[1]:coord_min[1] + self.block_size[1], + coord_min[2]:coord_min[2] + self.block_size[2]] = random_block sample['image'] = image return sample @@ -341,9 +330,8 @@ def __init__(self, params): super(OutPainting, self).__init__(params) self.inverse = params.get('OutPainting_inverse'.lower(), False) self.prob = params.get('OutPainting_probability'.lower(), 0.5) - self.block_range = params.get('OutPainting_block_range'.lower(), (1, 6)) - self.block_size_min = params.get('OutPainting_block_size_min'.lower(), None) - self.block_size_max = params.get('OutPainting_block_size_max'.lower(), None) + self.block_range = params.get('OutPainting_block_range'.lower(), (2, 8)) + self.block_size = params.get('OutPainting_block_size'.lower(), None) def __call__(self, sample): if(random.random() > self.prob): @@ -353,28 +341,18 @@ def __call__(self, sample): img_shape = image.shape img_dim = len(img_shape) - 1 assert(img_dim == 2 or img_dim == 3) - img_out = np.random.rand(*img_shape) + img_out = np.random.rand(*img_shape) * 2 -1 - if(self.block_size_min is None): - block_size_min = [img_shape[1+i] - 4 * img_shape[1+i]//7 for i in range(img_dim)] - elif(isinstance(self.block_size_min, int)): - block_size_min = [self.block_size_min] * img_dim + if(self.block_size is None): + margin = [16, 32, 32] + block_size = [img_shape[1+i] - margin[i] for i in range(img_dim)] else: - assert(len(self.block_size_min) == img_dim) - block_size_min = self.block_size_min + assert(len(self.block_size) == img_dim) + block_size = self.block_size - if(self.block_size_max is None): - block_size_max = [img_shape[1+i] - 3 * img_shape[1+i]//7 for i in range(img_dim)] - elif(isinstance(self.block_size_min, int)): - block_size_max = [self.block_size_max] * img_dim - else: - assert(len(self.block_size_max) == img_dim) - block_size_max = self.block_size_max block_num = random.randint(self.block_range[0], self.block_range[1]) for n in range(block_num): - block_size = [random.randint(block_size_min[i], block_size_max[i]) \ - for i in range(img_dim)] coord_min = [random.randint(3, img_shape[1+i] - block_size[i] - 3) \ for i in range(img_dim)] if(img_dim == 2): @@ -401,8 +379,8 @@ def __init__(self, params): self.inverse = params.get('InOutPainting_inverse'.lower(), False) self.prob = params.get('InOutPainting_probability'.lower(), 0.5) self.in_prob = params.get('InPainting_probability'.lower(), 0.5) - params['InPainting_probability'] = 1.0 - params['outPainting_probability'] = 1.0 + params['InPainting_probability'.lower()] = 1.0 + params['OutPainting_probability'.lower()] = 1.0 self.inpaint = InPainting(params) self.outpaint = OutPainting(params) @@ -423,35 +401,23 @@ class PatchSwaping(AbstractTransform): """ def __init__(self, params): super(PatchSwaping, self).__init__(params) + self.block_range = params.get('PatchSwaping_block_range'.lower(), (10, 20)) + self.block_size = params.get('PatchSwaping_block_size'.lower(), [8, 16, 16]) self.inverse = params.get('PatchSwaping_inverse'.lower(), False) - self.swap_t = params.get('PatchSwaping_swap_time'.lower(), (1, 6)) - self.patch_size_min = params.get('PatchSwaping_patch_size_min'.lower(), None) - self.patch_size_max = params.get('PatchSwaping_patch_size_max'.lower(), None) - - def __call__(self, sample): + def __call__(self, sample): image= sample['image'] img_shape = image.shape img_dim = len(img_shape) - 1 assert(img_dim == 2 or img_dim == 3) - img_out = image - - C, D, H, W = image.shape - patch_size = [random.randint(self.patch_size_min[i], self.patch_size_max[i]) for \ - i in range(img_dim)] - - coordinate_list = [] - for d in range(0, D-patch_size[0], patch_size[0]): - for h in range(0, H-patch_size[1], patch_size[1]): - for w in range(0, W-patch_size[2], patch_size[2]): - coordinate_list.append((d, h, w)) - random.shuffle(coordinate_list) + img_out = copy.deepcopy(image) - for t in range(self.swap_t): - pos_a0 = coordinate_list[2*t] - pos_b0 = coordinate_list[2*t + 1] - pos_a1 = [pos_a0[i] + patch_size[i] for i in range(img_dim)] - pos_b1 = [pos_b0[i] + patch_size[i] for i in range(img_dim)] + block_num = random.randint(self.block_range[0], self.block_range[1]) + for t in range(block_num): + pos_a0 = [random.randint(0, img_shape[-3+i] - self.block_size[i]) for i in range(img_dim)] + pos_b0 = [random.randint(0, img_shape[-3+i] - self.block_size[i]) for i in range(img_dim)] + pos_a1 = [pos_a0[i] + self.block_size[i] for i in range(img_dim)] + pos_b1 = [pos_b0[i] + self.block_size[i] for i in range(img_dim)] img_out[:, pos_a0[0]:pos_a1[0], pos_a0[1]:pos_a1[1], pos_a0[2]:pos_a1[2]] = \ image[:, pos_b0[0]:pos_b1[0], pos_b0[1]:pos_b1[1], pos_b0[2]:pos_b1[2]] img_out[:, pos_b0[0]:pos_b1[0], pos_b0[1]:pos_b1[1], pos_b0[2]:pos_b1[2]] = \ diff --git a/pymic/transform/normalize.py b/pymic/transform/normalize.py index 3d28c0d..5f0e4ec 100644 --- a/pymic/transform/normalize.py +++ b/pymic/transform/normalize.py @@ -74,7 +74,7 @@ def __call__(self, sample): class NormalizeWithMinMax(AbstractTransform): - """Nomralize the image to [0, 1]. The shape should be [C, D, H, W] or [C, H, W]. + """Nomralize the image to [-1, 1]. The shape should be [C, D, H, W] or [C, H, W]. The arguments should be written in the `params` dictionary, and it has the following fields: @@ -112,13 +112,13 @@ def __call__(self, sample): img_chn[img_chn < v0] = v0 img_chn[img_chn > v1] = v1 - img_chn = (img_chn - v0) / (v1 - v0) + img_chn = 2.0* (img_chn - v0) / (v1 - v0) -1.0 image[chn] = img_chn sample['image'] = image return sample class NormalizeWithPercentiles(AbstractTransform): - """Nomralize the image to [0, 1] with percentiles for given channels. + """Nomralize the image to [-1, 1] with percentiles for given channels. The shape should be [C, D, H, W] or [C, H, W]. The arguments should be written in the `params` dictionary, and it has the @@ -152,7 +152,7 @@ def __call__(self, sample): img_chn[img_chn < v0] = v0 img_chn[img_chn > v1] = v1 - img_chn = (img_chn - v0) / (v1 - v0) + img_chn = 2.0* (img_chn - v0) / (v1 - v0) -1.0 image[chn] = img_chn sample['image'] = image return sample \ No newline at end of file diff --git a/pymic/transform/rescale.py b/pymic/transform/rescale.py index fa4f052..2896a4e 100644 --- a/pymic/transform/rescale.py +++ b/pymic/transform/rescale.py @@ -217,6 +217,7 @@ def inverse_transform_for_prediction(self, sample): origin_shape = json.loads(sample['Resample_origin_shape'][0]) else: origin_shape = json.loads(sample['Resample_origin_shape']) + origin_dim = len(origin_shape) - 1 predict = sample['predict'] input_shape = predict.shape diff --git a/pymic/transform/trans_dict.py b/pymic/transform/trans_dict.py index a28e848..ed5ad0c 100644 --- a/pymic/transform/trans_dict.py +++ b/pymic/transform/trans_dict.py @@ -50,6 +50,7 @@ 'ChannelWiseThresholdWithNormalize': ChannelWiseThresholdWithNormalize, 'CropWithBoundingBox': CropWithBoundingBox, 'CropWithForeground': CropWithForeground, + 'CropHumanRegionFromCT': CropHumanRegionFromCT, 'CenterCrop': CenterCrop, 'GrayscaleToRGB': GrayscaleToRGB, 'GammaCorrection': GammaCorrection, @@ -67,6 +68,7 @@ 'NormalizeWithPercentiles': NormalizeWithPercentiles, 'PartialLabelToProbability':PartialLabelToProbability, 'RandomCrop': RandomCrop, + 'RandomSlice': RandomSlice, 'RandomResizedCrop': RandomResizedCrop, 'RandomRescale': RandomRescale, 'RandomTranspose': RandomTranspose, diff --git a/pymic/util/image_process.py b/pymic/util/image_process.py index d6a7220..c813e5d 100644 --- a/pymic/util/image_process.py +++ b/pymic/util/image_process.py @@ -327,7 +327,7 @@ def convert_label(label, source_list, target_list): label_converted[label_s > 0] = label_t[label_s > 0] return label_converted -def resample_sitk_image_to_given_spacing(image, spacing, order): +def resample_sitk_image_to_given_spacing(image, spacing, order = 3): """ Resample an sitk image objct to a given spacing. From 3a0e1501e75b58bdca5e089309fcec221a92813e Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 10 Jan 2024 14:32:22 +0800 Subject: [PATCH 175/225] update readme --- README.md | 4 +- pymic/net/net3d/unet3d.py | 151 +++++++++++++++-------------------- pymic/net_run/preprocess.py | 82 ------------------- pymic/transform/normalize.py | 2 +- setup.py | 2 +- 5 files changed, 68 insertions(+), 173 deletions(-) delete mode 100644 pymic/net_run/preprocess.py diff --git a/README.md b/README.md index ac7f9b6..4af5eff 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # PyMIC: A Pytorch-Based Toolkit for Medical Image Computing -PyMIC is a pytorch-based toolkit for medical image computing with annotation-efficient deep learning. Despite that pytorch is a fantastic platform for deep learning, using it for medical image computing is not straightforward as medical images are often with high dimension and large volume, multiple modalities and difficulies in annotating. This toolkit is developed to facilitate medical image computing researchers so that training and testing deep learning models become easier. It is very friendly to researchers who are new to this area. Even without writing any code, you can use PyMIC commands to train and test a model by simply editing configuration files. PyMIC is developed to support learning with imperfect labels, including semi-supervised and weakly supervised learning, and learning with noisy annotations. +PyMIC is a pytorch-based toolkit for medical image computing with annotation-efficient deep learning. Despite that pytorch is a fantastic platform for deep learning, using it for medical image computing is not straightforward as medical images are often with high dimension and large volume, multiple modalities and difficulies in annotating. This toolkit is developed to facilitate medical image computing researchers so that training and testing deep learning models become easier. It is very friendly to researchers who are new to this area. Even without writing any code, you can use PyMIC commands to train and test a model by simply editing configuration files. PyMIC is developed to support learning with imperfect labels, including semi-supervised, self-supervised, and weakly supervised learning, and learning with noisy annotations. Currently PyMIC supports 2D/3D medical image classification and segmentation, and it is still under development. If you use this toolkit, please cite the following paper: @@ -23,7 +23,7 @@ BibTeX entry: # Features PyMIC provides flixible modules for medical image computing tasks including classification and segmentation. It currently provides the following functions: -* Support for annotation-efficient image segmentation, especially for semi-supervised, self-supervised, weakly-supervised and noisy-label learning. +* Support for annotation-efficient image segmentation, especially for semi-supervised, self-supervised, self-supervised, weakly-supervised and noisy-label learning. * User friendly: For beginners, you only need to edit the configuration files for model training and inference, without writing code. For advanced users, you can customize different modules (networks, loss functions, training pipeline, etc) and easily integrate them into PyMIC. * Easy-to-use I/O interface to read and write different 2D and 3D images. * Various data pre-processing/transformation methods before sending a tensor into a network. diff --git a/pymic/net/net3d/unet3d.py b/pymic/net/net3d/unet3d.py index a17bcb8..e383e77 100644 --- a/pymic/net/net3d/unet3d.py +++ b/pymic/net/net3d/unet3d.py @@ -1,10 +1,11 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division +import logging import torch import torch.nn as nn import numpy as np -from torch.nn.functional import interpolate + class ConvBlock(nn.Module): """ @@ -56,22 +57,32 @@ class UpBlock(nn.Module): :param in_channels2: (int) Channel number of low-level features. :param out_channels: (int) Output channel number. :param dropout_p: (int) Dropout probability. - :param trilinear: (bool) Use trilinear for up-sampling (by default). - If False, deconvolution is used for up-sampling. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear`). The default value + is 2 (`Trilinear`). """ - def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, - trilinear=True): + def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, up_mode=2): super(UpBlock, self).__init__() - self.trilinear = trilinear - if trilinear: - self.conv1x1 = nn.Conv3d(in_channels1, in_channels2, kernel_size = 1) - self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True) + if(isinstance(up_mode, int)): + up_mode_values = ["transconv", "nearest", "trilinear"] + if(up_mode > 2): + raise ValueError("The upsample mode should be 0-2, but {0:} is given.".format(up_mode)) + self.up_mode = up_mode_values[up_mode] else: + self.up_mode = up_mode.lower() + + if (self.up_mode == "transconv"): self.up = nn.ConvTranspose3d(in_channels1, in_channels2, kernel_size=2, stride=2) + else: + self.conv1x1 = nn.Conv3d(in_channels1, in_channels2, kernel_size = 1) + if(self.up_mode == "nearest"): + self.up = nn.Upsample(scale_factor=2, mode=self.up_mode) + else: + self.up = nn.Upsample(scale_factor=2, mode=self.up_mode, align_corners=True) self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p) def forward(self, x1, x2): - if self.trilinear: + if self.up_mode != "transconv": x1 = self.conv1x1(x1) x1 = self.up(x1) x = torch.cat([x2, x1], dim=1) @@ -129,9 +140,10 @@ class Decoder(nn.Module): :param dropout: (list) The dropout ratio for each resolution level. The length should be the same as that of `feature_chns`. :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - :param multiscale_pred: (bool) Get multi-scale prediction. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear`). The default value + is 2 (`Trilinear`). + :param multiscale_pred: (bool) Get multi-scale prediction. """ def __init__(self, params): super(Decoder, self).__init__() @@ -140,21 +152,25 @@ def __init__(self, params): self.ft_chns = self.params['feature_chns'] self.dropout = self.params['dropout'] self.n_class = self.params['class_num'] - self.trilinear = self.params.get('trilinear', True) + self.up_mode = self.params.get('up_mode', 2) self.mul_pred = self.params.get('multiscale_pred', False) - assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) if(len(self.ft_chns) == 5): - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.trilinear) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.trilinear) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.trilinear) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.trilinear) + self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.up_mode) + self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.up_mode) + self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.up_mode) + self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.up_mode) self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 1) + if(self.mul_pred): self.out_conv1 = nn.Conv3d(self.ft_chns[1], self.n_class, kernel_size = 1) self.out_conv2 = nn.Conv3d(self.ft_chns[2], self.n_class, kernel_size = 1) self.out_conv3 = nn.Conv3d(self.ft_chns[3], self.n_class, kernel_size = 1) + self.stage = 'train' + + def set_stage(self, stage): + self.stage = stage def forward(self, x): if(len(self.ft_chns) == 5): @@ -169,7 +185,7 @@ def forward(self, x): x_d1 = self.up3(x_d2, x1) x_d0 = self.up4(x_d1, x0) output = self.out_conv(x_d0) - if(self.mul_pred): + if(self.mul_pred and self.stage == 'train'): output1 = self.out_conv1(x_d1) output2 = self.out_conv2(x_d2) output3 = self.out_conv3(x_d3) @@ -196,77 +212,38 @@ class UNet3D(nn.Module): :param dropout: (list) The dropout ratio for each resolution level. The length should be the same as that of `feature_chns`. :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using trilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear`). The default value + is 2 (`Trilinear`). :param multiscale_pred: (bool) Get multi-scale prediction. """ def __init__(self, params): super(UNet3D, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.n_class = self.params['class_num'] - self.trilinear = self.params['trilinear'] - self.mul_pred = self.params['multiscale_pred'] - assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) - - self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) - self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) - self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) - self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) - if(len(self.ft_chns) == 5): - self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], - dropout_p = self.dropout[3], trilinear=self.trilinear) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], - dropout_p = self.dropout[2], trilinear=self.trilinear) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], - dropout_p = self.dropout[1], trilinear=self.trilinear) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], - dropout_p = self.dropout[0], trilinear=self.trilinear) - - self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 1) - if(self.mul_pred): - self.out_conv1 = nn.Conv3d(self.ft_chns[1], self.n_class, kernel_size = 1) - self.out_conv2 = nn.Conv3d(self.ft_chns[2], self.n_class, kernel_size = 1) - self.out_conv3 = nn.Conv3d(self.ft_chns[3], self.n_class, kernel_size = 1) + params = self.get_default_parameters(params) + for p in params: + print(p, params[p]) + self.stage = 'train' + self.encoder = Encoder(params) + self.decoder = Decoder(params) + + def get_default_parameters(self, params): + default_param = { + 'feature_chns': [32, 64, 128, 256, 512], + 'dropout': [0.0, 0.0, 0.2, 0.3, 0.4], + 'up_mode': 2, + 'multiscale_pred': False + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params + + def set_stage(self, stage): + self.stage = stage + self.decoder.set_stage(stage) def forward(self, x): - x0 = self.in_conv(x) - x1 = self.down1(x0) - x2 = self.down2(x1) - x3 = self.down3(x2) - if(len(self.ft_chns) == 5): - x4 = self.down4(x3) - x_d3 = self.up1(x4, x3) - else: - x_d3 = x3 - x_d2 = self.up2(x_d3, x2) - x_d1 = self.up3(x_d2, x1) - x_d0 = self.up4(x_d1, x0) - output = self.out_conv(x_d0) - if(self.mul_pred): - output1 = self.out_conv1(x_d1) - output2 = self.out_conv2(x_d2) - output3 = self.out_conv3(x_d3) - output = [output, output1, output2, output3] + f = self.encoder(x) + output = self.decoder(f) return output - -if __name__ == "__main__": - params = {'in_chns':4, - 'class_num': 2, - 'feature_chns':[2, 8, 32, 64], - 'dropout' : [0, 0, 0, 0.5], - 'trilinear': True, - 'multiscale_pred': False} - Net = UNet3D(params) - Net = Net.double() - - x = np.random.rand(4, 4, 96, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - y = y.detach().numpy() - print(y.shape) diff --git a/pymic/net_run/preprocess.py b/pymic/net_run/preprocess.py deleted file mode 100644 index f2bbe0f..0000000 --- a/pymic/net_run/preprocess.py +++ /dev/null @@ -1,82 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division -import logging -import os -import sys -import shutil -from datetime import datetime -from pymic import TaskType -from pymic.util.parse_config import * -from pymic.net_run.agent_cls import ClassificationAgent -from pymic.net_run.agent_seg import SegmentationAgent -from pymic.net_run.semi_sup import SSLMethodDict -from pymic.net_run.weak_sup import WSLMethodDict -from pymic.net_run.self_sup import SelfSupMethodDict -from pymic.net_run.noisy_label import NLLMethodDict -# from pymic.net_run.self_sup import SelfSLSegAgent - -def get_seg_rec_agent(config, sup_type): - assert(sup_type in ['fully_sup', 'semi_sup', 'self_sup', 'weak_sup', 'noisy_label']) - if(sup_type == 'fully_sup'): - logging.info("\n********** Fully Supervised Learning **********\n") - agent = SegmentationAgent(config, 'train') - elif(sup_type == 'semi_sup'): - logging.info("\n********** Semi Supervised Learning **********\n") - method = config['semi_supervised_learning']['method_name'] - agent = SSLMethodDict[method](config, 'train') - elif(sup_type == 'weak_sup'): - logging.info("\n********** Weakly Supervised Learning **********\n") - method = config['weakly_supervised_learning']['method_name'] - agent = WSLMethodDict[method](config, 'train') - elif(sup_type == 'noisy_label'): - logging.info("\n********** Noisy Label Learning **********\n") - method = config['noisy_label_learning']['method_name'] - agent = NLLMethodDict[method](config, 'train') - elif(sup_type == 'self_sup'): - logging.info("\n********** Self Supervised Learning **********\n") - method = config['self_supervised_learning']['method_name'] - agent = SelfSupMethodDict[method](config, 'train') - else: - raise ValueError("undefined supervision type: {0:}".format(sup_type)) - return agent - -def main(): - """ - The main function for running a network for training. - """ - if(len(sys.argv) < 2): - print('Number of arguments should be 2. e.g.') - print(' pymic_train config.cfg') - exit() - cfg_file = str(sys.argv[1]) - if(not os.path.isfile(cfg_file)): - raise ValueError("The config file does not exist: " + cfg_file) - config = parse_config(cfg_file) - config = synchronize_config(config) - log_dir = config['training']['ckpt_save_dir'] - if(not os.path.exists(log_dir)): - os.makedirs(log_dir, exist_ok=True) - dst_cfg = cfg_file if "/" not in cfg_file else cfg_file.split("/")[-1] - shutil.copy(cfg_file, log_dir + "/" + dst_cfg) - datetime_str = str(datetime.now())[:-7].replace(":", "_") - if sys.version.startswith("3.9"): - logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(datetime_str), - level=logging.INFO, format='%(message)s', force=True) # for python 3.9 - else: - logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(datetime_str), - level=logging.INFO, format='%(message)s') # for python 3.6 - logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) - logging_config(config) - task = config['dataset']['task_type'] - if(task == TaskType.CLASSIFICATION_ONE_HOT or task == TaskType.CLASSIFICATION_COEXIST): - agent = ClassificationAgent(config, 'train') - else: - sup_type = config['dataset'].get('supervise_type', 'fully_sup') - agent = get_seg_rec_agent(config, sup_type) - - agent.run() - -if __name__ == "__main__": - main() - - diff --git a/pymic/transform/normalize.py b/pymic/transform/normalize.py index 5f0e4ec..643c12e 100644 --- a/pymic/transform/normalize.py +++ b/pymic/transform/normalize.py @@ -38,7 +38,7 @@ def __init__(self, params): self.chns = params.get('NormalizeWithMeanStd_channels'.lower(), None) self.mean = params.get('NormalizeWithMeanStd_mean'.lower(), None) self.std = params.get('NormalizeWithMeanStd_std'.lower(), None) - self.mask_thrd = params.get('NormalizeWithMeanStd_mask_threshold'.lower(), 1.0) + self.mask_thrd = params.get('NormalizeWithMeanStd_mask_threshold'.lower(), None) self.bg_random = params.get('NormalizeWithMeanStd_set_background_to_random'.lower(), True) self.inverse = params.get('NormalizeWithMeanStd_inverse'.lower(), False) diff --git a/setup.py b/setup.py index 36daf9a..22406a0 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setuptools.setup( name = 'PYMIC', - version = "0.4.0", + version = "0.4.0.1", author ='PyMIC Consortium', author_email = 'wguotai@gmail.com', description = description, From cafaac0ca5699750e34b1eeee092985085775949 Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 10 Jan 2024 14:39:40 +0800 Subject: [PATCH 176/225] fix minor issues for config --- pymic/loss/seg/mumford_shah.py | 7 +- pymic/net/net3d/trans3d/HiFormer_v1.py | 1010 --------------------- pymic/net/net3d/trans3d/HiFormer_v2.py | 381 -------- pymic/net/net3d/trans3d/HiFormer_v3.py | 455 ---------- pymic/net/net3d/trans3d/HiFormer_v4.py | 455 ---------- pymic/net/net3d/trans3d/HiFormer_v5.py | 308 ------- pymic/net/net3d/trans3d/MedFormer_v1.py | 173 ---- pymic/net/net3d/trans3d/MedFormer_v2.py | 464 ---------- pymic/net/net3d/trans3d/MedFormer_v3.py | 255 ------ pymic/net/net3d/trans3d/MedFormer_va1.py | 105 --- pymic/net/net3d/trans3d/__init__.py | 0 pymic/net/net3d/trans3d/nnFormer_wrap.py | 43 - pymic/net/net3d/trans3d/unetr.py | 227 ----- pymic/net/net3d/trans3d/unetr_pp.py | 469 ---------- pymic/net/net3d/trans3d/unetr_pp_block.py | 278 ------ pymic/net_run/agent_preprocess.py | 80 +- pymic/net_run/train.py | 1 - pymic/util/evaluation_seg.py | 96 +- pymic/util/image_process.py | 12 + pymic/util/parse_config.py | 29 +- pymic/util/preprocess.py | 63 -- setup.py | 2 +- 22 files changed, 128 insertions(+), 4785 deletions(-) delete mode 100644 pymic/net/net3d/trans3d/HiFormer_v1.py delete mode 100644 pymic/net/net3d/trans3d/HiFormer_v2.py delete mode 100644 pymic/net/net3d/trans3d/HiFormer_v3.py delete mode 100644 pymic/net/net3d/trans3d/HiFormer_v4.py delete mode 100644 pymic/net/net3d/trans3d/HiFormer_v5.py delete mode 100644 pymic/net/net3d/trans3d/MedFormer_v1.py delete mode 100644 pymic/net/net3d/trans3d/MedFormer_v2.py delete mode 100644 pymic/net/net3d/trans3d/MedFormer_v3.py delete mode 100644 pymic/net/net3d/trans3d/MedFormer_va1.py delete mode 100644 pymic/net/net3d/trans3d/__init__.py delete mode 100644 pymic/net/net3d/trans3d/nnFormer_wrap.py delete mode 100644 pymic/net/net3d/trans3d/unetr.py delete mode 100644 pymic/net/net3d/trans3d/unetr_pp.py delete mode 100644 pymic/net/net3d/trans3d/unetr_pp_block.py delete mode 100644 pymic/util/preprocess.py diff --git a/pymic/loss/seg/mumford_shah.py b/pymic/loss/seg/mumford_shah.py index 6da51b5..db9368a 100644 --- a/pymic/loss/seg/mumford_shah.py +++ b/pymic/loss/seg/mumford_shah.py @@ -3,8 +3,9 @@ import torch import torch.nn as nn +from pymic.loss.seg.abstract import AbstractSegLoss -class MumfordShahLoss(nn.Module): +class MumfordShahLoss(AbstractSegLoss): """ Implementation of Mumford Shah Loss for weakly supervised learning. @@ -76,8 +77,8 @@ def forward(self, loss_input_dict): image = loss_input_dict['image'] if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) pred_shape = list(predict.shape) if(len(pred_shape) == 5): diff --git a/pymic/net/net3d/trans3d/HiFormer_v1.py b/pymic/net/net3d/trans3d/HiFormer_v1.py deleted file mode 100644 index af73683..0000000 --- a/pymic/net/net3d/trans3d/HiFormer_v1.py +++ /dev/null @@ -1,1010 +0,0 @@ -from einops import rearrange -from copy import deepcopy -from nnformer.utilities.nd_softmax import softmax_helper -from torch import nn -import torch -import numpy as np -import torch.nn.functional -import torch.nn.functional as F -import torch.utils.checkpoint as checkpoint -from timm.models.layers import DropPath, to_3tuple, trunc_normal_ -from pymic.net.net3d.unet3d import ConvBlock, DownBlock -# from nnFormer -class ContiguousGrad(torch.autograd.Function): - @staticmethod - def forward(ctx, x): - return x - @staticmethod - def backward(ctx, grad_out): - return grad_out.contiguous() - -# from nnFormer -class Mlp(nn.Module): - """ Multilayer perceptron.""" - - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - -# from nnFormer -def window_partition(x, window_size): - - B, S, H, W, C = x.shape - x = x.view(B, S // window_size, window_size, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size, window_size, window_size, C) - return windows - -# from nnFormer -def window_reverse(windows, window_size, S, H, W): - - B = int(windows.shape[0] / (S * H * W / window_size / window_size / window_size)) - x = windows.view(B, S // window_size, H // window_size, W // window_size, window_size, window_size, window_size, -1) - x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, S, H, W, -1) - return x - - -# from nnFormer -class SwinTransformerBlock_kv(nn.Module): - - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - self.attn = WindowAttention_kv( - dim, window_size=to_3tuple(self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) - - #self.window_size=to_3tuple(self.window_size) - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - def forward(self, x, mask_matrix,skip=None,x_up=None): - - B, L, C = x.shape - S, H, W = self.input_resolution - - assert L == S * H * W, "input feature has wrong size" - - shortcut = x - skip = self.norm1(skip) - x_up = self.norm1(x_up) - - skip = skip.view(B, S, H, W, C) - x_up = x_up.view(B, S, H, W, C) - x = x.view(B, S, H, W, C) - # pad feature maps to multiples of window size - pad_r = (self.window_size - W % self.window_size) % self.window_size - pad_b = (self.window_size - H % self.window_size) % self.window_size - pad_g = (self.window_size - S % self.window_size) % self.window_size - - skip = F.pad(skip, (0, 0, 0, pad_r, 0, pad_b, 0, pad_g)) - x_up = F.pad(x_up, (0, 0, 0, pad_r, 0, pad_b, 0, pad_g)) - _, Sp, Hp, Wp, _ = skip.shape - - - - # cyclic shift - if self.shift_size > 0: - skip = torch.roll(skip, shifts=(-self.shift_size, -self.shift_size,-self.shift_size), dims=(1, 2,3)) - x_up = torch.roll(x_up, shifts=(-self.shift_size, -self.shift_size,-self.shift_size), dims=(1, 2,3)) - attn_mask = mask_matrix - else: - skip = skip - x_up=x_up - attn_mask = None - # partition windows - skip = window_partition(skip, self.window_size) - skip = skip.view(-1, self.window_size * self.window_size * self.window_size, - C) - x_up = window_partition(x_up, self.window_size) - x_up = x_up.view(-1, self.window_size * self.window_size * self.window_size, - C) - attn_windows=self.attn(skip,x_up,mask=attn_mask,pos_embed=None) - - # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, self.window_size, Sp, Hp, Wp) # B H' W' C - - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size, self.shift_size), dims=(1, 2, 3)) - else: - x = shifted_x - - if pad_r > 0 or pad_b > 0 or pad_g > 0: - x = x[:, :S, :H, :W, :].contiguous() - - x = x.view(B, S * H * W, C) - - # FFN - x = shortcut + self.drop_path(x) - x = x + self.drop_path(self.mlp(self.norm2(x))) - - return x - -# from nnFormer -class WindowAttention_kv(nn.Module): - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): - - super().__init__() - self.dim = dim - self.window_size = window_size - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 - - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), - num_heads)) - - # get pair-wise relative position index for each token inside the window - coords_s = torch.arange(self.window_size[0]) - coords_h = torch.arange(self.window_size[1]) - coords_w = torch.arange(self.window_size[2]) - coords = torch.stack(torch.meshgrid([coords_s, coords_h, coords_w])) - coords_flatten = torch.flatten(coords, 1) - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] - relative_coords = relative_coords.permute(1, 2, 0).contiguous() - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 2] += self.window_size[2] - 1 - - relative_coords[:, :, 0] *= 3 * self.window_size[1] - 1 - relative_coords[:, :, 1] *= 2 * self.window_size[1] - 1 - - relative_position_index = relative_coords.sum(-1) - self.register_buffer("relative_position_index", relative_position_index) - - self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - self.softmax = nn.Softmax(dim=-1) - trunc_normal_(self.relative_position_bias_table, std=.02) - - - def forward(self, skip,x_up,pos_embed=None, mask=None): - - B_, N, C = skip.shape - - kv = self.kv(skip) - q = x_up - - kv=kv.reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() - q = q.reshape(B_,N,self.num_heads,C//self.num_heads).permute(0,2,1,3).contiguous() - k,v = kv[0], kv[1] - q = q * self.scale - attn = (q @ k.transpose(-2, -1).contiguous()) - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1] * self.window_size[2], - self.window_size[0] * self.window_size[1] * self.window_size[2], -1) - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C).contiguous() - if pos_embed is not None: - x = x + pos_embed - x = self.proj(x) - x = self.proj_drop(x) - return x - -# from nnFormer -class WindowAttention(nn.Module): - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): - - super().__init__() - self.dim = dim - self.window_size = window_size - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 - - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), - num_heads)) - - # get pair-wise relative position index for each token inside the window - coords_s = torch.arange(self.window_size[0]) - coords_h = torch.arange(self.window_size[1]) - coords_w = torch.arange(self.window_size[2]) - coords = torch.stack(torch.meshgrid([coords_s, coords_h, coords_w])) - coords_flatten = torch.flatten(coords, 1) - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] - relative_coords = relative_coords.permute(1, 2, 0).contiguous() - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 2] += self.window_size[2] - 1 - - relative_coords[:, :, 0] *= 3 * self.window_size[1] - 1 - relative_coords[:, :, 1] *= 2 * self.window_size[1] - 1 - - relative_position_index = relative_coords.sum(-1) - self.register_buffer("relative_position_index", relative_position_index) - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - trunc_normal_(self.relative_position_bias_table, std=.02) - - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask=None,pos_embed=None): - - B_, N, C = x.shape - - qkv = self.qkv(x) - - qkv=qkv.reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - - q = q * self.scale - attn = (q @ k.transpose(-2, -1).contiguous()) - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1] * self.window_size[2], - self.window_size[0] * self.window_size[1] * self.window_size[2], -1) - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C).contiguous() - if pos_embed is not None: - x = x+pos_embed - x = self.proj(x) - x = self.proj_drop(x) - return x - -# from nnFormer -class SwinTransformerBlock(nn.Module): - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - - self.attn = WindowAttention( - dim, window_size=to_3tuple(self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) - - - - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - - def forward(self, x, mask_matrix): - - B, L, C = x.shape - S, H, W = self.input_resolution - - assert L == S * H * W, "input feature has wrong size" - - shortcut = x - x = self.norm1(x) - x = x.view(B, S, H, W, C) - - # pad feature maps to multiples of window size - pad_r = (self.window_size - W % self.window_size) % self.window_size - pad_b = (self.window_size - H % self.window_size) % self.window_size - pad_g = (self.window_size - S % self.window_size) % self.window_size - - x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b, 0, pad_g)) - _, Sp, Hp, Wp, _ = x.shape - - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size,-self.shift_size), dims=(1, 2,3)) - attn_mask = mask_matrix - else: - shifted_x = x - attn_mask = None - - # partition windows - x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size * self.window_size * self.window_size, - C) - - # W-MSA/SW-MSA - attn_windows = self.attn(x_windows, mask=attn_mask,pos_embed=None) - - # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, self.window_size, Sp, Hp, Wp) - - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size, self.shift_size), dims=(1, 2, 3)) - else: - x = shifted_x - - if pad_r > 0 or pad_b > 0 or pad_g > 0: - x = x[:, :S, :H, :W, :].contiguous() - - x = x.view(B, S * H * W, C) - - # FFN - x = shortcut + self.drop_path(x) - x = x + self.drop_path(self.mlp(self.norm2(x))) - - return x - -# from nnFormer -class PatchMerging(nn.Module): - - - def __init__(self, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - self.reduction = nn.Conv3d(dim,dim*2,kernel_size=3,stride=2,padding=1) - - self.norm = norm_layer(dim) - - def forward(self, x, S, H, W): - - B, L, C = x.shape - assert L == H * W * S, "input feature has wrong size" - x = x.view(B, S, H, W, C) - - x = F.gelu(x) - x = self.norm(x) - x=x.permute(0,4,1,2,3).contiguous() - x=self.reduction(x) - x=x.permute(0,2,3,4,1).contiguous().view(B,-1,2*C) - - return x - -# from nnFormer -class Patch_Expanding(nn.Module): - def __init__(self, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - - self.norm = norm_layer(dim) - self.up=nn.ConvTranspose3d(dim,dim//2,2,2) - def forward(self, x, S, H, W): - - - B, L, C = x.shape - assert L == H * W * S, "input feature has wrong size" - - x = x.view(B, S, H, W, C) - - - - x = self.norm(x) - x=x.permute(0,4,1,2,3).contiguous() - x = self.up(x) - x = ContiguousGrad.apply(x) - x=x.permute(0,2,3,4,1).contiguous().view(B,-1,C//2) - - return x - -# from nnFormer -class BasicLayer(nn.Module): - - def __init__(self, - dim, - input_resolution, - depth, - num_heads, - window_size=7, - mlp_ratio=4., - qkv_bias=True, - qk_scale=None, - drop=0., - attn_drop=0., - drop_path=0., - norm_layer=nn.LayerNorm, - downsample=True - ): - super().__init__() - self.window_size = window_size - self.shift_size = window_size // 2 - self.depth = depth - # build blocks - - self.blocks = nn.ModuleList([ - SwinTransformerBlock( - dim=dim, - input_resolution=input_resolution, - num_heads=num_heads, - window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop, - attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) - for i in range(depth)]) - - # patch merging layer - if downsample is not None: - self.downsample = downsample(dim=dim, norm_layer=norm_layer) - else: - self.downsample = None - - def forward(self, x, S, H, W): - - - # calculate attention mask for SW-MSA - Sp = int(np.ceil(S / self.window_size)) * self.window_size - Hp = int(np.ceil(H / self.window_size)) * self.window_size - Wp = int(np.ceil(W / self.window_size)) * self.window_size - img_mask = torch.zeros((1, Sp, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 - s_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for s in s_slices: - for h in h_slices: - for w in w_slices: - img_mask[:, s, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) - mask_windows = mask_windows.view(-1, - self.window_size * self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - for blk in self.blocks: - - x = blk(x, attn_mask) - if self.downsample is not None: - x_down = self.downsample(x, S, H, W) - Ws, Wh, Ww = (S + 1) // 2, (H + 1) // 2, (W + 1) // 2 - return x, S, H, W, x_down, Ws, Wh, Ww - else: - return x, S, H, W, x, S, H, W - -# from nnFormer -class BasicLayer_up(nn.Module): - - def __init__(self, - dim, - input_resolution, - depth, - num_heads, - window_size=7, - mlp_ratio=4., - qkv_bias=True, - qk_scale=None, - drop=0., - attn_drop=0., - drop_path=0., - norm_layer=nn.LayerNorm, - upsample=True - ): - super().__init__() - self.window_size = window_size - self.shift_size = window_size // 2 - self.depth = depth - - - # build blocks - self.blocks = nn.ModuleList() - self.blocks.append( - SwinTransformerBlock_kv( - dim=dim, - input_resolution=input_resolution, - num_heads=num_heads, - window_size=window_size, - shift_size=0 , - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop, - attn_drop=attn_drop, - drop_path=drop_path[0] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) - ) - for i in range(depth-1): - self.blocks.append( - SwinTransformerBlock( - dim=dim, - input_resolution=input_resolution, - num_heads=num_heads, - window_size=window_size, - shift_size=window_size // 2 , - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop, - attn_drop=attn_drop, - drop_path=drop_path[i+1] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) - ) - - - - self.Upsample = upsample(dim=2*dim, norm_layer=norm_layer) - def forward(self, x,skip, S, H, W): - - - x_up = self.Upsample(x, S, H, W) - - x = x_up + skip - S, H, W = S * 2, H * 2, W * 2 - # calculate attention mask for SW-MSA - Sp = int(np.ceil(S / self.window_size)) * self.window_size - Hp = int(np.ceil(H / self.window_size)) * self.window_size - Wp = int(np.ceil(W / self.window_size)) * self.window_size - img_mask = torch.zeros((1, Sp, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 - s_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for s in s_slices: - for h in h_slices: - for w in w_slices: - img_mask[:, s, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, - self.window_size * self.window_size * self.window_size) # 3d��3��winds�˻�����Ŀ�Ǻܴ�ģ�����winds����̫�� - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - - x = self.blocks[0](x, attn_mask,skip=skip,x_up=x_up) - for i in range(self.depth-1): - x = self.blocks[i+1](x,attn_mask) - - return x, S, H, W - - -# from nnFormer -class project(nn.Module): - def __init__(self,in_dim,out_dim,stride,padding,activate,norm,last=False): - super().__init__() - self.out_dim=out_dim - self.conv1=nn.Conv3d(in_dim,out_dim,kernel_size=3,stride=stride,padding=padding) - self.conv2=nn.Conv3d(out_dim,out_dim,kernel_size=3,stride=1,padding=1) - self.activate=activate() - self.norm1=norm(out_dim) - self.last=last - if not last: - self.norm2=norm(out_dim) - - def forward(self,x): - x=self.conv1(x) - x=self.activate(x) - #norm1 - Ws, Wh, Ww = x.size(2), x.size(3), x.size(4) - x = x.flatten(2).transpose(1, 2).contiguous() - x = self.norm1(x) - x = x.transpose(1, 2).contiguous().view(-1, self.out_dim, Ws, Wh, Ww) - - - x=self.conv2(x) - if not self.last: - x=self.activate(x) - #norm2 - Ws, Wh, Ww = x.size(2), x.size(3), x.size(4) - x = x.flatten(2).transpose(1, 2).contiguous() - x = self.norm2(x) - x = x.transpose(1, 2).contiguous().view(-1, self.out_dim, Ws, Wh, Ww) - return x - - -# from nnFormer -class PatchEmbed_backup(nn.Module): - def __init__(self, patch_size=4, in_chans=4, embed_dim=96, norm_layer=None): - super().__init__() - patch_size = to_3tuple(patch_size) - self.patch_size = patch_size - - self.in_chans = in_chans - self.embed_dim = embed_dim - stride1=[patch_size[0]//2,patch_size[1]//2,patch_size[2]//2] - stride2=[patch_size[0]//2,patch_size[1]//2,patch_size[2]//2] - self.proj1 = project(in_chans,embed_dim//2,stride1,1,nn.GELU,nn.LayerNorm,False) - self.proj2 = project(embed_dim//2,embed_dim,stride2,1,nn.GELU,nn.LayerNorm,True) - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x): - """Forward function.""" - # padding - _, _, S, H, W = x.size() - if W % self.patch_size[2] != 0: - x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2])) - if H % self.patch_size[1] != 0: - x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1])) - if S % self.patch_size[0] != 0: - x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - S % self.patch_size[0])) - x = self.proj1(x) # B C Ws Wh Ww - x = self.proj2(x) # B C Ws Wh Ww - if self.norm is not None: - Ws, Wh, Ww = x.size(2), x.size(3), x.size(4) - x = x.flatten(2).transpose(1, 2).contiguous() - x = self.norm(x) - x = x.transpose(1, 2).contiguous().view(-1, self.embed_dim, Ws, Wh, Ww) - - return x - - -class PatchEmbed(nn.Module): - """ - replace patch embed with conv layers""" - def __init__(self, in_chns=1, ft_chns = [32, 64, 128], dropout = [0, 0, 0.2]): - super().__init__() - self.in_conv= ConvBlock(in_chns, ft_chns[0], dropout[0]) - self.down1 = DownBlock(ft_chns[0], ft_chns[1], dropout[1]) - self.down2 = DownBlock(ft_chns[1], ft_chns[2], dropout[2]) - - - def forward(self, x): - """Forward function.""" - x0 = self.in_conv(x) - x1 = self.down1(x0) - x2 = self.down2(x1) - return x2 - -# from nnFormer -class Encoder(nn.Module): - - def __init__(self, - pretrain_img_size=224, - patch_size=4, - in_chans=1 , - embed_dim=96, - depths=[2, 2, 2, 2], - num_heads=[4, 8, 16, 32], - window_size=7, - mlp_ratio=4., - qkv_bias=True, - qk_scale=None, - drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0.2, - norm_layer=nn.LayerNorm, - patch_norm=True, - out_indices=(0, 1, 2, 3) - ): - super().__init__() - - self.pretrain_img_size = pretrain_img_size - - self.num_layers = len(depths) - print("number of layers in encoder", self.num_layers, depths) - self.embed_dim = embed_dim - self.patch_norm = patch_norm - self.out_indices = out_indices - - # split image into non-overlapping patches - # self.patch_embed = PatchEmbed( - # patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, - # norm_layer=norm_layer if self.patch_norm else None) - self.patch_embed = PatchEmbed(in_chans, ft_chns=[embed_dim // 4, embed_dim //2, embed_dim]) - - - - self.pos_drop = nn.Dropout(p=drop_rate) - - # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - # build layers - self.layers = nn.ModuleList() - for i_layer in range(self.num_layers): - layer = BasicLayer( - dim=int(embed_dim * 2 ** i_layer), - input_resolution=( - pretrain_img_size[0] // patch_size[0] // 2 ** i_layer, pretrain_img_size[1] // patch_size[1] // 2 ** i_layer, - pretrain_img_size[2] // patch_size[2] // 2 ** i_layer), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size[i_layer], - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop_rate, - attn_drop=attn_drop_rate, - drop_path=dpr[sum( - depths[:i_layer]):sum(depths[:i_layer + 1])], - norm_layer=norm_layer, - downsample=PatchMerging - if (i_layer < self.num_layers - 1) else None - ) - self.layers.append(layer) - - num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] - self.num_features = num_features - - # add a norm layer for each output - for i_layer in out_indices: - layer = norm_layer(num_features[i_layer]) - layer_name = f'norm{i_layer}' - self.add_module(layer_name, layer) - - - def forward(self, x): - """Forward function.""" - - x = self.patch_embed(x) - down=[] - Ws, Wh, Ww = x.size(2), x.size(3), x.size(4) - - x = x.flatten(2).transpose(1, 2).contiguous() - x = self.pos_drop(x) - - - for i in range(self.num_layers): - layer = self.layers[i] - x_out, S, H, W, x, Ws, Wh, Ww = layer(x, Ws, Wh, Ww) - if i in self.out_indices: - norm_layer = getattr(self, f'norm{i}') - x_out = norm_layer(x_out) - - out = x_out.view(-1, S, H, W, self.num_features[i]).permute(0, 4, 1, 2, 3).contiguous() - - down.append(out) - return down - - -# from nnFormer -class Decoder(nn.Module): - def __init__(self, - pretrain_img_size, - embed_dim, - patch_size=4, - depths=[2,2,2], - num_heads=[24,12,6], - window_size=4, - mlp_ratio=4., - qkv_bias=True, - qk_scale=None, - drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0.2, - norm_layer=nn.LayerNorm - ): - super().__init__() - - - self.num_layers = len(depths) - self.pos_drop = nn.Dropout(p=drop_rate) - - # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - # build layers - self.layers = nn.ModuleList() - for i_layer in range(self.num_layers)[::-1]: - - layer = BasicLayer_up( - dim=int(embed_dim * 2 ** (len(depths)-i_layer-1)), - input_resolution=( - pretrain_img_size[0] // patch_size[0] // 2 ** (len(depths)-i_layer-1), pretrain_img_size[1] // patch_size[1] // 2 ** (len(depths)-i_layer-1), - pretrain_img_size[2] // patch_size[2] // 2 ** (len(depths)-i_layer-1)), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size[i_layer], - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop_rate, - attn_drop=attn_drop_rate, - drop_path=dpr[sum( - depths[:i_layer]):sum(depths[:i_layer + 1])], - norm_layer=norm_layer, - upsample=Patch_Expanding - ) - self.layers.append(layer) - self.num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] - def forward(self,x,skips): - - outs=[] - S, H, W = x.size(2), x.size(3), x.size(4) - x = x.flatten(2).transpose(1, 2).contiguous() - for index,i in enumerate(skips): - i = i.flatten(2).transpose(1, 2).contiguous() - skips[index]=i - x = self.pos_drop(x) - - for i in range(self.num_layers)[::-1]: - - layer = self.layers[i] - - x, S, H, W, = layer(x,skips[i], S, H, W) - out = x.view(-1, S, H, W, self.num_features[i]) - outs.append(out) - return outs - - -class final_patch_expanding(nn.Module): - def __init__(self,dim,num_class,patch_size): - super().__init__() - self.up=nn.ConvTranspose3d(dim,num_class,patch_size,patch_size) - - def forward(self,x): - x=x.permute(0,4,1,2,3).contiguous() - x=self.up(x) - - - return x - - - - - -class HiFormer_v1(nn.Module): - def __init__(self, params): - """ - replace the embedding layer with convolutional blocks - """ - super(HiFormer_v1, self).__init__() - # crop_size=[96,96,96], - # embedding_dim=192, - # input_channels=1, - # num_classes=9, - # conv_op=nn.Conv3d, - # depths=[2,2,2,2], - # num_heads=[6, 12, 24, 48], - # patch_size=[4,4,4], - # window_size=[4,4,8,4], - # deep_supervision=False): - - crop_size = params["input_size"] - embed_dim = params.get("embedding_dim", 192) - input_channels = params["in_chns"] - num_classes = params["class_num"] - self.conv_op = nn.Conv3d - depths = params.get("depths", [2, 2, 2, 2]) - num_heads = params.get("num_heads", [6, 12, 24, 48]) - patch_size = params.get("patch_size", [4, 4, 4]) # for patch embedding - window_size = params.get("window_size", [4, 4, 8, 4]) # for swin transformer window - self._deep_supervision = params.get("deep_supervision", False) - self.do_ds = params.get("deep_supervision", False) - - - self.num_classes = num_classes - self.upscale_logits_ops = [] - self.upscale_logits_ops.append(lambda x: x) - - self.model_down=Encoder(pretrain_img_size=crop_size,window_size=window_size,embed_dim=embed_dim, - patch_size=patch_size,depths=depths,num_heads=num_heads,in_chans=input_channels, out_indices=range(len(depths))) - self.decoder=Decoder(pretrain_img_size=crop_size,embed_dim=embed_dim,window_size=window_size[::-1][1:],patch_size=patch_size,num_heads=num_heads[::-1][:-1],depths=depths[::-1][1:]) - - self.final=[] - if self.do_ds: - - for i in range(len(depths)-1): - self.final.append(final_patch_expanding(embed_dim*2**i,num_classes,patch_size=patch_size)) - - else: - self.final.append(final_patch_expanding(embed_dim,num_classes,patch_size=patch_size)) - - self.final=nn.ModuleList(self.final) - - - def forward(self, x): - - - seg_outputs=[] - skips = self.model_down(x) - neck=skips[-1] - - out=self.decoder(neck,skips) - - - - if self.do_ds: - for i in range(len(out)): - seg_outputs.append(self.final[-(i+1)](out[i])) - - - return seg_outputs[::-1] - else: - seg_outputs.append(self.final[0](out[-1])) - return seg_outputs[-1] - - -if __name__ == "__main__": - # params = {"input_size": [96, 96, 96], - # "in_chns": 1, - # "depth": [2, 2, 2, 2], - # "num_heads": [6, 12, 24, 48], - # "window_size": [6, 6, 6, 3], - # "class_num": 5} - params = {"input_size": [96, 96, 96], - "in_chns": 1, - "depths": [2, 2, 2], - "num_heads": [6, 12, 24], - "window_size": [6, 6, 6], - "class_num": 5} - Net = HiFormer_v1(params) - Net = Net.double() - - x = np.random.rand(1, 1, 96, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - print(y.shape) - - - diff --git a/pymic/net/net3d/trans3d/HiFormer_v2.py b/pymic/net/net3d/trans3d/HiFormer_v2.py deleted file mode 100644 index 7d4c440..0000000 --- a/pymic/net/net3d/trans3d/HiFormer_v2.py +++ /dev/null @@ -1,381 +0,0 @@ - -import torch -import numpy as np -import torch.utils.checkpoint as checkpoint -from einops import rearrange -from copy import deepcopy -from torch import nn -from pymic.net.net3d.trans3d.HiFormer_v1 import BasicLayer - -class ConvBlock(nn.Module): - """ - 2D or 3D convolutional block - - :param in_channels: (int) Input channel number. - :param out_channels: (int) Output channel number. - :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. - :param dropout_p: (int) Dropout probability. - """ - def __init__(self, in_channels, out_channels, dim = 2, dropout_p = 0.0): - super(ConvBlock, self).__init__() - assert(dim == 2 or dim == 3) - if(dim == 2): - kernel_size = [1, 3, 3] - padding = [0, 1, 1] - else: - kernel_size = 3 - padding = 1 - - self.conv_conv = nn.Sequential( - nn.BatchNorm3d(in_channels), - nn.PReLU(), - nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=padding), - nn.BatchNorm3d(out_channels), - nn.PReLU(), - nn.Dropout(dropout_p), - nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, padding=padding), - ) - - def forward(self, x): - return self.conv_conv(x) - - -class DownSample(nn.Module): - def __init__(self, in_channels, out_channels, dim = 2, first_layer = False): - super(DownSample, self).__init__() - assert(dim == 2 or dim == 3) - if(dim == 2): - kernel_size = [1, 3, 3] - stride = [1, 2, 2] - padding = [0, 1, 1] - else: - kernel_size = 3 - stride = 2 - padding = 1 - - if(first_layer): - self.down = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, - padding=padding, stride = stride) - else: - self.down = nn.Sequential( - nn.BatchNorm3d(in_channels), - nn.PReLU(), - nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, - padding=padding, stride = stride), - ) - - def forward(self, x): - return self.down(x) - - - -class ConvTransBlock(nn.Module): - def __init__(self, - input_resolution= [32, 32, 32], - chns=96, - depth=2, - num_head=4, - window_size=7, - mlp_ratio=4., - qkv_bias=True, - qk_scale=None, - drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0.2, - norm_layer=nn.LayerNorm, - patch_norm=True, - ): - super().__init__() - self.conv = ConvBlock(chns, chns, dim = 3, dropout_p = drop_rate) - self.trans = BasicLayer( - dim= chns, - input_resolution= input_resolution, - depth=depth, - num_heads=num_head, - window_size=window_size, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop_rate, - attn_drop=attn_drop_rate, - drop_path=drop_path_rate, - norm_layer=norm_layer, - downsample= None - ) - self.norm_layer = nn.LayerNorm(chns) - self.pos_drop = nn.Dropout(p=drop_rate) - - def forward(self, x): - """Forward function.""" - x1 = self.conv(x) - C, Ws, Wh, Ww = x.size(1), x.size(2), x.size(3), x.size(4) - x = x.flatten(2).transpose(1, 2).contiguous() - x = self.pos_drop(x) - x2, S, H, W, x, Ws, Wh, Ww = self.trans(x, Ws, Wh, Ww) - # x2 = self.norm_layer(x2) - x2 = x2.view(-1, S, H, W, C).permute(0, 4, 1, 2, 3).contiguous() - return x1 + x2 - - -class UpCatBlock(nn.Module): - """ - 3D upsampling followed by ConvBlock - - :param in_channels1: (int) Channel number of high-level features. - :param in_channels2: (int) Channel number of low-level features. - :param out_channels: (int) Output channel number. - :param dropout_p: (int) Dropout probability. - :param trilinear: (bool) Use trilinear for up-sampling (by default). - If False, deconvolution is used for up-sampling. - """ - def __init__(self, chns_l, chns_h, up_dim = 3, conv_dim = 3): - super(UpCatBlock, self).__init__() - assert(up_dim == 2 or up_dim == 3) - if(up_dim == 2): - kernel_size, stride = [1, 2, 2], [1, 2, 2] - else: - kernel_size, stride = 2, 2 - self.up = nn.ConvTranspose3d(chns_h, chns_l, - kernel_size = kernel_size, stride=stride) - - if(conv_dim == 2): - kernel_size, padding = [1, 3, 3], [0, 1, 1] - else: - kernel_size, padding = 3, 1 - self.conv = nn.Sequential( - nn.BatchNorm3d(chns_l*2), - nn.PReLU(), - nn.Conv3d(chns_l*2, chns_l, kernel_size=kernel_size, padding=padding), - ) - - def forward(self, x_l, x_h): - # print("input shapes", x1.shape, x2.shape) - # print("after upsample", x1.shape) - y = torch.cat([x_l, self.up(x_h)], dim=1) - return self.conv(y) - -class Encoder(nn.Module): - def __init__(self, - in_chns = 1 , - ft_chns = [48, 192, 384, 768], - input_size= [32, 128, 128], - down_dims = [2, 2, 3, 3], - conv_dims = [2, 3, 3, 3], - dropout = [0, 0.2, 0.2, 0.2], - depths = [2, 2, 2], - num_heads = [4, 8, 16], - window_sizes = [6, 6, 6], - ): - super().__init__() - - self.down1 = DownSample(in_chns, ft_chns[0], down_dims[0], first_layer=True) - self.down2 = DownSample(ft_chns[0], ft_chns[1], down_dims[1]) - self.down3 = DownSample(ft_chns[1], ft_chns[2], down_dims[2]) - self.down4 = DownSample(ft_chns[2], ft_chns[3], down_dims[3]) - - self.conv1 = ConvBlock(ft_chns[0], ft_chns[0], conv_dims[0], dropout[0]) - self.conv2 = ConvBlock(ft_chns[1], ft_chns[1], conv_dims[1], dropout[1]) - - down_scales = [] - for i in range(4): - down_scale = [1, 2, 2] if down_dims[i] == 2 else [2, 2, 2] - down_scales.append(down_scale) - - r_t2 = [input_size[i] // down_scales[0][i] // down_scales[1][i] for i in range(3)] - r_t3 = [r_t2[i] // down_scales[2][i] for i in range(3)] - r_t4 = [r_t3[i] // down_scales[3][i] for i in range(3)] - - self.conv_t2 = ConvTransBlock(chns = ft_chns[1], - input_resolution = r_t2, - window_size = window_sizes[0], - depth = depths[0], - num_head = num_heads[0], - drop_rate = dropout[1], - attn_drop_rate=dropout[1] - ) - self.conv_t3 = ConvTransBlock(chns = ft_chns[2], - input_resolution = r_t3, - window_size = window_sizes[1], - depth = depths[1], - num_head = num_heads[1], - drop_rate = dropout[2], - attn_drop_rate=dropout[2] - ) - self.conv_t4 = ConvTransBlock(chns = ft_chns[3], - input_resolution = r_t4, - window_size = window_sizes[2], - depth = depths[2], - num_head = num_heads[2], - drop_rate = dropout[3], - attn_drop_rate=dropout[3] - ) - - - - def forward(self, x): - """Forward function.""" - x1 = self.conv1(self.down1(x)) - x2 = self.conv2(self.down2(x1)) - x2 = self.conv_t2(x2) - x3 = self.conv_t3(self.down3(x2)) - x4 = self.conv_t4(self.down4(x3)) - - return x1, x2, x3, x4 - -class Decoder(nn.Module): - """ - Decoder of 3D UNet. - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 4 or 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - :param multiscale_pred: (bool) Get multi-scale prediction. - """ - def __init__(self, - ft_chns = [48, 192, 384, 768], - input_size = [32, 128, 128], - down_dims = [2, 2, 3, 3], - conv_dims = [2, 3, 3, 3], - dropout = [0, 0, 0.2, 0.2], - depths = [2, 2, 2], - num_heads = [4, 8, 16], - window_sizes = [6, 6, 6], - class_num = 2, - multiscale_pred = False - ): - super(Decoder, self).__init__() - - self.up1 = UpCatBlock(ft_chns[0], ft_chns[1], down_dims[1], conv_dims[0]) - self.up2 = UpCatBlock(ft_chns[1], ft_chns[2], down_dims[2], conv_dims[1]) - self.up3 = UpCatBlock(ft_chns[2], ft_chns[3], down_dims[3], conv_dims[2]) - - down_scales = [] - for i in range(4): - down_scale = [1, 2, 2] if down_dims[i] == 2 else [2, 2, 2] - down_scales.append(down_scale) - - r_t2 = [input_size[i] // down_scales[0][i] // down_scales[1][i] for i in range(3)] - r_t3 = [r_t2[i] // down_scales[2][i] for i in range(3)] - - self.conv1 = ConvBlock(ft_chns[0], ft_chns[0], conv_dims[0], dropout[0]) - self.conv2 = ConvTransBlock(chns = ft_chns[1], - input_resolution = r_t2, - window_size = window_sizes[0], - depth = depths[0], - num_head = num_heads[0], - drop_rate = dropout[1], - attn_drop_rate=dropout[1] - ) - self.conv3 = ConvTransBlock(chns = ft_chns[2], - input_resolution = r_t3, - window_size = window_sizes[1], - depth = depths[1], - num_head = num_heads[1], - drop_rate = dropout[2], - attn_drop_rate=dropout[2] - ) - - kernel_size, stride = 2, 2 - if down_dims[0] == 2: - kernel_size, stride = [1, 2, 2], [1, 2, 2] - self.out_conv0 = nn.ConvTranspose3d(ft_chns[0], class_num, - kernel_size = kernel_size, stride= stride) - - self.mul_pred = multiscale_pred - if(self.mul_pred): - self.out_conv1 = nn.Conv3d(ft_chns[0], class_num, kernel_size = 1) - self.out_conv2 = nn.Conv3d(ft_chns[1], class_num, kernel_size = 1) - self.out_conv3 = nn.Conv3d(ft_chns[2], class_num, kernel_size = 1) - - def forward(self, x): - x1, x2, x3, x4 = x - x_d3 = self.conv3(self.up3(x3, x4)) - x_d2 = self.conv2(self.up2(x2, x_d3)) - x_d1 = self.conv1(self.up1(x1, x_d2)) - - output = self.out_conv0(x_d1) - if(self.mul_pred): - output1 = self.out_conv1(x_d1) - output2 = self.out_conv2(x_d2) - output3 = self.out_conv3(x_d3) - output = [output, output1, output2, output3] - return output - -class HiFormer_v2(nn.Module): - def __init__(self, params): - """ - replace the embedding layer with convolutional blocks - """ - super(HiFormer_v2, self).__init__() - in_chns = params["in_chns"] - class_num = params["class_num"] - input_size = params["input_size"] - ft_chns = params.get("feature_chns", [48, 192, 384, 764]) - down_dims = params.get("down_dims", [2, 2, 3, 3]) - conv_dims = params.get("conv_dims", [2, 3, 3, 3]) - dropout = params.get('dropout', [0, 0.2, 0.2, 0.2]) - depths = params.get("depths", [2, 2, 2]) - num_heads = params.get("num_heads", [4, 8, 16]) - window_sizes= params.get("window_sizes", [6, 6, 6]) - multiscale_pred = params.get("multiscale_pred", False) - - self.encoder = Encoder(in_chns, - ft_chns = ft_chns, - input_size = input_size, - down_dims = down_dims, - conv_dims = conv_dims, - dropout = dropout, - depths = depths, - num_heads = num_heads, - window_sizes= window_sizes) - - self.decoder = Decoder(ft_chns = ft_chns, - input_size = input_size, - down_dims = down_dims, - conv_dims = conv_dims, - dropout = dropout, - depths = depths, - num_heads = num_heads, - window_sizes= window_sizes, - class_num = class_num, - multiscale_pred = multiscale_pred - ) - - def forward(self, x): - x = self.encoder(x) - x = self.decoder(x) - return x - - -if __name__ == "__main__": - params = {"input_size": [32, 128, 128], - "in_chns": 1, - "down_dims": [2, 2, 3, 3], - "conv_dims": [2, 3, 3, 3], - "feature_chns": [96, 192, 384, 768], - "class_num": 5, - "multiscale_pred": True} - Net = HiFormer_v2(params) - Net = Net.double() - - x = np.random.rand(1, 1, 32, 128, 128) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - if(params['multiscale_pred']): - for yi in y: - print(yi.shape) - else: - print(y.shape) - - - diff --git a/pymic/net/net3d/trans3d/HiFormer_v3.py b/pymic/net/net3d/trans3d/HiFormer_v3.py deleted file mode 100644 index 2f8c831..0000000 --- a/pymic/net/net3d/trans3d/HiFormer_v3.py +++ /dev/null @@ -1,455 +0,0 @@ - -import torch -import numpy as np -import torch.utils.checkpoint as checkpoint -from einops import rearrange -from copy import deepcopy -from torch import nn -from pymic.net.net3d.trans3d.HiFormer_v1 import BasicLayer - -class ConvBlock(nn.Module): - """ - 2D or 3D convolutional block - - :param in_channels: (int) Input channel number. - :param out_channels: (int) Output channel number. - :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. - :param dropout_p: (int) Dropout probability. - """ - def __init__(self, in_channels, out_channels, dim = 2, dropout_p = 0.0): - super(ConvBlock, self).__init__() - assert(dim == 2 or dim == 3) - if(dim == 2): - kernel_size = [1, 3, 3] - padding = [0, 1, 1] - else: - kernel_size = 3 - padding = 1 - - self.conv_conv = nn.Sequential( - nn.BatchNorm3d(in_channels), - nn.PReLU(), - nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=padding), - nn.BatchNorm3d(out_channels), - nn.PReLU(), - nn.Dropout(dropout_p), - nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, padding=padding), - ) - - def forward(self, x): - return self.conv_conv(x) - - -class DownSample(nn.Module): - def __init__(self, in_channels, out_channels, dim = 2, first_layer = False): - super(DownSample, self).__init__() - assert(dim == 2 or dim == 3) - if(dim == 2): - kernel_size = [1, 3, 3] - stride = [1, 2, 2] - padding = [0, 1, 1] - else: - kernel_size = 3 - stride = 2 - padding = 1 - - if(first_layer): - self.down = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, - padding=padding, stride = stride) - else: - self.down = nn.Sequential( - nn.BatchNorm3d(in_channels), - nn.PReLU(), - nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, - padding=padding, stride = stride), - ) - - def forward(self, x): - return self.down(x) - - - -class ConvTransBlock_backup(nn.Module): - def __init__(self, - input_resolution= [32, 32, 32], - chns=96, - depth=2, - num_head=4, - window_size=7, - mlp_ratio=4., - qkv_bias=True, - qk_scale=None, - drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0.2, - norm_layer=nn.LayerNorm, - patch_norm=True, - ): - super().__init__() - self.conv = ConvBlock(chns, chns, dim = 3, dropout_p = drop_rate) - self.trans = BasicLayer( - dim= chns, - input_resolution= input_resolution, - depth=depth, - num_heads=num_head, - window_size=window_size, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop_rate, - attn_drop=attn_drop_rate, - drop_path=drop_path_rate, - norm_layer=norm_layer, - downsample= None - ) - self.norm_layer = nn.LayerNorm(chns) - self.pos_drop = nn.Dropout(p=drop_rate) - - def forward(self, x): - """Forward function.""" - x1 = self.conv(x) - C, Ws, Wh, Ww = x.size(1), x.size(2), x.size(3), x.size(4) - x = x.flatten(2).transpose(1, 2).contiguous() - x = self.pos_drop(x) - x2, S, H, W, x, Ws, Wh, Ww = self.trans(x, Ws, Wh, Ww) - # x2 = self.norm_layer(x2) - x2 = x2.view(-1, S, H, W, C).permute(0, 4, 1, 2, 3).contiguous() - return x1 + x2 - -# only using the conv block -class ConvTransBlock(nn.Module): - def __init__(self, - input_resolution= [32, 32, 32], - chns=96, - depth=2, - num_head=4, - window_size=7, - mlp_ratio=4., - qkv_bias=True, - qk_scale=None, - drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0.2, - norm_layer=nn.LayerNorm, - patch_norm=True, - ): - super().__init__() - self.conv = ConvBlock(chns, chns, dim = 3, dropout_p = drop_rate) - # self.trans = BasicLayer( - # dim= chns, - # input_resolution= input_resolution, - # depth=depth, - # num_heads=num_head, - # window_size=window_size, - # mlp_ratio=mlp_ratio, - # qkv_bias=qkv_bias, - # qk_scale=qk_scale, - # drop=drop_rate, - # attn_drop=attn_drop_rate, - # drop_path=drop_path_rate, - # norm_layer=norm_layer, - # downsample= None - # ) - # self.norm_layer = nn.LayerNorm(chns) - # self.pos_drop = nn.Dropout(p=drop_rate) - - def forward(self, x): - """Forward function.""" - x1 = self.conv(x) - return x1 - # C, Ws, Wh, Ww = x.size(1), x.size(2), x.size(3), x.size(4) - # x = x.flatten(2).transpose(1, 2).contiguous() - # x = self.pos_drop(x) - # x2, S, H, W, x, Ws, Wh, Ww = self.trans(x, Ws, Wh, Ww) - # # x2 = self.norm_layer(x2) - # x2 = x2.view(-1, S, H, W, C).permute(0, 4, 1, 2, 3).contiguous() - # return x1 + x2 - -class UpCatBlock(nn.Module): - """ - 3D upsampling followed by ConvBlock - - :param in_channels1: (int) Channel number of high-level features. - :param in_channels2: (int) Channel number of low-level features. - :param out_channels: (int) Output channel number. - :param dropout_p: (int) Dropout probability. - :param trilinear: (bool) Use trilinear for up-sampling (by default). - If False, deconvolution is used for up-sampling. - """ - def __init__(self, chns_l, chns_h, up_dim = 3, conv_dim = 3): - super(UpCatBlock, self).__init__() - assert(up_dim == 2 or up_dim == 3) - if(up_dim == 2): - kernel_size, stride = [1, 2, 2], [1, 2, 2] - else: - kernel_size, stride = 2, 2 - self.up = nn.ConvTranspose3d(chns_h, chns_l, - kernel_size = kernel_size, stride=stride) - - if(conv_dim == 2): - kernel_size, padding = [1, 3, 3], [0, 1, 1] - else: - kernel_size, padding = 3, 1 - self.conv = nn.Sequential( - nn.BatchNorm3d(chns_l*2), - nn.PReLU(), - nn.Conv3d(chns_l*2, chns_l, kernel_size=kernel_size, padding=padding), - ) - - def forward(self, x_l, x_h): - # print("input shapes", x1.shape, x2.shape) - # print("after upsample", x1.shape) - y = torch.cat([x_l, self.up(x_h)], dim=1) - return self.conv(y) - -class Encoder(nn.Module): - def __init__(self, - in_chns = 1 , - ft_chns = [48, 192, 384, 768], - input_size= [32, 128, 128], - down_dims = [2, 2, 3, 3], - conv_dims = [2, 3, 3, 3], - dropout = [0, 0.2, 0.2, 0.2], - depths = [2, 2, 2], - num_heads = [4, 8, 16], - window_sizes = [6, 6, 6], - high_res = False, - ): - super().__init__() - self.high_res = high_res - - self.down1 = DownSample(in_chns, ft_chns[0], down_dims[0], first_layer=True) - self.down2 = DownSample(ft_chns[0], ft_chns[1], down_dims[1]) - self.down3 = DownSample(ft_chns[1], ft_chns[2], down_dims[2]) - self.down4 = DownSample(ft_chns[2], ft_chns[3], down_dims[3]) - - if(high_res): - self.conv0 = ConvBlock(in_chns, ft_chns[0] // 2, 3, 0) - self.conv1 = ConvBlock(ft_chns[0], ft_chns[0], conv_dims[0], dropout[0]) - self.conv2 = ConvBlock(ft_chns[1], ft_chns[1], conv_dims[1], dropout[1]) - - down_scales = [] - for i in range(4): - down_scale = [1, 2, 2] if down_dims[i] == 2 else [2, 2, 2] - down_scales.append(down_scale) - - r_t2 = [input_size[i] // down_scales[0][i] // down_scales[1][i] for i in range(3)] - r_t3 = [r_t2[i] // down_scales[2][i] for i in range(3)] - r_t4 = [r_t3[i] // down_scales[3][i] for i in range(3)] - - self.conv_t2 = ConvTransBlock(chns = ft_chns[1], - input_resolution = r_t2, - window_size = window_sizes[0], - depth = depths[0], - num_head = num_heads[0], - drop_rate = dropout[1], - attn_drop_rate=dropout[1] - ) - self.conv_t3 = ConvTransBlock(chns = ft_chns[2], - input_resolution = r_t3, - window_size = window_sizes[1], - depth = depths[1], - num_head = num_heads[1], - drop_rate = dropout[2], - attn_drop_rate=dropout[2] - ) - self.conv_t4 = ConvTransBlock(chns = ft_chns[3], - input_resolution = r_t4, - window_size = window_sizes[2], - depth = depths[2], - num_head = num_heads[2], - drop_rate = dropout[3], - attn_drop_rate=dropout[3] - ) - - - - def forward(self, x): - """Forward function.""" - if(self.high_res): - x0 = self.conv0(x) - x1 = self.conv1(self.down1(x)) - x2 = self.conv2(self.down2(x1)) - x2 = self.conv_t2(x2) - x3 = self.conv_t3(self.down3(x2)) - x4 = self.conv_t4(self.down4(x3)) - if(self.high_res): - return x0, x1, x2, x3, x4 - else: - return x1, x2, x3, x4 - -class Decoder(nn.Module): - """ - Decoder of 3D UNet. - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 4 or 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - :param multiscale_pred: (bool) Get multi-scale prediction. - """ - def __init__(self, - ft_chns = [48, 192, 384, 768], - input_size = [32, 128, 128], - down_dims = [2, 2, 3, 3], - conv_dims = [2, 3, 3, 3], - dropout = [0, 0, 0.2, 0.2], - depths = [2, 2, 2], - num_heads = [4, 8, 16], - window_sizes = [6, 6, 6], - high_res = False, - class_num = 2, - multiscale_pred = False - ): - super(Decoder, self).__init__() - self.high_res = high_res - if(self.high_res): - self.up0 = UpCatBlock(ft_chns[0] // 2, ft_chns[0], down_dims[0], 3) - self.conv0 = ConvBlock(ft_chns[0] // 2, ft_chns[0] // 2, 3, 0) - self.up1 = UpCatBlock(ft_chns[0], ft_chns[1], down_dims[1], conv_dims[0]) - self.up2 = UpCatBlock(ft_chns[1], ft_chns[2], down_dims[2], conv_dims[1]) - self.up3 = UpCatBlock(ft_chns[2], ft_chns[3], down_dims[3], conv_dims[2]) - - down_scales = [] - for i in range(4): - down_scale = [1, 2, 2] if down_dims[i] == 2 else [2, 2, 2] - down_scales.append(down_scale) - - r_t2 = [input_size[i] // down_scales[0][i] // down_scales[1][i] for i in range(3)] - r_t3 = [r_t2[i] // down_scales[2][i] for i in range(3)] - - self.conv1 = ConvBlock(ft_chns[0], ft_chns[0], conv_dims[0], dropout[0]) - self.conv2 = ConvTransBlock(chns = ft_chns[1], - input_resolution = r_t2, - window_size = window_sizes[0], - depth = depths[0], - num_head = num_heads[0], - drop_rate = dropout[1], - attn_drop_rate=dropout[1] - ) - self.conv3 = ConvTransBlock(chns = ft_chns[2], - input_resolution = r_t3, - window_size = window_sizes[1], - depth = depths[1], - num_head = num_heads[1], - drop_rate = dropout[2], - attn_drop_rate=dropout[2] - ) - - kernel_size, stride = 2, 2 - if down_dims[0] == 2: - kernel_size, stride = [1, 2, 2], [1, 2, 2] - if(self.high_res): - self.out_conv0 = nn.Conv3d(ft_chns[0] // 2, class_num, - kernel_size = [1, 3, 3], padding = [0, 1, 1]) - else: - self.out_conv0 = nn.ConvTranspose3d(ft_chns[0], class_num, - kernel_size = kernel_size, stride= stride) - - self.mul_pred = multiscale_pred - if(self.mul_pred): - self.out_conv1 = nn.Conv3d(ft_chns[0], class_num, kernel_size = 1) - self.out_conv2 = nn.Conv3d(ft_chns[1], class_num, kernel_size = 1) - self.out_conv3 = nn.Conv3d(ft_chns[2], class_num, kernel_size = 1) - - def forward(self, x): - if(self.high_res): - x0, x1, x2, x3, x4 = x - else: - x1, x2, x3, x4 = x - x_d3 = self.conv3(self.up3(x3, x4)) - x_d2 = self.conv2(self.up2(x2, x_d3)) - x_d1 = self.conv1(self.up1(x1, x_d2)) - if(self.high_res): - x_d0 = self.conv0(self.up0(x0, x_d1)) - output = self.out_conv0(x_d0) - else: - output = self.out_conv0(x_d1) - if(self.mul_pred): - output1 = self.out_conv1(x_d1) - output2 = self.out_conv2(x_d2) - output3 = self.out_conv3(x_d3) - output = [output, output1, output2, output3] - return output - -class HiFormer_v3(nn.Module): - def __init__(self, params): - """ - replace the embedding layer with convolutional blocks - """ - super(HiFormer_v3, self).__init__() - in_chns = params["in_chns"] - class_num = params["class_num"] - input_size = params["input_size"] - ft_chns = params.get("feature_chns", [48, 192, 384, 764]) - down_dims = params.get("down_dims", [2, 2, 3, 3]) - conv_dims = params.get("conv_dims", [2, 3, 3, 3]) - dropout = params.get('dropout', [0, 0.2, 0.2, 0.2]) - high_res = params.get("high_res", False) - depths = params.get("depths", [2, 2, 2]) - num_heads = params.get("num_heads", [4, 8, 16]) - window_sizes= params.get("window_sizes", [6, 6, 6]) - multiscale_pred = params.get("multiscale_pred", False) - - self.encoder = Encoder(in_chns, - ft_chns = ft_chns, - input_size = input_size, - down_dims = down_dims, - conv_dims = conv_dims, - dropout = dropout, - depths = depths, - num_heads = num_heads, - window_sizes= window_sizes, - high_res = high_res) - - self.decoder = Decoder(ft_chns = ft_chns, - input_size = input_size, - down_dims = down_dims, - conv_dims = conv_dims, - dropout = dropout, - depths = depths, - num_heads = num_heads, - window_sizes= window_sizes, - high_res = high_res, - class_num = class_num, - multiscale_pred = multiscale_pred - ) - - def forward(self, x): - x = self.encoder(x) - x = self.decoder(x) - return x - - -if __name__ == "__main__": - params = {"input_size": [64, 96, 96], - "in_chns": 1, - "down_dims": [3, 3, 3, 3], - "conv_dims": [3, 3, 3, 3], - "feature_chns": [96, 192, 384, 768], - "high_res": True, - "class_num": 5, - "multiscale_pred": True} - Net = HiFormer_v3(params) - Net = Net.double() - - x = np.random.rand(1, 1, 64, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - if(params['multiscale_pred']): - for yi in y: - print(yi.shape) - else: - print(y.shape) - - - diff --git a/pymic/net/net3d/trans3d/HiFormer_v4.py b/pymic/net/net3d/trans3d/HiFormer_v4.py deleted file mode 100644 index f0c6087..0000000 --- a/pymic/net/net3d/trans3d/HiFormer_v4.py +++ /dev/null @@ -1,455 +0,0 @@ - -import torch -import numpy as np -import torch.utils.checkpoint as checkpoint -from einops import rearrange -from copy import deepcopy -from torch import nn -from pymic.net.net3d.trans3d.HiFormer_v1 import BasicLayer - -class ConvBlock(nn.Module): - """ - 2D or 3D convolutional block - - :param in_channels: (int) Input channel number. - :param out_channels: (int) Output channel number. - :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. - :param dropout_p: (int) Dropout probability. - """ - def __init__(self, in_channels, out_channels, dim = 2, dropout_p = 0.0): - super(ConvBlock, self).__init__() - assert(dim == 2 or dim == 3) - if(dim == 2): - kernel_size = [1, 3, 3] - padding = [0, 1, 1] - else: - kernel_size = 3 - padding = 1 - - self.conv_conv = nn.Sequential( - nn.BatchNorm3d(in_channels), - nn.PReLU(), - nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=padding), - nn.BatchNorm3d(out_channels), - nn.PReLU(), - nn.Dropout(dropout_p), - nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, padding=padding), - ) - - def forward(self, x): - return self.conv_conv(x) - - -class DownSample(nn.Module): - def __init__(self, in_channels, out_channels, down_dim = 3, conv_dim = 3): - super(DownSample, self).__init__() - assert(down_dim == 2 or down_dim == 3) - assert(conv_dim == 2 or conv_dim == 3) - - kernel_size = [1, 2, 2] if(down_dim == 2) else 2 - self.pool = nn.MaxPool3d(kernel_size) - - if(conv_dim == 2): - kernel_size = [1, 3, 3] - padding = [0, 1, 1] - else: - kernel_size = 3 - padding = 1 - - self.conv = nn.Sequential( - nn.BatchNorm3d(in_channels), - nn.PReLU(), - nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=padding), - ) - - def forward(self, x): - return self.conv(self.pool(x)) - - - -# class ConvTransBlock(nn.Module): -# def __init__(self, -# input_resolution= [32, 32, 32], -# chns=96, -# depth=2, -# num_head=4, -# window_size=7, -# mlp_ratio=4., -# qkv_bias=True, -# qk_scale=None, -# drop_rate=0., -# attn_drop_rate=0., -# drop_path_rate=0.2, -# norm_layer=nn.LayerNorm, -# patch_norm=True, -# ): -# super().__init__() -# self.conv = ConvBlock(chns, chns, dim = 3, dropout_p = drop_rate) -# self.trans = BasicLayer( -# dim= chns, -# input_resolution= input_resolution, -# depth=depth, -# num_heads=num_head, -# window_size=window_size, -# mlp_ratio=mlp_ratio, -# qkv_bias=qkv_bias, -# qk_scale=qk_scale, -# drop=drop_rate, -# attn_drop=attn_drop_rate, -# drop_path=drop_path_rate, -# norm_layer=norm_layer, -# downsample= None -# ) -# self.norm_layer = nn.LayerNorm(chns) -# self.pos_drop = nn.Dropout(p=drop_rate) - -# def forward(self, x): -# """Forward function.""" -# x1 = self.conv(x) -# C, Ws, Wh, Ww = x.size(1), x.size(2), x.size(3), x.size(4) -# x = x.flatten(2).transpose(1, 2).contiguous() -# x = self.pos_drop(x) -# x2, S, H, W, x, Ws, Wh, Ww = self.trans(x, Ws, Wh, Ww) -# # x2 = self.norm_layer(x2) -# x2 = x2.view(-1, S, H, W, C).permute(0, 4, 1, 2, 3).contiguous() -# return x1 + x2 - -# only using the conv block -class ConvTransBlock(nn.Module): - def __init__(self, - input_resolution= [32, 32, 32], - chns=96, - depth=2, - num_head=4, - window_size=7, - mlp_ratio=4., - qkv_bias=True, - qk_scale=None, - drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0.2, - norm_layer=nn.LayerNorm, - patch_norm=True, - ): - super().__init__() - self.conv = ConvBlock(chns, chns, dim = 3, dropout_p = drop_rate) - # self.trans = BasicLayer( - # dim= chns, - # input_resolution= input_resolution, - # depth=depth, - # num_heads=num_head, - # window_size=window_size, - # mlp_ratio=mlp_ratio, - # qkv_bias=qkv_bias, - # qk_scale=qk_scale, - # drop=drop_rate, - # attn_drop=attn_drop_rate, - # drop_path=drop_path_rate, - # norm_layer=norm_layer, - # downsample= None - # ) - # self.norm_layer = nn.LayerNorm(chns) - # self.pos_drop = nn.Dropout(p=drop_rate) - - def forward(self, x): - """Forward function.""" - x1 = self.conv(x) - return x1 - # C, Ws, Wh, Ww = x.size(1), x.size(2), x.size(3), x.size(4) - # x = x.flatten(2).transpose(1, 2).contiguous() - # x = self.pos_drop(x) - # x2, S, H, W, x, Ws, Wh, Ww = self.trans(x, Ws, Wh, Ww) - # # x2 = self.norm_layer(x2) - # x2 = x2.view(-1, S, H, W, C).permute(0, 4, 1, 2, 3).contiguous() - # return x1 + x2 - -class ConvLayer(nn.Module): - """ - 2D or 3D convolutional block - - :param in_channels: (int) Input channel number. - :param out_channels: (int) Output channel number. - :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. - :param dropout_p: (int) Dropout probability. - """ - def __init__(self, in_channels, out_channels, kernel = 1, padding = 0): - super(ConvLayer, self).__init__() - - self.conv = nn.Sequential( - nn.BatchNorm3d(in_channels), - nn.PReLU(), - nn.Conv3d(in_channels, out_channels, kernel_size=kernel, padding=padding), - ) - - def forward(self, x): - return self.conv(x) - -class UpCatBlock(nn.Module): - """ - 3D upsampling followed by ConvBlock - - :param in_channels1: (int) Channel number of high-level features. - :param in_channels2: (int) Channel number of low-level features. - :param out_channels: (int) Output channel number. - :param dropout_p: (int) Dropout probability. - :param trilinear: (bool) Use trilinear for up-sampling (by default). - If False, deconvolution is used for up-sampling. - """ - def __init__(self, chns_l, chns_h, up_dim = 3, conv_dim = 3): - super(UpCatBlock, self).__init__() - assert(up_dim == 2 or up_dim == 3) - if(up_dim == 2): - kernel_size, stride = [1, 2, 2], [1, 2, 2] - else: - kernel_size, stride = 2, 2 - - self.up = nn.Sequential( - nn.BatchNorm3d(chns_h), - nn.PReLU(), - nn.ConvTranspose3d(chns_h, chns_l, kernel_size = kernel_size, stride=stride) - ) - - if(conv_dim == 2): - kernel_size, padding = [1, 3, 3], [0, 1, 1] - else: - kernel_size, padding = 3, 1 - self.conv = nn.Sequential( - nn.BatchNorm3d(chns_l*2), - nn.PReLU(), - nn.Conv3d(chns_l*2, chns_l, kernel_size=kernel_size, padding=padding), - ) - - def forward(self, x_l, x_h): - # print("input shapes", x1.shape, x2.shape) - # print("after upsample", x1.shape) - y = torch.cat([x_l, self.up(x_h)], dim=1) - return self.conv(y) - -class Encoder(nn.Module): - def __init__(self, - in_chns = 1 , - ft_chns = [24, 48, 192, 384, 768], - input_size= [32, 128, 128], - down_dims = [3, 3, 3, 3, 3], - conv_dims = [3, 3, 3, 3, 3], - dropout = [0, 0, 0.2, 0.2, 0.2], - depths = [2, 2, 2], - num_heads = [4, 8, 16], - window_sizes = [6, 6, 6], - ): - super().__init__() - self.proj = nn.Conv3d(in_chns, ft_chns[0], kernel_size=3, padding=1) - self.conv0 = ConvBlock(ft_chns[0], ft_chns[0], conv_dims[0], dropout[0]) - self.conv1 = ConvBlock(ft_chns[1], ft_chns[1], conv_dims[1], dropout[1]) - self.conv2 = ConvBlock(ft_chns[2], ft_chns[2], conv_dims[2], dropout[2]) - - self.down1 = DownSample(ft_chns[0], ft_chns[1], down_dims[0], conv_dims[1]) - self.down2 = DownSample(ft_chns[1], ft_chns[2], down_dims[1], conv_dims[2]) - self.down3 = DownSample(ft_chns[2], ft_chns[3], down_dims[2], conv_dims[3]) - self.down4 = DownSample(ft_chns[3], ft_chns[4], down_dims[3], conv_dims[4]) - - down_scales = [] - for i in range(4): - down_scale = [1, 2, 2] if down_dims[i] == 2 else [2, 2, 2] - down_scales.append(down_scale) - - r_t2 = [input_size[i] // down_scales[0][i] // down_scales[1][i] for i in range(3)] - r_t3 = [r_t2[i] // down_scales[2][i] for i in range(3)] - r_t4 = [r_t3[i] // down_scales[3][i] for i in range(3)] - - self.conv_t2 = ConvTransBlock(chns = ft_chns[2], - input_resolution = r_t2, - window_size = window_sizes[0], - depth = depths[0], - num_head = num_heads[0], - drop_rate = dropout[2], - attn_drop_rate=dropout[2] - ) - self.conv_t3 = ConvTransBlock(chns = ft_chns[3], - input_resolution = r_t3, - window_size = window_sizes[1], - depth = depths[1], - num_head = num_heads[1], - drop_rate = dropout[3], - attn_drop_rate=dropout[3] - ) - self.conv_t4 = ConvTransBlock(chns = ft_chns[4], - input_resolution = r_t4, - window_size = window_sizes[2], - depth = depths[2], - num_head = num_heads[2], - drop_rate = dropout[4], - attn_drop_rate=dropout[4] - ) - - - - def forward(self, x): - """Forward function.""" - x0 = self.conv0(self.proj(x)) - x1 = self.conv1(self.down1(x0)) - x2 = self.conv2(self.down2(x1)) - x2 = self.conv_t2(x2) - x3 = self.conv_t3(self.down3(x2)) - x4 = self.conv_t4(self.down4(x3)) - return x0, x1, x2, x3, x4 - - -class Decoder(nn.Module): - """ - Decoder of 3D UNet. - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 4 or 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - :param multiscale_pred: (bool) Get multi-scale prediction. - """ - def __init__(self, - ft_chns = [24, 48, 192, 384, 768], - input_size= [32, 128, 128], - down_dims = [3, 3, 3, 3, 3], - conv_dims = [3, 3, 3, 3, 3], - dropout = [0, 0, 0.2, 0.2, 0.2], - depths = [2, 2, 2], - num_heads = [4, 8, 16], - window_sizes = [6, 6, 6], - class_num = 2, - multiscale_pred = False - ): - super(Decoder, self).__init__() - # self.up0 = UpCatBlock(ft_chns[0] // 2, ft_chns[0], down_dims[0], 3) - # self.conv0 = ConvBlock(ft_chns[0] // 2, ft_chns[0] // 2, 3, 0) - self.up1 = UpCatBlock(ft_chns[0], ft_chns[1], down_dims[0], conv_dims[0]) - self.up2 = UpCatBlock(ft_chns[1], ft_chns[2], down_dims[1], conv_dims[1]) - self.up3 = UpCatBlock(ft_chns[2], ft_chns[3], down_dims[2], conv_dims[2]) - self.up4 = UpCatBlock(ft_chns[3], ft_chns[4], down_dims[3], conv_dims[3]) - - down_scales = [] - for i in range(4): - down_scale = [1, 2, 2] if down_dims[i] == 2 else [2, 2, 2] - down_scales.append(down_scale) - - r_t2 = [input_size[i] // down_scales[0][i] // down_scales[1][i] for i in range(3)] - r_t3 = [r_t2[i] // down_scales[2][i] for i in range(3)] - - self.conv0 = ConvBlock(ft_chns[0], ft_chns[0], conv_dims[0], dropout[0]) - self.conv1 = ConvBlock(ft_chns[1], ft_chns[1], conv_dims[1], dropout[1]) - self.conv2 = ConvTransBlock(chns = ft_chns[2], - input_resolution = r_t2, - window_size = window_sizes[0], - depth = depths[0], - num_head = num_heads[0], - drop_rate = dropout[2], - attn_drop_rate=dropout[2] - ) - self.conv3 = ConvTransBlock(chns = ft_chns[3], - input_resolution = r_t3, - window_size = window_sizes[1], - depth = depths[1], - num_head = num_heads[1], - drop_rate = dropout[3], - attn_drop_rate=dropout[3] - ) - - self.out_conv0 = ConvLayer(ft_chns[0], class_num) - - self.mul_pred = multiscale_pred - if(self.mul_pred): - self.out_conv1 = ConvLayer(ft_chns[1], class_num) - self.out_conv2 = ConvLayer(ft_chns[2], class_num) - self.out_conv3 = ConvLayer(ft_chns[3], class_num) - - def forward(self, x): - x0, x1, x2, x3, x4 = x - - x_d3 = self.conv3(self.up4(x3, x4)) - x_d2 = self.conv2(self.up3(x2, x_d3)) - x_d1 = self.conv1(self.up2(x1, x_d2)) - x_d0 = self.conv0(self.up1(x0, x_d1)) - output = self.out_conv0(x_d0) - - if(self.mul_pred): - output1 = self.out_conv1(x_d1) - output2 = self.out_conv2(x_d2) - output3 = self.out_conv3(x_d3) - output = [output, output1, output2, output3] - return output - -class HiFormer_v4(nn.Module): - def __init__(self, params): - """ - replace the embedding layer with convolutional blocks - """ - super(HiFormer_v4, self).__init__() - in_chns = params["in_chns"] - class_num = params["class_num"] - input_size = params["input_size"] - ft_chns = params.get("feature_chns", [32, 64, 128, 256, 512]) - down_dims = params.get("down_dims", [3, 3, 3, 3, 3]) - conv_dims = params.get("conv_dims", [3, 3, 3, 3, 3]) - dropout = params.get('dropout', [0, 0, 0.2, 0.2, 0.2]) - depths = params.get("depths", [2, 2, 2]) - num_heads = params.get("num_heads", [4, 8, 16]) - window_sizes= params.get("window_sizes", [6, 6, 6]) - multiscale_pred = params.get("multiscale_pred", False) - - self.encoder = Encoder(in_chns, - ft_chns = ft_chns, - input_size = input_size, - down_dims = down_dims, - conv_dims = conv_dims, - dropout = dropout, - depths = depths, - num_heads = num_heads, - window_sizes= window_sizes) - - self.decoder = Decoder(ft_chns = ft_chns, - input_size = input_size, - down_dims = down_dims, - conv_dims = conv_dims, - dropout = dropout, - depths = depths, - num_heads = num_heads, - window_sizes= window_sizes, - class_num = class_num, - multiscale_pred = multiscale_pred - ) - - def forward(self, x): - x = self.encoder(x) - x = self.decoder(x) - return x - - -if __name__ == "__main__": - params = {"input_size": [64, 96, 96], - "in_chns": 1, - "down_dims": [3, 3, 3, 3, 3], - "conv_dims": [3, 3, 3, 3, 3], - "feature_chns": [32, 64, 128, 256, 512], - "class_num": 5, - "multiscale_pred": True} - Net = HiFormer_v4(params) - Net = Net.double() - - x = np.random.rand(1, 1, 64, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - if(params['multiscale_pred']): - for yi in y: - print(yi.shape) - else: - print(y.shape) - - - diff --git a/pymic/net/net3d/trans3d/HiFormer_v5.py b/pymic/net/net3d/trans3d/HiFormer_v5.py deleted file mode 100644 index 5fcef5a..0000000 --- a/pymic/net/net3d/trans3d/HiFormer_v5.py +++ /dev/null @@ -1,308 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division - -import torch -import torch.nn as nn -import numpy as np -from torch.nn.functional import interpolate - - -class ConvBlock(nn.Module): - """ - 2D or 3D convolutional block - - :param in_channels: (int) Input channel number. - :param out_channels: (int) Output channel number. - :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. - :param dropout_p: (int) Dropout probability. - """ - def __init__(self, in_channels, out_channels, dropout_p = 0.0, dim = 3): - super(ConvBlock, self).__init__() - assert(dim == 2 or dim == 3) - if(dim == 2): - kernel_size = [1, 3, 3] - padding = [0, 1, 1] - else: - kernel_size = 3 - padding = 1 - - self.conv_conv = nn.Sequential( - nn.BatchNorm3d(in_channels), - nn.LeakyReLU(), - nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=padding), - nn.BatchNorm3d(out_channels), - nn.LeakyReLU(), - nn.Dropout(dropout_p), - nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, padding=padding), - ) - - def forward(self, x): - return self.conv_conv(x) - -class ConvLayer(nn.Module): - """ - 2D or 3D convolutional block - - :param in_channels: (int) Input channel number. - :param out_channels: (int) Output channel number. - :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. - :param dropout_p: (int) Dropout probability. - """ - def __init__(self, in_channels, out_channels, kernel = 1, padding = 0): - super(ConvLayer, self).__init__() - - self.conv = nn.Sequential( - nn.BatchNorm3d(in_channels), - nn.LeakyReLU(), - nn.Conv3d(in_channels, out_channels, kernel_size=kernel, padding=padding), - ) - - def forward(self, x): - return self.conv(x) - -class DownBlock(nn.Module): - """ - 3D downsampling followed by ConvBlock - - :param in_channels: (int) Input channel number. - :param out_channels: (int) Output channel number. - :param dropout_p: (int) Dropout probability. - """ - def __init__(self, in_channels, out_channels, dropout_p): - super(DownBlock, self).__init__() - self.maxpool_conv = nn.Sequential( - nn.MaxPool3d(2), - ConvBlock(in_channels, out_channels, dropout_p) - ) - - def forward(self, x): - return self.maxpool_conv(x) - -class UpBlock(nn.Module): - """ - 3D upsampling followed by ConvBlock - - :param in_channels1: (int) Channel number of high-level features. - :param in_channels2: (int) Channel number of low-level features. - :param out_channels: (int) Output channel number. - :param dropout_p: (int) Dropout probability. - :param trilinear: (bool) Use trilinear for up-sampling (by default). - If False, deconvolution is used for up-sampling. - """ - def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, - trilinear=True): - super(UpBlock, self).__init__() - self.trilinear = trilinear - if trilinear: - self.conv1x1 = nn.Conv3d(in_channels1, in_channels2, kernel_size = 1) - self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True) - else: - self.up = nn.Sequential( - nn.BatchNorm3d(in_channels1), - nn.LeakyReLU(), - nn.ConvTranspose3d(in_channels1, in_channels2, kernel_size=2, stride=2) - ) - self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p) - - def forward(self, x1, x2): - if self.trilinear: - x1 = self.conv1x1(x1) - x1 = self.up(x1) - x = torch.cat([x2, x1], dim=1) - return self.conv(x) - -class Encoder(nn.Module): - """ - Encoder of 3D UNet. - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 4 or 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - """ - def __init__(self, params): - super(Encoder, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) - - self.proj = nn.Conv3d(self.in_chns, self.ft_chns[0], kernel_size=3, padding=1) - self.in_conv= ConvBlock(self.ft_chns[0], self.ft_chns[0], self.dropout[0]) - self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) - self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) - self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) - if(len(self.ft_chns) == 5): - self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - - def forward(self, x): - x0 = self.in_conv(self.proj(x)) - x1 = self.down1(x0) - x2 = self.down2(x1) - x3 = self.down3(x2) - output = [x0, x1, x2, x3] - if(len(self.ft_chns) == 5): - x4 = self.down4(x3) - output.append(x4) - return output - -class Decoder(nn.Module): - """ - Decoder of 3D UNet. - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 4 or 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - :param multiscale_pred: (bool) Get multi-scale prediction. - """ - def __init__(self, params): - super(Decoder, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.n_class = self.params['class_num'] - self.trilinear = self.params.get('trilinear', True) - self.mul_pred = self.params.get('multiscale_pred', False) - - assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) - - if(len(self.ft_chns) == 5): - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.trilinear) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.trilinear) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.trilinear) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.trilinear) - self.out_conv = ConvLayer(self.ft_chns[0], self.n_class) - if(self.mul_pred): - self.out_conv1 = ConvLayer(self.ft_chns[1], self.n_class) - self.out_conv2 = ConvLayer(self.ft_chns[2], self.n_class) - self.out_conv3 = ConvLayer(self.ft_chns[3], self.n_class) - - def forward(self, x): - if(len(self.ft_chns) == 5): - assert(len(x) == 5) - x0, x1, x2, x3, x4 = x - x_d3 = self.up1(x4, x3) - else: - assert(len(x) == 4) - x0, x1, x2, x3 = x - x_d3 = x3 - x_d2 = self.up2(x_d3, x2) - x_d1 = self.up3(x_d2, x1) - x_d0 = self.up4(x_d1, x0) - output = self.out_conv(x_d0) - if(self.mul_pred): - output1 = self.out_conv1(x_d1) - output2 = self.out_conv2(x_d2) - output3 = self.out_conv3(x_d3) - output = [output, output1, output2, output3] - return output - -class HiFormer_v5(nn.Module): - """ - An implementation of the U-Net. - - * Reference: Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, Olaf Ronneberger: - 3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation. - `MICCAI (2) 2016: 424-432. `_ - - Note that there are some modifications from the original paper, such as - the use of batch normalization, dropout, leaky relu and deep supervision. - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 4 or 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using trilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - :param multiscale_pred: (bool) Get multi-scale prediction. - """ - def __init__(self, params): - super(HiFormer_v5, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.n_class = self.params['class_num'] - self.trilinear = self.params['trilinear'] - self.mul_pred = self.params['multiscale_pred'] - assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) - - self.proj = nn.Conv3d(self.in_chns, self.ft_chns[0], kernel_size=3, padding=1) - self.in_conv= ConvBlock(self.ft_chns[0], self.ft_chns[0], self.dropout[0]) - self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) - self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) - self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) - if(len(self.ft_chns) == 5): - self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], - dropout_p = self.dropout[3], trilinear=self.trilinear) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], - dropout_p = self.dropout[2], trilinear=self.trilinear) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], - dropout_p = self.dropout[1], trilinear=self.trilinear) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], - dropout_p = self.dropout[0], trilinear=self.trilinear) - - self.out_conv = ConvLayer(self.ft_chns[0], self.n_class) - if(self.mul_pred): - self.out_conv1 = ConvLayer(self.ft_chns[1], self.n_class) - self.out_conv2 = ConvLayer(self.ft_chns[2], self.n_class) - self.out_conv3 = ConvLayer(self.ft_chns[3], self.n_class) - - def forward(self, x): - x0 = self.in_conv(self.proj(x)) - x1 = self.down1(x0) - x2 = self.down2(x1) - x3 = self.down3(x2) - if(len(self.ft_chns) == 5): - x4 = self.down4(x3) - x_d3 = self.up1(x4, x3) - else: - x_d3 = x3 - x_d2 = self.up2(x_d3, x2) - x_d1 = self.up3(x_d2, x1) - x_d0 = self.up4(x_d1, x0) - output = self.out_conv(x_d0) - if(self.mul_pred): - output1 = self.out_conv1(x_d1) - output2 = self.out_conv2(x_d2) - output3 = self.out_conv3(x_d3) - output = [output, output1, output2, output3] - return output - -if __name__ == "__main__": - params = {'in_chns':4, - 'class_num': 2, - 'feature_chns':[32, 64, 128, 256, 512], - 'dropout' : [0, 0, 0, 0, 0.5], - 'trilinear': False, - 'multiscale_pred': False} - Net = HiFormer_v5(params) - Net = Net.double() - - x = np.random.rand(4, 4, 96, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - y = y.detach().numpy() - print(y.shape) diff --git a/pymic/net/net3d/trans3d/MedFormer_v1.py b/pymic/net/net3d/trans3d/MedFormer_v1.py deleted file mode 100644 index 1f2ed54..0000000 --- a/pymic/net/net3d/trans3d/MedFormer_v1.py +++ /dev/null @@ -1,173 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division - -import math -import torch -import torch.nn as nn -import numpy as np -from torch.nn import Dropout, Softmax, Linear, Conv2d, LayerNorm -from pymic.net.net3d.unet3d import Encoder, Decoder - -class Attention(nn.Module): - def __init__(self, params): - super(Attention, self).__init__() - hidden_size = params["attention_hidden_size"] - self.num_attention_heads = params["attention_num_heads"] - self.attention_head_size = int(hidden_size / self.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - - self.query = Linear(hidden_size, self.all_head_size) - self.key = Linear(hidden_size, self.all_head_size) - self.value = Linear(hidden_size, self.all_head_size) - - self.out = Linear(hidden_size, hidden_size) - self.attn_dropout = Dropout(params["attention_dropout_rate"]) - self.proj_dropout = Dropout(params["attention_dropout_rate"]) - - self.softmax = Softmax(dim=-1) - - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - - def forward(self, hidden_states): - mixed_query_layer = self.query(hidden_states) - mixed_key_layer = self.key(hidden_states) - mixed_value_layer = self.value(hidden_states) - - query_layer = self.transpose_for_scores(mixed_query_layer) - key_layer = self.transpose_for_scores(mixed_key_layer) - value_layer = self.transpose_for_scores(mixed_value_layer) - - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - attention_probs = self.softmax(attention_scores) - # weights = attention_probs if self.vis else None - attention_probs = self.attn_dropout(attention_probs) - - context_layer = torch.matmul(attention_probs, value_layer) - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) - attention_output = self.out(context_layer) - attention_output = self.proj_dropout(attention_output) - return attention_output - -class MLP(nn.Module): - def __init__(self, params): - super(MLP, self).__init__() - hidden_size = params["attention_hidden_size"] - mlp_dim = params["attention_mlp_dim"] - self.fc1 = Linear(hidden_size, mlp_dim) - self.fc2 = Linear(mlp_dim, hidden_size) - self.act_fn = torch.nn.functional.gelu - self.dropout = Dropout(params["attention_dropout_rate"]) - - self._init_weights() - - def _init_weights(self): - nn.init.xavier_uniform_(self.fc1.weight) - nn.init.xavier_uniform_(self.fc2.weight) - nn.init.normal_(self.fc1.bias, std=1e-6) - nn.init.normal_(self.fc2.bias, std=1e-6) - - def forward(self, x): - x = self.fc1(x) - x = self.act_fn(x) - x = self.dropout(x) - x = self.fc2(x) - x = self.dropout(x) - return x - -class Block(nn.Module): - def __init__(self, params): - super(Block, self).__init__() - hidden_size = params["attention_hidden_size"] - self.attention_norm = LayerNorm(hidden_size, eps=1e-6) - self.ffn_norm = LayerNorm(hidden_size, eps=1e-6) - self.ffn = MLP(params) - self.attn = Attention(params) - - def forward(self, x): - # convert the tensor shape from [B, C, D, H, W] to [B, DHW, C] - [B, C, D, H, W] = list(x.shape) - new_shape = [B, C, D*H*W] - x = torch.reshape(x, new_shape) - x = torch.transpose(x, 1, 2) - - h = x - x = self.attention_norm(x) - x = self.attn(x) - x = x + h - - h = x - x = self.ffn_norm(x) - x = self.ffn(x) - x = x + h - - # convert the result back to [B, C, D, H, W] - x = torch.transpose(x, 1, 2) - x = torch.reshape(x, [B, C, D, H, W]) - return x - -class MedFormerV1(nn.Module): - """ - An implementation of the U-Net. - - * Reference: Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, Olaf Ronneberger: - 3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation. - `MICCAI (2) 2016: 424-432. `_ - - Note that there are some modifications from the original paper, such as - the use of batch normalization, dropout, leaky relu and deep supervision. - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 4 or 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using trilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - :param deep_supervise: (bool) Using deep supervision for training or not. - """ - def __init__(self, params): - super(MedFormerV1, self).__init__() - self.params = params - self.encoder = Encoder(params) - self.decoder = Decoder(params) - self.attn = Block(params) - - def forward(self, x): - f = self.encoder(x) - f[-1] = self.attn(f[-1]) - output = self.decoder(f) - return output - -if __name__ == "__main__": - params = {'in_chns':4, - 'class_num': 2, - 'feature_chns':[16, 32, 64, 128], - 'dropout' : [0, 0, 0, 0.5], - 'trilinear': True, - 'deep_supervise': True, - 'attention_hidden_size': 128, - 'attention_num_heads': 4, - 'attention_mlp_dim': 256, - 'attention_dropout_rate': 0.2} - Net = MedFormerV1(params) - Net = Net.double() - - x = np.random.rand(1, 4, 96, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - print("output length", len(y)) - for yi in y: - yi = yi.detach().numpy() - print(yi.shape) diff --git a/pymic/net/net3d/trans3d/MedFormer_v2.py b/pymic/net/net3d/trans3d/MedFormer_v2.py deleted file mode 100644 index 00cb295..0000000 --- a/pymic/net/net3d/trans3d/MedFormer_v2.py +++ /dev/null @@ -1,464 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division - -import math -import copy -import torch -import torch.nn as nn -import numpy as np -import torch.nn.functional as F -from pymic.net.net3d.unet3d import ConvBlock, Encoder, Decoder -from pymic.net.net3d.trans3d.MedFormer_v1 import Block -from timm.models.layers import DropPath, to_3tuple, trunc_normal_ - - -# code from nnFormer -class Mlp(nn.Module): - """ Multilayer perceptron.""" - - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -def window_partition(x, window_size): - - B, S, H, W, C = x.shape - x = x.view(B, S // window_size, window_size, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size, window_size, window_size, C) - return windows - - -def window_reverse(windows, window_size, S, H, W): - - B = int(windows.shape[0] / (S * H * W / window_size / window_size / window_size)) - x = windows.view(B, S // window_size, H // window_size, W // window_size, window_size, window_size, window_size, -1) - x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, S, H, W, -1) - return x - - -class WindowAttention(nn.Module): - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): - - super().__init__() - self.dim = dim - self.window_size = window_size - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 - - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), - num_heads)) - - # get pair-wise relative position index for each token inside the window - coords_s = torch.arange(self.window_size[0]) - coords_h = torch.arange(self.window_size[1]) - coords_w = torch.arange(self.window_size[2]) - coords = torch.stack(torch.meshgrid([coords_s, coords_h, coords_w])) - coords_flatten = torch.flatten(coords, 1) - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] - relative_coords = relative_coords.permute(1, 2, 0).contiguous() - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 2] += self.window_size[2] - 1 - - relative_coords[:, :, 0] *= 3 * self.window_size[1] - 1 - relative_coords[:, :, 1] *= 2 * self.window_size[1] - 1 - - relative_position_index = relative_coords.sum(-1) - self.register_buffer("relative_position_index", relative_position_index) - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - trunc_normal_(self.relative_position_bias_table, std=.02) - - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask=None,pos_embed=None): - - B_, N, C = x.shape - - qkv = self.qkv(x) - - qkv=qkv.reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - - q = q * self.scale - attn = (q @ k.transpose(-2, -1).contiguous()) - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1] * self.window_size[2], - self.window_size[0] * self.window_size[1] * self.window_size[2], -1) - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C).contiguous() - if pos_embed is not None: - x = x+pos_embed - x = self.proj(x) - x = self.proj_drop(x) - return x - -class SwinTransformerBlock(nn.Module): - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - - self.attn = WindowAttention( - dim, window_size=to_3tuple(self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) - - - - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - - def forward(self, x, mask_matrix): - - B, L, C = x.shape - S, H, W = self.input_resolution - - assert L == S * H * W, "input feature has wrong size" - - shortcut = x - x = self.norm1(x) - x = x.view(B, S, H, W, C) - - # pad feature maps to multiples of window size - pad_r = (self.window_size - W % self.window_size) % self.window_size - pad_b = (self.window_size - H % self.window_size) % self.window_size - pad_g = (self.window_size - S % self.window_size) % self.window_size - - x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b, 0, pad_g)) - _, Sp, Hp, Wp, _ = x.shape - - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size,-self.shift_size), dims=(1, 2,3)) - attn_mask = mask_matrix - else: - shifted_x = x - attn_mask = None - - # partition windows - x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size * self.window_size * self.window_size, - C) - - # W-MSA/SW-MSA - attn_windows = self.attn(x_windows, mask=attn_mask,pos_embed=None) - - # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, self.window_size, Sp, Hp, Wp) - - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size, self.shift_size), dims=(1, 2, 3)) - else: - x = shifted_x - - if pad_r > 0 or pad_b > 0 or pad_g > 0: - x = x[:, :S, :H, :W, :].contiguous() - - x = x.view(B, S * H * W, C) - - # FFN - x = shortcut + self.drop_path(x) - x = x + self.drop_path(self.mlp(self.norm2(x))) - - return x - -class BasicLayer(nn.Module): - - def __init__(self, - dim, - input_resolution, - depth, - num_heads, - window_size=7, - mlp_ratio=4., - qkv_bias=True, - qk_scale=None, - drop=0., - attn_drop=0., - drop_path=0., - norm_layer=nn.LayerNorm, - downsample=True - ): - super().__init__() - self.window_size = window_size - self.shift_size = window_size // 2 - self.depth = depth - # build blocks - - self.blocks = nn.ModuleList([ - SwinTransformerBlock( - dim=dim, - input_resolution=input_resolution, - num_heads=num_heads, - window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop, - attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) - for i in range(depth)]) - - # patch merging layer - if downsample is not None: - self.downsample = downsample(dim=dim, norm_layer=norm_layer) - else: - self.downsample = None - - def forward(self, x, S, H, W): - - - # calculate attention mask for SW-MSA - Sp = int(np.ceil(S / self.window_size)) * self.window_size - Hp = int(np.ceil(H / self.window_size)) * self.window_size - Wp = int(np.ceil(W / self.window_size)) * self.window_size - img_mask = torch.zeros((1, Sp, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 - s_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for s in s_slices: - for h in h_slices: - for w in w_slices: - img_mask[:, s, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) - mask_windows = mask_windows.view(-1, - self.window_size * self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - for blk in self.blocks: - - x = blk(x, attn_mask) - if self.downsample is not None: - x_down = self.downsample(x, S, H, W) - Ws, Wh, Ww = (S + 1) // 2, (H + 1) // 2, (W + 1) // 2 - return x, S, H, W, x_down, Ws, Wh, Ww - else: - return x, S, H, W, x, S, H, W - - -class AttUpBlock(nn.Module): - """ - 3D upsampling followed by ConvBlock - - :param in_channels1: (int) Channel number of high-level features. - :param in_channels2: (int) Channel number of low-level features. - :param out_channels: (int) Output channel number. - :param dropout_p: (int) Dropout probability. - :param trilinear: (bool) Use trilinear for up-sampling (by default). - If False, deconvolution is used for up-sampling. - """ - def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, - trilinear=True, with_att = False, att_params = None): - super(AttUpBlock, self).__init__() - self.trilinear = trilinear - self.with_att = with_att - if trilinear: - self.conv1x1 = nn.Conv3d(in_channels1, in_channels2, kernel_size = 1) - self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True) - else: - self.up = nn.ConvTranspose3d(in_channels1, in_channels2, kernel_size=2, stride=2) - self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p) - if(self.with_att): - input_resolution = att_params['input_resolution'] - depth = att_params['depth'] - num_heads = att_params['num_heads'] - self.attn = BasicLayer(out_channels, input_resolution, depth, num_heads, downsample=None) - - def forward(self, x1, x2): - if self.trilinear: - x1 = self.conv1x1(x1) - x1 = self.up(x1) - x = torch.cat([x2, x1], dim=1) - x = self.conv(x) - if(self.with_att): - [B, C, D, H, W] = list(x.shape) - x = x.flatten(2).transpose(1, 2).contiguous() - x = self.attn(x, D, H, W)[0] - x = x.view(-1, D, H, W, C).permute(0, 4, 1, 2, 3).contiguous() - return x - -class AttDecoder(nn.Module): - """ - Decoder of 3D UNet. - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 4 or 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - """ - def __init__(self, params): - super(AttDecoder, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.n_class = self.params['class_num'] - self.trilinear = self.params.get('trilinear', True) - self.mul_pred = self.params['multiscale_pred'] - - assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) - - if(len(self.ft_chns) == 5): - self.up1 = AttUpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.trilinear) - att_params = {"input_resolution": [24, 24, 24], "depth": 2, "num_heads": 4} - self.up2 = AttUpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.trilinear, True, att_params) - att_params = {"input_resolution": [48, 48, 48], "depth": 2, "num_heads": 4} - self.up3 = AttUpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.trilinear, True, att_params) - self.up4 = AttUpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.trilinear) - self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 1) - if(self.mul_pred): - self.out_conv1 = nn.Conv3d(self.ft_chns[1], self.n_class, kernel_size = 1) - self.out_conv2 = nn.Conv3d(self.ft_chns[2], self.n_class, kernel_size = 1) - self.out_conv3 = nn.Conv3d(self.ft_chns[3], self.n_class, kernel_size = 1) - - def forward(self, x): - if(len(self.ft_chns) == 5): - assert(len(x) == 5) - x0, x1, x2, x3, x4 = x - x_d3 = self.up1(x4, x3) - else: - assert(len(x) == 4) - x0, x1, x2, x3 = x - x_d3 = x3 - x_d2 = self.up2(x_d3, x2) - x_d1 = self.up3(x_d2, x1) - x_d0 = self.up4(x_d1, x0) - output = self.out_conv(x_d0) - if(self.mul_pred): - output1 = self.out_conv1(x_d1) - output2 = self.out_conv2(x_d2) - output3 = self.out_conv3(x_d3) - output = [output, output1, output2, output3] - return output - -class MedFormerV2(nn.Module): - """ - An implementation of the U-Net. - - * Reference: Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, Olaf Ronneberger: - 3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation. - `MICCAI (2) 2016: 424-432. `_ - - Note that there are some modifications from the original paper, such as - the use of batch normalization, dropout, leaky relu and deep supervision. - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 4 or 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using trilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - """ - def __init__(self, params): - super(MedFormerV2, self).__init__() - self.params = params - self.encoder = Encoder(params) - self.decoder = AttDecoder(params) - self.attn = Block(params) - - def forward(self, x): - f = self.encoder(x) - f[-1] = self.attn(f[-1]) - output = self.decoder(f) - return output - -if __name__ == "__main__": - params = {'in_chns':4, - 'class_num': 2, - 'feature_chns':[16, 32, 64, 128], - 'dropout' : [0, 0, 0, 0.5], - 'trilinear': True, - 'multiscale_pred': True, - 'attention_hidden_size': 128, - 'attention_num_heads': 4, - 'attention_mlp_dim': 256, - 'attention_dropout_rate': 0.2} - - Net = MedFormerV2(params) - Net = Net.double() - - x = np.random.rand(1, 4, 96, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - print("output length", len(y)) - for yi in y: - yi = yi.detach().numpy() - print(yi.shape) diff --git a/pymic/net/net3d/trans3d/MedFormer_v3.py b/pymic/net/net3d/trans3d/MedFormer_v3.py deleted file mode 100644 index f119a9c..0000000 --- a/pymic/net/net3d/trans3d/MedFormer_v3.py +++ /dev/null @@ -1,255 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division - -import numpy as np -import torch -import torch.nn as nn -from torch.nn.functional import interpolate -from pymic.net.net3d.unet3d import ConvBlock, Encoder -from pymic.net.net3d.trans3d.MedFormer_v1 import Block -from pymic.net.net3d.trans3d.MedFormer_v2 import SwinTransformerBlock, window_partition - -class GLAttLayer(nn.Module): - def __init__(self, - dim, - input_resolution, - num_heads, - window_size=7, - mlp_ratio=4., - qkv_bias=True, - qk_scale=None, - drop=0., - attn_drop=0., - drop_path=0., - norm_layer=nn.LayerNorm): - super().__init__() - self.window_size = window_size - self.shift_size = window_size // 2 - # build blocks - - self.lcl_att = SwinTransformerBlock( - dim=dim, - input_resolution=input_resolution, - num_heads=num_heads, - window_size=window_size, - shift_size=0, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop, - attn_drop=attn_drop, - drop_path=drop_path) - self.adpool = nn.AdaptiveAvgPool3d([12, 12, 12]) - - params = {'attention_hidden_size': dim, - 'attention_num_heads': 4, - 'attention_mlp_dim': dim, - 'attention_dropout_rate': 0.2} - self.glb_att = Block(params) - self.conv1x1 = nn.Sequential( - nn.Conv3d(2*dim, dim, kernel_size=1), - nn.BatchNorm3d(dim), - nn.LeakyReLU()) - - def forward(self, x): - [B, C, S, H, W] = list(x.shape) - # calculate attention mask for SW-MSA - Sp = int(np.ceil(S / self.window_size)) * self.window_size - Hp = int(np.ceil(H / self.window_size)) * self.window_size - Wp = int(np.ceil(W / self.window_size)) * self.window_size - img_mask = torch.zeros((1, Sp, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 - s_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for s in s_slices: - for h in h_slices: - for w in w_slices: - img_mask[:, s, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) - mask_windows = mask_windows.view(-1, - self.window_size * self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - - # for local attention - xl = x.flatten(2).transpose(1, 2).contiguous() - xl = self.lcl_att(xl, attn_mask) - xl = xl.view(-1, S, H, W, C).permute(0, 4, 1, 2, 3).contiguous() - - # for global attention - xg = self.adpool(x) - xg = self.glb_att(xg) - xg = interpolate(xg, [S, H, W], mode = 'trilinear') - out = torch.cat([xl, xg], dim=1) - out = self.conv1x1(out) - return out - -class AttUpBlock(nn.Module): - """ - 3D upsampling followed by ConvBlock - - :param in_channels1: (int) Channel number of high-level features. - :param in_channels2: (int) Channel number of low-level features. - :param out_channels: (int) Output channel number. - :param dropout_p: (int) Dropout probability. - :param trilinear: (bool) Use trilinear for up-sampling (by default). - If False, deconvolution is used for up-sampling. - """ - def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, - trilinear=True, with_att = False, att_params = None): - super(AttUpBlock, self).__init__() - self.trilinear = trilinear - self.with_att = with_att - if trilinear: - self.conv1x1 = nn.Conv3d(in_channels1, in_channels2, kernel_size = 1) - self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True) - else: - self.up = nn.ConvTranspose3d(in_channels1, in_channels2, kernel_size=2, stride=2) - self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p) - if(self.with_att): - input_resolution = att_params['input_resolution'] - num_heads = att_params['num_heads'] - window_size = att_params['window_size'] - self.attn = GLAttLayer(out_channels, input_resolution, num_heads, window_size, 2.0) - - def forward(self, x1, x2): - if self.trilinear: - x1 = self.conv1x1(x1) - x1 = self.up(x1) - x = torch.cat([x2, x1], dim=1) - x = self.conv(x) - if(self.with_att): - x = self.attn(x) - return x - - -class AttDecoder(nn.Module): - """ - Decoder of 3D UNet. - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 4 or 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - """ - def __init__(self, params): - super(AttDecoder, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.n_class = self.params['class_num'] - self.trilinear = self.params.get('trilinear', True) - self.mul_pred = self.params['multiscale_pred'] - - assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) - - if(len(self.ft_chns) == 5): - self.up1 = AttUpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.trilinear) - att_params = {"input_resolution": [24, 24, 24], "num_heads": 4, "window_size": 7} - self.up2 = AttUpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.trilinear, True, att_params) - att_params = {"input_resolution": [48, 48, 48], "num_heads": 4, "window_size": 7} - self.up3 = AttUpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.trilinear, True, att_params) - self.up4 = AttUpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.trilinear) - self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 1) - if(self.mul_pred): - self.out_conv1 = nn.Conv3d(self.ft_chns[1], self.n_class, kernel_size = 1) - self.out_conv2 = nn.Conv3d(self.ft_chns[2], self.n_class, kernel_size = 1) - self.out_conv3 = nn.Conv3d(self.ft_chns[3], self.n_class, kernel_size = 1) - - def forward(self, x): - if(len(self.ft_chns) == 5): - assert(len(x) == 5) - x0, x1, x2, x3, x4 = x - x_d3 = self.up1(x4, x3) - else: - assert(len(x) == 4) - x0, x1, x2, x3 = x - x_d3 = x3 - x_d2 = self.up2(x_d3, x2) - x_d1 = self.up3(x_d2, x1) - x_d0 = self.up4(x_d1, x0) - output = self.out_conv(x_d0) - if(self.mul_pred): - output1 = self.out_conv1(x_d1) - output2 = self.out_conv2(x_d2) - output3 = self.out_conv3(x_d3) - output = [output, output1, output2, output3] - return output - -class MedFormerV3(nn.Module): - """ - An implementation of the U-Net. - - * Reference: Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, Olaf Ronneberger: - 3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation. - `MICCAI (2) 2016: 424-432. `_ - - Note that there are some modifications from the original paper, such as - the use of batch normalization, dropout, leaky relu and deep supervision. - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 4 or 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using trilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - """ - def __init__(self, params): - super(MedFormerV3, self).__init__() - self.params = params - self.encoder = Encoder(params) - self.decoder = AttDecoder(params) - params["attention_hidden_size"] = params['feature_chns'][-1] - params["attention_mlp_dim"] = params['feature_chns'][-1] - self.attn = Block(params) - - def forward(self, x): - f = self.encoder(x) - f[-1] = self.attn(f[-1]) - output = self.decoder(f) - return output - -if __name__ == "__main__": - params = {'in_chns':4, - 'class_num': 2, - 'feature_chns':[16, 32, 64, 128], - 'dropout' : [0, 0, 0, 0.5], - 'trilinear': True, - 'multiscale_pred': True, - 'attention_num_heads': 4, - 'attention_dropout_rate': 0.2} - - Net = MedFormerV3(params) - Net = Net.double() - - x = np.random.rand(2, 4, 96, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - print("output length", len(y)) - for yi in y: - yi = yi.detach().numpy() - print(yi.shape) diff --git a/pymic/net/net3d/trans3d/MedFormer_va1.py b/pymic/net/net3d/trans3d/MedFormer_va1.py deleted file mode 100644 index 27dfa3e..0000000 --- a/pymic/net/net3d/trans3d/MedFormer_va1.py +++ /dev/null @@ -1,105 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division - -import math -import torch -import torch.nn as nn -import numpy as np -from torch.nn import Dropout, Softmax, Linear, Conv2d, LayerNorm -from pymic.net.net3d.unet3d import Decoder - -class EmbeddingBlock(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, padding, stride): - super(EmbeddingBlock, self).__init__() - self.out_channels = out_channels - self.conv1 = nn.Conv3d(in_channels, out_channels//2, kernel_size=kernel_size, padding=padding, stride = stride) - self.conv2 = nn.Conv3d(out_channels//2, out_channels, kernel_size=1) - self.act = nn.GELU() - self.norm1 = nn.LayerNorm(out_channels//2) - self.norm2 = nn.LayerNorm(out_channels) - - - def forward(self, x): - x = self.act(self.conv1(x)) - # norm 1 - Ws, Wh, Ww = x.size(2), x.size(3), x.size(4) - x = x.flatten(2).transpose(1, 2).contiguous() - x = self.norm1(x) - x = x.transpose(1, 2).contiguous().view(-1, self.out_channels // 2, Ws, Wh, Ww) - - x = self.act(self.conv2(x)) - x = x.flatten(2).transpose(1, 2).contiguous() - x = self.norm2(x) - x = x.transpose(1, 2).contiguous().view(-1, self.out_channels, Ws, Wh, Ww) - - return x - -class Encoder(nn.Module): - """ - Encoder of 3D UNet. - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 4 or 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - """ - def __init__(self, params): - super(Encoder, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - assert(len(self.ft_chns) == 4) - - self.down0 = EmbeddingBlock(self.in_chns, self.ft_chns[0], 3, 1, 1) - self.down1 = EmbeddingBlock(self.in_chns, self.ft_chns[1], 2, 0, 2) - self.down2 = EmbeddingBlock(self.in_chns, self.ft_chns[2], 4, 0, 4) - self.down3 = EmbeddingBlock(self.in_chns, self.ft_chns[3], 8, 0, 8) - - def forward(self, x): - x0 = self.down0(x) - x1 = self.down1(x) - x2 = self.down2(x) - x3 = self.down3(x) - output = [x0, x1, x2, x3] - return output - -class MedFormerVA1(nn.Module): - def __init__(self, params): - super(MedFormerVA1, self).__init__() - self.params = params - self.encoder = Encoder(params) - self.decoder = Decoder(params) - - def forward(self, x): - f = self.encoder(x) - output = self.decoder(f) - return output - - -if __name__ == "__main__": - params = {'in_chns':1, - 'class_num': 8, - 'feature_chns':[16, 32, 64, 128], - 'dropout' : [0, 0, 0, 0.5], - 'trilinear': True, - 'deep_supervise': True, - 'attention_hidden_size': 128, - 'attention_num_heads': 4, - 'attention_mlp_dim': 256, - 'attention_dropout_rate': 0.2} - Net = MedFormerVA1(params) - Net = Net.double() - - x = np.random.rand(1, 1, 96, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - print("output length", len(y)) - for yi in y: - yi = yi.detach().numpy() - print(yi.shape) \ No newline at end of file diff --git a/pymic/net/net3d/trans3d/__init__.py b/pymic/net/net3d/trans3d/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/pymic/net/net3d/trans3d/nnFormer_wrap.py b/pymic/net/net3d/trans3d/nnFormer_wrap.py deleted file mode 100644 index 35593a4..0000000 --- a/pymic/net/net3d/trans3d/nnFormer_wrap.py +++ /dev/null @@ -1,43 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division - -import math -import torch -import torch.nn as nn -import numpy as np -from nnformer.network_architecture.nnFormer_tumor import nnFormer - -class nnFormer_wrap(nn.Module): - def __init__(self, params): - super(nnFormer_wrap, self).__init__() - patch_size = params["patch_size"] # 96x96x96 - n_class = params['class_num'] - in_chns = params['in_chns'] - # https://github.com/282857341/nnFormer/blob/main/nnformer/network_architecture/nnFormer_tumor.py - self.nnformer = nnFormer(crop_size = patch_size, - embedding_dim=192, - input_channels = in_chns, - num_classes = n_class, - conv_op=nn.Conv3d, - depths =[2,2,2,2], - num_heads = [6, 12, 24, 48], - patch_size = [4,4,4], - window_size= [4,4,8,4], - deep_supervision=False) - - def forward(self, x): - return self.nnformer(x) - -if __name__ == "__main__": - params = {"patch_size": [96, 96, 96], - "in_chns": 1, - "class_num": 5} - Net = nnFormer_wrap(params) - Net = Net.double() - - x = np.random.rand(1, 1, 96, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - print(y.shape) diff --git a/pymic/net/net3d/trans3d/unetr.py b/pymic/net/net3d/trans3d/unetr.py deleted file mode 100644 index ea90b2f..0000000 --- a/pymic/net/net3d/trans3d/unetr.py +++ /dev/null @@ -1,227 +0,0 @@ -from __future__ import print_function, division - -import torch -import torch.nn as nn - -from monai.networks.blocks import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock -from monai.networks.blocks.dynunet_block import UnetOutBlock -from monai.networks.nets import ViT - - -class UNETR(nn.Module): - """ - UNETR based on: "Hatamizadeh et al., - UNETR: Transformers for 3D Medical Image Segmentation " - """ - - def __init__(self, params): - # in_channels: int, - # out_channels: int, - # img_size: Tuple[int, int, int], - # feature_size: int = 16, - # hidden_size: int = 768, - # mlp_dim: int = 3072, - # num_heads: int = 12, - # pos_embed: str = "perceptron", - # norm_name: Union[Tuple, str] = "instance", - # conv_block: bool = False, - # res_block: bool = True, - # dropout_rate: float = 0.0, - # ) -> None: - """ - Args: - in_channels: dimension of input channels. - out_channels: dimension of output channels. - img_size: dimension of input image. - feature_size: dimension of network feature size. - hidden_size: dimension of hidden layer. - mlp_dim: dimension of feedforward layer. - num_heads: number of attention heads. - pos_embed: position embedding layer type. - norm_name: feature normalization type and arguments. - conv_block: bool argument to determine if convolutional block is used. - res_block: bool argument to determine if residual block is used. - dropout_rate: faction of the input units to drop. - Examples:: - # for single channel input 4-channel output with patch size of (96,96,96), feature size of 32 and batch norm - >>> net = UNETR(in_channels=1, out_channels=4, img_size=(96,96,96), feature_size=32, norm_name='batch') - # for 4-channel input 3-channel output with patch size of (128,128,128), conv position embedding and instance norm - >>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), pos_embed='conv', norm_name='instance') - """ - - super().__init__() - in_channels = params['in_chns'] - out_channels = params['class_num'] - img_size = params['img_size'] - feature_size = 16 - hidden_size = 768 - mlp_dim = 3072 - num_heads = 12 - pos_embed = "perceptron" - norm_name = "instance" - conv_block = False - res_block = True - dropout_rate = 0.0 - - if not (0 <= dropout_rate <= 1): - raise AssertionError("dropout_rate should be between 0 and 1.") - - if hidden_size % num_heads != 0: - raise AssertionError("hidden size should be divisible by num_heads.") - - if pos_embed not in ["conv", "perceptron"]: - raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.") - - self.num_layers = 12 - self.patch_size = (16, 16, 16) - self.feat_size = ( - img_size[0] // self.patch_size[0], - img_size[1] // self.patch_size[1], - img_size[2] // self.patch_size[2], - ) - self.hidden_size = hidden_size - self.classification = False - self.vit = ViT( - in_channels=in_channels, - img_size=img_size, - patch_size=self.patch_size, - hidden_size=hidden_size, - mlp_dim=mlp_dim, - num_layers=self.num_layers, - num_heads=num_heads, - pos_embed=pos_embed, - classification=self.classification, - dropout_rate=dropout_rate, - ) - self.encoder1 = UnetrBasicBlock( - spatial_dims=3, - in_channels=in_channels, - out_channels=feature_size, - kernel_size=3, - stride=1, - norm_name=norm_name, - res_block=res_block, - ) - self.encoder2 = UnetrPrUpBlock( - spatial_dims=3, - in_channels=hidden_size, - out_channels=feature_size * 2, - num_layer=2, - kernel_size=3, - stride=1, - upsample_kernel_size=2, - norm_name=norm_name, - conv_block=conv_block, - res_block=res_block, - ) - self.encoder3 = UnetrPrUpBlock( - spatial_dims=3, - in_channels=hidden_size, - out_channels=feature_size * 4, - num_layer=1, - kernel_size=3, - stride=1, - upsample_kernel_size=2, - norm_name=norm_name, - conv_block=conv_block, - res_block=res_block, - ) - self.encoder4 = UnetrPrUpBlock( - spatial_dims=3, - in_channels=hidden_size, - out_channels=feature_size * 8, - num_layer=0, - kernel_size=3, - stride=1, - upsample_kernel_size=2, - norm_name=norm_name, - conv_block=conv_block, - res_block=res_block, - ) - self.decoder5 = UnetrUpBlock( - spatial_dims=3, - in_channels=hidden_size, - out_channels=feature_size * 8, - kernel_size=3, - upsample_kernel_size=2, - norm_name=norm_name, - res_block=res_block, - ) - self.decoder4 = UnetrUpBlock( - spatial_dims=3, - in_channels=feature_size * 8, - out_channels=feature_size * 4, - kernel_size=3, - upsample_kernel_size=2, - norm_name=norm_name, - res_block=res_block, - ) - self.decoder3 = UnetrUpBlock( - spatial_dims=3, - in_channels=feature_size * 4, - out_channels=feature_size * 2, - kernel_size=3, - upsample_kernel_size=2, - norm_name=norm_name, - res_block=res_block, - ) - self.decoder2 = UnetrUpBlock( - spatial_dims=3, - in_channels=feature_size * 2, - out_channels=feature_size, - kernel_size=3, - upsample_kernel_size=2, - norm_name=norm_name, - res_block=res_block, - ) - self.out = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels) # type: ignore - - def proj_feat(self, x, hidden_size, feat_size): - x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) - x = x.permute(0, 4, 1, 2, 3).contiguous() - return x - - def load_from(self, weights): - with torch.no_grad(): - res_weight = weights - # copy weights from patch embedding - for i in weights["state_dict"]: - print(i) - self.vit.patch_embedding.position_embeddings.copy_( - weights["state_dict"]["module.transformer.patch_embedding.position_embeddings_3d"] - ) - self.vit.patch_embedding.cls_token.copy_( - weights["state_dict"]["module.transformer.patch_embedding.cls_token"] - ) - self.vit.patch_embedding.patch_embeddings[1].weight.copy_( - weights["state_dict"]["module.transformer.patch_embedding.patch_embeddings.1.weight"] - ) - self.vit.patch_embedding.patch_embeddings[1].bias.copy_( - weights["state_dict"]["module.transformer.patch_embedding.patch_embeddings.1.bias"] - ) - - # copy weights from encoding blocks (default: num of blocks: 12) - for bname, block in self.vit.blocks.named_children(): - print(block) - block.loadFrom(weights, n_block=bname) - # last norm layer of transformer - self.vit.norm.weight.copy_(weights["state_dict"]["module.transformer.norm.weight"]) - self.vit.norm.bias.copy_(weights["state_dict"]["module.transformer.norm.bias"]) - - def forward(self, x_in): - x, hidden_states_out = self.vit(x_in) - enc1 = self.encoder1(x_in) - x2 = hidden_states_out[3] - enc2 = self.encoder2(self.proj_feat(x2, self.hidden_size, self.feat_size)) - x3 = hidden_states_out[6] - enc3 = self.encoder3(self.proj_feat(x3, self.hidden_size, self.feat_size)) - x4 = hidden_states_out[9] - enc4 = self.encoder4(self.proj_feat(x4, self.hidden_size, self.feat_size)) - dec4 = self.proj_feat(x, self.hidden_size, self.feat_size) - dec3 = self.decoder5(dec4, enc4) - dec2 = self.decoder4(dec3, enc3) - dec1 = self.decoder3(dec2, enc2) - out = self.decoder2(dec1, enc1) - logits = self.out(out) - return logits - diff --git a/pymic/net/net3d/trans3d/unetr_pp.py b/pymic/net/net3d/trans3d/unetr_pp.py deleted file mode 100644 index a4ab7e6..0000000 --- a/pymic/net/net3d/trans3d/unetr_pp.py +++ /dev/null @@ -1,469 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from typing import Optional, Sequence, Tuple, Union -from pymic.net.net3d.trans3d.unetr_pp_block import UnetOutBlock, UnetResBlock, get_conv_layer -from timm.models.layers import trunc_normal_ -from monai.utils import optional_import -from monai.networks.blocks.convolutions import Convolution -from monai.networks.layers.factories import Act, Norm -from monai.networks.layers.utils import get_act_layer, get_norm_layer - -einops, _ = optional_import("einops") - -class LayerNorm(nn.Module): - def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): - super().__init__() - self.weight = nn.Parameter(torch.ones(normalized_shape)) - self.bias = nn.Parameter(torch.zeros(normalized_shape)) - self.eps = eps - self.data_format = data_format - if self.data_format not in ["channels_last", "channels_first"]: - raise NotImplementedError - self.normalized_shape = (normalized_shape,) - - def forward(self, x): - if self.data_format == "channels_last": - return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) - elif self.data_format == "channels_first": - u = x.mean(1, keepdim=True) - s = (x - u).pow(2).mean(1, keepdim=True) - x = (x - u) / torch.sqrt(s + self.eps) - x = self.weight[:, None, None] * x + self.bias[:, None, None] - return x - -class EPA(nn.Module): - """ - Efficient Paired Attention Block, based on: "Shaker et al., - UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" - """ - def __init__(self, input_size, hidden_size, proj_size, num_heads=4, qkv_bias=False, - channel_attn_drop=0.1, spatial_attn_drop=0.1): - super().__init__() - self.num_heads = num_heads - self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) - self.temperature2 = nn.Parameter(torch.ones(num_heads, 1, 1)) - - # qkvv are 4 linear layers (query_shared, key_shared, value_spatial, value_channel) - self.qkvv = nn.Linear(hidden_size, hidden_size * 4, bias=qkv_bias) - - # E and F are projection matrices with shared weights used in spatial attention module to project - # keys and values from HWD-dimension to P-dimension - self.E = self.F = nn.Linear(input_size, proj_size) - - self.attn_drop = nn.Dropout(channel_attn_drop) - self.attn_drop_2 = nn.Dropout(spatial_attn_drop) - - self.out_proj = nn.Linear(hidden_size, int(hidden_size // 2)) - self.out_proj2 = nn.Linear(hidden_size, int(hidden_size // 2)) - - def forward(self, x): - B, N, C = x.shape - - qkvv = self.qkvv(x).reshape(B, N, 4, self.num_heads, C // self.num_heads) - - qkvv = qkvv.permute(2, 0, 3, 1, 4) - - q_shared, k_shared, v_CA, v_SA = qkvv[0], qkvv[1], qkvv[2], qkvv[3] - - q_shared = q_shared.transpose(-2, -1) - k_shared = k_shared.transpose(-2, -1) - v_CA = v_CA.transpose(-2, -1) - v_SA = v_SA.transpose(-2, -1) - - k_shared_projected = self.E(k_shared) - - v_SA_projected = self.F(v_SA) - - q_shared = torch.nn.functional.normalize(q_shared, dim=-1) - k_shared = torch.nn.functional.normalize(k_shared, dim=-1) - - attn_CA = (q_shared @ k_shared.transpose(-2, -1)) * self.temperature - - attn_CA = attn_CA.softmax(dim=-1) - attn_CA = self.attn_drop(attn_CA) - - x_CA = (attn_CA @ v_CA).permute(0, 3, 1, 2).reshape(B, N, C) - - attn_SA = (q_shared.permute(0, 1, 3, 2) @ k_shared_projected) * self.temperature2 - - attn_SA = attn_SA.softmax(dim=-1) - attn_SA = self.attn_drop_2(attn_SA) - - x_SA = (attn_SA @ v_SA_projected.transpose(-2, -1)).permute(0, 3, 1, 2).reshape(B, N, C) - - # Concat fusion - x_SA = self.out_proj(x_SA) - x_CA = self.out_proj2(x_CA) - x = torch.cat((x_SA, x_CA), dim=-1) - return x - - @torch.jit.ignore - def no_weight_decay(self): - return {'temperature', 'temperature2'} - - -class TransformerBlock(nn.Module): - """ - A transformer block, based on: "Shaker et al., - UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" - """ - - def __init__( - self, - input_size: int, - hidden_size: int, - proj_size: int, - num_heads: int, - dropout_rate: float = 0.0, - pos_embed=False, - ) -> None: - """ - Args: - input_size: the size of the input for each stage. - hidden_size: dimension of hidden layer. - proj_size: projection size for keys and values in the spatial attention module. - num_heads: number of attention heads. - dropout_rate: faction of the input units to drop. - pos_embed: bool argument to determine if positional embedding is used. - - """ - - super().__init__() - - if not (0 <= dropout_rate <= 1): - raise ValueError("dropout_rate should be between 0 and 1.") - - if hidden_size % num_heads != 0: - print("Hidden size is ", hidden_size) - print("Num heads is ", num_heads) - raise ValueError("hidden_size should be divisible by num_heads.") - - self.norm = nn.LayerNorm(hidden_size) - self.gamma = nn.Parameter(1e-6 * torch.ones(hidden_size), requires_grad=True) - self.epa_block = EPA(input_size=input_size, hidden_size=hidden_size, proj_size=proj_size, num_heads=num_heads, channel_attn_drop=dropout_rate,spatial_attn_drop=dropout_rate) - self.conv51 = UnetResBlock(3, hidden_size, hidden_size, kernel_size=3, stride=1, norm_name="batch") - self.conv8 = nn.Sequential(nn.Dropout3d(0.1, False), nn.Conv3d(hidden_size, hidden_size, 1)) - - self.pos_embed = None - if pos_embed: - self.pos_embed = nn.Parameter(torch.zeros(1, input_size, hidden_size)) - - def forward(self, x): - B, C, H, W, D = x.shape - x = x.reshape(B, C, H * W * D).permute(0, 2, 1) - - if self.pos_embed is not None: - x = x + self.pos_embed - attn = x + self.gamma * self.epa_block(self.norm(x)) - - attn_skip = attn.reshape(B, H, W, D, C).permute(0, 4, 1, 2, 3) # (B, C, H, W, D) - attn = self.conv51(attn_skip) - x = attn_skip + self.conv8(attn) - - return x - -class UnetrPPEncoder(nn.Module): - def __init__(self, input_size=[32 * 32 * 32, 16 * 16 * 16, 8 * 8 * 8, 4 * 4 * 4],dims=[32, 64, 128, 256], - proj_size =[64,64,64,32], depths=[3, 3, 3, 3], num_heads=4, spatial_dims=3, - in_channels=1, dropout=0.0, transformer_dropout_rate=0.15, kernel_size=(2,4,4), **kwargs): - super().__init__() - - self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers - stem_layer = nn.Sequential( - get_conv_layer(spatial_dims, in_channels, dims[0], kernel_size=kernel_size, stride=kernel_size, - dropout=dropout, conv_only=True, ), - get_norm_layer(name=("group", {"num_groups": in_channels}), channels=dims[0]), - ) - self.downsample_layers.append(stem_layer) - for i in range(3): - downsample_layer = nn.Sequential( - get_conv_layer(spatial_dims, dims[i], dims[i + 1], kernel_size=(2, 2, 2), stride=(2, 2, 2), - dropout=dropout, conv_only=True, ), - get_norm_layer(name=("group", {"num_groups": dims[i]}), channels=dims[i + 1]), - ) - self.downsample_layers.append(downsample_layer) - - self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple Transformer blocks - for i in range(4): - stage_blocks = [] - for j in range(depths[i]): - stage_blocks.append(TransformerBlock(input_size=input_size[i], hidden_size=dims[i], proj_size=proj_size[i], num_heads=num_heads, - dropout_rate=transformer_dropout_rate, pos_embed=True)) - self.stages.append(nn.Sequential(*stage_blocks)) - self.hidden_states = [] - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, (nn.Conv2d, nn.Linear)): - trunc_normal_(m.weight, std=.02) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, (LayerNorm, nn.LayerNorm)): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - def forward_features(self, x): - hidden_states = [] - x = self.downsample_layers[0](x) - x = self.stages[0](x) - - hidden_states.append(x) - - for i in range(1, 4): - x = self.downsample_layers[i](x) - x = self.stages[i](x) - if i == 3: # Reshape the output of the last stage - x = einops.rearrange(x, "b c h w d -> b (h w d) c") - hidden_states.append(x) - return x, hidden_states - - def forward(self, x): - x, hidden_states = self.forward_features(x) - return x, hidden_states - - -class UnetrUpBlock(nn.Module): - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Union[Sequence[int], int], - upsample_kernel_size: Union[Sequence[int], int], - norm_name: Union[Tuple, str], - proj_size: int = 64, - num_heads: int = 4, - out_size: int = 0, - depth: int = 3, - conv_decoder: bool = False, - ) -> None: - """ - Args: - spatial_dims: number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - kernel_size: convolution kernel size. - upsample_kernel_size: convolution kernel size for transposed convolution layers. - norm_name: feature normalization type and arguments. - proj_size: projection size for keys and values in the spatial attention module. - num_heads: number of heads inside each EPA module. - out_size: spatial size for each decoder. - depth: number of blocks for the current decoder stage. - """ - - super().__init__() - upsample_stride = upsample_kernel_size - self.transp_conv = get_conv_layer( - spatial_dims, - in_channels, - out_channels, - kernel_size=upsample_kernel_size, - stride=upsample_stride, - conv_only=True, - is_transposed=True, - ) - - # 4 feature resolution stages, each consisting of multiple residual blocks - self.decoder_block = nn.ModuleList() - - # If this is the last decoder, use ConvBlock(UnetResBlock) instead of EPA_Block (see suppl. material in the paper) - if conv_decoder == True: - self.decoder_block.append( - UnetResBlock(spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1, - norm_name=norm_name, )) - else: - stage_blocks = [] - for j in range(depth): - stage_blocks.append(TransformerBlock(input_size=out_size, hidden_size= out_channels, proj_size=proj_size, num_heads=num_heads, - dropout_rate=0.15, pos_embed=True)) - self.decoder_block.append(nn.Sequential(*stage_blocks)) - - def _init_weights(self, m): - if isinstance(m, (nn.Conv2d, nn.Linear)): - trunc_normal_(m.weight, std=.02) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, (nn.LayerNorm)): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - def forward(self, inp, skip): - - out = self.transp_conv(inp) - out = out + skip - out = self.decoder_block[0](out) - - return out - - -class UNETR_PP(nn.Module): - """ - UNETR++ based on: "Shaker et al., - UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" - """ - - def __init__(self, params): - """ - Args: - in_channels: dimension of input channels. - out_channels: dimension of output channels. - img_size: dimension of input image. - feature_size: dimension of network feature size. - hidden_size: dimension of the last encoder. - num_heads: number of attention heads. - pos_embed: position embedding layer type. - norm_name: feature normalization type and arguments. - dropout_rate: faction of the input units to drop. - depths: number of blocks for each stage. - dims: number of channel maps for the stages. - conv_op: type of convolution operation. - do_ds: use deep supervision to compute the loss. - - """ - super().__init__() - in_channels = params['in_chns'] - out_channels = params['class_num'] - img_size = params['img_size'] - self.res_mode= params.get("resolution_mode", 1) - feature_size = params.get('feature_size', 16) - hidden_size = params.get('hidden_size', 256) - num_heads = params.get('num_heads', 4) - pos_embed = params.get('pos_embed', "perceptron") - norm_name = params.get('norm_name', "instance") - dropout_rate = params.get('dropout_rate', 0.0) - depths = params.get('depths', [3, 3, 3, 3]) - dims = params.get('dims', [32, 64, 128, 256]) - conv_op = nn.Conv3d - do_ds = params.get('deep_supervise', True) - - self.do_ds = do_ds - self.conv_op = conv_op - self.num_classes = out_channels - if not (0 <= dropout_rate <= 1): - raise AssertionError("dropout_rate should be between 0 and 1.") - - if pos_embed not in ["conv", "perceptron"]: - raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.") - - kernel_ds = [4, 2, 1] - kernel_d = kernel_ds[self.res_mode] - self.patch_size = (kernel_d, 4, 4) - - self.feat_size = ( - img_size[0] // self.patch_size[0] // 8, # 8 is the downsampling happened through the four encoders stages - img_size[1] // self.patch_size[1] // 8, # 8 is the downsampling happened through the four encoders stages - img_size[2] // self.patch_size[2] // 8, # 8 is the downsampling happened through the four encoders stages - ) - - self.hidden_size = hidden_size - - self.unetr_pp_encoder = UnetrPPEncoder(dims=dims, depths=depths, num_heads=num_heads, - in_channels=in_channels, kernel_size=self.patch_size) - - self.encoder1 = UnetResBlock( - spatial_dims=3, - in_channels=in_channels, - out_channels=feature_size, - kernel_size=3, - stride=1, - norm_name=norm_name, - ) - self.decoder5 = UnetrUpBlock( - spatial_dims=3, - in_channels=feature_size * 16, - out_channels=feature_size * 8, - kernel_size=3, - upsample_kernel_size=2, - norm_name=norm_name, - out_size=8 * 8 * 8, - ) - self.decoder4 = UnetrUpBlock( - spatial_dims=3, - in_channels=feature_size * 8, - out_channels=feature_size * 4, - kernel_size=3, - upsample_kernel_size=2, - norm_name=norm_name, - out_size=16 * 16 * 16, - ) - self.decoder3 = UnetrUpBlock( - spatial_dims=3, - in_channels=feature_size * 4, - out_channels=feature_size * 2, - kernel_size=3, - upsample_kernel_size=2, - norm_name=norm_name, - out_size=32 * 32 * 32, - ) - - self.decoder2 = UnetrUpBlock( - spatial_dims=3, - in_channels=feature_size * 2, - out_channels=feature_size, - kernel_size=3, - upsample_kernel_size= self.patch_size, - norm_name=norm_name, - out_size= kernel_d*32 * 128 * 128, - conv_decoder=True, - ) - self.out1 = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels) - # if self.do_ds: - self.out2 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 2, out_channels=out_channels) - self.out3 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 4, out_channels=out_channels) - - def proj_feat(self, x, hidden_size, feat_size): - x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) - x = x.permute(0, 4, 1, 2, 3).contiguous() - return x - - def forward(self, x_in): - x_output, hidden_states = self.unetr_pp_encoder(x_in) - - convBlock = self.encoder1(x_in) - - # Four encoders - enc1 = hidden_states[0] - enc2 = hidden_states[1] - enc3 = hidden_states[2] - enc4 = hidden_states[3] - - # Four decoders - dec4 = self.proj_feat(enc4, self.hidden_size, self.feat_size) - dec3 = self.decoder5(dec4, enc3) - dec2 = self.decoder4(dec3, enc2) - dec1 = self.decoder3(dec2, enc1) - - out = self.decoder2(dec1, convBlock) - if self.do_ds: - logits = [self.out1(out), self.out2(dec1), self.out3(dec2)] - else: - logits = self.out1(out) - - return logits - - -if __name__ == "__main__": - depths = [128, 64, 32] - for i in range(3): - params = {'in_chns': 4, - 'class_num': 2, - 'img_size': [depths[i], 128, 128], - 'resolution_mode': i - } - net = UNETR_PP(params) - net.double() - - x = np.random.rand(2, 4, depths[i], 128, 128) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = net(xt) - print(len(y)) - for yi in y: - yi = yi.detach().numpy() - print(yi.shape) \ No newline at end of file diff --git a/pymic/net/net3d/trans3d/unetr_pp_block.py b/pymic/net/net3d/trans3d/unetr_pp_block.py deleted file mode 100644 index 89a8769..0000000 --- a/pymic/net/net3d/trans3d/unetr_pp_block.py +++ /dev/null @@ -1,278 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division - -import numpy as np -import torch -import torch.nn as nn -from typing import Optional, Sequence, Tuple, Union -from monai.networks.blocks.convolutions import Convolution -from monai.networks.layers.factories import Act, Norm -from monai.networks.layers.utils import get_act_layer, get_norm_layer - - -class UnetResBlock(nn.Module): - """ - A skip-connection based module that can be used for DynUNet, based on: - `Automated Design of Deep Learning Methods for Biomedical Image Segmentation `_. - `nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation `_. - - Args: - spatial_dims: number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - kernel_size: convolution kernel size. - stride: convolution stride. - norm_name: feature normalization type and arguments. - act_name: activation layer type and arguments. - dropout: dropout probability. - - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Union[Sequence[int], int], - stride: Union[Sequence[int], int], - norm_name: Union[Tuple, str], - act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), - dropout: Optional[Union[Tuple, str, float]] = None, - ): - super().__init__() - self.conv1 = get_conv_layer( - spatial_dims, - in_channels, - out_channels, - kernel_size=kernel_size, - stride=stride, - dropout=dropout, - conv_only=True, - ) - self.conv2 = get_conv_layer( - spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1, dropout=dropout, conv_only=True - ) - self.lrelu = get_act_layer(name=act_name) - self.norm1 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) - self.norm2 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) - self.downsample = in_channels != out_channels - stride_np = np.atleast_1d(stride) - if not np.all(stride_np == 1): - self.downsample = True - if self.downsample: - self.conv3 = get_conv_layer( - spatial_dims, in_channels, out_channels, kernel_size=1, stride=stride, dropout=dropout, conv_only=True - ) - self.norm3 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) - - def forward(self, inp): - residual = inp - out = self.conv1(inp) - out = self.norm1(out) - out = self.lrelu(out) - out = self.conv2(out) - out = self.norm2(out) - if hasattr(self, "conv3"): - residual = self.conv3(residual) - if hasattr(self, "norm3"): - residual = self.norm3(residual) - out += residual - out = self.lrelu(out) - return out - - -class UnetBasicBlock(nn.Module): - """ - A CNN module module that can be used for DynUNet, based on: - `Automated Design of Deep Learning Methods for Biomedical Image Segmentation `_. - `nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation `_. - - Args: - spatial_dims: number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - kernel_size: convolution kernel size. - stride: convolution stride. - norm_name: feature normalization type and arguments. - act_name: activation layer type and arguments. - dropout: dropout probability. - - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Union[Sequence[int], int], - stride: Union[Sequence[int], int], - norm_name: Union[Tuple, str], - act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), - dropout: Optional[Union[Tuple, str, float]] = None, - ): - super().__init__() - self.conv1 = get_conv_layer( - spatial_dims, - in_channels, - out_channels, - kernel_size=kernel_size, - stride=stride, - dropout=dropout, - conv_only=True, - ) - self.conv2 = get_conv_layer( - spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1, dropout=dropout, conv_only=True - ) - self.lrelu = get_act_layer(name=act_name) - self.norm1 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) - self.norm2 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) - - def forward(self, inp): - out = self.conv1(inp) - out = self.norm1(out) - out = self.lrelu(out) - out = self.conv2(out) - out = self.norm2(out) - out = self.lrelu(out) - return out - - -class UnetUpBlock(nn.Module): - """ - An upsampling module that can be used for DynUNet, based on: - `Automated Design of Deep Learning Methods for Biomedical Image Segmentation `_. - `nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation `_. - - Args: - spatial_dims: number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - kernel_size: convolution kernel size. - stride: convolution stride. - upsample_kernel_size: convolution kernel size for transposed convolution layers. - norm_name: feature normalization type and arguments. - act_name: activation layer type and arguments. - dropout: dropout probability. - trans_bias: transposed convolution bias. - - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Union[Sequence[int], int], - stride: Union[Sequence[int], int], - upsample_kernel_size: Union[Sequence[int], int], - norm_name: Union[Tuple, str], - act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), - dropout: Optional[Union[Tuple, str, float]] = None, - trans_bias: bool = False, - ): - super().__init__() - upsample_stride = upsample_kernel_size - self.transp_conv = get_conv_layer( - spatial_dims, - in_channels, - out_channels, - kernel_size=upsample_kernel_size, - stride=upsample_stride, - dropout=dropout, - bias=trans_bias, - conv_only=True, - is_transposed=True, - ) - self.conv_block = UnetBasicBlock( - spatial_dims, - out_channels + out_channels, - out_channels, - kernel_size=kernel_size, - stride=1, - dropout=dropout, - norm_name=norm_name, - act_name=act_name, - ) - - def forward(self, inp, skip): - # number of channels for skip should equals to out_channels - out = self.transp_conv(inp) - out = torch.cat((out, skip), dim=1) - out = self.conv_block(out) - return out - - -class UnetOutBlock(nn.Module): - def __init__( - self, spatial_dims: int, in_channels: int, out_channels: int, dropout: Optional[Union[Tuple, str, float]] = None - ): - super().__init__() - self.conv = get_conv_layer( - spatial_dims, in_channels, out_channels, kernel_size=1, stride=1, dropout=dropout, bias=True, conv_only=True - ) - - def forward(self, inp): - return self.conv(inp) - - -def get_conv_layer( - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Union[Sequence[int], int] = 3, - stride: Union[Sequence[int], int] = 1, - act: Optional[Union[Tuple, str]] = Act.PRELU, - norm: Union[Tuple, str] = Norm.INSTANCE, - dropout: Optional[Union[Tuple, str, float]] = None, - bias: bool = False, - conv_only: bool = True, - is_transposed: bool = False, -): - padding = get_padding(kernel_size, stride) - output_padding = None - if is_transposed: - output_padding = get_output_padding(kernel_size, stride, padding) - return Convolution( - spatial_dims, - in_channels, - out_channels, - strides=stride, - kernel_size=kernel_size, - act=act, - norm=norm, - dropout=dropout, - bias=bias, - conv_only=conv_only, - is_transposed=is_transposed, - padding=padding, - output_padding=output_padding, - ) - - -def get_padding( - kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int] -) -> Union[Tuple[int, ...], int]: - - kernel_size_np = np.atleast_1d(kernel_size) - stride_np = np.atleast_1d(stride) - padding_np = (kernel_size_np - stride_np + 1) / 2 - if np.min(padding_np) < 0: - raise AssertionError("padding value should not be negative, please change the kernel size and/or stride.") - padding = tuple(int(p) for p in padding_np) - - return padding if len(padding) > 1 else padding[0] - - -def get_output_padding( - kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], padding: Union[Sequence[int], int] -) -> Union[Tuple[int, ...], int]: - kernel_size_np = np.atleast_1d(kernel_size) - stride_np = np.atleast_1d(stride) - padding_np = np.atleast_1d(padding) - - out_padding_np = 2 * padding_np + stride_np - kernel_size_np - if np.min(out_padding_np) < 0: - raise AssertionError("out_padding value should not be negative, please change the kernel size and/or stride.") - out_padding = tuple(int(p) for p in out_padding_np) - - return out_padding if len(out_padding) > 1 else out_padding[0] diff --git a/pymic/net_run/agent_preprocess.py b/pymic/net_run/agent_preprocess.py index c681de9..67b1262 100644 --- a/pymic/net_run/agent_preprocess.py +++ b/pymic/net_run/agent_preprocess.py @@ -8,8 +8,8 @@ from pymic.io.image_read_write import save_nd_array_as_image from pymic.io.nifty_dataset import NiftyDataset from pymic.transform.trans_dict import TransformDict - - +from pymic.net_run.agent_abstract import seed_torch +from pymic.net_run.self_sup.util import volume_fusion class PreprocessAgent(object): def __init__(self, config): @@ -19,9 +19,14 @@ def __init__(self, config): self.task_type = config['dataset']['task_type'] self.dataloader = None self.dataloader_unlab= None + + deterministic = config['dataset'].get('deterministic', True) + if(deterministic): + random_seed = config['dataset'].get('random_seed', 1) + seed_torch(random_seed) def get_dataset_from_config(self): - root_dir = self.config['dataset']['root_dir'] + root_dir = self.config['dataset']['data_dir'] modal_num = self.config['dataset'].get('modal_num', 1) transform_names = self.config['dataset']["transform"] @@ -40,6 +45,8 @@ def get_dataset_from_config(self): data_csv = self.config['dataset'].get('data_csv', None) data_csv_unlab = self.config['dataset'].get('data_csv_unlab', None) + batch_size = self.config['dataset'].get('batch_size', 1) + data_shuffle = self.config['dataset'].get('data_shuffle', False) if(data_csv is not None): dataset = NiftyDataset(root_dir = root_dir, csv_file = data_csv, @@ -48,7 +55,7 @@ def get_dataset_from_config(self): transform = data_transform, task = self.task_type) self.dataloader = torch.utils.data.DataLoader(dataset, - batch_size = 1, shuffle=False, num_workers= 8, + batch_size = batch_size, shuffle=data_shuffle, num_workers= 8, worker_init_fn=None, generator = torch.Generator()) if(data_csv_unlab is not None): dataset_unlab = NiftyDataset(root_dir = root_dir, @@ -58,7 +65,7 @@ def get_dataset_from_config(self): transform = data_transform, task = self.task_type) self.dataloader_unlab = torch.utils.data.DataLoader(dataset_unlab, - batch_size = 1, shuffle=False, num_workers= 8, + batch_size = batch_size, shuffle=data_shuffle, num_workers= 8, worker_init_fn=None, generator = torch.Generator()) def run(self): @@ -67,38 +74,35 @@ def run(self): """ self.get_dataset_from_config() out_dir = self.config['dataset']['output_dir'] + if(not os.path.isdir(out_dir)): + os.mkdir(out_dir) + batch_operation = self.config['dataset'].get('batch_operation', None) for dataloader in [self.dataloader, self.dataloader_unlab]: - for item in dataloader: - img = item['image'][0] # the batch size is 1 - # save differnt modaliteis - img_names = item['names'] - spacing = [x.numpy()[0] for x in item['spacing']] - for i in range(img.shape[0]): - image_name = out_dir + "/" + img_names[i][0] - print(image_name) - save_nd_array_as_image(img[i], image_name, reference_name = None, spacing=spacing) - if('label' in item): - lab = item['label'][0] - label_name = out_dir + "/" + img_names[-1][0] - print(label_name) - save_nd_array_as_image(lab[0], label_name, reference_name = None, spacing=spacing) - -def main(): - """ - The main function for data preprocessing. - """ - if(len(sys.argv) < 2): - print('Number of arguments should be 2. e.g.') - print(' pymic_preprocess config.cfg') - exit() - cfg_file = str(sys.argv[1]) - if(not os.path.isfile(cfg_file)): - raise ValueError("The config file does not exist: " + cfg_file) - config = parse_config(cfg_file) - config = synchronize_config(config) - agent = PreprocessAgent(config) - agent.run() + if(dataloader is None): + continue + for data in dataloader: + inputs = data['image'] + labels = data.get('label', None) + img_names = data['names'] + lab_names = img_names[-1] + B, C = inputs.shape[0], inputs.shape[1] + spacing = [x.numpy()[0] for x in data['spacing']] + + if(batch_operation is not None and 'VolumeFusion' in batch_operation): + class_num = self.config['dataset']['VolumeFusion_cls_num'.lower()] + block_range = self.config['dataset']['VolumeFusion_block_range'.lower()] + size_min = self.config['dataset']['VolumeFusion_size_min'.lower()] + size_max = self.config['dataset']['VolumeFusion_size_max'.lower()] + if(labels is None): + lab_names = [item.replace(".nii.gz", "_lab.nii.gz") for item in img_names[0]] + inputs, labels = volume_fusion(inputs, class_num - 1, block_range, size_min, size_max) -if __name__ == "__main__": - main() - + for b in range(B): + for c in range(C): + image_name = out_dir + "/" + img_names[c][b] + print(image_name) + save_nd_array_as_image(inputs[b][c], image_name, reference_name = None, spacing=spacing) + if(labels is not None): + label_name = out_dir + "/" + lab_names[b] + print(label_name) + save_nd_array_as_image(labels[b][0], label_name, reference_name = None, spacing=spacing) diff --git a/pymic/net_run/train.py b/pymic/net_run/train.py index 50a5fb7..3a4571f 100644 --- a/pymic/net_run/train.py +++ b/pymic/net_run/train.py @@ -14,7 +14,6 @@ from pymic.net_run.weak_sup import WSLMethodDict from pymic.net_run.self_sup import SelfSupMethodDict from pymic.net_run.noisy_label import NLLMethodDict -# from pymic.net_run.self_sup import SelfSLSegAgent def get_seg_rec_agent(config, sup_type): assert(sup_type in ['fully_sup', 'semi_sup', 'self_sup', 'weak_sup', 'noisy_label']) diff --git a/pymic/util/evaluation_seg.py b/pymic/util/evaluation_seg.py index a9b114b..82939ce 100644 --- a/pymic/util/evaluation_seg.py +++ b/pymic/util/evaluation_seg.py @@ -286,63 +286,65 @@ def evaluation(config): label_list = [label_list] label_fuse = config.get('label_fuse', False) output_name = config.get('output_name', None) - gt_root = config['ground_truth_folder_root'] - seg_root = config['segmentation_folder_root'] + gt_dir = config['ground_truth_folder'] + seg_dirs = config['segmentation_folder'] image_pair_csv = config.get('evaluation_image_pair', None) + if(not isinstance(seg_dirs, (tuple, list))): + seg_dirs = [seg_dirs] if(image_pair_csv is not None): image_pair = pd.read_csv(image_pair_csv) gt_names, seg_names = image_pair.iloc[:, 0], image_pair.iloc[:, 1] else: - seg_names = sorted(os.listdir(seg_root)) + seg_names = sorted(os.listdir(seg_dirs[0])) seg_names = [item for item in seg_names if is_image_name(item)] gt_names = seg_names + for seg_dir in seg_dirs: + for metric in metric_list: + print(metric) + score_all_data = [] + name_score_list= [] + for i in range(len(gt_names)): + gt_full_name = join(gt_dir, gt_names[i]) + seg_full_name = join(seg_dir, seg_names[i]) + s_dict = load_image_as_nd_array(seg_full_name) + g_dict = load_image_as_nd_array(gt_full_name) + s_volume = s_dict["data_array"]; s_spacing = s_dict["spacing"] + g_volume = g_dict["data_array"]; g_spacing = g_dict["spacing"] + # for dim in range(len(s_spacing)): + # assert(s_spacing[dim] == g_spacing[dim]) + + score_vector = get_multi_class_evaluation_score(s_volume, g_volume, label_list, + label_fuse, s_spacing, metric ) + if(len(label_list) > 1): + score_vector.append(np.asarray(score_vector).mean()) + score_all_data.append(score_vector) + name_score_list.append([seg_names[i]] + score_vector) + print(seg_names[i], score_vector) + score_all_data = np.asarray(score_all_data) + score_mean = score_all_data.mean(axis = 0) + score_std = score_all_data.std(axis = 0) + name_score_list.append(['mean'] + list(score_mean)) + name_score_list.append(['std'] + list(score_std)) - for metric in metric_list: - print(metric) - score_all_data = [] - name_score_list= [] - for i in range(len(gt_names)): - gt_full_name = join(gt_root, gt_names[i]) - seg_full_name = join(seg_root, seg_names[i]) - s_dict = load_image_as_nd_array(seg_full_name) - g_dict = load_image_as_nd_array(gt_full_name) - s_volume = s_dict["data_array"]; s_spacing = s_dict["spacing"] - g_volume = g_dict["data_array"]; g_spacing = g_dict["spacing"] - # for dim in range(len(s_spacing)): - # assert(s_spacing[dim] == g_spacing[dim]) - - score_vector = get_multi_class_evaluation_score(s_volume, g_volume, label_list, - label_fuse, s_spacing, metric ) - if(len(label_list) > 1): - score_vector.append(np.asarray(score_vector).mean()) - score_all_data.append(score_vector) - name_score_list.append([seg_names[i]] + score_vector) - print(seg_names[i], score_vector) - score_all_data = np.asarray(score_all_data) - score_mean = score_all_data.mean(axis = 0) - score_std = score_all_data.std(axis = 0) - name_score_list.append(['mean'] + list(score_mean)) - name_score_list.append(['std'] + list(score_std)) - - # save the result as csv - if(output_name is None): - metric_output_name = "{0:}/eval_{1:}.csv".format(seg_root, metric) - else: - metric_output_name = output_name - with open(metric_output_name, mode='w') as csv_file: - csv_writer = csv.writer(csv_file, delimiter=',', - quotechar='"',quoting=csv.QUOTE_MINIMAL) - head = ['image'] + ["class_{0:}".format(i) for i in label_list] - if(len(label_list) > 1): - head = head + ["average"] - csv_writer.writerow(head) - for item in name_score_list: - csv_writer.writerow(item) - - print("{0:} mean ".format(metric), score_mean) - print("{0:} std ".format(metric), score_std) + # save the result as csv + if(output_name is None): + metric_output_name = "{0:}/eval_{1:}.csv".format(seg_dir, metric) + else: + metric_output_name = output_name + with open(metric_output_name, mode='w') as csv_file: + csv_writer = csv.writer(csv_file, delimiter=',', + quotechar='"',quoting=csv.QUOTE_MINIMAL) + head = ['image'] + ["class_{0:}".format(i) for i in label_list] + if(len(label_list) > 1): + head = head + ["average"] + csv_writer.writerow(head) + for item in name_score_list: + csv_writer.writerow(item) + + print("{0:} mean ".format(metric), score_mean) + print("{0:} std ".format(metric), score_std) def main(): """ diff --git a/pymic/util/image_process.py b/pymic/util/image_process.py index c813e5d..158569c 100644 --- a/pymic/util/image_process.py +++ b/pymic/util/image_process.py @@ -38,6 +38,18 @@ def get_ND_bounding_box(volume, margin = None): bb_max[i] = min(bb_max[i] + margin[i], input_shape[i]) return bb_min, bb_max +def get_human_region_from_ct(image, threshold_i = -600, threshold_z = 0.6): + input_shape = image.shape + mask = np.asarray(image > threshold_i) + mask2d = np.mean(mask, axis = 0) > threshold_z + se = np.ones([3,3]) + mask2d = ndimage.binary_opening(mask2d, se, iterations = 2) + mask2d = get_largest_k_components(mask2d, 1) + bbmin, bbmax = get_ND_bounding_box(mask2d, margin = [0, 0]) + bb_min = [0] + bbmin + bb_max = list(input_shape[:1]) + bbmax + return bb_min, bb_max + def crop_ND_volume_with_bounding_box(volume, bb_min, bb_max): """ Extract a subregion form an ND image. diff --git a/pymic/util/parse_config.py b/pymic/util/parse_config.py index a12cc76..0e38b91 100644 --- a/pymic/util/parse_config.py +++ b/pymic/util/parse_config.py @@ -102,24 +102,35 @@ def parse_config(filename): def synchronize_config(config): data_cfg = config['dataset'] - net_cfg = config['network'] - # data_cfg["modal_num"] = net_cfg["in_chns"] data_cfg["task_type"] = TaskDict[data_cfg["task_type"]] - data_cfg["LabelToProbability_class_num".lower()] = net_cfg["class_num"] - if "PartialLabelToProbability" in data_cfg['train_transform']: + if('network' in config): + net_cfg = config['network'] + # data_cfg["modal_num"] = net_cfg["in_chns"] + data_cfg["LabelToProbability_class_num".lower()] = net_cfg["class_num"] + transform = [] + if('transform' in data_cfg and data_cfg['transform'] is not None): + transform.extend(data_cfg['transform']) + if('train_transform' in data_cfg and data_cfg['train_transform'] is not None): + transform.extend(data_cfg['train_transform']) + if('valid_transform' in data_cfg and data_cfg['valid_transform'] is not None): + transform.extend(data_cfg['valid_transform']) + if('test_transform' in data_cfg and data_cfg['test_transform'] is not None): + transform.extend(data_cfg['test_transform']) + if ( "PartialLabelToProbability" in transform and 'network' in config): data_cfg["PartialLabelToProbability_class_num".lower()] = net_cfg["class_num"] patch_size = data_cfg.get('patch_size', None) if(patch_size is not None): - if('Pad' in data_cfg['train_transform']): + if('Pad' in transform and 'Pad_output_size'.lower() not in data_cfg): data_cfg['Pad_output_size'.lower()] = patch_size - if('CenterCrop' in data_cfg['train_transform']): + if('CenterCrop' in transform and 'CenterCrop_output_size'.lower() not in data_cfg): data_cfg['CenterCrop_output_size'.lower()] = patch_size - if('RandomCrop' in data_cfg['train_transform']): + if('RandomCrop' in transform and 'RandomCrop_output_size'.lower() not in data_cfg): data_cfg['RandomCrop_output_size'.lower()] = patch_size - if('RandomResizedCrop' in data_cfg['train_transform']): + if('RandomResizedCrop' in transform and \ + 'RandomResizedCrop_output_size'.lower() not in data_cfg): data_cfg['RandomResizedCrop_output_size'.lower()] = patch_size config['dataset'] = data_cfg - config['network'] = net_cfg + # config['network'] = net_cfg return config def logging_config(config): diff --git a/pymic/util/preprocess.py b/pymic/util/preprocess.py deleted file mode 100644 index c0dc9a1..0000000 --- a/pymic/util/preprocess.py +++ /dev/null @@ -1,63 +0,0 @@ -# -*- coding: utf-8 -*- -import os -import numpy as np -import SimpleITK as sitk -from pymic.io.image_read_write import load_image_as_nd_array -from pymic.transform.trans_dict import TransformDict -from pymic.util.parse_config import parse_config - -def get_transform_list(trans_config_file): - """ - Create a list of transforms given a configuration file. - """ - config = parse_config(trans_config_file) - transform_list = [] - - transform_param = config['dataset'] - transform_param['task'] = 'segmentation' - transform_names = config['dataset']['transform'] - for name in transform_names: - print(name) - if(name not in TransformDict): - raise(ValueError("Undefined transform {0:}".format(name))) - one_transform = TransformDict[name](transform_param) - transform_list.append(one_transform) - return transform_list - -def preprocess_with_transform(transforms, img_in_name, img_out_name, - lab_in_name = None, lab_out_name = None): - """ - Using a list of data transforms for preprocessing, - such as image normalization, cropping, etc. - TODO: support multip-modality preprocessing. - - :param transforms: (list) A list of transform objects. - :param img_in_name: (str) Input file name. - :param img_out_name: (str) Output file name. - :param lab_in_name: (optional, str) If None, load the image's - corresponding label for preprocessing as well. - :param lab_out_name: (optional, str) The output label name. - """ - image_dict = load_image_as_nd_array(img_in_name) - sample = {'image': np.asarray(image_dict['data_array'], np.float32), - 'origin':image_dict['origin'], - 'spacing': image_dict['spacing'], - 'direction':image_dict['direction']} - if(lab_in_name is not None): - label_dict = load_image_as_nd_array(lab_in_name) - sample['label'] = label_dict['data_array'] - for transform in transforms: - sample = transform(sample) - - out_img = sitk.GetImageFromArray(sample['image'][0]) - out_img.SetSpacing(sample['spacing']) - out_img.SetOrigin(sample['origin']) - out_img.SetDirection(sample['direction']) - sitk.WriteImage(out_img, img_out_name) - if(lab_in_name is not None and lab_out_name is not None): - out_lab = sitk.GetImageFromArray(sample['label'][0]) - out_lab.CopyInformation(out_img) - sitk.WriteImage(out_lab, lab_out_name) - - - diff --git a/setup.py b/setup.py index 22406a0..cfb5634 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setuptools.setup( name = 'PYMIC', - version = "0.4.0.1", + version = "0.4.0.2", author ='PyMIC Consortium', author_email = 'wguotai@gmail.com', description = description, From 613923b8f7e0fa943f456943c855a6d667ebd5e0 Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 10 Jan 2024 14:47:47 +0800 Subject: [PATCH 177/225] addd pymic_preprocess to setup --- pymic/net_run/agent_seg.py | 1 + setup.py | 1 + 2 files changed, 2 insertions(+) diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 2d6d489..2d61d0d 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -177,6 +177,7 @@ def training(self): inputs, labels_prob = mixup(inputs, labels_prob) # for debug + # print("current iteration", it) # if(it > 10): # break # for i in range(inputs.shape[0]): diff --git a/setup.py b/setup.py index cfb5634..1cad2ff 100644 --- a/setup.py +++ b/setup.py @@ -41,6 +41,7 @@ python_requires = '>=3.6', entry_points = { 'console_scripts': [ + 'pymic_preprocess = pymic.net_run.agent_preprocess:main' 'pymic_train = pymic.net_run.train:main', 'pymic_test = pymic.net_run.predict:main', 'pymic_eval_cls = pymic.util.evaluation_cls:main', From 6b08588ee46e5a3c0c45a6d832c015443fa52108 Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 10 Jan 2024 16:24:28 +0800 Subject: [PATCH 178/225] update config for multi-network models --- pymic/net_run/noisy_label/nll_co_teaching.py | 24 ---------------- pymic/net_run/noisy_label/nll_dast.py | 28 ++++++++---------- pymic/net_run/noisy_label/nll_trinet.py | 26 ----------------- pymic/net_run/preprocess.py | 30 ++++++++++++++++++++ setup.py | 2 +- 5 files changed, 43 insertions(+), 67 deletions(-) create mode 100644 pymic/net_run/preprocess.py diff --git a/pymic/net_run/noisy_label/nll_co_teaching.py b/pymic/net_run/noisy_label/nll_co_teaching.py index ec8e230..c60616e 100644 --- a/pymic/net_run/noisy_label/nll_co_teaching.py +++ b/pymic/net_run/noisy_label/nll_co_teaching.py @@ -18,22 +18,6 @@ from pymic.util.parse_config import * from pymic.util.ramps import get_rampup_ratio -class BiNet(nn.Module): - def __init__(self, params): - super(BiNet, self).__init__() - net_name = params['net_type'] - self.net1 = SegNetDict[net_name](params) - self.net2 = SegNetDict[net_name](params) - - def forward(self, x): - out1 = self.net1(x) - out2 = self.net2(x) - - if(self.training): - return out1, out2 - else: - return (out1 + out2) / 2 - class NLLCoTeaching(SegmentationAgent): """ Co-teaching for noisy-label learning. @@ -58,14 +42,6 @@ def __init__(self, config, stage = 'train'): logging.warn("only CrossEntropyLoss supported for" + " coteaching, the specified loss {0:} is ingored".format(loss_type)) - def create_network(self): - if(self.net is None): - self.net = BiNet(self.config['network']) - if(self.tensor_type == 'float'): - self.net.float() - else: - self.net.double() - def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] diff --git a/pymic/net_run/noisy_label/nll_dast.py b/pymic/net_run/noisy_label/nll_dast.py index 1921e9c..a90747c 100644 --- a/pymic/net_run/noisy_label/nll_dast.py +++ b/pymic/net_run/noisy_label/nll_dast.py @@ -117,31 +117,27 @@ def get_noisy_dataset_from_config(self): """ Create a dataset for images with noisy labels based on configuraiton. """ - root_dir = self.config['dataset']['root_dir'] - modal_num = self.config['dataset'].get('modal_num', 1) - transform_names = self.config['dataset']['train_transform'] - - self.transform_list = [] - if(transform_names is None or len(transform_names) == 0): - data_transform = None - else: - transform_param = self.config['dataset'] - transform_param['task'] = 'segmentation' - for name in transform_names: + trans_names, trans_params = self.get_transform_names_and_parameters('train') + transform_list = [] + if(trans_names is not None and len(trans_names) > 0): + for name in trans_names: if(name not in self.transform_dict): raise(ValueError("Undefined transform {0:}".format(name))) - one_transform = self.transform_dict[name](transform_param) - self.transform_list.append(one_transform) - data_transform = transforms.Compose(self.transform_list) + one_transform = self.transform_dict[name](trans_params) + transform_list.append(one_transform) + data_transform = transforms.Compose(transform_list) + modal_num = self.config['dataset'].get('modal_num', 1) csv_file = self.config['dataset'].get('train_csv_noise', None) - dataset = NiftyDataset(root_dir=root_dir, + dataset = NiftyDataset(root_dir = self.config['dataset']['train_dir'], csv_file = csv_file, modal_num = modal_num, with_label= True, - transform = data_transform ) + transform = data_transform , + task = self.task_type) return dataset + def create_dataset(self): super(NLLDAST, self).create_dataset() if(self.stage == 'train'): diff --git a/pymic/net_run/noisy_label/nll_trinet.py b/pymic/net_run/noisy_label/nll_trinet.py index 25c90cf..64d87b6 100644 --- a/pymic/net_run/noisy_label/nll_trinet.py +++ b/pymic/net_run/noisy_label/nll_trinet.py @@ -17,24 +17,6 @@ from pymic.util.parse_config import * from pymic.util.ramps import get_rampup_ratio -class TriNet(nn.Module): - def __init__(self, params): - super(TriNet, self).__init__() - net_name = params['net_type'] - self.net1 = SegNetDict[net_name](params) - self.net2 = SegNetDict[net_name](params) - self.net3 = SegNetDict[net_name](params) - - def forward(self, x): - out1 = self.net1(x) - out2 = self.net2(x) - out3 = self.net3(x) - - if(self.training): - return out1, out2, out3 - else: - return (out1 + out2 + out3) / 3 - class NLLTriNet(SegmentationAgent): """ Implementation of trinet for learning from noisy samples for @@ -56,14 +38,6 @@ class NLLTriNet(SegmentationAgent): def __init__(self, config, stage = 'train'): super(NLLTriNet, self).__init__(config, stage) - def create_network(self): - if(self.net is None): - self.net = TriNet(self.config['network']) - if(self.tensor_type == 'float'): - self.net.float() - else: - self.net.double() - def get_loss_and_confident_mask(self, pred, labels_prob, conf_ratio): prob = nn.Softmax(dim = 1)(pred) prob_2d = reshape_tensor_to_2D(prob) * 0.999 + 5e-4 diff --git a/pymic/net_run/preprocess.py b/pymic/net_run/preprocess.py new file mode 100644 index 0000000..3b34887 --- /dev/null +++ b/pymic/net_run/preprocess.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import os +import sys +from datetime import datetime +from pymic.util.parse_config import * +from pymic.net_run.agent_preprocess import PreprocessAgent + + +def main(): + """ + The main function for data preprocessing. + """ + if(len(sys.argv) < 2): + print('Number of arguments should be 2. e.g.') + print(' pymic_preprocess config.cfg') + exit() + cfg_file = str(sys.argv[1]) + if(not os.path.isfile(cfg_file)): + raise ValueError("The config file does not exist: " + cfg_file) + config = parse_config(cfg_file) + config = synchronize_config(config) + agent = PreprocessAgent(config) + agent.run() + +if __name__ == "__main__": + main() + + + diff --git a/setup.py b/setup.py index 1cad2ff..8030dc2 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,7 @@ python_requires = '>=3.6', entry_points = { 'console_scripts': [ - 'pymic_preprocess = pymic.net_run.agent_preprocess:main' + 'pymic_preprocess = pymic.net_run.preprocess:main', 'pymic_train = pymic.net_run.train:main', 'pymic_test = pymic.net_run.predict:main', 'pymic_eval_cls = pymic.util.evaluation_cls:main', From 57e5257b431968fb3e858838870b7dd6e81fdc5f Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 11 Jan 2024 22:26:52 +0800 Subject: [PATCH 179/225] update code for nll_clslsr --- pymic/loss/seg/slsr.py | 4 +-- pymic/net_run/noisy_label/nll_clslsr.py | 33 +++++++++++-------------- 2 files changed, 17 insertions(+), 20 deletions(-) diff --git a/pymic/loss/seg/slsr.py b/pymic/loss/seg/slsr.py index d5c4151..92adea7 100644 --- a/pymic/loss/seg/slsr.py +++ b/pymic/loss/seg/slsr.py @@ -38,8 +38,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) predict = reshape_tensor_to_2D(predict) soft_y = reshape_tensor_to_2D(soft_y) if(pix_w is not None): diff --git a/pymic/net_run/noisy_label/nll_clslsr.py b/pymic/net_run/noisy_label/nll_clslsr.py index 0148621..836272a 100644 --- a/pymic/net_run/noisy_label/nll_clslsr.py +++ b/pymic/net_run/noisy_label/nll_clslsr.py @@ -142,9 +142,9 @@ def test_time_dropout(m): print(gt.shape, pred_cat.shape) conf = get_confident_map(gt, pred_cat) conf = conf.reshape(-1, 256, 256).astype(np.uint8) * 255 - save_dir = self.config['dataset']['root_dir'] + "/slsr_conf" + save_dir = self.config['dataset']['train_dir'] + "/slsr_conf" for idx in range(len(filename_list)): - filename = filename_list[idx][0].split('/')[-1] + filename = filename_list[idx][0][0].split('/')[-1] conf_map = Image.fromarray(conf[idx]) dst_path = os.path.join(save_dir, filename) conf_map.save(dst_path) @@ -152,32 +152,29 @@ def test_time_dropout(m): def get_confidence_map(cfg_file): config = parse_config(cfg_file) config = synchronize_config(config) + agent = NLLCLSLSR(config, 'test') - # set dataset - transform_names = config['dataset']['valid_transform'] + # set customized dataset for testing, i.e,. inference with training images + trans_names, trans_params = agent.get_transform_names_and_parameters('valid') transform_list = [] - transform_dict = TransformDict - if(transform_names is None or len(transform_names) == 0): - data_transform = None - else: - transform_param = config['dataset'] - transform_param['task'] = 'segmentation' - for name in transform_names: - if(name not in transform_dict): + if(trans_names is not None and len(trans_names) > 0): + for name in trans_names: + if(name not in agent.transform_dict): raise(ValueError("Undefined transform {0:}".format(name))) - one_transform = transform_dict[name](transform_param) + one_transform = agent.transform_dict[name](trans_params) transform_list.append(one_transform) - data_transform = transforms.Compose(transform_list) + data_transform = transforms.Compose(transform_list) csv_file = config['dataset']['train_csv'] modal_num = config['dataset'].get('modal_num', 1) - dataset = NiftyDataset(root_dir = config['dataset']['root_dir'], + stage_dir = config['dataset']['train_dir'] + dataset = NiftyDataset(root_dir = stage_dir, csv_file = csv_file, modal_num = modal_num, with_label= True, - transform = data_transform ) + transform = data_transform, + task = agent.task_type) - agent = NLLCLSLSR(config, 'test') agent.set_datasets(None, None, dataset) agent.transform_list = transform_list agent.create_dataset() @@ -196,4 +193,4 @@ def get_confidence_map(cfg_file): "label": df_train["label"]} train_cl_csv = csv_file.replace(".csv", "_clslsr.csv") df_cl = pd.DataFrame.from_dict(train_cl_dict) - df_cl.to_csv(train_cl_csv, index = False) \ No newline at end of file + df_cl.to_csv(train_cl_csv, index = False) False) \ No newline at end of file From 1f44153ea5ae9d4ea500d24c5d0847d818279956 Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 11 Jan 2024 23:04:07 +0800 Subject: [PATCH 180/225] Update nll_clslsr.py --- pymic/net_run/noisy_label/nll_clslsr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymic/net_run/noisy_label/nll_clslsr.py b/pymic/net_run/noisy_label/nll_clslsr.py index 836272a..c977eba 100644 --- a/pymic/net_run/noisy_label/nll_clslsr.py +++ b/pymic/net_run/noisy_label/nll_clslsr.py @@ -193,4 +193,4 @@ def get_confidence_map(cfg_file): "label": df_train["label"]} train_cl_csv = csv_file.replace(".csv", "_clslsr.csv") df_cl = pd.DataFrame.from_dict(train_cl_dict) - df_cl.to_csv(train_cl_csv, index = False) False) \ No newline at end of file + df_cl.to_csv(train_cl_csv, index = False) \ No newline at end of file From 59a6f70612f474e39120306244501b9f839d22d6 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 12 Jan 2024 10:32:55 +0800 Subject: [PATCH 181/225] update preprocess file fix issues for label name in output --- pymic/net_run/agent_preprocess.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/pymic/net_run/agent_preprocess.py b/pymic/net_run/agent_preprocess.py index 67b1262..db8b10b 100644 --- a/pymic/net_run/agent_preprocess.py +++ b/pymic/net_run/agent_preprocess.py @@ -73,7 +73,8 @@ def run(self): Do preprocessing for labeled and unlabeled data. """ self.get_dataset_from_config() - out_dir = self.config['dataset']['output_dir'] + out_dir = self.config['dataset']['output_dir'] + modal_num = self.config['dataset']['modal_num'] if(not os.path.isdir(out_dir)): os.mkdir(out_dir) batch_operation = self.config['dataset'].get('batch_operation', None) @@ -82,9 +83,12 @@ def run(self): continue for data in dataloader: inputs = data['image'] - labels = data.get('label', None) + labels = data.get('label', None) img_names = data['names'] - lab_names = img_names[-1] + if(len(img_names) == modal_num): # for unlabeled dataset + lab_names = [item.replace(".nii.gz", "_lab.nii.gz") for item in img_names[0]] + else: + lab_names = img_names[-1] B, C = inputs.shape[0], inputs.shape[1] spacing = [x.numpy()[0] for x in data['spacing']] @@ -93,8 +97,6 @@ def run(self): block_range = self.config['dataset']['VolumeFusion_block_range'.lower()] size_min = self.config['dataset']['VolumeFusion_size_min'.lower()] size_max = self.config['dataset']['VolumeFusion_size_max'.lower()] - if(labels is None): - lab_names = [item.replace(".nii.gz", "_lab.nii.gz") for item in img_names[0]] inputs, labels = volume_fusion(inputs, class_num - 1, block_range, size_min, size_max) for b in range(B): From 4a3eeb407d5cdfb0896777b49a1c3d48f8b196c0 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 13 Jan 2024 20:22:02 +0800 Subject: [PATCH 182/225] update activation function for loss, and ema --- pymic/loss/seg/ssl.py | 8 ++++---- pymic/net_run/infer_func.py | 2 +- pymic/net_run/semi_sup/ssl_mt.py | 2 +- pymic/net_run/semi_sup/ssl_uamt.py | 2 +- pymic/net_run/weak_sup/wsl_ustm.py | 2 +- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pymic/loss/seg/ssl.py b/pymic/loss/seg/ssl.py index 0bf276f..3a7430a 100644 --- a/pymic/loss/seg/ssl.py +++ b/pymic/loss/seg/ssl.py @@ -34,8 +34,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) # for numeric stability predict = predict * 0.999 + 5e-4 @@ -70,8 +70,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) # for numeric stability predict = predict * 0.999 + 5e-4 diff --git a/pymic/net_run/infer_func.py b/pymic/net_run/infer_func.py index b0190ad..e0e466e 100644 --- a/pymic/net_run/infer_func.py +++ b/pymic/net_run/infer_func.py @@ -104,7 +104,7 @@ def __infer_with_sliding_window(self, image): weight = torch.zeros(output_shape).to(image.device) temp_w = self.__get_gaussian_weight_map(window_size) temp_w = np.broadcast_to(temp_w, [batch_size, class_num] + window_size) - temp_w = torch.from_numpy(temp_w).to(image.device) + temp_w = torch.from_numpy(np.array(temp_w)).to(image.device) temp_in_shape = img_full_shape[:2] + window_size tempx = torch.ones(temp_in_shape).to(image.device) out_num, scale_list = self.__get_prediction_number_and_scales(tempx) diff --git a/pymic/net_run/semi_sup/ssl_mt.py b/pymic/net_run/semi_sup/ssl_mt.py index 2a2abb8..409af19 100644 --- a/pymic/net_run/semi_sup/ssl_mt.py +++ b/pymic/net_run/semi_sup/ssl_mt.py @@ -106,7 +106,7 @@ def training(self): alpha = ssl_cfg.get('ema_decay', 0.99) alpha = min(1 - 1 / (self.glob_it / iter_valid + 1), alpha) for ema_param, param in zip(self.net_ema.parameters(), self.net.parameters()): - ema_param.data.mul_(alpha).add_(1 - alpha, param.data) + ema_param.data.mul_(alpha).add(param.data, alpha = 1.0 - alpha) train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() diff --git a/pymic/net_run/semi_sup/ssl_uamt.py b/pymic/net_run/semi_sup/ssl_uamt.py index 6222fe3..053a012 100644 --- a/pymic/net_run/semi_sup/ssl_uamt.py +++ b/pymic/net_run/semi_sup/ssl_uamt.py @@ -108,7 +108,7 @@ def training(self): alpha = ssl_cfg.get('ema_decay', 0.99) alpha = min(1 - 1 / (self.glob_it / iter_valid + 1), alpha) for ema_param, param in zip(self.net_ema.parameters(), self.net.parameters()): - ema_param.data.mul_(alpha).add_(1 - alpha, param.data) + ema_param.data.mul_(alpha).add(param.data, alpha = 1.0 - alpha) train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() diff --git a/pymic/net_run/weak_sup/wsl_ustm.py b/pymic/net_run/weak_sup/wsl_ustm.py index 0ea3fbc..31a6644 100644 --- a/pymic/net_run/weak_sup/wsl_ustm.py +++ b/pymic/net_run/weak_sup/wsl_ustm.py @@ -125,7 +125,7 @@ def training(self): alpha = wsl_cfg.get('ema_decay', 0.99) alpha = min(1 - 1 / (self.glob_it / iter_valid + 1), alpha) for ema_param, param in zip(self.net_ema.parameters(), self.net.parameters()): - ema_param.data.mul_(alpha).add_(1 - alpha, param.data) + ema_param.data.mul_(alpha).add(param.data, alpha = 1.0 - alpha) train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() From be140dec60b1ed37529ac57d8edb270b2cdf0add Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 16 Jan 2024 21:06:24 +0800 Subject: [PATCH 183/225] Update setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 8030dc2..ebb738f 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setuptools.setup( name = 'PYMIC', - version = "0.4.0.2", + version = "0.4.1", author ='PyMIC Consortium', author_email = 'wguotai@gmail.com', description = description, From 0e6b60cc5d355651b714a3af2db6a5361a0917ec Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 16 Jan 2024 21:22:18 +0800 Subject: [PATCH 184/225] Update README.md update version to 0.4.1 --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 4af5eff..90a6de7 100644 --- a/README.md +++ b/README.md @@ -47,10 +47,10 @@ Run the following command to install the latest released version of PyMIC: ```bash pip install PYMIC ``` -To install a specific version of PYMIC such as 0.4.0, run: +To install a specific version of PYMIC such as 0.4.1, run: ```bash -pip install PYMIC==0.4.0 +pip install PYMIC==0.4.1 ``` Alternatively, you can download the source code for the latest version. Run the following command to compile and install: From 84d8ed084819cb2c76971dc03de44d75d71344ce Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 16 Jan 2024 21:25:14 +0800 Subject: [PATCH 185/225] Update __init__.py update version --- pymic/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymic/__init__.py b/pymic/__init__.py index 1520d82..33943e4 100644 --- a/pymic/__init__.py +++ b/pymic/__init__.py @@ -1,7 +1,7 @@ from __future__ import absolute_import from enum import Enum -__version__ = "0.4.0" +__version__ = "0.4.1" class TaskType(Enum): CLASSIFICATION_ONE_HOT = 1 From ca4ebb0a66a69fbe3bec19fdc042671fc4116465 Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 1 May 2024 16:18:41 +0800 Subject: [PATCH 186/225] update config files use argparse for configuration --- pymic/io/image_read_write.py | 10 +++- pymic/io/nifty_dataset.py | 32 ++++++++++--- pymic/net_run/agent_abstract.py | 5 +- pymic/net_run/agent_cls.py | 10 ++-- pymic/net_run/agent_preprocess.py | 34 +++++++++++--- pymic/net_run/agent_seg.py | 7 +-- pymic/net_run/predict.py | 32 +++++++++---- pymic/net_run/preprocess.py | 13 ++++-- pymic/net_run/train.py | 32 +++++++++---- pymic/transform/affine.py | 6 ++- pymic/transform/intensity.py | 77 +++++++++++++++++++++++++++++-- pymic/transform/label_convert.py | 2 + pymic/util/evaluation_seg.py | 8 ++-- pymic/util/parse_config.py | 26 ++++++++++- 14 files changed, 239 insertions(+), 55 deletions(-) diff --git a/pymic/io/image_read_write.py b/pymic/io/image_read_write.py index 3aa87bd..efbe656 100644 --- a/pymic/io/image_read_write.py +++ b/pymic/io/image_read_write.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division - +import logging import os import numpy as np import SimpleITK as sitk @@ -53,10 +53,16 @@ def load_rgb_image_as_3d_array(filename): image = np.expand_dims(image, axis = 0) else: # transpose rgb image from [H, W, C] to [C, H, W] - assert(image_shape[2] == 3 or image_shape[2] == 4) + # logging.warning("The image is expected to have 1 or three channels, but it has a different channel number") + # logging.warning("({0:} {1:}".format(filename, image_shape)) if(image_shape[2] == 4): image = image[:, :, range(3)] + elif(image_shape[2] == 2): + image = image[:, :, 0:1] + elif(image_shape[2] != 3): + raise ValueError("invalid channel number {0:}", image_shape[2]) image = np.transpose(image, axes = [2, 0, 1]) + output = {} output['data_array'] = image output['origin'] = (0, 0) diff --git a/pymic/io/nifty_dataset.py b/pymic/io/nifty_dataset.py index aefe4da..c3253c9 100644 --- a/pymic/io/nifty_dataset.py +++ b/pymic/io/nifty_dataset.py @@ -118,7 +118,7 @@ def __getitem__(self, idx): return sample -class ClassificationDataset(NiftyDataset): +class ClassificationDataset(Dataset): """ Dataset for loading images for classification. It generates 4D tensors with dimention order [C, D, H, W] for 3D images, and 3D tensors @@ -134,16 +134,32 @@ class ClassificationDataset(NiftyDataset): """ def __init__(self, root_dir, csv_file, modal_num = 1, class_num = 2, with_label = False, transform=None, task = TaskType.CLASSIFICATION_ONE_HOT): - super(ClassificationDataset, self).__init__(root_dir, - csv_file, modal_num, with_label, transform) + # super(ClassificationDataset, self).__init__(root_dir, + # csv_file, modal_num, with_label, transform, task) + self.root_dir = root_dir + self.csv_items = pd.read_csv(csv_file) + self.modal_num = modal_num + self.with_label = with_label + self.transform = transform self.class_num = class_num self.task = task assert self.task in [TaskType.CLASSIFICATION_ONE_HOT, TaskType.CLASSIFICATION_COEXIST] + + csv_keys = list(self.csv_items.keys()) + self.image_weight_idx = None + if('image_weight' in csv_keys): + self.image_weight_idx = csv_keys.index('image_weight') + + def __len__(self): + return len(self.csv_items) def __getlabel__(self, idx): csv_keys = list(self.csv_items.keys()) - label_idx = csv_keys.index('label') - label = self.csv_items.iloc[idx, label_idx] + if self.task == TaskType.CLASSIFICATION_ONE_HOT: + label_idx = csv_keys.index('label') + label = self.csv_items.iloc[idx, label_idx] + else: + label = np.asarray(self.csv_items.iloc[idx, 1:self.class_num + 1], np.float32) return label def __getweight__(self, idx): @@ -161,13 +177,15 @@ def __getitem__(self, idx): names_list.append(image_name) image_list.append(image_data) image = np.concatenate(image_list, axis = 0) - image = np.asarray(image, np.float32) + image = np.asarray(image, np.float32) sample = {'image': image, 'names' : names_list[0], 'origin':image_dict['origin'], 'spacing': image_dict['spacing'], 'direction':image_dict['direction']} + if (self.with_label): - sample['label'] = self.__getlabel__(idx) + label = self.__getlabel__(idx) + sample['label'] = label #np.asarray(label, np.float32) if (self.image_weight_idx is not None): sample['image_weight'] = self.__getweight__(idx) if self.transform: diff --git a/pymic/net_run/agent_abstract.py b/pymic/net_run/agent_abstract.py index f9575ab..d8abb19 100644 --- a/pymic/net_run/agent_abstract.py +++ b/pymic/net_run/agent_abstract.py @@ -63,6 +63,7 @@ def __init__(self, config, stage = 'train'): if(self.deterministic): seed_torch(self.random_seed) logging.info("deterministric is true") + def set_datasets(self, train_set, valid_set, test_set): """ @@ -139,7 +140,7 @@ def get_checkpoint_name(self): """ ckpt_mode = self.config['testing']['ckpt_mode'] if(ckpt_mode == 0 or ckpt_mode == 1): - ckpt_dir = self.config['training']['ckpt_save_dir'] + ckpt_dir = self.config['training']['ckpt_dir'] ckpt_prefix = self.config['training'].get('ckpt_prefix', None) if(ckpt_prefix is None): ckpt_prefix = ckpt_dir.split('/')[-1] @@ -326,3 +327,5 @@ def run(self): else: self.infer() + + diff --git a/pymic/net_run/agent_cls.py b/pymic/net_run/agent_cls.py index ee4e25b..a31df84 100644 --- a/pymic/net_run/agent_cls.py +++ b/pymic/net_run/agent_cls.py @@ -73,7 +73,8 @@ def get_stage_dataset_from_config(self, stage): modal_num = modal_num, class_num = class_num, with_label= not (stage == 'test'), - transform = data_transform ) + transform = data_transform, + task = self.task_type) return dataset def create_network(self): @@ -97,6 +98,8 @@ def create_loss_calculator(self): if(self.loss_dict is None): self.loss_dict = PyMICClsLossDict loss_name = self.config['training']['loss_type'] + if(loss_name != "SigmoidCELoss" and self.task_type == TaskType.CLASSIFICATION_COEXIST): + raise ValueError("SigmoidCELoss should be used when task_type is cls_coexist") if(loss_name in self.loss_dict): self.loss_calculater = self.loss_dict[loss_name](self.config['training']) else: @@ -218,7 +221,7 @@ def train_valid(self): self.device = torch.device("cuda:{0:}".format(device_ids[0])) self.net.to(self.device) - ckpt_dir = self.config['training']['ckpt_save_dir'] + ckpt_dir = self.config['training']['ckpt_dir'] if(ckpt_dir[-1] == "/"): ckpt_dir = ckpt_dir[:-1] ckpt_prefix = self.config['training'].get('ckpt_prefix', None) @@ -259,7 +262,7 @@ def train_valid(self): self.trainIter = iter(self.train_loader) logging.info("{0:} training start".format(str(datetime.now())[:-7])) - self.summ_writer = SummaryWriter(self.config['training']['ckpt_save_dir']) + self.summ_writer = SummaryWriter(self.config['training']['ckpt_dir']) self.glob_it = iter_start for it in range(iter_start, iter_max, iter_valid): lr_value = self.optimizer.param_groups[0]['lr'] @@ -267,6 +270,7 @@ def train_valid(self): train_scalars = self.training() t1 = time.time() valid_scalars = self.validation() + t2 = time.time() if(isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): self.scheduler.step(valid_scalars[metrics]) diff --git a/pymic/net_run/agent_preprocess.py b/pymic/net_run/agent_preprocess.py index db8b10b..c53421f 100644 --- a/pymic/net_run/agent_preprocess.py +++ b/pymic/net_run/agent_preprocess.py @@ -9,7 +9,7 @@ from pymic.io.nifty_dataset import NiftyDataset from pymic.transform.trans_dict import TransformDict from pymic.net_run.agent_abstract import seed_torch -from pymic.net_run.self_sup.util import volume_fusion +from pymic.net_run.self_sup.util import volume_fusion, nonlienar_volume_fusion,augmented_volume_fusion,self_volume_fusion class PreprocessAgent(object): def __init__(self, config): @@ -92,17 +92,37 @@ def run(self): B, C = inputs.shape[0], inputs.shape[1] spacing = [x.numpy()[0] for x in data['spacing']] - if(batch_operation is not None and 'VolumeFusion' in batch_operation): - class_num = self.config['dataset']['VolumeFusion_cls_num'.lower()] - block_range = self.config['dataset']['VolumeFusion_block_range'.lower()] - size_min = self.config['dataset']['VolumeFusion_size_min'.lower()] - size_max = self.config['dataset']['VolumeFusion_size_max'.lower()] - inputs, labels = volume_fusion(inputs, class_num - 1, block_range, size_min, size_max) + if(batch_operation is not None): + if('VolumeFusion' in batch_operation): + class_num = self.config['dataset']['VolumeFusion_cls_num'.lower()] + block_range = self.config['dataset']['VolumeFusion_block_range'.lower()] + size_min = self.config['dataset']['VolumeFusion_size_min'.lower()] + size_max = self.config['dataset']['VolumeFusion_size_max'.lower()] + inputs, labels = volume_fusion(inputs, class_num - 1, block_range, size_min, size_max) + elif('SelfVolumeFusion' in batch_operation): + class_num = self.config['dataset']['SelfVolumeFusion_cls_num'.lower()] + fuse_ratio = self.config['dataset']['SelfVolumeFusion_fuse_ratio'.lower()] + size_min = self.config['dataset']['SelfVolumeFusion_size_min'.lower()] + size_max = self.config['dataset']['SelfVolumeFusion_size_max'.lower()] + inputs, labels = self_volume_fusion(inputs, class_num - 1, fuse_ratio, size_min, size_max) + elif('NonLinearVolumeFusion' in batch_operation): + block_range = self.config['dataset']['NonLinearVolumeFusion_block_range'.lower()] + size_min = self.config['dataset']['NonLinearVolumeFusion_size_min'.lower()] + size_max = self.config['dataset']['NonLinearVolumeFusion_size_max'.lower()] + inputs, labels = nonlienar_volume_fusion(inputs, block_range, size_min, size_max) + elif('AugmentedVolumeFusion' in batch_operation): + size_min = self.config['dataset']['AugmentedVolumeFusion_size_min'.lower()] + size_max = self.config['dataset']['AugmentedVolumeFusion_size_max'.lower()] + inputs, labels = augmented_volume_fusion(inputs, size_min, size_max) for b in range(B): for c in range(C): image_name = out_dir + "/" + img_names[c][b] print(image_name) + out_dir_full = "/".join(image_name.split("/")[:-1]) + print(out_dir_full) + if(not os.path.exists(out_dir_full)): + os.mkdir(out_dir_full) save_nd_array_as_image(inputs[b][c], image_name, reference_name = None, spacing=spacing) if(labels is not None): label_name = out_dir + "/" + lab_names[b] diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 2d61d0d..4a80298 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -321,7 +321,7 @@ def train_valid(self): self.device = torch.device("cuda:{0:}".format(device_ids[0])) self.net.to(self.device) - ckpt_dir = self.config['training']['ckpt_save_dir'] + ckpt_dir = self.config['training']['ckpt_dir'] if(ckpt_dir[-1] == "/"): ckpt_dir = ckpt_dir[:-1] ckpt_prefix = self.config['training'].get('ckpt_prefix', None) @@ -365,7 +365,7 @@ def train_valid(self): self.trainIter = iter(self.train_loader) logging.info("{0:} training start".format(str(datetime.now())[:-7])) - self.summ_writer = SummaryWriter(self.config['training']['ckpt_save_dir']) + self.summ_writer = SummaryWriter(self.config['training']['ckpt_dir']) self.glob_it = iter_start for it in range(iter_start, iter_max, iter_valid): lr_value = self.optimizer.param_groups[0]['lr'] @@ -581,7 +581,7 @@ def save_outputs(self, data): if(test_dir is None): test_dir = self.config['dataset']['train_dir'] - for i in range(len(names)): + for i in range(output.shape[0]): save_name = names[i][0].split('/')[-1] if ignore_dir else \ names[i][0].replace('/', '_') if((filename_replace_source is not None) and (filename_replace_target is not None)): @@ -607,3 +607,4 @@ def save_outputs(self, data): if(len(temp_prob.shape) == 2): temp_prob = np.asarray(temp_prob * 255, np.uint8) save_nd_array_as_image(temp_prob, prob_save_name, test_dir + '/' + names[i][0]) +0]) diff --git a/pymic/net_run/predict.py b/pymic/net_run/predict.py index 80134d8..e618be6 100644 --- a/pymic/net_run/predict.py +++ b/pymic/net_run/predict.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division +import argparse import logging import os import sys @@ -12,16 +13,30 @@ def main(): """ - The main function for running a network for training or inference. + The main function for running a network for inference. """ if(len(sys.argv) < 2): - print('Number of arguments should be 2. e.g.') - print(' pymic_test config.cfg') + print('Number of arguments should be at least 2. e.g.') + print(' pymic_test config.cfg -test_csv train.csv -output_dir result_dir -ckpt_mode 1') exit() - cfg_file = str(sys.argv[1]) - if(not os.path.isfile(cfg_file)): - raise ValueError("The config file does not exist: " + cfg_file) - config = parse_config(cfg_file) + parser = argparse.ArgumentParser() + parser.add_argument("cfg", help="configuration file for testing") + parser.add_argument("-test_csv", help="the csv file for testing images", + required=False, default=None) + parser.add_argument("-output_dir", help="the output dir for inference results", + required=False, default=None) + parser.add_argument("-ckpt_dir", help="the dir for trained model", + required=False, default=None) + parser.add_argument("-ckpt_mode", help="the mode for chekpoint: 0-latest, 1-best, 2-customized", + required=False, default=None) + parser.add_argument("-ckpt_name", help="the name chekpoint if ckpt_mode = 2", + required=False, default=None) + parser.add_argument("-gpus", help="the gpus for runing, e.g., [0]", + required=False, default=None) + args = parser.parse_args() + if(not os.path.isfile(args.cfg)): + raise ValueError("The config file does not exist: " + args.cfg) + config = parse_config(args) config = synchronize_config(config) log_dir = config['testing']['output_dir'] if(not os.path.exists(log_dir)): @@ -34,7 +49,8 @@ def main(): logging.basicConfig(filename=log_dir+"/log_test.txt", level=logging.INFO, format='%(message)s') # for python 3.6 logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) - logging_config(config) + dst_cfg = args.cfg if "/" not in args.cfg else args.cfg.split("/")[-1] + wrtie_config(config, log_dir + "/" + dst_cfg) task = config['dataset']['task_type'] if(task == TaskType.CLASSIFICATION_ONE_HOT or task == TaskType.CLASSIFICATION_COEXIST): agent = ClassificationAgent(config, 'test') diff --git a/pymic/net_run/preprocess.py b/pymic/net_run/preprocess.py index 3b34887..63410b5 100644 --- a/pymic/net_run/preprocess.py +++ b/pymic/net_run/preprocess.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division +import argparse import os import sys from datetime import datetime @@ -15,11 +16,13 @@ def main(): print('Number of arguments should be 2. e.g.') print(' pymic_preprocess config.cfg') exit() - cfg_file = str(sys.argv[1]) - if(not os.path.isfile(cfg_file)): - raise ValueError("The config file does not exist: " + cfg_file) - config = parse_config(cfg_file) - config = synchronize_config(config) + parser = argparse.ArgumentParser() + parser.add_argument("cfg", help="configuration file for preprocessing") + args = parser.parse_args() + if(not os.path.isfile(args.cfg)): + raise ValueError("The config file does not exist: " + args.cfg) + config = parse_config(args) + config = synchronize_config(config) agent = PreprocessAgent(config) agent.run() diff --git a/pymic/net_run/train.py b/pymic/net_run/train.py index 3a4571f..ed60fa1 100644 --- a/pymic/net_run/train.py +++ b/pymic/net_run/train.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division +import argparse import logging import os import sys @@ -48,19 +49,28 @@ def main(): The main function for running a network for training. """ if(len(sys.argv) < 2): - print('Number of arguments should be 2. e.g.') - print(' pymic_train config.cfg') + print('Number of arguments should be at least 2. e.g.') + print(' pymic_train config.cfg -train_csv train.csv') exit() - cfg_file = str(sys.argv[1]) - if(not os.path.isfile(cfg_file)): - raise ValueError("The config file does not exist: " + cfg_file) - config = parse_config(cfg_file) + parser = argparse.ArgumentParser() + parser.add_argument("cfg", help="configuration file for training") + parser.add_argument("-train_csv", help="the csv file for training images", + required=False, default=None) + parser.add_argument("-valid_csv", help="the csv file for validation images", + required=False, default=None) + parser.add_argument("-ckpt_dir", help="the output dir for trained model", + required=False, default=None) + parser.add_argument("-gpus", help="the gpus for runing, e.g., [0]", + required=False, default=None) + args = parser.parse_args() + if(not os.path.isfile(args.cfg)): + raise ValueError("The config file does not exist: " + args.cfg) + config = parse_config(args) config = synchronize_config(config) - log_dir = config['training']['ckpt_save_dir'] + + log_dir = config['training']['ckpt_dir'] if(not os.path.exists(log_dir)): os.makedirs(log_dir, exist_ok=True) - dst_cfg = cfg_file if "/" not in cfg_file else cfg_file.split("/")[-1] - shutil.copy(cfg_file, log_dir + "/" + dst_cfg) datetime_str = str(datetime.now())[:-7].replace(":", "_") if sys.version.startswith("3.9"): logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(datetime_str), @@ -69,7 +79,9 @@ def main(): logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(datetime_str), level=logging.INFO, format='%(message)s') # for python 3.6 logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) - logging_config(config) + dst_cfg = args.cfg if "/" not in args.cfg else args.cfg.split("/")[-1] + wrtie_config(config, log_dir + "/" + dst_cfg) + task = config['dataset']['task_type'] if(task == TaskType.CLASSIFICATION_ONE_HOT or task == TaskType.CLASSIFICATION_COEXIST): agent = ClassificationAgent(config, 'train') diff --git a/pymic/transform/affine.py b/pymic/transform/affine.py index 552516f..2efd586 100644 --- a/pymic/transform/affine.py +++ b/pymic/transform/affine.py @@ -86,7 +86,7 @@ def _get_affine_param(self, sample, output_shape): # sample['Affine_Param'] = json.dumps((input_shape, tform["matrix"])) return sample, tform - def _apply_affine_to_ND_volume(self, image, output_shape, tform, order = 3): + def _apply_affine_to_ND_volume(self, image, output_shape, tform, order = 2): """ output_shape should only has two dimensions, e.g., (H, W) """ @@ -152,5 +152,9 @@ def _get_param_for_inverse_transform(self, sample): # aff_out_shape = origin_shape[-2:] # output_predict = self._apply_affine_to_ND_volume(predict, aff_out_shape, tform.inverse) + # sample['predict'] = output_predict + # return sample + = self._apply_affine_to_ND_volume(predict, aff_out_shape, tform.inverse) + # sample['predict'] = output_predict # return sample diff --git a/pymic/transform/intensity.py b/pymic/transform/intensity.py index 2b19ebc..2829e76 100644 --- a/pymic/transform/intensity.py +++ b/pymic/transform/intensity.py @@ -1,10 +1,13 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division import copy +import itertools import json import math import random import numpy as np +from scipy import ndimage +from skimage import exposure from pymic.transform.abstract_transform import AbstractTransform from pymic.util.image_process import * try: # SciPy >= 0.19 @@ -82,7 +85,37 @@ def __call__(self, sample): image[chn] = np.clip(image[chn], lower_c, upper_c) sample['image'] = image return sample + +class HistEqual(AbstractTransform): + """ + Histogram equalization. Note that the output will be in the range of [0, 1]. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `HistEqual_channels`: (list) A list of int for specifying the channels. + :param `HistEqual_bin`: (int) The number of bins. + :param `HistEqual_inverse`: (optional, bool) + Is inverse transform needed for inference. Default is `False`. + """ + def __init__(self, params): + super(HistEqual, self).__init__(params) + self.channels = params.get('HistEqual_channels'.lower(), None) + # self.min = params.get('HistEqual_min'.lower(), None) + # self.max = params.get('HistEqual_max'.lower(), None) + self.bin = params.get('HistEqual_bin'.lower(), 2000) + self.inverse = params.get('HistEqual_inverse'.lower(), False) + def __call__(self, sample): + image = sample['image'] + C = image.shape[0] + chns = range(C) if self.channels is None else self.channels + for i in range(len(chns)): + c = chns[i] + image[c] = exposure.equalize_hist(image[c],nbins= self.bin) + sample['image'] = image + return sample + class GammaCorrection(AbstractTransform): """ Apply random gamma correction to given channels. @@ -189,7 +222,7 @@ def __init__(self, params): self.block_size = params.get('NonLinearTransform_block_size'.lower(), [8, 16, 16]) - def __apply_nonlinear_transform(self, img): + def apply_nonlinear_transform(self, img): """ the input img should be normlized to [0, 1]""" points = [[0, 0], [random.random(), random.random()], [random.random(), random.random()], [1, 1]] @@ -217,7 +250,7 @@ def __call__(self, sample): if(v_min < v_max): img_c = (img_c - v_min)/(v_max - v_min) if(self.block_range is None): # apply non-linear transform to the entire image - img_c = self.__apply_nonlinear_transform(img_c) + img_c = self.apply_nonlinear_transform(img_c) else: # non-linear transform to random blocks img_c_sr = copy.deepcopy(img_c) for n in range(self.block_range[0], self.block_range[1]): @@ -229,7 +262,7 @@ def __call__(self, sample): img_c[coord_min[0]:coord_min[0] + self.block_size[0], coord_min[1]:coord_min[1] + self.block_size[1], coord_min[2]:coord_min[2] + self.block_size[2]] = \ - self.__apply_nonlinear_transform(window) + self.apply_nonlinear_transform(window) image[chn] = img_c * (v_max - v_min) + v_min sample['image'] = image return sample @@ -423,6 +456,44 @@ def __call__(self, sample): img_out[:, pos_b0[0]:pos_b1[0], pos_b0[1]:pos_b1[1], pos_b0[2]:pos_b1[2]] = \ image[:, pos_a0[0]:pos_a1[0], pos_a0[1]:pos_a1[1], pos_a0[2]:pos_a1[2]] + sample['image'] = img_out + sample['label'] = image + return sample + +class MaskedImageModeling(AbstractTransform): + """ + Apply masking for context restoration in self-supervised learning. + Reference: Zekai Chen et al., Masked Image Modeling Advances 3D Medical Image Analysis, + WACV, 2023 . + """ + def __init__(self, params): + super(MaskedImageModeling, self).__init__(params) + self.ratio = params.get('MaskedImageModeling_ratio'.lower(), 0.45) + self.block_size = params.get('MaskedImageModeling_block_size'.lower(), [8, 16, 16]) + self.inverse = params.get('MaskedImageModeling_inverse'.lower(), False) + + def __call__(self, sample): + image= sample['image'] + C, D, H, W = image.shape + img_out = copy.deepcopy(image) + + block = np.zeros([C] + list(self.block_size)) + for d in range(0, D, self.block_size[0]): + d1 = d + self.block_size[0] + if d1 > D: + continue + for h in range(0, H, self.block_size[1]): + h1 = h + self.block_size[1] + if h1 > H: + continue + for w in range(0, W, self.block_size[2]): + w1 = w + self.block_size[2] + if w1 > W: + continue + r = random.random() + if ( r < self.ratio): + img_out[:, d:d1, h:h1, w:w1] = block + sample['image'] = img_out sample['label'] = image return sample \ No newline at end of file diff --git a/pymic/transform/label_convert.py b/pymic/transform/label_convert.py index 00c505e..afbbaaf 100644 --- a/pymic/transform/label_convert.py +++ b/pymic/transform/label_convert.py @@ -92,6 +92,8 @@ def __call__(self, sample): label_prob = np.zeros((self.class_num,), np.float32) label_prob[label_idx] = 1.0 sample['label_prob'] = label_prob + elif(self.task == TaskType.CLASSIFICATION_COEXIST): + sample['label_prob'] = sample['label'] return sample class LabelSmooth(AbstractTransform): diff --git a/pymic/util/evaluation_seg.py b/pymic/util/evaluation_seg.py index 82939ce..926ff0e 100644 --- a/pymic/util/evaluation_seg.py +++ b/pymic/util/evaluation_seg.py @@ -387,14 +387,14 @@ def main(): args = parser.parse_args() print(args) if(args.cfg is not None): - config = parse_config(args.cfg)['evaluation'] + config = parse_config(args)['evaluation'] else: config = {} config['metric_list'] = parse_value_from_string(args.metric) config['label_list'] = None if args.cls_index is None else parse_value_from_string(args.cls_index) config['class_number']= None if args.cls_num is None else parse_value_from_string(args.cls_num) - config['ground_truth_folder_root'] = args.gt_dir - config['segmentation_folder_root'] = args.seg_dir + config['ground_truth_folder'] = args.gt_dir + config['segmentation_folder'] = args.seg_dir config['evaluation_image_pair'] = args.name_pair config['output_name'] = args.out print(config) @@ -402,3 +402,5 @@ def main(): if __name__ == '__main__': main() + + main() diff --git a/pymic/util/parse_config.py b/pymic/util/parse_config.py index 0e38b91..09f6db0 100644 --- a/pymic/util/parse_config.py +++ b/pymic/util/parse_config.py @@ -84,14 +84,18 @@ def parse_value_from_string(val_str): val = val_str return val -def parse_config(filename): +def parse_config(args): config = configparser.ConfigParser() - config.read(filename) + config.read(args.cfg) output = {} for section in config.sections(): output[section] = {} for key in config[section]: val_str = str(config[section][key]) + if hasattr(args, key): + args_key = getattr(args, key) + if(args_key is not None): + val_str = args_key if(len(val_str)>0): val = parse_value_from_string(val_str) output[section][key] = val @@ -133,6 +137,24 @@ def synchronize_config(config): # config['network'] = net_cfg return config +def wrtie_config(config, output_name): + logging.info("The running configuations are: ") + with open(output_name, 'w') as f: + for section in config: + if(isinstance(config[section], dict)): + line = '[' + section + ']' + f.write('\n' + line + '\n') + logging.info(line) + for key in config[section]: + value = config[section][key] + line = "{0:} = {1:}".format(key, value) + f.write(line + '\n') + logging.info(line) + else: + line = "{0:} = {1:}".format(section, config[section]) + f.write(line + "\n") + logging.info(line) + def logging_config(config): for section in config: if(isinstance(config[section], dict)): From f823eb8ba5e23327f8556f432a8f114a792731ca Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 6 Aug 2024 17:02:39 +0800 Subject: [PATCH 187/225] update transform --- README.md | 2 +- pymic/transform/crop.py | 10 ++-- pymic/transform/intensity.py | 100 ++++++++++++++++++++++++++++++++--- pymic/transform/normalize.py | 17 ++++-- 4 files changed, 112 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 90a6de7..6f34f92 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # PyMIC: A Pytorch-Based Toolkit for Medical Image Computing -PyMIC is a pytorch-based toolkit for medical image computing with annotation-efficient deep learning. Despite that pytorch is a fantastic platform for deep learning, using it for medical image computing is not straightforward as medical images are often with high dimension and large volume, multiple modalities and difficulies in annotating. This toolkit is developed to facilitate medical image computing researchers so that training and testing deep learning models become easier. It is very friendly to researchers who are new to this area. Even without writing any code, you can use PyMIC commands to train and test a model by simply editing configuration files. PyMIC is developed to support learning with imperfect labels, including semi-supervised, self-supervised, and weakly supervised learning, and learning with noisy annotations. +PyMIC is a pytorch-based toolkit for medical image computing with annotation-efficient deep learning. Despite that pytorch is a fantastic platform for deep learning, using it for medical image computing is not straightforward as medical images are often with high dimension and large volume, multiple modalities and difficulies in annotating. This toolkit is developed to facilitate medical image computing researchers so that training and testing deep learning models become easier. It is very friendly to researchers who are new to this area. Even without writing any code, you can use PyMIC commands to train and test a model by simply editing configuration files. PyMIC is developed to support learning with imperfect labels, including semi-supervised, self-supervised, and weakly supervised learning, and learning with noisy annotations. Currently PyMIC supports 2D/3D medical image classification and segmentation, and it is still under development. If you use this toolkit, please cite the following paper: diff --git a/pymic/transform/crop.py b/pymic/transform/crop.py index b821bb2..9e1c077 100644 --- a/pymic/transform/crop.py +++ b/pymic/transform/crop.py @@ -428,9 +428,9 @@ def __call__(self, sample): return sample -class CropHumanRegionFromCT(CenterCrop): +class CropHumanRegion(CenterCrop): """ - Crop the human region from a CT volume. + Crop the human region from a CT for MRI volume. The arguments should be written in the `params` dictionary, and it has the following fields: @@ -447,9 +447,9 @@ class CropHumanRegionFromCT(CenterCrop): Default is `True`. """ def __init__(self, params): - self.threshold_i = params.get('CropHumanRegionFromCT_intensity_threshold'.lower(), -600) - self.threshold_z = params.get('CropHumanRegionFromCT_zaxis_threshold'.lower(), 0.5) - self.inverse = params.get('CropHumanRegionFromCT_inverse'.lower(), True) + self.threshold_i = params.get('CropHumanRegion_intensity_threshold'.lower(), -600) + self.threshold_z = params.get('CropHumanRegion_zaxis_threshold'.lower(), 0.5) + self.inverse = params.get('CropHumanRegion_inverse'.lower(), True) self.task = params['task'] def _get_crop_param(self, sample): diff --git a/pymic/transform/intensity.py b/pymic/transform/intensity.py index 2829e76..f05b95b 100644 --- a/pymic/transform/intensity.py +++ b/pymic/transform/intensity.py @@ -47,6 +47,7 @@ def bezier_curve(points, nTimes=1000): return xvals, yvals + class IntensityClip(AbstractTransform): """ Clip the intensity for input image @@ -161,6 +162,48 @@ def __call__(self, sample): sample['image'] = image return sample +def gaussian_noise(image, std_min, std_max,): + """ + The input has a shape of [C, D, H, W] or [D, H, W]. + In the former case, volume-level noise will be added. + In the latter case, slice-level noise will ba added. + """ + v_min = image.min() + v_max = image.max() + std = random.random() * (std_max - std_min) + std_min + noise = np.random.normal(0, std, image.shape) + out = image + noise + out = np.clip(out, v_min, v_max) + return out + +def gaussian_blur(image, sigma_min, sigma_max): + sigma = random.random() * (sigma_max - sigma_min) + sigma_min + out = ndimage.gaussian_filter(image, sigma, order = 0) + return out + +def gaussian_sharpen(image, sigma_min, sigma_max, alpha = 10.0): + blurred = gaussian_blur(image, sigma_min, sigma_max) + out = image + (image - blurred) * alpha + return out + +def window_level_augment(image, offset = 0.1): + v_min = image.min() + v_max = image.max() + margin = (v_max - v_min) * offset + v0 = random.uniform(v_min - margin, v_min + margin) + v1 = random.uniform(v_max - margin, v_max + margin) + out = np.clip((image - v0) / (v1 - v0), 0, 1) + return out + +def gamma_correction(image, gamma_min, gamma_max): + v_min = image.min() + v_max = image.max() + if(v_min < v_max): + image = (image - v_min)/(v_max - v_min) + gamma = random.random() * (gamma_max - gamma_min) + gamma_min + image = np.power(image, gamma)*(v_max - v_min) + v_min + return image + class GaussianNoise(AbstractTransform): """ Add Gaussian Noise to given channels. @@ -179,8 +222,8 @@ class GaussianNoise(AbstractTransform): def __init__(self, params): super(GaussianNoise, self).__init__(params) self.channels = params.get('GaussianNoise_channels'.lower(), None) - self.mean = params['GaussianNoise_mean'.lower()] - self.std = params['GaussianNoise_std'.lower()] + self.std_min = params.get('GaussianNoise_std_min'.lower(), 0.02) + self.std_max = params.get('GaussianNoise_std_max'.lower(), 0.1) self.prob = params.get('GaussianNoise_probability'.lower(), 0.5) self.inverse = params.get('GaussianNoise_inverse'.lower(), False) @@ -190,10 +233,53 @@ def __call__(self, sample): self.channels = range(image.shape[0]) for chn in self.channels: if(np.random.uniform() < self.prob): - img_c = image[chn] - noise = np.random.normal(self.mean, self.std, img_c.shape) - image[chn] = img_c + noise + image[chn] = gaussian_noise(image[chn], self.std_min, self.std_max) + sample['image'] = image + return sample + +def adaptive_contrast_adjust(image, p0=0.1, p1=99.9): + v_min = image.min() + v_max = image.max() + v0 = np.percentile(image, p0) + v1 = np.percentile(image, p1) + mask_l = image < v0 + mask_m = (image >= v0) * (image <= v1) + mask_u = image > v1 + image[mask_l] = (image[mask_l] - v_min) * 0.1 / (v0 - v_min) + image[mask_m] = (image[mask_m] - v0) / (v1 - v0)*0.8 + 0.1 + image[mask_u] = 0.9 + 0.1 * (image[mask_u] - v1) / (v_max - v1) + return image + +class AdaptiveContrastAdjust(AbstractTransform): + """ + Add Gaussian Noise to given channels. + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `GaussianNoise_channels`: (list) A list of int for specifying the channels. + :param `GaussianNoise_mean`: (float) The mean value of noise. + :param `GaussianNoise_std`: (float) The std of noise. + :param `GaussianNoise_probability`: (optional, float) + The probability of applying GaussianNoise. Default is 0.5. + :param `GaussianNoise_inverse`: (optional, bool) + Is inverse transform needed for inference. Default is `False`. + """ + def __init__(self, params): + super(AdaptiveContrastAdjust, self).__init__(params) + self.channels = params.get('AdaptiveContrastAdjust_channels'.lower(), None) + self.p0 = params.get('AdaptiveContrastAdjust_percent_lower'.lower(), 2) + self.p1 = params.get('AdaptiveContrastAdjust_percent_upper'.lower(), 98) + self.prob = params.get('AdaptiveContrastAdjust_probability'.lower(), 0.5) + self.inverse = params.get('AdaptiveContrastAdjust_inverse'.lower(), False) + + def __call__(self, sample): + image = sample['image'] * 1.0 + if(self.channels is None): + self.channels = range(image.shape[0]) + for chn in self.channels: + if(np.random.uniform() < self.prob): + image[chn] = adaptive_contrast_adjust(image[chn], self.p0, self.p1) sample['image'] = image return sample @@ -219,7 +305,7 @@ def __init__(self, params): self.prob = params.get('NonLinearTransform_probability'.lower(), 0.5) self.inverse = params.get('NonLinearTransform_inverse'.lower(), False) self.block_range = params.get('NonLinearTransform_block_range'.lower(), None) - self.block_size = params.get('NonLinearTransform_block_size'.lower(), [8, 16, 16]) + self.block_size = params.get('NonLinearTransform_block_size'.lower(), [4, 8, 8]) def apply_nonlinear_transform(self, img): @@ -326,7 +412,7 @@ def __init__(self, params): self.inverse = params.get('InPainting_inverse'.lower(), False) self.prob = params.get('InPainting_probability'.lower(), 0.5) self.block_range = params.get('InPainting_block_range'.lower(), (20, 40)) - self.block_size = params.get('InPainting_block_size'.lower(), [8, 16, 16]) + self.block_size = params.get('InPainting_block_size'.lower(), [4, 8, 8]) def __call__(self, sample): if(random.random() > self.prob): diff --git a/pymic/transform/normalize.py b/pymic/transform/normalize.py index 643c12e..35c5dc4 100644 --- a/pymic/transform/normalize.py +++ b/pymic/transform/normalize.py @@ -131,14 +131,17 @@ class NormalizeWithPercentiles(AbstractTransform): The min percentile, which must be between 0 and 100 inclusive. :param `NormalizeWithPercentiles_percentile_upper`: (float) The max percentile, which must be between 0 and 100 inclusive. + :param `NormalizeWithPercentiles_output_mode`: (int) 0: the output is in the range [0,1] + Otherwise the output is in the range of [-1, 1] :param `NormalizeWithMinMax_inverse`: (optional, bool) Is inverse transform needed for inference. Default is `False`. """ def __init__(self, params): super(NormalizeWithPercentiles, self).__init__(params) - self.chns = params['NormalizeWithPercentiles_channels'.lower()] - self.percent_lower = params['NormalizeWithPercentiles_percentile_lower'.lower()] - self.percent_upper = params['NormalizeWithPercentiles_percentile_upper'.lower()] + self.chns = params.get('NormalizeWithPercentiles_channels'.lower(), None) + self.percent_lower = params.get('NormalizeWithPercentiles_percentile_lower'.lower(), 0.1) + self.percent_upper = params.get('NormalizeWithPercentiles_percentile_upper'.lower(), 99.9) + self.out_mode = params.get('NormalizeWithPercentiles_output_mode'.lower(), 0) self.inverse = params.get('NormalizeWithPercentiles_inverse'.lower(), False) def __call__(self, sample): @@ -152,7 +155,13 @@ def __call__(self, sample): img_chn[img_chn < v0] = v0 img_chn[img_chn > v1] = v1 - img_chn = 2.0* (img_chn - v0) / (v1 - v0) -1.0 + if(self.out_mode == 0): + img_chn = (img_chn - v0) / (v1 - v0) + img_chn = np.clip(img_chn, 0, 1) + else: + img_chn = 2.0* (img_chn - v0) / (v1 - v0) -1.0 + img_chn = np.clip(img_chn, -1, 1) + image[chn] = img_chn sample['image'] = image return sample \ No newline at end of file From f4a2dcea7800f503772380835fefda782b578766 Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 4 Sep 2024 16:24:59 +0800 Subject: [PATCH 188/225] allow timing --- pymic/net/net3d/unet2d5.py | 258 ++++++++++++------- pymic/net/net3d/unet3d.py | 102 +++++--- pymic/net/net3d/unet3d_dual_branch.py | 3 +- pymic/net/net3d/unet3d_scse.py | 133 ++++------ pymic/net_run/noisy_label/nll_co_teaching.py | 31 ++- pymic/net_run/noisy_label/nll_dast.py | 24 +- pymic/net_run/noisy_label/nll_trinet.py | 26 +- pymic/net_run/self_sup/self_volume_fusion.py | 245 +++++++++++++++++- pymic/net_run/self_sup/util.py | 159 +++++++++++- pymic/net_run/semi_sup/__init__.py | 2 + pymic/net_run/semi_sup/ssl_abstract.py | 3 + pymic/net_run/semi_sup/ssl_cct.py | 22 +- pymic/net_run/semi_sup/ssl_cps.py | 23 +- pymic/net_run/semi_sup/ssl_em.py | 24 +- pymic/net_run/semi_sup/ssl_mcnet.py | 24 +- pymic/net_run/semi_sup/ssl_mt.py | 23 +- pymic/net_run/semi_sup/ssl_uamt.py | 22 +- pymic/net_run/semi_sup/ssl_urpc.py | 22 +- pymic/net_run/weak_sup/wsl_abstract.py | 9 +- pymic/net_run/weak_sup/wsl_dmpls.py | 24 +- pymic/net_run/weak_sup/wsl_em.py | 24 +- pymic/net_run/weak_sup/wsl_gatedcrf.py | 22 +- pymic/net_run/weak_sup/wsl_mumford_shah.py | 22 +- pymic/net_run/weak_sup/wsl_tv.py | 22 +- pymic/net_run/weak_sup/wsl_ustm.py | 22 +- 25 files changed, 951 insertions(+), 340 deletions(-) diff --git a/pymic/net/net3d/unet2d5.py b/pymic/net/net3d/unet2d5.py index 308fdde..4e70393 100644 --- a/pymic/net/net3d/unet2d5.py +++ b/pymic/net/net3d/unet2d5.py @@ -1,9 +1,16 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division + +import logging import torch import torch.nn as nn import numpy as np +ConvND = {2: nn.Conv2d, 3: nn.Conv3d} +BatchNormND = {2: nn.BatchNorm2d, 3: nn.BatchNorm3d} +MaxPoolND = {2: nn.MaxPool2d, 3: nn.MaxPool3d} +ConvTransND = {2: nn.ConvTranspose2d, 3: nn.ConvTranspose3d} + class ConvBlockND(nn.Module): """ 2D or 3D convolutional block @@ -13,29 +20,17 @@ class ConvBlockND(nn.Module): :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. :param dropout_p: (int) Dropout probability. """ - def __init__(self, in_channels, out_channels, - dim = 2, dropout_p = 0.0): + def __init__(self, in_channels, out_channels, dim = 2, dropout_p = 0.0): super(ConvBlockND, self).__init__() assert(dim == 2 or dim == 3) self.dim = dim - if(self.dim == 2): - self.conv_conv = nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), - nn.BatchNorm2d(out_channels), - nn.PReLU(), - nn.Dropout(dropout_p), - nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), - nn.BatchNorm2d(out_channels), - nn.PReLU() - ) - else: - self.conv_conv = nn.Sequential( - nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1), - nn.BatchNorm3d(out_channels), + self.conv_conv = nn.Sequential( + ConvND[dim](in_channels, out_channels, kernel_size=3, padding=1), + BatchNormND[dim](out_channels), nn.PReLU(), nn.Dropout(dropout_p), - nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1), - nn.BatchNorm3d(out_channels), + ConvND[dim](out_channels, out_channels, kernel_size=3, padding=1), + BatchNormND[dim](out_channels), nn.PReLU() ) @@ -52,17 +47,12 @@ class DownBlock(nn.Module): :param dropout_p: (int) Dropout probability. :param downsample: (bool) Use downsample or not after convolution. """ - def __init__(self, in_channels, out_channels, - dim = 2, dropout_p = 0.0, downsample = True): + def __init__(self, in_channels, out_channels, dim = 2, dropout_p = 0.0, downsample = True): super(DownBlock, self).__init__() self.downsample = downsample self.dim = dim self.conv = ConvBlockND(in_channels, out_channels, dim, dropout_p) - if(downsample): - if(self.dim == 2): - self.down_layer = nn.MaxPool2d(kernel_size = 2, stride = 2) - else: - self.down_layer = nn.MaxPool3d(kernel_size = 2, stride = 2) + self.down_layer = MaxPoolND[dim](kernel_size = 2, stride = 2) def forward(self, x): x_shape = list(x.shape) @@ -95,28 +85,31 @@ class UpBlock(nn.Module): :param out_channels: (int) Output channel number. :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. :param dropout_p: (int) Dropout probability. - :param bilinear: (bool) Use bilinear for up-sampling or not. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear` for 3D and `Bilinear` for 2D). + The default value is 2. """ def __init__(self, in_channels1, in_channels2, out_channels, - dim = 2, dropout_p = 0.0, bilinear=True): + dim = 2, dropout_p = 0.0, up_mode= 2): super(UpBlock, self).__init__() - self.bilinear = bilinear - self.dim = dim - if bilinear: - if(dim == 2): - self.up = nn.Sequential( - nn.Conv2d(in_channels1, in_channels2, kernel_size = 1), - nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)) - else: - self.up = nn.Sequential( - nn.Conv3d(in_channels1, in_channels2, kernel_size = 1), - nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)) + if(isinstance(up_mode, int)): + up_mode_values = ["transconv", "nearest", "trilinear"] + if(up_mode > 2): + raise ValueError("The upsample mode should be 0-2, but {0:} is given.".format(up_mode)) + self.up_mode = up_mode_values[up_mode] else: - if(dim == 2): - self.up = nn.ConvTranspose2d(in_channels1, in_channels2, kernel_size=2, stride=2) + self.up_mode = up_mode.lower() + + self.dim = dim + if (self.up_mode == "transconv"): + self.up = ConvTransND[dim](in_channels1, in_channels2, kernel_size=2, stride=2) + else: + self.conv1x1 = ConvND[dim](in_channels1, in_channels2, kernel_size = 1) + if(self.up_mode == "nearest"): + self.up = nn.Upsample(scale_factor=2, mode=self.up_mode) else: - self.up = nn.ConvTranspose3d(in_channels1, in_channels2, kernel_size=2, stride=2) - + mode = "trilinear" if dim == 3 else "bilinear" + self.up = nn.Upsample(scale_factor=2, mode=mode, align_corners=True) self.conv = ConvBlockND(in_channels2 * 2, out_channels, dim, dropout_p) def forward(self, x1, x2): @@ -132,6 +125,8 @@ def forward(self, x1, x2): x2 = torch.transpose(x2, 1, 2) x2 = torch.reshape(x2, new_shape) + if self.up_mode != "transconv": + x1 = self.conv1x1(x1) x1 = self.up(x1) output = torch.cat([x2, x1], dim=1) output = self.conv(output) @@ -141,6 +136,98 @@ def forward(self, x1, x2): output = torch.transpose(output, 1, 2) return output +class Encoder(nn.Module): + """ + A modification of the encoder of 3D UNet by using ConvScSEBlock3D + + Parameters are given in the `params` dictionary. + See :mod:`pymic.net.net3d.unet3d.Encoder` for details. + """ + def __init__(self, params): + super(Encoder, self).__init__() + self.params = params + self.in_chns = self.params['in_chns'] + self.n_class = self.params['class_num'] + self.ft_chns = self.params['feature_chns'] + self.dropout = self.params['dropout'] + self.dims = self.params['conv_dims'] + + self.block0 = DownBlock(self.in_chns, self.ft_chns[0], self.dims[0], self.dropout[0], True) + self.block1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dims[1], self.dropout[1], True) + self.block2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dims[2], self.dropout[2], True) + self.block3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dims[3], self.dropout[3], True) + self.block4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dims[4], self.dropout[4], False) + + def forward(self, x): + x0, x0_d = self.block0(x) + x1, x1_d = self.block1(x0_d) + x2, x2_d = self.block2(x1_d) + x3, x3_d = self.block3(x2_d) + x4, x4_d = self.block4(x3_d) + return [x0, x1, x2, x3, x4] + +class Decoder(nn.Module): + """ + Decoder of 3D UNet. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear` for 3D and `Bilinear` for 2D). + The default value is 2. + :param multiscale_pred: (bool) Get multi-scale prediction. + """ + def __init__(self, params): + super(Decoder, self).__init__() + self.params = params + self.in_chns = self.params['in_chns'] + self.n_class = self.params['class_num'] + self.ft_chns = self.params['feature_chns'] + self.dropout = self.params['dropout'] + self.dims = self.params['conv_dims'] + self.up_mode = self.params.get('up_mode', 2) + self.mul_pred = self.params.get('multiscale_pred', False) + + self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], + self.dims[3], dropout_p = self.dropout[3], up_mode=self.up_mode) + self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], + self.dims[2], dropout_p = self.dropout[2], up_mode=self.up_mode) + self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], + self.dims[1], dropout_p = self.dropout[1], up_mode=self.up_mode) + self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], + self.dims[0], dropout_p = self.dropout[0], up_mode=self.up_mode) + + self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 1) + if(self.mul_pred): + self.out_conv1 = nn.Conv3d(self.ft_chns[1], self.n_class, kernel_size = 1) + self.out_conv2 = nn.Conv3d(self.ft_chns[2], self.n_class, kernel_size = 1) + self.out_conv3 = nn.Conv3d(self.ft_chns[3], self.n_class, kernel_size = 1) + self.stage = 'train' + + def set_stage(self, stage): + self.stage = stage + + def forward(self, x): + x0, x1, x2, x3, x4 = x + x_d3 = self.up1(x4, x3) + x_d2 = self.up2(x_d3, x2) + x_d1 = self.up3(x_d2, x1) + x_d0 = self.up4(x_d1, x0) + output = self.out_conv(x_d0) + if(self.mul_pred and self.stage == 'train'): + output1 = self.out_conv1(x_d1) + output2 = self.out_conv2(x_d2) + output3 = self.out_conv3(x_d3) + output = [output, output1, output2, output3] + return output + class UNet2D5(nn.Module): """ A 2.5D network combining 3D convolutions with 2D convolutions. @@ -164,68 +251,39 @@ class UNet2D5(nn.Module): :param conv_dims: (list) The convolution dimension (2 or 3) for each resolution level. The length should be the same as that of `feature_chns`. :param class_num: (int) The class number for segmentation task. - :param bilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear`). The default value + is 2 (`Trilinear`). + :param multiscale_pred: (bool) Get multi-scale prediction. """ def __init__(self, params): super(UNet2D5, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.dims = self.params['conv_dims'] - self.n_class = self.params['class_num'] - self.bilinear = self.params['bilinear'] - - assert(len(self.ft_chns) == 5) + params = self.get_default_parameters(params) + for p in params: + print(p, params[p]) + self.stage = 'train' + self.encoder = Encoder(params) + self.decoder = Decoder(params) + + def get_default_parameters(self, params): + default_param = { + 'feature_chns': [32, 64, 128, 256, 512], + 'dropout': [0.0, 0.0, 0.2, 0.3, 0.4], + 'conv_dims':[2, 2, 3, 3, 3], + 'up_mode': 2, + 'multiscale_pred': False + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params - self.block0 = DownBlock(self.in_chns, self.ft_chns[0], self.dims[0], self.dropout[0], True) - self.block1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dims[1], self.dropout[1], True) - self.block2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dims[2], self.dropout[2], True) - self.block3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dims[3], self.dropout[3], True) - self.block4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dims[4], self.dropout[4], False) - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], - self.dims[3], dropout_p = self.dropout[3], bilinear = self.bilinear) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], - self.dims[2], dropout_p = self.dropout[2], bilinear = self.bilinear) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], - self.dims[1], dropout_p = self.dropout[1], bilinear = self.bilinear) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], - self.dims[0], dropout_p = self.dropout[0], bilinear = self.bilinear) - - self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, - kernel_size = (1, 3, 3), padding = (0, 1, 1)) + def set_stage(self, stage): + self.stage = stage + self.decoder.set_stage(stage) def forward(self, x): - x0, x0_d = self.block0(x) - x1, x1_d = self.block1(x0_d) - x2, x2_d = self.block2(x1_d) - x3, x3_d = self.block3(x2_d) - x4, x4_d = self.block4(x3_d) - - x = self.up1(x4, x3) - x = self.up2(x, x2) - x = self.up3(x, x1) - x = self.up4(x, x0) - output = self.out_conv(x) + f = self.encoder(x) + output = self.decoder(f) return output - - -if __name__ == "__main__": - params = {'in_chns':4, - 'feature_chns':[2, 8, 32, 48, 64], - 'conv_dims': [2, 2, 3, 3, 3], - 'dropout': [0, 0, 0.3, 0.4, 0.5], - 'class_num': 2, - 'bilinear': False} - Net = UNet2D5(params) - Net = Net.double() - - x = np.random.rand(4, 4, 32, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - print(len(y.size())) - y = y.detach().numpy() - print(y.shape) diff --git a/pymic/net/net3d/unet3d.py b/pymic/net/net3d/unet3d.py index e383e77..b66ea38 100644 --- a/pymic/net/net3d/unet3d.py +++ b/pymic/net/net3d/unet3d.py @@ -5,6 +5,7 @@ import torch import torch.nn as nn import numpy as np +from pymic.net.net_init import Initialization_He, Initialization_XavierUniform class ConvBlock(nn.Module): @@ -16,15 +17,23 @@ class ConvBlock(nn.Module): :param out_channels: (int) Output channel number. :param dropout_p: (int) Dropout probability. """ - def __init__(self, in_channels, out_channels, dropout_p): + def __init__(self, in_channels, out_channels, dropout_p, norm_type = 'batch_norm'): super(ConvBlock, self).__init__() + if(norm_type == 'batch_norm'): + norm1 = nn.BatchNorm3d(out_channels, affine = True) + norm2 = nn.BatchNorm3d(out_channels, affine = True) + elif(norm_type == 'instance_norm'): + norm1 = nn.InstanceNorm3d(out_channels, affine = True) + norm2 = nn.InstanceNorm3d(out_channels, affine = True) + else: + raise ValueError("norm_type {0:} not supported, it should be batch_norm or instance_norm".format(norm_type)) self.conv_conv = nn.Sequential( nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1), - nn.BatchNorm3d(out_channels), + norm1, nn.LeakyReLU(), nn.Dropout(dropout_p), nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1), - nn.BatchNorm3d(out_channels), + norm2, nn.LeakyReLU() ) @@ -39,11 +48,11 @@ class DownBlock(nn.Module): :param out_channels: (int) Output channel number. :param dropout_p: (int) Dropout probability. """ - def __init__(self, in_channels, out_channels, dropout_p): + def __init__(self, in_channels, out_channels, dropout_p, norm_type = 'batch_norm'): super(DownBlock, self).__init__() self.maxpool_conv = nn.Sequential( nn.MaxPool3d(2), - ConvBlock(in_channels, out_channels, dropout_p) + ConvBlock(in_channels, out_channels, dropout_p, norm_type) ) def forward(self, x): @@ -61,7 +70,8 @@ class UpBlock(nn.Module): 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear`). The default value is 2 (`Trilinear`). """ - def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, up_mode=2): + def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, + up_mode=2, norm_type = 'batch_norm'): super(UpBlock, self).__init__() if(isinstance(up_mode, int)): up_mode_values = ["transconv", "nearest", "trilinear"] @@ -79,7 +89,7 @@ def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, up_mode= self.up = nn.Upsample(scale_factor=2, mode=self.up_mode) else: self.up = nn.Upsample(scale_factor=2, mode=self.up_mode, align_corners=True) - self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p) + self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p, norm_type) def forward(self, x1, x2): if self.up_mode != "transconv": @@ -104,17 +114,19 @@ class Encoder(nn.Module): def __init__(self, params): super(Encoder, self).__init__() self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) - - self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) - self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) - self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) - self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) - if(len(self.ft_chns) == 5): - self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) + in_chns = self.params['in_chns'] + ft_chns = self.params['feature_chns'] + dropout = self.params['dropout'] + norm_type = self.params['norm_type'] + assert(len(ft_chns) == 5 or len(ft_chns) == 4) + + self.ft_chns= ft_chns + self.in_conv= ConvBlock(in_chns, ft_chns[0], dropout[0], norm_type) + self.down1 = DownBlock(ft_chns[0], ft_chns[1], dropout[1], norm_type) + self.down2 = DownBlock(ft_chns[1], ft_chns[2], dropout[2], norm_type) + self.down3 = DownBlock(ft_chns[2], ft_chns[3], dropout[3], norm_type) + if(len(ft_chns) == 5): + self.down4 = DownBlock(ft_chns[3], ft_chns[4], dropout[4]) def forward(self, x): x0 = self.in_conv(x) @@ -148,25 +160,26 @@ class Decoder(nn.Module): def __init__(self, params): super(Decoder, self).__init__() self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.n_class = self.params['class_num'] - self.up_mode = self.params.get('up_mode', 2) - self.mul_pred = self.params.get('multiscale_pred', False) - assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) + ft_chns = self.params['feature_chns'] + dropout = self.params['dropout'] + n_class = self.params['class_num'] + norm_type = self.params['norm_type'] + up_mode = self.params.get('up_mode', 2) + self.ft_chns = ft_chns + self.mul_pred = self.params.get('multiscale_pred', False) + assert(len(ft_chns) == 5 or len(ft_chns) == 4) - if(len(self.ft_chns) == 5): - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.up_mode) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.up_mode) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.up_mode) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.up_mode) - self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 1) + if(len(ft_chns) == 5): + self.up1 = UpBlock(ft_chns[4], ft_chns[3], ft_chns[3], dropout[3], up_mode, norm_type) + self.up2 = UpBlock(ft_chns[3], ft_chns[2], ft_chns[2], dropout[2], up_mode, norm_type) + self.up3 = UpBlock(ft_chns[2], ft_chns[1], ft_chns[1], dropout[1], up_mode, norm_type) + self.up4 = UpBlock(ft_chns[1], ft_chns[0], ft_chns[0], dropout[0], up_mode, norm_type) + self.out_conv = nn.Conv3d(ft_chns[0], n_class, kernel_size = 1) if(self.mul_pred): - self.out_conv1 = nn.Conv3d(self.ft_chns[1], self.n_class, kernel_size = 1) - self.out_conv2 = nn.Conv3d(self.ft_chns[2], self.n_class, kernel_size = 1) - self.out_conv3 = nn.Conv3d(self.ft_chns[3], self.n_class, kernel_size = 1) + self.out_conv1 = nn.Conv3d(ft_chns[1], n_class, kernel_size = 1) + self.out_conv2 = nn.Conv3d(ft_chns[2], n_class, kernel_size = 1) + self.out_conv3 = nn.Conv3d(ft_chns[3], n_class, kernel_size = 1) self.stage = 'train' def set_stage(self, stage): @@ -223,14 +236,21 @@ def __init__(self, params): for p in params: print(p, params[p]) self.stage = 'train' + self.update_mode = params.get("update_mode", "all") self.encoder = Encoder(params) - self.decoder = Decoder(params) - + self.decoder = Decoder(params) + + init = params['initialization'].lower() + weightInitializer = Initialization_He(1e-2) if init == 'he' else Initialization_XavierUniform() + self.apply(weightInitializer) + def get_default_parameters(self, params): default_param = { 'feature_chns': [32, 64, 128, 256, 512], 'dropout': [0.0, 0.0, 0.2, 0.3, 0.4], 'up_mode': 2, + 'initialization': 'he', + 'norm_type': 'batch_norm', 'multiscale_pred': False } for key in default_param: @@ -243,6 +263,16 @@ def set_stage(self, stage): self.stage = stage self.decoder.set_stage(stage) + def get_parameters_to_update(self): + if(self.update_mode == "all"): + return self.parameters() + elif(self.update_mode == "decoder"): + print("only update parameters in decoder") + params = self.decoder.parameters() + return params + else: + raise(ValueError("update_mode can only be 'all' or 'decoder'.")) + def forward(self, x): f = self.encoder(x) output = self.decoder(f) diff --git a/pymic/net/net3d/unet3d_dual_branch.py b/pymic/net/net3d/unet3d_dual_branch.py index 3bede4e..54b01a0 100644 --- a/pymic/net/net3d/unet3d_dual_branch.py +++ b/pymic/net/net3d/unet3d_dual_branch.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division -import torch import torch.nn as nn from pymic.net.net3d.unet3d import * @@ -20,7 +19,7 @@ class UNet3D_DualBranch(nn.Module): :param output_mode: (str) How to obtain the result during the inference. `average`: taking average of the two branches. - `first`: takeing the result in the first branch. + `first`: taking the result in the first branch. `second`: taking the result in the second branch. """ def __init__(self, params): diff --git a/pymic/net/net3d/unet3d_scse.py b/pymic/net/net3d/unet3d_scse.py index b2da0dc..79abf9c 100644 --- a/pymic/net/net3d/unet3d_scse.py +++ b/pymic/net/net3d/unet3d_scse.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn import numpy as np +from pymic.net.net3d.unet3d import UpBlock, Encoder, Decoder, UNet3D from pymic.net.net3d.scse3d import * class ConvScSEBlock3D(nn.Module): @@ -48,108 +49,70 @@ def __init__(self, in_channels, out_channels, dropout_p): def forward(self, x): return self.maxpool_conv(x) -class UpBlock(nn.Module): +class UpBlockScSE(UpBlock): """3D Up-sampling followed by `ConvScSEBlock3D` in UNet3D_ScSE. :param in_channels1: (int) Input channel number for low-resolution feature map. :param in_channels2: (int) Input channel number for high-resolution feature map. :param out_channels: (int) Output channel number. :param dropout_p: (int) Dropout probability. - :param trilinear: (bool) Use trilinear for up-sampling or not. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear`). The default value + is 2 (`Trilinear`). """ - def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, - trilinear=True): - super(UpBlock, self).__init__() - self.trilinear = trilinear - if trilinear: - self.conv1x1 = nn.Conv3d(in_channels1, in_channels2, kernel_size = 1) - self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True) - else: - self.up = nn.ConvTranspose3d(in_channels1, in_channels2, kernel_size=2, stride=2) + def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, up_mode=2): + super(UpBlockScSE, self).__init__(in_channels1, in_channels2, + out_channels, dropout_p, up_mode) self.conv = ConvScSEBlock3D(in_channels2 * 2, out_channels, dropout_p) - def forward(self, x1, x2): - if self.trilinear: - x1 = self.conv1x1(x1) - x1 = self.up(x1) - x = torch.cat([x2, x1], dim=1) - return self.conv(x) -class UNet3D_ScSE(nn.Module): +class EncoderScSE(Encoder): """ - Combining 3D U-Net with SCSE module. - - * Reference: Abhijit Guha Roy, Nassir Navab, Christian Wachinger: - Recalibrating Fully Convolutional Networks With Spatial and Channel - "Squeeze and Excitation" Blocks. - `IEEE Trans. Med. Imaging 38(2): 540-549 (2019). `_ + A modification of the encoder of 3D UNet by using ConvScSEBlock3D - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using trilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. + Parameters are given in the `params` dictionary. + See :mod:`pymic.net.net3d.unet3d.Encoder` for details. """ def __init__(self, params): - super(UNet3D_ScSE, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.n_class = self.params['class_num'] - self.bilinear = self.params['trilinear'] + super(EncoderScSE, self).__init__(params) - assert(len(self.ft_chns) == 5) - self.in_conv= ConvScSEBlock3D(self.in_chns, self.ft_chns[0], self.dropout[0]) self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) - self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p = self.dropout[3]) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p = self.dropout[2]) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p = self.dropout[1]) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p = self.dropout[0]) - - self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, - kernel_size = 3, padding = 1) + if(len(self.ft_chns) == 5): + self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - def forward(self, x): - - x0 = self.in_conv(x) - x1 = self.down1(x0) - x2 = self.down2(x1) - x3 = self.down3(x2) - x4 = self.down4(x3) - - x = self.up1(x4, x3) - x = self.up2(x, x2) - x = self.up3(x, x1) - x = self.up4(x, x0) - output = self.out_conv(x) - - return output - -if __name__ == "__main__": - params = {'in_chns':4, - 'feature_chns':[2, 8, 32, 48, 64], - 'dropout': [0, 0, 0.3, 0.4, 0.5], - 'class_num': 2, - 'trilinear': True} - Net = UNet3D_ScSE(params) - Net = Net.double() - - x = np.random.rand(4, 4, 96, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - print(len(y.size())) - y = y.detach().numpy() - print(y.shape) \ No newline at end of file +class DecoderScSE(Decoder): + """ + A modification of the decoder of 3D UNet by using ConvScSEBlock3D + + Parameters are given in the `params` dictionary. + See :mod:`pymic.net.net3d.unet3d.Decoder` for details. + """ + def __init__(self, params): + super(DecoderScSE, self).__init__(params) + + if(len(self.ft_chns) == 5): + self.up1 = UpBlockScSE(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.up_mode) + self.up2 = UpBlockScSE(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.up_mode) + self.up3 = UpBlockScSE(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.up_mode) + self.up4 = UpBlockScSE(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.up_mode) + + +class UNet3D_ScSE(UNet3D): + """ + Combining 3D U-Net with SCSE module. + + * Reference: Abhijit Guha Roy, Nassir Navab, Christian Wachinger: + Recalibrating Fully Convolutional Networks With Spatial and Channel + "Squeeze and Excitation" Blocks. + `IEEE Trans. Med. Imaging 38(2): 540-549 (2019). `_ + + Parameters are given in the `params` dictionary. + See :mod:`pymic.net.net3d.unet3d.UNet3D` for details. + """ + def __init__(self, params): + super(UNet3D_ScSE, self).__init__(params) + self.encoder = EncoderScSE(params) + self.decoder = DecoderScSE(params) diff --git a/pymic/net_run/noisy_label/nll_co_teaching.py b/pymic/net_run/noisy_label/nll_co_teaching.py index c60616e..d46b05b 100644 --- a/pymic/net_run/noisy_label/nll_co_teaching.py +++ b/pymic/net_run/noisy_label/nll_co_teaching.py @@ -5,16 +5,14 @@ import os import sys import numpy as np +import time import torch import torch.nn as nn -import torch.optim as optim -from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.util import reshape_tensor_to_2D from pymic.net_run.agent_seg import SegmentationAgent -from pymic.net.net_dict_seg import SegNetDict from pymic.util.parse_config import * from pymic.util.ramps import get_rampup_ratio @@ -51,19 +49,19 @@ def training(self): rampup_start = nll_cfg.get('rampup_start', 0) rampup_end = nll_cfg.get('rampup_end', iter_max) - train_loss_no_select1 = 0 - train_loss_no_select2 = 0 - train_loss1 = 0 - train_loss2 = 0 + train_loss_no_select1, train_loss_no_select2 = 0, 0 + train_loss1, train_avg_loss2 = 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() for it in range(iter_valid): + t0 = time.time() try: data = next(self.trainIter) except StopIteration: self.trainIter = iter(self.train_loader) data = next(self.trainIter) - + t1 = time.time() # get the inputs inputs = self.convert_tensor_type(data['image']) labels_prob = self.convert_tensor_type(data['label_prob']) @@ -74,7 +72,7 @@ def training(self): # forward + backward + optimize outputs1, outputs2 = self.net(inputs) - + t2 = time.time() prob1 = nn.Softmax(dim = 1)(outputs1) prob2 = nn.Softmax(dim = 1)(outputs2) prob1_2d = reshape_tensor_to_2D(prob1) * 0.999 + 5e-4 @@ -101,8 +99,9 @@ def training(self): loss2_select = loss2[ind_1_update] loss = loss1_select.mean() + loss2_select.mean() - + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss_no_select1 = train_loss_no_select1 + loss1.mean().item() @@ -115,6 +114,11 @@ def training(self): soft_out1, labels_prob = reshape_prediction_and_ground_truth(soft_out1, labels_prob) dice_list = get_classwise_dice(soft_out1, labels_prob).detach().cpu().numpy() train_dice_list.append(dice_list) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss_no_select1 = train_loss_no_select1 / iter_valid train_avg_loss_no_select2 = train_loss_no_select2 / iter_valid train_avg_loss1 = train_loss1 / iter_valid @@ -126,7 +130,9 @@ def training(self): 'loss1':train_avg_loss1, 'loss2': train_avg_loss2, 'loss_no_select1':train_avg_loss_no_select1, 'loss_no_select2':train_avg_loss_no_select2, - 'select_ratio':remb_ratio, 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'select_ratio':remb_ratio, 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } return train_scalers def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): @@ -153,3 +159,6 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): logging.info('valid loss {0:.4f}, avg foreground dice {1:.4f} '.format( valid_scalars['loss'], valid_scalars['avg_fg_dice']) + "[" + \ ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") + logging.info("data: {0:.2f}s, forward: {1:.2f}s, loss: {2:.2f}s, backward: {3:.2f}s".format( + train_scalars['data_time'], train_scalars['forward_time'], + train_scalars['loss_time'], train_scalars['backward_time'])) diff --git a/pymic/net_run/noisy_label/nll_dast.py b/pymic/net_run/noisy_label/nll_dast.py index a90747c..938e10a 100644 --- a/pymic/net_run/noisy_label/nll_dast.py +++ b/pymic/net_run/noisy_label/nll_dast.py @@ -1,8 +1,9 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division +import numpy as np import random import torch -import numpy as np +import time import torch.nn as nn import torchvision.transforms as transforms from pymic.io.nifty_dataset import NiftyDataset @@ -163,15 +164,15 @@ def training(self): iter_max = self.config['training']['iter_max'] rampup_start = nll_cfg.get('rampup_start', 0) rampup_end = nll_cfg.get('rampup_end', iter_max) - train_loss = 0 - train_loss_sup = 0 - train_loss_reg = 0 + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() rank_length = nll_cfg.get("dast_rank_length", 20) consist_loss = ConsistLoss() for it in range(iter_valid): + t0 = time.time() try: data_cl = next(self.trainIter) except StopIteration: @@ -182,7 +183,7 @@ def training(self): except StopIteration: self.trainIter_noise = iter(self.train_loader_noise) data_no = next(self.trainIter_noise) - + t1 = time.time() # get the inputs x0 = self.convert_tensor_type(data_cl['image']) # clean sample y0 = self.convert_tensor_type(data_cl['label_prob']) @@ -196,6 +197,7 @@ def training(self): # forward + backward + optimize b0_pred, b1_pred = self.net(inputs) + t2 = time.time() n0 = list(x0.shape)[0] # number of clean samples b0_x0_pred = b0_pred[:n0] # predication of clean samples from clean branch b0_x1_pred = b0_pred[n0:] # predication of noisy samples from clean branch @@ -231,8 +233,9 @@ def training(self): b0_x1_prob = nn.Softmax(dim = 1)(b0_x1_pred) loss_st = torch.mean(torch.abs(b0_x1_prob - sharpen(pseudo_label, 0.5))) loss = loss + loss_st * w_st - + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss = train_loss + loss.item() @@ -248,6 +251,11 @@ def training(self): p0_soft, y0 = reshape_prediction_and_ground_truth(p0_soft, y0) dice_list = get_classwise_dice(p0_soft, y0) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid @@ -256,7 +264,9 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':w_dbc, - 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } return train_scalers def train_valid(self): diff --git a/pymic/net_run/noisy_label/nll_trinet.py b/pymic/net_run/noisy_label/nll_trinet.py index 64d87b6..4c4198f 100644 --- a/pymic/net_run/noisy_label/nll_trinet.py +++ b/pymic/net_run/noisy_label/nll_trinet.py @@ -2,9 +2,8 @@ from __future__ import print_function, division import logging -import os -import sys import numpy as np +import time import torch import torch.nn as nn import torch.optim as optim @@ -62,14 +61,16 @@ def training(self): train_loss_no_select2 = 0 train_loss1, train_loss2, train_loss3 = 0, 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() for it in range(iter_valid): + t0 = time.time() try: data = next(self.trainIter) except StopIteration: self.trainIter = iter(self.train_loader) data = next(self.trainIter) - + t1 = time.time() # get the inputs inputs = self.convert_tensor_type(data['image']) labels_prob = self.convert_tensor_type(data['label_prob']) @@ -80,7 +81,7 @@ def training(self): # forward + backward + optimize outputs1, outputs2, outputs3 = self.net(inputs) - + t2 = time.time() rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end) forget_ratio = (1 - select_ratio) * rampup_ratio remb_ratio = 1 - forget_ratio @@ -95,8 +96,9 @@ def training(self): loss2_avg = torch.sum(loss2 * mask13) / mask13.sum() loss3_avg = torch.sum(loss3 * mask12) / mask12.sum() loss = (loss1_avg + loss2_avg + loss3_avg) / 3 - + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss_no_select1 = train_loss_no_select1 + loss1.mean().item() @@ -109,6 +111,11 @@ def training(self): soft_out1, labels_prob = reshape_prediction_and_ground_truth(soft_out1, labels_prob) dice_list = get_classwise_dice(soft_out1, labels_prob).detach().cpu().numpy() train_dice_list.append(dice_list) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss_no_select1 = train_loss_no_select1 / iter_valid train_avg_loss_no_select2 = train_loss_no_select2 / iter_valid train_avg_loss1 = train_loss1 / iter_valid @@ -120,7 +127,9 @@ def training(self): 'loss1':train_avg_loss1, 'loss2': train_avg_loss2, 'loss_no_select1':train_avg_loss_no_select1, 'loss_no_select2':train_avg_loss_no_select2, - 'select_ratio':remb_ratio, 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'select_ratio':remb_ratio, 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } return train_scalers def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): @@ -146,4 +155,7 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): ' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]") logging.info('valid loss {0:.4f}, avg foreground dice {1:.4f} '.format( valid_scalars['loss'], valid_scalars['avg_fg_dice']) + "[" + \ - ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") + ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") + logging.info("data: {0:.2f}s, forward: {1:.2f}s, loss: {2:.2f}s, backward: {3:.2f}s".format( + train_scalars['data_time'], train_scalars['forward_time'], + train_scalars['loss_time'], train_scalars['backward_time'])) diff --git a/pymic/net_run/self_sup/self_volume_fusion.py b/pymic/net_run/self_sup/self_volume_fusion.py index 91fe088..ad2d640 100644 --- a/pymic/net_run/self_sup/self_volume_fusion.py +++ b/pymic/net_run/self_sup/self_volume_fusion.py @@ -30,11 +30,11 @@ from pymic.loss.seg.util import get_classwise_dice from pymic.transform.trans_dict import TransformDict from pymic.util.post_process import PostProcessDict -from pymic.util.image_process import convert_label from pymic.util.parse_config import * from pymic.util.general import get_one_hot_seg from pymic.io.image_read_write import save_nd_array_as_image -from pymic.net_run.self_sup.util import volume_fusion +from pymic.net_run.self_sup.util import volume_fusion, nonlienar_volume_fusion, augmented_volume_fusion +from pymic.net_run.self_sup.util import self_volume_fusion from pymic.net_run.agent_seg import SegmentationAgent @@ -57,7 +57,6 @@ def __init__(self, config, stage = 'train'): def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] - cls_num = self.config['network']['class_num'] block_range = self.config['self_supervised_learning']['VolumeFusion_block_range'.lower()] size_min = self.config['self_supervised_learning']['VolumeFusion_size_min'.lower()] size_max = self.config['self_supervised_learning']['VolumeFusion_size_max'.lower()] @@ -73,8 +72,8 @@ def training(self): data = next(self.trainIter) # get the inputs inputs = self.convert_tensor_type(data['image']) - inputs, labels = volume_fusion(inputs, cls_num - 1, block_range, size_min, size_max) - labels_prob = get_one_hot_seg(labels, cls_num) + inputs, labels = volume_fusion(inputs, class_num - 1, block_range, size_min, size_max) + labels_prob = get_one_hot_seg(labels, class_num) # for debug # if(it==10): @@ -117,3 +116,239 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'avg_fg_dice':train_avg_dice,\ 'class_dice': train_cls_dice} return train_scalers + +class SelfSupSelfVolumeFusion(SegmentationAgent): + """ + Abstract class for self-supervised segmentation. + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. + """ + def __init__(self, config, stage = 'train'): + super(SelfSupSelfVolumeFusion, self).__init__(config, stage) + + def training(self): + class_num = self.config['network']['class_num'] + iter_valid = self.config['training']['iter_valid'] + fuse_ratio = self.config['self_supervised_learning']['SelfVolumeFusion_fuse_ratio'.lower()] + size_min = self.config['self_supervised_learning']['SelfVolumeFusion_size_min'.lower()] + size_max = self.config['self_supervised_learning']['SelfVolumeFusion_size_max'.lower()] + + train_loss = 0 + train_dice_list = [] + self.net.train() + for it in range(iter_valid): + try: + data = next(self.trainIter) + except StopIteration: + self.trainIter = iter(self.train_loader) + data = next(self.trainIter) + # get the inputs + inputs = self.convert_tensor_type(data['image']) + inputs, labels = self_volume_fusion(inputs, class_num - 1, fuse_ratio, size_min, size_max) + labels_prob = get_one_hot_seg(labels, class_num) + + # for debug + # if(it==10): + # break + # for i in range(inputs.shape[0]): + # image_i = inputs[i][0] + # label_i = np.argmax(labels_prob[i], axis = 0) + # # pixw_i = pix_w[i][0] + # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) + # label_name = "temp/label_{0:}_{1:}.nii.gz".format(it, i) + # # weight_name= "temp/weight_{0:}_{1:}.nii.gz".format(it, i) + # save_nd_array_as_image(image_i, image_name, reference_name = None) + # save_nd_array_as_image(label_i, label_name, reference_name = None) + # # save_nd_array_as_image(pixw_i, weight_name, reference_name = None) + # continue + + inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device) + + # zero the parameter gradients + self.optimizer.zero_grad() + + # forward + backward + optimize + outputs = self.net(inputs) + loss = self.get_loss_value(data, outputs, labels_prob) + loss.backward() + self.optimizer.step() + train_loss = train_loss + loss.item() + # get dice evaluation for each class + if(isinstance(outputs, tuple) or isinstance(outputs, list)): + outputs = outputs[0] + outputs_argmax = torch.argmax(outputs, dim = 1, keepdim = True) + soft_out = get_soft_label(outputs_argmax, class_num, self.tensor_type) + soft_out, labels_prob = reshape_prediction_and_ground_truth(soft_out, labels_prob) + dice_list = get_classwise_dice(soft_out, labels_prob) + train_dice_list.append(dice_list.cpu().numpy()) + train_avg_loss = train_loss / iter_valid + train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) + train_avg_dice = train_cls_dice[1:].mean() + + train_scalers = {'loss': train_avg_loss, 'avg_fg_dice':train_avg_dice,\ + 'class_dice': train_cls_dice} + return train_scalers + +class SelfSupNonLinearVolumeFusion(SegmentationAgent): + """ + Abstract class for self-supervised segmentation. + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. + """ + def __init__(self, config, stage = 'train'): + super(SelfSupNonLinearVolumeFusion, self).__init__(config, stage) + + def training(self): + class_num = 3 + iter_valid = self.config['training']['iter_valid'] + block_range = self.config['self_supervised_learning']['NonLinearVolumeFusion_block_range'.lower()] + size_min = self.config['self_supervised_learning']['NonLinearVolumeFusion_size_min'.lower()] + size_max = self.config['self_supervised_learning']['NonLinearVolumeFusion_size_max'.lower()] + + train_loss = 0 + train_dice_list = [] + self.net.train() + for it in range(iter_valid): + try: + data = next(self.trainIter) + except StopIteration: + self.trainIter = iter(self.train_loader) + data = next(self.trainIter) + # get the inputs + inputs = self.convert_tensor_type(data['image']) + inputs, labels = nonlienar_volume_fusion(inputs, block_range, size_min, size_max) + labels_prob = get_one_hot_seg(labels, class_num) + + # for debug + # if(it==10): + # break + # for i in range(inputs.shape[0]): + # image_i = inputs[i][0] + # label_i = np.argmax(labels_prob[i], axis = 0) + # # pixw_i = pix_w[i][0] + # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) + # label_name = "temp/label_{0:}_{1:}.nii.gz".format(it, i) + # # weight_name= "temp/weight_{0:}_{1:}.nii.gz".format(it, i) + # save_nd_array_as_image(image_i, image_name, reference_name = None) + # save_nd_array_as_image(label_i, label_name, reference_name = None) + # # save_nd_array_as_image(pixw_i, weight_name, reference_name = None) + # continue + + inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device) + + # zero the parameter gradients + self.optimizer.zero_grad() + + # forward + backward + optimize + outputs = self.net(inputs) + loss = self.get_loss_value(data, outputs, labels_prob) + loss.backward() + self.optimizer.step() + train_loss = train_loss + loss.item() + # get dice evaluation for each class + if(isinstance(outputs, tuple) or isinstance(outputs, list)): + outputs = outputs[0] + outputs_argmax = torch.argmax(outputs, dim = 1, keepdim = True) + soft_out = get_soft_label(outputs_argmax, class_num, self.tensor_type) + soft_out, labels_prob = reshape_prediction_and_ground_truth(soft_out, labels_prob) + dice_list = get_classwise_dice(soft_out, labels_prob) + train_dice_list.append(dice_list.cpu().numpy()) + train_avg_loss = train_loss / iter_valid + train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) + train_avg_dice = train_cls_dice[1:].mean() + + train_scalers = {'loss': train_avg_loss, 'avg_fg_dice':train_avg_dice,\ + 'class_dice': train_cls_dice} + return train_scalers + +class SelfSupAugmentedVolumeFusion(SegmentationAgent): + """ + Abstract class for self-supervised segmentation. + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. + """ + def __init__(self, config, stage = 'train'): + super(SelfSupAugmentedVolumeFusion, self).__init__(config, stage) + + def training(self): + class_num = 5 + iter_valid = self.config['training']['iter_valid'] + size_min = self.config['self_supervised_learning']['AugmentedVolumeFusion_size_min'.lower()] + size_max = self.config['self_supervised_learning']['AugmentedVolumeFusion_size_max'.lower()] + + train_loss = 0 + train_dice_list = [] + self.net.train() + for it in range(iter_valid): + try: + data = next(self.trainIter) + except StopIteration: + self.trainIter = iter(self.train_loader) + data = next(self.trainIter) + # get the inputs + inputs = self.convert_tensor_type(data['image']) + inputs, labels = augmented_volume_fusion(inputs, size_min, size_max) + labels_prob = get_one_hot_seg(labels, class_num) + + # for debug + # if(it==10): + # break + # for i in range(inputs.shape[0]): + # image_i = inputs[i][0] + # label_i = np.argmax(labels_prob[i], axis = 0) + # # pixw_i = pix_w[i][0] + # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) + # label_name = "temp/label_{0:}_{1:}.nii.gz".format(it, i) + # # weight_name= "temp/weight_{0:}_{1:}.nii.gz".format(it, i) + # save_nd_array_as_image(image_i, image_name, reference_name = None) + # save_nd_array_as_image(label_i, label_name, reference_name = None) + # # save_nd_array_as_image(pixw_i, weight_name, reference_name = None) + # continue + + inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device) + + # zero the parameter gradients + self.optimizer.zero_grad() + + # forward + backward + optimize + outputs = self.net(inputs) + loss = self.get_loss_value(data, outputs, labels_prob) + loss.backward() + self.optimizer.step() + train_loss = train_loss + loss.item() + # get dice evaluation for each class + if(isinstance(outputs, tuple) or isinstance(outputs, list)): + outputs = outputs[0] + outputs_argmax = torch.argmax(outputs, dim = 1, keepdim = True) + soft_out = get_soft_label(outputs_argmax, class_num, self.tensor_type) + soft_out, labels_prob = reshape_prediction_and_ground_truth(soft_out, labels_prob) + dice_list = get_classwise_dice(soft_out, labels_prob) + train_dice_list.append(dice_list.cpu().numpy()) + train_avg_loss = train_loss / iter_valid + train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) + train_avg_dice = train_cls_dice[1:].mean() + + train_scalers = {'loss': train_avg_loss, 'avg_fg_dice':train_avg_dice,\ + 'class_dice': train_cls_dice} + return train_scalers \ No newline at end of file diff --git a/pymic/net_run/self_sup/util.py b/pymic/net_run/self_sup/util.py index db27702..d6adcc1 100644 --- a/pymic/net_run/self_sup/util.py +++ b/pymic/net_run/self_sup/util.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division import os +import copy import torch import random import numpy as np @@ -136,7 +137,6 @@ def get_human_body_mask_and_crop(input_dir, out_img_dir, out_mask_dir): mask_obj.CopyInformation(out_img_obj) sitk.WriteImage(mask_obj, mask_name) - def volume_fusion(x, fg_num, block_range, size_min, size_max): """ Fuse a subregion of an impage with another one to generate @@ -145,7 +145,7 @@ def volume_fusion(x, fg_num, block_range, size_min, size_max): """ #n_min, n_max, N, C, D, H, W = list(x.shape) - fg_mask = torch.zeros_like(x).to(torch.int32) + fg_mask = torch.zeros_like(x[:, :1, :, :, :]).to(torch.int32) # generate mask for n in range(N): p_num = random.randint(block_range[0], block_range[1]) @@ -163,10 +163,163 @@ def volume_fusion(x, fg_num, block_range, size_min, size_max): h1 = min(H, h0 + h) w1 = min(W, w0 + w) d0, h0, w0 = max(0, d0), max(0, h0), max(0, w0) - temp_m = torch.ones([C, d1 - d0, h1 - h0, w1 - w0]) * random.randint(1, fg_num) + temp_m = torch.ones([1, d1 - d0, h1 - h0, w1 - w0]) * random.randint(1, fg_num) fg_mask[n, :, d0:d1, h0:h1, w0:w1] = temp_m fg_w = fg_mask * 1.0 / fg_num x_roll = torch.roll(x, 1, 0) x_fuse = fg_w*x_roll + (1.0 - fg_w)*x # y_prob = get_one_hot_seg(fg_mask.to(torch.int32), fg_num + 1) return x_fuse, fg_mask + +def nonlinear_transform(x): + v_min = torch.min(x) + v_max = torch.max(x) + x = (x - v_min)/(v_max - v_min) + a = random.random() * 0.7 + 0.15 + b = random.random() * 0.7 + 0.15 + alpha = b / a + beta = (1 - b) / (1 - a) + if(alpha < 1.0 ): + y = torch.maximum(alpha*x, beta*x + 1 - beta) + else: + y = torch.minimum(alpha*x, beta*x + 1 - beta) + if(random.random() < 0.5): + y = 1.0 - y + y = y * (v_max - v_min) + v_min + return y + +def nonlienar_volume_fusion(x, block_range, size_min, size_max): + """ + Fuse a subregion of an impage with another one to generate + images and labels for self-supervised segmentation. + input x should be a batch of tensors + """ + #n_min, n_max, + N, C, D, H, W = list(x.shape) + # apply nonlinear transform to x: + x_nl1 = torch.zeros_like(x).to(torch.float32) + x_nl2 = torch.zeros_like(x).to(torch.float32) + for n in range(N): + x_nl1[n] = nonlinear_transform(x[n]) + x_nl2[n] = nonlinear_transform(x[n]) + x_roll = torch.roll(x_nl2, 1, 0) + mask = torch.zeros_like(x).to(torch.int32) + p_num = random.randint(block_range[0], block_range[1]) + for n in range(N): + for i in range(p_num): + d = random.randint(size_min[0], size_max[0]) + h = random.randint(size_min[1], size_max[1]) + w = random.randint(size_min[2], size_max[2]) + dc = random.randint(0, D - 1) + hc = random.randint(0, H - 1) + wc = random.randint(0, W - 1) + d0 = dc - d // 2 + h0 = hc - h // 2 + w0 = wc - w // 2 + d1 = min(D, d0 + d) + h1 = min(H, h0 + h) + w1 = min(W, w0 + w) + d0, h0, w0 = max(0, d0), max(0, h0), max(0, w0) + temp_m = torch.ones([C, d1 - d0, h1 - h0, w1 - w0]) + if(random.random() < 0.5): + temp_m = temp_m * 2 + mask[n, :, d0:d1, h0:h1, w0:w1] = temp_m + + mask1 = (mask == 1).to(torch.int32) + mask2 = (mask == 2).to(torch.int32) + y = x_nl1 * (1.0 - mask1) + x_nl2 * mask1 + y = y * (1.0 - mask2) + x_roll * mask2 + return y, mask + +def augmented_volume_fusion(x, size_min, size_max): + """ + Fuse a subregion of an impage with another one to generate + images and labels for self-supervised segmentation. + input x should be a batch of tensors + """ + #n_min, n_max, + N, C, D, H, W = list(x.shape) + # apply nonlinear transform to x: + x1 = torch.zeros_like(x).to(torch.float32) + y = torch.zeros_like(x).to(torch.float32) + mask = torch.zeros_like(x).to(torch.int32) + for n in range(N): + x1[n] = nonlinear_transform(x[n]) + y[n] = nonlinear_transform(x[n]) + x2 = torch.roll(x1, 1, 0) + + for n in range(N): + block_size = [random.randint(size_min[i], size_max[i]) for i in range(3)] + d_start = random.randint(0, block_size[0] // 2) + h_start = random.randint(0, block_size[1] // 2) + w_stat = random.randint(0, block_size[2] // 2) + for d in range(d_start, D, block_size[0]): + if(D - d < block_size[0] // 2): + continue + d1 = min(d + block_size[0], D) + for h in range(h_start, H, block_size[1]): + if(H - h < block_size[1] // 2): + continue + h1 = min(h + block_size[1], H) + for w in range(w_stat, W, block_size[2]): + if(W - w < block_size[2] // 2): + continue + w1 = min(w + block_size[2], W) + p = random.random() + if(p < 0.15): # nonlinear intensity augmentation + mask[n, :, d:d1, h:h1, w:w1] = 1 + y[n, :, d:d1, h:h1, w:w1] = x1[n, :, d:d1, h:h1, w:w1] + elif(p < 0.3): # random flip across a certain axis + mask[n, :, d:d1, h:h1, w:w1] = 2 + flip_axis = random.randint(-3, -1) + y[n, :, d:d1, h:h1, w:w1] = torch.flip(y[n, :, d:d1, h:h1, w:w1], (flip_axis,)) + elif(p < 0.45): # nonlinear intensity augmentation and random flip across a certain axis + mask[n, :, d:d1, h:h1, w:w1] = 3 + flip_axis = random.randint(-3, -1) + y[n, :, d:d1, h:h1, w:w1] = torch.flip(x1[n, :, d:d1, h:h1, w:w1], (flip_axis,)) + elif(p < 0.6): # paste from another volume + mask[n, :, d:d1, h:h1, w:w1] = 4 + y[n, :, d:d1, h:h1, w:w1] = x2[n, :, d:d1, h:h1, w:w1] + return y, mask + +def self_volume_fusion(x, fg_num, fuse_ratio, size_min, size_max): + """ + Fuse a subregion of an impage with another one to generate + images and labels for self-supervised segmentation. + input x should be a batch of tensors + """ + #n_min, n_max, + N, C, D, H, W = list(x.shape) + y = 1.0 * x + fg_mask = torch.zeros_like(x[:, :1, :, :, :]).to(torch.int32) + + for n in range(N): + db = random.randint(size_min[0], size_max[0]) + hb = random.randint(size_min[1], size_max[1]) + wb = random.randint(size_min[2], size_max[2]) + d0 = random.randint(0, D % db) + h0 = random.randint(0, H % hb) + w0 = random.randint(0, W % wb) + coord_list_source = [] + for di in range(D // db): + for hi in range(H // hb): + for wi in range(W // wb): + coord_list_source.append([di, hi, wi]) + coord_list_target = copy.deepcopy(coord_list_source) + random.shuffle(coord_list_source) + random.shuffle(coord_list_target) + for i in range(int(len(coord_list_source)*fuse_ratio)): + ds_l = d0 + db * coord_list_source[i][0] + hs_l = h0 + hb * coord_list_source[i][1] + ws_l = w0 + wb * coord_list_source[i][2] + dt_l = d0 + db * coord_list_target[i][0] + ht_l = h0 + hb * coord_list_target[i][1] + wt_l = w0 + wb * coord_list_target[i][2] + s_crop = x[n, :, ds_l:ds_l+db, hs_l:hs_l+hb, ws_l:ws_l+wb] + t_crop = x[n, :, dt_l:dt_l+db, ht_l:ht_l+hb, wt_l:wt_l+wb] + fg_m = random.randint(1, fg_num) + fg_w = fg_m / (fg_num + 0.0) + y[n, :, dt_l:dt_l+db, ht_l:ht_l+hb, wt_l:wt_l+wb] = t_crop * (1.0 - fg_w) + s_crop * fg_w + fg_mask[n, 0, dt_l:dt_l+db, ht_l:ht_l+hb, wt_l:wt_l+wb] = \ + torch.ones([1, db, hb, wb]) * fg_m + return y, fg_mask \ No newline at end of file diff --git a/pymic/net_run/semi_sup/__init__.py b/pymic/net_run/semi_sup/__init__.py index d3095f6..cb5d1a3 100644 --- a/pymic/net_run/semi_sup/__init__.py +++ b/pymic/net_run/semi_sup/__init__.py @@ -3,6 +3,7 @@ from pymic.net_run.semi_sup.ssl_em import SSLEntropyMinimization from pymic.net_run.semi_sup.ssl_mt import SSLMeanTeacher from pymic.net_run.semi_sup.ssl_mcnet import SSLMCNet +from pymic.net_run.semi_sup.ssl_cdma import SSLCDMA from pymic.net_run.semi_sup.ssl_uamt import SSLUncertaintyAwareMeanTeacher from pymic.net_run.semi_sup.ssl_cct import SSLCCT from pymic.net_run.semi_sup.ssl_cps import SSLCPS @@ -12,6 +13,7 @@ SSLMethodDict = {'EntropyMinimization': SSLEntropyMinimization, 'MeanTeacher': SSLMeanTeacher, 'MCNet': SSLMCNet, + 'CDMA': SSLCDMA, 'UAMT': SSLUncertaintyAwareMeanTeacher, 'CCT': SSLCCT, 'CPS': SSLCPS, diff --git a/pymic/net_run/semi_sup/ssl_abstract.py b/pymic/net_run/semi_sup/ssl_abstract.py index 0e05281..69a09fd 100644 --- a/pymic/net_run/semi_sup/ssl_abstract.py +++ b/pymic/net_run/semi_sup/ssl_abstract.py @@ -101,6 +101,9 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): logging.info('valid loss {0:.4f}, avg foreground dice {1:.4f} '.format( valid_scalars['loss'], valid_scalars['avg_fg_dice']) + "[" + \ ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") + logging.info("data: {0:.2f}s, forward: {1:.2f}s, loss: {2:.2f}s, backward: {3:.2f}s".format( + train_scalars['data_time'], train_scalars['forward_time'], + train_scalars['loss_time'], train_scalars['backward_time'])) def train_valid(self): self.trainIter_unlab = iter(self.train_loader_unlab) diff --git a/pymic/net_run/semi_sup/ssl_cct.py b/pymic/net_run/semi_sup/ssl_cct.py index 1943608..81723bd 100644 --- a/pymic/net_run/semi_sup/ssl_cct.py +++ b/pymic/net_run/semi_sup/ssl_cct.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division import logging +import time import torch import torch.nn as nn import torch.nn.functional as F @@ -88,13 +89,13 @@ def training(self): rampup_end = ssl_cfg.get('rampup_end', iter_max) unsup_loss_name = ssl_cfg.get('unsupervised_loss', "MSE") self.unsup_loss_f = unsup_loss_dict[unsup_loss_name] - train_loss = 0 - train_loss_sup = 0 - train_loss_reg = 0 + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() for it in range(iter_valid): + t0 = time.time() try: data_lab = next(self.trainIter) except StopIteration: @@ -105,7 +106,7 @@ def training(self): except StopIteration: self.trainIter_unlab = iter(self.train_loader_unlab) data_unlab = next(self.trainIter_unlab) - + t1 = time.time() # get the inputs x0 = self.convert_tensor_type(data_lab['image']) y0 = self.convert_tensor_type(data_lab['label_prob']) @@ -118,6 +119,7 @@ def training(self): # forward pass output, aux_outputs = self.net(inputs) + t2 = time.time() n0 = list(x0.shape)[0] # get supervised loss @@ -135,8 +137,9 @@ def training(self): rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") regular_w = ssl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg - + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss = train_loss + loss.item() @@ -150,6 +153,11 @@ def training(self): p0_soft, y0 = reshape_prediction_and_ground_truth(p0_soft, y0) dice_list = get_classwise_dice(p0_soft, y0) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid @@ -158,5 +166,7 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } return train_scalers diff --git a/pymic/net_run/semi_sup/ssl_cps.py b/pymic/net_run/semi_sup/ssl_cps.py index 7acfe17..dc0d325 100644 --- a/pymic/net_run/semi_sup/ssl_cps.py +++ b/pymic/net_run/semi_sup/ssl_cps.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division import logging +import time import numpy as np import torch from random import random @@ -44,8 +45,10 @@ def training(self): train_loss_sup1, train_loss_pseudo_sup1 = 0, 0 train_loss_sup2, train_loss_pseudo_sup2 = 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() for it in range(iter_valid): + t0 = time.time() try: data_lab = next(self.trainIter) except StopIteration: @@ -56,7 +59,7 @@ def training(self): except StopIteration: self.trainIter_unlab = iter(self.train_loader_unlab) data_unlab = next(self.trainIter_unlab) - + t1 = time.time() # get the inputs x0 = self.convert_tensor_type(data_lab['image']) y0 = self.convert_tensor_type(data_lab['label_prob']) @@ -84,6 +87,7 @@ def training(self): outputs1, outputs2 = self.net(inputs) outputs_soft1 = torch.softmax(outputs1, dim=1) outputs_soft2 = torch.softmax(outputs2, dim=1) + t2 = time.time() n0 = list(x0.shape)[0] p0 = outputs_soft1[:n0] @@ -105,8 +109,9 @@ def training(self): model1_loss = loss_sup1 + regular_w * pse_sup1 model2_loss = loss_sup2 + regular_w * pse_sup2 loss = model1_loss + model2_loss - + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss = train_loss + loss.item() @@ -123,6 +128,11 @@ def training(self): p0_soft, y0 = reshape_prediction_and_ground_truth(p0_soft, y0) dice_list = get_classwise_dice(p0_soft, y0) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup1 = train_loss_sup1 / iter_valid train_avg_loss_sup2 = train_loss_sup2 / iter_valid @@ -134,7 +144,9 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup1':train_avg_loss_sup1, 'loss_sup2': train_avg_loss_sup2, 'loss_pse_sup1':train_avg_loss_pse_sup1, 'loss_pse_sup2': train_avg_loss_pse_sup2, - 'regular_w':regular_w, 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'regular_w':regular_w, 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } return train_scalers def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): @@ -162,4 +174,7 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): ' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]") logging.info('valid loss {0:.4f}, avg dice {1:.4f} '.format( valid_scalars['loss'], valid_scalars['avg_fg_dice']) + "[" + \ - ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") \ No newline at end of file + ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") + logging.info("data: {0:.2f}s, forward: {1:.2f}s, loss: {2:.2f}s, backward: {3:.2f}s".format( + train_scalars['data_time'], train_scalars['forward_time'], + train_scalars['loss_time'], train_scalars['backward_time'])) \ No newline at end of file diff --git a/pymic/net_run/semi_sup/ssl_em.py b/pymic/net_run/semi_sup/ssl_em.py index fde941b..750e3d3 100644 --- a/pymic/net_run/semi_sup/ssl_em.py +++ b/pymic/net_run/semi_sup/ssl_em.py @@ -2,6 +2,7 @@ from __future__ import print_function, division import logging import numpy as np +import time import torch from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth @@ -40,12 +41,12 @@ def training(self): iter_max = self.config['training']['iter_max'] rampup_start = ssl_cfg.get('rampup_start', 0) rampup_end = ssl_cfg.get('rampup_end', iter_max) - train_loss = 0 - train_loss_sup = 0 - train_loss_reg = 0 + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() for it in range(iter_valid): + t0 = time.time() try: data_lab = next(self.trainIter) except StopIteration: @@ -56,7 +57,7 @@ def training(self): except StopIteration: self.trainIter_unlab = iter(self.train_loader_unlab) data_unlab = next(self.trainIter_unlab) - + t1 = time.time() # get the inputs x0 = self.convert_tensor_type(data_lab['image']) y0 = self.convert_tensor_type(data_lab['label_prob']) @@ -69,6 +70,8 @@ def training(self): # forward + backward + optimize outputs = self.net(inputs) + t2 = time.time() + n0 = list(x0.shape)[0] p0 = outputs[:n0] loss_sup = self.get_loss_value(data_lab, p0, y0) @@ -79,8 +82,10 @@ def training(self): regular_w = ssl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg - # if (self.config['training']['use']) + t3 = time.time() + loss.backward() + t4 = time.time() self.optimizer.step() train_loss = train_loss + loss.item() @@ -94,6 +99,11 @@ def training(self): p0_soft, y0 = reshape_prediction_and_ground_truth(p0_soft, y0) dice_list = get_classwise_dice(p0_soft, y0) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid @@ -102,5 +112,7 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } return train_scalers \ No newline at end of file diff --git a/pymic/net_run/semi_sup/ssl_mcnet.py b/pymic/net_run/semi_sup/ssl_mcnet.py index 66e1034..d374773 100644 --- a/pymic/net_run/semi_sup/ssl_mcnet.py +++ b/pymic/net_run/semi_sup/ssl_mcnet.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division import logging +import time import torch import torch.nn as nn import torch.nn.functional as F @@ -44,13 +45,13 @@ def training(self): rampup_end = ssl_cfg.get('rampup_end', iter_max) temperature = ssl_cfg.get('temperature', 0.1) unsup_loss_name = ssl_cfg.get('unsupervised_loss', "MSE") - train_loss = 0 - train_loss_sup = 0 - train_loss_reg = 0 + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() for it in range(iter_valid): + t0 = time.time() try: data_lab = next(self.trainIter) except StopIteration: @@ -61,7 +62,7 @@ def training(self): except StopIteration: self.trainIter_unlab = iter(self.train_loader_unlab) data_unlab = next(self.trainIter_unlab) - + t1 = time.time() # get the inputs x0 = self.convert_tensor_type(data_lab['image']) y0 = self.convert_tensor_type(data_lab['label_prob']) @@ -74,6 +75,7 @@ def training(self): # forward pass to obtain multiple predictions outputs = self.net(inputs) + t2 = time.time() num_outputs = len(outputs) n0 = list(x0.shape)[0] p0 = F.softmax(outputs[0], dim=1)[:n0] @@ -81,7 +83,7 @@ def training(self): p_ori = torch.zeros((num_outputs,) + outputs[0].shape) y_psu = torch.zeros((num_outputs,) + outputs[0].shape) - # get supervised loss + # get supervised loss loss_sup = 0 for idx in range(num_outputs): p0i = outputs[idx][:n0] @@ -102,8 +104,9 @@ def training(self): rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") regular_w = ssl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg - + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss = train_loss + loss.item() @@ -117,6 +120,11 @@ def training(self): p0_soft, y0 = reshape_prediction_and_ground_truth(p0_soft, y0) dice_list = get_classwise_dice(p0_soft, y0) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid @@ -125,5 +133,7 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } return train_scalers diff --git a/pymic/net_run/semi_sup/ssl_mt.py b/pymic/net_run/semi_sup/ssl_mt.py index 409af19..303eeff 100644 --- a/pymic/net_run/semi_sup/ssl_mt.py +++ b/pymic/net_run/semi_sup/ssl_mt.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division import logging +import time import torch import numpy as np from pymic.loss.seg.util import get_soft_label @@ -50,13 +51,13 @@ def training(self): iter_max = self.config['training']['iter_max'] rampup_start = ssl_cfg.get('rampup_start', 0) rampup_end = ssl_cfg.get('rampup_end', iter_max) - train_loss = 0 - train_loss_sup = 0 - train_loss_reg = 0 + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() self.net_ema.to(self.device) for it in range(iter_valid): + t0 = time.time() try: data_lab = next(self.trainIter) except StopIteration: @@ -67,7 +68,7 @@ def training(self): except StopIteration: self.trainIter_unlab = iter(self.train_loader_unlab) data_unlab = next(self.trainIter_unlab) - + t1 = time.time() # get the inputs x0 = self.convert_tensor_type(data_lab['image']) y0 = self.convert_tensor_type(data_lab['label_prob']) @@ -82,6 +83,8 @@ def training(self): self.optimizer.zero_grad() outputs = self.net(inputs) + t2 = time.time() + n0 = list(x0.shape)[0] p0 = outputs[:n0] loss_sup = self.get_loss_value(data_lab, p0, y0) @@ -98,8 +101,9 @@ def training(self): loss_reg = torch.nn.MSELoss()(p1_soft, p1_ema_soft) loss = loss_sup + regular_w*loss_reg - + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() # update EMA @@ -119,6 +123,11 @@ def training(self): p0_soft, y0 = reshape_prediction_and_ground_truth(p0_soft, y0) dice_list = get_classwise_dice(p0_soft, y0) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid @@ -127,5 +136,7 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time} return train_scalers \ No newline at end of file diff --git a/pymic/net_run/semi_sup/ssl_uamt.py b/pymic/net_run/semi_sup/ssl_uamt.py index 053a012..5888a8b 100644 --- a/pymic/net_run/semi_sup/ssl_uamt.py +++ b/pymic/net_run/semi_sup/ssl_uamt.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division import logging +import time import torch import numpy as np from pymic.loss.seg.util import get_soft_label @@ -33,13 +34,13 @@ def training(self): iter_max = self.config['training']['iter_max'] rampup_start = ssl_cfg.get('rampup_start', 0) rampup_end = ssl_cfg.get('rampup_end', iter_max) - train_loss = 0 - train_loss_sup = 0 - train_loss_reg = 0 + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() self.net_ema.to(self.device) for it in range(iter_valid): + t0 = time.time() try: data_lab = next(self.trainIter) except StopIteration: @@ -50,7 +51,7 @@ def training(self): except StopIteration: self.trainIter_unlab = iter(self.train_loader_unlab) data_unlab = next(self.trainIter_unlab) - + t1 = time.time() # get the inputs x0 = self.convert_tensor_type(data_lab['image']) y0 = self.convert_tensor_type(data_lab['label_prob']) @@ -64,6 +65,7 @@ def training(self): self.optimizer.zero_grad() outputs = self.net(inputs) + t2 = time.time() n0 = list(x0.shape)[0] p0, p1 = torch.tensor_split(outputs, [n0,], dim = 0) outputs_soft = torch.softmax(outputs, dim=1) @@ -100,8 +102,9 @@ def training(self): regular_w = ssl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg - + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() # update EMA @@ -121,6 +124,11 @@ def training(self): p0_soft, y0 = reshape_prediction_and_ground_truth(p0_soft, y0) dice_list = get_classwise_dice(p0_soft, y0) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid @@ -129,5 +137,7 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } return train_scalers \ No newline at end of file diff --git a/pymic/net_run/semi_sup/ssl_urpc.py b/pymic/net_run/semi_sup/ssl_urpc.py index 56bb77e..8447709 100644 --- a/pymic/net_run/semi_sup/ssl_urpc.py +++ b/pymic/net_run/semi_sup/ssl_urpc.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division import logging +import time import torch import torch.nn as nn import numpy as np @@ -35,13 +36,13 @@ def training(self): iter_max = self.config['training']['iter_max'] rampup_start = ssl_cfg.get('rampup_start', 0) rampup_end = ssl_cfg.get('rampup_end', iter_max) - train_loss = 0 - train_loss_sup = 0 - train_loss_reg = 0 + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() kl_distance = nn.KLDivLoss(reduction='none') for it in range(iter_valid): + t0 = time.time() try: data_lab = next(self.trainIter) except StopIteration: @@ -52,7 +53,7 @@ def training(self): except StopIteration: self.trainIter_unlab = iter(self.train_loader_unlab) data_unlab = next(self.trainIter_unlab) - + t1 = time.time() # get the inputs x0 = self.convert_tensor_type(data_lab['image']) y0 = self.convert_tensor_type(data_lab['label_prob']) @@ -65,6 +66,7 @@ def training(self): # forward pass outputs_list = self.net(inputs) + t2 = time.time() n0 = list(x0.shape)[0] # get supervised loss @@ -95,8 +97,9 @@ def training(self): rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") regular_w = ssl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg - + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss = train_loss + loss.item() @@ -110,6 +113,11 @@ def training(self): p0_soft, y0 = reshape_prediction_and_ground_truth(p0_soft, y0) dice_list = get_classwise_dice(p0_soft, y0) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid @@ -118,5 +126,7 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time} return train_scalers diff --git a/pymic/net_run/weak_sup/wsl_abstract.py b/pymic/net_run/weak_sup/wsl_abstract.py index f290465..37b9445 100644 --- a/pymic/net_run/weak_sup/wsl_abstract.py +++ b/pymic/net_run/weak_sup/wsl_abstract.py @@ -24,7 +24,7 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): 'valid':valid_scalars['loss']} loss_sup_scalar = {'train':train_scalars['loss_sup']} loss_upsup_scalar = {'train':train_scalars['loss_reg']} - dice_scalar ={'train':train_scalars['avg_fg_dice'], 'valid':valid_scalars['avg_fg_dice']} + dice_scalar ={'valid':valid_scalars['avg_fg_dice']} self.summ_writer.add_scalars('loss', loss_scalar, glob_it) self.summ_writer.add_scalars('loss_sup', loss_sup_scalar, glob_it) self.summ_writer.add_scalars('loss_reg', loss_upsup_scalar, glob_it) @@ -36,9 +36,10 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): cls_dice_scalar = {'train':train_scalars['class_dice'][c], \ 'valid':valid_scalars['class_dice'][c]} self.summ_writer.add_scalars('class_{0:}_dice'.format(c), cls_dice_scalar, glob_it) - logging.info('train loss {0:.4f}, avg foreground dice {1:.4f} '.format( - train_scalars['loss'], train_scalars['avg_fg_dice']) + "[" + \ - ' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]") + logging.info('train loss {0:.4f}'.format(train_scalars['loss'])) logging.info('valid loss {0:.4f}, avg foreground dice {1:.4f} '.format( valid_scalars['loss'], valid_scalars['avg_fg_dice']) + "[" + \ ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") + logging.info("data: {0:.2f}s, forward: {1:.2f}s, loss: {2:.2f}s, backward: {3:.2f}s".format( + train_scalars['data_time'], train_scalars['forward_time'], + train_scalars['loss_time'], train_scalars['backward_time'])) diff --git a/pymic/net_run/weak_sup/wsl_dmpls.py b/pymic/net_run/weak_sup/wsl_dmpls.py index 4212409..01902c7 100644 --- a/pymic/net_run/weak_sup/wsl_dmpls.py +++ b/pymic/net_run/weak_sup/wsl_dmpls.py @@ -3,6 +3,7 @@ import logging import numpy as np import random +import time import torch from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label @@ -44,18 +45,18 @@ def training(self): iter_max = self.config['training']['iter_max'] rampup_start = wsl_cfg.get('rampup_start', 0) rampup_end = wsl_cfg.get('rampup_end', iter_max) - train_loss = 0 - train_loss_sup = 0 - train_loss_reg = 0 + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() for it in range(iter_valid): + t0 = time.time() try: data = next(self.trainIter) except StopIteration: self.trainIter = iter(self.train_loader) data = next(self.trainIter) - + t1 = time.time() # get the inputs inputs = self.convert_tensor_type(data['image']) y = self.convert_tensor_type(data['label_prob']) @@ -67,6 +68,8 @@ def training(self): # forward + backward + optimize outputs1, outputs2 = self.net(inputs) + t2 = time.time() + loss_sup1 = self.get_loss_value(data, outputs1, y) loss_sup2 = self.get_loss_value(data, outputs2, y) loss_sup = 0.5 * (loss_sup1 + loss_sup2) @@ -88,8 +91,9 @@ def training(self): rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg - + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss = train_loss + loss.item() @@ -103,6 +107,11 @@ def training(self): p_soft, y = reshape_prediction_and_ground_truth(p_soft, y) dice_list = get_classwise_dice(p_soft, y) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid @@ -111,7 +120,8 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} - + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time} return train_scalers \ No newline at end of file diff --git a/pymic/net_run/weak_sup/wsl_em.py b/pymic/net_run/weak_sup/wsl_em.py index adcd70c..987aa89 100644 --- a/pymic/net_run/weak_sup/wsl_em.py +++ b/pymic/net_run/weak_sup/wsl_em.py @@ -2,12 +2,12 @@ from __future__ import print_function, division import logging import numpy as np +import time import torch from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.ssl import EntropyLoss -from pymic.net_run.agent_seg import SegmentationAgent from pymic.net_run.weak_sup import WSLSegAgent from pymic.util.ramps import get_rampup_ratio @@ -38,18 +38,18 @@ def training(self): iter_max = self.config['training']['iter_max'] rampup_start = wsl_cfg.get('rampup_start', 0) rampup_end = wsl_cfg.get('rampup_end', iter_max) - train_loss = 0 - train_loss_sup = 0 - train_loss_reg = 0 + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() for it in range(iter_valid): + t0 = time.time() try: data = next(self.trainIter) except StopIteration: self.trainIter = iter(self.train_loader) data = next(self.trainIter) - + t1 = time.time() # get the inputs inputs = self.convert_tensor_type(data['image']) y = self.convert_tensor_type(data['label_prob']) @@ -61,6 +61,8 @@ def training(self): # forward + backward + optimize outputs = self.net(inputs) + t2 = time.time() + loss_sup = self.get_loss_value(data, outputs, y) loss_dict= {"prediction":outputs, 'softmax':True} loss_reg = EntropyLoss()(loss_dict) @@ -68,8 +70,9 @@ def training(self): rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg - + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss = train_loss + loss.item() @@ -83,6 +86,11 @@ def training(self): p_soft, y = reshape_prediction_and_ground_truth(p_soft, y) dice_list = get_classwise_dice(p_soft, y) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid @@ -91,5 +99,7 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } return train_scalers \ No newline at end of file diff --git a/pymic/net_run/weak_sup/wsl_gatedcrf.py b/pymic/net_run/weak_sup/wsl_gatedcrf.py index 2ce1f95..1ecae4a 100644 --- a/pymic/net_run/weak_sup/wsl_gatedcrf.py +++ b/pymic/net_run/weak_sup/wsl_gatedcrf.py @@ -2,6 +2,7 @@ from __future__ import print_function, division import logging import numpy as np +import time import torch from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth @@ -48,20 +49,19 @@ def training(self): iter_max = self.config['training']['iter_max'] rampup_start = wsl_cfg.get('rampup_start', 0) rampup_end = wsl_cfg.get('rampup_end', iter_max) - train_loss = 0 - train_loss_sup = 0 - train_loss_reg = 0 + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] - + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 gatecrf_loss = GatedCRFLoss() self.net.train() for it in range(iter_valid): + t0 = time.time() try: data = next(self.trainIter) except StopIteration: self.trainIter = iter(self.train_loader) data = next(self.trainIter) - + t1 = time.time() # get the inputs inputs = self.convert_tensor_type(data['image']) y = self.convert_tensor_type(data['label_prob']) @@ -94,8 +94,9 @@ def training(self): rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg - + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss = train_loss + loss.item() @@ -109,6 +110,11 @@ def training(self): p_soft, y = reshape_prediction_and_ground_truth(p_soft, y) dice_list = get_classwise_dice(p_soft, y) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid @@ -117,6 +123,8 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } return train_scalers \ No newline at end of file diff --git a/pymic/net_run/weak_sup/wsl_mumford_shah.py b/pymic/net_run/weak_sup/wsl_mumford_shah.py index 2480fee..e917fb3 100644 --- a/pymic/net_run/weak_sup/wsl_mumford_shah.py +++ b/pymic/net_run/weak_sup/wsl_mumford_shah.py @@ -2,6 +2,7 @@ from __future__ import print_function, division import logging import numpy as np +import time import torch from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth @@ -37,20 +38,20 @@ def training(self): iter_max = self.config['training']['iter_max'] rampup_start = wsl_cfg.get('rampup_start', 0) rampup_end = wsl_cfg.get('rampup_end', iter_max) - train_loss = 0 - train_loss_sup = 0 - train_loss_reg = 0 + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 reg_loss_calculator = MumfordShahLoss(wsl_cfg) self.net.train() for it in range(iter_valid): + t0 = time.time() try: data = next(self.trainIter) except StopIteration: self.trainIter = iter(self.train_loader) data = next(self.trainIter) - + t1 = time.time() # get the inputs inputs = self.convert_tensor_type(data['image']) y = self.convert_tensor_type(data['label_prob']) @@ -62,6 +63,7 @@ def training(self): # forward + backward + optimize outputs = self.net(inputs) + t2 = time.time() loss_sup = self.get_loss_value(data, outputs, y) loss_dict = {"prediction":outputs, 'image':inputs} loss_reg = reg_loss_calculator(loss_dict) @@ -69,8 +71,9 @@ def training(self): rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg - # if (self.config['training']['use']) + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss = train_loss + loss.item() @@ -84,6 +87,11 @@ def training(self): p_soft, y = reshape_prediction_and_ground_truth(p_soft, y) dice_list = get_classwise_dice(p_soft, y) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid @@ -92,6 +100,8 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } return train_scalers \ No newline at end of file diff --git a/pymic/net_run/weak_sup/wsl_tv.py b/pymic/net_run/weak_sup/wsl_tv.py index 9d13c5d..5a43ce0 100644 --- a/pymic/net_run/weak_sup/wsl_tv.py +++ b/pymic/net_run/weak_sup/wsl_tv.py @@ -2,6 +2,7 @@ from __future__ import print_function, division import logging import numpy as np +import time import torch from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth @@ -34,18 +35,18 @@ def training(self): iter_max = self.config['training']['iter_max'] rampup_start = wsl_cfg.get('rampup_start', 0) rampup_end = wsl_cfg.get('rampup_end', iter_max) - train_loss = 0 - train_loss_sup = 0 - train_loss_reg = 0 + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() for it in range(iter_valid): + t0 = time.time() try: data = next(self.trainIter) except StopIteration: self.trainIter = iter(self.train_loader) data = next(self.trainIter) - + t1 = time.time() # get the inputs inputs = self.convert_tensor_type(data['image']) y = self.convert_tensor_type(data['label_prob']) @@ -57,6 +58,7 @@ def training(self): # forward + backward + optimize outputs = self.net(inputs) + t2 = time.time() loss_sup = self.get_loss_value(data, outputs, y) loss_dict = {"prediction":outputs, 'softmax':True} loss_reg = TotalVariationLoss()(loss_dict) @@ -64,8 +66,9 @@ def training(self): rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg - # if (self.config['training']['use']) + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss = train_loss + loss.item() @@ -79,6 +82,11 @@ def training(self): p_soft, y = reshape_prediction_and_ground_truth(p_soft, y) dice_list = get_classwise_dice(p_soft, y) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid @@ -87,6 +95,8 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } return train_scalers \ No newline at end of file diff --git a/pymic/net_run/weak_sup/wsl_ustm.py b/pymic/net_run/weak_sup/wsl_ustm.py index 31a6644..ac1c6a5 100644 --- a/pymic/net_run/weak_sup/wsl_ustm.py +++ b/pymic/net_run/weak_sup/wsl_ustm.py @@ -3,6 +3,7 @@ import logging import numpy as np import random +import time import torch import torch.nn.functional as F from pymic.loss.seg.util import get_soft_label @@ -54,19 +55,19 @@ def training(self): iter_max = self.config['training']['iter_max'] rampup_start = wsl_cfg.get('rampup_start', 0) rampup_end = wsl_cfg.get('rampup_end', iter_max) - train_loss = 0 - train_loss_sup = 0 - train_loss_reg = 0 + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() self.net_ema.to(self.device) for it in range(iter_valid): + t0 = time.time() try: data = next(self.trainIter) except StopIteration: self.trainIter = iter(self.train_loader) data = next(self.trainIter) - + t1 = time.time() # get the inputs inputs = self.convert_tensor_type(data['image']) y = self.convert_tensor_type(data['label_prob']) @@ -79,6 +80,7 @@ def training(self): # forward + backward + optimize noise = torch.clamp(torch.randn_like(inputs) * 0.1, -0.2, 0.2) outputs = self.net(inputs + noise) + t2 = time.time() out_prob= F.softmax(outputs, dim=1) loss_sup = self.get_loss_value(data, outputs, y) @@ -117,7 +119,7 @@ def training(self): regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg - + t3 = time.time() loss.backward() self.optimizer.step() @@ -126,6 +128,7 @@ def training(self): alpha = min(1 - 1 / (self.glob_it / iter_valid + 1), alpha) for ema_param, param in zip(self.net_ema.parameters(), self.net.parameters()): ema_param.data.mul_(alpha).add(param.data, alpha = 1.0 - alpha) + t4 = time.time() train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() @@ -138,6 +141,11 @@ def training(self): p_soft, y = reshape_prediction_and_ground_truth(p_soft, y) dice_list = get_classwise_dice(p_soft, y) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid @@ -146,5 +154,7 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } return train_scalers \ No newline at end of file From b11dd8c22239e533a756fd709c907cb2e07b8e7f Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 30 Sep 2024 16:17:35 +0800 Subject: [PATCH 189/225] update self-supervised learning methods --- pymic/net_run/agent_rec.py | 30 +- pymic/net_run/agent_seg.py | 25 +- pymic/net_run/self_sup/self_volf.py | 37 ++ pymic/net_run/self_sup/self_volume_fusion.py | 354 ------------------- pymic/transform/trans_dict.py | 17 +- pymic/transform/volume_fusion.py | 117 ++++++ 6 files changed, 213 insertions(+), 367 deletions(-) create mode 100644 pymic/net_run/self_sup/self_volf.py delete mode 100644 pymic/net_run/self_sup/self_volume_fusion.py create mode 100644 pymic/transform/volume_fusion.py diff --git a/pymic/net_run/agent_rec.py b/pymic/net_run/agent_rec.py index cd311ad..cc83526 100644 --- a/pymic/net_run/agent_rec.py +++ b/pymic/net_run/agent_rec.py @@ -29,6 +29,9 @@ class ReconstructionAgent(SegmentationAgent): """ def __init__(self, config, stage = 'train'): super(ReconstructionAgent, self).__init__(config, stage) + if (self.config['network']['class_num'] != 1): + raise ValueError("For reconstruction tasks, the output channel number should be 1, " + + "but {} was given.".format(self.config['network']['class_num'])) def create_loss_calculator(self): if(self.loss_dict is None): @@ -55,14 +58,17 @@ def create_loss_calculator(self): def training(self): iter_valid = self.config['training']['iter_valid'] train_loss = 0 + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() for it in range(iter_valid): + t0 = time.time() try: data = next(self.trainIter) except StopIteration: self.trainIter = iter(self.train_loader) data = next(self.trainIter) # get the inputs + t1 = time.time() inputs = self.convert_tensor_type(data['image']) label = self.convert_tensor_type(data['label']) @@ -87,7 +93,7 @@ def training(self): # forward + backward + optimize outputs = self.net(inputs) - + t2 = time.time() # for debug # if it < 5: # outputs = nn.Tanh()(outputs) @@ -100,15 +106,23 @@ def training(self): # break loss = self.get_loss_value(data, outputs, label) + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss = train_loss + loss.item() - # get dice evaluation for each class + if(isinstance(outputs, tuple) or isinstance(outputs, list)): outputs = outputs[0] + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid - train_scalers = {'loss': train_avg_loss} + train_scalers = {'loss': train_avg_loss, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time} return train_scalers def validation(self): @@ -163,6 +177,9 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it) logging.info('train loss {0:.4f}'.format(train_scalars['loss'])) logging.info('valid loss {0:.4f}'.format(valid_scalars['loss'])) + logging.info("data: {0:.2f}s, forward: {1:.2f}s, loss: {2:.2f}s, backward: {3:.2f}s".format( + train_scalars['data_time'], train_scalars['forward_time'], + train_scalars['loss_time'], train_scalars['backward_time'])) def train_valid(self): device_ids = self.config['training']['gpus'] @@ -173,7 +190,7 @@ def train_valid(self): self.device = torch.device("cuda:{0:}".format(device_ids[0])) self.net.to(self.device) - ckpt_dir = self.config['training']['ckpt_save_dir'] + ckpt_dir = self.config['training']['ckpt_dir'] ckpt_prefix = self.config['training'].get('ckpt_prefix', None) if(ckpt_prefix is None): ckpt_prefix = ckpt_dir.split('/')[-1] @@ -224,7 +241,7 @@ def train_valid(self): self.trainIter = iter(self.train_loader) logging.info("{0:} training start".format(str(datetime.now())[:-7])) - self.summ_writer = SummaryWriter(self.config['training']['ckpt_save_dir']) + self.summ_writer = SummaryWriter(self.config['training']['ckpt_dir']) self.glob_it = iter_start for it in range(iter_start, iter_max, iter_valid): lr_value = self.optimizer.param_groups[0]['lr'] @@ -242,6 +259,9 @@ def train_valid(self): logging.info("\n{0:} it {1:}".format(str(datetime.now())[:-7], self.glob_it)) logging.info('learning rate {0:}'.format(lr_value)) logging.info("training/validation time: {0:.2f}s/{1:.2f}s".format(t1-t0, t2-t1)) + logging.info("data: {0:.2f}s, forward: {1:.2f}s, loss: {2:.2f}s, backward: {3:.2f}s".format( + train_scalars['data_time'], train_scalars['forward_time'], + train_scalars['loss_time'], train_scalars['backward_time'])) self.write_scalars(train_scalars, valid_scalars, lr_value, self.glob_it) if(valid_scalars['loss'] < self.min_val_loss): self.min_val_loss = valid_scalars['loss'] diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 4a80298..59ffe05 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -74,6 +74,7 @@ def get_stage_dataset_from_config(self, stage): stage_dir = self.config['dataset']['valid_dir'] if(stage == 'test' and "test_dir" in self.config['dataset']): stage_dir = self.config['dataset']['test_dir'] + logging.info("Creating dataset for {0:}".format(stage)) dataset = NiftyDataset(root_dir = stage_dir, csv_file = csv_file, modal_num = modal_num, @@ -163,14 +164,16 @@ def training(self): mixup_prob = self.config['training'].get('mixup_probability', 0.0) train_loss = 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() for it in range(iter_valid): + t0 = time.time() try: data = next(self.trainIter) except StopIteration: self.trainIter = iter(self.train_loader) data = next(self.trainIter) - # get the inputs + t1 = time.time() inputs = self.convert_tensor_type(data['image']) labels_prob = self.convert_tensor_type(data['label_prob']) if(mixup_prob > 0 and random() < mixup_prob): @@ -196,14 +199,17 @@ def training(self): inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device) - + # zero the parameter gradients self.optimizer.zero_grad() # forward + backward + optimize outputs = self.net(inputs) + t2 = time.time() loss = self.get_loss_value(data, outputs, labels_prob) + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss = train_loss + loss.item() # get dice evaluation for each class @@ -214,12 +220,19 @@ def training(self): soft_out, labels_prob = reshape_prediction_and_ground_truth(soft_out, labels_prob) dice_list = get_classwise_dice(soft_out, labels_prob) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) train_avg_dice = train_cls_dice[1:].mean() train_scalers = {'loss': train_avg_loss, 'avg_fg_dice':train_avg_dice,\ - 'class_dice': train_cls_dice} + 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time} return train_scalers def validation(self): @@ -289,7 +302,10 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): ' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]") logging.info('valid loss {0:.4f}, avg foreground dice {1:.4f} '.format( valid_scalars['loss'], valid_scalars['avg_fg_dice']) + "[" + \ - ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") + ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") + logging.info("data: {0:.2f}s, forward: {1:.2f}s, loss: {2:.2f}s, backward: {3:.2f}s".format( + train_scalars['data_time'], train_scalars['forward_time'], + train_scalars['loss_time'], train_scalars['backward_time'])) def load_pretrained_weights(self, network, pretrained_dict, device_ids): if(len(device_ids) > 1): @@ -607,4 +623,3 @@ def save_outputs(self, data): if(len(temp_prob.shape) == 2): temp_prob = np.asarray(temp_prob * 255, np.uint8) save_nd_array_as_image(temp_prob, prob_save_name, test_dir + '/' + names[i][0]) -0]) diff --git a/pymic/net_run/self_sup/self_volf.py b/pymic/net_run/self_sup/self_volf.py new file mode 100644 index 0000000..4615979 --- /dev/null +++ b/pymic/net_run/self_sup/self_volf.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +from pymic.net_run.agent_seg import SegmentationAgent + + +class SelfSupVolumeFusion(SegmentationAgent): + """ + Abstract class for self-supervised segmentation. + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. + """ + def __init__(self, config, stage = 'train'): + super(SelfSupVolumeFusion, self).__init__(config, stage) + + def get_transform_names_and_parameters(self, stage): + trans_names, trans_params = super(SelfSupVolumeFusion, self).get_transform_names_and_parameters(stage) + if(stage == 'train'): + print('training transforms:', trans_names) + if("Crop4VolumeFusion" not in trans_names): + raise ValueError("Crop4VolumeFusion is required for VolF, \ + but it is not given in training transform") + if("VolumeFusion" not in trans_names): + raise ValueError("VolumeFusion is required for VolF, \ + but it is not given in training transform") + if("LabelToProbability" not in trans_names): + raise ValueError("LabelToProbability is required for VolF, \ + but it is not given in training transform") + return trans_names, trans_params + + \ No newline at end of file diff --git a/pymic/net_run/self_sup/self_volume_fusion.py b/pymic/net_run/self_sup/self_volume_fusion.py deleted file mode 100644 index ad2d640..0000000 --- a/pymic/net_run/self_sup/self_volume_fusion.py +++ /dev/null @@ -1,354 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division -import copy -import os -import sys -import shutil -import time -import logging -import scipy -import torch -import torchvision.transforms as transforms -import numpy as np -import torch.nn as nn -import torch.optim as optim -import torch.nn.functional as F -from datetime import datetime -from random import random -from torch.optim import lr_scheduler -from tensorboardX import SummaryWriter -from pymic.io.image_read_write import save_nd_array_as_image -from pymic.io.nifty_dataset import NiftyDataset -from pymic.net.net_dict_seg import SegNetDict -from pymic.net_run.agent_abstract import NetRunAgent -from pymic.net_run.infer_func import Inferer -from pymic.loss.loss_dict_seg import SegLossDict -from pymic.loss.seg.combined import CombinedLoss -from pymic.loss.seg.deep_sup import DeepSuperviseLoss -from pymic.loss.seg.util import get_soft_label -from pymic.loss.seg.util import reshape_prediction_and_ground_truth -from pymic.loss.seg.util import get_classwise_dice -from pymic.transform.trans_dict import TransformDict -from pymic.util.post_process import PostProcessDict -from pymic.util.parse_config import * -from pymic.util.general import get_one_hot_seg -from pymic.io.image_read_write import save_nd_array_as_image -from pymic.net_run.self_sup.util import volume_fusion, nonlienar_volume_fusion, augmented_volume_fusion -from pymic.net_run.self_sup.util import self_volume_fusion -from pymic.net_run.agent_seg import SegmentationAgent - - -class SelfSupVolumeFusion(SegmentationAgent): - """ - Abstract class for self-supervised segmentation. - - :param config: (dict) A dictionary containing the configuration. - :param stage: (str) One of the stage in `train` (default), `inference` or `test`. - - .. note:: - - In the configuration dictionary, in addition to the four sections (`dataset`, - `network`, `training` and `inference`) used in fully supervised learning, an - extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. - """ - def __init__(self, config, stage = 'train'): - super(SelfSupVolumeFusion, self).__init__(config, stage) - - def training(self): - class_num = self.config['network']['class_num'] - iter_valid = self.config['training']['iter_valid'] - block_range = self.config['self_supervised_learning']['VolumeFusion_block_range'.lower()] - size_min = self.config['self_supervised_learning']['VolumeFusion_size_min'.lower()] - size_max = self.config['self_supervised_learning']['VolumeFusion_size_max'.lower()] - - train_loss = 0 - train_dice_list = [] - self.net.train() - for it in range(iter_valid): - try: - data = next(self.trainIter) - except StopIteration: - self.trainIter = iter(self.train_loader) - data = next(self.trainIter) - # get the inputs - inputs = self.convert_tensor_type(data['image']) - inputs, labels = volume_fusion(inputs, class_num - 1, block_range, size_min, size_max) - labels_prob = get_one_hot_seg(labels, class_num) - - # for debug - # if(it==10): - # break - # for i in range(inputs.shape[0]): - # image_i = inputs[i][0] - # label_i = np.argmax(labels_prob[i], axis = 0) - # # pixw_i = pix_w[i][0] - # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) - # label_name = "temp/label_{0:}_{1:}.nii.gz".format(it, i) - # # weight_name= "temp/weight_{0:}_{1:}.nii.gz".format(it, i) - # save_nd_array_as_image(image_i, image_name, reference_name = None) - # save_nd_array_as_image(label_i, label_name, reference_name = None) - # # save_nd_array_as_image(pixw_i, weight_name, reference_name = None) - # continue - - inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device) - - # zero the parameter gradients - self.optimizer.zero_grad() - - # forward + backward + optimize - outputs = self.net(inputs) - loss = self.get_loss_value(data, outputs, labels_prob) - loss.backward() - self.optimizer.step() - train_loss = train_loss + loss.item() - # get dice evaluation for each class - if(isinstance(outputs, tuple) or isinstance(outputs, list)): - outputs = outputs[0] - outputs_argmax = torch.argmax(outputs, dim = 1, keepdim = True) - soft_out = get_soft_label(outputs_argmax, class_num, self.tensor_type) - soft_out, labels_prob = reshape_prediction_and_ground_truth(soft_out, labels_prob) - dice_list = get_classwise_dice(soft_out, labels_prob) - train_dice_list.append(dice_list.cpu().numpy()) - train_avg_loss = train_loss / iter_valid - train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) - train_avg_dice = train_cls_dice[1:].mean() - - train_scalers = {'loss': train_avg_loss, 'avg_fg_dice':train_avg_dice,\ - 'class_dice': train_cls_dice} - return train_scalers - -class SelfSupSelfVolumeFusion(SegmentationAgent): - """ - Abstract class for self-supervised segmentation. - - :param config: (dict) A dictionary containing the configuration. - :param stage: (str) One of the stage in `train` (default), `inference` or `test`. - - .. note:: - - In the configuration dictionary, in addition to the four sections (`dataset`, - `network`, `training` and `inference`) used in fully supervised learning, an - extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. - """ - def __init__(self, config, stage = 'train'): - super(SelfSupSelfVolumeFusion, self).__init__(config, stage) - - def training(self): - class_num = self.config['network']['class_num'] - iter_valid = self.config['training']['iter_valid'] - fuse_ratio = self.config['self_supervised_learning']['SelfVolumeFusion_fuse_ratio'.lower()] - size_min = self.config['self_supervised_learning']['SelfVolumeFusion_size_min'.lower()] - size_max = self.config['self_supervised_learning']['SelfVolumeFusion_size_max'.lower()] - - train_loss = 0 - train_dice_list = [] - self.net.train() - for it in range(iter_valid): - try: - data = next(self.trainIter) - except StopIteration: - self.trainIter = iter(self.train_loader) - data = next(self.trainIter) - # get the inputs - inputs = self.convert_tensor_type(data['image']) - inputs, labels = self_volume_fusion(inputs, class_num - 1, fuse_ratio, size_min, size_max) - labels_prob = get_one_hot_seg(labels, class_num) - - # for debug - # if(it==10): - # break - # for i in range(inputs.shape[0]): - # image_i = inputs[i][0] - # label_i = np.argmax(labels_prob[i], axis = 0) - # # pixw_i = pix_w[i][0] - # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) - # label_name = "temp/label_{0:}_{1:}.nii.gz".format(it, i) - # # weight_name= "temp/weight_{0:}_{1:}.nii.gz".format(it, i) - # save_nd_array_as_image(image_i, image_name, reference_name = None) - # save_nd_array_as_image(label_i, label_name, reference_name = None) - # # save_nd_array_as_image(pixw_i, weight_name, reference_name = None) - # continue - - inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device) - - # zero the parameter gradients - self.optimizer.zero_grad() - - # forward + backward + optimize - outputs = self.net(inputs) - loss = self.get_loss_value(data, outputs, labels_prob) - loss.backward() - self.optimizer.step() - train_loss = train_loss + loss.item() - # get dice evaluation for each class - if(isinstance(outputs, tuple) or isinstance(outputs, list)): - outputs = outputs[0] - outputs_argmax = torch.argmax(outputs, dim = 1, keepdim = True) - soft_out = get_soft_label(outputs_argmax, class_num, self.tensor_type) - soft_out, labels_prob = reshape_prediction_and_ground_truth(soft_out, labels_prob) - dice_list = get_classwise_dice(soft_out, labels_prob) - train_dice_list.append(dice_list.cpu().numpy()) - train_avg_loss = train_loss / iter_valid - train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) - train_avg_dice = train_cls_dice[1:].mean() - - train_scalers = {'loss': train_avg_loss, 'avg_fg_dice':train_avg_dice,\ - 'class_dice': train_cls_dice} - return train_scalers - -class SelfSupNonLinearVolumeFusion(SegmentationAgent): - """ - Abstract class for self-supervised segmentation. - - :param config: (dict) A dictionary containing the configuration. - :param stage: (str) One of the stage in `train` (default), `inference` or `test`. - - .. note:: - - In the configuration dictionary, in addition to the four sections (`dataset`, - `network`, `training` and `inference`) used in fully supervised learning, an - extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. - """ - def __init__(self, config, stage = 'train'): - super(SelfSupNonLinearVolumeFusion, self).__init__(config, stage) - - def training(self): - class_num = 3 - iter_valid = self.config['training']['iter_valid'] - block_range = self.config['self_supervised_learning']['NonLinearVolumeFusion_block_range'.lower()] - size_min = self.config['self_supervised_learning']['NonLinearVolumeFusion_size_min'.lower()] - size_max = self.config['self_supervised_learning']['NonLinearVolumeFusion_size_max'.lower()] - - train_loss = 0 - train_dice_list = [] - self.net.train() - for it in range(iter_valid): - try: - data = next(self.trainIter) - except StopIteration: - self.trainIter = iter(self.train_loader) - data = next(self.trainIter) - # get the inputs - inputs = self.convert_tensor_type(data['image']) - inputs, labels = nonlienar_volume_fusion(inputs, block_range, size_min, size_max) - labels_prob = get_one_hot_seg(labels, class_num) - - # for debug - # if(it==10): - # break - # for i in range(inputs.shape[0]): - # image_i = inputs[i][0] - # label_i = np.argmax(labels_prob[i], axis = 0) - # # pixw_i = pix_w[i][0] - # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) - # label_name = "temp/label_{0:}_{1:}.nii.gz".format(it, i) - # # weight_name= "temp/weight_{0:}_{1:}.nii.gz".format(it, i) - # save_nd_array_as_image(image_i, image_name, reference_name = None) - # save_nd_array_as_image(label_i, label_name, reference_name = None) - # # save_nd_array_as_image(pixw_i, weight_name, reference_name = None) - # continue - - inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device) - - # zero the parameter gradients - self.optimizer.zero_grad() - - # forward + backward + optimize - outputs = self.net(inputs) - loss = self.get_loss_value(data, outputs, labels_prob) - loss.backward() - self.optimizer.step() - train_loss = train_loss + loss.item() - # get dice evaluation for each class - if(isinstance(outputs, tuple) or isinstance(outputs, list)): - outputs = outputs[0] - outputs_argmax = torch.argmax(outputs, dim = 1, keepdim = True) - soft_out = get_soft_label(outputs_argmax, class_num, self.tensor_type) - soft_out, labels_prob = reshape_prediction_and_ground_truth(soft_out, labels_prob) - dice_list = get_classwise_dice(soft_out, labels_prob) - train_dice_list.append(dice_list.cpu().numpy()) - train_avg_loss = train_loss / iter_valid - train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) - train_avg_dice = train_cls_dice[1:].mean() - - train_scalers = {'loss': train_avg_loss, 'avg_fg_dice':train_avg_dice,\ - 'class_dice': train_cls_dice} - return train_scalers - -class SelfSupAugmentedVolumeFusion(SegmentationAgent): - """ - Abstract class for self-supervised segmentation. - - :param config: (dict) A dictionary containing the configuration. - :param stage: (str) One of the stage in `train` (default), `inference` or `test`. - - .. note:: - - In the configuration dictionary, in addition to the four sections (`dataset`, - `network`, `training` and `inference`) used in fully supervised learning, an - extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. - """ - def __init__(self, config, stage = 'train'): - super(SelfSupAugmentedVolumeFusion, self).__init__(config, stage) - - def training(self): - class_num = 5 - iter_valid = self.config['training']['iter_valid'] - size_min = self.config['self_supervised_learning']['AugmentedVolumeFusion_size_min'.lower()] - size_max = self.config['self_supervised_learning']['AugmentedVolumeFusion_size_max'.lower()] - - train_loss = 0 - train_dice_list = [] - self.net.train() - for it in range(iter_valid): - try: - data = next(self.trainIter) - except StopIteration: - self.trainIter = iter(self.train_loader) - data = next(self.trainIter) - # get the inputs - inputs = self.convert_tensor_type(data['image']) - inputs, labels = augmented_volume_fusion(inputs, size_min, size_max) - labels_prob = get_one_hot_seg(labels, class_num) - - # for debug - # if(it==10): - # break - # for i in range(inputs.shape[0]): - # image_i = inputs[i][0] - # label_i = np.argmax(labels_prob[i], axis = 0) - # # pixw_i = pix_w[i][0] - # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) - # label_name = "temp/label_{0:}_{1:}.nii.gz".format(it, i) - # # weight_name= "temp/weight_{0:}_{1:}.nii.gz".format(it, i) - # save_nd_array_as_image(image_i, image_name, reference_name = None) - # save_nd_array_as_image(label_i, label_name, reference_name = None) - # # save_nd_array_as_image(pixw_i, weight_name, reference_name = None) - # continue - - inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device) - - # zero the parameter gradients - self.optimizer.zero_grad() - - # forward + backward + optimize - outputs = self.net(inputs) - loss = self.get_loss_value(data, outputs, labels_prob) - loss.backward() - self.optimizer.step() - train_loss = train_loss + loss.item() - # get dice evaluation for each class - if(isinstance(outputs, tuple) or isinstance(outputs, list)): - outputs = outputs[0] - outputs_argmax = torch.argmax(outputs, dim = 1, keepdim = True) - soft_out = get_soft_label(outputs_argmax, class_num, self.tensor_type) - soft_out, labels_prob = reshape_prediction_and_ground_truth(soft_out, labels_prob) - dice_list = get_classwise_dice(soft_out, labels_prob) - train_dice_list.append(dice_list.cpu().numpy()) - train_avg_loss = train_loss / iter_valid - train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) - train_avg_dice = train_cls_dice[1:].mean() - - train_scalers = {'loss': train_avg_loss, 'avg_fg_dice':train_avg_dice,\ - 'class_dice': train_cls_dice} - return train_scalers \ No newline at end of file diff --git a/pymic/transform/trans_dict.py b/pymic/transform/trans_dict.py index ed5ad0c..2a15857 100644 --- a/pymic/transform/trans_dict.py +++ b/pymic/transform/trans_dict.py @@ -15,6 +15,7 @@ 'LabelConvertNonzero': LabelConvertNonzero, 'LabelToProbability': LabelToProbability, 'IntensityClip': IntensityClip, + 'NonLinearTransform': NonLinearTransform, 'NormalizeWithMeanStd': NormalizeWithMeanStd, 'NormalizeWithMinMax': NormalizeWithMinMax, 'NormalizeWithPercentiles': NormalizeWithPercentiles, @@ -41,20 +42,28 @@ from pymic.transform.threshold import * from pymic.transform.normalize import * from pymic.transform.crop import * -from pymic.transform.mix import * +from pymic.transform.crop4dino import Crop4Dino +from pymic.transform.crop4vox2vec import Crop4Vox2Vec +from pymic.transform.crop4vf import Crop4VolumeFusion, VolumeFusion, VolumeFusionShuffle +from pymic.transform.volume_fusion import * from pymic.transform.label_convert import * TransformDict = { 'Affine': Affine, + 'AdaptiveContrastAdjust': AdaptiveContrastAdjust, 'ChannelWiseThreshold': ChannelWiseThreshold, 'ChannelWiseThresholdWithNormalize': ChannelWiseThresholdWithNormalize, 'CropWithBoundingBox': CropWithBoundingBox, 'CropWithForeground': CropWithForeground, - 'CropHumanRegionFromCT': CropHumanRegionFromCT, + 'CropHumanRegion': CropHumanRegion, 'CenterCrop': CenterCrop, + 'Crop4Dino': Crop4Dino, + 'Crop4Vox2Vec': Crop4Vox2Vec, + 'Crop4VolumeFusion': Crop4VolumeFusion, 'GrayscaleToRGB': GrayscaleToRGB, 'GammaCorrection': GammaCorrection, 'GaussianNoise': GaussianNoise, + 'HistEqual': HistEqual, 'InPainting': InPainting, 'InOutPainting': InOutPainting, 'LabelConvert': LabelConvert, @@ -62,6 +71,7 @@ 'LabelToProbability': LabelToProbability, 'LocalShuffling': LocalShuffling, 'IntensityClip': IntensityClip, + 'MaskedImageModeling': MaskedImageModeling, 'NonLinearTransform': NonLinearTransform, 'NormalizeWithMeanStd': NormalizeWithMeanStd, 'NormalizeWithMinMax': NormalizeWithMinMax, @@ -82,5 +92,6 @@ 'OutPainting': OutPainting, 'Pad': Pad, 'PatchSwaping':PatchSwaping, - 'PatchMix': PatchMix + 'VolumeFusion': VolumeFusion, + 'VolumeFusionShuffle': VolumeFusionShuffle } diff --git a/pymic/transform/volume_fusion.py b/pymic/transform/volume_fusion.py new file mode 100644 index 0000000..38c74c0 --- /dev/null +++ b/pymic/transform/volume_fusion.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import random +import numpy as np +from pymic.transform.abstract_transform import AbstractTransform +from pymic.util.image_process import * +try: # SciPy >= 0.19 + from scipy.special import comb +except ImportError: + from scipy.misc import comb + +def random_resized_crop(x, output_shape): + img_shape = x.shape[1:] + ratio = [img_shape[i] / output_shape[i] for i in range(3)] + r_max = [min(ratio[i], 1.25) for i in range(3)] + r_min = (0.8, 0.8, 0.8) + scale = [r_min[i] + random.random() * (r_max[i] - r_min[i]) for i in range(3)] + crop_size = [int(output_shape[i] * scale[i]) for i in range(3)] + + bb_min = [random.randint(0, img_shape[i] - crop_size[i]) for i in range(3)] + bb_max = [bb_min[i] + crop_size[i] for i in range(3)] + bb_min = [0] + bb_min + bb_max = [x.shape[0]] + bb_max + crop_volume = crop_ND_volume_with_bounding_box(x, bb_min, bb_max) + + scale = [(output_shape[i] + 0.0)/crop_size[i] for i in range(3)] + scale = [1.0] + scale + y = ndimage.interpolation.zoom(crop_volume, scale, order = 1) + return y + +def nonlinear_transform(x): + v_min = np.min(x) + v_max = np.max(x) + x = (x - v_min)/(v_max - v_min) + a = random.random() * 0.7 + 0.15 + b = random.random() * 0.7 + 0.15 + alpha = b / a + beta = (1 - b) / (1 - a) + if(alpha < 1.0 ): + y = np.maximum(alpha*x, beta*x + 1 - beta) + else: + y = np.minimum(alpha*x, beta*x + 1 - beta) + if(random.random() < 0.5): + y = 1.0 - y + y = y * (v_max - v_min) + v_min + return y + +def random_flip(x): + flip_axis = [] + if(random.random() > 0.5): + flip_axis.append(-1) + if(random.random() > 0.5): + flip_axis.append(-2) + if(random.random() > 0.5): + flip_axis.append(-3) + + if(len(flip_axis) > 0): + # use .copy() to avoid negative strides of numpy array + # current pytorch does not support negative strides + y = np.flip(x, flip_axis).copy() + else: + y = x + return y + + +class VolumeFusion(AbstractTransform): + """ + fusing two subvolumes of an image, used for self-supervised learning + """ + def __init__(self, params): + super(VolumeFusion, self).__init__(params) + self.inverse = params.get('VolumeFusion_inverse'.lower(), False) + self.crop_size = params.get('VolumeFusion_crop_size'.lower(), [64, 128, 128]) + self.block_range = params.get('VolumeFusion_block_range'.lower(), [20, 40]) + self.size_min = params.get('VolumeFusion_size_min'.lower(), [8, 16, 16]) + self.size_max = params.get('VolumeFusion_size_max'.lower(), [16, 32, 32]) + + def __call__(self, sample): + x = sample['image'] + x0 = random_resized_crop(x, self.crop_size) + x1 = random_resized_crop(x, self.crop_size) + x0 = random_flip(x0) + x1 = random_flip(x1) + # nonlinear transform + x0a = nonlinear_transform(x0) + x0b = nonlinear_transform(x0) + x1 = nonlinear_transform(x1) + + D, H, W = x0.shape[1:] + mask = np.zeros_like(x0, np.uint8) + p_num = random.randint(self.block_range[0], self.block_range[1]) + for i in range(p_num): + d = random.randint(self.size_min[0], self.size_max[0]) + h = random.randint(self.size_min[1], self.size_max[1]) + w = random.randint(self.size_min[2], self.size_max[2]) + dc = random.randint(0, D - 1) + hc = random.randint(0, H - 1) + wc = random.randint(0, W - 1) + d0 = dc - d // 2 + h0 = hc - h // 2 + w0 = wc - w // 2 + d1 = min(D, d0 + d) + h1 = min(H, h0 + h) + w1 = min(W, w0 + w) + d0, h0, w0 = max(0, d0), max(0, h0), max(0, w0) + temp_m = np.ones([d1 - d0, h1 - h0, w1 - w0]) + if(random.random() < 0.5): + temp_m = temp_m * 2 + mask[:, d0:d1, h0:h1, w0:w1] = temp_m + + mask1 = np.asarray(mask == 1, np.uint8) + mask2 = np.asarray(mask == 2, np.uint8) + y = x0a * (1.0 - mask1) + x0b * mask1 + y = y * (1.0 - mask2) + x1 * mask2 + sample['image'] = y + sample['label'] = mask + return sample From 53db91d5066f2ff20a9487fd4118f9e2750f6908 Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 30 Sep 2024 16:20:24 +0800 Subject: [PATCH 190/225] update transform --- pymic/transform/affine.py | 4 - pymic/transform/crop.py | 50 +++++-- pymic/transform/crop4dino.py | 175 +++++++++++++++++++++++ pymic/transform/crop4vf.py | 232 +++++++++++++++++++++++++++++++ pymic/transform/crop4vox2vec.py | 160 +++++++++++++++++++++ pymic/transform/flip.py | 3 +- pymic/transform/rescale.py | 11 +- pymic/transform/volume_fusion.py | 117 ---------------- 8 files changed, 612 insertions(+), 140 deletions(-) create mode 100644 pymic/transform/crop4dino.py create mode 100644 pymic/transform/crop4vf.py create mode 100644 pymic/transform/crop4vox2vec.py delete mode 100644 pymic/transform/volume_fusion.py diff --git a/pymic/transform/affine.py b/pymic/transform/affine.py index 2efd586..1717a97 100644 --- a/pymic/transform/affine.py +++ b/pymic/transform/affine.py @@ -152,9 +152,5 @@ def _get_param_for_inverse_transform(self, sample): # aff_out_shape = origin_shape[-2:] # output_predict = self._apply_affine_to_ND_volume(predict, aff_out_shape, tform.inverse) - # sample['predict'] = output_predict - # return sample - = self._apply_affine_to_ND_volume(predict, aff_out_shape, tform.inverse) - # sample['predict'] = output_predict # return sample diff --git a/pymic/transform/crop.py b/pymic/transform/crop.py index 9e1c077..c444acd 100644 --- a/pymic/transform/crop.py +++ b/pymic/transform/crop.py @@ -260,7 +260,7 @@ def _get_crop_param(self, sample): crop_min = [0 if item == 0 else random.randint(0, item) for item in crop_margin] crop_max = [crop_min[i] + output_size[i] for i in range(input_dim)] - label_exist = False if ('label' not in sample or sample['label']) is None else True + label_exist = True if ('label' in sample and sample['label'].sum() > 0) else False if(label_exist and self.fg_focus and random.random() < self.fg_ratio): label = sample['label'][0] if(self.mask_label is None): @@ -398,6 +398,9 @@ class RandomSlice(AbstractTransform): """ def __init__(self, params): self.output_size = params['RandomSlice_output_size'.lower()] + self.fg_focus = params.get('RandomSlice_foreground_focus'.lower(), False) + self.fg_ratio = params.get('RandomSlice_foreground_ratio'.lower(), 0.5) + self.mask_label = params.get('RandomSlice_mask_label'.lower(), None) self.shuffle = params.get('RandomSlice_shuffle'.lower(), False) self.inverse = params.get('RandomSlice_inverse'.lower(), False) self.task = params['Task'.lower()] @@ -406,13 +409,30 @@ def __call__(self, sample): image = sample['image'] D = image.shape[1] assert( D >= self.output_size) + out_half = self.output_size // 2 + + label_exist = True if ('label' in sample and sample['label'].sum() > 0) else False + if(label_exist and self.fg_focus and random.random() < self.fg_ratio): + label = sample['label'][0] + if(self.mask_label is None): + mask_label = np.unique(label)[1:] + else: + mask_label = self.mask_label + random_label = random.choice(mask_label) + mask = label == random_label + dc = random.choice(np.nonzero(mask)[0]) + else: + dc = random.choice(range(out_half, D - out_half)) + slice_idx = list(range(D)) if(self.shuffle): random.shuffle(slice_idx) - slice_idx = slice_idx[:self.output_size] - else: d0 = random.randint(0, D - self.output_size) d1 = d0 + self.output_size + slice_idx = slice_idx[d0:d1-1] + [dc] + else: + d0 = max(0, dc - out_half) + d1 = d0 + self.output_size slice_idx = slice_idx[d0:d1] sample['image'] = image[:, slice_idx, :, :] @@ -448,6 +468,7 @@ class CropHumanRegion(CenterCrop): """ def __init__(self, params): self.threshold_i = params.get('CropHumanRegion_intensity_threshold'.lower(), -600) + self.threshold_mode = params.get('CropHumanRegion_threshold_mode'.lower(), 'mean') self.threshold_z = params.get('CropHumanRegion_zaxis_threshold'.lower(), 0.5) self.inverse = params.get('CropHumanRegion_inverse'.lower(), True) self.task = params['task'] @@ -456,20 +477,27 @@ def _get_crop_param(self, sample): image = sample['image'] input_shape = image.shape mask = np.asarray(image[0] > self.threshold_i) - mask2d = np.mean(mask, axis = 0) > self.threshold_z + if(self.threshold_mode == "mean"): + mask2d = np.mean(mask, axis = 0) > self.threshold_z + else: + mask2d = np.max(mask, axis = 0) se = np.ones([3,3]) mask2d = ndimage.binary_opening(mask2d, se, iterations = 2) - mask2d = get_largest_k_components(mask2d, 1) - bbmin, bbmax = get_ND_bounding_box(mask2d, margin = [0, 0]) + if(mask2d.sum() > 0): + mask2d = get_largest_k_components(mask2d, 1) + bbmin, bbmax = get_ND_bounding_box(mask2d, margin = [0, 0]) + else: + bbmin = [0] * (image.ndim - 2) + bbmax = list(input_shape[2:]) crop_min = [0, 0] + bbmin crop_max = list(input_shape[:2]) + bbmax - sample['CropHumanRegionFromCT_Param'] = json.dumps((input_shape, crop_min, crop_max)) + sample['CropHumanRegion_Param'] = json.dumps((input_shape, crop_min, crop_max)) return sample, crop_min, crop_max def _get_param_for_inverse_transform(self, sample): - if(isinstance(sample['CropHumanRegionFromCT_Param'], list) or \ - isinstance(sample['CropHumanRegionFromCT_Param'], tuple)): - params = json.loads(sample['CropHumanRegionFromCT_Param'][0]) + if(isinstance(sample['CropHumanRegion_Param'], list) or \ + isinstance(sample['CropHumanRegion_Param'], tuple)): + params = json.loads(sample['CropHumanRegion_Param'][0]) else: - params = json.loads(sample['CropHumanRegionFromCT_Param']) + params = json.loads(sample['CropHumanRegion_Param']) return params \ No newline at end of file diff --git a/pymic/transform/crop4dino.py b/pymic/transform/crop4dino.py new file mode 100644 index 0000000..df86787 --- /dev/null +++ b/pymic/transform/crop4dino.py @@ -0,0 +1,175 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import json +import math +import random +import numpy as np +from scipy import ndimage +from pymic import TaskType +from pymic.transform.abstract_transform import AbstractTransform +from pymic.transform.crop import CenterCrop +from pymic.transform.intensity import * +from pymic.util.image_process import * + +class Crop4Dino(CenterCrop): + """ + Randomly crop an volume into two views with augmentation. This is used for + self-supervised pretraining such as DeSD. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `DualViewCrop_output_size`: (list/tuple) Desired output size [D, H, W]. + The output channel is the same as the input channel. + :param `DualViewCrop_scale_lower_bound`: (list/tuple) Lower bound of the range of scale + for each dimension. e.g. (1.0, 0.5, 0.5). + param `DualViewCrop_scale_upper_bound`: (list/tuple) Upper bound of the range of scale + for each dimension. e.g. (1.0, 2.0, 2.0). + :param `DualViewCrop_inverse`: (optional, bool) Is inverse transform needed for inference. + Default is `False`. Currently, the inverse transform is not supported, and + this transform is assumed to be used only during training stage. + """ + def __init__(self, params): + self.output_size = params['Crop4Dino_output_size'.lower()] + self.scale_lower = params['Crop4Dino_resize_lower_bound'.lower()] + self.scale_upper = params['Crop4Dino_resize_upper_bound'.lower()] + self.prob = params.get('Crop4Dino_resize_prob'.lower(), 0.5) + self.noise_std_range = params.get('Crop4Dino_noise_std_range'.lower(), (0.05, 0.1)) + self.blur_sigma_range = params.get('Crop4Dino_blur_sigma_range'.lower(), (1.0, 3.0)) + self.gamma_range = params.get('Crop4Dino_gamma_range'.lower(), (0.75, 1.25)) + self.inverse = params.get('Crop4Dino_inverse'.lower(), False) + self.task = params['Task'.lower()] + assert isinstance(self.output_size, (list, tuple)) + assert isinstance(self.scale_lower, (list, tuple)) + assert isinstance(self.scale_upper, (list, tuple)) + + def __call__(self, sample): + image = sample['image'] + channel, input_size = image.shape[0], image.shape[1:] + input_dim = len(input_size) + assert(input_dim == len(self.output_size)) + + # # center crop first + # crop_size = self.output_size + # crop_margin = [input_size[i] - crop_size[i] for i in range(input_dim)] + # crop_min = [int(item/2) for item in crop_margin] + # crop_max = [crop_min[i] + crop_size[i] for i in range(input_dim)] + # crop_min = [0] + crop_min + # crop_max = [channel] + crop_max + # crop0 = crop_ND_volume_with_bounding_box(image, crop_min, crop_max) + + crop_num = 2 + crop_img = [] + for crop_i in range(crop_num): + resize = random.random() < self.prob + if(resize): + scale = [self.scale_lower[i] + (self.scale_upper[i] - self.scale_lower[i]) * random.random() \ + for i in range(input_dim)] + crop_size = [int(self.output_size[i] * scale[i]) for i in range(input_dim)] + else: + crop_size = self.output_size + + crop_margin = [input_size[i] - crop_size[i] for i in range(input_dim)] + pad_image = min(crop_margin) < 0 + if(pad_image): # pad the image if necessary + pad_size = [max(0, -crop_margin[i]) for i in range(input_dim)] + pad_lower = [int(pad_size[i] / 2) for i in range(input_dim)] + pad_upper = [pad_size[i] - pad_lower[i] for i in range(input_dim)] + pad = [(pad_lower[i], pad_upper[i]) for i in range(input_dim)] + pad = tuple([(0, 0)] + pad) + image = np.pad(image, pad, 'reflect') + crop_margin = [max(0, crop_margin[i]) for i in range(input_dim)] + + + crop_min = [random.randint(0, item) for item in crop_margin] + crop_max = [crop_min[i] + crop_size[i] for i in range(input_dim)] + crop_min = [0] + crop_min + crop_max = [channel] + crop_max + + crop_out = crop_ND_volume_with_bounding_box(image, crop_min, crop_max) + if(resize): + scale = [(self.output_size[i] + 0.0)/crop_size[i] for i in range(input_dim)] + scale = [1.0] + scale + crop_out = ndimage.interpolation.zoom(crop_out, scale, order = 1) + + # add intensity augmentation + C = crop_out.shape[0] + for c in range(C): + if(random.random() < 0.8): + crop_out[c] = gaussian_noise(crop_out[c], self.noise_std_range[0], self.noise_std_range[1]) + + if(random.uniform(0, 1) < 0.5): + crop_out[c] = gaussian_blur(crop_out[c], self.blur_sigma_range[0], self.blur_sigma_range[1]) + else: + alpha = random.uniform(0.0, 2.0) + crop_out[c] = gaussian_sharpen(crop_out[c], self.blur_sigma_range[0], self.blur_sigma_range[1], alpha) + if(random.random() < 0.8): + crop_out[c] = gamma_correction(crop_out[c], self.gamma_range[0], self.gamma_range[1]) + if(random.random() < 0.8): + crop_out[c] = window_level_augment(crop_out[c]) + crop_img.append(crop_out) + sample['image'] = crop_img + return sample + + def __call__backup(self, sample): + image = sample['image'] + channel, input_size = image.shape[0], image.shape[1:] + input_dim = len(input_size) + assert(input_dim == len(self.output_size)) + + # center crop first + crop_size = self.output_size + crop_margin = [input_size[i] - crop_size[i] for i in range(input_dim)] + crop_min = [int(item/2) for item in crop_margin] + crop_max = [crop_min[i] + crop_size[i] for i in range(input_dim)] + crop_min = [0] + crop_min + crop_max = [channel] + crop_max + crop0 = crop_ND_volume_with_bounding_box(image, crop_min, crop_max) + + # crop_num = 2 + # crop_img = [] + # for crop_i in range(crop_num): + # get another resized crop size + resize = random.random() < self.prob + if(resize): + scale = [self.scale_lower[i] + (self.scale_upper[i] - self.scale_lower[i]) * random.random() \ + for i in range(input_dim)] + crop_size = [int(self.output_size[i] * scale[i]) for i in range(input_dim)] + else: + crop_size = self.output_size + + crop_margin = [input_size[i] - crop_size[i] for i in range(input_dim)] + pad_image = min(crop_margin) < 0 + if(pad_image): # pad the image if necessary + pad_size = [max(0, -crop_margin[i]) for i in range(input_dim)] + pad_lower = [int(pad_size[i] / 2) for i in range(input_dim)] + pad_upper = [pad_size[i] - pad_lower[i] for i in range(input_dim)] + pad = [(pad_lower[i], pad_upper[i]) for i in range(input_dim)] + pad = tuple([(0, 0)] + pad) + image = np.pad(image, pad, 'reflect') + crop_margin = [max(0, crop_margin[i]) for i in range(input_dim)] + + + crop_min = [random.randint(0, item) for item in crop_margin] + crop_max = [crop_min[i] + crop_size[i] for i in range(input_dim)] + crop_min = [0] + crop_min + crop_max = [channel] + crop_max + + crop_out = crop_ND_volume_with_bounding_box(image, crop_min, crop_max) + if(resize): + scale = [(self.output_size[i] + 0.0)/crop_size[i] for i in range(input_dim)] + scale = [1.0] + scale + crop_out = ndimage.interpolation.zoom(crop_out, scale, order = 1) + # crop_img.append(crop_out) + crop_img = [crop0, crop_out] + # add intensity augmentation + # image_t = gaussian_noise(image_t, self.noise_std_range[0], self.noise_std_range[1], 0.8) + # image_t = gaussian_blur(image_t, self.blur_sigma_range[0], self.blur_sigma_range[1], 0.8) + # image_t = brightness_multiplicative(image_t, self.inten_multi_range[0], self.inten_multi_range[1], 0.8) + # image_t = brightness_additive(image_t, self.inten_add_range[0], self.inten_add_range[1], 0.8) + # image_t = contrast_augment(image_t, self.contrast_f_range[0], self.contrast_f_range[1], 0.8) + # image_t = gamma_correction(image_t, self.gamma_range[0], self.gamma_range[1], 0.8) + sample['image'] = crop_img + return sample diff --git a/pymic/transform/crop4vf.py b/pymic/transform/crop4vf.py new file mode 100644 index 0000000..4e07357 --- /dev/null +++ b/pymic/transform/crop4vf.py @@ -0,0 +1,232 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch + +import json +import math +import random +import numpy as np +from imops import crop_to_box +from typing import * +from scipy import ndimage +from pymic import TaskType +from pymic.transform.abstract_transform import AbstractTransform +from pymic.transform.crop import CenterCrop +from pymic.util.image_process import * +from pymic.transform.intensity import * + + +def random_resized_crop(image, output_size, scale_lower, scale_upper): + input_size = image.shape + scale = [scale_lower[i] + (scale_upper[i] - scale_lower[i]) * random.random() \ + for i in range(3)] + crop_size = [min(int(output_size[i] * scale[i]), input_size[1+i]) for i in range(3)] + crop_margin = [input_size[1+i] - crop_size[i] for i in range(3)] + crop_min = [random.randint(0, item) for item in crop_margin] + crop_max = [crop_min[i] + crop_size[i] for i in range(3)] + crop_min = [0] + crop_min + crop_max = [input_size[0]] + crop_max + + image_t = crop_ND_volume_with_bounding_box(image, crop_min, crop_max) + scale = [(output_size[i] + 0.0)/crop_size[i] for i in range(3)] + scale = [1.0] + scale + image_t = ndimage.interpolation.zoom(image_t, scale, order = 1) + return image_t + +def random_flip(image): + flip_axis = [] + if(random.random() > 0.5): + flip_axis.append(-1) + if(random.random() > 0.5): + flip_axis.append(-2) + if(random.random() > 0.5): + flip_axis.append(-3) + if(len(flip_axis) > 0): + image = np.flip(image , flip_axis) + return image + + +class Crop4VolumeFusion(AbstractTransform): + """ + Randomly crop an volume into two views with augmentation. This is used for + self-supervised pretraining in Vox2vec. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `Crop4VolumeFusion_output_size`: (list/tuple) Desired output size [D, H, W]. + The output channel is the same as the input channel. + :param `Crop4VolumeFusion_rescale_lower_bound`: (list/tuple) Lower bound of the range of scale + for each dimension. e.g. (1.0, 0.5, 0.5). + param `Crop4VolumeFusion_rescale_upper_bound`: (list/tuple) Upper bound of the range of scale + for each dimension. e.g. (1.0, 2.0, 2.0). + :param `Crop4VolumeFusion_augentation_mode`: (optional, int) The mode for augmentation of cropped volume. + 0: no spatial or intensity augmentatin. + 1: intensity augmentation only +` 2: spatial augmentation only + 3: Both intensity and spatial augmentation (default). + """ + def __init__(self, params): + self.output_size = params['Crop4VolumeFusion_output_size'.lower()] + self.scale_lower = params.get('Crop4VolumeFusion_rescale_lower_bound'.lower(), [0.7, 0.7, 0.7]) + self.scale_upper = params.get('Crop4VolumeFusion_rescale_upper_bound'.lower(), [1.5, 1.5, 1.5]) + self.aug_mode = params.get('Crop4VolumeFusion_augentation_mode'.lower(), 3) + self.task = params['Task'.lower()] + assert isinstance(self.output_size, (list, tuple)) + + def __call__(self, sample): + image = sample['image'] + channel, input_size = image.shape[0], image.shape[1:] + input_dim = len(input_size) + assert channel == 1 + assert(input_dim == len(self.output_size)) + + if(self.aug_mode == 0 or self.aug_mode == 1): + self.scale_lower = [1.0, 1.0, 1.0] + self.scale_upper = [1.0, 1.0, 1.0] + patch_1 = random_resized_crop(image, self.output_size, self.scale_lower, self.scale_upper) + patch_2 = random_resized_crop(image, self.output_size, self.scale_lower, self.scale_upper) + if(self.aug_mode > 1): + patch_1 = random_flip(patch_1) + patch_2 = random_flip(patch_2) + if(self.aug_mode == 1 or self.aug_mode == 3): + p0, p1 = random.uniform(0.1, 2.0), random.uniform(98, 99.9) + patch_1 = adaptive_contrast_adjust(patch_1, p0, p1) + patch_1 = gamma_correction(patch_1, 0.7, 1.5) + + p0, p1 = random.uniform(0.1, 2.0), random.uniform(98, 99.9) + patch_2 = adaptive_contrast_adjust(patch_2, p0, p1) + patch_2 = gamma_correction(patch_2, 0.7, 1.5) + + if(random.random() < 0.25): + patch_1 = 1.0 - patch_1 + patch_2 = 1.0 - patch_2 + + sample['image'] = patch_1, patch_2 + return sample + +class VolumeFusion(AbstractTransform): + """ + Randomly crop an volume into two views with augmentation. This is used for + self-supervised pretraining in Vox2vec. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `DualViewCrop_output_size`: (list/tuple) Desired output size [D, H, W]. + The output channel is the same as the input channel. + :param `DualViewCrop_scale_lower_bound`: (list/tuple) Lower bound of the range of scale + for each dimension. e.g. (1.0, 0.5, 0.5). + param `DualViewCrop_scale_upper_bound`: (list/tuple) Upper bound of the range of scale + for each dimension. e.g. (1.0, 2.0, 2.0). + :param `DualViewCrop_inverse`: (optional, bool) Is inverse transform needed for inference. + Default is `False`. Currently, the inverse transform is not supported, and + this transform is assumed to be used only during training stage. + """ + def __init__(self, params): + self.cls_num = params.get('VolumeFusion_cls_num'.lower(), 5) + self.ratio = params.get('VolumeFusion_foreground_ratio'.lower(), 0.7) + self.size_min = params.get('VolumeFusion_patchsize_min'.lower(), [8, 8, 8]) + self.size_max = params.get('VolumeFusion_patchsize_max'.lower(), [32, 32, 32]) + self.task = params['Task'.lower()] + + def __call__(self, sample): + K = self.cls_num - 1 + image1, image2 = sample['image'] + C, D, H, W = image1.shape + db = random.randint(self.size_min[0], self.size_max[0]) + hb = random.randint(self.size_min[1], self.size_max[1]) + wb = random.randint(self.size_min[2], self.size_max[2]) + d_offset = random.randint(0, D % db) + h_offset = random.randint(0, H % hb) + w_offset = random.randint(0, W % wb) + d_n = D // db + h_n = H // hb + w_n = W // wb + Nblock = d_n * h_n * w_n + Nfg = int(d_n * h_n * w_n * self.ratio) + list_fg = [1] * Nfg + [0] * (Nblock - Nfg) + random.shuffle(list_fg) + mask = np.zeros([1, D, H, W], np.uint8) + for d in range(d_n): + for h in range(h_n): + for w in range(w_n): + d0, h0, w0 = d*db + d_offset, h*hb + h_offset, w*wb + w_offset + d1, h1, w1 = d0 + db, h0 + hb, w0 + wb + idx = d*h_n*w_n + h*w_n + w + if(list_fg[idx]> 0): + cls_k = random.randint(1, K) + mask[:, d0:d1, h0:h1, w0:w1] = cls_k + alpha = mask * 1.0 / K + x_fuse = alpha*image1 + (1.0 - alpha)*image2 + sample['image'] = x_fuse + sample['label'] = mask + return sample + +class VolumeFusionShuffle(AbstractTransform): + """ + Randomly crop an volume into two views with augmentation. This is used for + self-supervised pretraining in Vox2vec. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `DualViewCrop_output_size`: (list/tuple) Desired output size [D, H, W]. + The output channel is the same as the input channel. + :param `DualViewCrop_scale_lower_bound`: (list/tuple) Lower bound of the range of scale + for each dimension. e.g. (1.0, 0.5, 0.5). + param `DualViewCrop_scale_upper_bound`: (list/tuple) Upper bound of the range of scale + for each dimension. e.g. (1.0, 2.0, 2.0). + :param `DualViewCrop_inverse`: (optional, bool) Is inverse transform needed for inference. + Default is `False`. Currently, the inverse transform is not supported, and + this transform is assumed to be used only during training stage. + """ + def __init__(self, params): + self.cls_num = params.get('VolumeFusionShuffle_cls_num'.lower(), 5) + self.ratio = params.get('VolumeFusionShuffle_foreground_ratio'.lower(), 0.7) + self.size_min = params.get('VolumeFusionShuffle_patchsize_min'.lower(), [8, 8, 8]) + self.size_max = params.get('VolumeFusionShuffle_patchsize_max'.lower(), [32, 32, 32]) + self.task = params['Task'.lower()] + + def __call__(self, sample): + K = self.cls_num - 1 + image1, image2 = sample['image'] + C, D, H, W = image1.shape + x_fuse = image2 * 1.0 + mask = np.zeros([1, D, H, W], np.uint8) + db = random.randint(self.size_min[0], self.size_max[0]) + hb = random.randint(self.size_min[1], self.size_max[1]) + wb = random.randint(self.size_min[2], self.size_max[2]) + d_offset = random.randint(0, D % db) + h_offset = random.randint(0, H % hb) + w_offset = random.randint(0, W % wb) + d_n = D // db + h_n = H // hb + w_n = W // wb + coord_list_source = [] + for di in range(d_n): + for hi in range(h_n): + for wi in range(w_n): + coord_list_source.append([di, hi, wi]) + coord_list_target = copy.deepcopy(coord_list_source) + random.shuffle(coord_list_source) + random.shuffle(coord_list_target) + for i in range(int(len(coord_list_source)*self.ratio)): + ds_l = d_offset + db * coord_list_source[i][0] + hs_l = h_offset + hb * coord_list_source[i][1] + ws_l = w_offset + wb * coord_list_source[i][2] + dt_l = d_offset + db * coord_list_target[i][0] + ht_l = h_offset + hb * coord_list_target[i][1] + wt_l = w_offset + wb * coord_list_target[i][2] + s_crop = image1[:, ds_l:ds_l+db, hs_l:hs_l+hb, ws_l:ws_l+wb] + t_crop = image2[:, dt_l:dt_l+db, ht_l:ht_l+hb, wt_l:wt_l+wb] + fg_m = random.randint(1, K) + fg_w = fg_m / (K + 0.0) + x_fuse[:, dt_l:dt_l+db, ht_l:ht_l+hb, wt_l:wt_l+wb] = t_crop * (1.0 - fg_w) + s_crop * fg_w + mask[0, dt_l:dt_l+db, ht_l:ht_l+hb, wt_l:wt_l+wb] = \ + np.ones([1, db, hb, wb]) * fg_m + sample['image'] = x_fuse + sample['label'] = mask + return sample + diff --git a/pymic/transform/crop4vox2vec.py b/pymic/transform/crop4vox2vec.py new file mode 100644 index 0000000..6fdcf83 --- /dev/null +++ b/pymic/transform/crop4vox2vec.py @@ -0,0 +1,160 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch + +import json +import math +import random +import numpy as np +from imops import crop_to_box +from typing import * +from scipy import ndimage +from pymic import TaskType +from pymic.transform.abstract_transform import AbstractTransform +from pymic.transform.crop import CenterCrop +from pymic.util.image_process import * +from pymic.transform.intensity import * + +def normalize_axis_list(axis, ndim): + return list(np.core.numeric.normalize_axis_tuple(axis, ndim)) + +def scale_hu(image_hu: np.ndarray, window_hu: Tuple[float, float]) -> np.ndarray: + min_hu, max_hu = window_hu + assert min_hu < max_hu + return np.clip((image_hu - min_hu) / (max_hu - min_hu), 0, 1) + +# def gaussian_filter( +# x: np.ndarray, +# sigma: Union[float, Sequence[float]], +# axis: Union[int, Sequence[int]] +# ) -> np.ndarray: +# axis = normalize_axis_list(axis, x.ndim) +# sigma = np.broadcast_to(sigma, len(axis)) +# for sgm, ax in zip(sigma, axis): +# x = ndimage.gaussian_filter1d(x, sgm, ax) +# return x + +# def gaussian_sharpen( +# x: np.ndarray, +# sigma_1: Union[float, Sequence[float]], +# sigma_2: Union[float, Sequence[float]], +# alpha: float, +# axis: Union[int, Sequence[int]] +# ) -> np.ndarray: +# """ See https://docs.monai.io/en/stable/transforms.html#gaussiansharpen """ +# blurred = gaussian_filter(x, sigma_1, axis) +# return blurred + alpha * (blurred - gaussian_filter(blurred, sigma_2, axis)) + +def sample_box(image_size, patch_size, anchor_voxel=None): + image_size = np.array(image_size, ndmin=1) + patch_size = np.array(patch_size, ndmin=1) + + if not np.all(image_size >= patch_size): + raise ValueError(f'Can\'t sample patch of size {patch_size} from image of size {image_size}') + + min_start = 0 + max_start = image_size - patch_size + if anchor_voxel is not None: + anchor_voxel = np.array(anchor_voxel, ndmin=1) + min_start = np.maximum(min_start, anchor_voxel - patch_size + 1) + max_start = np.minimum(max_start, anchor_voxel) + start = np.random.randint(min_start, max_start + 1) + return np.array([start, start + patch_size]) + +def sample_views( + image: np.ndarray, + min_overlap: Tuple[int, int, int], + patch_size: Tuple[int, int, int], + max_num_voxels: int, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """ For 3D volumes, the image shape should be [C, D, H, W]. + """ + img_size = image.shape[1:] + overlap = [random.randint(min_overlap[i], patch_size[i]) for i in range(3)] + union_size = [2*patch_size[i] - overlap[i] for i in range(3)] + anchor_max = [img_size[i] - union_size[i] for i in range(3)] + crop_min_1 = [random.randint(0, anchor_max[i]) for i in range(3)] + crop_min_2 = [crop_min_1[i] + patch_size[i] - overlap[i] for i in range(3)] + patch_1 = sample_view(image, crop_min_1, patch_size) + patch_2 = sample_view(image, crop_min_2, patch_size) + + coords = [range(crop_min_2[i], crop_min_2[i] + overlap[i]) for i in range(3)] + coords = np.asarray(np.meshgrid(coords[0], coords[1], coords[2])) + coords = coords.reshape(3, -1).transpose() + roi_voxels_1 = coords - crop_min_1 + roi_voxels_2 = coords - crop_min_2 + + indices = range(coords.shape[0]) + if len(indices) > max_num_voxels: + indices = np.random.choice(indices, max_num_voxels, replace=False) + + return patch_1, patch_2, roi_voxels_1[indices], roi_voxels_2[indices] + + +def sample_view(image, crop_min, patch_size): + """ For 3D volumes, the image shape should be [C, D, H, W]. + """ + assert image.ndim == 4 + C = image.shape[0] + crop_max = [crop_min[i] + patch_size[i] for i in range(3)] + out = crop_ND_volume_with_bounding_box(image, [0] + crop_min, [C] + crop_max) + + # intensity augmentations + for c in range(C): + if(random.random() < 0.8): + out[c] = gaussian_noise(out[c], 0.05, 0.1) + if(random.random() < 0.5): + out[c] = gaussian_blur(out[c], 0.5, 1.5) + else: + alpha = random.uniform(0.0, 2.0) + out[c] = gaussian_sharpen(out[c], 0.5, 2.0, alpha) + if(random.random() < 0.8): + out[c] = gamma_correction(out[c], 0.5, 2.0) + if(random.random() < 0.8): + out[c] = window_level_augment(out[c]) + return out + +class Crop4Vox2Vec(CenterCrop): + """ + Randomly crop an volume into two views with augmentation. This is used for + self-supervised pretraining in Vox2vec. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `DualViewCrop_output_size`: (list/tuple) Desired output size [D, H, W]. + The output channel is the same as the input channel. + :param `DualViewCrop_scale_lower_bound`: (list/tuple) Lower bound of the range of scale + for each dimension. e.g. (1.0, 0.5, 0.5). + param `DualViewCrop_scale_upper_bound`: (list/tuple) Upper bound of the range of scale + for each dimension. e.g. (1.0, 2.0, 2.0). + :param `DualViewCrop_inverse`: (optional, bool) Is inverse transform needed for inference. + Default is `False`. Currently, the inverse transform is not supported, and + this transform is assumed to be used only during training stage. + """ + def __init__(self, params): + self.output_size = params['Crop4Vox2Vec_output_size'.lower()] + self.min_overlap = params.get('Crop4Vox2Vec_min_overlap'.lower(), [8, 12, 12]) + self.max_voxel = params.get('Crop4Vox2Vec_max_voxel'.lower(), 1024) + self.inverse = params.get('Crop4Vox2Vec_inverse'.lower(), False) + self.task = params['Task'.lower()] + assert isinstance(self.output_size, (list, tuple)) + + def __call__(self, sample): + image = sample['image'] + channel, input_size = image.shape[0], image.shape[1:] + input_dim = len(input_size) + assert channel == 1 + assert(input_dim == len(self.output_size)) + invalid_size = [input_size[i] < self.output_size[i]*2 - self.min_overlap[i] for i in range(3)] + if True in invalid_size: + raise ValueError("The overlap requirement {0:} is too weak for the given patch size \ + {1:} and input size {2:}".format( self.min_overlap, self.output_size,input_size)) + + patches_1, patches_2, voxels_1, voxels_2 = sample_views(image, + self.min_overlap, self.output_size, self.max_voxel) + sample['image'] = patches_1, patches_2, voxels_1, voxels_2 + return sample + + diff --git a/pymic/transform/flip.py b/pymic/transform/flip.py index 6ea017c..6ffd535 100644 --- a/pymic/transform/flip.py +++ b/pymic/transform/flip.py @@ -33,8 +33,7 @@ def __init__(self, params): def __call__(self, sample): image = sample['image'] - input_shape = image.shape - input_dim = len(input_shape) - 1 + input_dim = image.ndim flip_axis = [] if(self.flip_width): if(random.random() > 0.5): diff --git a/pymic/transform/rescale.py b/pymic/transform/rescale.py index 2896a4e..47271ec 100644 --- a/pymic/transform/rescale.py +++ b/pymic/transform/rescale.py @@ -165,11 +165,10 @@ class Resample(Rescale): The arguments should be written in the `params` dictionary, and it has the following fields: - :param `Rescale_output_size`: (list/tuple or int) The output size along each spatial axis, - such as [D, H, W] or [H, W]. If D is None, the input image is only reslcaled in 2D. - If int, the smallest axis is matched to output_size keeping aspect ratio the same - as the input. - :param `Rescale_inverse`: (optional, bool) + :param `Resample_output_spacing`: (list/tuple or int) The output spacing along each spatial axis, + such as [Ds, Hs, Ws] or [Hs, Ws]. If Ds is None, the input image is only reslcaled in 2D. + :param `Resample_ignore_zspacing_range`: (list/tuple) The range of zspacing that would be ingored. + :param `Resample_inverse`: (optional, bool) Is inverse transform needed for inference. Default is `True`. """ def __init__(self, params): @@ -177,11 +176,11 @@ def __init__(self, params): self.output_spacing = params["Resample_output_spacing".lower()] self.ignore_zspacing= params.get("Resample_ignore_zspacing_range".lower(), None) self.inverse = params.get("Resample_inverse".lower(), True) - # assert isinstance(self.output_size, (int, list, tuple)) def __call__(self, sample): image = sample['image'] input_shape = image.shape + input_dim = len(input_shape) - 1 spacing = sample['spacing'] out_spacing = [item for item in self.output_spacing] diff --git a/pymic/transform/volume_fusion.py b/pymic/transform/volume_fusion.py deleted file mode 100644 index 38c74c0..0000000 --- a/pymic/transform/volume_fusion.py +++ /dev/null @@ -1,117 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division -import random -import numpy as np -from pymic.transform.abstract_transform import AbstractTransform -from pymic.util.image_process import * -try: # SciPy >= 0.19 - from scipy.special import comb -except ImportError: - from scipy.misc import comb - -def random_resized_crop(x, output_shape): - img_shape = x.shape[1:] - ratio = [img_shape[i] / output_shape[i] for i in range(3)] - r_max = [min(ratio[i], 1.25) for i in range(3)] - r_min = (0.8, 0.8, 0.8) - scale = [r_min[i] + random.random() * (r_max[i] - r_min[i]) for i in range(3)] - crop_size = [int(output_shape[i] * scale[i]) for i in range(3)] - - bb_min = [random.randint(0, img_shape[i] - crop_size[i]) for i in range(3)] - bb_max = [bb_min[i] + crop_size[i] for i in range(3)] - bb_min = [0] + bb_min - bb_max = [x.shape[0]] + bb_max - crop_volume = crop_ND_volume_with_bounding_box(x, bb_min, bb_max) - - scale = [(output_shape[i] + 0.0)/crop_size[i] for i in range(3)] - scale = [1.0] + scale - y = ndimage.interpolation.zoom(crop_volume, scale, order = 1) - return y - -def nonlinear_transform(x): - v_min = np.min(x) - v_max = np.max(x) - x = (x - v_min)/(v_max - v_min) - a = random.random() * 0.7 + 0.15 - b = random.random() * 0.7 + 0.15 - alpha = b / a - beta = (1 - b) / (1 - a) - if(alpha < 1.0 ): - y = np.maximum(alpha*x, beta*x + 1 - beta) - else: - y = np.minimum(alpha*x, beta*x + 1 - beta) - if(random.random() < 0.5): - y = 1.0 - y - y = y * (v_max - v_min) + v_min - return y - -def random_flip(x): - flip_axis = [] - if(random.random() > 0.5): - flip_axis.append(-1) - if(random.random() > 0.5): - flip_axis.append(-2) - if(random.random() > 0.5): - flip_axis.append(-3) - - if(len(flip_axis) > 0): - # use .copy() to avoid negative strides of numpy array - # current pytorch does not support negative strides - y = np.flip(x, flip_axis).copy() - else: - y = x - return y - - -class VolumeFusion(AbstractTransform): - """ - fusing two subvolumes of an image, used for self-supervised learning - """ - def __init__(self, params): - super(VolumeFusion, self).__init__(params) - self.inverse = params.get('VolumeFusion_inverse'.lower(), False) - self.crop_size = params.get('VolumeFusion_crop_size'.lower(), [64, 128, 128]) - self.block_range = params.get('VolumeFusion_block_range'.lower(), [20, 40]) - self.size_min = params.get('VolumeFusion_size_min'.lower(), [8, 16, 16]) - self.size_max = params.get('VolumeFusion_size_max'.lower(), [16, 32, 32]) - - def __call__(self, sample): - x = sample['image'] - x0 = random_resized_crop(x, self.crop_size) - x1 = random_resized_crop(x, self.crop_size) - x0 = random_flip(x0) - x1 = random_flip(x1) - # nonlinear transform - x0a = nonlinear_transform(x0) - x0b = nonlinear_transform(x0) - x1 = nonlinear_transform(x1) - - D, H, W = x0.shape[1:] - mask = np.zeros_like(x0, np.uint8) - p_num = random.randint(self.block_range[0], self.block_range[1]) - for i in range(p_num): - d = random.randint(self.size_min[0], self.size_max[0]) - h = random.randint(self.size_min[1], self.size_max[1]) - w = random.randint(self.size_min[2], self.size_max[2]) - dc = random.randint(0, D - 1) - hc = random.randint(0, H - 1) - wc = random.randint(0, W - 1) - d0 = dc - d // 2 - h0 = hc - h // 2 - w0 = wc - w // 2 - d1 = min(D, d0 + d) - h1 = min(H, h0 + h) - w1 = min(W, w0 + w) - d0, h0, w0 = max(0, d0), max(0, h0), max(0, w0) - temp_m = np.ones([d1 - d0, h1 - h0, w1 - w0]) - if(random.random() < 0.5): - temp_m = temp_m * 2 - mask[:, d0:d1, h0:h1, w0:w1] = temp_m - - mask1 = np.asarray(mask == 1, np.uint8) - mask2 = np.asarray(mask == 2, np.uint8) - y = x0a * (1.0 - mask1) + x0b * mask1 - y = y * (1.0 - mask2) + x1 * mask2 - sample['image'] = y - sample['label'] = mask - return sample From 2bb8b0813e96713af067f9a5eb325bde061d1fe3 Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 30 Sep 2024 16:22:40 +0800 Subject: [PATCH 191/225] update util files --- pymic/transform/trans_dict.py | 1 - pymic/util/image_process.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pymic/transform/trans_dict.py b/pymic/transform/trans_dict.py index 2a15857..332e594 100644 --- a/pymic/transform/trans_dict.py +++ b/pymic/transform/trans_dict.py @@ -45,7 +45,6 @@ from pymic.transform.crop4dino import Crop4Dino from pymic.transform.crop4vox2vec import Crop4Vox2Vec from pymic.transform.crop4vf import Crop4VolumeFusion, VolumeFusion, VolumeFusionShuffle -from pymic.transform.volume_fusion import * from pymic.transform.label_convert import * TransformDict = { diff --git a/pymic/util/image_process.py b/pymic/util/image_process.py index 158569c..c31f28f 100644 --- a/pymic/util/image_process.py +++ b/pymic/util/image_process.py @@ -73,7 +73,7 @@ def crop_ND_volume_with_bounding_box(volume, bb_min, bb_max): output = volume[bb_min[0]:bb_max[0], bb_min[1]:bb_max[1], bb_min[2]:bb_max[2], bb_min[3]:bb_max[3], bb_min[4]:bb_max[4]] else: raise ValueError("the dimension number shoud be 2 to 5") - return output + return output * 1 def set_ND_volume_roi_with_bounding_box_range(volume, bb_min, bb_max, sub_volume, addition = True): """ From 9cc16a6e3e29c34259083c8d5f409affc7869316 Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 30 Sep 2024 16:24:05 +0800 Subject: [PATCH 192/225] update loss function --- pymic/loss/cls/basic.py | 2 +- pymic/loss/cls/infoNCE.py | 39 +++++++++++++++++++++++++++++++++++++ pymic/loss/loss_dict_cls.py | 3 ++- pymic/loss/seg/deep_sup.py | 3 ++- 4 files changed, 44 insertions(+), 3 deletions(-) create mode 100644 pymic/loss/cls/infoNCE.py diff --git a/pymic/loss/cls/basic.py b/pymic/loss/cls/basic.py index 56925fc..6b71e23 100644 --- a/pymic/loss/cls/basic.py +++ b/pymic/loss/cls/basic.py @@ -39,7 +39,7 @@ def forward(self, loss_input_dict): class SigmoidCELoss(AbstractClassificationLoss): """ - Sigmoid-based CE loss. + Sigmoid-based CE loss, should be used when task_type = cls_coexist """ def __init__(self, params = None): super(SigmoidCELoss, self).__init__(params) diff --git a/pymic/loss/cls/infoNCE.py b/pymic/loss/cls/infoNCE.py new file mode 100644 index 0000000..fb6f1c1 --- /dev/null +++ b/pymic/loss/cls/infoNCE.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import torch.nn as nn + +class InfoNCELoss(nn.Module): + """ + Abstract Classification Loss. + """ + def __init__(self, params = None): + super(InfoNCELoss, self).__init__() + self.temp = params.get("temperature", 0.1) + + def forward(self, input_1, input_2): + """ + The arguments should be written in the `loss_input_dict` dictionary, and it has the + following fields. + + :param prediction: A prediction with shape of [N, C] where C is the class number. + :param ground_truth: The corresponding ground truth, with shape of [N, 1]. + + Note that `prediction` is the digit output of a network, before using softmax. + """ + B = list(input_1.shape)[0] + loss = 0.0 + for b in range(B): + embeds_1 = input_1[b] + embeds_2 = input_2[b] + logits_11 = torch.matmul(embeds_1, embeds_1.T) / self.temp + logits_11.fill_diagonal_(float('-inf')) + logits_12 = torch.matmul(embeds_1, embeds_2.T) / self.temp + logits_22 = torch.matmul(embeds_2, embeds_2.T) / self.temp + logits_22.fill_diagonal_(float('-inf')) + loss_1 = torch.mean(-logits_12.diag() + torch.logsumexp(torch.cat([logits_11, logits_12], dim=1), dim=1)) + loss_2 = torch.mean(-logits_12.diag() + torch.logsumexp(torch.cat([logits_12.T, logits_22], dim=1), dim=1)) + loss = loss + (loss_1 + loss_2) / 2 + loss = loss / B + return loss \ No newline at end of file diff --git a/pymic/loss/loss_dict_cls.py b/pymic/loss/loss_dict_cls.py index e07f46b..44744cb 100644 --- a/pymic/loss/loss_dict_cls.py +++ b/pymic/loss/loss_dict_cls.py @@ -11,9 +11,10 @@ """ from __future__ import print_function, division from pymic.loss.cls.basic import * - +from pymic.loss.cls.infoNCE import InfoNCELoss PyMICClsLossDict = {"CrossEntropyLoss": CrossEntropyLoss, "SigmoidCELoss": SigmoidCELoss, + 'InfoNCELoss': InfoNCELoss, "L1Loss": L1Loss, "MSELoss": MSELoss, "NLLLoss": NLLLoss} diff --git a/pymic/loss/seg/deep_sup.py b/pymic/loss/seg/deep_sup.py index da6d9ef..6669486 100644 --- a/pymic/loss/seg/deep_sup.py +++ b/pymic/loss/seg/deep_sup.py @@ -2,6 +2,7 @@ from __future__ import print_function, division import torch.nn as nn +import numpy as np from torch.nn.functional import interpolate from pymic.loss.seg.abstract import AbstractSegLoss @@ -69,7 +70,7 @@ def forward(self, loss_input_dict): be a list or a tuple""") pred_num = len(pred) if(self.deep_sup_weight is None): - self.deep_sup_weight = [1.0] * pred_num + self.deep_sup_weight = [1.0 / pow(2, i) for i in range(pred_num)] else: assert(pred_num == len(self.deep_sup_weight)) loss_sum, weight_sum = 0.0, 0.0 From 474188eaed3bcb4dd373054f10e8546e2450de73 Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 30 Sep 2024 16:26:26 +0800 Subject: [PATCH 193/225] Update torch_pretrained_net.py --- pymic/net/cls/torch_pretrained_net.py | 48 +++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 3 deletions(-) diff --git a/pymic/net/cls/torch_pretrained_net.py b/pymic/net/cls/torch_pretrained_net.py index 5017f72..c1959a4 100644 --- a/pymic/net/cls/torch_pretrained_net.py +++ b/pymic/net/cls/torch_pretrained_net.py @@ -2,7 +2,6 @@ from __future__ import print_function, division import itertools -import torch import torch.nn as nn import torchvision.models as models @@ -61,7 +60,49 @@ class ResNet18(BuiltInNet): """ def __init__(self, params): super(ResNet18, self).__init__(params) - self.net = models.resnet18(pretrained = self.pretrain) + weights = 'IMAGENET1K_V1' if self.pretrain else None + self.net = models.resnet18(weights = weights) + + # replace the last layer + num_ftrs = self.net.fc.in_features + self.net.fc = nn.Linear(num_ftrs, params['class_num']) + + # replace the first layer when in_chns is not 3 + if(self.in_chns != 3): + self.net.conv1 = nn.Conv2d(self.in_chns, 64, kernel_size=(7, 7), + stride=(2, 2), padding=(3, 3), bias=False) + + def get_parameters_to_update(self): + if(self.update_mode == "all"): + return self.net.parameters() + elif(self.update_mode == "last"): + params = self.net.fc.parameters() + if(self.in_chns !=3): + # combining the two iterables into a single one + # see: https://dzone.com/articles/python-joining-multiple + params = itertools.chain() + for pram in [self.net.fc.parameters(), self.net.conv1.parameters()]: + params = itertools.chain(params, pram) + return params + else: + raise(ValueError("update_mode can only be 'all' or 'last'.")) + +class ResNet50(BuiltInNet): + """ + ResNet18 for classification. + Parameters should be set in the `params` dictionary that contains the + following fields: + + :param input_chns: (int) Input channel number, default is 3. + :param pretrain: (bool) Using pretrained model or not, default is True. + :param update_mode: (str) The strategy for updating layers: "`all`" means updating + all the layers, and "`last`" (by default) means updating the last layer, + as well as the first layer when `input_chns` is not 3. + """ + def __init__(self, params): + super(ResNet50, self).__init__(params) + weights = 'IMAGENET1K_V1' if self.pretrain else None + self.net = models.resnet50(weights = weights) # replace the last layer num_ftrs = self.net.fc.in_features @@ -101,7 +142,8 @@ class VGG16(BuiltInNet): """ def __init__(self, params): super(VGG16, self).__init__(params) - self.net = models.vgg16(pretrained = self.pretrain) + weights = 'IMAGENET1K_V1' if self.pretrain else None + self.net = models.vgg16(weights = weights) # replace the last layer num_ftrs = self.net.classifier[-1].in_features From 1615412608e86471f94f53c08895bb153f2de836 Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 30 Sep 2024 16:28:32 +0800 Subject: [PATCH 194/225] Create canet.py --- pymic/net/net2d/canet.py | 229 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 229 insertions(+) create mode 100644 pymic/net/net2d/canet.py diff --git a/pymic/net/net2d/canet.py b/pymic/net/net2d/canet.py new file mode 100644 index 0000000..ab64ab5 --- /dev/null +++ b/pymic/net/net2d/canet.py @@ -0,0 +1,229 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import torch +import torch.nn as nn + +class ConvLayer(nn.Module): + """ + A combination of Conv2d, BatchNorm2d and LeakyReLU. + """ + def __init__(self, in_channels, out_channels, kernel_size = 1): + super(ConvLayer, self).__init__() + padding = int((kernel_size - 1) / 2) + self.conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding), + nn.BatchNorm2d(out_channels), + nn.LeakyReLU() + ) + + def forward(self, x): + return self.conv(x) + +class SEBlock(nn.Module): + """ + A Modified Squeeze-and-Excitation block for spatial attention. + """ + def __init__(self, in_channels, r): + super(SEBlock, self).__init__() + + redu_chns = int(in_channels / r) + self.se_layers = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels, redu_chns, kernel_size=1, padding=0), + nn.LeakyReLU(), + nn.Conv2d(redu_chns, in_channels, kernel_size=1, padding=0), + nn.ReLU()) + + def forward(self, x): + f = self.se_layers(x) + return f*x + x + +class ASPPBlock(nn.Module): + """ + ASPP block. + """ + def __init__(self,in_channels, out_channels_list, kernel_size_list, dilation_list): + super(ASPPBlock, self).__init__() + self.conv_num = len(out_channels_list) + assert(self.conv_num == 4) + assert(self.conv_num == len(kernel_size_list) and self.conv_num == len(dilation_list)) + pad0 = int((kernel_size_list[0] - 1) / 2 * dilation_list[0]) + pad1 = int((kernel_size_list[1] - 1) / 2 * dilation_list[1]) + pad2 = int((kernel_size_list[2] - 1) / 2 * dilation_list[2]) + pad3 = int((kernel_size_list[3] - 1) / 2 * dilation_list[3]) + self.conv_1 = nn.Conv2d(in_channels, out_channels_list[0], kernel_size = kernel_size_list[0], + dilation = dilation_list[0], padding = pad0 ) + self.conv_2 = nn.Conv2d(in_channels, out_channels_list[1], kernel_size = kernel_size_list[1], + dilation = dilation_list[1], padding = pad1 ) + self.conv_3 = nn.Conv2d(in_channels, out_channels_list[2], kernel_size = kernel_size_list[2], + dilation = dilation_list[2], padding = pad2 ) + self.conv_4 = nn.Conv2d(in_channels, out_channels_list[3], kernel_size = kernel_size_list[3], + dilation = dilation_list[3], padding = pad3 ) + + out_channels = out_channels_list[0] + out_channels_list[1] + out_channels_list[2] + out_channels_list[3] + self.conv_1x1 = nn.Sequential( + nn.Conv2d(out_channels, out_channels, kernel_size=1, padding=0), + nn.BatchNorm2d(out_channels), + nn.LeakyReLU()) + + def forward(self, x): + x1 = self.conv_1(x) + x2 = self.conv_2(x) + x3 = self.conv_3(x) + x4 = self.conv_4(x) + + y = torch.cat([x1, x2, x3, x4], dim=1) + y = self.conv_1x1(y) + return y + +class ConvBNActBlock(nn.Module): + """ + Two convolution layers with batch norm, leaky relu, + dropout and SE block. + """ + def __init__(self,in_channels, out_channels, dropout_p): + super(ConvBNActBlock, self).__init__() + self.conv_conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(out_channels), + nn.LeakyReLU(), + nn.Dropout(dropout_p), + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(out_channels), + nn.LeakyReLU(), + SEBlock(out_channels, 2) + ) + + def forward(self, x): + return self.conv_conv(x) + +class DownBlock(nn.Module): + """ + Downsampling by a concantenation of max-pool and avg-pool, + followed by ConvBNActBlock. + """ + def __init__(self, in_channels, out_channels, dropout_p): + super(DownBlock, self).__init__() + self.maxpool = nn.MaxPool2d(2) + self.avgpool = nn.AvgPool2d(2) + self.conv = ConvBNActBlock(2 * in_channels, out_channels, dropout_p) + + def forward(self, x): + x_max = self.maxpool(x) + x_avg = self.avgpool(x) + x_cat = torch.cat([x_max, x_avg], dim=1) + y = self.conv(x_cat) + return y + x_cat + +class UpBlock(nn.Module): + """ + Upssampling followed by ConvBNActBlock. + """ + def __init__(self, in_channels1, in_channels2, out_channels, + bilinear=True, dropout_p = 0.5): + super(UpBlock, self).__init__() + self.bilinear = bilinear + if bilinear: + self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size = 1) + self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + else: + self.up = nn.ConvTranspose2d(in_channels1, in_channels2, kernel_size=2, stride=2) + self.conv = ConvBNActBlock(in_channels2 * 2, out_channels, dropout_p) + + def forward(self, x1, x2): + if self.bilinear: + x1 = self.conv1x1(x1) + x1 = self.up(x1) + x_cat = torch.cat([x2, x1], dim=1) + y = self.conv(x_cat) + return y + x_cat + +class CANet(nn.Module): + """ + Implementation of of CA-Net for biomedical image segmentation. + + * Reference: R. Gu et al. `CA-Net: Comprehensive Attention Convolutional Neural Networks + for Explainable Medical Image Segmentation `_. + IEEE Transactions on Medical Imaging, 40(2),2021:699-711. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param bilinear: (bool) Using bilinear for up-sampling or not. + If False, deconvolution will be used for up-sampling. + """ + def __init__(self, params): + super(COPLENet, self).__init__() + self.params = params + self.in_chns = self.params['in_chns'] + self.ft_chns = self.params['feature_chns'] + self.dropout = self.params['dropout'] + self.n_class = self.params['class_num'] + self.bilinear = self.params['bilinear'] + assert(len(self.ft_chns) == 5) + + f0_half = int(self.ft_chns[0] / 2) + f1_half = int(self.ft_chns[1] / 2) + f2_half = int(self.ft_chns[2] / 2) + f3_half = int(self.ft_chns[3] / 2) + self.in_conv= ConvBNActBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) + self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) + self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) + self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) + self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) + + self.bridge0= ConvLayer(self.ft_chns[0], f0_half) + self.bridge1= ConvLayer(self.ft_chns[1], f1_half) + self.bridge2= ConvLayer(self.ft_chns[2], f2_half) + self.bridge3= ConvLayer(self.ft_chns[3], f3_half) + + self.up1 = UpBlock(self.ft_chns[4], f3_half, self.ft_chns[3], dropout_p = self.dropout[3]) + self.up2 = UpBlock(self.ft_chns[3], f2_half, self.ft_chns[2], dropout_p = self.dropout[2]) + self.up3 = UpBlock(self.ft_chns[2], f1_half, self.ft_chns[1], dropout_p = self.dropout[1]) + self.up4 = UpBlock(self.ft_chns[1], f0_half, self.ft_chns[0], dropout_p = self.dropout[0]) + + f4 = self.ft_chns[4] + aspp_chns = [int(f4 / 4), int(f4 / 4), int(f4 / 4), int(f4 / 4)] + aspp_knls = [1, 3, 3, 3] + aspp_dila = [1, 2, 4, 6] + self.aspp = ASPPBlock(f4, aspp_chns, aspp_knls, aspp_dila) + + + self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, + kernel_size = 3, padding = 1) + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + x0 = self.in_conv(x) + x0b = self.bridge0(x0) + x1 = self.down1(x0) + x1b = self.bridge1(x1) + x2 = self.down2(x1) + x2b = self.bridge2(x2) + x3 = self.down3(x2) + x3b = self.bridge3(x3) + x4 = self.down4(x3) + x4 = self.aspp(x4) + + x = self.up1(x4, x3b) + x = self.up2(x, x2b) + x = self.up3(x, x1b) + x = self.up4(x, x0b) + output = self.out_conv(x) + + if(len(x_shape) == 5): + new_shape = [N, D] + list(output.shape)[1:] + output = torch.reshape(output, new_shape) + output = torch.transpose(output, 1, 2) + return output \ No newline at end of file From 11e1c48549387a0e343e5715630ba6e4c47a1cb5 Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 30 Sep 2024 21:42:34 +0800 Subject: [PATCH 195/225] update 3d segmentation networks --- pymic/net/net3d/grunet.py | 264 ++++++ pymic/net/net3d/trans3d/__init__.py | 0 pymic/net/net3d/trans3d/transunet3d.py | 1053 ++++++++++++++++++++++++ pymic/net/net3d/unet3d.py | 22 +- pymic/net/net_dict_seg.py | 24 +- pymic/net_run/self_sup/__init__.py | 8 +- pymic/net_run/self_sup/self_vox2vec.py | 304 +++++++ 7 files changed, 1655 insertions(+), 20 deletions(-) create mode 100644 pymic/net/net3d/grunet.py create mode 100644 pymic/net/net3d/trans3d/__init__.py create mode 100644 pymic/net/net3d/trans3d/transunet3d.py create mode 100644 pymic/net_run/self_sup/self_vox2vec.py diff --git a/pymic/net/net3d/grunet.py b/pymic/net/net3d/grunet.py new file mode 100644 index 0000000..bae6447 --- /dev/null +++ b/pymic/net/net3d/grunet.py @@ -0,0 +1,264 @@ +# -*- coding: utf-8 -*- +# Note: this is a renamed version of fmunetv3. fmunetv3 will be removed in a later version. +# This network is originally used in the VolF paper. +from __future__ import print_function, division + +import itertools +import logging +import torch +import torch.nn as nn +from pymic.net.net_init import Initialization_He, Initialization_XavierUniform + +dim0 = {0:3, 1:2, 2:2} +dim1 = {0:3, 1:3, 2:2} +conv_knl = {2: (1, 3, 3), 3: 3} +conv_pad = {2: (0, 1, 1), 3: 1} +pool_knl = {2: (1, 2, 2), 3: 2} +down_stride = {2: (1, 2, 2), 3: 2} + +class ResConv(nn.Module): + def __init__(self, out_channels, dim = 3, dropout_p = 0.0, depth = 2): + super(ResConv, self).__init__() + assert(dim == 2 or dim == 3) + self.out_channels = out_channels + self.conv_list = nn.ModuleList([nn.Sequential( + nn.InstanceNorm3d(out_channels, affine = True), + nn.LeakyReLU(), + nn.Dropout(dropout_p), + nn.Conv3d(out_channels, out_channels, kernel_size=conv_knl[dim], padding=conv_pad[dim])) + for i in range(depth)]) + + def forward(self, x): + for conv in self.conv_list: + x = conv(x) + x + return x + +class DownSample(nn.Module): + """downsampling based on convolution + + :param in_channels: (int) Input channel number. + :param out_channels: (int) Output channel number. + :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. + :param dropout_p: (int) Dropout probability. + :param downsample: (bool) Use downsample or not after convolution. + """ + def __init__(self, in_channels, out_channels, dim = 3): + super(DownSample, self).__init__() + self.down = nn.Sequential( + nn.InstanceNorm3d(in_channels, affine = True), + nn.LeakyReLU(), + nn.Conv3d(in_channels, out_channels, kernel_size=conv_knl[dim], + padding=conv_pad[dim], stride = down_stride[dim]) + ) + + def forward(self, x): + return self.down(x) + +class UpCatConv(nn.Module): + """Upsampling followed by `ResConv` block + + :param in_channels1: (int) Input channel number for low-resolution feature map. + :param in_channels2: (int) Input channel number for high-resolution feature map. + :param out_channels: (int) Output channel number. + :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. + :param dropout_p: (int) Dropout probability. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear` for 3D and `Bilinear` for 2D). + The default value is 2. + """ + def __init__(self, in_channels1, in_channels2, out_channels, dim = 3): + super(UpCatConv, self).__init__() + + self.up = nn.Sequential( + nn.InstanceNorm3d(in_channels1, affine = True), + nn.LeakyReLU(), + nn.Conv3d(in_channels1, in_channels2, kernel_size=1, padding=0), + nn.Upsample(scale_factor=pool_knl[dim], mode='trilinear', align_corners=True) + ) + + self.conv = nn.Sequential( + nn.InstanceNorm3d(in_channels2*2, affine = True), + nn.LeakyReLU(), + nn.Conv3d(in_channels2 * 2, out_channels, kernel_size=conv_knl[dim], padding=conv_pad[dim]) + ) + + def forward(self, x_l, x_h): + """ + x_l: low-resolution feature map. + x_h: high-resolution feature map. + """ + y = torch.cat([x_h, self.up(x_l)], dim=1) + return self.conv(y) + +class Encoder(nn.Module): + """ + A modification of the encoder of 3D UNet by using ConvScSEBlock3D + + Parameters are given in the `params` dictionary. + See :mod:`pymic.net.net3d.unet3d.Encoder` for details. + + res_mode: resolution mode: 0-- isotrpic, 1-- near isotrpic, 2-- isotropic + """ + def __init__(self, ft_chns, res_mode = 0, dropout_p = 0, depth = 2): + super(Encoder, self).__init__() + d0, d1 = dim0[res_mode], dim1[res_mode] + + self.en_conv0 = ResConv(ft_chns[0], d0, 0, depth) + self.en_conv1 = ResConv(ft_chns[1], d1, 0, depth) + self.en_conv2 = ResConv(ft_chns[2], 3, dropout_p, depth) + self.en_conv3 = ResConv(ft_chns[3], 3, dropout_p, depth) + self.en_conv4 = ResConv(ft_chns[4], 3, dropout_p, depth) + + self.down0 = DownSample(ft_chns[0], ft_chns[1], d0) + self.down1 = DownSample(ft_chns[1], ft_chns[2], d1) + self.down2 = DownSample(ft_chns[2], ft_chns[3], 3) + self.down3 = DownSample(ft_chns[3], ft_chns[4], 3) + + def forward(self, x): + x0 = self.en_conv0(x) + x1 = self.en_conv1(self.down0(x0)) + x2 = self.en_conv2(self.down1(x1)) + x3 = self.en_conv3(self.down2(x2)) + x4 = self.en_conv4(self.down3(x3)) + return [x0, x1, x2, x3, x4] + +class Decoder(nn.Module): + """ + A modification of the encoder of 3D UNet by using ConvScSEBlock3D + + Parameters are given in the `params` dictionary. + See :mod:`pymic.net.net3d.unet3d.Encoder` for details. + """ + def __init__(self, ft_chns, res_mode = 0, dropout_p = 0, depth = 2): + super(Decoder, self).__init__() + d0, d1 = dim0[res_mode], dim1[res_mode] + + self.upcat0 = UpCatConv(ft_chns[1], ft_chns[0], ft_chns[0], d0) + self.upcat1 = UpCatConv(ft_chns[2], ft_chns[1], ft_chns[1], d1) + self.upcat2 = UpCatConv(ft_chns[3], ft_chns[2], ft_chns[2], 3) + self.upcat3 = UpCatConv(ft_chns[4], ft_chns[3], ft_chns[3], 3) + + self.de_conv0 = ResConv(ft_chns[0], d0, 0, depth) + self.de_conv1 = ResConv(ft_chns[1], d1, 0, depth) + self.de_conv2 = ResConv(ft_chns[2], 3, dropout_p, depth) + self.de_conv3 = ResConv(ft_chns[3], 3, dropout_p, depth) + self.de_conv4 = ResConv(ft_chns[4], 3, dropout_p, depth) + + def forward(self, x): + x0, x1, x2, x3, x4 = x + x4_de = self.de_conv4(x4) + x3_de = self.de_conv3(self.upcat3(x4_de, x3)) + x2_de = self.de_conv2(self.upcat2(x3_de, x2)) + x1_de = self.de_conv1(self.upcat1(x2_de, x1)) + x0_de = self.de_conv0(self.upcat0(x1_de, x0)) + return [x0_de, x1_de, x2_de, x3_de] + +class GRUNet(nn.Module): + """ + A General Residual UNet. + + * Reference: Guotai Wang, Jonathan Shapey, Wenqi Li, Reuben Dorent, Alex Demitriadis, + Sotirios Bisdas, Ian Paddick, Robert Bradford, Shaoting Zhang, Sébastien Ourselin, + Tom Vercauteren: Automatic Segmentation of Vestibular Schwannoma from T2-Weighted + MRI by Deep Spatial Attention with Hardness-Weighted Loss. + `MICCAI (2) 2019: 264-272. `_ + + Note that the attention module in the orininal paper is not used here. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param conv_dims: (list) The convolution dimension (2 or 3) for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear`). The default value + is 2 (`Trilinear`). + :param multiscale_pred: (bool) Get multi-scale prediction. + """ + def __init__(self, params): + super(GRUNet, self).__init__() + params = self.get_default_parameters(params) + + self.stage = 'train' + in_chns = params['in_chns'] + ft_chns = params['feature_chns'] + res_mode = params['res_mode'] + dropout = params['dropout'] + depth = params['depth'] + cls_num = params['class_num'] + self.mul_pred = params.get('multiscale_pred', True) + self.tune_mode= params.get('finetune_mode', 'all') + self.load_mode= params.get('weights_load_mode', 'all') + + d0 = dim0[res_mode] + self.project = nn.Conv3d(in_chns, ft_chns[0], kernel_size=conv_knl[d0], padding=conv_pad[d0]) + self.encoder = Encoder(ft_chns, res_mode, dropout, depth) + # self.decoder = Decoder(ft_chns, res_mode, dropout, depth = 2) + self.decoder = Decoder(ft_chns, res_mode, dropout, depth) + + self.out_layers = nn.ModuleList() + dims = [dim0[res_mode], dim1[res_mode], 3, 3] + for i in range(4): + out_layer = nn.Sequential( + nn.InstanceNorm3d(ft_chns[i], affine = True), + nn.LeakyReLU(), + nn.Conv3d(ft_chns[i], cls_num, kernel_size=conv_knl[dims[i]], padding=conv_pad[dims[i]])) + self.out_layers.append(out_layer) + + init = params['initialization'].lower() + weightInitializer = Initialization_He(1e-2) if init == 'he' else Initialization_XavierUniform() + self.apply(weightInitializer) + + def get_default_parameters(self, params): + default_param = { + 'finetune_mode': 'all', + 'initialization': 'he', + 'feature_chns': [32, 64, 128, 256, 512], + 'dropout': 0.2, + 'res_mode': 0, + 'depth': 2, + 'multiscale_pred': True + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params + + def set_stage(self, stage): + self.stage = stage + + def forward(self, x): + x_en = self.encoder(self.project(x)) + x_de = self.decoder(x_en) + output = self.out_layers[0](x_de[0]) + if(self.mul_pred and self.stage == 'train'): + output = [output] + for i in range(1, len(x_de)): + output.append(self.out_layers[i](x_de[i])) + return output + + def get_parameters_to_update(self): + if(self.tune_mode == 'all'): + return self.parameters() + + up_params = itertools.chain() + if(self.tune_mode == 'decoder'): + up_blocks = [self.decoder, self.out_layers] + else: + raise ValueError("undefined fine-tune mode for GRUNet: {0:}".format(self.tune_mode)) + for block in up_blocks: + up_params = itertools.chain(up_params, block.parameters()) + return up_params + + def get_parameters_to_load(self): + state_dict = self.state_dict() + if(self.load_mode == 'encoder'): + state_dict = {k:v for k, v in state_dict.items() if "project" in k or "encoder" in k } + return state_dict \ No newline at end of file diff --git a/pymic/net/net3d/trans3d/__init__.py b/pymic/net/net3d/trans3d/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pymic/net/net3d/trans3d/transunet3d.py b/pymic/net/net3d/trans3d/transunet3d.py new file mode 100644 index 0000000..183834e --- /dev/null +++ b/pymic/net/net3d/trans3d/transunet3d.py @@ -0,0 +1,1053 @@ +# 3D version of TransUNet; Copyright Johns Hopkins University +# Modified from nnUNet + + + +import torch +import numpy as np +import torch.nn.functional +import torch.nn.functional as F + +from copy import deepcopy +from torch import nn +from torch.cuda.amp import autocast +from scipy.optimize import linear_sum_assignment + +from ..networks.neural_network import SegmentationNetwork +from .vit_modeling import Transformer +from .vit_modeling import CONFIGS as CONFIGS_ViT + +softmax_helper = lambda x: F.softmax(x, 1) + +class InitWeights_He(object): + def __init__(self, neg_slope=1e-2): + self.neg_slope = neg_slope + + def __call__(self, module): + if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d): + module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope) + if module.bias is not None: + module.bias = nn.init.constant_(module.bias, 0) + +class ConvDropoutNormNonlin(nn.Module): + """ + fixes a bug in ConvDropoutNormNonlin where lrelu was used regardless of nonlin. Bad. + """ + + def __init__(self, input_channels, output_channels, + conv_op=nn.Conv2d, conv_kwargs=None, + norm_op=nn.BatchNorm2d, norm_op_kwargs=None, + dropout_op=nn.Dropout2d, dropout_op_kwargs=None, + nonlin=nn.LeakyReLU, nonlin_kwargs=None): + super(ConvDropoutNormNonlin, self).__init__() + if nonlin_kwargs is None: + nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True} + if dropout_op_kwargs is None: + dropout_op_kwargs = {'p': 0.5, 'inplace': True} + if norm_op_kwargs is None: + norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'momentum': 0.1} + if conv_kwargs is None: + conv_kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1, 'dilation': 1, 'bias': True} + + self.nonlin_kwargs = nonlin_kwargs + self.nonlin = nonlin + self.dropout_op = dropout_op + self.dropout_op_kwargs = dropout_op_kwargs + self.norm_op_kwargs = norm_op_kwargs + self.conv_kwargs = conv_kwargs + self.conv_op = conv_op + self.norm_op = norm_op + + self.conv = self.conv_op(input_channels, output_channels, **self.conv_kwargs) + if self.dropout_op is not None and self.dropout_op_kwargs['p'] is not None and self.dropout_op_kwargs[ + 'p'] > 0: + self.dropout = self.dropout_op(**self.dropout_op_kwargs) + else: + self.dropout = None + self.instnorm = self.norm_op(output_channels, **self.norm_op_kwargs) + self.lrelu = self.nonlin(**self.nonlin_kwargs) + + def forward(self, x): + x = self.conv(x) + if self.dropout is not None: + x = self.dropout(x) + return self.lrelu(self.instnorm(x)) + + +class ConvDropoutNonlinNorm(ConvDropoutNormNonlin): + def forward(self, x): + x = self.conv(x) + if self.dropout is not None: + x = self.dropout(x) + return self.instnorm(self.lrelu(x)) + + +class StackedConvLayers(nn.Module): + def __init__(self, input_feature_channels, output_feature_channels, num_convs, + conv_op=nn.Conv2d, conv_kwargs=None, + norm_op=nn.BatchNorm2d, norm_op_kwargs=None, + dropout_op=nn.Dropout2d, dropout_op_kwargs=None, + nonlin=nn.LeakyReLU, nonlin_kwargs=None, first_stride=None, basic_block=ConvDropoutNormNonlin): + ''' + stacks ConvDropoutNormLReLU layers. initial_stride will only be applied to first layer in the stack. The other parameters affect all layers + :param input_feature_channels: + :param output_feature_channels: + :param num_convs: + :param dilation: + :param kernel_size: + :param padding: + :param dropout: + :param initial_stride: + :param conv_op: + :param norm_op: + :param dropout_op: + :param inplace: + :param neg_slope: + :param norm_affine: + :param conv_bias: + ''' + self.input_channels = input_feature_channels + self.output_channels = output_feature_channels + + if nonlin_kwargs is None: + nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True} + if dropout_op_kwargs is None: + dropout_op_kwargs = {'p': 0.5, 'inplace': True} + if norm_op_kwargs is None: + norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'momentum': 0.1} + if conv_kwargs is None: + conv_kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1, 'dilation': 1, 'bias': True} + + self.nonlin_kwargs = nonlin_kwargs + self.nonlin = nonlin + self.dropout_op = dropout_op + self.dropout_op_kwargs = dropout_op_kwargs + self.norm_op_kwargs = norm_op_kwargs + self.conv_kwargs = conv_kwargs + self.conv_op = conv_op + self.norm_op = norm_op + + if first_stride is not None: + self.conv_kwargs_first_conv = deepcopy(conv_kwargs) + self.conv_kwargs_first_conv['stride'] = first_stride + else: + self.conv_kwargs_first_conv = conv_kwargs + + super(StackedConvLayers, self).__init__() + self.blocks = nn.Sequential( + *([basic_block(input_feature_channels, output_feature_channels, self.conv_op, + self.conv_kwargs_first_conv, + self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, + self.nonlin, self.nonlin_kwargs)] + + [basic_block(output_feature_channels, output_feature_channels, self.conv_op, + self.conv_kwargs, + self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, + self.nonlin, self.nonlin_kwargs) for _ in range(num_convs - 1)])) + + def forward(self, x): + return self.blocks(x) + + +def print_module_training_status(module): + if isinstance(module, nn.Conv2d) or isinstance(module, nn.Conv3d) or isinstance(module, nn.Dropout3d) or \ + isinstance(module, nn.Dropout2d) or isinstance(module, nn.Dropout) or isinstance(module, nn.InstanceNorm3d) \ + or isinstance(module, nn.InstanceNorm2d) or isinstance(module, nn.InstanceNorm1d) \ + or isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d) or isinstance(module, + nn.BatchNorm1d): + print(str(module), module.training) + + +class Upsample(nn.Module): + def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=False): + super(Upsample, self).__init__() + self.align_corners = align_corners + self.mode = mode + self.scale_factor = scale_factor + self.size = size + + def forward(self, x): + return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, + align_corners=self.align_corners) + +def c2_xavier_fill(module: nn.Module) -> None: + """ + Initialize `module.weight` using the "XavierFill" implemented in Caffe2. + Also initializes `module.bias` to 0. + Args: + module (torch.nn.Module): module to initialize. + """ + # Caffe2 implementation of XavierFill in fact + # corresponds to kaiming_uniform_ in PyTorch + nn.init.kaiming_uniform_(module.weight, a=1) + if module.bias is not None: + # pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[nn.Module, + # torch.Tensor]`. + nn.init.constant_(module.bias, 0) + +class Generic_TransUNet_max_ppbp(SegmentationNetwork): + DEFAULT_BATCH_SIZE_3D = 2 + DEFAULT_PATCH_SIZE_3D = (64, 192, 160) + SPACING_FACTOR_BETWEEN_STAGES = 2 + BASE_NUM_FEATURES_3D = 30 + MAX_NUMPOOL_3D = 999 + MAX_NUM_FILTERS_3D = 320 + + DEFAULT_PATCH_SIZE_2D = (256, 256) + BASE_NUM_FEATURES_2D = 30 + DEFAULT_BATCH_SIZE_2D = 50 + MAX_NUMPOOL_2D = 999 + MAX_FILTERS_2D = 480 + + use_this_for_batch_size_computation_2D = 19739648 + use_this_for_batch_size_computation_3D = 520000000 # 505789440 + + def __init__(self, input_channels, base_num_features, num_classes, num_pool, num_conv_per_stage=2, + feat_map_mul_on_downscale=2, conv_op=nn.Conv2d, + norm_op=nn.BatchNorm2d, norm_op_kwargs=None, + dropout_op=nn.Dropout2d, dropout_op_kwargs=None, + nonlin=nn.LeakyReLU, nonlin_kwargs=None, deep_supervision=True, dropout_in_localization=False, + final_nonlin=softmax_helper, weightInitializer=InitWeights_He(1e-2), pool_op_kernel_sizes=None, + conv_kernel_sizes=None, + upscale_logits=False, convolutional_pooling=False, convolutional_upsampling=False, # TODO default False + max_num_features=None, basic_block=ConvDropoutNormNonlin, + seg_output_use_bias=False, + patch_size=None, is_vit_pretrain=False, + vit_depth=12, vit_hidden_size=768, vit_mlp_dim=3072, vit_num_heads=12, + max_msda='', is_max_ms=True, is_max_ms_fpn=False, max_n_fpn=4, max_ms_idxs=[-4,-3,-2], max_ss_idx=0, + is_max_bottleneck_transformer=False, max_seg_weight=1.0, max_hidden_dim=256, max_dec_layers=10, + mw = 0.5, + is_max=True, is_masked_attn=False, is_max_ds=False, is_masking=False, is_masking_argmax=False, + is_fam=False, fam_k=5, fam_reduct_ratio=8, + is_max_hungarian=False, num_queries=None, is_max_cls=False, + point_rend=False, num_point_rend=None, no_object_weight=None, is_mhsa_float32=False, no_max_hw_pe=False, + max_infer=None, cost_weight=[2.0, 5.0, 5.0], vit_layer_scale=False, decoder_layer_scale=False): + + super(Generic_TransUNet_max_ppbp, self).__init__() + + # newly added + self.is_fam = is_fam + self.is_max, self.max_msda, self.is_max_ms, self.is_max_ms_fpn, self.max_n_fpn, self.max_ss_idx, self.mw = is_max, max_msda, is_max_ms, is_max_ms_fpn, max_n_fpn, max_ss_idx, mw + self.max_ms_idxs = max_ms_idxs + + self.is_max_cls = is_max_cls + self.is_masked_attn, self.is_max_ds = is_masked_attn, is_max_ds + self.is_max_bottleneck_transformer = is_max_bottleneck_transformer + + self.convolutional_upsampling = convolutional_upsampling + self.convolutional_pooling = convolutional_pooling + self.upscale_logits = upscale_logits + if nonlin_kwargs is None: + nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True} + if dropout_op_kwargs is None: + dropout_op_kwargs = {'p': 0.5, 'inplace': True} + if norm_op_kwargs is None: + norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'momentum': 0.1} + + self.conv_kwargs = {'stride': 1, 'dilation': 1, 'bias': True} + + self.nonlin = nonlin + self.nonlin_kwargs = nonlin_kwargs + self.dropout_op_kwargs = dropout_op_kwargs + self.norm_op_kwargs = norm_op_kwargs + self.weightInitializer = weightInitializer + self.conv_op = conv_op + self.norm_op = norm_op + self.dropout_op = dropout_op + self.num_classes = num_classes + self.final_nonlin = final_nonlin + self._deep_supervision = deep_supervision + self.do_ds = deep_supervision + + if conv_op == nn.Conv2d: + upsample_mode = 'bilinear' + pool_op = nn.MaxPool2d + transpconv = nn.ConvTranspose2d + if pool_op_kernel_sizes is None: + pool_op_kernel_sizes = [(2, 2)] * num_pool + if conv_kernel_sizes is None: + conv_kernel_sizes = [(3, 3)] * (num_pool + 1) + elif conv_op == nn.Conv3d: + upsample_mode = 'trilinear' + pool_op = nn.MaxPool3d + transpconv = nn.ConvTranspose3d + if pool_op_kernel_sizes is None: + pool_op_kernel_sizes = [(2, 2, 2)] * num_pool + if conv_kernel_sizes is None: + conv_kernel_sizes = [(3, 3, 3)] * (num_pool + 1) + else: + raise ValueError("unknown convolution dimensionality, conv op: %s" % str(conv_op)) + + self.input_shape_must_be_divisible_by = np.prod(pool_op_kernel_sizes, 0, dtype=np.int64) + self.pool_op_kernel_sizes = pool_op_kernel_sizes + self.conv_kernel_sizes = conv_kernel_sizes + + self.conv_pad_sizes = [] + for krnl in self.conv_kernel_sizes: + self.conv_pad_sizes.append([1 if i == 3 else 0 for i in krnl]) + + if max_num_features is None: + if self.conv_op == nn.Conv3d: + self.max_num_features = self.MAX_NUM_FILTERS_3D + else: + self.max_num_features = self.MAX_FILTERS_2D + else: + self.max_num_features = max_num_features + + self.conv_blocks_context = [] + self.conv_blocks_localization = [] + self.td = [] + self.tu = [] + + + self.fams = [] + + output_features = base_num_features + input_features = input_channels + + for d in range(num_pool): + # determine the first stride + if d != 0 and self.convolutional_pooling: + first_stride = pool_op_kernel_sizes[d - 1] + else: + first_stride = None + + self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[d] + self.conv_kwargs['padding'] = self.conv_pad_sizes[d] + # add convolutions + self.conv_blocks_context.append(StackedConvLayers(input_features, output_features, num_conv_per_stage, + self.conv_op, self.conv_kwargs, self.norm_op, + self.norm_op_kwargs, self.dropout_op, + self.dropout_op_kwargs, self.nonlin, self.nonlin_kwargs, + first_stride, basic_block=basic_block)) + if not self.convolutional_pooling: + self.td.append(pool_op(pool_op_kernel_sizes[d])) + input_features = output_features + output_features = int(np.round(output_features * feat_map_mul_on_downscale)) + + output_features = min(output_features, self.max_num_features) + + # now the bottleneck. + # determine the first stride + if self.convolutional_pooling: + first_stride = pool_op_kernel_sizes[-1] + else: + first_stride = None + + # the output of the last conv must match the number of features from the skip connection if we are not using + # convolutional upsampling. If we use convolutional upsampling then the reduction in feature maps will be + # done by the transposed conv + if self.convolutional_upsampling: + final_num_features = output_features + else: + final_num_features = self.conv_blocks_context[-1].output_channels + + self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[num_pool] + self.conv_kwargs['padding'] = self.conv_pad_sizes[num_pool] + self.conv_blocks_context.append(nn.Sequential( + StackedConvLayers(input_features, output_features, num_conv_per_stage - 1, self.conv_op, self.conv_kwargs, + self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, self.nonlin, + self.nonlin_kwargs, first_stride, basic_block=basic_block), + StackedConvLayers(output_features, final_num_features, 1, self.conv_op, self.conv_kwargs, + self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, self.nonlin, + self.nonlin_kwargs, basic_block=basic_block))) + + # if we don't want to do dropout in the localization pathway then we set the dropout prob to zero here + if not dropout_in_localization: + old_dropout_p = self.dropout_op_kwargs['p'] + self.dropout_op_kwargs['p'] = 0.0 + + # now lets build the localization pathway + for u in range(num_pool): + nfeatures_from_down = final_num_features + nfeatures_from_skip = self.conv_blocks_context[ + -(2 + u)].output_channels # self.conv_blocks_context[-1] is bottleneck, so start with -2 + n_features_after_tu_and_concat = nfeatures_from_skip * 2 + + # the first conv reduces the number of features to match those of skip + # the following convs work on that number of features + # if not convolutional upsampling then the final conv reduces the num of features again + if u != num_pool - 1 and not self.convolutional_upsampling: + final_num_features = self.conv_blocks_context[-(3 + u)].output_channels + else: + final_num_features = nfeatures_from_skip + + if not self.convolutional_upsampling: + self.tu.append(Upsample(scale_factor=pool_op_kernel_sizes[-(u + 1)], mode=upsample_mode)) + else: + self.tu.append(transpconv(nfeatures_from_down, nfeatures_from_skip, pool_op_kernel_sizes[-(u + 1)], + pool_op_kernel_sizes[-(u + 1)], bias=False)) + + self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[- (u + 1)] + self.conv_kwargs['padding'] = self.conv_pad_sizes[- (u + 1)] + self.conv_blocks_localization.append(nn.Sequential( + StackedConvLayers(n_features_after_tu_and_concat, nfeatures_from_skip, num_conv_per_stage - 1, + self.conv_op, self.conv_kwargs, self.norm_op, self.norm_op_kwargs, self.dropout_op, + self.dropout_op_kwargs, self.nonlin, self.nonlin_kwargs, basic_block=basic_block), + StackedConvLayers(nfeatures_from_skip, final_num_features, 1, self.conv_op, self.conv_kwargs, + self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, + self.nonlin, self.nonlin_kwargs, basic_block=basic_block) + )) + + + + if self.is_fam: + self.fams = nn.ModuleList(self.fams) + + if self.do_ds: + self.seg_outputs = [] + for ds in range(len(self.conv_blocks_localization)): + self.seg_outputs.append(conv_op(self.conv_blocks_localization[ds][-1].output_channels, num_classes, + 1, 1, 0, 1, 1, seg_output_use_bias)) + self.seg_outputs = nn.ModuleList(self.seg_outputs) + + self.upscale_logits_ops = [] + cum_upsample = np.cumprod(np.vstack(pool_op_kernel_sizes), axis=0)[::-1] + for usl in range(num_pool - 1): + if self.upscale_logits: + self.upscale_logits_ops.append(Upsample(scale_factor=tuple([int(i) for i in cum_upsample[usl + 1]]), + mode=upsample_mode)) + else: + self.upscale_logits_ops.append(lambda x: x) + + if not dropout_in_localization: + self.dropout_op_kwargs['p'] = old_dropout_p + + # register all modules properly + self.conv_blocks_localization = nn.ModuleList(self.conv_blocks_localization) + self.conv_blocks_context = nn.ModuleList(self.conv_blocks_context) + self.td = nn.ModuleList(self.td) + self.tu = nn.ModuleList(self.tu) + + if self.upscale_logits: + self.upscale_logits_ops = nn.ModuleList( + self.upscale_logits_ops) # lambda x:x is not a Module so we need to distinguish here + + if self.weightInitializer is not None: + self.apply(self.weightInitializer) + # self.apply(print_module_training_status) + + # Transformer configuration + if self.is_max_bottleneck_transformer: + self.patch_size = patch_size # e.g. [48, 192, 192] + config_vit = CONFIGS_ViT['R50-ViT-B_16'] + config_vit.transformer.num_layers = vit_depth + config_vit.hidden_size = vit_hidden_size # 768 + config_vit.transformer.mlp_dim = vit_mlp_dim # 3072 + config_vit.transformer.num_heads = vit_num_heads # 12 + self.conv_more = nn.Conv3d(config_vit.hidden_size, output_features, 1) + num_pool_per_axis = np.prod(np.array(pool_op_kernel_sizes), axis=0) + num_pool_per_axis = np.log2(num_pool_per_axis).astype(np.uint8) + feat_size = [int(self.patch_size[0]/2**num_pool_per_axis[0]), int(self.patch_size[1]/2**num_pool_per_axis[1]), int(self.patch_size[2]/2**num_pool_per_axis[2])] + self.transformer = Transformer(config_vit, feat_size=feat_size, vis=False, feat_channels=output_features, use_layer_scale=vit_layer_scale) + if is_vit_pretrain: + self.transformer.load_from(weights=np.load(config_vit.pretrained_path)) + + + if self.is_max: + # Max PPB+ configuration (i.e. MultiScaleStandardTransformerDecoder) + cfg = { + "num_classes": num_classes, + "hidden_dim": max_hidden_dim, + "num_queries": num_classes if num_queries is None else num_queries, # N=K if 'fixed matching', else default=100, + "nheads": 8, + "dim_feedforward": max_hidden_dim * 8, # 2048, + "dec_layers": max_dec_layers, # 9 decoder layers, add one for the loss on learnable query? + "pre_norm": False, + "enforce_input_project": False, + "mask_dim": max_hidden_dim, # input feat of segm head? + "non_object": False, + "use_layer_scale": decoder_layer_scale, + } + cfg['non_object'] = is_max_cls + input_proj_list = [] # from low resolution to high resolution (res4 -> res1), [1, 1024, 14, 14], [1, 512, 28, 28], 1, 256, 56, 56], [1, 64, 112, 112] + decoder_channels = [320, 320, 256, 128, 64, 32] + if self.is_max_ms: # use multi-scale feature as Transformer decoder input + if self.is_max_ms_fpn: + for idx, in_channels in enumerate(decoder_channels[:max_n_fpn]): # max_n_fpn=4: 1/32, 1/16, 1/8, 1/4 + input_proj_list.append(nn.Sequential( + nn.Conv3d(in_channels, max_hidden_dim, kernel_size=1), + nn.GroupNorm(32, max_hidden_dim), + nn.Upsample(size=(int(patch_size[0]/2), int(patch_size[1]/4), int(patch_size[2]/4)), mode='trilinear') + )) # proj to scale (1, 1/2, 1/2), TODO: init + self.input_proj = nn.ModuleList(input_proj_list) + self.linear_encoder_feature = nn.Conv3d(max_hidden_dim * max_n_fpn, max_hidden_dim, 1, 1) # concat four-level feature + else: + for idx, in_channels in enumerate([decoder_channels[i] for i in self.max_ms_idxs]): + input_proj_list.append(nn.Sequential( + nn.Conv3d(in_channels, max_hidden_dim, kernel_size=1), + nn.GroupNorm(32, max_hidden_dim), + )) + self.input_proj = nn.ModuleList(input_proj_list) + + # self.linear_mask_features =nn.Conv3d(decoder_channels[max_n_fpn-1], cfg["mask_dim"], kernel_size=1, stride=1, padding=0,) # low-level feat, dot product Trans-feat + self.linear_mask_features =nn.Conv3d(decoder_channels[-1], cfg["mask_dim"], kernel_size=1, stride=1, padding=0,) # following SingleScale, high-level feat, obtain seg_map + else: + self.linear_encoder_feature = nn.Conv3d(decoder_channels[max_ss_idx], cfg["mask_dim"], kernel_size=1) + self.linear_mask_features = nn.Conv3d(decoder_channels[-1], cfg["mask_dim"], kernel_size=1, stride=1, padding=0,) # low-level feat, dot product Trans-feat + + if self.is_masked_attn: + from .mask2former_modeling.transformer_decoder.mask2former_transformer_decoder3d import MultiScaleMaskedTransformerDecoder3d + cfg['num_feature_levels'] = 1 if not self.is_max_ms or self.is_max_ms_fpn else 3 + cfg["is_masking"] = True if is_masking else False + cfg["is_masking_argmax"] = True if is_masking_argmax else False + cfg["is_mhsa_float32"] = True if is_mhsa_float32 else False + cfg["no_max_hw_pe"] = True if no_max_hw_pe else False + self.predictor = MultiScaleMaskedTransformerDecoder3d(in_channels=max_hidden_dim, mask_classification=is_max_cls, **cfg) + else: + from .mask2former_modeling.transformer_decoder.maskformer_transformer_decoder3d import StandardTransformerDecoder + cfg["dropout"], cfg["enc_layers"], cfg["deep_supervision"] = 0.1, 0, False + self.predictor = StandardTransformerDecoder(in_channels=max_hidden_dim, mask_classification=is_max_cls, **cfg) + + def forward(self, x): + skips = [] + seg_outputs = [] + for d in range(len(self.conv_blocks_context) - 1): + x = self.conv_blocks_context[d](x) + skips.append(x) + if not self.convolutional_pooling: + x = self.td[d](x) + + x = self.conv_blocks_context[-1](x) + ######### TransUNet ######### + if self.is_max_bottleneck_transformer: + x, attn = self.transformer(x) # [b, hidden, d/8, h/16, w/16] + x = self.conv_more(x) + ############################# + + ds_feats = [] # obtain multi-scale feature + ds_feats.append(x) + for u in range(len(self.tu)): + if unm", inputs, targets) + denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :] + loss = 1 - (numerator + 1) / (denominator + 1) + return loss + + +batch_dice_loss_jit = torch.jit.script( + batch_dice_loss +) # type: torch.jit.ScriptModule + + + +def batch_sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor): + """ + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + Returns: + Loss tensor + """ + hw = inputs.shape[1] + + pos = F.binary_cross_entropy_with_logits( + inputs, torch.ones_like(inputs), reduction="none" + ) + neg = F.binary_cross_entropy_with_logits( + inputs, torch.zeros_like(inputs), reduction="none" + ) + + loss = torch.einsum("nc,mc->nm", pos, targets) + torch.einsum( + "nc,mc->nm", neg, (1 - targets) + ) + + return loss / hw + + +batch_sigmoid_ce_loss_jit = torch.jit.script( + batch_sigmoid_ce_loss +) # type: torch.jit.ScriptModule + + +class HungarianMatcher3D(nn.Module): + """This class computes an assignment between the targets and the predictions of the network + + For efficiency reasons, the targets don't include the no_object. Because of this, in general, + there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, + while the others are un-matched (and thus treated as non-objects). + """ + + def __init__(self, cost_class: float = 1, cost_mask: float = 1, cost_dice: float = 1, ): + """Creates the matcher + + Params: + cost_class: This is the relative weight of the classification error in the matching cost + cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost + cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost + """ + super().__init__() + self.cost_class = cost_class + self.cost_mask = cost_mask + self.cost_dice = cost_dice + + def compute_cls_loss(self, inputs, targets): + """ Classification loss (NLL) + implemented in compute_loss() + """ + raise NotImplementedError + + + def compute_dice_loss(self, inputs, targets): + """ mask dice loss + inputs (B*K, C, H, W) + target (B*K, D, H, W) + """ + inputs = inputs.sigmoid() + inputs = inputs.flatten(1) + targets = targets.flatten(1) + num_masks = len(inputs) + + numerator = 2 * (inputs * targets).sum(-1) + denominator = inputs.sum(-1) + targets.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + return loss.sum() / num_masks + + + def compute_ce_loss(self, inputs, targets): + """mask ce loss""" + num_masks = len(inputs) + loss = F.binary_cross_entropy_with_logits(inputs.flatten(1), targets.flatten(1), reduction="none") + loss = loss.mean(1).sum() / num_masks + return loss + + def compute_dice(self, inputs, targets): + """ output (N_q, C, H, W) + target (K, D, H, W) + """ + inputs = inputs.sigmoid() + inputs = inputs.flatten(1) + targets = targets.flatten(1) + numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets) + denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :] + loss = 1 - (numerator + 1) / (denominator + 1) + return loss # [N_q, K] + + + def compute_ce(self, inputs, targets): + """ output (N_q, C, H, W) + target (K, D, H, W) + return (N_q, K) + """ + inputs = inputs.flatten(1) + targets = targets.flatten(1) + hw = inputs.shape[1] + + pos = F.binary_cross_entropy_with_logits( + inputs, torch.ones_like(inputs), reduction="none" + ) + + neg = F.binary_cross_entropy_with_logits( + inputs, torch.zeros_like(inputs), reduction="none" + ) + + loss = torch.einsum("nc,mc->nm", pos, targets) + torch.einsum( + "nc,mc->nm", neg, (1 - targets) + ) + + return loss / hw + + # target_onehot = torch.zeros_like(output, device=output.device) + # target_onehot.scatter_(1, target.long(), 1) + # assert (torch.argmax(target_onehot, dim=1) == target[:, 0].long()).all() + # ce_loss = F.binary_cross_entropy_with_logits(output, target_onehot) + # return ce_loss + + + @torch.no_grad() + def memory_efficient_forward(self, outputs, targets): + """More memory-friendly matching for single aux, outputs: (b, q, d, h, w)""" + """suppose each crop must contain foreground class""" + bs, num_queries = outputs["pred_logits"].shape[:2] + indices = [] + + # Iterate through batch size + for b in range(bs): + out_prob = outputs["pred_logits"][b].softmax(-1) # [num_queries, num_classes+1] + out_mask = outputs["pred_masks"][b] # [num_queries, H_pred, W_pred] + + tgt_ids = targets[b]["labels"] + tgt_mask = targets[b]["masks"].to(out_mask) # [K, D, H, W], K is number of classes shown in this image, and K < n_class + + # target_onehot = torch.zeros_like(tgt_mask, device=out_mask.device) + # target_onehot.scatter_(1, targets.long(), 1) + + cost_class = -out_prob[:, tgt_ids] # [num_queries, K] + + with autocast(enabled=False): + out_mask = out_mask.float() + tgt_mask = tgt_mask.float() + cost_dice = self.compute_dice(out_mask, tgt_mask) + cost_mask = self.compute_ce(out_mask, tgt_mask) + + # Final cost matrix + C = ( + self.cost_class * cost_class + + self.cost_mask * cost_mask + + self.cost_dice * cost_dice + ) + + C = C.reshape(num_queries, -1).cpu() # (num_queries, K) + + # linear_sum_assignment return a tuple of two arrays: row_ind, col_ind, the length of array is min(N_q, K) + # The cost of the assignment can be computed as cost_matrix[row_ind, col_ind].sum() + + indices.append(linear_sum_assignment(C)) + + final_indices = [ + (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) + for i, j in indices + ] + + return final_indices + + @torch.no_grad() + def forward(self, outputs, targets): + """Performs the matching + + Params: + outputs: This is a dict that contains at least these entries: + "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits + "pred_masks": Tensor of dim [batch_size, num_queries, H_pred, W_pred] with the predicted masks + + targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: + "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth + objects in the target) containing the class labels + "masks": Tensor of dim [num_target_boxes, H_gt, W_gt] containing the target masks + + Returns: + A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected targets (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_boxes) + """ + + return self.memory_efficient_forward(outputs, targets) + + def __repr__(self, _repr_indent=4): + head = "Matcher " + self.__class__.__name__ + body = [ + "cost_class: {}".format(self.cost_class), + "cost_mask: {}".format(self.cost_mask), + "cost_dice: {}".format(self.cost_dice), + ] + lines = [head] + [" " * _repr_indent + line for line in body] + return "\n".join(lines) + + + def _get_src_permutation_idx(self, indices): + # permute predictions following indices + batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + src_idx = torch.cat([src for (src, _) in indices]) + return batch_idx, src_idx + + def _get_tgt_permutation_idx(self, indices): + # permute targets following indices + batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + tgt_idx = torch.cat([tgt for (_, tgt) in indices]) + return batch_idx, tgt_idx + + +def compute_loss_hungarian(outputs, targets, idx, matcher, num_classes, point_rend=False, num_points=12544, oversample_ratio=3.0, importance_sample_ratio=0.75, no_object_weight=None, cost_weight=[2,5,5]): + """output is a dict only contain keys ['pred_masks', 'pred_logits'] """ + # outputs_without_aux = {k: v for k, v in output.items() if k != "aux_outputs"} + + indices = matcher(outputs, targets) + src_idx = matcher._get_src_permutation_idx(indices) # return a tuple of (batch_idx, src_idx) + tgt_idx = matcher._get_tgt_permutation_idx(indices) # return a tuple of (batch_idx, tgt_idx) + assert len(tgt_idx[0]) == sum([len(t["masks"]) for t in targets]) # verify that all masks of (K1, K2, ..) are used + + # step2 : compute mask loss + src_masks = outputs["pred_masks"] + src_masks = src_masks[src_idx] # [len(src_idx[0]), D, H, W] -> (K1+K2+..., D, H, W) + target_masks = torch.cat([t["masks"] for t in targets], dim=0) # (K1+K2+..., D, H, W) actually + src_masks = src_masks[:, None] # [K..., 1, D, H, W] + target_masks = target_masks[:, None] + + if point_rend: # only calculate hard example + with torch.no_grad(): + # num_points=12544 config in cityscapes + + # sample point_coords + point_coords = get_uncertain_point_coords_with_randomness( + src_masks.float(), + lambda logits: calculate_uncertainty(logits), + num_points, + oversample_ratio, + importance_sample_ratio, + ) # [K, num_points=12544, 3] + + point_labels = point_sample_3d( + target_masks.float(), + point_coords.float(), + align_corners=False, + ).squeeze(1) # [K, 12544] + + point_logits = point_sample_3d( + src_masks.float(), + point_coords.float(), + align_corners=False, + ).squeeze(1) # [K, 12544] + + src_masks, target_masks = point_logits, point_labels + + loss_mask_ce = matcher.compute_ce_loss(src_masks, target_masks) + loss_mask_dice = matcher.compute_dice_loss(src_masks, target_masks) + + # step3: compute class loss + src_logits = outputs["pred_logits"].float() # (B, num_query, num_class+1) + target_classes_o = torch.cat([t["labels"] for t in targets], dim=0) # (K1+K2+, ) + target_classes = torch.full( + src_logits.shape[:2], num_classes, dtype=torch.int64, device=src_logits.device + ) # (B, num_query, num_class+1) + target_classes[src_idx] = target_classes_o + + + if no_object_weight is not None: + empty_weight = torch.ones(num_classes + 1).to(src_logits.device) + empty_weight[-1] = no_object_weight + loss_cls = F.cross_entropy(src_logits.transpose(1, 2), target_classes, empty_weight) + else: + loss_cls = F.cross_entropy(src_logits.transpose(1, 2), target_classes) + + loss = (cost_weight[0]/10)*loss_cls + (cost_weight[1]/10)*loss_mask_ce + (cost_weight[2]/10)*loss_mask_dice # 2:5:5, like hungarian matching + # print("idx {}, loss {}, loss_cls {}, loss_mask_ce {}, loss_mask_dice {}".format(idx, loss, loss_cls, loss_mask_ce, loss_mask_dice)) + return loss + + +def point_sample_3d(input, point_coords, **kwargs): + """ + from detectron2.projects.point_rend.point_features + A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors. + Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside + [0, 1] x [0, 1] square. + Args: + input (Tensor): A tensor of shape (N, C, D, H, W) that contains features map on a D x H x W grid. + point_coords (Tensor): A tensor of shape (N, P, 3) or (N, Dgrid, Hgrid, Wgrid, 3) that contains + [0, 1] x [0, 1] x [0, 1] normalized point coordinates. + Returns: + output (Tensor): A tensor of shape (N, C, P) or (N, C, Dgrid, Hgrid, Wgrid) that contains + features for points in `point_coords`. The features are obtained via bilinear + interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`. + """ + add_dim = False + if point_coords.dim() == 3: + add_dim = True + point_coords = point_coords.unsqueeze(2).unsqueeze(2) # why + + # point_coords should be (N, D, H, W, 3) + output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs) + + if add_dim: + output = output.squeeze(3).squeeze(3) + + return output + + +def calculate_uncertainty(logits): + """ + We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the + foreground class in `classes`. + Args: + logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or + class-agnostic, where R is the total number of predicted masks in all images and C is + the number of foreground classes. The values are logits. + Returns: + scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with + the most uncertain locations having the highest uncertainty score. + """ + assert logits.shape[1] == 1 + gt_class_logits = logits.clone() + return -(torch.abs(gt_class_logits)) + + +# implemented! +def get_uncertain_point_coords_with_randomness( + coarse_logits, uncertainty_func, num_points, oversample_ratio, importance_sample_ratio): + """ + Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties + are calculated for each point using 'uncertainty_func' function that takes point's logit + prediction as input. + See PointRend paper for details. + Args: + coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for + class-specific or class-agnostic prediction. + uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that + contains logit predictions for P points and returns their uncertainties as a Tensor of + shape (N, 1, P). + num_points (int): The number of points P to sample. + oversample_ratio (int): Oversampling parameter. + importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling. + Returns: + point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P + sampled points. + """ + assert oversample_ratio >= 1 + assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0 + n_dim = 3 + num_boxes = coarse_logits.shape[0] + num_sampled = int(num_points * oversample_ratio) # 12544 * 3, oversampled + point_coords = torch.rand(num_boxes, num_sampled, n_dim, device=coarse_logits.device) # (K, 37632, 3); uniform dist [0, 1) + point_logits = point_sample_3d(coarse_logits, point_coords, align_corners=False) # (K, 1, 37632) + + # It is crucial to calculate uncertainty based on the sampled prediction value for the points. + # Calculating uncertainties of the coarse predictions first and sampling them for points leads + # to incorrect results. + # To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between + # two coarse predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value. + # However, if we calculate uncertainties for the coarse predictions first, + # both will have -1 uncertainty, and the sampled point will get -1 uncertainty. + point_uncertainties = uncertainty_func(point_logits) + num_uncertain_points = int(importance_sample_ratio * num_points) # 9408 + + num_random_points = num_points - num_uncertain_points # 3136 + idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] + + shift = num_sampled * torch.arange(num_boxes, dtype=torch.long, device=coarse_logits.device) + idx += shift[:, None] # [K, 9408] + + point_coords = point_coords.view(-1, n_dim)[idx.view(-1), :].view( + num_boxes, num_uncertain_points, n_dim + ) # [K, 9408, 3] + + if num_random_points > 0: + # from detectron2.layers import cat + point_coords = torch.cat( + [ + point_coords, + torch.rand(num_boxes, num_random_points, n_dim, device=coarse_logits.device), + ], + dim=1, + ) # [K, 12544, 3] + + return point_coords \ No newline at end of file diff --git a/pymic/net/net3d/unet3d.py b/pymic/net/net3d/unet3d.py index b66ea38..5954869 100644 --- a/pymic/net/net3d/unet3d.py +++ b/pymic/net/net3d/unet3d.py @@ -236,7 +236,8 @@ def __init__(self, params): for p in params: print(p, params[p]) self.stage = 'train' - self.update_mode = params.get("update_mode", "all") + self.tune_mode= params.get('finetune_mode', 'all') + self.load_mode= params.get('weights_load_mode', 'all') self.encoder = Encoder(params) self.decoder = Decoder(params) @@ -263,17 +264,24 @@ def set_stage(self, stage): self.stage = stage self.decoder.set_stage(stage) + + def forward(self, x): + f = self.encoder(x) + output = self.decoder(f) + return output + def get_parameters_to_update(self): - if(self.update_mode == "all"): + if(self.tune_mode == "all"): return self.parameters() - elif(self.update_mode == "decoder"): + elif(self.tune_mode == "decoder"): print("only update parameters in decoder") params = self.decoder.parameters() return params else: raise(ValueError("update_mode can only be 'all' or 'decoder'.")) - def forward(self, x): - f = self.encoder(x) - output = self.decoder(f) - return output + def get_parameters_to_load(self): + state_dict = self.state_dict() + if(self.load_mode == 'encoder'): + state_dict = {k:v for k, v in state_dict.items() if "encoder" in k } + return state_dict diff --git a/pymic/net/net_dict_seg.py b/pymic/net/net_dict_seg.py index e381421..741eed7 100644 --- a/pymic/net/net_dict_seg.py +++ b/pymic/net/net_dict_seg.py @@ -16,10 +16,10 @@ """ from __future__ import print_function, division from pymic.net.net2d.unet2d import UNet2D -from pymic.net.net2d.unet2d_dual_branch import UNet2D_DualBranch +from pymic.net.net2d.unet2d_multi_decoder import UNet2D_DualBranch, MCNet2D from pymic.net.net2d.unet2d_canet import CANet from pymic.net.net2d.unet2d_cct import UNet2D_CCT -from pymic.net.net2d.unet2d_mcnet import MCNet2D +from pymic.net.net2d.unet2d_mtnet import MTNet2D from pymic.net.net2d.cople_net import COPLENet from pymic.net.net2d.unet2d_attention import AttentionUNet2D from pymic.net.net2d.unet2d_nest import NestedUNet2D @@ -28,8 +28,13 @@ from pymic.net.net2d.trans2d.swinunet import SwinUNet from pymic.net.net3d.unet2d5 import UNet2D5 from pymic.net.net3d.unet3d import UNet3D +from pymic.net.net3d.grunet import GRUNet +from pymic.net.net3d.fmunetv3 import FMUNetV3 from pymic.net.net3d.unet3d_scse import UNet3D_ScSE from pymic.net.net3d.unet3d_dual_branch import UNet3D_DualBranch +# from pymic.net.net3d.stunet_wrap import STUNet_wrap +# from pymic.net.net3d.mystunet import MySTUNet + # from pymic.net.net3d.trans3d.nnFormer_wrap import nnFormer_wrap # from pymic.net.net3d.trans3d.unetr import UNETR # from pymic.net.net3d.trans3d.unetr_pp import UNETR_PP @@ -49,6 +54,7 @@ 'UNet2D_DualBranch': UNet2D_DualBranch, 'UNet2D_CCT': UNet2D_CCT, 'MCNet2D': MCNet2D, + 'MTNet2D': MTNet2D, 'CANet': CANet, 'COPLENet': COPLENet, 'AttentionUNet2D': AttentionUNet2D, @@ -57,20 +63,14 @@ 'TransUNet': TransUNet, 'SwinUNet': SwinUNet, 'UNet2D5': UNet2D5, + 'GRUNet': GRUNet, + 'FMUNetV3': FMUNetV3, 'UNet3D': UNet3D, 'UNet3D_ScSE': UNet3D_ScSE, 'UNet3D_DualBranch': UNet3D_DualBranch, + # 'STUNet': STUNet_wrap, + # 'MySTUNet': MySTUNet, # 'nnFormer': nnFormer_wrap, # 'UNETR': UNETR, # 'UNETR_PP': UNETR_PP, - # 'MedFormerV1': MedFormerV1, - # 'MedFormerV2': MedFormerV2, - # 'MedFormerV3': MedFormerV3, - # 'MedFormerVA1':MedFormerVA1, - # 'HiFormer_v1': HiFormer_v1, - # 'HiFormer_v2': HiFormer_v2, - # 'HiFormer_v3': HiFormer_v3, - # 'HiFormer_v4': HiFormer_v4, - # 'HiFormer_v5': HiFormer_v5 - # 'SwitchNet': SwitchNet } diff --git a/pymic/net_run/self_sup/__init__.py b/pymic/net_run/self_sup/__init__.py index d73e42a..86482f9 100644 --- a/pymic/net_run/self_sup/__init__.py +++ b/pymic/net_run/self_sup/__init__.py @@ -1,10 +1,16 @@ from __future__ import absolute_import from pymic.net_run.self_sup.self_genesis import SelfSupModelGenesis from pymic.net_run.self_sup.self_patch_swapping import SelfSupPatchSwapping -from pymic.net_run.self_sup.self_volume_fusion import SelfSupVolumeFusion +# from pymic.net_run.self_sup.self_mim import SelfSupMIM +# from pymic.net_run.self_sup.self_dino import SelfSupDINO +from pymic.net_run.self_sup.self_vox2vec import SelfSupVox2Vec +from pymic.net_run.self_sup.self_volf import SelfSupVolumeFusion SelfSupMethodDict = { + # 'DINO': SelfSupDINO, + 'Vox2Vec': SelfSupVox2Vec, 'ModelGenesis': SelfSupModelGenesis, 'PatchSwapping': SelfSupPatchSwapping, 'VolumeFusion': SelfSupVolumeFusion + # 'MaskedImageModeling': SelfSupMIM } \ No newline at end of file diff --git a/pymic/net_run/self_sup/self_vox2vec.py b/pymic/net_run/self_sup/self_vox2vec.py new file mode 100644 index 0000000..94e2b16 --- /dev/null +++ b/pymic/net_run/self_sup/self_vox2vec.py @@ -0,0 +1,304 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import copy +import logging +import time +import logging +import torch +import torch.nn as nn +from datetime import datetime +from torch.optim import lr_scheduler +from tensorboardX import SummaryWriter +from pymic.io.image_read_write import save_nd_array_as_image +from pymic.net.net3d.fmunetv3 import FMUNetV3 +from pymic.net_run.agent_seg import SegmentationAgent +from pymic.loss.cls.infoNCE import InfoNCELoss + +def select_from_pyramid(feature_pyramid, indices): + """Select features from feature pyramid by their indices w.r.t. base feature map. + + Args: + feature_pyramid (Sequence[torch.Tensor]): Sequence of tensors of shapes ``(B, C_i, D_i, H_i, W_i)``. + indices (torch.Tensor): tensor of shape ``(B, N, 3)`` + + Returns: + torch.Tensor: tensor of shape ``(B, N, \sum_i c_i)`` + """ + out = [] + for i, x in enumerate(feature_pyramid): + batch_size = list(x.shape)[0] + x_move = x.moveaxis(1, -1) + index_i = indices // 2 ** i + x_i = [x_move[b][index_i[b][:, 0], index_i[b][:, 1], index_i[b][:, 2], :] for \ + b in range(batch_size)] + x_i = torch.stack(x_i) + out.append(x_i) + out = torch.cat(out, dim = -1) + return out + +class Vox2VecHead(nn.Module): + def __init__(self, params): + super(Vox2VecHead, self).__init__() + ft_chns = params['feature_chns'] + hidden_dim = params['hidden_dim'] + proj_dim = params['project_dim'] + embed_dim = sum(ft_chns) + self.proj_head = nn.Sequential( + nn.Linear(embed_dim, hidden_dim), + nn.BatchNorm1d(hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.BatchNorm1d(hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, proj_dim) + ) + + def forward(self, x): + output = self.proj_head(x) + output = nn.functional.normalize(output) + return output + +class Vox2VecWrapper(nn.Module): + """ + Perform forward pass separately on each resolution input. + The inputs corresponding to a single resolution are clubbed and single + forward is run on the same resolution inputs. Hence we do several + forward passes = number of different resolutions used. We then + concatenate all the output features and run the head forward on these + concatenated features. + """ + def __init__(self, backbone, head): + super(Vox2VecWrapper, self).__init__() + self.backbone = backbone + self.head = head + + def forward(self, x, vex_idx): + if(isinstance(self.backbone, FMUNetV3)): + x = self.backbone.project(x) + f = self.backbone.encoder(x) + B = list(f[0].shape)[0] + f_fpn = select_from_pyramid(f, vex_idx) + feature_dim = list(f_fpn.shape)[-1] + f_fpn = f_fpn.view(-1, feature_dim) + output = self.head(f_fpn) + proj_dim = list(output.shape)[-1] + output = output.view(B, -1, proj_dim) + return output + +class SelfSupVox2Vec(SegmentationAgent): + """ + An agent for image self-supervised learning with DeSD. + """ + def __init__(self, config, stage = 'train'): + super(SelfSupVox2Vec, self).__init__(config, stage) + + def create_network(self): + super(SelfSupVox2Vec, self).create_network() + proj_dim = self.config['self_supervised_learning'].get('project_dim', 1024) + hidden_dim = self.config['self_supervised_learning'].get('hidden_dim', 1024) + head_params= {'feature_chns': self.config['network']['feature_chns'], + 'hidden_dim':hidden_dim, + 'project_dim':proj_dim} + self.head = Vox2VecHead(head_params) + self.net_wrapper = Vox2VecWrapper(self.net, self.head) + + def create_loss_calculator(self): + # constrastive loss + self_sup_params = self.config['self_supervised_learning'] + self.loss_calculator = InfoNCELoss(self_sup_params) + + def get_parameters_to_update(self): + params = self.net_wrapper.parameters() + return params + + def training(self): + iter_valid = self.config['training']['iter_valid'] + train_loss = 0 + err_info = None + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 + self.net_wrapper.train() + for it in range(iter_valid): + t0 = time.time() + try: + data = next(self.trainIter) + except StopIteration: + self.trainIter = iter(self.train_loader) + data = next(self.trainIter) + t1 = time.time() + patch1, patch2, vox_ids1, vox_ids2 = data['image'] + inputs = torch.cat([patch1, patch2], dim = 0) + vox_ids = torch.cat([vox_ids1, vox_ids2], dim = 0) + inputs = self.convert_tensor_type(inputs) + inputs = inputs.to(self.device) + vox_ids = vox_ids.to(self.device) + + # for debug + # for i in range(patch1.shape[0]): + # v1_i = patch1[i][0] + # v2_i = patch2[i][0] + # print("patch shape", v1_i.shape, v2_i.shape) + # image_name0 = "temp/image_{0:}_{1:}_v0.nii.gz".format(it, i) + # image_name1 = "temp/image_{0:}_{1:}_v1.nii.gz".format(it, i) + # save_nd_array_as_image(v1_i, image_name0, reference_name = None) + # save_nd_array_as_image(v2_i, image_name1, reference_name = None) + # if(it > 10): + # return + + # zero the parameter gradients + self.optimizer.zero_grad() + + # forward + backward + optimize + out = self.net_wrapper(inputs, vox_ids) + out1, out2 = out.chunk(2) + + t2 = time.time() + loss = self.loss_calculator(out1, out2) + t3 = time.time() + + loss.backward() + self.optimizer.step() + train_loss = train_loss + loss.item() + t4 = time.time() + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 + + train_avg_loss = train_loss / iter_valid + train_scalers = {'loss': train_avg_loss, 'data_time': data_time, + 'gpu_time':gpu_time, 'loss_time':loss_time, 'back_time':back_time, + 'err_info': err_info} + return train_scalers + + + def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): + loss_scalar ={'train':train_scalars['loss']} + self.summ_writer.add_scalars('loss', loss_scalar, glob_it) + self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it) + logging.info('train loss {0:.4f}'.format(train_scalars['loss'])) + + def train_valid(self): + device_ids = self.config['training']['gpus'] + if(len(device_ids) > 1): + self.device = torch.device("cuda:0") + self.net_wrapper = nn.DataParallel(self.net_wrapper, device_ids = device_ids) + else: + self.device = torch.device("cuda:{0:}".format(device_ids[0])) + self.net_wrapper.to(self.device) + + ckpt_dir = self.config['training']['ckpt_dir'] + ckpt_prefix = self.config['training'].get('ckpt_prefix', None) + if(ckpt_prefix is None): + ckpt_prefix = ckpt_dir.split('/')[-1] + # iter_start = self.config['training']['iter_start'] + iter_start = 0 + iter_max = self.config['training']['iter_max'] + iter_valid = self.config['training']['iter_valid'] + iter_save = self.config['training'].get('iter_save', None) + early_stop_it = self.config['training'].get('early_stop_patience', None) + if(iter_save is None): + iter_save_list = [iter_max] + elif(isinstance(iter_save, (tuple, list))): + iter_save_list = iter_save + else: + iter_save_list = range(0, iter_max + 1, iter_save) + + self.min_loss = 10000.0 + self.min_loss_it = 0 + self.best_model_wts = None + self.bett_head_wts = None + checkpoint = None + # initialize the network with pre-trained weights + ckpt_init_name = self.config['training'].get('ckpt_init_name', None) + ckpt_init_mode = self.config['training'].get('ckpt_init_mode', 0) + ckpt_for_optm = None + if(ckpt_init_name is not None): + checkpoint = torch.load(ckpt_dir + "/" + ckpt_init_name, map_location = self.device) + pretrained_dict = checkpoint['model_state_dict'] + pretrain_head_dict = checkpoint['head_state_dict'] + self.load_pretrained_weights(self.net, pretrained_dict, device_ids) + self.load_pretrained_weights(self.head, pretrain_head_dict, device_ids) + + if(ckpt_init_mode > 0): # Load other information + self.min_loss = checkpoint.get('train_loss', 10000) + iter_start = checkpoint['iteration'] + self.min_loss_it = iter_start + self.best_model_wts = checkpoint['model_state_dict'] + self.best_head_wts = checkpoint['head_state_dict'] + ckpt_for_optm = checkpoint + + self.create_optimizer(self.get_parameters_to_update(), ckpt_for_optm) + self.create_loss_calculator() + + self.trainIter = iter(self.train_loader) + + logging.info("{0:} training start".format(str(datetime.now())[:-7])) + self.summ_writer = SummaryWriter(self.config['training']['ckpt_dir']) + self.glob_it = iter_start + for it in range(iter_start, iter_max, iter_valid): + lr_value = self.optimizer.param_groups[0]['lr'] + + t0 = time.time() + train_scalars = self.training() + t1 = time.time() + if(isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): + self.scheduler.step(-train_scalars['loss']) + else: + self.scheduler.step() + + self.glob_it = it + iter_valid + logging.info("\n{0:} it {1:}".format(str(datetime.now())[:-7], self.glob_it)) + logging.info('learning rate {0:}'.format(lr_value)) + logging.info("training time: {0:.2f}s".format(t1-t0)) + logging.info("data: {0:.2f}s, gpu: {1:.2f}s, loss: {2:.2f}s, back: {3:.2f}s".format( + train_scalars['data_time'], train_scalars['gpu_time'], + train_scalars['loss_time'], train_scalars['back_time'])) + + self.write_scalars(train_scalars, None, lr_value, self.glob_it) + if(train_scalars['loss'] < self.min_loss): + self.min_loss = train_scalars['loss'] + self.min_loss_it = self.glob_it + if(len(device_ids) > 1): + self.best_model_wts = copy.deepcopy(self.net.module.state_dict()) + self.best_head_wts = copy.deepcopy(self.head.module.state_dict()) + else: + self.best_model_wts = copy.deepcopy(self.net.state_dict()) + self.best_head_wts = copy.deepcopy(self.head.state_dict()) + + save_dict = {'iteration': self.min_loss_it, + 'train_loss': self.min_loss, + 'model_state_dict': self.best_model_wts, + 'head_state_dict': self.best_head_wts, + 'optimizer_state_dict': self.optimizer.state_dict()} + save_name = "{0:}/{1:}_best.pt".format(ckpt_dir, ckpt_prefix) + torch.save(save_dict, save_name) + txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefix), 'wt') + txt_file.write(str(self.min_loss_it)) + txt_file.close() + + stop_now = True if(early_stop_it is not None and \ + self.glob_it - self.min_loss_it > early_stop_it) else False + if(train_scalars['err_info'] is not None): + logging.info("Early stopped due to error: {0:}".format(train_scalars['err_info'])) + stop_now = True + if ((self.glob_it in iter_save_list) or stop_now): + save_dict = {'iteration': self.glob_it, + 'train_loss': train_scalars['loss'], + 'model_state_dict': self.net.module.state_dict() \ + if len(device_ids) > 1 else self.net.state_dict(), + 'head_state_dict': self.head.module.state_dict() \ + if len(device_ids) > 1 else self.head.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict()} + save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.glob_it) + torch.save(save_dict, save_name) + txt_file = open("{0:}/{1:}_latest.txt".format(ckpt_dir, ckpt_prefix), 'wt') + txt_file.write(str(self.glob_it)) + txt_file.close() + if(stop_now): + logging.info("The training is early stopped") + break + # save the best performing checkpoint + logging.info('The best performing iter is {0:}, train loss {1:}'.format(\ + self.min_loss_it, self.min_loss)) + self.summ_writer.close() \ No newline at end of file From e2f68d06fecc9d6f36d896f6ab6cd92ce8236c3c Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 30 Sep 2024 22:34:19 +0800 Subject: [PATCH 196/225] add vitb16 --- pymic/net/cls/torch_pretrained_net.py | 35 ++++++++++++++++++++++++++- pymic/net/net_dict_cls.py | 3 ++- pymic/net/net_init.py | 26 ++++++++++++++++++++ 3 files changed, 62 insertions(+), 2 deletions(-) create mode 100644 pymic/net/net_init.py diff --git a/pymic/net/cls/torch_pretrained_net.py b/pymic/net/cls/torch_pretrained_net.py index c1959a4..9d05c28 100644 --- a/pymic/net/cls/torch_pretrained_net.py +++ b/pymic/net/cls/torch_pretrained_net.py @@ -181,7 +181,8 @@ class MobileNetV2(BuiltInNet): """ def __init__(self, params): super(MobileNetV2, self).__init__(params) - self.net = models.mobilenet_v2(pretrained = self.pretrain) + weights = 'IMAGENET1K_V1' if self.pretrain else None + self.net = models.mobilenet_v2(weights = weights) # replace the last layer num_ftrs = self.net.last_channel @@ -204,6 +205,38 @@ def get_parameters_to_update(self): return params else: raise(ValueError("update_mode can only be 'all' or 'last'.")) + +class ViTB16(BuiltInNet): + """ + ViTB16 for classification. + Parameters should be set in the `params` dictionary that contains the + following fields: + + :param input_chns: (int) Input channel number, default is 3. + :param pretrain: (bool) Using pretrained model or not, default is True. + :param update_mode: (str) The strategy for updating layers: "`all`" means updating + all the layers, and "`last`" (by default) means updating the last layer, + as well as the first layer when `input_chns` is not 3. + """ + def __init__(self, params): + super(ViTB16, self).__init__(params) + weights = 'IMAGENET1K_V1' if self.pretrain else None + self.net = models.vit_b_16(weights = weights) + + # replace the last layer + num_ftrs = self.net.representation_size + if(num_ftrs is None): + num_ftrs = self.net.hidden_dim + self.net.heads[-1] = nn.Linear(num_ftrs, params['class_num']) + + def get_parameters_to_update(self): + if(self.update_mode == "all"): + return self.net.parameters() + elif(self.update_mode == "last"): + params = self.net.heads[-1].parameters() + return params + else: + raise(ValueError("update_mode can only be 'all' or 'last'.")) if __name__ == "__main__": params = {"class_num": 2, "pretrain": False, "input_chns": 3} diff --git a/pymic/net/net_dict_cls.py b/pymic/net/net_dict_cls.py index 3a7808b..a83334a 100644 --- a/pymic/net/net_dict_cls.py +++ b/pymic/net/net_dict_cls.py @@ -13,5 +13,6 @@ TorchClsNetDict = { 'resnet18': ResNet18, 'vgg16': VGG16, - 'mobilenetv2':MobileNetV2 + 'mobilenetv2':MobileNetV2, + 'vitb16': ViTB16 } diff --git a/pymic/net/net_init.py b/pymic/net/net_init.py new file mode 100644 index 0000000..1f9b48e --- /dev/null +++ b/pymic/net/net_init.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +from torch import nn + + +class Initialization_He(object): + def __init__(self, neg_slope=1e-2): + self.neg_slope = neg_slope + + def __call__(self, module): + if isinstance(module, (nn.Conv3d, nn.Conv2d, nn.ConvTranspose2d, nn.ConvTranspose3d)): + module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope) + if module.bias is not None: + module.bias = nn.init.constant_(module.bias, 0) + + +class Initialization_XavierUniform(object): + def __init__(self, gain=1): + self.gain = gain + + def __call__(self, module): + if isinstance(module, (nn.Conv3d ,nn.Conv2d, nn.ConvTranspose2d, nn.ConvTranspose3d)): + module.weight = nn.init.xavier_uniform_(module.weight, self.gain) + if module.bias is not None: + module.bias = nn.init.constant_(module.bias, 0) From 2db9a3a0fc2429a38e44e1bad7376c42d4c5e77f Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 2 Oct 2024 15:48:05 +0800 Subject: [PATCH 197/225] add CANet for segmentation --- pymic/net/net2d/canet.py | 229 --------- pymic/net/net2d/canet_module.py | 578 ---------------------- pymic/net/net2d/unet2d_canet.py | 840 ++++++++++++++++++++++++++------ 3 files changed, 687 insertions(+), 960 deletions(-) delete mode 100644 pymic/net/net2d/canet.py delete mode 100644 pymic/net/net2d/canet_module.py diff --git a/pymic/net/net2d/canet.py b/pymic/net/net2d/canet.py deleted file mode 100644 index ab64ab5..0000000 --- a/pymic/net/net2d/canet.py +++ /dev/null @@ -1,229 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division -import torch -import torch.nn as nn - -class ConvLayer(nn.Module): - """ - A combination of Conv2d, BatchNorm2d and LeakyReLU. - """ - def __init__(self, in_channels, out_channels, kernel_size = 1): - super(ConvLayer, self).__init__() - padding = int((kernel_size - 1) / 2) - self.conv = nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding), - nn.BatchNorm2d(out_channels), - nn.LeakyReLU() - ) - - def forward(self, x): - return self.conv(x) - -class SEBlock(nn.Module): - """ - A Modified Squeeze-and-Excitation block for spatial attention. - """ - def __init__(self, in_channels, r): - super(SEBlock, self).__init__() - - redu_chns = int(in_channels / r) - self.se_layers = nn.Sequential( - nn.AdaptiveAvgPool2d(1), - nn.Conv2d(in_channels, redu_chns, kernel_size=1, padding=0), - nn.LeakyReLU(), - nn.Conv2d(redu_chns, in_channels, kernel_size=1, padding=0), - nn.ReLU()) - - def forward(self, x): - f = self.se_layers(x) - return f*x + x - -class ASPPBlock(nn.Module): - """ - ASPP block. - """ - def __init__(self,in_channels, out_channels_list, kernel_size_list, dilation_list): - super(ASPPBlock, self).__init__() - self.conv_num = len(out_channels_list) - assert(self.conv_num == 4) - assert(self.conv_num == len(kernel_size_list) and self.conv_num == len(dilation_list)) - pad0 = int((kernel_size_list[0] - 1) / 2 * dilation_list[0]) - pad1 = int((kernel_size_list[1] - 1) / 2 * dilation_list[1]) - pad2 = int((kernel_size_list[2] - 1) / 2 * dilation_list[2]) - pad3 = int((kernel_size_list[3] - 1) / 2 * dilation_list[3]) - self.conv_1 = nn.Conv2d(in_channels, out_channels_list[0], kernel_size = kernel_size_list[0], - dilation = dilation_list[0], padding = pad0 ) - self.conv_2 = nn.Conv2d(in_channels, out_channels_list[1], kernel_size = kernel_size_list[1], - dilation = dilation_list[1], padding = pad1 ) - self.conv_3 = nn.Conv2d(in_channels, out_channels_list[2], kernel_size = kernel_size_list[2], - dilation = dilation_list[2], padding = pad2 ) - self.conv_4 = nn.Conv2d(in_channels, out_channels_list[3], kernel_size = kernel_size_list[3], - dilation = dilation_list[3], padding = pad3 ) - - out_channels = out_channels_list[0] + out_channels_list[1] + out_channels_list[2] + out_channels_list[3] - self.conv_1x1 = nn.Sequential( - nn.Conv2d(out_channels, out_channels, kernel_size=1, padding=0), - nn.BatchNorm2d(out_channels), - nn.LeakyReLU()) - - def forward(self, x): - x1 = self.conv_1(x) - x2 = self.conv_2(x) - x3 = self.conv_3(x) - x4 = self.conv_4(x) - - y = torch.cat([x1, x2, x3, x4], dim=1) - y = self.conv_1x1(y) - return y - -class ConvBNActBlock(nn.Module): - """ - Two convolution layers with batch norm, leaky relu, - dropout and SE block. - """ - def __init__(self,in_channels, out_channels, dropout_p): - super(ConvBNActBlock, self).__init__() - self.conv_conv = nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), - nn.BatchNorm2d(out_channels), - nn.LeakyReLU(), - nn.Dropout(dropout_p), - nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), - nn.BatchNorm2d(out_channels), - nn.LeakyReLU(), - SEBlock(out_channels, 2) - ) - - def forward(self, x): - return self.conv_conv(x) - -class DownBlock(nn.Module): - """ - Downsampling by a concantenation of max-pool and avg-pool, - followed by ConvBNActBlock. - """ - def __init__(self, in_channels, out_channels, dropout_p): - super(DownBlock, self).__init__() - self.maxpool = nn.MaxPool2d(2) - self.avgpool = nn.AvgPool2d(2) - self.conv = ConvBNActBlock(2 * in_channels, out_channels, dropout_p) - - def forward(self, x): - x_max = self.maxpool(x) - x_avg = self.avgpool(x) - x_cat = torch.cat([x_max, x_avg], dim=1) - y = self.conv(x_cat) - return y + x_cat - -class UpBlock(nn.Module): - """ - Upssampling followed by ConvBNActBlock. - """ - def __init__(self, in_channels1, in_channels2, out_channels, - bilinear=True, dropout_p = 0.5): - super(UpBlock, self).__init__() - self.bilinear = bilinear - if bilinear: - self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size = 1) - self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) - else: - self.up = nn.ConvTranspose2d(in_channels1, in_channels2, kernel_size=2, stride=2) - self.conv = ConvBNActBlock(in_channels2 * 2, out_channels, dropout_p) - - def forward(self, x1, x2): - if self.bilinear: - x1 = self.conv1x1(x1) - x1 = self.up(x1) - x_cat = torch.cat([x2, x1], dim=1) - y = self.conv(x_cat) - return y + x_cat - -class CANet(nn.Module): - """ - Implementation of of CA-Net for biomedical image segmentation. - - * Reference: R. Gu et al. `CA-Net: Comprehensive Attention Convolutional Neural Networks - for Explainable Medical Image Segmentation `_. - IEEE Transactions on Medical Imaging, 40(2),2021:699-711. - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param bilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - """ - def __init__(self, params): - super(COPLENet, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.n_class = self.params['class_num'] - self.bilinear = self.params['bilinear'] - assert(len(self.ft_chns) == 5) - - f0_half = int(self.ft_chns[0] / 2) - f1_half = int(self.ft_chns[1] / 2) - f2_half = int(self.ft_chns[2] / 2) - f3_half = int(self.ft_chns[3] / 2) - self.in_conv= ConvBNActBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) - self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) - self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) - self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) - self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - - self.bridge0= ConvLayer(self.ft_chns[0], f0_half) - self.bridge1= ConvLayer(self.ft_chns[1], f1_half) - self.bridge2= ConvLayer(self.ft_chns[2], f2_half) - self.bridge3= ConvLayer(self.ft_chns[3], f3_half) - - self.up1 = UpBlock(self.ft_chns[4], f3_half, self.ft_chns[3], dropout_p = self.dropout[3]) - self.up2 = UpBlock(self.ft_chns[3], f2_half, self.ft_chns[2], dropout_p = self.dropout[2]) - self.up3 = UpBlock(self.ft_chns[2], f1_half, self.ft_chns[1], dropout_p = self.dropout[1]) - self.up4 = UpBlock(self.ft_chns[1], f0_half, self.ft_chns[0], dropout_p = self.dropout[0]) - - f4 = self.ft_chns[4] - aspp_chns = [int(f4 / 4), int(f4 / 4), int(f4 / 4), int(f4 / 4)] - aspp_knls = [1, 3, 3, 3] - aspp_dila = [1, 2, 4, 6] - self.aspp = ASPPBlock(f4, aspp_chns, aspp_knls, aspp_dila) - - - self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, - kernel_size = 3, padding = 1) - - def forward(self, x): - x_shape = list(x.shape) - if(len(x_shape) == 5): - [N, C, D, H, W] = x_shape - new_shape = [N*D, C, H, W] - x = torch.transpose(x, 1, 2) - x = torch.reshape(x, new_shape) - x0 = self.in_conv(x) - x0b = self.bridge0(x0) - x1 = self.down1(x0) - x1b = self.bridge1(x1) - x2 = self.down2(x1) - x2b = self.bridge2(x2) - x3 = self.down3(x2) - x3b = self.bridge3(x3) - x4 = self.down4(x3) - x4 = self.aspp(x4) - - x = self.up1(x4, x3b) - x = self.up2(x, x2b) - x = self.up3(x, x1b) - x = self.up4(x, x0b) - output = self.out_conv(x) - - if(len(x_shape) == 5): - new_shape = [N, D] + list(output.shape)[1:] - output = torch.reshape(output, new_shape) - output = torch.transpose(output, 1, 2) - return output \ No newline at end of file diff --git a/pymic/net/net2d/canet_module.py b/pymic/net/net2d/canet_module.py deleted file mode 100644 index 097a4f1..0000000 --- a/pymic/net/net2d/canet_module.py +++ /dev/null @@ -1,578 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Building blcoks for CA-Net. - -Oringinal file is on `Github. -`_ -""" - -from __future__ import print_function, division -import torch -import torch.nn as nn -import functools -from torch.nn import functional as F - - -class conv_block(nn.Module): - def __init__(self, ch_in, ch_out, drop_out=False): - super(conv_block, self).__init__() - self.conv = nn.Sequential( - nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), - nn.BatchNorm2d(ch_out), - nn.ReLU(inplace=True), - nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True), - nn.BatchNorm2d(ch_out), - nn.ReLU(inplace=True), - ) - self.dropout = drop_out - - def forward(self, x): - x = self.conv(x) - if self.dropout: - x = nn.Dropout2d(0.5)(x) - return x - - -# # UpCat(nn.Module) for U-net UP convolution -class UpCat(nn.Module): - def __init__(self, in_feat, out_feat, is_deconv=True): - super(UpCat, self).__init__() - if is_deconv: - self.up = nn.ConvTranspose2d(in_feat, out_feat, kernel_size=2, stride=2) - else: - self.up = nn.Upsample(scale_factor=2, mode='bilinear') - - def forward(self, inputs, down_outputs): - # TODO: Upsampling required after deconv? - outputs = self.up(down_outputs) - offset = inputs.size()[3] - outputs.size()[3] - if offset == 1: - addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2]), out=None).unsqueeze( - 3).cuda() - outputs = torch.cat([outputs, addition], dim=3) - elif offset > 1: - addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2], offset), out=None).cuda() - outputs = torch.cat([outputs, addition], dim=3) - out = torch.cat([inputs, outputs], dim=1) - - return out - - -# # UpCatconv(nn.Module) for up convolution -class UpCatconv(nn.Module): - def __init__(self, in_feat, out_feat, is_deconv=True, drop_out=False): - super(UpCatconv, self).__init__() - - if is_deconv: - self.conv = conv_block(in_feat, out_feat, drop_out=drop_out) - self.up = nn.ConvTranspose2d(in_feat, out_feat, kernel_size=2, stride=2) - else: - self.conv = conv_block(in_feat + out_feat, out_feat, drop_out=drop_out) - self.up = nn.Upsample(scale_factor=2, mode='bilinear') - - def forward(self, inputs, down_outputs): - # TODO: Upsampling required after deconv - outputs = self.up(down_outputs) - offset = inputs.size()[3] - outputs.size()[3] - if offset == 1: - addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2]), out=None).unsqueeze( - 3).cuda() - outputs = torch.cat([outputs, addition], dim=3) - elif offset > 1: - addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2], offset), out=None).cuda() - outputs = torch.cat([outputs, addition], dim=3) - out = self.conv(torch.cat([inputs, outputs], dim=1)) - - return out - - -class UnetDsv3(nn.Module): - def __init__(self, in_size, out_size, scale_factor): - super(UnetDsv3, self).__init__() - self.dsv = nn.Sequential(nn.Conv2d(in_size, out_size, kernel_size=1, stride=1, padding=0), - nn.Upsample(size=scale_factor, mode='bilinear'), ) - - def forward(self, input): - return self.dsv(input) - - -###### Intial weights ##### -def weights_init_normal(m): - classname = m.__class__.__name__ - #print(classname) - if classname.find('Conv') != -1: - nn.init.normal(m.weight.data, 0.0, 0.02) - elif classname.find('Linear') != -1: - nn.init.normal(m.weight.data, 0.0, 0.02) - elif classname.find('BatchNorm') != -1: - nn.init.normal(m.weight.data, 1.0, 0.02) - nn.init.constant(m.bias.data, 0.0) - - -def weights_init_xavier(m): - classname = m.__class__.__name__ - #print(classname) - if classname.find('Conv') != -1: - nn.init.xavier_normal(m.weight.data, gain=1) - elif classname.find('Linear') != -1: - nn.init.xavier_normal(m.weight.data, gain=1) - elif classname.find('BatchNorm') != -1: - nn.init.normal(m.weight.data, 1.0, 0.02) - nn.init.constant(m.bias.data, 0.0) - - -def weights_init_kaiming(m): - classname = m.__class__.__name__ - #print(classname) - if classname.find('Conv') != -1: - nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in') - elif classname.find('Linear') != -1: - nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in') - elif classname.find('BatchNorm') != -1: - nn.init.normal(m.weight.data, 1.0, 0.02) - nn.init.constant(m.bias.data, 0.0) - - -def weights_init_orthogonal(m): - classname = m.__class__.__name__ - #print(classname) - if classname.find('Conv') != -1: - nn.init.orthogonal(m.weight.data, gain=1) - elif classname.find('Linear') != -1: - nn.init.orthogonal(m.weight.data, gain=1) - elif classname.find('BatchNorm') != -1: - nn.init.normal(m.weight.data, 1.0, 0.02) - nn.init.constant(m.bias.data, 0.0) - - -def init_weights(net, init_type='normal'): - #print('initialization method [%s]' % init_type) - if init_type == 'normal': - net.apply(weights_init_normal) - elif init_type == 'xavier': - net.apply(weights_init_xavier) - elif init_type == 'kaiming': - net.apply(weights_init_kaiming) - elif init_type == 'orthogonal': - net.apply(weights_init_orthogonal) - else: - raise NotImplementedError('initialization method [%s] is not implemented' % init_type) - - -def get_norm_layer(norm_type='instance'): - if norm_type == 'batch': - norm_layer = functools.partial(nn.BatchNorm2d, affine=True) - elif norm_type == 'instance': - norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) - elif norm_type == 'none': - norm_layer = None - else: - raise NotImplementedError('normalization layer [%s] is not found' % norm_type) - return norm_layer - - -###### For attention ###### -class _GridAttentionBlockND(nn.Module): - def __init__(self, in_channels, gating_channels, inter_channels=None, dimension=3, mode='concatenation', - sub_sample_factor=(2,2,2)): - super(_GridAttentionBlockND, self).__init__() - - assert dimension in [2, 3] - assert mode in ['concatenation', 'concatenation_debug', 'concatenation_residual'] - - # Downsampling rate for the input featuremap - if isinstance(sub_sample_factor, tuple): self.sub_sample_factor = sub_sample_factor - elif isinstance(sub_sample_factor, list): self.sub_sample_factor = tuple(sub_sample_factor) - else: self.sub_sample_factor = tuple([sub_sample_factor]) * dimension - - # Default parameter set - self.mode = mode - self.dimension = dimension - self.sub_sample_kernel_size = self.sub_sample_factor - - # Number of channels (pixel dimensions) - self.in_channels = in_channels - self.gating_channels = gating_channels - self.inter_channels = inter_channels - - if self.inter_channels is None: - self.inter_channels = in_channels // 2 - if self.inter_channels == 0: - self.inter_channels = 1 - - if dimension == 3: - conv_nd = nn.Conv3d - bn = nn.BatchNorm3d - self.upsample_mode = 'trilinear' - elif dimension == 2: - conv_nd = nn.Conv2d - bn = nn.BatchNorm2d - self.upsample_mode = 'bilinear' - else: - raise NotImplemented - - # Output transform - self.W = nn.Sequential( - conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), - bn(self.in_channels), - ) - - # Theta^T * x_ij + Phi^T * gating_signal + bias - self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, - kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=True) - self.phi = conv_nd(in_channels=self.gating_channels, out_channels=self.inter_channels, - kernel_size=(1, 1), stride=1, padding=0, bias=True) - self.psi = conv_nd(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True) - - # Initialise weights - for m in self.children(): - init_weights(m, init_type='kaiming') - - # Define the operation - if mode == 'concatenation': - self.operation_function = self._concatenation - elif mode == 'concatenation_debug': - self.operation_function = self._concatenation_debug - elif mode == 'concatenation_residual': - self.operation_function = self._concatenation_residual - else: - raise NotImplementedError('Unknown operation function.') - - - def forward(self, x, g): - ''' - :param x: (b, c, t, h, w) - :param g: (b, g_d) - :return: - ''' - - output = self.operation_function(x, g) - return output - - def _concatenation(self, x, g): - input_size = x.size() - batch_size = input_size[0] - assert batch_size == g.size(0) - - # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) - # phi => (b, g_d) -> (b, i_c) - theta_x = self.theta(x) - theta_x_size = theta_x.size() - - # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') - # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) - phi_g = F.upsample(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) - f = F.relu(theta_x + phi_g, inplace=True) - - # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) - sigm_psi_f = F.sigmoid(self.psi(f)) - - # upsample the attentions and multiply - sigm_psi_f = F.upsample(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) - y = sigm_psi_f.expand_as(x) * x - W_y = self.W(y) - - return W_y, sigm_psi_f - - def _concatenation_debug(self, x, g): - input_size = x.size() - batch_size = input_size[0] - assert batch_size == g.size(0) - - # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) - # phi => (b, g_d) -> (b, i_c) - theta_x = self.theta(x) - theta_x_size = theta_x.size() - - # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') - # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) - phi_g = F.upsample(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) - f = F.softplus(theta_x + phi_g) - - # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) - sigm_psi_f = F.sigmoid(self.psi(f)) - - # upsample the attentions and multiply - sigm_psi_f = F.upsample(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) - y = sigm_psi_f.expand_as(x) * x - W_y = self.W(y) - - return W_y, sigm_psi_f - - - def _concatenation_residual(self, x, g): - input_size = x.size() - batch_size = input_size[0] - assert batch_size == g.size(0) - - # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) - # phi => (b, g_d) -> (b, i_c) - theta_x = self.theta(x) - theta_x_size = theta_x.size() - - # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') - # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) - phi_g = F.upsample(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) - f = F.relu(theta_x + phi_g, inplace=True) - - # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) - f = self.psi(f).view(batch_size, 1, -1) - sigm_psi_f = F.softmax(f, dim=2).view(batch_size, 1, *theta_x.size()[2:]) - - # upsample the attentions and multiply - sigm_psi_f = F.upsample(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) - y = sigm_psi_f.expand_as(x) * x - W_y = self.W(y) - - return W_y, sigm_psi_f - - -class GridAttentionBlock2D(_GridAttentionBlockND): - def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation', - sub_sample_factor=(2, 2)): - super(GridAttentionBlock2D, self).__init__(in_channels, - inter_channels=inter_channels, - gating_channels=gating_channels, - dimension=2, mode=mode, - sub_sample_factor=sub_sample_factor, - ) - - -class GridAttentionBlock3D(_GridAttentionBlockND): - def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation', - sub_sample_factor=(2,2,2)): - super(GridAttentionBlock3D, self).__init__(in_channels, - inter_channels=inter_channels, - gating_channels=gating_channels, - dimension=3, mode=mode, - sub_sample_factor=sub_sample_factor, - ) - -class _GridAttentionBlockND_TORR(nn.Module): - def __init__(self, in_channels, gating_channels, inter_channels=None, dimension=3, mode='concatenation', - sub_sample_factor=(1,1,1), bn_layer=True, use_W=True, use_phi=True, use_theta=True, use_psi=True, nonlinearity1='relu'): - super(_GridAttentionBlockND_TORR, self).__init__() - - assert dimension in [2, 3] - assert mode in ['concatenation', 'concatenation_softmax', - 'concatenation_sigmoid', 'concatenation_mean', - 'concatenation_range_normalise', 'concatenation_mean_flow'] - - # Default parameter set - self.mode = mode - self.dimension = dimension - self.sub_sample_factor = sub_sample_factor if isinstance(sub_sample_factor, tuple) else tuple([sub_sample_factor])*dimension - self.sub_sample_kernel_size = self.sub_sample_factor - - # Number of channels (pixel dimensions) - self.in_channels = in_channels - self.gating_channels = gating_channels - self.inter_channels = inter_channels - - if self.inter_channels is None: - self.inter_channels = in_channels // 2 - if self.inter_channels == 0: - self.inter_channels = 1 - - if dimension == 3: - conv_nd = nn.Conv3d - bn = nn.BatchNorm3d - self.upsample_mode = 'trilinear' - elif dimension == 2: - conv_nd = nn.Conv2d - bn = nn.BatchNorm2d - self.upsample_mode = 'bilinear' - else: - raise NotImplemented - - # initialise id functions - # Theta^T * x_ij + Phi^T * gating_signal + bias - self.W = lambda x: x - self.theta = lambda x: x - self.psi = lambda x: x - self.phi = lambda x: x - self.nl1 = lambda x: x - - if use_W: - if bn_layer: - self.W = nn.Sequential( - conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), - bn(self.in_channels), - ) - else: - self.W = conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0) - - if use_theta: - self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, - kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=False) - - - if use_phi: - self.phi = conv_nd(in_channels=self.gating_channels, out_channels=self.inter_channels, - kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=False) - - - if use_psi: - self.psi = conv_nd(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True) - - - if nonlinearity1: - if nonlinearity1 == 'relu': - self.nl1 = lambda x: F.relu(x, inplace=True) - - if 'concatenation' in mode: - self.operation_function = self._concatenation - else: - raise NotImplementedError('Unknown operation function.') - - # Initialise weights - for m in self.children(): - init_weights(m, init_type='kaiming') - - - if use_psi and self.mode == 'concatenation_sigmoid': - nn.init.constant(self.psi.bias.data, 3.0) - - if use_psi and self.mode == 'concatenation_softmax': - nn.init.constant(self.psi.bias.data, 10.0) - - # if use_psi and self.mode == 'concatenation_mean': - # nn.init.constant(self.psi.bias.data, 3.0) - - # if use_psi and self.mode == 'concatenation_range_normalise': - # nn.init.constant(self.psi.bias.data, 3.0) - - parallel = False - if parallel: - if use_W: self.W = nn.DataParallel(self.W) - if use_phi: self.phi = nn.DataParallel(self.phi) - if use_psi: self.psi = nn.DataParallel(self.psi) - if use_theta: self.theta = nn.DataParallel(self.theta) - - def forward(self, x, g): - ''' - :param x: (b, c, t, h, w) - :param g: (b, g_d) - :return: - ''' - - output = self.operation_function(x, g) - return output - - def _concatenation(self, x, g): - input_size = x.size() - batch_size = input_size[0] - assert batch_size == g.size(0) - - ############################# - # compute compatibility score - - # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) - # phi => (b, c, t, h, w) -> (b, i_c, t, h, w) - theta_x = self.theta(x) - theta_x_size = theta_x.size() - - # nl(theta.x + phi.g + bias) -> f = (b, i_c, t/s1, h/s2, w/s3) - phi_g = F.upsample(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) - - f = theta_x + phi_g - f = self.nl1(f) - - psi_f = self.psi(f) - - ############################################ - # normalisation -- scale compatibility score - # psi^T . f -> (b, 1, t/s1, h/s2, w/s3) - if self.mode == 'concatenation_softmax': - sigm_psi_f = F.softmax(psi_f.view(batch_size, 1, -1), dim=2) - sigm_psi_f = sigm_psi_f.view(batch_size, 1, *theta_x_size[2:]) - elif self.mode == 'concatenation_mean': - psi_f_flat = psi_f.view(batch_size, 1, -1) - psi_f_sum = torch.sum(psi_f_flat, dim=2)#clamp(1e-6) - psi_f_sum = psi_f_sum[:,:,None].expand_as(psi_f_flat) - - sigm_psi_f = psi_f_flat / psi_f_sum - sigm_psi_f = sigm_psi_f.view(batch_size, 1, *theta_x_size[2:]) - elif self.mode == 'concatenation_mean_flow': - psi_f_flat = psi_f.view(batch_size, 1, -1) - ss = psi_f_flat.shape - psi_f_min = psi_f_flat.min(dim=2)[0].view(ss[0],ss[1],1) - psi_f_flat = psi_f_flat - psi_f_min - psi_f_sum = torch.sum(psi_f_flat, dim=2).view(ss[0],ss[1],1).expand_as(psi_f_flat) - - sigm_psi_f = psi_f_flat / psi_f_sum - sigm_psi_f = sigm_psi_f.view(batch_size, 1, *theta_x_size[2:]) - elif self.mode == 'concatenation_range_normalise': - psi_f_flat = psi_f.view(batch_size, 1, -1) - ss = psi_f_flat.shape - psi_f_max = torch.max(psi_f_flat, dim=2)[0].view(ss[0], ss[1], 1) - psi_f_min = torch.min(psi_f_flat, dim=2)[0].view(ss[0], ss[1], 1) - - sigm_psi_f = (psi_f_flat - psi_f_min) / (psi_f_max - psi_f_min).expand_as(psi_f_flat) - sigm_psi_f = sigm_psi_f.view(batch_size, 1, *theta_x_size[2:]) - - elif self.mode == 'concatenation_sigmoid': - sigm_psi_f = F.sigmoid(psi_f) - else: - raise NotImplementedError - - # sigm_psi_f is attention map! upsample the attentions and multiply - sigm_psi_f = F.upsample(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) - y = sigm_psi_f.expand_as(x) * x - W_y = self.W(y) - - return W_y, sigm_psi_f - - -class GridAttentionBlock2D_TORR(_GridAttentionBlockND_TORR): - def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation', - sub_sample_factor=(1,1), bn_layer=True, - use_W=True, use_phi=True, use_theta=True, use_psi=True, - nonlinearity1='relu'): - super(GridAttentionBlock2D_TORR, self).__init__(in_channels, - inter_channels=inter_channels, - gating_channels=gating_channels, - dimension=2, mode=mode, - sub_sample_factor=sub_sample_factor, - bn_layer=bn_layer, - use_W=use_W, - use_phi=use_phi, - use_theta=use_theta, - use_psi=use_psi, - nonlinearity1=nonlinearity1) - - -class GridAttentionBlock3D_TORR(_GridAttentionBlockND_TORR): - def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation', - sub_sample_factor=(1,1,1), bn_layer=True): - super(GridAttentionBlock3D_TORR, self).__init__(in_channels, - inter_channels=inter_channels, - gating_channels=gating_channels, - dimension=3, mode=mode, - sub_sample_factor=sub_sample_factor, - bn_layer=bn_layer) - - -class MultiAttentionBlock(nn.Module): - def __init__(self, in_size, gate_size, inter_size, nonlocal_mode, sub_sample_factor): - super(MultiAttentionBlock, self).__init__() - self.gate_block_1 = GridAttentionBlock2D(in_channels=in_size, gating_channels=gate_size, - inter_channels=inter_size, mode=nonlocal_mode, - sub_sample_factor=sub_sample_factor) - self.gate_block_2 = GridAttentionBlock2D(in_channels=in_size, gating_channels=gate_size, - inter_channels=inter_size, mode=nonlocal_mode, - sub_sample_factor=sub_sample_factor) - self.combine_gates = nn.Sequential(nn.Conv2d(in_size*2, in_size, kernel_size=1, stride=1, padding=0), - nn.BatchNorm2d(in_size), - nn.ReLU(inplace=True)) - - # initialise the blocks - for m in self.children(): - if m.__class__.__name__.find('GridAttentionBlock2D') != -1: continue - init_weights(m, init_type='kaiming') - - def forward(self, input, gating_signal): - gate_1, attention_1 = self.gate_block_1(input, gating_signal) - gate_2, attention_2 = self.gate_block_2(input, gating_signal) - - return self.combine_gates(torch.cat([gate_1, gate_2], 1)), torch.cat([attention_1, attention_2], 1) \ No newline at end of file diff --git a/pymic/net/net2d/unet2d_canet.py b/pymic/net/net2d/unet2d_canet.py index defcb60..0af3159 100644 --- a/pymic/net/net2d/unet2d_canet.py +++ b/pymic/net/net2d/unet2d_canet.py @@ -1,18 +1,327 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division - -import numpy as np +import numpy as np import torch import torch.nn as nn +from torch.nn import init from torch.nn import functional as F -from pymic.net.net2d.canet_module import * + +## init +def weights_init_normal(m): + classname = m.__class__.__name__ + #print(classname) + if classname.find('Conv') != -1: + init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find('Linear') != -1: + init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm') != -1: + init.normal_(m.weight.data, 1.0, 0.02) + init.constant_(m.bias.data, 0.0) + + +def weights_init_xavier(m): + classname = m.__class__.__name__ + #print(classname) + if classname.find('Conv') != -1: + init.xavier_normal_(m.weight.data, gain=1) + elif classname.find('Linear') != -1: + init.xavier_normal_(m.weight.data, gain=1) + elif classname.find('BatchNorm') != -1: + init.normal_(m.weight.data, 1.0, 0.02) + init.constant_(m.bias.data, 0.0) + + +def weights_init_kaiming(m): + classname = m.__class__.__name__ + #print(classname) + if classname.find('Conv') != -1: + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif classname.find('Linear') != -1: + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif classname.find('BatchNorm') != -1: + init.normal_(m.weight.data, 1.0, 0.02) + init.constant_(m.bias.data, 0.0) + + +def weights_init_orthogonal(m): + classname = m.__class__.__name__ + #print(classname) + if classname.find('Conv') != -1: + init.orthogonal_(m.weight.data, gain=1) + elif classname.find('Linear') != -1: + init.orthogonal_(m.weight.data, gain=1) + elif classname.find('BatchNorm') != -1: + init.normal_(m.weight.data, 1.0, 0.02) + init.constant_(m.bias.data, 0.0) + + +def init_weights(net, init_type='normal'): + #print('initialization method [%s]' % init_type) + if init_type == 'normal': + net.apply(weights_init_normal) + elif init_type == 'xavier': + net.apply(weights_init_xavier) + elif init_type == 'kaiming': + net.apply(weights_init_kaiming) + elif init_type == 'orthogonal': + net.apply(weights_init_orthogonal) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + +## 1, modules +def conv1x1(in_planes, out_planes, stride=1, bias=False): + "1x1 convolution" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, + padding=0, bias=bias) def conv3x3(in_planes, out_planes, stride=1, bias=False, group=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,padding=1, groups=group, bias=bias) +# conv_block(nn.Module) for U-net convolution block +class conv_block(nn.Module): + def __init__(self, ch_in, ch_out, drop_out=False): + super(conv_block, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), + nn.BatchNorm2d(ch_out), + nn.ReLU(inplace=True), + nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True), + nn.BatchNorm2d(ch_out), + nn.ReLU(inplace=True), + ) + self.dropout = drop_out + + def forward(self, x): + x = self.conv(x) + if self.dropout: + x = nn.Dropout2d(0.5)(x) + return x + + +# # UpCat(nn.Module) for U-net UP convolution +class UpCat(nn.Module): + def __init__(self, in_feat, out_feat, is_deconv=True): + super(UpCat, self).__init__() + + if is_deconv: + self.up = nn.ConvTranspose2d(in_feat, out_feat, kernel_size=2, stride=2) + else: + self.up = nn.Upsample(scale_factor=2, mode='bilinear') + + def forward(self, inputs, down_outputs): + # TODO: Upsampling required after deconv? + outputs = self.up(down_outputs) + offset = inputs.size()[3] - outputs.size()[3] + if offset == 1: + addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2]), out=None).unsqueeze( + 3).cuda() + outputs = torch.cat([outputs, addition], dim=3) + elif offset > 1: + addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2], offset), out=None).cuda() + outputs = torch.cat([outputs, addition], dim=3) + out = torch.cat([inputs, outputs], dim=1) + + return out + + +# # UpCatconv(nn.Module) for up convolution +class UpCatconv(nn.Module): + def __init__(self, in_feat, out_feat, is_deconv=True, drop_out=False): + super(UpCatconv, self).__init__() + + if is_deconv: + self.conv = conv_block(in_feat, out_feat, drop_out=drop_out) + self.up = nn.ConvTranspose2d(in_feat, out_feat, kernel_size=2, stride=2) + else: + self.conv = conv_block(in_feat + out_feat, out_feat, drop_out=drop_out) + self.up = nn.Upsample(scale_factor=2, mode='bilinear') + + def forward(self, inputs, down_outputs): + # TODO: Upsampling required after deconv + outputs = self.up(down_outputs) + offset = inputs.size()[3] - outputs.size()[3] + if offset == 1: + addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2]), out=None).unsqueeze( + 3).cuda() + outputs = torch.cat([outputs, addition], dim=3) + elif offset > 1: + addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2], offset), out=None).cuda() + outputs = torch.cat([outputs, addition], dim=3) + out = self.conv(torch.cat([inputs, outputs], dim=1)) + + return out + + + +class _GridAttentionBlockND(nn.Module): + def __init__(self, in_channels, gating_channels, inter_channels=None, dimension=3, mode='concatenation', + sub_sample_factor=(2,2,2)): + super(_GridAttentionBlockND, self).__init__() + + assert dimension in [2, 3] + assert mode in ['concatenation', 'concatenation_debug', 'concatenation_residual'] + + # Downsampling rate for the input featuremap + if isinstance(sub_sample_factor, tuple): self.sub_sample_factor = sub_sample_factor + elif isinstance(sub_sample_factor, list): self.sub_sample_factor = tuple(sub_sample_factor) + else: self.sub_sample_factor = tuple([sub_sample_factor]) * dimension + + # Default parameter set + self.mode = mode + self.dimension = dimension + self.sub_sample_kernel_size = self.sub_sample_factor + + # Number of channels (pixel dimensions) + self.in_channels = in_channels + self.gating_channels = gating_channels + self.inter_channels = inter_channels + + if self.inter_channels is None: + self.inter_channels = in_channels // 2 + if self.inter_channels == 0: + self.inter_channels = 1 + + if dimension == 3: + conv_nd = nn.Conv3d + bn = nn.BatchNorm3d + self.upsample_mode = 'trilinear' + elif dimension == 2: + conv_nd = nn.Conv2d + bn = nn.BatchNorm2d + self.upsample_mode = 'bilinear' + else: + raise NotImplemented + + # Output transform + self.W = nn.Sequential( + conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), + bn(self.in_channels), + ) + + # Theta^T * x_ij + Phi^T * gating_signal + bias + self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, + kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=True) + self.phi = conv_nd(in_channels=self.gating_channels, out_channels=self.inter_channels, + kernel_size=(1, 1), stride=1, padding=0, bias=True) + self.psi = conv_nd(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True) + + # Initialise weights + for m in self.children(): + init_weights(m, init_type='kaiming') + + # Define the operation + if mode == 'concatenation': + self.operation_function = self._concatenation + elif mode == 'concatenation_debug': + self.operation_function = self._concatenation_debug + elif mode == 'concatenation_residual': + self.operation_function = self._concatenation_residual + else: + raise NotImplementedError('Unknown operation function.') + + + def forward(self, x, g): + ''' + :param x: (b, c, t, h, w) + :param g: (b, g_d) + :return: + ''' + + output = self.operation_function(x, g) + return output + + def _concatenation(self, x, g): + input_size = x.size() + batch_size = input_size[0] + assert batch_size == g.size(0) + + # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) + # phi => (b, g_d) -> (b, i_c) + theta_x = self.theta(x) + theta_x_size = theta_x.size() + + # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') + # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) + phi_g = F.interpolate(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) + f = F.relu(theta_x + phi_g, inplace=True) + + # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) + sigm_psi_f = torch.sigmoid(self.psi(f)) + + # upsample the attentions and multiply + sigm_psi_f = F.interpolate(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) + y = sigm_psi_f.expand_as(x) * x + W_y = self.W(y) + + return W_y, sigm_psi_f + + def _concatenation_debug(self, x, g): + input_size = x.size() + batch_size = input_size[0] + assert batch_size == g.size(0) + + # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) + # phi => (b, g_d) -> (b, i_c) + theta_x = self.theta(x) + theta_x_size = theta_x.size() + + # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') + # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) + phi_g = F.interpolate(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) + f = F.softplus(theta_x + phi_g) + + # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) + sigm_psi_f = torch.sigmoid(self.psi(f)) + + # upsample the attentions and multiply + sigm_psi_f = F.interpolate(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) + y = sigm_psi_f.expand_as(x) * x + W_y = self.W(y) + + return W_y, sigm_psi_f + + def _concatenation_residual(self, x, g): + input_size = x.size() + batch_size = input_size[0] + assert batch_size == g.size(0) + + # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) + # phi => (b, g_d) -> (b, i_c) + theta_x = self.theta(x) + theta_x_size = theta_x.size() + + # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') + # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) + phi_g = F.interpolate(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) + f = F.relu(theta_x + phi_g, inplace=True) + + # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) + f = self.psi(f).view(batch_size, 1, -1) + sigm_psi_f = F.softmax(f, dim=2).view(batch_size, 1, *theta_x.size()[2:]) + + # upsample the attentions and multiply + sigm_psi_f = F.interpolate(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) + y = sigm_psi_f.expand_as(x) * x + W_y = self.W(y) + + return W_y, sigm_psi_f + + +class GridAttentionBlock2D(_GridAttentionBlockND): + def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation', + sub_sample_factor=(2, 2)): + super(GridAttentionBlock2D, self).__init__(in_channels, + inter_channels=inter_channels, + gating_channels=gating_channels, + dimension=2, mode=mode, + sub_sample_factor=sub_sample_factor, + ) + + +## 2, channel attention class SE_Conv_Block(nn.Module): expansion = 4 @@ -29,6 +338,9 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, drop_out=False): self.stride = stride self.dropout = drop_out + self.globalAvgPool = nn.AdaptiveAvgPool2d(1) + self.globalMaxPool = nn.AdaptiveMaxPool2d(1) + self.fc1 = nn.Linear(in_features=planes * 2, out_features=round(planes / 2)) self.fc2 = nn.Linear(in_features=round(planes / 2), out_features=planes * 2) self.sigmoid = nn.Sigmoid() @@ -54,7 +366,7 @@ def forward(self, x): original_out = out out1 = out # For global average pool - out = F.adaptive_avg_pool2d(out, (1,1)) + out = self.globalAvgPool(out) out = out.view(out.size(0), -1) out = self.fc1(out) out = self.relu(out) @@ -64,7 +376,7 @@ def forward(self, x): avg_att = out out = out * original_out # For global maximum pool - out1 = F.adaptive_max_pool2d(out1, (1,1)) + out1 = self.globalMaxPool(out1) out1 = out1.view(out1.size(0), -1) out1 = self.fc1(out1) out1 = self.relu(out1) @@ -87,175 +399,197 @@ def forward(self, x): return out, att_weight -# # CBAM Convolutional block attention module -class BasicConv(nn.Module): - def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, - relu=True, bn=True, bias=False): - super(BasicConv, self).__init__() - self.out_channels = out_planes - self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, - dilation=dilation, groups=groups, bias=bias) - self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None - self.relu = nn.ReLU() if relu else None +## 3, grid attention +class _GridAttentionBlockND(nn.Module): + def __init__(self, in_channels, gating_channels, inter_channels=None, dimension=3, mode='concatenation', + sub_sample_factor=(2,2,2)): + super(_GridAttentionBlockND, self).__init__() - def forward(self, x): - x = self.conv(x) - if self.bn is not None: - x = self.bn(x) - if self.relu is not None: - x = self.relu(x) - return x + assert dimension in [2, 3] + assert mode in ['concatenation', 'concatenation_debug', 'concatenation_residual'] + # Downsampling rate for the input featuremap + if isinstance(sub_sample_factor, tuple): self.sub_sample_factor = sub_sample_factor + elif isinstance(sub_sample_factor, list): self.sub_sample_factor = tuple(sub_sample_factor) + else: self.sub_sample_factor = tuple([sub_sample_factor]) * dimension -class Flatten(nn.Module): - def forward(self, x): - return x.view(x.size(0), -1) + # Default parameter set + self.mode = mode + self.dimension = dimension + self.sub_sample_kernel_size = self.sub_sample_factor + # Number of channels (pixel dimensions) + self.in_channels = in_channels + self.gating_channels = gating_channels + self.inter_channels = inter_channels -class ChannelGate(nn.Module): - def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): - super(ChannelGate, self).__init__() - self.gate_channels = gate_channels - self.mlp = nn.Sequential( - Flatten(), - nn.Linear(gate_channels, gate_channels // reduction_ratio), - nn.ReLU(), - nn.Linear(gate_channels // reduction_ratio, gate_channels) - ) - self.pool_types = pool_types + if self.inter_channels is None: + self.inter_channels = in_channels // 2 + if self.inter_channels == 0: + self.inter_channels = 1 - def forward(self, x): - channel_att_sum = None - for pool_type in self.pool_types: - if pool_type == 'avg': - avg_pool = F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) - channel_att_raw = self.mlp(avg_pool) - elif pool_type == 'max': - max_pool = F.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) - channel_att_raw = self.mlp(max_pool) - elif pool_type == 'lp': - lp_pool = F.lp_pool2d(x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) - channel_att_raw = self.mlp(lp_pool) - elif pool_type == 'lse': - # LSE pool only - lse_pool = logsumexp_2d(x) - channel_att_raw = self.mlp(lse_pool) + if dimension == 3: + conv_nd = nn.Conv3d + bn = nn.BatchNorm3d + self.upsample_mode = 'trilinear' + elif dimension == 2: + conv_nd = nn.Conv2d + bn = nn.BatchNorm2d + self.upsample_mode = 'bilinear' + else: + raise NotImplemented - if channel_att_sum is None: - channel_att_sum = channel_att_raw - else: - channel_att_sum = channel_att_sum + channel_att_raw + # Output transform + self.W = nn.Sequential( + conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), + bn(self.in_channels), + ) - # scalecoe = F.sigmoid(channel_att_sum) - # print("channel att_sum", channel_att_sum.shape) - # channel_att_sum = channel_att_sum.reshape(channel_att_sum.shape[0], 4, 4) - # avg_weight = torch.mean(channel_att_sum, dim=2).unsqueeze(2) - # avg_weight = avg_weight.expand(channel_att_sum.shape[0], 4, 4).reshape(channel_att_sum.shape[0], 16) - # scale = F.sigmoid(avg_weight).unsqueeze(2).unsqueeze(3).expand_as(x) - scale = F.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x) - return x * scale, scale + # Theta^T * x_ij + Phi^T * gating_signal + bias + self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, + kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=True) + self.phi = conv_nd(in_channels=self.gating_channels, out_channels=self.inter_channels, + kernel_size=(1, 1), stride=1, padding=0, bias=True) + self.psi = conv_nd(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True) + # Initialise weights + for m in self.children(): + init_weights(m, init_type='kaiming') -def logsumexp_2d(tensor): - tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) - s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) - outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() - return outputs + # Define the operation + if mode == 'concatenation': + self.operation_function = self._concatenation + elif mode == 'concatenation_debug': + self.operation_function = self._concatenation_debug + elif mode == 'concatenation_residual': + self.operation_function = self._concatenation_residual + else: + raise NotImplementedError('Unknown operation function.') -class ChannelPool(nn.Module): - def forward(self, x): - return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1) + def forward(self, x, g): + ''' + :param x: (b, c, t, h, w) + :param g: (b, g_d) + :return: + ''' + output = self.operation_function(x, g) + return output -class SpatialGate(nn.Module): - def __init__(self): - super(SpatialGate, self).__init__() - kernel_size = 7 - self.compress = ChannelPool() - self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False) + def _concatenation(self, x, g): + input_size = x.size() + batch_size = input_size[0] + assert batch_size == g.size(0) - def forward(self, x): - x_compress = self.compress(x) - x_out = self.spatial(x_compress) - scale = F.sigmoid(x_out) # broadcasting - return x * scale, scale + # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) + # phi => (b, g_d) -> (b, i_c) + theta_x = self.theta(x) + theta_x_size = theta_x.size() -class SpatialAtten(nn.Module): - def __init__(self, in_size, out_size, kernel_size=3, stride=1): - super(SpatialAtten, self).__init__() - self.conv1 = BasicConv(in_size, out_size, kernel_size, stride=stride, - padding=(kernel_size-1) // 2, relu=True) - self.conv2 = BasicConv(out_size, in_size, kernel_size=1, stride=stride, - padding=0, relu=True, bn=False) + # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') + # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) + phi_g = F.interpolate(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) + f = F.relu(theta_x + phi_g, inplace=True) - def forward(self, x): - residual = x - x_out = self.conv1(x) - x_out = self.conv2(x_out) - spatial_att = F.sigmoid(x_out) - # .unsqueeze(4).permute(0, 1, 4, 2, 3) - # spatial_att = spatial_att.expand(spatial_att.shape[0], 4, 4, spatial_att.shape[3], spatial_att.shape[4]).reshape( - # spatial_att.shape[0], 16, spatial_att.shape[3], spatial_att.shape[4]) - x_out = residual * spatial_att + # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) + sigm_psi_f = torch.sigmoid(self.psi(f)) - x_out += residual + # upsample the attentions and multiply + sigm_psi_f = F.interpolate(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) + y = sigm_psi_f.expand_as(x) * x + W_y = self.W(y) - return x_out, spatial_att + return W_y, sigm_psi_f -class Scale_atten_block(nn.Module): - def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False): - super(Scale_atten_block, self).__init__() - self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) - self.no_spatial = no_spatial - if not no_spatial: - self.SpatialGate = SpatialAtten(gate_channels, gate_channels //reduction_ratio) + def _concatenation_debug(self, x, g): + input_size = x.size() + batch_size = input_size[0] + assert batch_size == g.size(0) - def forward(self, x): - x_out, ca_atten = self.ChannelGate(x) - if not self.no_spatial: - x_out, sa_atten = self.SpatialGate(x_out) + # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) + # phi => (b, g_d) -> (b, i_c) + theta_x = self.theta(x) + theta_x_size = theta_x.size() - return x_out, ca_atten, sa_atten + # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') + # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) + phi_g = F.interpolate(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) + f = F.softplus(theta_x + phi_g) + # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) + sigm_psi_f = torch.sigmoid(self.psi(f)) -class scale_atten_convblock(nn.Module): - def __init__(self, in_size, out_size, stride=1, downsample=None, use_cbam=True, no_spatial=False, drop_out=False): - super(scale_atten_convblock, self).__init__() - self.downsample = downsample - self.stride = stride - self.no_spatial = no_spatial - self.dropout = drop_out + # upsample the attentions and multiply + sigm_psi_f = F.interpolate(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) + y = sigm_psi_f.expand_as(x) * x + W_y = self.W(y) - self.relu = nn.ReLU(inplace=True) - self.conv3 = conv3x3(in_size, out_size) - self.bn3 = nn.BatchNorm2d(out_size) + return W_y, sigm_psi_f - if use_cbam: - self.cbam = Scale_atten_block(in_size, reduction_ratio=4, no_spatial=self.no_spatial) # out_size - else: - self.cbam = None - def forward(self, x): - residual = x + def _concatenation_residual(self, x, g): + input_size = x.size() + batch_size = input_size[0] + assert batch_size == g.size(0) - if self.downsample is not None: - residual = self.downsample(x) + # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) + # phi => (b, g_d) -> (b, i_c) + theta_x = self.theta(x) + theta_x_size = theta_x.size() - if not self.cbam is None: - out, scale_c_atten, scale_s_atten = self.cbam(x) + # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') + # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) + phi_g = F.interpolate(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) + f = F.relu(theta_x + phi_g, inplace=True) - out += residual - out = self.relu(out) - out = self.conv3(out) - out = self.bn3(out) - out = self.relu(out) + # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) + f = self.psi(f).view(batch_size, 1, -1) + sigm_psi_f = F.softmax(f, dim=2).view(batch_size, 1, *theta_x.size()[2:]) - if self.dropout: - out = nn.Dropout2d(0.5)(out) + # upsample the attentions and multiply + sigm_psi_f = F.interpolate(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) + y = sigm_psi_f.expand_as(x) * x + W_y = self.W(y) - return out - + return W_y, sigm_psi_f + + +class GridAttentionBlock2D(_GridAttentionBlockND): + def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation', + sub_sample_factor=(2, 2)): + super(GridAttentionBlock2D, self).__init__(in_channels, + inter_channels=inter_channels, + gating_channels=gating_channels, + dimension=2, mode=mode, + sub_sample_factor=sub_sample_factor, + ) + +class MultiAttentionBlock(nn.Module): + def __init__(self, in_size, gate_size, inter_size, nonlocal_mode, sub_sample_factor): + super(MultiAttentionBlock, self).__init__() + self.gate_block_1 = GridAttentionBlock2D(in_channels=in_size, gating_channels=gate_size, + inter_channels=inter_size, mode=nonlocal_mode, + sub_sample_factor=sub_sample_factor) + self.gate_block_2 = GridAttentionBlock2D(in_channels=in_size, gating_channels=gate_size, + inter_channels=inter_size, mode=nonlocal_mode, + sub_sample_factor=sub_sample_factor) + self.combine_gates = nn.Sequential(nn.Conv2d(in_size*2, in_size, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(in_size), + nn.ReLU(inplace=True)) + + # initialise the blocks + for m in self.children(): + if m.__class__.__name__.find('GridAttentionBlock2D') != -1: continue + init_weights(m, init_type='kaiming') + + def forward(self, input, gating_signal): + gate_1, attention_1 = self.gate_block_1(input, gating_signal) + gate_2, attention_2 = self.gate_block_2(input, gating_signal) + + return self.combine_gates(torch.cat([gate_1, gate_2], 1)), torch.cat([attention_1, attention_2], 1) + +## 4, Non-local layers class _NonLocalBlockND(nn.Module): def __init__(self, in_channels, inter_channels=None, dimension=3, mode='embedded_gaussian', sub_sample_factor=4, bn_layer=True): @@ -300,13 +634,13 @@ def __init__(self, in_channels, inter_channels=None, dimension=3, mode='embedded kernel_size=1, stride=1, padding=0), bn(self.in_channels) ) - nn.init.constant(self.W[1].weight, 0) - nn.init.constant(self.W[1].bias, 0) + nn.init.constant_(self.W[1].weight, 0) + nn.init.constant_(self.W[1].bias, 0) else: self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0) - nn.init.constant(self.W.weight, 0) - nn.init.constant(self.W.bias, 0) + nn.init.constant_(self.W.weight, 0) + nn.init.constant_(self.W.bias, 0) self.theta = None self.phi = None @@ -527,7 +861,7 @@ def _concatenation_proper_down(self, x): y = y.contiguous().view(batch_size, self.inter_channels, *downsampled_size[2:]) # upsample the final featuremaps # (b,0.5c,t/s1,h/s2,w/s3) - y = F.upsample(y, size=x.size()[2:], mode='trilinear') + y = F.interpolate(y, size=x.size()[2:], mode='trilinear') # attention block output W_y = self.W(y) @@ -544,7 +878,191 @@ def __init__(self, in_channels, inter_channels=None, mode='embedded_gaussian', s sub_sample_factor=sub_sample_factor, bn_layer=bn_layer) - +## 5, scale attention +class BasicConv(nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, + relu=True, bn=True, bias=False): + super(BasicConv, self).__init__() + self.out_channels = out_planes + self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, + dilation=dilation, groups=groups, bias=bias) + self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None + self.relu = nn.ReLU() if relu else None + + def forward(self, x): + x = self.conv(x) + if self.bn is not None: + x = self.bn(x) + if self.relu is not None: + x = self.relu(x) + return x + + +class Flatten(nn.Module): + def forward(self, x): + return x.view(x.size(0), -1) + + +class ChannelGate(nn.Module): + def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): + super(ChannelGate, self).__init__() + self.gate_channels = gate_channels + self.mlp = nn.Sequential( + Flatten(), + nn.Linear(gate_channels, gate_channels // reduction_ratio), + nn.ReLU(), + nn.Linear(gate_channels // reduction_ratio, gate_channels) + ) + self.pool_types = pool_types + + def forward(self, x): + channel_att_sum = None + for pool_type in self.pool_types: + if pool_type == 'avg': + avg_pool = F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) + channel_att_raw = self.mlp(avg_pool) + elif pool_type == 'max': + max_pool = F.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) + channel_att_raw = self.mlp(max_pool) + elif pool_type == 'lp': + lp_pool = F.lp_pool2d(x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) + channel_att_raw = self.mlp(lp_pool) + elif pool_type == 'lse': + # LSE pool only + lse_pool = logsumexp_2d(x) + channel_att_raw = self.mlp(lse_pool) + + if channel_att_sum is None: + channel_att_sum = channel_att_raw + else: + channel_att_sum = channel_att_sum + channel_att_raw + + # scalecoe = F.sigmoid(channel_att_sum) + channel_att_sum = channel_att_sum.reshape(channel_att_sum.shape[0], 4, 4) + avg_weight = torch.mean(channel_att_sum, dim=2).unsqueeze(2) + avg_weight = avg_weight.expand(channel_att_sum.shape[0], 4, 4).reshape(channel_att_sum.shape[0], 16) + scale = torch.sigmoid(avg_weight).unsqueeze(2).unsqueeze(3).expand_as(x) + + return x * scale, scale + + +def logsumexp_2d(tensor): + tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) + s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) + outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() + return outputs + + +class ChannelPool(nn.Module): + def forward(self, x): + return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1) + + +class SpatialGate(nn.Module): + def __init__(self): + super(SpatialGate, self).__init__() + kernel_size = 7 + self.compress = ChannelPool() + self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False) + + def forward(self, x): + x_compress = self.compress(x) + x_out = self.spatial(x_compress) + scale = torch.sigmoid(x_out) # broadcasting + # spa_scale = scale.expand_as(x) + # print(spa_scale.shape) + return x * scale, scale + +class SpatialAtten(nn.Module): + def __init__(self, in_size, out_size, kernel_size=3, stride=1): + super(SpatialAtten, self).__init__() + self.conv1 = BasicConv(in_size, out_size, kernel_size, stride=stride, + padding=(kernel_size-1) // 2, relu=True) + self.conv2 = BasicConv(out_size, out_size, kernel_size=1, stride=stride, + padding=0, relu=True, bn=False) + + def forward(self, x): + residual = x + x_out = self.conv1(x) + x_out = self.conv2(x_out) + spatial_att = torch.sigmoid(x_out).unsqueeze(4).permute(0, 1, 4, 2, 3) + spatial_att = spatial_att.expand(spatial_att.shape[0], 4, 4, spatial_att.shape[3], spatial_att.shape[4]).reshape( + spatial_att.shape[0], 16, spatial_att.shape[3], spatial_att.shape[4]) + x_out = residual * spatial_att + + x_out += residual + + return x_out, spatial_att + +class Scale_atten_block(nn.Module): + def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False): + super(Scale_atten_block, self).__init__() + self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) + self.no_spatial = no_spatial + if not no_spatial: + self.SpatialGate = SpatialAtten(gate_channels, gate_channels //reduction_ratio) + + def forward(self, x): + x_out, ca_atten = self.ChannelGate(x) + if not self.no_spatial: + x_out, sa_atten = self.SpatialGate(x_out) + + return x_out, ca_atten, sa_atten + + +class scale_atten_convblock(nn.Module): + def __init__(self, in_size, out_size, stride=1, downsample=None, use_cbam=True, no_spatial=False, drop_out=False): + super(scale_atten_convblock, self).__init__() + # if stride != 1 or in_size != out_size: + # downsample = nn.Sequential( + # nn.Conv2d(in_size, out_size, + # kernel_size=1, stride=stride, bias=False), + # nn.BatchNorm2d(out_size), + # ) + self.downsample = downsample + self.stride = stride + self.no_spatial = no_spatial + self.dropout = drop_out + + self.relu = nn.ReLU(inplace=True) + self.conv3 = conv3x3(in_size, out_size) + self.bn3 = nn.BatchNorm2d(out_size) + + if use_cbam: + self.cbam = Scale_atten_block(in_size, reduction_ratio=4, no_spatial=self.no_spatial) # out_size + else: + self.cbam = None + + def forward(self, x): + residual = x + + if self.downsample is not None: + residual = self.downsample(x) + + if not self.cbam is None: + out, scale_c_atten, scale_s_atten = self.cbam(x) + + # scale_c_atten = nn.Sigmoid()(scale_c_atten) + # scale_s_atten = nn.Sigmoid()(scale_s_atten) + # scale_atten = channel_atten_c * spatial_atten_s + + # scale_max = torch.argmax(scale_atten, dim=1, keepdim=True) + # scale_max_soft = get_soft_label(input_tensor=scale_max, num_class=8) + # scale_max_soft = scale_max_soft.permute(0, 3, 1, 2) + # scale_atten_soft = scale_atten * scale_max_soft + + out += residual + out = self.relu(out) + out = self.conv3(out) + out = self.bn3(out) + out = self.relu(out) + + if self.dropout: + out = nn.Dropout2d(0.5)(out) + + return out + +## 6, CANet class CANet(nn.Module): """ Implementation of CANet (Comprehensive Attention Network) for image segmentation. @@ -559,24 +1077,25 @@ class CANet(nn.Module): :param in_chns: (int) Input channel number. :param feature_chns: (list) Feature channel for each resolution level. The length should be 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. :param class_num: (int) The class number for segmentation task. - :param bilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. + :param is_deconv: (bool) Using deconvolution for up-sampling or not. + If False, bilinear interpolation will be used for up-sampling. Default is True. + :param is_batchnorm: (bool) If batch normalization is or not. Default is True. + :param feature_scale: (int) The scale of resolution levels. Default is 4. """ def __init__(self, params): #args, in_ch=3, n_classes=2, feature_scale=4, is_deconv=True, is_batchnorm=True, # nonlocal_mode='concatenation', attention_dsample=(1, 1)): super(CANet, self).__init__() self.in_channels = params['in_chns'] self.num_classes = params['class_num'] + self.feature_chns= params.get('feature_chns', [32, 64, 128, 256, 512]) self.is_deconv = params.get('is_deconv', True) self.is_batchnorm = params.get('is_batchnorm', True) self.feature_scale = params.get('feature_scale', 4) nonlocal_mode = 'concatenation' attention_dsample = (1, 1) - filters = [64, 128, 256, 512, 1024] + filters = self.feature_chns filters = [int(x / self.feature_scale) for x in filters] # downsampling @@ -623,6 +1142,13 @@ def __init__(self, params): #args, in_ch=3, n_classes=2, feature_scale=4, is_de self.final = nn.Conv2d(filters[0], self.num_classes, kernel_size=1) def forward(self, inputs): + x_shape = list(inputs.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + inputs = torch.transpose(inputs, 1, 2) + inputs = torch.reshape(inputs, new_shape) + # Feature Extraction conv1 = self.conv1(inputs) maxpool1 = self.maxpool1(conv1) @@ -671,6 +1197,14 @@ def forward(self, inputs): out = self.scale_att(dsv_cat) out = self.final(out) + if(len(x_shape) == 5): + if(isinstance(out, (list,tuple))): + for i in range(len(out)): + new_shape = [N, D] + list(out[i].shape)[1:] + out[i] = torch.transpose(torch.reshape(out[i], new_shape), 1, 2) + else: + new_shape = [N, D] + list(out.shape)[1:] + out = torch.transpose(torch.reshape(out, new_shape), 1, 2) return out From 1aebb5e053309639bb7720f77281fd13e735d5f9 Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 2 Oct 2024 15:51:18 +0800 Subject: [PATCH 198/225] Update unet2d_canet.py --- pymic/net/net2d/unet2d_canet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymic/net/net2d/unet2d_canet.py b/pymic/net/net2d/unet2d_canet.py index 0af3159..f578025 100644 --- a/pymic/net/net2d/unet2d_canet.py +++ b/pymic/net/net2d/unet2d_canet.py @@ -1216,7 +1216,7 @@ def forward(self, inputs): x = np.random.rand(4, 3, 224, 224) xt = torch.from_numpy(x) - xt = torch.tensor(xt) + xt = xt.clone().detach() y = Net(xt) print(len(y.size())) From d0ebdc2f8c193d431224689406e70018ab4e93a6 Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 2 Oct 2024 21:05:32 +0800 Subject: [PATCH 199/225] update unetpp --- pymic/net/net2d/{unet2d_nest.py => unet2d_pp.py} | 8 ++++---- pymic/net/net_dict_seg.py | 14 +++++++------- 2 files changed, 11 insertions(+), 11 deletions(-) rename pymic/net/net2d/{unet2d_nest.py => unet2d_pp.py} (96%) diff --git a/pymic/net/net2d/unet2d_nest.py b/pymic/net/net2d/unet2d_pp.py similarity index 96% rename from pymic/net/net2d/unet2d_nest.py rename to pymic/net/net2d/unet2d_pp.py index efa048f..f2a003b 100644 --- a/pymic/net/net2d/unet2d_nest.py +++ b/pymic/net/net2d/unet2d_pp.py @@ -3,9 +3,9 @@ import torch.nn as nn from pymic.net.net2d.unet2d import * -class NestedUNet2D(nn.Module): +class UNet2Dpp(nn.Module): """ - An implementation of the Nested U-Net. + An implementation of the U-Net++. * Reference: Zongwei Zhou, et al.: `UNet++: A Nested U-Net Architecture for Medical Image Segmentation. `_ @@ -25,7 +25,7 @@ class NestedUNet2D(nn.Module): :param class_num: (int) The class number for segmentation task. """ def __init__(self, params): - super(NestedUNet2D, self).__init__() + super(UNet2Dpp, self).__init__() self.params = params self.in_chns = self.params['in_chns'] self.filters = self.params['feature_chns'] @@ -96,7 +96,7 @@ def forward(self, x): 'feature_chns':[2, 8, 32, 48, 64], 'dropout': [0, 0, 0.3, 0.4, 0.5], 'class_num': 2} - Net = NestedUNet2D(params) + Net = UNet2Dpp(params) Net = Net.double() x = np.random.rand(4, 4, 10, 96, 96) diff --git a/pymic/net/net_dict_seg.py b/pymic/net/net_dict_seg.py index 741eed7..99d1693 100644 --- a/pymic/net/net_dict_seg.py +++ b/pymic/net/net_dict_seg.py @@ -22,7 +22,7 @@ from pymic.net.net2d.unet2d_mtnet import MTNet2D from pymic.net.net2d.cople_net import COPLENet from pymic.net.net2d.unet2d_attention import AttentionUNet2D -from pymic.net.net2d.unet2d_nest import NestedUNet2D +from pymic.net.net2d.unet2d_pp import UNet2Dpp from pymic.net.net2d.unet2d_scse import UNet2D_ScSE from pymic.net.net2d.trans2d.transunet import TransUNet from pymic.net.net2d.trans2d.swinunet import SwinUNet @@ -50,15 +50,15 @@ # from pymic.net.net3d.trans3d.SwitchNet import SwitchNet SegNetDict = { + 'AttentionUNet2D': AttentionUNet2D, + 'CANet': CANet, + 'COPLENet': COPLENet, + 'MCNet2D': MCNet2D, + 'MTNet2D': MTNet2D, 'UNet2D': UNet2D, 'UNet2D_DualBranch': UNet2D_DualBranch, 'UNet2D_CCT': UNet2D_CCT, - 'MCNet2D': MCNet2D, - 'MTNet2D': MTNet2D, - 'CANet': CANet, - 'COPLENet': COPLENet, - 'AttentionUNet2D': AttentionUNet2D, - 'NestedUNet2D': NestedUNet2D, + 'UNet2Dpp': UNet2Dpp, 'UNet2D_ScSE': UNet2D_ScSE, 'TransUNet': TransUNet, 'SwinUNet': SwinUNet, From 9670b63225fbd0cc8fe1989bb08a5a670913bd96 Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 3 Oct 2024 21:41:13 +0800 Subject: [PATCH 200/225] add lcovnet --- pymic/net/net3d/lcovnet.py | 246 +++++++++++++++++++++++++++++++++++++ pymic/net/net_dict_seg.py | 2 + pymic/test/test_net3d.py | 97 ++++++++++++++- 3 files changed, 344 insertions(+), 1 deletion(-) create mode 100644 pymic/net/net3d/lcovnet.py diff --git a/pymic/net/net3d/lcovnet.py b/pymic/net/net3d/lcovnet.py new file mode 100644 index 0000000..bd91878 --- /dev/null +++ b/pymic/net/net3d/lcovnet.py @@ -0,0 +1,246 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import logging +import torch +import torch.nn as nn +import numpy as np +from pymic.net.net_init import Initialization_He, Initialization_XavierUniform + +class UnetBlock_Encode(nn.Module): + def __init__(self, in_channels, out_channel): + super(UnetBlock_Encode, self).__init__() + + self.in_chns = in_channels + self.out_chns = out_channel + + self.conv1 = nn.Sequential( + nn.Conv3d(self.in_chns, self.out_chns, kernel_size=(1, 1, 3), + padding=(0, 0, 1)), + nn.BatchNorm3d(self.out_chns), + nn.ReLU6(inplace=True) + ) + + self.conv2_1 = nn.Sequential( + nn.Conv3d(self.out_chns, self.out_chns, kernel_size=(3, 3, 1), + padding=(1, 1, 0), groups=1), + nn.BatchNorm3d(self.out_chns), + nn.ReLU6(inplace=True), + nn.Dropout(p=0.2) + ) + + self.conv2_2 = nn.Sequential( + nn.AvgPool3d(kernel_size=4, stride=2, padding=1), + nn.Conv3d(self.out_chns, self.out_chns, kernel_size=1, + padding=0), + nn.BatchNorm3d(self.out_chns), + nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False) + ) + + def forward(self, x): + # print(x.shape) + x = self.conv1(x) + + x1 = self.conv2_1(x) + x2 = self.conv2_2(x) + x2 = torch.sigmoid(x2) + x = x1 + x2 * x + return x + + +class UnetBlock_Encode_BottleNeck(nn.Module): + def __init__(self, in_channels, out_channel): + super(UnetBlock_Encode_BottleNeck, self).__init__() + + self.in_chns = in_channels + self.out_chns = out_channel + + self.conv1 = nn.Sequential( + nn.Conv3d(self.in_chns, self.out_chns, kernel_size=(1, 1, 3), + padding=(0, 0, 1)), + nn.BatchNorm3d(self.out_chns), + nn.ReLU6(inplace=True) + ) + + self.conv2_1 = nn.Sequential( + nn.Conv3d(self.out_chns, self.out_chns, kernel_size=(3, 3, 1), + padding=(1, 1, 0), groups=self.out_chns), + nn.BatchNorm3d(self.out_chns), + nn.ReLU6(inplace=True), + nn.Dropout(p=0.2) + ) + + self.conv2_2 = nn.Sequential( + # nn.AvgPool3d(kernel_size=4, stride=2), + nn.Conv3d(self.out_chns, self.out_chns, kernel_size=1, + padding=0), + nn.BatchNorm3d(self.out_chns), + nn.ReLU6(inplace=True), + nn.Dropout(p=0.2) + ) + + def forward(self, x): + x = self.conv1(x) + + x1 = self.conv2_1(x) + x2 = self.conv2_2(x) + x2 = torch.sigmoid(x2) + x = x1 + x2 * x + return x + + +class UnetBlock_Down(nn.Module): + def __init__(self): + super(UnetBlock_Down, self).__init__() + self.avg_pool = nn.MaxPool3d(kernel_size=2, stride=2) + + def forward(self, x): + x = self.avg_pool(x) + return x + + +class UnetBlock_Up(nn.Module): + def __init__(self, in_channels, out_channel): + super(UnetBlock_Up, self).__init__() + self.conv = self.conv1 = nn.Sequential( + nn.Conv3d(in_channels, out_channel, kernel_size=1, + padding=0, groups=1), + nn.BatchNorm3d(out_channel), + nn.ReLU6(inplace=True), + nn.Dropout(p=0.2) + ) + + self.up = nn.Upsample( + scale_factor=2, mode='trilinear', align_corners=False) + + def forward(self, x): + x = self.conv(x) + x = self.up(x) + return x + + +class LCOVNet(nn.Module): + """ + An implementation of the LCOVNet. + + * Reference: Q. Zhao, L. Zhong, J. Xiao, J. Zhang, Y. Chen , W. Liao, S. Zhang, and G. Wang: + Efficient Multi-Organ Segmentation From 3D Abdominal CT Images With Lightweight Network and Knowledge Distillation. + `IEEE TMI 42(9) 2023: 2513 - 2523. `_ + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear`). The default value + is 2 (`Trilinear`). + :param multiscale_pred: (bool) Get multi-scale prediction. + """ + def __init__(self, params): + super(LCOVNet, self).__init__() + params = self.get_default_parameters(params) + for p in params: + print(p, params[p]) + self.stage = 'train' + # C_in=32, n_classes=17, m=1, is_ds=True): + + in_chns = params['in_chns'] + n_class = params['class_num'] + self.ft_chns = params['feature_chns'] + self.mul_pred = params.get('multiscale_pred', False) + + self.Encode_block1 = UnetBlock_Encode(in_chns, self.ft_chns[0]) + self.down1 = UnetBlock_Down() + + self.Encode_block2 = UnetBlock_Encode(self.ft_chns[0], self.ft_chns[1]) + self.down2 = UnetBlock_Down() + + self.Encode_block3 = UnetBlock_Encode(self.ft_chns[1], self.ft_chns[2]) + self.down3 = UnetBlock_Down() + + self.Encode_block4 = UnetBlock_Encode(self.ft_chns[2], self.ft_chns[3]) + self.down4 = UnetBlock_Down() + + self.Encode_BottleNeck_block5 = UnetBlock_Encode_BottleNeck( + self.ft_chns[3], self.ft_chns[4]) + + self.up1 = UnetBlock_Up(self.ft_chns[4], self.ft_chns[3]) + self.Decode_block1 = UnetBlock_Encode( + self.ft_chns[3]*2, self.ft_chns[3]) + self.segout1 = nn.Conv3d( + self.ft_chns[3], n_class, kernel_size=1, padding=0) + + self.up2 = UnetBlock_Up(self.ft_chns[3], self.ft_chns[2]) + self.Decode_block2 = UnetBlock_Encode( + self.ft_chns[2]*2, self.ft_chns[2]) + self.segout2 = nn.Conv3d( + self.ft_chns[2], n_class, kernel_size=1, padding=0) + + self.up3 = UnetBlock_Up(self.ft_chns[2], self.ft_chns[1]) + self.Decode_block3 = UnetBlock_Encode( + self.ft_chns[1]*2, self.ft_chns[1]) + self.segout3 = nn.Conv3d( + self.ft_chns[1], n_class, kernel_size=1, padding=0) + + self.up4 = UnetBlock_Up(self.ft_chns[1], self.ft_chns[0]) + self.Decode_block4 = UnetBlock_Encode( + self.ft_chns[0]*2, self.ft_chns[0]) + self.segout4 = nn.Conv3d( + self.ft_chns[0], n_class, kernel_size=1, padding=0) + + def get_default_parameters(self, params): + default_param = { + 'feature_chns': [32, 64, 128, 256, 512], + 'initialization': 'he', + 'multiscale_pred': False + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params + + def forward(self, x): + _x1 = self.Encode_block1(x) + x1 = self.down1(_x1) + + _x2 = self.Encode_block2(x1) + x2 = self.down2(_x2) + + _x3 = self.Encode_block3(x2) + x3 = self.down2(_x3) + + _x4 = self.Encode_block4(x3) + x4 = self.down2(_x4) + + x5 = self.Encode_BottleNeck_block5(x4) + + x6 = self.up1(x5) + x6 = torch.cat((x6, _x4), dim=1) + x6 = self.Decode_block1(x6) + segout1 = self.segout1(x6) + + x7 = self.up2(x6) + x7 = torch.cat((x7, _x3), dim=1) + x7 = self.Decode_block2(x7) + segout2 = self.segout2(x7) + + x8 = self.up3(x7) + x8 = torch.cat((x8, _x2), dim=1) + x8 = self.Decode_block3(x8) + segout3 = self.segout3(x8) + + x9 = self.up4(x8) + x9 = torch.cat((x9, _x1), dim=1) + x9 = self.Decode_block4(x9) + segout4 = self.segout4(x9) + + if (self.mul_pred == True and self.stage == 'train'): + return [segout4, segout3, segout2, segout1] + else: + return segout4 \ No newline at end of file diff --git a/pymic/net/net_dict_seg.py b/pymic/net/net_dict_seg.py index 99d1693..687710b 100644 --- a/pymic/net/net_dict_seg.py +++ b/pymic/net/net_dict_seg.py @@ -30,6 +30,7 @@ from pymic.net.net3d.unet3d import UNet3D from pymic.net.net3d.grunet import GRUNet from pymic.net.net3d.fmunetv3 import FMUNetV3 +from pymic.net.net3d.lcovnet import LCOVNet from pymic.net.net3d.unet3d_scse import UNet3D_ScSE from pymic.net.net3d.unet3d_dual_branch import UNet3D_DualBranch # from pymic.net.net3d.stunet_wrap import STUNet_wrap @@ -64,6 +65,7 @@ 'SwinUNet': SwinUNet, 'UNet2D5': UNet2D5, 'GRUNet': GRUNet, + 'LCOVNet': LCOVNet, 'FMUNetV3': FMUNetV3, 'UNet3D': UNet3D, 'UNet3D_ScSE': UNet3D_ScSE, diff --git a/pymic/test/test_net3d.py b/pymic/test/test_net3d.py index 180dcff..058a6fd 100644 --- a/pymic/test/test_net3d.py +++ b/pymic/test/test_net3d.py @@ -6,6 +6,9 @@ from pymic.net.net3d.unet3d import UNet3D from pymic.net.net3d.unet3d_scse import UNet3D_ScSE from pymic.net.net3d.unet2d5 import UNet2D5 +from pymic.net.net3d.grunet import GRUNet +from pymic.net.net3d.lcovnet import LCOVNet +from pymic.net.net3d.trans3d.unetr_pp import UNETR_PP def test_unet3d(): params = {'in_chns':4, @@ -61,6 +64,22 @@ def test_unet3d_scse(): y = y.detach().numpy() print(y.shape) +def test_lcovnet(): + params = {'in_chns':4, + 'feature_chns':[16, 32, 64, 128, 256], + 'class_num': 2} + Net = LCOVNet(params) + Net = Net.double() + + x = np.random.rand(4, 4, 96, 96, 96) + xt = torch.from_numpy(x) + xt = xt.clone().detach() + + y = Net(xt) + print(len(y.size())) + y = y.detach().numpy() + print(y.shape) + def test_unet2d5(): params = {'in_chns':4, 'feature_chns':[8, 16, 32, 64, 128], @@ -100,9 +119,85 @@ def test_unet2d5(): y = y.detach().numpy() print(y.shape) +def test_mystunet(): + in_chns = 4 + num_class = 4 + # input_channels, num_classes, depth=[1,1,1,1,1,1], dims=[32, 64, 128, 256, 512, 512], + # pool_op_kernel_sizes=None, conv_kernel_sizes=None) + dims=[16, 32, 64, 128, 256, 512] + Net = MySTUNet(in_chns, num_class, dims = dims, pool_op_kernel_sizes = [[2, 2, 2], [2,2,2], [2,2,2], [2,2,2], [1, 1, 1]], + conv_kernel_sizes = [[3, 3, 3], [3,3,3], [3,3,3], [3,3,3], [3,3,3], [3, 3, 3]]) + Net = Net.double() + + x = np.random.rand(4, 4, 96, 96, 96) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + out = Net(xt) + for y in out: + y = y.detach().numpy() + print(y.shape) + +def test_grunet(): + params = {'in_chns':4, + 'feature_chns':[8, 16, 32, 64, 128], + 'dims': [2, 3, 3, 3, 3], + 'dropout': [0, 0, 0.3, 0.4, 0.5], + 'class_num': 2, + 'depth': 2, + 'multiscale_pred': True} + x = np.random.rand(4, 4, 64, 128, 128) + + # params = {'in_chns':4, + # 'feature_chns':[8, 16, 32, 64, 128], + # 'dims': [3, 3, 3, 3, 3], + # 'dropout': [0, 0, 0.3, 0.4, 0.5], + # 'class_num': 2, + # 'depth': 4, + # 'multiscale_pred': True} + # x = np.random.rand(4, 4, 96, 96, 96) + + Net = GRUNet(params) + Net = Net.double() + + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + out = Net(xt) + for y in out: + y = y.detach().numpy() + print(y.shape) + +def test_unetr_pp(): + depths = [128, 64, 32] + for i in range(3): + params = {'in_chns': 4, + 'class_num': 2, + 'img_size': [depths[i], 128, 128], + 'resolution_mode': i + } + net = UNETR_PP(params) + net.double() + + x = np.random.rand(2, 4, depths[i], 128, 128) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = net(xt) + print(len(y)) + for yi in y: + yi = yi.detach().numpy() + print(yi.shape) + + + if __name__ == "__main__": # test_unet3d() # test_unet3d_scse() - test_unet2d5() + test_lcovnet() + # test_unetr_pp() + # test_unet2d5() + # test_mystunet() + # test_fmunetv2() \ No newline at end of file From a74efa89ff91b4da59bfcd78152c257bd6b01515 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 4 Oct 2024 15:23:41 +0800 Subject: [PATCH 201/225] Create unet2d_multi_decoder.py --- pymic/net/net2d/unet2d_multi_decoder.py | 162 ++++++++++++++++++++++++ 1 file changed, 162 insertions(+) create mode 100644 pymic/net/net2d/unet2d_multi_decoder.py diff --git a/pymic/net/net2d/unet2d_multi_decoder.py b/pymic/net/net2d/unet2d_multi_decoder.py new file mode 100644 index 0000000..2c7b3c4 --- /dev/null +++ b/pymic/net/net2d/unet2d_multi_decoder.py @@ -0,0 +1,162 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import torch.nn as nn +from pymic.net.net2d.unet2d import * + +class UNet2D_DualBranch(nn.Module): + """ + A dual branch network using UNet2D as backbone. + + * Reference: Xiangde Luo, Minhao Hu, Wenjun Liao, Shuwei Zhai, Tao Song, Guotai Wang, + Shaoting Zhang. ScribblScribble-Supervised Medical Image Segmentation via + Dual-Branch Network and Dynamically Mixed Pseudo Labels Supervision. + `MICCAI 2022. `_ + + The parameters for the backbone should be given in the `params` dictionary. + See :mod:`pymic.net.net2d.unet2d.UNet2D` for details. + In addition, the following field should be included: + + :param output_mode: (str) How to obtain the result during the inference. + `average`: taking average of the two branches. + `first`: takeing the result in the first branch. + `second`: taking the result in the second branch. + """ + def __init__(self, params): + super(UNet2D_DualBranch, self).__init__() + params = self.get_default_parameters(params) + self.output_mode = params["output_mode"] + self.encoder = Encoder(params) + self.decoder1 = Decoder(params) + self.decoder2 = Decoder(params) + + def get_default_parameters(self, params): + default_param = { + 'feature_chns': [32, 64, 128, 256, 512], + 'dropout': [0.0, 0.0, 0.2, 0.3, 0.4], + 'up_mode': 2, + 'multiscale_pred': False, + 'output_mode': "average" + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + + f = self.encoder(x) + output1 = self.decoder1(f) + output2 = self.decoder2(f) + if(len(x_shape) == 5): + new_shape = [N, D] + list(output1.shape)[1:] + output1 = torch.reshape(output1, new_shape) + output1 = torch.transpose(output1, 1, 2) + output2 = torch.reshape(output2, new_shape) + output2 = torch.transpose(output2, 1, 2) + + if(self.training): + return output1, output2 + else: + if(self.output_mode == "average"): + return (output1 + output2)/2 + elif(self.output_mode == "first"): + return output1 + else: + return output2 + +class UNet2D_TriBranch(nn.Module): + """ + A tri-branch network using UNet2D as backbone. The super class for MCNet2D and MTNet2D. + + The parameters for the backbone should be given in the `params` dictionary. + See :mod:`pymic.net.net2d.unet2d.UNet2D` for details. + """ + def __init__(self, params): + super(UNet2D_TriBranch, self).__init__() + params = self.get_default_parameters(params) + self.encoder = Encoder(params) + self.decoder1 = Decoder(params) + self.decoder2 = Decoder(params) + self.decoder3 = Decoder(params) + + def get_default_parameters(self, params): + default_param = { + 'feature_chns': [32, 64, 128, 256, 512], + 'dropout': [0.0, 0.0, 0.2, 0.3, 0.4], + 'up_mode': 2, + 'multiscale_pred': False, + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + + feature = self.encoder(x) + output1 = self.decoder1(feature) + new_shape = [N, D] + list(output1.shape)[1:] + output1 = torch.transpose(torch.reshape(output1, new_shape), 1, 2) + if(not self.training): + return output1 + output2 = self.decoder2(feature) + output3 = self.decoder3(feature) + if(len(x_shape) == 5): + output2 = torch.transpose(torch.reshape(output2, new_shape), 1, 2) + output3 = torch.transpose(torch.reshape(output3, new_shape), 1, 2) + return output1, output2, output3 + +class MCNet2D(UNet2D_TriBranch): + """ + A tri-branch network using UNet2D as backbone. + + * Reference: Yicheng Wu, Zongyuan Ge et al. Mutual consistency learning for + semi-supervised medical image segmentation. + `Medical Image Analysis 2022. `_ + + The original code is at: https://github.com/ycwu1997/MC-Net + + The parameters for the backbone should be given in the `params` dictionary. + See :mod:`pymic.net.net2d.unet2d.UNet2D` for details. + """ + def __init__(self, params): + super(MCNet2D, self).__init__(params) + in_chns = params['in_chns'] + class_num = params['class_num'] + ft_chns = params['feature_chns'] + dropout = params['dropout'] + params1 = {'in_chns': in_chns, + 'feature_chns': ft_chns, + 'dropout': dropout, + 'class_num': class_num, + 'up_mode': 0 } + params2 = {'in_chns': in_chns, + 'feature_chns': ft_chns, + 'dropout': dropout, + 'class_num': class_num, + 'up_mode': 1 } + params3 = {'in_chns': in_chns, + 'feature_chns': ft_chns, + 'dropout': dropout, + 'class_num': class_num, + 'up_mode': 2 } + self.encoder = Encoder(params1) + self.decoder1 = Decoder(params1) + self.decoder2 = Decoder(params2) + self.decoder3 = Decoder(params3) \ No newline at end of file From 59c82f3ff8230be5586bae6f40000dfdf42b18b4 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 4 Oct 2024 15:27:21 +0800 Subject: [PATCH 202/225] update 3D networks --- pymic/net/net3d/fmunetv3.py | 262 ++++++++++++++++++++++++++++++++++++ pymic/net/net_dict_seg.py | 4 +- 2 files changed, 264 insertions(+), 2 deletions(-) create mode 100644 pymic/net/net3d/fmunetv3.py diff --git a/pymic/net/net3d/fmunetv3.py b/pymic/net/net3d/fmunetv3.py new file mode 100644 index 0000000..1367209 --- /dev/null +++ b/pymic/net/net3d/fmunetv3.py @@ -0,0 +1,262 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import itertools +import logging +import torch +import torch.nn as nn +from pymic.net.net_init import Initialization_He, Initialization_XavierUniform + +dim0 = {0:3, 1:2, 2:2} +dim1 = {0:3, 1:3, 2:2} +conv_knl = {2: (1, 3, 3), 3: 3} +conv_pad = {2: (0, 1, 1), 3: 1} +pool_knl = {2: (1, 2, 2), 3: 2} +down_stride = {2: (1, 2, 2), 3: 2} + +class ResConv(nn.Module): + def __init__(self, out_channels, dim = 3, dropout_p = 0.0, depth = 2): + super(ResConv, self).__init__() + assert(dim == 2 or dim == 3) + self.out_channels = out_channels + self.conv_list = nn.ModuleList([nn.Sequential( + nn.InstanceNorm3d(out_channels, affine = True), + nn.LeakyReLU(), + nn.Dropout(dropout_p), + nn.Conv3d(out_channels, out_channels, kernel_size=conv_knl[dim], padding=conv_pad[dim])) + for i in range(depth)]) + + def forward(self, x): + for conv in self.conv_list: + x = conv(x) + x + return x + +class DownSample(nn.Module): + """downsampling based on convolution + + :param in_channels: (int) Input channel number. + :param out_channels: (int) Output channel number. + :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. + :param dropout_p: (int) Dropout probability. + :param downsample: (bool) Use downsample or not after convolution. + """ + def __init__(self, in_channels, out_channels, dim = 3): + super(DownSample, self).__init__() + self.down = nn.Sequential( + nn.InstanceNorm3d(in_channels, affine = True), + nn.LeakyReLU(), + nn.Conv3d(in_channels, out_channels, kernel_size=conv_knl[dim], + padding=conv_pad[dim], stride = down_stride[dim]) + ) + + def forward(self, x): + return self.down(x) + +class UpCatConv(nn.Module): + """Upsampling followed by `ResConv` block + + :param in_channels1: (int) Input channel number for low-resolution feature map. + :param in_channels2: (int) Input channel number for high-resolution feature map. + :param out_channels: (int) Output channel number. + :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. + :param dropout_p: (int) Dropout probability. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear` for 3D and `Bilinear` for 2D). + The default value is 2. + """ + def __init__(self, in_channels1, in_channels2, out_channels, dim = 3): + super(UpCatConv, self).__init__() + + self.up = nn.Sequential( + nn.InstanceNorm3d(in_channels1, affine = True), + nn.LeakyReLU(), + nn.Conv3d(in_channels1, in_channels2, kernel_size=1, padding=0), + nn.Upsample(scale_factor=pool_knl[dim], mode='trilinear', align_corners=True) + ) + + self.conv = nn.Sequential( + nn.InstanceNorm3d(in_channels2*2, affine = True), + nn.LeakyReLU(), + nn.Conv3d(in_channels2 * 2, out_channels, kernel_size=conv_knl[dim], padding=conv_pad[dim]) + ) + + def forward(self, x_l, x_h): + """ + x_l: low-resolution feature map. + x_h: high-resolution feature map. + """ + y = torch.cat([x_h, self.up(x_l)], dim=1) + return self.conv(y) + +class Encoder(nn.Module): + """ + A modification of the encoder of 3D UNet by using ConvScSEBlock3D + + Parameters are given in the `params` dictionary. + See :mod:`pymic.net.net3d.unet3d.Encoder` for details. + + res_mode: resolution mode: 0-- isotrpic, 1-- near isotrpic, 2-- isotropic + """ + def __init__(self, ft_chns, res_mode = 0, dropout_p = 0, depth = 2): + super(Encoder, self).__init__() + d0, d1 = dim0[res_mode], dim1[res_mode] + + self.en_conv0 = ResConv(ft_chns[0], d0, 0, depth) + self.en_conv1 = ResConv(ft_chns[1], d1, 0, depth) + self.en_conv2 = ResConv(ft_chns[2], 3, dropout_p, depth) + self.en_conv3 = ResConv(ft_chns[3], 3, dropout_p, depth) + self.en_conv4 = ResConv(ft_chns[4], 3, dropout_p, depth) + + self.down0 = DownSample(ft_chns[0], ft_chns[1], d0) + self.down1 = DownSample(ft_chns[1], ft_chns[2], d1) + self.down2 = DownSample(ft_chns[2], ft_chns[3], 3) + self.down3 = DownSample(ft_chns[3], ft_chns[4], 3) + + def forward(self, x): + x0 = self.en_conv0(x) + x1 = self.en_conv1(self.down0(x0)) + x2 = self.en_conv2(self.down1(x1)) + x3 = self.en_conv3(self.down2(x2)) + x4 = self.en_conv4(self.down3(x3)) + return [x0, x1, x2, x3, x4] + +class Decoder(nn.Module): + """ + A modification of the encoder of 3D UNet by using ConvScSEBlock3D + + Parameters are given in the `params` dictionary. + See :mod:`pymic.net.net3d.unet3d.Encoder` for details. + """ + def __init__(self, ft_chns, res_mode = 0, dropout_p = 0, depth = 2): + super(Decoder, self).__init__() + d0, d1 = dim0[res_mode], dim1[res_mode] + + self.upcat0 = UpCatConv(ft_chns[1], ft_chns[0], ft_chns[0], d0) + self.upcat1 = UpCatConv(ft_chns[2], ft_chns[1], ft_chns[1], d1) + self.upcat2 = UpCatConv(ft_chns[3], ft_chns[2], ft_chns[2], 3) + self.upcat3 = UpCatConv(ft_chns[4], ft_chns[3], ft_chns[3], 3) + + self.de_conv0 = ResConv(ft_chns[0], d0, 0, depth) + self.de_conv1 = ResConv(ft_chns[1], d1, 0, depth) + self.de_conv2 = ResConv(ft_chns[2], 3, dropout_p, depth) + self.de_conv3 = ResConv(ft_chns[3], 3, dropout_p, depth) + self.de_conv4 = ResConv(ft_chns[4], 3, dropout_p, depth) + + def forward(self, x): + x0, x1, x2, x3, x4 = x + x4_de = self.de_conv4(x4) + x3_de = self.de_conv3(self.upcat3(x4_de, x3)) + x2_de = self.de_conv2(self.upcat2(x3_de, x2)) + x1_de = self.de_conv1(self.upcat1(x2_de, x1)) + x0_de = self.de_conv0(self.upcat0(x1_de, x0)) + return [x0_de, x1_de, x2_de, x3_de] + +class FMUNetV3(nn.Module): + """ + A 2.5D network combining 3D convolutions with 2D convolutions. + + * Reference: Guotai Wang, Jonathan Shapey, Wenqi Li, Reuben Dorent, Alex Demitriadis, + Sotirios Bisdas, Ian Paddick, Robert Bradford, Shaoting Zhang, Sébastien Ourselin, + Tom Vercauteren: Automatic Segmentation of Vestibular Schwannoma from T2-Weighted + MRI by Deep Spatial Attention with Hardness-Weighted Loss. + `MICCAI (2) 2019: 264-272. `_ + + Note that the attention module in the orininal paper is not used here. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param conv_dims: (list) The convolution dimension (2 or 3) for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear`). The default value + is 2 (`Trilinear`). + :param multiscale_pred: (bool) Get multi-scale prediction. + """ + def __init__(self, params): + super(FMUNetV3, self).__init__() + params = self.get_default_parameters(params) + + self.stage = 'train' + in_chns = params['in_chns'] + ft_chns = params['feature_chns'] + res_mode = params['res_mode'] + dropout = params['dropout'] + depth = params['depth'] + cls_num = params['class_num'] + self.mul_pred = params.get('multiscale_pred', True) + self.tune_mode= params.get('finetune_mode', 'all') + self.load_mode= params.get('weights_load_mode', 'all') + + d0 = dim0[res_mode] + self.project = nn.Conv3d(in_chns, ft_chns[0], kernel_size=conv_knl[d0], padding=conv_pad[d0]) + self.encoder = Encoder(ft_chns, res_mode, dropout, depth) + # self.decoder = Decoder(ft_chns, res_mode, dropout, depth = 2) + self.decoder = Decoder(ft_chns, res_mode, dropout, depth) + + self.out_layers = nn.ModuleList() + dims = [dim0[res_mode], dim1[res_mode], 3, 3] + for i in range(4): + out_layer = nn.Sequential( + nn.InstanceNorm3d(ft_chns[i], affine = True), + nn.LeakyReLU(), + nn.Conv3d(ft_chns[i], cls_num, kernel_size=conv_knl[dims[i]], padding=conv_pad[dims[i]])) + self.out_layers.append(out_layer) + + init = params['initialization'].lower() + weightInitializer = Initialization_He(1e-2) if init == 'he' else Initialization_XavierUniform() + self.apply(weightInitializer) + + def get_default_parameters(self, params): + default_param = { + 'finetune_mode': 'all', + 'initialization': 'he', + 'feature_chns': [32, 64, 128, 256, 512], + 'dropout': 0.2, + 'res_mode': 0, + 'depth': 2, + 'multiscale_pred': True + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params + + def set_stage(self, stage): + self.stage = stage + + def forward(self, x): + x_en = self.encoder(self.project(x)) + x_de = self.decoder(x_en) + output = self.out_layers[0](x_de[0]) + if(self.mul_pred and self.stage == 'train'): + output = [output] + for i in range(1, len(x_de)): + output.append(self.out_layers[i](x_de[i])) + return output + + def get_parameters_to_update(self): + if(self.tune_mode == 'all'): + return self.parameters() + + up_params = itertools.chain() + if(self.tune_mode == 'decoder'): + up_blocks = [self.decoder, self.out_layers] + else: + raise ValueError("undefined fine-tune mode for FMUNet: {0:}".format(self.tune_mode)) + for block in up_blocks: + up_params = itertools.chain(up_params, block.parameters()) + return up_params + + def get_parameters_to_load(self): + state_dict = self.state_dict() + if(self.load_mode == 'encoder'): + state_dict = {k:v for k, v in state_dict.items() if "project" in k or "encoder" in k } + return state_dict \ No newline at end of file diff --git a/pymic/net/net_dict_seg.py b/pymic/net/net_dict_seg.py index 687710b..d7f759b 100644 --- a/pymic/net/net_dict_seg.py +++ b/pymic/net/net_dict_seg.py @@ -19,7 +19,7 @@ from pymic.net.net2d.unet2d_multi_decoder import UNet2D_DualBranch, MCNet2D from pymic.net.net2d.unet2d_canet import CANet from pymic.net.net2d.unet2d_cct import UNet2D_CCT -from pymic.net.net2d.unet2d_mtnet import MTNet2D +# from pymic.net.net2d.unet2d_mtnet import MTNet2D from pymic.net.net2d.cople_net import COPLENet from pymic.net.net2d.unet2d_attention import AttentionUNet2D from pymic.net.net2d.unet2d_pp import UNet2Dpp @@ -55,7 +55,7 @@ 'CANet': CANet, 'COPLENet': COPLENet, 'MCNet2D': MCNet2D, - 'MTNet2D': MTNet2D, + # 'MTNet2D': MTNet2D, 'UNet2D': UNet2D, 'UNet2D_DualBranch': UNet2D_DualBranch, 'UNet2D_CCT': UNet2D_CCT, From b99b5de9a330f7334c0d45527b08b2b4205fefd9 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 4 Oct 2024 15:28:36 +0800 Subject: [PATCH 203/225] Update __init__.py --- pymic/net_run/semi_sup/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymic/net_run/semi_sup/__init__.py b/pymic/net_run/semi_sup/__init__.py index cb5d1a3..769a66b 100644 --- a/pymic/net_run/semi_sup/__init__.py +++ b/pymic/net_run/semi_sup/__init__.py @@ -3,7 +3,7 @@ from pymic.net_run.semi_sup.ssl_em import SSLEntropyMinimization from pymic.net_run.semi_sup.ssl_mt import SSLMeanTeacher from pymic.net_run.semi_sup.ssl_mcnet import SSLMCNet -from pymic.net_run.semi_sup.ssl_cdma import SSLCDMA +# from pymic.net_run.semi_sup.ssl_cdma import SSLCDMA from pymic.net_run.semi_sup.ssl_uamt import SSLUncertaintyAwareMeanTeacher from pymic.net_run.semi_sup.ssl_cct import SSLCCT from pymic.net_run.semi_sup.ssl_cps import SSLCPS @@ -13,7 +13,7 @@ SSLMethodDict = {'EntropyMinimization': SSLEntropyMinimization, 'MeanTeacher': SSLMeanTeacher, 'MCNet': SSLMCNet, - 'CDMA': SSLCDMA, + # 'CDMA': SSLCDMA, 'UAMT': SSLUncertaintyAwareMeanTeacher, 'CCT': SSLCCT, 'CPS': SSLCPS, From 981b772f8e75002c1f13d789fb2c5454c5683a80 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 4 Oct 2024 15:47:37 +0800 Subject: [PATCH 204/225] load pretrained weights for classification models --- pymic/net_run/agent_cls.py | 49 +++++++++++++++++++++++++++----------- pymic/net_run/train.py | 2 ++ 2 files changed, 37 insertions(+), 14 deletions(-) diff --git a/pymic/net_run/agent_cls.py b/pymic/net_run/agent_cls.py index a31df84..728f805 100644 --- a/pymic/net_run/agent_cls.py +++ b/pymic/net_run/agent_cls.py @@ -19,7 +19,7 @@ from pymic.net.net_dict_cls import TorchClsNetDict from pymic.transform.trans_dict import TransformDict from pymic.net_run.agent_abstract import NetRunAgent -from pymic.util.general import mixup +from pymic.util.general import mixup, tensor_shape_match import warnings warnings.filterwarnings('ignore', '.*output shape of zoom.*') @@ -212,6 +212,27 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): logging.info('valid loss {0:.4f}, avg {1:} {2:.4f}'.format( valid_scalars['loss'], metrics, valid_scalars[metrics])) + def load_pretrained_weights(self, network, pretrained_dict, device_ids): + if(len(device_ids) > 1): + if(hasattr(network.module, "get_parameters_to_load")): + model_dict = network.module.get_parameters_to_load() + else: + model_dict = network.module.state_dict() + else: + if(hasattr(network, "get_parameters_to_load")): + model_dict = network.get_parameters_to_load() + else: + model_dict = network.state_dict() + pretrained_dict = {k: v for k, v in pretrained_dict.items() if \ + k in model_dict and tensor_shape_match(pretrained_dict[k], model_dict[k])} + logging.info("Initializing the following parameters with pre-trained model") + for k in pretrained_dict: + logging.info(k) + if (len(device_ids) > 1): + network.module.load_state_dict(pretrained_dict, strict = False) + else: + network.load_state_dict(pretrained_dict, strict = False) + def train_valid(self): device_ids = self.config['training']['gpus'] if(len(device_ids) > 1): @@ -227,7 +248,7 @@ def train_valid(self): ckpt_prefix = self.config['training'].get('ckpt_prefix', None) if(ckpt_prefix is None): ckpt_prefix = ckpt_dir.split('/')[-1] - iter_start = self.config['training']['iter_start'] + iter_start = 0 iter_max = self.config['training']['iter_max'] iter_valid = self.config['training']['iter_valid'] iter_save = self.config['training']['iter_save'] @@ -243,18 +264,18 @@ def train_valid(self): self.max_val_score = 0.0 self.max_val_it = 0 self.best_model_wts = None - self.checkpoint = None - if(iter_start > 0): - checkpoint_file = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, iter_start) - self.checkpoint = torch.load(checkpoint_file, map_location = self.device) - assert(self.checkpoint['iteration'] == iter_start) - if(len(device_ids) > 1): - self.net.module.load_state_dict(self.checkpoint['model_state_dict']) - else: - self.net.load_state_dict(self.checkpoint['model_state_dict']) - self.max_val_score = self.checkpoint.get('valid_pred', 0) - self.max_val_it = self.checkpoint['iteration'] - self.best_model_wts = self.checkpoint['model_state_dict'] + ckpt_init_name = self.config['training'].get('ckpt_init_name', None) + ckpt_init_mode = self.config['training'].get('ckpt_init_mode', 0) + + if(ckpt_init_name is not None): + checkpoint = torch.load(ckpt_dir + "/" + ckpt_init_name, map_location = self.device) + pretrained_dict = checkpoint['model_state_dict'] + self.load_pretrained_weights(self.net, pretrained_dict, device_ids) + if(ckpt_init_mode > 0): # Load other information + iter_start = checkpoint['iteration'] + self.max_val_score = checkpoint.get('valid_pred', 0) + self.max_val_it = checkpoint['iteration'] + self.best_model_wts = checkpoint['model_state_dict'] self.create_optimizer(self.get_parameters_to_update()) self.create_loss_calculator() diff --git a/pymic/net_run/train.py b/pymic/net_run/train.py index ed60fa1..ec0002f 100644 --- a/pymic/net_run/train.py +++ b/pymic/net_run/train.py @@ -60,6 +60,8 @@ def main(): required=False, default=None) parser.add_argument("-ckpt_dir", help="the output dir for trained model", required=False, default=None) + parser.add_argument("-iter_max", help="the maximal iteration number for training", + required=False, default=None) parser.add_argument("-gpus", help="the gpus for runing, e.g., [0]", required=False, default=None) args = parser.parse_args() From d833edc2528afd297798a30d0ce918559f64ed13 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 4 Oct 2024 16:31:54 +0800 Subject: [PATCH 205/225] update 3d networks --- pymic/net/net2d/cople_net.py | 34 ++++++++++++++++++++++------------ pymic/net/net3d/unet3d_scse.py | 23 +++++++++++++---------- 2 files changed, 35 insertions(+), 22 deletions(-) diff --git a/pymic/net/net2d/cople_net.py b/pymic/net/net2d/cople_net.py index bd54fd6..046dee8 100644 --- a/pymic/net/net2d/cople_net.py +++ b/pymic/net/net2d/cople_net.py @@ -120,20 +120,30 @@ class UpBlock(nn.Module): Upssampling followed by ConvBNActBlock. """ def __init__(self, in_channels1, in_channels2, out_channels, - bilinear=True, dropout_p = 0.5): + up_mode = 2, dropout_p = 0.5): super(UpBlock, self).__init__() - self.bilinear = bilinear - if bilinear: - self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size = 1) - self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + if(isinstance(up_mode, int)): + up_mode_values = ["transconv", "nearest", "bilinear", "bicubic"] + if(up_mode > 3): + raise ValueError("The upsample mode should be 0-3, but {0:} is given.".format(up_mode)) + self.up_mode = up_mode_values[up_mode] else: + self.up_mode = up_mode.lower() + + if (self.up_mode == "transconv"): self.up = nn.ConvTranspose2d(in_channels1, in_channels2, kernel_size=2, stride=2) + else: + self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size = 1) + if(self.up_mode == "nearest"): + self.up = nn.Upsample(scale_factor=2, mode=self.up_mode) + else: + self.up = nn.Upsample(scale_factor=2, mode=self.up_mode, align_corners=True) self.conv = ConvBNActBlock(in_channels2 * 2, out_channels, dropout_p) def forward(self, x1, x2): - if self.bilinear: + if self.up_mode != "transconv": x1 = self.conv1x1(x1) - x1 = self.up(x1) + x1 = self.up(x1) x_cat = torch.cat([x2, x1], dim=1) y = self.conv(x_cat) return y + x_cat @@ -165,7 +175,7 @@ def __init__(self, params): self.ft_chns = self.params['feature_chns'] self.dropout = self.params['dropout'] self.n_class = self.params['class_num'] - self.bilinear = self.params['bilinear'] + self.up_mode = self.params.get('up_mode', 2) assert(len(self.ft_chns) == 5) f0_half = int(self.ft_chns[0] / 2) @@ -183,10 +193,10 @@ def __init__(self, params): self.bridge2= ConvLayer(self.ft_chns[2], f2_half) self.bridge3= ConvLayer(self.ft_chns[3], f3_half) - self.up1 = UpBlock(self.ft_chns[4], f3_half, self.ft_chns[3], dropout_p = self.dropout[3]) - self.up2 = UpBlock(self.ft_chns[3], f2_half, self.ft_chns[2], dropout_p = self.dropout[2]) - self.up3 = UpBlock(self.ft_chns[2], f1_half, self.ft_chns[1], dropout_p = self.dropout[1]) - self.up4 = UpBlock(self.ft_chns[1], f0_half, self.ft_chns[0], dropout_p = self.dropout[0]) + self.up1 = UpBlock(self.ft_chns[4], f3_half, self.ft_chns[3], self.up_mode, dropout_p = self.dropout[3]) + self.up2 = UpBlock(self.ft_chns[3], f2_half, self.ft_chns[2], self.up_mode, dropout_p = self.dropout[2]) + self.up3 = UpBlock(self.ft_chns[2], f1_half, self.ft_chns[1], self.up_mode, dropout_p = self.dropout[1]) + self.up4 = UpBlock(self.ft_chns[1], f0_half, self.ft_chns[0], self.up_mode, dropout_p = self.dropout[0]) f4 = self.ft_chns[4] aspp_chns = [int(f4 / 4), int(f4 / 4), int(f4 / 4), int(f4 / 4)] diff --git a/pymic/net/net3d/unet3d_scse.py b/pymic/net/net3d/unet3d_scse.py index 79abf9c..49cecc4 100644 --- a/pymic/net/net3d/unet3d_scse.py +++ b/pymic/net/net3d/unet3d_scse.py @@ -76,12 +76,14 @@ class EncoderScSE(Encoder): def __init__(self, params): super(EncoderScSE, self).__init__(params) - self.in_conv= ConvScSEBlock3D(self.in_chns, self.ft_chns[0], self.dropout[0]) - self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) - self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) - self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) + in_chns = self.params['in_chns'] + dropout = self.params['dropout'] + self.in_conv= ConvScSEBlock3D(in_chns, self.ft_chns[0], dropout[0]) + self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], dropout[1]) + self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], dropout[2]) + self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], dropout[3]) if(len(self.ft_chns) == 5): - self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) + self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], dropout[4]) class DecoderScSE(Decoder): """ @@ -92,12 +94,13 @@ class DecoderScSE(Decoder): """ def __init__(self, params): super(DecoderScSE, self).__init__(params) - + dropout = self.params['dropout'] + up_mode = self.params.get('up_mode', 2) if(len(self.ft_chns) == 5): - self.up1 = UpBlockScSE(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.up_mode) - self.up2 = UpBlockScSE(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.up_mode) - self.up3 = UpBlockScSE(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.up_mode) - self.up4 = UpBlockScSE(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.up_mode) + self.up1 = UpBlockScSE(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout[3], up_mode) + self.up2 = UpBlockScSE(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout[2], up_mode) + self.up3 = UpBlockScSE(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout[1], up_mode) + self.up4 = UpBlockScSE(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout[0], up_mode) class UNet3D_ScSE(UNet3D): From bbc75bfcc616698f8ffc50d3fcb2619095510a66 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 4 Oct 2024 17:16:01 +0800 Subject: [PATCH 206/225] Update wsl_gatedcrf.py --- pymic/net_run/weak_sup/wsl_gatedcrf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pymic/net_run/weak_sup/wsl_gatedcrf.py b/pymic/net_run/weak_sup/wsl_gatedcrf.py index 1ecae4a..7eaa67d 100644 --- a/pymic/net_run/weak_sup/wsl_gatedcrf.py +++ b/pymic/net_run/weak_sup/wsl_gatedcrf.py @@ -73,6 +73,7 @@ def training(self): # forward + backward + optimize outputs = self.net(inputs) + t2 = time.time() loss_sup = self.get_loss_value(data, outputs, y) # for gated CRF loss, the input should be like NCHW From 908eb6f67ee54064c7147dec85674d41db6a600c Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 4 Oct 2024 17:24:39 +0800 Subject: [PATCH 207/225] Update nll_co_teaching.py --- pymic/net_run/noisy_label/nll_co_teaching.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pymic/net_run/noisy_label/nll_co_teaching.py b/pymic/net_run/noisy_label/nll_co_teaching.py index d46b05b..33e375c 100644 --- a/pymic/net_run/noisy_label/nll_co_teaching.py +++ b/pymic/net_run/noisy_label/nll_co_teaching.py @@ -50,7 +50,8 @@ def training(self): rampup_end = nll_cfg.get('rampup_end', iter_max) train_loss_no_select1, train_loss_no_select2 = 0, 0 - train_loss1, train_avg_loss2 = 0, 0 + train_loss1, train_avg_loss1 = 0, 0 + train_loss2, train_avg_loss2 = 0, 0 train_dice_list = [] data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() From 6f04440b9417142bd9f8c03efb12547e311e5294 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 15 Nov 2024 13:51:19 +0800 Subject: [PATCH 208/225] Update __init__.py --- pymic/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymic/__init__.py b/pymic/__init__.py index 33943e4..b7531d9 100644 --- a/pymic/__init__.py +++ b/pymic/__init__.py @@ -1,7 +1,7 @@ from __future__ import absolute_import from enum import Enum -__version__ = "0.4.1" +__version__ = "0.4.2" # 2024.11.15 class TaskType(Enum): CLASSIFICATION_ONE_HOT = 1 From 63e9a531ab6b0c99321a5a7b8e2d7c1d005c3312 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 15 Nov 2024 14:34:15 +0800 Subject: [PATCH 209/225] update config for evaluation --- pymic/util/evaluation_cls.py | 58 +++++++++++++++++++++--------------- setup.py | 2 +- 2 files changed, 35 insertions(+), 25 deletions(-) diff --git a/pymic/util/evaluation_cls.py b/pymic/util/evaluation_cls.py index af11a17..a65953a 100644 --- a/pymic/util/evaluation_cls.py +++ b/pymic/util/evaluation_cls.py @@ -3,7 +3,7 @@ Evaluation module for classification tasks. """ from __future__ import absolute_import, print_function - +import argparse import os import csv import sys @@ -75,15 +75,15 @@ def binary_evaluation(config): The arguments are given in the `config` dictionary. It should have the following fields: - :param metric_list: (list) A list of evaluation metrics. + :param metric: (list) A list of evaluation metrics. The supported metrics are {`accuracy`, `recall`, `sensitivity`, `specificity`, `precision`, `auc`}. - :param ground_truth_csv: (str) The csv file for ground truth. - :param predict_prob_csv: (str) The csv file for prediction probability. + :param gt_csv: (str) The csv file for ground truth. + :param pred_prob_csv: (str) The csv file for prediction probability. """ - metric_list = config['metric_list'] - gt_csv = config['ground_truth_csv'] - prob_csv= config['predict_prob_csv'] + metric_list = config['metric'] + gt_csv = config['gt_csv'] + prob_csv= config['pred_prob_csv'] gt_items = pd.read_csv(gt_csv) prob_items = pd.read_csv(prob_csv) assert(len(gt_items) == len(prob_items)) @@ -111,15 +111,15 @@ def nexcl_evaluation(config): The arguments are given in the `config` dictionary. It should have the following fields: - :param metric_list: (list) A list of evaluation metrics. + :param metric: (list) A list of evaluation metrics. The supported metrics are {`accuracy`, `recall`, `sensitivity`, `specificity`, `precision`, `auc`}. - :param ground_truth_csv: (str) The csv file for ground truth. - :param predict_prob_csv: (str) The csv file for prediction probability. + :param gt_csv: (str) The csv file for ground truth. + :param pred_prob_csv: (str) The csv file for prediction probability. """ - metric_list = config['metric_list'] - gt_csv = config['ground_truth_csv'] - prob_csv = config['predict_prob_csv'] + metric_list = config['metric'] + gt_csv = config['gt_csv'] + prob_csv = config['pred_prob_csv'] gt_items = pd.read_csv(gt_csv) prob_items= pd.read_csv(prob_csv) assert(len(gt_items) == len(prob_items)) @@ -163,25 +163,35 @@ def main(): .. code-block:: none - pymic_evaluate_cls config.cfg + pymic_evaluate_cls -cfg config.cfg The configuration file should have an `evaluation` section with the following fields: :param task_type: (str) `cls` or `cls_nexcl`. - :param metric_list: (list) A list of evaluation metrics. + :param metric: (list) A list of evaluation metrics. The supported metrics are {`accuracy`, `recall`, `sensitivity`, `specificity`, `precision`, `auc`}. - :param ground_truth_csv: (str) The csv file for ground truth. - :param predict_prob_csv: (str) The csv file for prediction probability. + :param gt_csv: (str) The csv file for ground truth. + :param pred_prob_csv: (str) The csv file for prediction probability. """ - if(len(sys.argv) < 2): - print('Number of arguments should be 2. e.g.') - print(' pymic_evaluate_cls config.cfg') - exit() - config_file = str(sys.argv[1]) - assert(os.path.isfile(config_file)) - config = parse_config(config_file)['evaluation'] + parser = argparse.ArgumentParser() + parser.add_argument("-cfg", help="configuration file for evaluation", + required=False, default=None) + parser.add_argument("-metric", help="evaluation metrics, e.g., accuracy, or [accuracy, auc]", + required=False, default=None) + parser.add_argument("-gt_csv", help="csv file for ground truth", + required=False, default=None) + parser.add_argument("-pred_prob_csv", help="csv file for probability prediction", + required=False, default=None) + args = parser.parse_args() + print(args) + if(args.cfg is not None): + config = parse_config(args)['evaluation'] + + # config_file = str(sys.argv[1]) + # assert(os.path.isfile(config_file)) + # config = parse_config(config_file)['evaluation'] task_type = config.get('task_type', "cls") if(task_type == "cls"): # default exclusive classification binary_evaluation(config) diff --git a/setup.py b/setup.py index ebb738f..312b1b0 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setuptools.setup( name = 'PYMIC', - version = "0.4.1", + version = "0.4.2", author ='PyMIC Consortium', author_email = 'wguotai@gmail.com', description = description, From 6dff55e782e8d7c055c73bc3e7dd6818d952e11a Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 15 Nov 2024 16:20:48 +0800 Subject: [PATCH 210/225] upgrade to v0.5.0 --- README.md | 4 ++-- pymic/__init__.py | 2 +- setup.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 6f34f92..bedeaae 100644 --- a/README.md +++ b/README.md @@ -47,10 +47,10 @@ Run the following command to install the latest released version of PyMIC: ```bash pip install PYMIC ``` -To install a specific version of PYMIC such as 0.4.1, run: +To install a specific version of PYMIC such as 0.5.0, run: ```bash -pip install PYMIC==0.4.1 +pip install PYMIC==0.5.0 ``` Alternatively, you can download the source code for the latest version. Run the following command to compile and install: diff --git a/pymic/__init__.py b/pymic/__init__.py index b7531d9..ae1775d 100644 --- a/pymic/__init__.py +++ b/pymic/__init__.py @@ -1,7 +1,7 @@ from __future__ import absolute_import from enum import Enum -__version__ = "0.4.2" # 2024.11.15 +__version__ = "0.5.0" # 2024.11.15 class TaskType(Enum): CLASSIFICATION_ONE_HOT = 1 diff --git a/setup.py b/setup.py index 312b1b0..cbf7355 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setuptools.setup( name = 'PYMIC', - version = "0.4.2", + version = "0.5.0", author ='PyMIC Consortium', author_email = 'wguotai@gmail.com', description = description, From 808dfbef597d455b640a9e4df25b6fd3c30b2eb1 Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 3 Dec 2024 15:21:58 +0800 Subject: [PATCH 211/225] update dataloader and nework add FMUNet, and allow missing modalities for multi-modal inputs --- pymic/io/nifty_dataset.py | 16 ++- pymic/net/net3d/fmunet.py | 265 +++++++++++++++++++++++++++++++++++++ pymic/net/net_dict_seg.py | 2 + pymic/net_run/agent_seg.py | 6 +- 4 files changed, 284 insertions(+), 5 deletions(-) create mode 100644 pymic/net/net3d/fmunet.py diff --git a/pymic/io/nifty_dataset.py b/pymic/io/nifty_dataset.py index c3253c9..048ba64 100644 --- a/pymic/io/nifty_dataset.py +++ b/pymic/io/nifty_dataset.py @@ -22,11 +22,13 @@ class NiftyDataset(Dataset): :param transform: (list) List of transforms to be applied on a sample. The built-in transforms can listed in :mod:`pymic.transform.trans_dict`. """ - def __init__(self, root_dir, csv_file, modal_num = 1, + # def __init__(self, root_dir, csv_file, modal_num = 1, + def __init__(self, root_dir, csv_file, modal_num = 1, allow_missing_modal = False, with_label = False, transform=None, task = TaskType.SEGMENTATION): self.root_dir = root_dir self.csv_items = pd.read_csv(csv_file) self.modal_num = modal_num + self.allow_emtpy= allow_missing_modal self.with_label = with_label self.transform = transform self.task = task @@ -89,11 +91,19 @@ def __get_pixel_weight__(self, idx): def __getitem__(self, idx): names_list, image_list = [], [] + image_shape = None for i in range (self.modal_num): image_name = self.csv_items.iloc[idx, i] image_full_name = "{0:}/{1:}".format(self.root_dir, image_name) - image_dict = load_image_as_nd_array(image_full_name) - image_data = image_dict['data_array'] + if(os.path.exists(image_full_name)): + image_dict = load_image_as_nd_array(image_full_name) + image_data = image_dict['data_array'] + elif(self.allow_emtpy and image_shape is not None): + image_data = np.zeros(image_shape) + else: + raise KeyError("File not found: {0:}".format(image_full_name)) + if(i == 0): + image_shape = image_data.shape names_list.append(image_name) image_list.append(image_data) image = np.concatenate(image_list, axis = 0) diff --git a/pymic/net/net3d/fmunet.py b/pymic/net/net3d/fmunet.py new file mode 100644 index 0000000..84ee385 --- /dev/null +++ b/pymic/net/net3d/fmunet.py @@ -0,0 +1,265 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import itertools +import logging +import torch +import torch.nn as nn +from pymic.net.net_init import Initialization_He, Initialization_XavierUniform + +''' +A copy of fmunetv3, and rename the class as FMUNet. +''' +dim0 = {0:3, 1:2, 2:2} +dim1 = {0:3, 1:3, 2:2} +conv_knl = {2: (1, 3, 3), 3: 3} +conv_pad = {2: (0, 1, 1), 3: 1} +pool_knl = {2: (1, 2, 2), 3: 2} +down_stride = {2: (1, 2, 2), 3: 2} + +class ResConv(nn.Module): + def __init__(self, out_channels, dim = 3, dropout_p = 0.0, depth = 2): + super(ResConv, self).__init__() + assert(dim == 2 or dim == 3) + self.out_channels = out_channels + self.conv_list = nn.ModuleList([nn.Sequential( + nn.InstanceNorm3d(out_channels, affine = True), + nn.LeakyReLU(), + nn.Dropout(dropout_p), + nn.Conv3d(out_channels, out_channels, kernel_size=conv_knl[dim], padding=conv_pad[dim])) + for i in range(depth)]) + + def forward(self, x): + for conv in self.conv_list: + x = conv(x) + x + return x + +class DownSample(nn.Module): + """downsampling based on convolution + + :param in_channels: (int) Input channel number. + :param out_channels: (int) Output channel number. + :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. + :param dropout_p: (int) Dropout probability. + :param downsample: (bool) Use downsample or not after convolution. + """ + def __init__(self, in_channels, out_channels, dim = 3): + super(DownSample, self).__init__() + self.down = nn.Sequential( + nn.InstanceNorm3d(in_channels, affine = True), + nn.LeakyReLU(), + nn.Conv3d(in_channels, out_channels, kernel_size=conv_knl[dim], + padding=conv_pad[dim], stride = down_stride[dim]) + ) + + def forward(self, x): + return self.down(x) + +class UpCatConv(nn.Module): + """Upsampling followed by `ResConv` block + + :param in_channels1: (int) Input channel number for low-resolution feature map. + :param in_channels2: (int) Input channel number for high-resolution feature map. + :param out_channels: (int) Output channel number. + :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. + :param dropout_p: (int) Dropout probability. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear` for 3D and `Bilinear` for 2D). + The default value is 2. + """ + def __init__(self, in_channels1, in_channels2, out_channels, dim = 3): + super(UpCatConv, self).__init__() + + self.up = nn.Sequential( + nn.InstanceNorm3d(in_channels1, affine = True), + nn.LeakyReLU(), + nn.Conv3d(in_channels1, in_channels2, kernel_size=1, padding=0), + nn.Upsample(scale_factor=pool_knl[dim], mode='trilinear', align_corners=True) + ) + + self.conv = nn.Sequential( + nn.InstanceNorm3d(in_channels2*2, affine = True), + nn.LeakyReLU(), + nn.Conv3d(in_channels2 * 2, out_channels, kernel_size=conv_knl[dim], padding=conv_pad[dim]) + ) + + def forward(self, x_l, x_h): + """ + x_l: low-resolution feature map. + x_h: high-resolution feature map. + """ + y = torch.cat([x_h, self.up(x_l)], dim=1) + return self.conv(y) + +class Encoder(nn.Module): + """ + A modification of the encoder of 3D UNet by using ConvScSEBlock3D + + Parameters are given in the `params` dictionary. + See :mod:`pymic.net.net3d.unet3d.Encoder` for details. + + res_mode: resolution mode: 0-- isotrpic, 1-- near isotrpic, 2-- isotropic + """ + def __init__(self, ft_chns, res_mode = 0, dropout_p = 0, depth = 2): + super(Encoder, self).__init__() + d0, d1 = dim0[res_mode], dim1[res_mode] + + self.en_conv0 = ResConv(ft_chns[0], d0, 0, depth) + self.en_conv1 = ResConv(ft_chns[1], d1, 0, depth) + self.en_conv2 = ResConv(ft_chns[2], 3, dropout_p, depth) + self.en_conv3 = ResConv(ft_chns[3], 3, dropout_p, depth) + self.en_conv4 = ResConv(ft_chns[4], 3, dropout_p, depth) + + self.down0 = DownSample(ft_chns[0], ft_chns[1], d0) + self.down1 = DownSample(ft_chns[1], ft_chns[2], d1) + self.down2 = DownSample(ft_chns[2], ft_chns[3], 3) + self.down3 = DownSample(ft_chns[3], ft_chns[4], 3) + + def forward(self, x): + x0 = self.en_conv0(x) + x1 = self.en_conv1(self.down0(x0)) + x2 = self.en_conv2(self.down1(x1)) + x3 = self.en_conv3(self.down2(x2)) + x4 = self.en_conv4(self.down3(x3)) + return [x0, x1, x2, x3, x4] + +class Decoder(nn.Module): + """ + A modification of the encoder of 3D UNet by using ConvScSEBlock3D + + Parameters are given in the `params` dictionary. + See :mod:`pymic.net.net3d.unet3d.Encoder` for details. + """ + def __init__(self, ft_chns, res_mode = 0, dropout_p = 0, depth = 2): + super(Decoder, self).__init__() + d0, d1 = dim0[res_mode], dim1[res_mode] + + self.upcat0 = UpCatConv(ft_chns[1], ft_chns[0], ft_chns[0], d0) + self.upcat1 = UpCatConv(ft_chns[2], ft_chns[1], ft_chns[1], d1) + self.upcat2 = UpCatConv(ft_chns[3], ft_chns[2], ft_chns[2], 3) + self.upcat3 = UpCatConv(ft_chns[4], ft_chns[3], ft_chns[3], 3) + + self.de_conv0 = ResConv(ft_chns[0], d0, 0, depth) + self.de_conv1 = ResConv(ft_chns[1], d1, 0, depth) + self.de_conv2 = ResConv(ft_chns[2], 3, dropout_p, depth) + self.de_conv3 = ResConv(ft_chns[3], 3, dropout_p, depth) + self.de_conv4 = ResConv(ft_chns[4], 3, dropout_p, depth) + + def forward(self, x): + x0, x1, x2, x3, x4 = x + x4_de = self.de_conv4(x4) + x3_de = self.de_conv3(self.upcat3(x4_de, x3)) + x2_de = self.de_conv2(self.upcat2(x3_de, x2)) + x1_de = self.de_conv1(self.upcat1(x2_de, x1)) + x0_de = self.de_conv0(self.upcat0(x1_de, x0)) + return [x0_de, x1_de, x2_de, x3_de] + +class FMUNet(nn.Module): + """ + A 2.5D network combining 3D convolutions with 2D convolutions. + + * Reference: Guotai Wang, Jonathan Shapey, Wenqi Li, Reuben Dorent, Alex Demitriadis, + Sotirios Bisdas, Ian Paddick, Robert Bradford, Shaoting Zhang, Sébastien Ourselin, + Tom Vercauteren: Automatic Segmentation of Vestibular Schwannoma from T2-Weighted + MRI by Deep Spatial Attention with Hardness-Weighted Loss. + `MICCAI (2) 2019: 264-272. `_ + + Note that the attention module in the orininal paper is not used here. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param conv_dims: (list) The convolution dimension (2 or 3) for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear`). The default value + is 2 (`Trilinear`). + :param multiscale_pred: (bool) Get multi-scale prediction. + """ + def __init__(self, params): + super(FMUNet, self).__init__() + params = self.get_default_parameters(params) + + self.stage = 'train' + in_chns = params['in_chns'] + ft_chns = params['feature_chns'] + res_mode = params['res_mode'] + dropout = params['dropout'] + depth = params['depth'] + cls_num = params['class_num'] + self.mul_pred = params.get('multiscale_pred', True) + self.tune_mode= params.get('finetune_mode', 'all') + self.load_mode= params.get('weights_load_mode', 'all') + + d0 = dim0[res_mode] + self.project = nn.Conv3d(in_chns, ft_chns[0], kernel_size=conv_knl[d0], padding=conv_pad[d0]) + self.encoder = Encoder(ft_chns, res_mode, dropout, depth) + # self.decoder = Decoder(ft_chns, res_mode, dropout, depth = 2) + self.decoder = Decoder(ft_chns, res_mode, dropout, depth) + + self.out_layers = nn.ModuleList() + dims = [dim0[res_mode], dim1[res_mode], 3, 3] + for i in range(4): + out_layer = nn.Sequential( + nn.InstanceNorm3d(ft_chns[i], affine = True), + nn.LeakyReLU(), + nn.Conv3d(ft_chns[i], cls_num, kernel_size=conv_knl[dims[i]], padding=conv_pad[dims[i]])) + self.out_layers.append(out_layer) + + init = params['initialization'].lower() + weightInitializer = Initialization_He(1e-2) if init == 'he' else Initialization_XavierUniform() + self.apply(weightInitializer) + + def get_default_parameters(self, params): + default_param = { + 'finetune_mode': 'all', + 'initialization': 'he', + 'feature_chns': [32, 64, 128, 256, 512], + 'dropout': 0.2, + 'res_mode': 0, + 'depth': 2, + 'multiscale_pred': True + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params + + def set_stage(self, stage): + self.stage = stage + + def forward(self, x): + x_en = self.encoder(self.project(x)) + x_de = self.decoder(x_en) + output = self.out_layers[0](x_de[0]) + if(self.mul_pred and self.stage == 'train'): + output = [output] + for i in range(1, len(x_de)): + output.append(self.out_layers[i](x_de[i])) + return output + + def get_parameters_to_update(self): + if(self.tune_mode == 'all'): + return self.parameters() + + up_params = itertools.chain() + if(self.tune_mode == 'decoder'): + up_blocks = [self.decoder, self.out_layers] + else: + raise ValueError("undefined fine-tune mode for FMUNet: {0:}".format(self.tune_mode)) + for block in up_blocks: + up_params = itertools.chain(up_params, block.parameters()) + return up_params + + def get_parameters_to_load(self): + state_dict = self.state_dict() + if(self.load_mode == 'encoder'): + state_dict = {k:v for k, v in state_dict.items() if "project" in k or "encoder" in k } + return state_dict \ No newline at end of file diff --git a/pymic/net/net_dict_seg.py b/pymic/net/net_dict_seg.py index d7f759b..8877ade 100644 --- a/pymic/net/net_dict_seg.py +++ b/pymic/net/net_dict_seg.py @@ -30,6 +30,7 @@ from pymic.net.net3d.unet3d import UNet3D from pymic.net.net3d.grunet import GRUNet from pymic.net.net3d.fmunetv3 import FMUNetV3 +from pymic.net.net3d.fmunet import FMUNet from pymic.net.net3d.lcovnet import LCOVNet from pymic.net.net3d.unet3d_scse import UNet3D_ScSE from pymic.net.net3d.unet3d_dual_branch import UNet3D_DualBranch @@ -66,6 +67,7 @@ 'UNet2D5': UNet2D5, 'GRUNet': GRUNet, 'LCOVNet': LCOVNet, + 'FMUNet': FMUNet, 'FMUNetV3': FMUNetV3, 'UNet3D': UNet3D, 'UNet3D_ScSE': UNet3D_ScSE, diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 59ffe05..cb30b45 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -68,8 +68,9 @@ def get_stage_dataset_from_config(self, stage): self.test_transforms = transform_list else: with_label = self.config['dataset'].get(stage + '_label', True) - modal_num = self.config['dataset'].get('modal_num', 1) - stage_dir = self.config['dataset'].get('train_dir', None) + modal_num = self.config['dataset'].get('modal_num', 1) + allow_miss = self.config['dataset'].get('allow_missing_modal', False) + stage_dir = self.config['dataset'].get('train_dir', None) if(stage == 'valid' and "valid_dir" in self.config['dataset']): stage_dir = self.config['dataset']['valid_dir'] if(stage == 'test' and "test_dir" in self.config['dataset']): @@ -78,6 +79,7 @@ def get_stage_dataset_from_config(self, stage): dataset = NiftyDataset(root_dir = stage_dir, csv_file = csv_file, modal_num = modal_num, + allow_missing_modal = allow_miss, with_label= with_label, transform = data_transform, task = self.task_type) From 2a240b1751c6139af1abc7882dd3498687e4d04a Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 3 Dec 2024 16:58:07 +0800 Subject: [PATCH 212/225] add postprocess dictionary --- pymic/net_run/agent_abstract.py | 34 +++++++++++++++++++++++---------- pymic/net_run/agent_rec.py | 14 ++++++-------- pymic/net_run/agent_seg.py | 11 +---------- 3 files changed, 31 insertions(+), 28 deletions(-) diff --git a/pymic/net_run/agent_abstract.py b/pymic/net_run/agent_abstract.py index d8abb19..01ad808 100644 --- a/pymic/net_run/agent_abstract.py +++ b/pymic/net_run/agent_abstract.py @@ -56,6 +56,8 @@ def __init__(self, config, stage = 'train'): self.loss_dict = None self.transform_dict = None self.inferer = None + self.postprocess_dict = None + self.postprocessor = None self.tensor_type = config['dataset']['tensor_type'] self.task_type = config['dataset']['task_type'] self.deterministic = config['training'].get('deterministic', True) @@ -102,6 +104,14 @@ def set_net_dict(self, net_dict): """ self.net_dict = net_dict + def set_postprocess_dict(self, postprocess_dict): + """ + Set the available methods for postprocess, including customized postprocess methods. + + :param postprocess_dict: (dictionary) A dictionary of available postprocess methods. + """ + self.postprocess_dict = postprocess_dict + def set_loss_dict(self, loss_dict): """ Set the available loss functions, including customized loss functions. @@ -258,7 +268,11 @@ def create_dataset(self): if(self.train_set is None): self.train_set = self.get_stage_dataset_from_config('train') if(self.valid_set is None): - self.valid_set = self.get_stage_dataset_from_config('valid') + valid_csv = self.config['dataset'].get('valid_csv', None) + if valid_csv is not None: + self.valid_set = self.get_stage_dataset_from_config('valid') + else: + logging.warning("Dataset for validation is not created, as valid_dir is not provided.") if(self.deterministic): def worker_init_fn(worker_id): # workder_seed = self.random_seed+worker_id @@ -269,18 +283,20 @@ def worker_init_fn(worker_id): else: worker_init = None - bn_train = self.config['dataset']['train_batch_size'] - bn_valid = self.config['dataset'].get('valid_batch_size', 1) num_worker = self.config['dataset'].get('num_worker', 8) - g_train, g_valid = torch.Generator(), torch.Generator() + bn_train = self.config['dataset']['train_batch_size'] + g_train = torch.Generator() g_train.manual_seed(self.random_seed) - g_valid.manual_seed(self.random_seed) self.train_loader = torch.utils.data.DataLoader(self.train_set, batch_size = bn_train, shuffle=True, num_workers= num_worker, worker_init_fn=worker_init, generator = g_train, drop_last = True) - self.valid_loader = torch.utils.data.DataLoader(self.valid_set, - batch_size = bn_valid, shuffle=False, num_workers= num_worker, - worker_init_fn=worker_init, generator = g_valid) + if(self.valid_set is not None): + bn_valid = self.config['dataset'].get('valid_batch_size', 1) + g_valid = torch.Generator() + g_valid.manual_seed(self.random_seed) + self.valid_loader = torch.utils.data.DataLoader(self.valid_set, + batch_size = bn_valid, shuffle=False, num_workers= num_worker, + worker_init_fn=worker_init, generator = g_valid) else: bn_test = self.config['dataset'].get('test_batch_size', 1) if(self.test_set is None): @@ -327,5 +343,3 @@ def run(self): else: self.infer() - - diff --git a/pymic/net_run/agent_rec.py b/pymic/net_run/agent_rec.py index cc83526..0d78b43 100644 --- a/pymic/net_run/agent_rec.py +++ b/pymic/net_run/agent_rec.py @@ -108,8 +108,8 @@ def training(self): loss = self.get_loss_value(data, outputs, label) t3 = time.time() loss.backward() - t4 = time.time() self.optimizer.step() + t4 = time.time() train_loss = train_loss + loss.item() if(isinstance(outputs, tuple) or isinstance(outputs, list)): @@ -175,8 +175,7 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): 'valid':valid_scalars['loss']} self.summ_writer.add_scalars('loss', loss_scalar, glob_it) self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it) - logging.info('train loss {0:.4f}'.format(train_scalars['loss'])) - logging.info('valid loss {0:.4f}'.format(valid_scalars['loss'])) + logging.info('train/valid loss {0:.4f}/{1:.4f}'.format(train_scalars['loss'],valid_scalars['loss'])) logging.info("data: {0:.2f}s, forward: {1:.2f}s, loss: {2:.2f}s, backward: {3:.2f}s".format( train_scalars['data_time'], train_scalars['forward_time'], train_scalars['loss_time'], train_scalars['backward_time'])) @@ -259,9 +258,6 @@ def train_valid(self): logging.info("\n{0:} it {1:}".format(str(datetime.now())[:-7], self.glob_it)) logging.info('learning rate {0:}'.format(lr_value)) logging.info("training/validation time: {0:.2f}s/{1:.2f}s".format(t1-t0, t2-t1)) - logging.info("data: {0:.2f}s, forward: {1:.2f}s, loss: {2:.2f}s, backward: {3:.2f}s".format( - train_scalars['data_time'], train_scalars['forward_time'], - train_scalars['loss_time'], train_scalars['backward_time'])) self.write_scalars(train_scalars, valid_scalars, lr_value, self.glob_it) if(valid_scalars['loss'] < self.min_val_loss): self.min_val_loss = valid_scalars['loss'] @@ -320,19 +316,21 @@ def save_outputs(self, data): if(isinstance(pred, (list, tuple))): pred = pred[0] pred = np.tanh(pred) + if(self.postprocessor is not None): + pred = self.postprocessor(pred) # pred = scipy.special.expit(pred) # save the output predictions test_dir = self.config['dataset'].get('test_dir', None) if(test_dir is None): test_dir = self.config['dataset']['train_dir'] - for i in range(len(names)): + for i in range(pred.shape[1]): save_name = names[i][0].split('/')[-1] if ignore_dir else \ names[i][0].replace('/', '_') if((filename_replace_source is not None) and (filename_replace_target is not None)): save_name = save_name.replace(filename_replace_source, filename_replace_target) print(save_name) save_name = "{0:}/{1:}".format(output_dir, save_name) - save_nd_array_as_image(pred[i][i], save_name, test_dir + '/' + names[i][0]) + save_nd_array_as_image(pred[i][0], save_name, test_dir + '/' + names[i][0]) \ No newline at end of file diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index cb30b45..45e815b 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -150,15 +150,6 @@ def get_loss_value(self, data, pred, gt, param = None): loss_input_dict['class_weight'] = class_weight.to(device) loss_value = self.loss_calculator(loss_input_dict) return loss_value - - def set_postprocessor(self, postprocessor): - """ - Set post processor after prediction. - - :param postprocessor: post processor, such as an instance of - `pymic.util.post_process.PostProcess`. - """ - self.postprocessor = postprocessor def training(self): class_num = self.config['network']['class_num'] @@ -476,7 +467,7 @@ def test_time_dropout(m): self.inferer = Inferer(infer_cfg) postpro_name = self.config['testing'].get('post_process', None) if(self.postprocessor is None and postpro_name is not None): - self.postprocessor = PostProcessDict[postpro_name](self.config['testing']) + self.postprocessor = self.postprocess_dict[postpro_name](self.config['testing']) infer_time_list = [] with torch.no_grad(): for data in self.test_loader: From 4625b0e048d0c9d053301f1fe406f1502ba48349 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 31 Jan 2025 17:40:48 +0800 Subject: [PATCH 213/225] add umamba to pymic --- pymic/net/net2d/umamba.py | 1234 +++++++++++++++++++++++++++++++++++++ pymic/net/net_dict_seg.py | 3 + pymic/test/test_net2d.py | 15 +- 3 files changed, 1250 insertions(+), 2 deletions(-) create mode 100644 pymic/net/net2d/umamba.py diff --git a/pymic/net/net2d/umamba.py b/pymic/net/net2d/umamba.py new file mode 100644 index 0000000..63ccacf --- /dev/null +++ b/pymic/net/net2d/umamba.py @@ -0,0 +1,1234 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import logging +import numpy as np +import math +import torch +from torch import nn +from torch.nn import functional as F +from typing import Union, Type, List, Tuple + +from torch.nn.modules.conv import _ConvNd +from torch.nn.modules.dropout import _DropoutNd +from mamba_ssm import Mamba +from torch.cuda.amp import autocast + +# from dynamic_network_architectures.building_blocks.helper import convert_conv_op_to_dim +# from nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager, PlansManager +# from dynamic_network_architectures.building_blocks.helper import get_matching_instancenorm, convert_dim_to_conv_op +# from nnunetv2.utilities.network_initialization import InitWeights_He +# from dynamic_network_architectures.building_blocks.helper import maybe_convert_scalar_to_list, get_matching_pool_op + +def dim_of_conv_op(conv_op: Type[_ConvNd]) -> int: + """ + :param conv_op: conv class + :return: dimension: 1, 2 or 3 + """ + if conv_op == nn.Conv1d: + return 1 + elif conv_op == nn.Conv2d: + return 2 + elif conv_op == nn.Conv3d: + return 3 + else: + raise ValueError("Unknown dimension. Only 1d 2d and 3d conv are supported. got %s" % str(conv_op)) + +def get_matching_pool_op(conv_op: Type[_ConvNd] = None, + dimension: int = None, + adaptive=False, + pool_type: str = 'avg') -> Type[torch.nn.Module]: + """ + You MUST set EITHER conv_op OR dimension. Do not set both! + :param conv_op: + :param dimension: + :param adaptive: + :param pool_type: either 'avg' or 'max' + :return: + """ + assert not ((conv_op is not None) and (dimension is not None)), \ + "You MUST set EITHER conv_op OR dimension. Do not set both!" + assert pool_type in ['avg', 'max'], 'pool_type must be either avg or max' + if conv_op is not None: + dimension = dim_of_conv_op(conv_op) + assert dimension in [1, 2, 3], 'Dimension must be 1, 2 or 3' + + if conv_op is not None: + dimension = dim_of_conv_op(conv_op) + + if dimension == 1: + if pool_type == 'avg': + if adaptive: + return nn.AdaptiveAvgPool1d + else: + return nn.AvgPool1d + elif pool_type == 'max': + if adaptive: + return nn.AdaptiveMaxPool1d + else: + return nn.MaxPool1d + elif dimension == 2: + if pool_type == 'avg': + if adaptive: + return nn.AdaptiveAvgPool2d + else: + return nn.AvgPool2d + elif pool_type == 'max': + if adaptive: + return nn.AdaptiveMaxPool2d + else: + return nn.MaxPool2d + elif dimension == 3: + if pool_type == 'avg': + if adaptive: + return nn.AdaptiveAvgPool3d + else: + return nn.AvgPool3d + elif pool_type == 'max': + if adaptive: + return nn.AdaptiveMaxPool3d + else: + return nn.MaxPool3d + +def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): + """ + This function is taken from the timm package (https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py). + + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + +class DropPath(nn.Module): + """ + This class is taken from the timm package (https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py). + + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + +def make_divisible(v, divisor=8, min_value=None, round_limit=.9): + """ + This function is taken from the timm package (https://github.com/rwightman/pytorch-image-models/blob/b7cb8d0337b3e7b50516849805ddb9be5fc11644/timm/models/layers/helpers.py#L25) + """ + min_value = min_value or divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < round_limit * v: + new_v += divisor + return new_v + +class SqueezeExcite(nn.Module): + """ + This class is taken from the timm package (https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/squeeze_excite.py) + and slightly modified so that the convolution type can be adapted. + + SE Module as defined in original SE-Nets with a few additions + Additions include: + * divisor can be specified to keep channels % div == 0 (default: 8) + * reduction channels can be specified directly by arg (if rd_channels is set) + * reduction channels can be specified by float rd_ratio (default: 1/16) + * global max pooling can be added to the squeeze aggregation + * customizable activation, normalization, and gate layer + """ + def __init__( + self, channels, conv_op, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False, + act_layer=nn.ReLU, norm_layer=None, gate_layer=nn.Sigmoid): + super(SqueezeExcite, self).__init__() + self.add_maxpool = add_maxpool + if not rd_channels: + rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) + self.fc1 = conv_op(channels, rd_channels, kernel_size=1, bias=True) + self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity() + self.act = act_layer(inplace=True) + self.fc2 = conv_op(rd_channels, channels, kernel_size=1, bias=True) + self.gate = gate_layer() + + def forward(self, x): + x_se = x.mean((2, 3), keepdim=True) + if self.add_maxpool: + # experimental codepath, may remove or change + x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True) + x_se = self.fc1(x_se) + x_se = self.act(self.bn(x_se)) + x_se = self.fc2(x_se) + return x * self.gate(x_se) + + +class ConvDropoutNormReLU(nn.Module): + def __init__(self, + conv_op: Type[_ConvNd], + input_channels: int, + output_channels: int, + kernel_size: Union[int, List[int], Tuple[int, ...]], + stride: Union[int, List[int], Tuple[int, ...]], + conv_bias: bool = False, + norm_op: Union[None, Type[nn.Module]] = None, + norm_op_kwargs: dict = None, + dropout_op: Union[None, Type[_DropoutNd]] = None, + dropout_op_kwargs: dict = None, + nonlin: Union[None, Type[torch.nn.Module]] = None, + nonlin_kwargs: dict = None, + nonlin_first: bool = False + ): + super(ConvDropoutNormReLU, self).__init__() + self.input_channels = input_channels + self.output_channels = output_channels + if not isinstance(stride, (tuple, list, np.ndarray)): + stride = [stride] * dim_of_conv_op(conv_op) + self.stride = stride + + if not isinstance(kernel_size, (tuple, list, np.ndarray)): + kernel_size = [kernel_size] * dim_of_conv_op(conv_op) + if norm_op_kwargs is None: + norm_op_kwargs = {} + if nonlin_kwargs is None: + nonlin_kwargs = {} + + ops = [] + + self.conv = conv_op( + input_channels, + output_channels, + kernel_size, + stride, + padding=[(i - 1) // 2 for i in kernel_size], + dilation=1, + bias=conv_bias, + ) + ops.append(self.conv) + + if dropout_op is not None: + self.dropout = dropout_op(**dropout_op_kwargs) + ops.append(self.dropout) + + if norm_op is not None: + self.norm = norm_op(output_channels, **norm_op_kwargs) + ops.append(self.norm) + + if nonlin is not None: + self.nonlin = nonlin(**nonlin_kwargs) + ops.append(self.nonlin) + + if nonlin_first and (norm_op is not None and nonlin is not None): + ops[-1], ops[-2] = ops[-2], ops[-1] + + self.all_modules = nn.Sequential(*ops) + + def forward(self, x): + return self.all_modules(x) + + def compute_conv_feature_map_size(self, input_size): + assert len(input_size) == len(self.stride), "just give the image size without color/feature channels or " \ + "batch channel. Do not give input_size=(b, c, x, y(, z)). " \ + "Give input_size=(x, y(, z))!" + output_size = [i // j for i, j in zip(input_size, self.stride)] # we always do same padding + return np.prod([self.output_channels, *output_size], dtype=np.int64) + +# # from dynamic_network_architectures.building_blocks.residual import BasicBlockD +class BasicBlockD(nn.Module): + def __init__(self, + conv_op: Type[_ConvNd], + input_channels: int, + output_channels: int, + kernel_size: Union[int, List[int], Tuple[int, ...]], + stride: Union[int, List[int], Tuple[int, ...]], + conv_bias: bool = False, + norm_op: Union[None, Type[nn.Module]] = None, + norm_op_kwargs: dict = None, + dropout_op: Union[None, Type[_DropoutNd]] = None, + dropout_op_kwargs: dict = None, + nonlin: Union[None, Type[torch.nn.Module]] = None, + nonlin_kwargs: dict = None, + stochastic_depth_p: float = 0.0, + squeeze_excitation: bool = False, + squeeze_excitation_reduction_ratio: float = 1. / 16, + # todo wideresnet? + ): + """ + This implementation follows ResNet-D: + + He, Tong, et al. "Bag of tricks for image classification with convolutional neural networks." + Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2019. + + The skip has an avgpool (if needed) followed by 1x1 conv instead of just a strided 1x1 conv + + :param conv_op: + :param input_channels: + :param output_channels: + :param kernel_size: refers only to convs in feature extraction path, not to 1x1x1 conv in skip + :param stride: only applies to first conv (and skip). Second conv always has stride 1 + :param conv_bias: + :param norm_op: + :param norm_op_kwargs: + :param dropout_op: only the first conv can have dropout. The second never has + :param dropout_op_kwargs: + :param nonlin: + :param nonlin_kwargs: + :param stochastic_depth_p: + :param squeeze_excitation: + :param squeeze_excitation_reduction_ratio: + """ + super().__init__() + self.input_channels = input_channels + self.output_channels = output_channels + if not isinstance(stride, (tuple, list, np.ndarray)): + stride = [stride] * dim_of_conv_op(conv_op) + self.stride = stride + + if not isinstance(kernel_size, (tuple, list, np.ndarray)): + kernel_size = [kernel_size] * dim_of_conv_op(conv_op) + + if norm_op_kwargs is None: + norm_op_kwargs = {} + if nonlin_kwargs is None: + nonlin_kwargs = {} + + self.conv1 = ConvDropoutNormReLU(conv_op, input_channels, output_channels, kernel_size, stride, conv_bias, + norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs) + self.conv2 = ConvDropoutNormReLU(conv_op, output_channels, output_channels, kernel_size, 1, conv_bias, norm_op, + norm_op_kwargs, None, None, None, None) + + self.nonlin2 = nonlin(**nonlin_kwargs) if nonlin is not None else lambda x: x + + # Stochastic Depth + self.apply_stochastic_depth = False if stochastic_depth_p == 0.0 else True + if self.apply_stochastic_depth: + self.drop_path = DropPath(drop_prob=stochastic_depth_p) + + # Squeeze Excitation + self.apply_se = squeeze_excitation + if self.apply_se: + self.squeeze_excitation = SqueezeExcite(self.output_channels, conv_op, + rd_ratio=squeeze_excitation_reduction_ratio, rd_divisor=8) + + has_stride = (isinstance(stride, int) and stride != 1) or any([i != 1 for i in stride]) + requires_projection = (input_channels != output_channels) + + if has_stride or requires_projection: + ops = [] + if has_stride: + ops.append(get_matching_pool_op(conv_op=conv_op, adaptive=False, pool_type='avg')(stride, stride)) + if requires_projection: + ops.append( + ConvDropoutNormReLU(conv_op, input_channels, output_channels, 1, 1, False, norm_op, + norm_op_kwargs, None, None, None, None + ) + ) + self.skip = nn.Sequential(*ops) + else: + self.skip = lambda x: x + + def forward(self, x): + residual = self.skip(x) + out = self.conv2(self.conv1(x)) + if self.apply_stochastic_depth: + out = self.drop_path(out) + if self.apply_se: + out = self.squeeze_excitation(out) + out += residual + return self.nonlin2(out) + + def compute_conv_feature_map_size(self, input_size): + assert len(input_size) == len(self.stride), "just give the image size without color/feature channels or " \ + "batch channel. Do not give input_size=(b, c, x, y(, z)). " \ + "Give input_size=(x, y(, z))!" + size_after_stride = [i // j for i, j in zip(input_size, self.stride)] + # conv1 + output_size_conv1 = np.prod([self.output_channels, *size_after_stride], dtype=np.int64) + # conv2 + output_size_conv2 = np.prod([self.output_channels, *size_after_stride], dtype=np.int64) + # skip conv (if applicable) + if (self.input_channels != self.output_channels) or any([i != j for i, j in zip(input_size, size_after_stride)]): + assert isinstance(self.skip, nn.Sequential) + output_size_skip = np.prod([self.output_channels, *size_after_stride], dtype=np.int64) + else: + assert not isinstance(self.skip, nn.Sequential) + output_size_skip = 0 + return output_size_conv1 + output_size_conv2 + output_size_skip + +class UpsampleLayer(nn.Module): + def __init__( + self, + conv_op, + input_channels, + output_channels, + pool_op_kernel_size, + mode='nearest' + ): + super().__init__() + self.conv = conv_op(input_channels, output_channels, kernel_size=1) + self.pool_op_kernel_size = pool_op_kernel_size + self.mode = mode + + def forward(self, x): + x = F.interpolate(x, scale_factor=self.pool_op_kernel_size, mode=self.mode) + x = self.conv(x) + return x + +# class MambaLayer(nn.Module): +# def __init__(self, dim, d_state = 16, d_conv = 4, expand = 2): +# super().__init__() +# self.dim = dim +# self.norm = nn.LayerNorm(dim) +# self.mamba = Mamba( +# d_model=dim, # Model dimension d_model +# d_state=d_state, # SSM state expansion factor +# d_conv=d_conv, # Local convolution width +# expand=expand, # Block expansion factor +# ) + +# @autocast(enabled=False) +# def forward(self, x): +# if x.dtype == torch.float16: +# x = x.type(torch.float32) +# B, C = x.shape[:2] +# assert C == self.dim +# n_tokens = x.shape[2:].numel() +# img_dims = x.shape[2:] +# x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2) +# x_norm = self.norm(x_flat) +# x_mamba = self.mamba(x_norm) +# out = x_mamba.transpose(-1, -2).reshape(B, C, *img_dims) + +# return out + +class MambaLayer(nn.Module): + def __init__(self, dim, d_state = 16, d_conv = 4, expand = 2, channel_token = False): + super().__init__() + self.dim = dim + self.norm = nn.LayerNorm(dim) + self.mamba = Mamba( + d_model=dim, # Model dimension d_model + d_state=d_state, # SSM state expansion factor + d_conv=d_conv, # Local convolution width + expand=expand, # Block expansion factor + ) + self.channel_token = channel_token ## whether to use channel as tokens + + def forward_patch_token(self, x): + B, d_model = x.shape[:2] + assert d_model == self.dim + n_tokens = x.shape[2:].numel() + img_dims = x.shape[2:] + x_flat = x.reshape(B, d_model, n_tokens).transpose(-1, -2) + x_norm = self.norm(x_flat) + x_mamba = self.mamba(x_norm) + out = x_mamba.transpose(-1, -2).reshape(B, d_model, *img_dims) + + return out + + def forward_channel_token(self, x): + B, n_tokens = x.shape[:2] + d_model = x.shape[2:].numel() + assert d_model == self.dim, f"d_model: {d_model}, self.dim: {self.dim}" + img_dims = x.shape[2:] + x_flat = x.flatten(2) + assert x_flat.shape[2] == d_model, f"x_flat.shape[2]: {x_flat.shape[2]}, d_model: {d_model}" + x_norm = self.norm(x_flat) + x_mamba = self.mamba(x_norm) + out = x_mamba.reshape(B, n_tokens, *img_dims) + + return out + + @autocast(enabled=False) + def forward(self, x): + if x.dtype == torch.float16: + x = x.type(torch.float32) + + if self.channel_token: + out = self.forward_channel_token(x) + else: + out = self.forward_patch_token(x) + + return out + + +class BasicResBlock(nn.Module): + def __init__( + self, + conv_op, + input_channels, + output_channels, + norm_op, + norm_op_kwargs, + kernel_size=3, + padding=1, + stride=1, + use_1x1conv=False, + nonlin=nn.LeakyReLU, + nonlin_kwargs={'inplace': True} + ): + super().__init__() + + self.conv1 = conv_op(input_channels, output_channels, kernel_size, stride=stride, padding=padding) + self.norm1 = norm_op(output_channels, **norm_op_kwargs) + self.act1 = nonlin(**nonlin_kwargs) + + self.conv2 = conv_op(output_channels, output_channels, kernel_size, padding=padding) + self.norm2 = norm_op(output_channels, **norm_op_kwargs) + self.act2 = nonlin(**nonlin_kwargs) + + if use_1x1conv: + self.conv3 = conv_op(input_channels, output_channels, kernel_size=1, stride=stride) + else: + self.conv3 = None + + def forward(self, x): + y = self.conv1(x) + y = self.act1(self.norm1(y)) + y = self.norm2(self.conv2(y)) + if self.conv3: + x = self.conv3(x) + y += x + return self.act2(y) + +class UNetResEncoder(nn.Module): + def __init__(self, + input_channels: int, + n_stages: int, + features_per_stage: Union[int, List[int], Tuple[int, ...]], + conv_op: Type[_ConvNd], + kernel_sizes: Union[int, List[int], Tuple[int, ...]], + strides: Union[int, List[int], Tuple[int, ...], Tuple[Tuple[int, ...], ...]], + n_blocks_per_stage: Union[int, List[int], Tuple[int, ...]], + conv_bias: bool = False, + norm_op: Union[None, Type[nn.Module]] = None, + norm_op_kwargs: dict = None, + nonlin: Union[None, Type[torch.nn.Module]] = None, + nonlin_kwargs: dict = None, + return_skips: bool = False, + stem_channels: int = None, + pool_type: str = 'conv', + ): + super().__init__() + if isinstance(kernel_sizes, int): + kernel_sizes = [kernel_sizes] * n_stages + if isinstance(features_per_stage, int): + features_per_stage = [features_per_stage] * n_stages + if isinstance(n_blocks_per_stage, int): + n_blocks_per_stage = [n_blocks_per_stage] * n_stages + if isinstance(strides, int): + strides = [strides] * n_stages + + assert len( + kernel_sizes) == n_stages, "kernel_sizes must have as many entries as we have resolution stages (n_stages)" + assert len( + n_blocks_per_stage) == n_stages, "n_conv_per_stage must have as many entries as we have resolution stages (n_stages)" + assert len( + features_per_stage) == n_stages, "features_per_stage must have as many entries as we have resolution stages (n_stages)" + assert len(strides) == n_stages, "strides must have as many entries as we have resolution stages (n_stages). " \ + "Important: first entry is recommended to be 1, else we run strided conv drectly on the input" + + pool_op = get_matching_pool_op(conv_op, pool_type=pool_type) if pool_type != 'conv' else None + + self.conv_pad_sizes = [] + for krnl in kernel_sizes: + self.conv_pad_sizes.append([i // 2 for i in krnl]) + + stem_channels = features_per_stage[0] + + self.stem = nn.Sequential( + BasicResBlock( + conv_op = conv_op, + input_channels = input_channels, + output_channels = stem_channels, + norm_op=norm_op, + norm_op_kwargs=norm_op_kwargs, + kernel_size=kernel_sizes[0], + padding=self.conv_pad_sizes[0], + stride=1, + nonlin=nonlin, + nonlin_kwargs=nonlin_kwargs, + use_1x1conv=True + ), + *[ + BasicBlockD( + conv_op = conv_op, + input_channels = stem_channels, + output_channels = stem_channels, + kernel_size = kernel_sizes[0], + stride = 1, + conv_bias = conv_bias, + norm_op = norm_op, + norm_op_kwargs = norm_op_kwargs, + nonlin = nonlin, + nonlin_kwargs = nonlin_kwargs, + ) for _ in range(n_blocks_per_stage[0] - 1) + ] + ) + + + input_channels = stem_channels + + # now build the network + stages = [] + for s in range(n_stages): + stage = nn.Sequential( + BasicResBlock( + conv_op = conv_op, + norm_op = norm_op, + norm_op_kwargs = norm_op_kwargs, + input_channels = input_channels, + output_channels = features_per_stage[s], + kernel_size = kernel_sizes[s], + padding=self.conv_pad_sizes[s], + stride=strides[s], + use_1x1conv=True, + nonlin = nonlin, + nonlin_kwargs = nonlin_kwargs + ), + *[ + BasicBlockD( + conv_op = conv_op, + input_channels = features_per_stage[s], + output_channels = features_per_stage[s], + kernel_size = kernel_sizes[s], + stride = 1, + conv_bias = conv_bias, + norm_op = norm_op, + norm_op_kwargs = norm_op_kwargs, + nonlin = nonlin, + nonlin_kwargs = nonlin_kwargs, + ) for _ in range(n_blocks_per_stage[s] - 1) + ] + ) + + stages.append(stage) + input_channels = features_per_stage[s] + + self.stages = nn.Sequential(*stages) + self.output_channels = features_per_stage + self.strides = [[item] * dim_of_conv_op(conv_op) if not isinstance(item, (tuple, list, np.ndarray)) \ + else item for item in strides] + # self.strides = [maybe_convert_scalar_to_list(conv_op, i) for i in strides] + + self.return_skips = return_skips + + self.conv_op = conv_op + self.norm_op = norm_op + self.norm_op_kwargs = norm_op_kwargs + self.nonlin = nonlin + self.nonlin_kwargs = nonlin_kwargs + #self.dropout_op = dropout_op + #self.dropout_op_kwargs = dropout_op_kwargs + self.conv_bias = conv_bias + self.kernel_sizes = kernel_sizes + + def forward(self, x): + if self.stem is not None: + x = self.stem(x) + ret = [] + for s in self.stages: + x = s(x) + ret.append(x) + if self.return_skips: + return ret + else: + return ret[-1] + + def compute_conv_feature_map_size(self, input_size): + if self.stem is not None: + output = self.stem.compute_conv_feature_map_size(input_size) + else: + output = np.int64(0) + + for s in range(len(self.stages)): + output += self.stages[s].compute_conv_feature_map_size(input_size) + input_size = [i // j for i, j in zip(input_size, self.strides[s])] + + return output + + +class UNetResDecoder(nn.Module): + def __init__(self, + encoder, + num_classes, + n_conv_per_stage: Union[int, Tuple[int, ...], List[int]], + deep_supervision, nonlin_first: bool = False): + + super().__init__() + self.deep_supervision = deep_supervision + self.encoder = encoder + self.num_classes = num_classes + n_stages_encoder = len(encoder.output_channels) + if isinstance(n_conv_per_stage, int): + n_conv_per_stage = [n_conv_per_stage] * (n_stages_encoder - 1) + assert len(n_conv_per_stage) == n_stages_encoder - 1, "n_conv_per_stage must have as many entries as we have " \ + "resolution stages - 1 (n_stages in encoder - 1), " \ + "here: %d" % n_stages_encoder + + stages = [] + upsample_layers = [] + + seg_layers = [] + for s in range(1, n_stages_encoder): + input_features_below = encoder.output_channels[-s] + input_features_skip = encoder.output_channels[-(s + 1)] + stride_for_upsampling = encoder.strides[-s] + upsample_layers.append(UpsampleLayer( + conv_op = encoder.conv_op, + input_channels = input_features_below, + output_channels = input_features_skip, + pool_op_kernel_size = stride_for_upsampling, + mode='nearest' + )) + + stages.append(nn.Sequential( + BasicResBlock( + conv_op = encoder.conv_op, + norm_op = encoder.norm_op, + norm_op_kwargs = encoder.norm_op_kwargs, + nonlin = encoder.nonlin, + nonlin_kwargs = encoder.nonlin_kwargs, + input_channels = 2 * input_features_skip if s < n_stages_encoder - 1 else input_features_skip, + output_channels = input_features_skip, + kernel_size = encoder.kernel_sizes[-(s + 1)], + padding=encoder.conv_pad_sizes[-(s + 1)], + stride=1, + use_1x1conv=True + ), + *[ + BasicBlockD( + conv_op = encoder.conv_op, + input_channels = input_features_skip, + output_channels = input_features_skip, + kernel_size = encoder.kernel_sizes[-(s + 1)], + stride = 1, + conv_bias = encoder.conv_bias, + norm_op = encoder.norm_op, + norm_op_kwargs = encoder.norm_op_kwargs, + nonlin = encoder.nonlin, + nonlin_kwargs = encoder.nonlin_kwargs, + ) for _ in range(n_conv_per_stage[s-1] - 1) + ] + )) + seg_layers.append(encoder.conv_op(input_features_skip, num_classes, 1, 1, 0, bias=True)) + + self.stages = nn.ModuleList(stages) + self.upsample_layers = nn.ModuleList(upsample_layers) + self.seg_layers = nn.ModuleList(seg_layers) + + def forward(self, skips): + lres_input = skips[-1] + seg_outputs = [] + for s in range(len(self.stages)): + x = self.upsample_layers[s](lres_input) + if s < (len(self.stages) - 1): + x = torch.cat((x, skips[-(s+2)]), 1) + x = self.stages[s](x) + if self.deep_supervision: + seg_outputs.append(self.seg_layers[s](x)) + elif s == (len(self.stages) - 1): + seg_outputs.append(self.seg_layers[-1](x)) + lres_input = x + seg_outputs = seg_outputs[::-1] + + if not self.deep_supervision: + r = seg_outputs[0] + else: + r = seg_outputs + return r + + def compute_conv_feature_map_size(self, input_size): + skip_sizes = [] + for s in range(len(self.encoder.strides) - 1): + skip_sizes.append([i // j for i, j in zip(input_size, self.encoder.strides[s])]) + input_size = skip_sizes[-1] + + assert len(skip_sizes) == len(self.stages) + + output = np.int64(0) + for s in range(len(self.stages)): + output += self.stages[s].compute_conv_feature_map_size(skip_sizes[-(s+1)]) + output += np.prod([self.encoder.output_channels[-(s+2)], *skip_sizes[-(s+1)]], dtype=np.int64) + if self.deep_supervision or (s == (len(self.stages) - 1)): + output += np.prod([self.num_classes, *skip_sizes[-(s+1)]], dtype=np.int64) + return output + + +class UMambaBot(nn.Module): + """ + UMambaBot that uses Mamba block at the bottleneck of UNet. + + * Reference: Jun Ma, Feifei Li, Bo Wang. + U-Mamba: Enhancing long-range dependency for biomedical image segmentation. + arxiv 2403.20035, 2024. + + The implementation is based on the code at: + https://github.com/bowang-lab/U-Mamba. + + The parameters for the backbone should be given in the `params` dictionary. + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param class_num: (int) The class number for segmentation task. + :param n_blocks_per_stage: (int) the number of con blocks at each stage. + """ + def __init__(self, params): + super(UMambaBot, self).__init__() + params = self.get_default_parameters(params) + for p in params: + print(p, params[p]) + + # def __init__(self, + # input_channels: int, + # n_stages: int, + # features_per_stage: Union[int, List[int], Tuple[int, ...]], + # conv_op: Type[_ConvNd], + # kernel_sizes: Union[int, List[int], Tuple[int, ...]], + # strides: Union[int, List[int], Tuple[int, ...]], + # n_conv_per_stage: Union[int, List[int], Tuple[int, ...]], + # num_classes: int, + # n_conv_per_stage_decoder: Union[int, Tuple[int, ...], List[int]], + # conv_bias: bool = False, + # norm_op: Union[None, Type[nn.Module]] = None, + # norm_op_kwargs: dict = None, + # dropout_op: Union[None, Type[_DropoutNd]] = None, + # dropout_op_kwargs: dict = None, + # nonlin: Union[None, Type[torch.nn.Module]] = None, + # nonlin_kwargs: dict = None, + # deep_supervision: bool = False, + # stem_channels: int = None + # ): + # super().__init__() + + input_channels = params['in_chns'] + features_per_stage = params['feature_chns'] + num_classes = params['class_num'] + n_blocks_per_stage = params['n_blocks_per_stage'] + n_conv_per_stage_decoder = n_blocks_per_stage + n_stages = len(features_per_stage) + conv_op = nn.Conv2d + kernel_sizes = [(3,3)] * len(features_per_stage) + strides = [(1, 1)] + [(2,2)] * (len(features_per_stage) - 1) + # strides = [(1,1)] * len(features_per_stage) + + conv_bias = True + norm_op = nn.InstanceNorm2d + norm_op_kwargs = {"affine":True} + nonlin=nn.LeakyReLU + nonlin_kwargs={'inplace': True} + deep_supervision = False + stem_channels = None + + if isinstance(n_blocks_per_stage, int): + n_blocks_per_stage = [n_blocks_per_stage] * n_stages + if isinstance(n_conv_per_stage_decoder, int): + n_conv_per_stage_decoder = [n_conv_per_stage_decoder] * (n_stages - 1) + + for s in range(math.ceil(n_stages / 2), n_stages): + n_blocks_per_stage[s] = 1 + + for s in range(math.ceil((n_stages - 1) / 2 + 0.5), n_stages - 1): + n_conv_per_stage_decoder[s] = 1 + + + assert len(n_blocks_per_stage) == n_stages, "n_blocks_per_stage must have as many entries as we have " \ + f"resolution stages. here: {n_stages}. " \ + f"n_blocks_per_stage: {n_blocks_per_stage}" + assert len(n_conv_per_stage_decoder) == (n_stages - 1), "n_conv_per_stage_decoder must have one less entries " \ + f"as we have resolution stages. here: {n_stages} " \ + f"stages, so it should have {n_stages - 1} entries. " \ + f"n_conv_per_stage_decoder: {n_conv_per_stage_decoder}" + + + self.encoder = UNetResEncoder( + input_channels, + n_stages, + features_per_stage, + conv_op, + kernel_sizes, + strides, + n_blocks_per_stage, + conv_bias, + norm_op, + norm_op_kwargs, + nonlin, + nonlin_kwargs, + return_skips=True, + stem_channels=stem_channels + ) + + self.mamba_layer = MambaLayer(dim = features_per_stage[-1]) + + self.decoder = UNetResDecoder(self.encoder, num_classes, n_conv_per_stage_decoder, deep_supervision) + + def get_default_parameters(self, params): + default_param = { + 'feature_chns': [32, 64, 128, 256, 512], + 'n_blocks_per_stage': 2 + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + skips = self.encoder(x) + # for skip in skips: + # print(skip.shape) + skips[-1] = self.mamba_layer(skips[-1]) + output = self.decoder(skips) + if(len(x_shape) == 5): + if(isinstance(output, (list,tuple))): + for i in range(len(output)): + new_shape = [N, D] + list(output[i].shape)[1:] + output[i] = torch.transpose(torch.reshape(output[i], new_shape), 1, 2) + else: + new_shape = [N, D] + list(output.shape)[1:] + output = torch.transpose(torch.reshape(output, new_shape), 1, 2) + return output + + def compute_conv_feature_map_size(self, input_size): + assert len(input_size) == dim_of_conv_op(self.encoder.conv_op), "just give the image size without color/feature channels or " \ + "batch channel. Do not give input_size=(b, c, x, y(, z)). " \ + "Give input_size=(x, y(, z))!" + return self.encoder.compute_conv_feature_map_size(input_size) + self.decoder.compute_conv_feature_map_size(input_size) + +class ResidualMambaEncoder(nn.Module): + def __init__(self, + input_size: Tuple[int, ...], + input_channels: int, + n_stages: int, + features_per_stage: Union[int, List[int], Tuple[int, ...]], + conv_op: Type[_ConvNd], + kernel_sizes: Union[int, List[int], Tuple[int, ...]], + strides: Union[int, List[int], Tuple[int, ...], Tuple[Tuple[int, ...], ...]], + n_blocks_per_stage: Union[int, List[int], Tuple[int, ...]], + conv_bias: bool = False, + norm_op: Union[None, Type[nn.Module]] = None, + norm_op_kwargs: dict = None, + nonlin: Union[None, Type[torch.nn.Module]] = None, + nonlin_kwargs: dict = None, + return_skips: bool = False, + stem_channels: int = None, + pool_type: str = 'conv', + ): + super().__init__() + if isinstance(kernel_sizes, int): + kernel_sizes = [kernel_sizes] * n_stages + if isinstance(features_per_stage, int): + features_per_stage = [features_per_stage] * n_stages + if isinstance(n_blocks_per_stage, int): + n_blocks_per_stage = [n_blocks_per_stage] * n_stages + if isinstance(strides, int): + strides = [strides] * n_stages + assert len( + kernel_sizes) == n_stages, "kernel_sizes must have as many entries as we have resolution stages (n_stages)" + assert len( + n_blocks_per_stage) == n_stages, "n_conv_per_stage must have as many entries as we have resolution stages (n_stages)" + assert len( + features_per_stage) == n_stages, "features_per_stage must have as many entries as we have resolution stages (n_stages)" + assert len(strides) == n_stages, "strides must have as many entries as we have resolution stages (n_stages). " \ + "Important: first entry is recommended to be 1, else we run strided conv drectly on the input" + + pool_op = get_matching_pool_op(conv_op, pool_type=pool_type) if pool_type != 'conv' else None + + do_channel_token = [False] * n_stages + feature_map_sizes = [] + feature_map_size = input_size + for s in range(n_stages): + feature_map_sizes.append([i // j for i, j in zip(feature_map_size, strides[s])]) + feature_map_size = feature_map_sizes[-1] + if np.prod(feature_map_size) <= features_per_stage[s]: + do_channel_token[s] = True + + + print(f"feature_map_sizes: {feature_map_sizes}") + print(f"do_channel_token: {do_channel_token}") + + self.conv_pad_sizes = [] + for krnl in kernel_sizes: + self.conv_pad_sizes.append([i // 2 for i in krnl]) + + stem_channels = features_per_stage[0] + self.stem = nn.Sequential( + BasicResBlock( + conv_op = conv_op, + input_channels = input_channels, + output_channels = stem_channels, + norm_op=norm_op, + norm_op_kwargs=norm_op_kwargs, + kernel_size=kernel_sizes[0], + padding=self.conv_pad_sizes[0], + stride=1, + nonlin=nonlin, + nonlin_kwargs=nonlin_kwargs, + use_1x1conv=True + ), + *[ + BasicBlockD( + conv_op = conv_op, + input_channels = stem_channels, + output_channels = stem_channels, + kernel_size = kernel_sizes[0], + stride = 1, + conv_bias = conv_bias, + norm_op = norm_op, + norm_op_kwargs = norm_op_kwargs, + nonlin = nonlin, + nonlin_kwargs = nonlin_kwargs, + ) for _ in range(n_blocks_per_stage[0] - 1) + ] + ) + + input_channels = stem_channels + + stages = [] + mamba_layers = [] + for s in range(n_stages): + stage = nn.Sequential( + BasicResBlock( + conv_op = conv_op, + norm_op = norm_op, + norm_op_kwargs = norm_op_kwargs, + input_channels = input_channels, + output_channels = features_per_stage[s], + kernel_size = kernel_sizes[s], + padding=self.conv_pad_sizes[s], + stride=strides[s], + use_1x1conv=True, + nonlin = nonlin, + nonlin_kwargs = nonlin_kwargs + ), + *[ + BasicBlockD( + conv_op = conv_op, + input_channels = features_per_stage[s], + output_channels = features_per_stage[s], + kernel_size = kernel_sizes[s], + stride = 1, + conv_bias = conv_bias, + norm_op = norm_op, + norm_op_kwargs = norm_op_kwargs, + nonlin = nonlin, + nonlin_kwargs = nonlin_kwargs, + ) for _ in range(n_blocks_per_stage[s] - 1) + ] + ) + + if bool(s % 2) ^ bool(n_stages % 2): ## gurantee the last stage has mamaba layer + mamba_layers.append( + MambaLayer( + dim = np.prod(feature_map_sizes[s]) if do_channel_token[s] else features_per_stage[s], + channel_token = do_channel_token[s] + ) + ) + else: + mamba_layers.append(nn.Identity()) + + stages.append(stage) + input_channels = features_per_stage[s] + + self.mamba_layers = nn.ModuleList(mamba_layers) + self.stages = nn.ModuleList(stages) + self.output_channels = features_per_stage + self.strides = [[item] * dim_of_conv_op(conv_op) if not isinstance(item, (tuple, list, np.ndarray)) \ + else item for item in strides] + # self.strides = [maybe_convert_scalar_to_list(conv_op, i) for i in strides] + self.return_skips = return_skips + + self.conv_op = conv_op + self.norm_op = norm_op + self.norm_op_kwargs = norm_op_kwargs + self.nonlin = nonlin + self.nonlin_kwargs = nonlin_kwargs + #self.dropout_op = dropout_op + #self.dropout_op_kwargs = dropout_op_kwargs + self.conv_bias = conv_bias + self.kernel_sizes = kernel_sizes + + def forward(self, x): + if self.stem is not None: + x = self.stem(x) + ret = [] + for s in range(len(self.stages)): + x = self.stages[s](x) + x = self.mamba_layers[s](x) + ret.append(x) + if self.return_skips: + return ret + else: + return ret[-1] + + def compute_conv_feature_map_size(self, input_size): + if self.stem is not None: + output = self.stem.compute_conv_feature_map_size(input_size) + else: + output = np.int64(0) + + for s in range(len(self.stages)): + output += self.stages[s].compute_conv_feature_map_size(input_size) + input_size = [i // j for i, j in zip(input_size, self.strides[s])] + + return output + + +class UMambaEnc(nn.Module): + """ + UMambaEnc that uses Mamba block at the encoder and bottleneck of UNet. + + * Reference: Jun Ma, Feifei Li, Bo Wang. + U-Mamba: Enhancing long-range dependency for biomedical image segmentation. + arxiv 2403.20035, 2024. + + The implementation is based on the code at: + https://github.com/bowang-lab/U-Mamba. + + The parameters for the backbone should be given in the `params` dictionary. + + :param input_size: (list) the size of input image, such as [256, 256] + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param class_num: (int) The class number for segmentation task. + :param n_blocks_per_stage: (int) the number of con blocks at each stage. + """ + def __init__(self, params): + super(UMambaEnc, self).__init__() + params = self.get_default_parameters(params) + for p in params: + print(p, params[p]) + # def __init__(self, + # input_size: Tuple[int, ...], + # input_channels: int, + # n_stages: int, + # features_per_stage: Union[int, List[int], Tuple[int, ...]], + # conv_op: Type[_ConvNd], + # kernel_sizes: Union[int, List[int], Tuple[int, ...]], + # strides: Union[int, List[int], Tuple[int, ...]], + # n_conv_per_stage: Union[int, List[int], Tuple[int, ...]], + # num_classes: int, + # n_conv_per_stage_decoder: Union[int, Tuple[int, ...], List[int]], + # conv_bias: bool = False, + # norm_op: Union[None, Type[nn.Module]] = None, + # norm_op_kwargs: dict = None, + # dropout_op: Union[None, Type[_DropoutNd]] = None, + # dropout_op_kwargs: dict = None, + # nonlin: Union[None, Type[torch.nn.Module]] = None, + # nonlin_kwargs: dict = None, + # deep_supervision: bool = False, + # stem_channels: int = None + # ): + # super().__init__() + + input_size = params['input_size'] + input_channels = params['in_chns'] + features_per_stage = params['feature_chns'] + num_classes = params['class_num'] + n_blocks_per_stage = params['n_blocks_per_stage'] + n_conv_per_stage_decoder = n_blocks_per_stage + n_stages = len(features_per_stage) + conv_op = nn.Conv2d + kernel_sizes = [(3,3)] * len(features_per_stage) + strides = [(1, 1)] + [(2,2)] * (len(features_per_stage) - 1) + + conv_bias = True + norm_op = nn.InstanceNorm2d + norm_op_kwargs = {"affine":True} + nonlin=nn.LeakyReLU + nonlin_kwargs={'inplace': True} + deep_supervision = False + stem_channels = None + + if isinstance(n_blocks_per_stage, int): + n_blocks_per_stage = [n_blocks_per_stage] * n_stages + if isinstance(n_conv_per_stage_decoder, int): + n_conv_per_stage_decoder = [n_conv_per_stage_decoder] * (n_stages - 1) + + for s in range(math.ceil(n_stages / 2), n_stages): + n_blocks_per_stage[s] = 1 + + for s in range(math.ceil((n_stages - 1) / 2 + 0.5), n_stages - 1): + n_conv_per_stage_decoder[s] = 1 + + + assert len(n_blocks_per_stage) == n_stages, "n_blocks_per_stage must have as many entries as we have " \ + f"resolution stages. here: {n_stages}. " \ + f"n_blocks_per_stage: {n_blocks_per_stage}" + assert len(n_conv_per_stage_decoder) == (n_stages - 1), "n_conv_per_stage_decoder must have one less entries " \ + f"as we have resolution stages. here: {n_stages} " \ + f"stages, so it should have {n_stages - 1} entries. " \ + f"n_conv_per_stage_decoder: {n_conv_per_stage_decoder}" + self.encoder = ResidualMambaEncoder( + input_size, + input_channels, + n_stages, + features_per_stage, + conv_op, + kernel_sizes, + strides, + n_blocks_per_stage, + conv_bias, + norm_op, + norm_op_kwargs, + nonlin, + nonlin_kwargs, + return_skips=True, + stem_channels=stem_channels + ) + + self.decoder = UNetResDecoder(self.encoder, num_classes, n_conv_per_stage_decoder, deep_supervision) + + def get_default_parameters(self, params): + default_param = { + 'feature_chns': [32, 64, 128, 256, 512], + 'n_blocks_per_stage': 2 + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + skips = self.encoder(x) + output = self.decoder(skips) + if(len(x_shape) == 5): + if(isinstance(output, (list,tuple))): + for i in range(len(output)): + new_shape = [N, D] + list(output[i].shape)[1:] + output[i] = torch.transpose(torch.reshape(output[i], new_shape), 1, 2) + else: + new_shape = [N, D] + list(output.shape)[1:] + output = torch.transpose(torch.reshape(output, new_shape), 1, 2) + return output + + def compute_conv_feature_map_size(self, input_size): + assert len(input_size) == dim_of_conv_op(self.encoder.conv_op), "just give the image size without color/feature channels or " \ + "batch channel. Do not give input_size=(b, c, x, y(, z)). " \ + "Give input_size=(x, y(, z))!" + return self.encoder.compute_conv_feature_map_size(input_size) + self.decoder.compute_conv_feature_map_size(input_size) + diff --git a/pymic/net/net_dict_seg.py b/pymic/net/net_dict_seg.py index 8877ade..ffd82ae 100644 --- a/pymic/net/net_dict_seg.py +++ b/pymic/net/net_dict_seg.py @@ -26,6 +26,7 @@ from pymic.net.net2d.unet2d_scse import UNet2D_ScSE from pymic.net.net2d.trans2d.transunet import TransUNet from pymic.net.net2d.trans2d.swinunet import SwinUNet +from pymic.net.net2d.umamba import UMambaBot, UMambaEnc from pymic.net.net3d.unet2d5 import UNet2D5 from pymic.net.net3d.unet3d import UNet3D from pymic.net.net3d.grunet import GRUNet @@ -62,6 +63,8 @@ 'UNet2D_CCT': UNet2D_CCT, 'UNet2Dpp': UNet2Dpp, 'UNet2D_ScSE': UNet2D_ScSE, + 'UMambaBot': UMambaBot, + 'UMambaEnc': UMambaEnc, 'TransUNet': TransUNet, 'SwinUNet': SwinUNet, 'UNet2D5': UNet2D5, diff --git a/pymic/test/test_net2d.py b/pymic/test/test_net2d.py index aafaf20..9013386 100644 --- a/pymic/test/test_net2d.py +++ b/pymic/test/test_net2d.py @@ -5,7 +5,7 @@ import numpy as np from pymic.net.net2d.unet2d import UNet2D from pymic.net.net2d.unet2d_scse import UNet2D_ScSE - +from pymic.net.net2d.umamba_bot import UMambaBot def test_unet2d(): params = {'in_chns':4, 'feature_chns':[16, 32, 64, 128, 256], @@ -52,6 +52,17 @@ def test_unet2d_scse(): else: print(out.shape) +def test_umamba(): + x = np.random.rand(4, 4, 10, 256, 256) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + Net = UMambaBot() + out = Net(xt) + y = out.detach().numpy() + print(y.shape) + if __name__ == "__main__": # test_unet2d() - test_unet2d_scse() \ No newline at end of file + # test_unet2d_scse() + test_umamba() \ No newline at end of file From 9890b77ccaab4f963e0208291df93d242239967f Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 31 Jan 2025 22:47:09 +0800 Subject: [PATCH 214/225] add ultralight_vm_unet --- pymic/net/net2d/unet2d_vm_light.py | 284 +++++++++++++++++++++++++++++ pymic/net/net_dict_seg.py | 2 + 2 files changed, 286 insertions(+) create mode 100644 pymic/net/net2d/unet2d_vm_light.py diff --git a/pymic/net/net2d/unet2d_vm_light.py b/pymic/net/net2d/unet2d_vm_light.py new file mode 100644 index 0000000..ed3f1c5 --- /dev/null +++ b/pymic/net/net2d/unet2d_vm_light.py @@ -0,0 +1,284 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import math +import torch +from torch import nn +import torch.nn.functional as F + +from timm.models.layers import trunc_normal_ +from mamba_ssm import Mamba + + +class PVMLayer(nn.Module): + def __init__(self, input_dim, output_dim, d_state = 16, d_conv = 4, expand = 2): + super().__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.norm = nn.LayerNorm(input_dim) + self.mamba = Mamba( + d_model=input_dim//4, # Model dimension d_model + d_state=d_state, # SSM state expansion factor + d_conv=d_conv, # Local convolution width + expand=expand, # Block expansion factor + ) + self.proj = nn.Linear(input_dim, output_dim) + self.skip_scale= nn.Parameter(torch.ones(1)) + + def forward(self, x): + if x.dtype == torch.float16: + x = x.type(torch.float32) + B, C = x.shape[:2] + assert C == self.input_dim + n_tokens = x.shape[2:].numel() + img_dims = x.shape[2:] + x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2) + x_norm = self.norm(x_flat) + + x1, x2, x3, x4 = torch.chunk(x_norm, 4, dim=2) + x_mamba1 = self.mamba(x1) + self.skip_scale * x1 + x_mamba2 = self.mamba(x2) + self.skip_scale * x2 + x_mamba3 = self.mamba(x3) + self.skip_scale * x3 + x_mamba4 = self.mamba(x4) + self.skip_scale * x4 + x_mamba = torch.cat([x_mamba1, x_mamba2,x_mamba3,x_mamba4], dim=2) + + x_mamba = self.norm(x_mamba) + x_mamba = self.proj(x_mamba) + out = x_mamba.transpose(-1, -2).reshape(B, self.output_dim, *img_dims) + return out + + +class Channel_Att_Bridge(nn.Module): + def __init__(self, c_list, split_att='fc'): + super().__init__() + c_list_sum = sum(c_list) - c_list[-1] + self.split_att = split_att + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.get_all_att = nn.Conv1d(1, 1, kernel_size=3, padding=1, bias=False) + self.att1 = nn.Linear(c_list_sum, c_list[0]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[0], 1) + self.att2 = nn.Linear(c_list_sum, c_list[1]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[1], 1) + self.att3 = nn.Linear(c_list_sum, c_list[2]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[2], 1) + self.att4 = nn.Linear(c_list_sum, c_list[3]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[3], 1) + self.att5 = nn.Linear(c_list_sum, c_list[4]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[4], 1) + self.sigmoid = nn.Sigmoid() + + def forward(self, t1, t2, t3, t4, t5): + att = torch.cat((self.avgpool(t1), + self.avgpool(t2), + self.avgpool(t3), + self.avgpool(t4), + self.avgpool(t5)), dim=1) + att = self.get_all_att(att.squeeze(-1).transpose(-1, -2)) + if self.split_att != 'fc': + att = att.transpose(-1, -2) + att1 = self.sigmoid(self.att1(att)) + att2 = self.sigmoid(self.att2(att)) + att3 = self.sigmoid(self.att3(att)) + att4 = self.sigmoid(self.att4(att)) + att5 = self.sigmoid(self.att5(att)) + if self.split_att == 'fc': + att1 = att1.transpose(-1, -2).unsqueeze(-1).expand_as(t1) + att2 = att2.transpose(-1, -2).unsqueeze(-1).expand_as(t2) + att3 = att3.transpose(-1, -2).unsqueeze(-1).expand_as(t3) + att4 = att4.transpose(-1, -2).unsqueeze(-1).expand_as(t4) + att5 = att5.transpose(-1, -2).unsqueeze(-1).expand_as(t5) + else: + att1 = att1.unsqueeze(-1).expand_as(t1) + att2 = att2.unsqueeze(-1).expand_as(t2) + att3 = att3.unsqueeze(-1).expand_as(t3) + att4 = att4.unsqueeze(-1).expand_as(t4) + att5 = att5.unsqueeze(-1).expand_as(t5) + + return att1, att2, att3, att4, att5 + + +class Spatial_Att_Bridge(nn.Module): + def __init__(self): + super().__init__() + self.shared_conv2d = nn.Sequential(nn.Conv2d(2, 1, 7, stride=1, padding=9, dilation=3), + nn.Sigmoid()) + + def forward(self, t1, t2, t3, t4, t5): + t_list = [t1, t2, t3, t4, t5] + att_list = [] + for t in t_list: + avg_out = torch.mean(t, dim=1, keepdim=True) + max_out, _ = torch.max(t, dim=1, keepdim=True) + att = torch.cat([avg_out, max_out], dim=1) + att = self.shared_conv2d(att) + att_list.append(att) + return att_list[0], att_list[1], att_list[2], att_list[3], att_list[4] + + +class SC_Att_Bridge(nn.Module): + def __init__(self, c_list, split_att='fc'): + super().__init__() + + self.catt = Channel_Att_Bridge(c_list, split_att=split_att) + self.satt = Spatial_Att_Bridge() + + def forward(self, t1, t2, t3, t4, t5): + r1, r2, r3, r4, r5 = t1, t2, t3, t4, t5 + + satt1, satt2, satt3, satt4, satt5 = self.satt(t1, t2, t3, t4, t5) + t1, t2, t3, t4, t5 = satt1 * t1, satt2 * t2, satt3 * t3, satt4 * t4, satt5 * t5 + + r1_, r2_, r3_, r4_, r5_ = t1, t2, t3, t4, t5 + t1, t2, t3, t4, t5 = t1 + r1, t2 + r2, t3 + r3, t4 + r4, t5 + r5 + + catt1, catt2, catt3, catt4, catt5 = self.catt(t1, t2, t3, t4, t5) + t1, t2, t3, t4, t5 = catt1 * t1, catt2 * t2, catt3 * t3, catt4 * t4, catt5 * t5 + + return t1 + r1_, t2 + r2_, t3 + r3_, t4 + r4_, t5 + r5_ + + +class UltraLight_VM_UNet(nn.Module): + def __init__(self, params): + """ + UltraLight_VM_UNet that is a lightweight model using CNN and Mamba. + + * Reference: Renkai Wu, Yinghao Liu, Pengchen Liang, Qing Chang. + UltraLight VM-UNet: Parallel Vision Mamba Significantly Reduces Parameters for Skin Lesion Segmentation. + arxiv 2403.20035, 2024. + + The implementation is based on the code at: + https://github.com/wurenkai/UltraLight-VM-UNet. + + The parameters for the backbone should be given in the `params` dictionary. + + :param in_chns: (int) Input channel number. + :param class_num: (int) The class number for segmentation task. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 6, by default it is [8, 16, 24, 32, 48, 64]. + :param bridge: (int) If the bridge based on spatial and channel attentions is used or not. + By default it is True. + """ + super(UltraLight_VM_UNet, self).__init__() + + input_channels = params['in_chns'] + num_classes = params['class_num'] + c_list = params.get('feature_chns', [8, 16, 24, 32, 48, 64]) + self.bridge = params.get('bridge', True) + split_att = 'fc' + # def __init__(self, num_classes=1, input_channels=3, c_list=[8,16,24,32,48,64], + # split_att='fc', bridge=True): + # super().__init__() + # self.bridge = bridge + + self.encoder1 = nn.Sequential( + nn.Conv2d(input_channels, c_list[0], 3, stride=1, padding=1), + ) + self.encoder2 =nn.Sequential( + nn.Conv2d(c_list[0], c_list[1], 3, stride=1, padding=1), + ) + self.encoder3 = nn.Sequential( + nn.Conv2d(c_list[1], c_list[2], 3, stride=1, padding=1), + ) + self.encoder4 = nn.Sequential( + PVMLayer(input_dim=c_list[2], output_dim=c_list[3]) + ) + self.encoder5 = nn.Sequential( + PVMLayer(input_dim=c_list[3], output_dim=c_list[4]) + ) + self.encoder6 = nn.Sequential( + PVMLayer(input_dim=c_list[4], output_dim=c_list[5]) + ) + + if self.bridge: + self.scab = SC_Att_Bridge(c_list, split_att) + print('SC_Att_Bridge was used') + + self.decoder1 = nn.Sequential( + PVMLayer(input_dim=c_list[5], output_dim=c_list[4]) + ) + self.decoder2 = nn.Sequential( + PVMLayer(input_dim=c_list[4], output_dim=c_list[3]) + ) + self.decoder3 = nn.Sequential( + PVMLayer(input_dim=c_list[3], output_dim=c_list[2]) + ) + self.decoder4 = nn.Sequential( + nn.Conv2d(c_list[2], c_list[1], 3, stride=1, padding=1), + ) + self.decoder5 = nn.Sequential( + nn.Conv2d(c_list[1], c_list[0], 3, stride=1, padding=1), + ) + self.ebn1 = nn.GroupNorm(4, c_list[0]) + self.ebn2 = nn.GroupNorm(4, c_list[1]) + self.ebn3 = nn.GroupNorm(4, c_list[2]) + self.ebn4 = nn.GroupNorm(4, c_list[3]) + self.ebn5 = nn.GroupNorm(4, c_list[4]) + self.dbn1 = nn.GroupNorm(4, c_list[4]) + self.dbn2 = nn.GroupNorm(4, c_list[3]) + self.dbn3 = nn.GroupNorm(4, c_list[2]) + self.dbn4 = nn.GroupNorm(4, c_list[1]) + self.dbn5 = nn.GroupNorm(4, c_list[0]) + + self.final = nn.Conv2d(c_list[0], num_classes, kernel_size=1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv1d): + n = m.kernel_size[0] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + + out = F.gelu(F.max_pool2d(self.ebn1(self.encoder1(x)),2,2)) + t1 = out # b, c0, H/2, W/2 + + out = F.gelu(F.max_pool2d(self.ebn2(self.encoder2(out)),2,2)) + t2 = out # b, c1, H/4, W/4 + + out = F.gelu(F.max_pool2d(self.ebn3(self.encoder3(out)),2,2)) + t3 = out # b, c2, H/8, W/8 + + out = F.gelu(F.max_pool2d(self.ebn4(self.encoder4(out)),2,2)) + t4 = out # b, c3, H/16, W/16 + + out = F.gelu(F.max_pool2d(self.ebn5(self.encoder5(out)),2,2)) + t5 = out # b, c4, H/32, W/32 + + if self.bridge: t1, t2, t3, t4, t5 = self.scab(t1, t2, t3, t4, t5) + + out = F.gelu(self.encoder6(out)) # b, c5, H/32, W/32 + + out5 = F.gelu(self.dbn1(self.decoder1(out))) # b, c4, H/32, W/32 + out5 = torch.add(out5, t5) # b, c4, H/32, W/32 + + out4 = F.gelu(F.interpolate(self.dbn2(self.decoder2(out5)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c3, H/16, W/16 + out4 = torch.add(out4, t4) # b, c3, H/16, W/16 + + out3 = F.gelu(F.interpolate(self.dbn3(self.decoder3(out4)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c2, H/8, W/8 + out3 = torch.add(out3, t3) # b, c2, H/8, W/8 + + out2 = F.gelu(F.interpolate(self.dbn4(self.decoder4(out3)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c1, H/4, W/4 + out2 = torch.add(out2, t2) # b, c1, H/4, W/4 + + out1 = F.gelu(F.interpolate(self.dbn5(self.decoder5(out2)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c0, H/2, W/2 + out1 = torch.add(out1, t1) # b, c0, H/2, W/2 + + out0 = F.interpolate(self.final(out1),scale_factor=(2,2),mode ='bilinear',align_corners=True) # b, num_class, H, W + + if(len(x_shape) == 5): + new_shape = [N, D] + list(out0.shape)[1:] + out0 = torch.transpose(torch.reshape(out0, new_shape), 1, 2) + return out0 + diff --git a/pymic/net/net_dict_seg.py b/pymic/net/net_dict_seg.py index ffd82ae..b557bec 100644 --- a/pymic/net/net_dict_seg.py +++ b/pymic/net/net_dict_seg.py @@ -27,6 +27,7 @@ from pymic.net.net2d.trans2d.transunet import TransUNet from pymic.net.net2d.trans2d.swinunet import SwinUNet from pymic.net.net2d.umamba import UMambaBot, UMambaEnc +from pymic.net.net2d.unet2d_vm_light import UltraLight_VM_UNet from pymic.net.net3d.unet2d5 import UNet2D5 from pymic.net.net3d.unet3d import UNet3D from pymic.net.net3d.grunet import GRUNet @@ -65,6 +66,7 @@ 'UNet2D_ScSE': UNet2D_ScSE, 'UMambaBot': UMambaBot, 'UMambaEnc': UMambaEnc, + 'UltraLight_VM_UNet': UltraLight_VM_UNet, 'TransUNet': TransUNet, 'SwinUNet': SwinUNet, 'UNet2D5': UNet2D5, From a279271326ddea66a630f26556673d4332bdf7e6 Mon Sep 17 00:00:00 2001 From: taigw Date: Sun, 2 Feb 2025 22:19:21 +0800 Subject: [PATCH 215/225] add VMUNet --- pymic/net/net2d/unet2d_vm.py | 820 +++++++++++++++++++++++++++++ pymic/net/net2d/unet2d_vm_light.py | 2 +- pymic/net/net_dict_seg.py | 2 + 3 files changed, 823 insertions(+), 1 deletion(-) create mode 100644 pymic/net/net2d/unet2d_vm.py diff --git a/pymic/net/net2d/unet2d_vm.py b/pymic/net/net2d/unet2d_vm.py new file mode 100644 index 0000000..126fa30 --- /dev/null +++ b/pymic/net/net2d/unet2d_vm.py @@ -0,0 +1,820 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import math +from functools import partial +from typing import Optional, Callable + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from einops import rearrange, repeat +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +try: + from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref +except: + pass + +# an alternative for mamba_ssm (in which causal_conv1d is needed) +try: + from selective_scan import selective_scan_fn as selective_scan_fn_v1 + from selective_scan import selective_scan_ref as selective_scan_ref_v1 +except: + pass + +DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})" + + +def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False): + """ + u: r(B D L) + delta: r(B D L) + A: r(D N) + B: r(B N L) + C: r(B N L) + D: r(D) + z: r(B D L) + delta_bias: r(D), fp32 + + ignores: + [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] + """ + import numpy as np + + # fvcore.nn.jit_handles + def get_flops_einsum(input_shapes, equation): + np_arrs = [np.zeros(s) for s in input_shapes] + optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1] + for line in optim.split("\n"): + if "optimized flop" in line.lower(): + # divided by 2 because we count MAC (multiply-add counted as one flop) + flop = float(np.floor(float(line.split(":")[-1]) / 2)) + return flop + + + assert not with_complex + + flops = 0 # below code flops = 0 + if False: + ... + """ + dtype_in = u.dtype + u = u.float() + delta = delta.float() + if delta_bias is not None: + delta = delta + delta_bias[..., None].float() + if delta_softplus: + delta = F.softplus(delta) + batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] + is_variable_B = B.dim() >= 3 + is_variable_C = C.dim() >= 3 + if A.is_complex(): + if is_variable_B: + B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) + if is_variable_C: + C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) + else: + B = B.float() + C = C.float() + x = A.new_zeros((batch, dim, dstate)) + ys = [] + """ + + flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln") + if with_Group: + flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln") + else: + flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln") + if False: + ... + """ + deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) + if not is_variable_B: + deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) + else: + if B.dim() == 3: + deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) + else: + B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) + deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) + if is_variable_C and C.dim() == 4: + C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) + last_state = None + """ + + in_for_flops = B * D * N + if with_Group: + in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd") + else: + in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd") + flops += L * in_for_flops + if False: + ... + """ + for i in range(u.shape[2]): + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + if not is_variable_C: + y = torch.einsum('bdn,dn->bd', x, C) + else: + if C.dim() == 3: + y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) + else: + y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) + if i == u.shape[2] - 1: + last_state = x + if y.is_complex(): + y = y.real * 2 + ys.append(y) + y = torch.stack(ys, dim=2) # (batch dim L) + """ + + if with_D: + flops += B * D * L + if with_Z: + flops += B * D * L + if False: + ... + """ + out = y if D is None else y + u * rearrange(D, "d -> d 1") + if z is not None: + out = out * F.silu(z) + out = out.to(dtype=dtype_in) + """ + + return flops + + +class PatchEmbed2D(nn.Module): + r""" Image to Patch Embedding + Args: + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, **kwargs): + super().__init__() + if isinstance(patch_size, int): + patch_size = (patch_size, patch_size) + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = self.proj(x).permute(0, 2, 3, 1) + if self.norm is not None: + x = self.norm(x) + return x + + +class PatchMerging2D(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + B, H, W, C = x.shape + + SHAPE_FIX = [-1, -1] + if (W % 2 != 0) or (H % 2 != 0): + print(f"Warning, x.shape {x.shape} is not match even ===========", flush=True) + SHAPE_FIX[0] = H // 2 + SHAPE_FIX[1] = W // 2 + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + + if SHAPE_FIX[0] > 0: + x0 = x0[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] + x1 = x1[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] + x2 = x2[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] + x3 = x3[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] + + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, H//2, W//2, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +class PatchExpand2D(nn.Module): + def __init__(self, dim, dim_scale=2, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim*2 + self.dim_scale = dim_scale + self.expand = nn.Linear(self.dim, dim_scale*self.dim, bias=False) + self.norm = norm_layer(self.dim // dim_scale) + + def forward(self, x): + B, H, W, C = x.shape + x = self.expand(x) + + x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//self.dim_scale) + x= self.norm(x) + + return x + + +class Final_PatchExpand2D(nn.Module): + def __init__(self, dim, dim_scale=4, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.dim_scale = dim_scale + self.expand = nn.Linear(self.dim, dim_scale*self.dim, bias=False) + self.norm = norm_layer(self.dim // dim_scale) + + def forward(self, x): + B, H, W, C = x.shape + x = self.expand(x) + + x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//self.dim_scale) + x= self.norm(x) + + return x + + +class SS2D(nn.Module): + def __init__( + self, + d_model, + d_state=16, + # d_state="auto", # 20240109 + d_conv=3, + expand=2, + dt_rank="auto", + dt_min=0.001, + dt_max=0.1, + dt_init="random", + dt_scale=1.0, + dt_init_floor=1e-4, + dropout=0., + conv_bias=True, + bias=False, + device=None, + dtype=None, + **kwargs, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.d_model = d_model + self.d_state = d_state + # self.d_state = math.ceil(self.d_model / 6) if d_state == "auto" else d_model # 20240109 + self.d_conv = d_conv + self.expand = expand + self.d_inner = int(self.expand * self.d_model) + self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank + + self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) + self.conv2d = nn.Conv2d( + in_channels=self.d_inner, + out_channels=self.d_inner, + groups=self.d_inner, + bias=conv_bias, + kernel_size=d_conv, + padding=(d_conv - 1) // 2, + **factory_kwargs, + ) + self.act = nn.SiLU() + + self.x_proj = ( + nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), + nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), + nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), + nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), + ) + self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K=4, N, inner) + del self.x_proj + + self.dt_projs = ( + self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), + self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), + self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), + self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), + ) + self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K=4, inner, rank) + self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K=4, inner) + del self.dt_projs + + self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=4, merge=True) # (K=4, D, N) + self.Ds = self.D_init(self.d_inner, copies=4, merge=True) # (K=4, D, N) + + # self.selective_scan = selective_scan_fn + self.forward_core = self.forward_corev0 + + self.out_norm = nn.LayerNorm(self.d_inner) + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + self.dropout = nn.Dropout(dropout) if dropout > 0. else None + + @staticmethod + def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs): + dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs) + + # Initialize special dt projection to preserve variance at initialization + dt_init_std = dt_rank**-0.5 * dt_scale + if dt_init == "constant": + nn.init.constant_(dt_proj.weight, dt_init_std) + elif dt_init == "random": + nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) + else: + raise NotImplementedError + + # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max + dt = torch.exp( + torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ).clamp(min=dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + with torch.no_grad(): + dt_proj.bias.copy_(inv_dt) + # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit + dt_proj.bias._no_reinit = True + + return dt_proj + + @staticmethod + def A_log_init(d_state, d_inner, copies=1, device=None, merge=True): + # S4D real initialization + A = repeat( + torch.arange(1, d_state + 1, dtype=torch.float32, device=device), + "n -> d n", + d=d_inner, + ).contiguous() + A_log = torch.log(A) # Keep A_log in fp32 + if copies > 1: + A_log = repeat(A_log, "d n -> r d n", r=copies) + if merge: + A_log = A_log.flatten(0, 1) + A_log = nn.Parameter(A_log) + A_log._no_weight_decay = True + return A_log + + @staticmethod + def D_init(d_inner, copies=1, device=None, merge=True): + # D "skip" parameter + D = torch.ones(d_inner, device=device) + if copies > 1: + D = repeat(D, "n1 -> r n1", r=copies) + if merge: + D = D.flatten(0, 1) + D = nn.Parameter(D) # Keep in fp32 + D._no_weight_decay = True + return D + + def forward_corev0(self, x: torch.Tensor): + self.selective_scan = selective_scan_fn + + B, C, H, W = x.shape + L = H * W + K = 4 + + x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) + xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l) + + x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight) + # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1) + dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2) + dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight) + # dts = dts + self.dt_projs_bias.view(1, K, -1, 1) + + xs = xs.float().view(B, -1, L) # (b, k * d, l) + dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l) + Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l) + Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l) + Ds = self.Ds.float().view(-1) # (k * d) + As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state) + dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d) + + out_y = self.selective_scan( + xs, dts, + As, Bs, Cs, Ds, z=None, + delta_bias=dt_projs_bias, + delta_softplus=True, + return_last_state=False, + ).view(B, K, -1, L) + assert out_y.dtype == torch.float + + inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) + wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) + invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) + + return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y + + # an alternative to forward_corev1 + def forward_corev1(self, x: torch.Tensor): + self.selective_scan = selective_scan_fn_v1 + + B, C, H, W = x.shape + L = H * W + K = 4 + + x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) + xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l) + + x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight) + # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1) + dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2) + dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight) + # dts = dts + self.dt_projs_bias.view(1, K, -1, 1) + + xs = xs.float().view(B, -1, L) # (b, k * d, l) + dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l) + Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l) + Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l) + Ds = self.Ds.float().view(-1) # (k * d) + As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state) + dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d) + + out_y = self.selective_scan( + xs, dts, + As, Bs, Cs, Ds, + delta_bias=dt_projs_bias, + delta_softplus=True, + ).view(B, K, -1, L) + assert out_y.dtype == torch.float + + inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) + wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) + invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) + + return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y + + def forward(self, x: torch.Tensor, **kwargs): + B, H, W, C = x.shape + + xz = self.in_proj(x) + x, z = xz.chunk(2, dim=-1) # (b, h, w, d) + + x = x.permute(0, 3, 1, 2).contiguous() + x = self.act(self.conv2d(x)) # (b, d, h, w) + y1, y2, y3, y4 = self.forward_core(x) + assert y1.dtype == torch.float32 + y = y1 + y2 + y3 + y4 + y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1) + y = self.out_norm(y) + y = y * F.silu(z) + out = self.out_proj(y) + if self.dropout is not None: + out = self.dropout(out) + return out + + +class VSSBlock(nn.Module): + def __init__( + self, + hidden_dim: int = 0, + drop_path: float = 0, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + attn_drop_rate: float = 0, + d_state: int = 16, + **kwargs, + ): + super().__init__() + self.ln_1 = norm_layer(hidden_dim) + self.self_attention = SS2D(d_model=hidden_dim, dropout=attn_drop_rate, d_state=d_state, **kwargs) + self.drop_path = DropPath(drop_path) + + def forward(self, input: torch.Tensor): + x = input + self.drop_path(self.self_attention(self.ln_1(input))) + return x + + +class VSSLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + depth (int): Number of blocks. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + dim, + depth, + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + d_state=16, + **kwargs, + ): + super().__init__() + self.dim = dim + self.use_checkpoint = use_checkpoint + + self.blocks = nn.ModuleList([ + VSSBlock( + hidden_dim=dim, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + attn_drop_rate=attn_drop, + d_state=d_state, + ) + for i in range(depth)]) + + if True: # is this really applied? Yes, but been overriden later in VSSM! + def _init_weights(module: nn.Module): + for name, p in module.named_parameters(): + if name in ["out_proj.weight"]: + p = p.clone().detach_() # fake init, just to keep the seed .... + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + self.apply(_init_weights) + + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + + if self.downsample is not None: + x = self.downsample(x) + + return x + + + +class VSSLayer_up(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + depth (int): Number of blocks. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + dim, + depth, + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + upsample=None, + use_checkpoint=False, + d_state=16, + **kwargs, + ): + super().__init__() + self.dim = dim + self.use_checkpoint = use_checkpoint + + self.blocks = nn.ModuleList([ + VSSBlock( + hidden_dim=dim, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + attn_drop_rate=attn_drop, + d_state=d_state, + ) + for i in range(depth)]) + + if True: # is this really applied? Yes, but been overriden later in VSSM! + def _init_weights(module: nn.Module): + for name, p in module.named_parameters(): + if name in ["out_proj.weight"]: + p = p.clone().detach_() # fake init, just to keep the seed .... + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + self.apply(_init_weights) + + if upsample is not None: + self.upsample = upsample(dim=dim, norm_layer=norm_layer) + else: + self.upsample = None + + + def forward(self, x): + if self.upsample is not None: + x = self.upsample(x) + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + return x + + + +# class VSSM(nn.Module): +class VMUNet(nn.Module): + """ + VM_UNet that is a UNet-like pure vision mambal network for segmentation. + + * Reference: Jiacheng Ruan et al., VM-UNet: Vision Mamba UNet for Medical Image Segmentation. + arxiv 2403.09157, 2024. + + The implementation is based on the code at: + https://github.com/JCruan519/VM-UNet. + + The parameters for the backbone should be given in the `params` dictionary. + + :param in_chns: (int) Input channel number. + :param class_num: (int) The class number for segmentation task. + :param depths: (list) The depth of VSS block at each resolution level. + The length should be 4, by default it is [2, 2, 9, 2]. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4, by default it is [96, 192, 384, 768]. + """ + # def __init__(self, c=4, in_chans=3, num_classes=1000, depths=[2, 2, 9, 2], depths_decoder=[2, 9, 2, 2], + # dims=[96, 192, 384, 768], dims_decoder=[768, 384, 192, 96], d_state=16, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + # norm_layer=nn.LayerNorm, patch_norm=True, + # use_checkpoint=False, **kwargs): + # super().__init__() + def __init__(self, params): + super(VMUNet, self).__init__() + in_chans = params['in_chns'] + num_classes = params['class_num'] + patch_size = params.get('patch_size', 4) + depths = params.get('depths', [2, 2, 9, 2]) + depths_decoder = depths.copy() + depths_decoder.reverse() + dims = params.get('feature_chns', [96, 192, 384, 768]) + dims_decoder = dims.copy() + dims_decoder.reverse() + d_state = params.get('d_state', 16) + drop_rate = params.get('drop_rate', 0.) + attn_drop_rate = params.get('att_drop_rate', 0.) + drop_path_rate = params.get('path_drop_rate', 0.1) + + norm_layer = nn.LayerNorm + patch_norm = True + use_checkpoint = False + + self.num_classes = num_classes + self.num_layers = len(depths) + if isinstance(dims, int): + dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)] + self.embed_dim = dims[0] + self.num_features = dims[-1] + self.dims = dims + + self.patch_embed = PatchEmbed2D(patch_size=patch_size, in_chans=in_chans, embed_dim=self.embed_dim, + norm_layer=norm_layer if patch_norm else None) + + # WASTED absolute position embedding ====================== + self.ape = False + # self.ape = False + # drop_rate = 0.0 + if self.ape: + self.patches_resolution = self.patch_embed.patches_resolution + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, *self.patches_resolution, self.embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + dpr_decoder = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths_decoder))][::-1] + + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = VSSLayer( + dim=dims[i_layer], + depth=depths[i_layer], + d_state=math.ceil(dims[0] / 6) if d_state is None else d_state, # 20240109 + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging2D if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + ) + self.layers.append(layer) + + self.layers_up = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = VSSLayer_up( + dim=dims_decoder[i_layer], + depth=depths_decoder[i_layer], + d_state=math.ceil(dims[0] / 6) if d_state is None else d_state, # 20240109 + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr_decoder[sum(depths_decoder[:i_layer]):sum(depths_decoder[:i_layer + 1])], + norm_layer=norm_layer, + upsample=PatchExpand2D if (i_layer != 0) else None, + use_checkpoint=use_checkpoint, + ) + self.layers_up.append(layer) + + self.final_up = Final_PatchExpand2D(dim=dims_decoder[-1], dim_scale=4, norm_layer=norm_layer) + self.final_conv = nn.Conv2d(dims_decoder[-1]//4, num_classes, 1) + + # self.norm = norm_layer(self.num_features) + # self.avgpool = nn.AdaptiveAvgPool1d(1) + # self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m: nn.Module): + """ + out_proj.weight which is previously initilized in VSSBlock, would be cleared in nn.Linear + no fc.weight found in the any of the model parameters + no nn.Embedding found in the any of the model parameters + so the thing is, VSSBlock initialization is useless + + Conv2D is not intialized !!! + """ + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def forward_features(self, x): + skip_list = [] + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + skip_list.append(x) + x = layer(x) + return x, skip_list + + def forward_features_up(self, x, skip_list): + for inx, layer_up in enumerate(self.layers_up): + if inx == 0: + x = layer_up(x) + else: + x = layer_up(x+skip_list[-inx]) + + return x + + def forward_final(self, x): + x = self.final_up(x) + x = x.permute(0,3,1,2) + x = self.final_conv(x) + return x + + def forward_backbone(self, x): + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x) + return x + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + x, skip_list = self.forward_features(x) + x = self.forward_features_up(x, skip_list) + x = self.forward_final(x) + + if(len(x_shape) == 5): + new_shape = [N, D] + list(x.shape)[1:] + x = torch.transpose(torch.reshape(x, new_shape), 1, 2) + + return x + + + + + + diff --git a/pymic/net/net2d/unet2d_vm_light.py b/pymic/net/net2d/unet2d_vm_light.py index ed3f1c5..ab1de76 100644 --- a/pymic/net/net2d/unet2d_vm_light.py +++ b/pymic/net/net2d/unet2d_vm_light.py @@ -150,7 +150,7 @@ def __init__(self, params): :param class_num: (int) The class number for segmentation task. :param feature_chns: (list) Feature channel for each resolution level. The length should be 6, by default it is [8, 16, 24, 32, 48, 64]. - :param bridge: (int) If the bridge based on spatial and channel attentions is used or not. + :param bridge: (bool) If the bridge based on spatial and channel attentions is used or not. By default it is True. """ super(UltraLight_VM_UNet, self).__init__() diff --git a/pymic/net/net_dict_seg.py b/pymic/net/net_dict_seg.py index b557bec..6f0f0c6 100644 --- a/pymic/net/net_dict_seg.py +++ b/pymic/net/net_dict_seg.py @@ -27,6 +27,7 @@ from pymic.net.net2d.trans2d.transunet import TransUNet from pymic.net.net2d.trans2d.swinunet import SwinUNet from pymic.net.net2d.umamba import UMambaBot, UMambaEnc +from pymic.net.net2d.unet2d_vm import VMUNet from pymic.net.net2d.unet2d_vm_light import UltraLight_VM_UNet from pymic.net.net3d.unet2d5 import UNet2D5 from pymic.net.net3d.unet3d import UNet3D @@ -66,6 +67,7 @@ 'UNet2D_ScSE': UNet2D_ScSE, 'UMambaBot': UMambaBot, 'UMambaEnc': UMambaEnc, + 'VMUNet':VMUNet, 'UltraLight_VM_UNet': UltraLight_VM_UNet, 'TransUNet': TransUNet, 'SwinUNet': SwinUNet, From 9481ec32aab3dfb818be8c75056a6622046c1ffc Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 1 Aug 2025 11:24:36 +0800 Subject: [PATCH 216/225] add DMSPS add DMSPS for weakly supervised segmentation add adaptive region specific Tverskyloss --- pymic/loss/seg/ars_tversky.py | 67 ++++++++++++++ pymic/net_run/weak_sup/__init__.py | 4 +- pymic/net_run/weak_sup/wsl_dmpls.py | 10 ++- pymic/net_run/weak_sup/wsl_dmsps.py | 131 ++++++++++++++++++++++++++++ 4 files changed, 208 insertions(+), 4 deletions(-) create mode 100644 pymic/loss/seg/ars_tversky.py create mode 100644 pymic/net_run/weak_sup/wsl_dmsps.py diff --git a/pymic/loss/seg/ars_tversky.py b/pymic/loss/seg/ars_tversky.py new file mode 100644 index 0000000..4fafeae --- /dev/null +++ b/pymic/loss/seg/ars_tversky.py @@ -0,0 +1,67 @@ + +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch.nn as nn +from pymic.loss.seg.abstract import AbstractSegLoss + +class ARSTverskyLoss(AbstractSegLoss): + """ + The Adaptive Region-Specific Loss in this paper: + + * Y. Chen et al.: Adaptive Region-Specific Loss for Improved Medical Image Segmentation. + `IEEE TPAMI 2023. `_ + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `ARSTversky_patch_size`: (list) the patch size. + :param `A`: the lowest weight for FP or FN (default 0.3) + :param `B`: the gap between lowest and highest weight (default 0.4) + """ + def __init__(self, params): + super(ARSTverskyLoss, self).__init__(params) + self.patch_size = params['ARSTversky_patch_size'.lower()] + self.a = params.get('ARSTversky_a'.lower(), 0.3) + self.b = params.get('ARSTversky_b'.lower(), 0.4) + + self.dim = len(self.patch_size) + assert self.dim in [2, 3], "The num of dim must be 2 or 3." + if self.dim == 3: + self.pool = nn.AvgPool3d(kernel_size=self.patch_size, stride=self.patch_size) + elif self.dim == 2: + self.pool = nn.AvgPool2d(kernel_size=self.patch_size, stride=self.patch_size) + + def forward(self, loss_input_dict): + predict = loss_input_dict['prediction'] + soft_y = loss_input_dict['ground_truth'] + + if(isinstance(predict, (list, tuple))): + predict = predict[0] + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) + + smooth = 1e-5 + if self.dim == 2: + assert predict.shape[-2] % self.patch_size[0] == 0, "image size % patch size must be 0 in dimension y" + assert predict.shape[-1] % self.patch_size[1] == 0, "image size % patch size must be 0 in dimension x" + elif self.dim == 3: + assert predict.shape[-3] % self.patch_size[0] == 0, "image size % patch size must be 0 in dimension z" + assert predict.shape[-2] % self.patch_size[1] == 0, "image size % patch size must be 0 in dimension y" + assert predict.shape[-1] % self.patch_size[2] == 0, "image size % patch size must be 0 in dimension x" + + tp = predict * soft_y + fp = predict * (1 - soft_y) + fn = (1 - predict) * soft_y + + region_tp = self.pool(tp) + region_fp = self.pool(fp) + region_fn = self.pool(fn) + + alpha = self.a + self.b * (region_fp + smooth) / (region_fp + region_fn + smooth) + beta = self.a + self.b * (region_fn + smooth) / (region_fp + region_fn + smooth) + + region_tversky = (region_tp + smooth) / (region_tp + alpha * region_fp + beta * region_fn + smooth) + region_tversky = 1 - region_tversky + loss = region_tversky.mean() + return loss \ No newline at end of file diff --git a/pymic/net_run/weak_sup/__init__.py b/pymic/net_run/weak_sup/__init__.py index b3c8332..a583ae8 100644 --- a/pymic/net_run/weak_sup/__init__.py +++ b/pymic/net_run/weak_sup/__init__.py @@ -6,10 +6,12 @@ from pymic.net_run.weak_sup.wsl_tv import WSLTotalVariation from pymic.net_run.weak_sup.wsl_ustm import WSLUSTM from pymic.net_run.weak_sup.wsl_dmpls import WSLDMPLS +from pymic.net_run.weak_sup.wsl_dmsps import WSLDMSPS WSLMethodDict = {'EntropyMinimization': WSLEntropyMinimization, 'GatedCRF': WSLGatedCRF, 'MumfordShah': WSLMumfordShah, 'TotalVariation': WSLTotalVariation, 'USTM': WSLUSTM, - 'DMPLS': WSLDMPLS} \ No newline at end of file + 'DMPLS': WSLDMPLS, + 'DMSPS': WSLDMSPS} \ No newline at end of file diff --git a/pymic/net_run/weak_sup/wsl_dmpls.py b/pymic/net_run/weak_sup/wsl_dmpls.py index 01902c7..ea96bbb 100644 --- a/pymic/net_run/weak_sup/wsl_dmpls.py +++ b/pymic/net_run/weak_sup/wsl_dmpls.py @@ -5,11 +5,11 @@ import random import time import torch -from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.dice import DiceLoss +from pymic.loss.seg.ce import CrossEntropyLoss from pymic.net_run.weak_sup import WSLSegAgent from pymic.util.ramps import get_rampup_ratio @@ -42,9 +42,13 @@ def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] wsl_cfg = self.config['weakly_supervised_learning'] - iter_max = self.config['training']['iter_max'] + iter_max = self.config['training']['iter_max'] rampup_start = wsl_cfg.get('rampup_start', 0) rampup_end = wsl_cfg.get('rampup_end', iter_max) + pseudo_loss_type = wsl_cfg.get('pseudo_sup_loss', 'dice_loss') + if (pseudo_loss_type not in ('dice_loss', 'ce_loss')): + raise ValueError("""For pseudo supervision loss, only dice_loss and ce_loss \ + are supported.""") train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 @@ -83,7 +87,7 @@ def training(self): pseudo_lab = get_soft_label(pseudo_lab, class_num, self.tensor_type) # calculate the pseudo label supervision loss - loss_calculator = DiceLoss() + loss_calculator = DiceLoss() if pseudo_loss_type == 'dice_loss' else CrossEntropyLoss() loss_dict1 = {"prediction":outputs1, 'ground_truth':pseudo_lab} loss_dict2 = {"prediction":outputs2, 'ground_truth':pseudo_lab} loss_reg = 0.5 * (loss_calculator(loss_dict1) + loss_calculator(loss_dict2)) diff --git a/pymic/net_run/weak_sup/wsl_dmsps.py b/pymic/net_run/weak_sup/wsl_dmsps.py new file mode 100644 index 0000000..dca3643 --- /dev/null +++ b/pymic/net_run/weak_sup/wsl_dmsps.py @@ -0,0 +1,131 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import logging +import numpy as np +import random +import time +import torch +from pymic.loss.seg.util import get_soft_label +from pymic.loss.seg.util import reshape_prediction_and_ground_truth +from pymic.loss.seg.util import get_classwise_dice +from pymic.loss.seg.dice import DiceLoss +from pymic.loss.seg.ce import CrossEntropyLoss +from pymic.net_run.weak_sup import WSLSegAgent +from pymic.util.ramps import get_rampup_ratio + +class WSLDMSPS(WSLSegAgent): + """ + Weakly supervised segmentation based on Dynamically Mixed Pseudo Labels Supervision. + + * Reference: Meng Han, Xiangde Luo, Xiangjiang Xie, Wenjun Liao, Shichuan Zhang, Tao Song, + Guotai Wang, Shaoting Zhang. DMSPS: Dynamically mixed soft pseudo-label supervision for + scribble-supervised medical image segmentation. + `Medical Image Analysis 2024. `_ + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `weakly_supervised_learning` is needed. See :doc:`usage.wsl` for details. + """ + def __init__(self, config, stage = 'train'): + net_type = config['network']['net_type'] + if net_type not in ['UNet2D_DualBranch', 'UNet3D_DualBranch']: + raise ValueError("""For WSL_DMPLS, a dual branch network is expected. \ + It only supports UNet2D_DualBranch and UNet3D_DualBranch currently.""") + super(WSLDMSPS, self).__init__(config, stage) + + def training(self): + class_num = self.config['network']['class_num'] + iter_valid = self.config['training']['iter_valid'] + wsl_cfg = self.config['weakly_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = wsl_cfg.get('rampup_start', 0) + rampup_end = wsl_cfg.get('rampup_end', iter_max) + pseudo_loss_type = wsl_cfg.get('pseudo_sup_loss', 'ce_loss') + if (pseudo_loss_type not in ('dice_loss', 'ce_loss')): + raise ValueError("""For pseudo supervision loss, only dice_loss and ce_loss \ + are supported.""") + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 + train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 + self.net.train() + for it in range(iter_valid): + t0 = time.time() + try: + data = next(self.trainIter) + except StopIteration: + self.trainIter = iter(self.train_loader) + data = next(self.trainIter) + t1 = time.time() + # get the inputs + inputs = self.convert_tensor_type(data['image']) + y = self.convert_tensor_type(data['label_prob']) + + inputs, y = inputs.to(self.device), y.to(self.device) + + # zero the parameter gradients + self.optimizer.zero_grad() + + # forward + backward + optimize + outputs1, outputs2 = self.net(inputs) + t2 = time.time() + + loss_sup1 = self.get_loss_value(data, outputs1, y) + loss_sup2 = self.get_loss_value(data, outputs2, y) + loss_sup = 0.5 * (loss_sup1 + loss_sup2) + + # get pseudo label with dynamical mix + outputs_soft1 = torch.softmax(outputs1, dim=1) + outputs_soft2 = torch.softmax(outputs2, dim=1) + beta = random.random() + pseudo_lab = beta*outputs_soft1.detach() + (1.0-beta)*outputs_soft2.detach() + # pseudo_lab = torch.argmax(pseudo_lab, dim = 1, keepdim = True) + # pseudo_lab = get_soft_label(pseudo_lab, class_num, self.tensor_type) + + # calculate the pseudo label supervision loss + loss_calculator = DiceLoss() if pseudo_loss_type == 'dice_loss' else CrossEntropyLoss() + loss_dict1 = {"prediction":outputs1, 'ground_truth':pseudo_lab} + loss_dict2 = {"prediction":outputs2, 'ground_truth':pseudo_lab} + loss_reg = 0.5 * (loss_calculator(loss_dict1) + loss_calculator(loss_dict2)) + + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio + loss = loss_sup + regular_w*loss_reg + t3 = time.time() + loss.backward() + t4 = time.time() + self.optimizer.step() + + train_loss = train_loss + loss.item() + train_loss_sup = train_loss_sup + loss_sup.item() + train_loss_reg = train_loss_reg + loss_reg.item() + # get dice evaluation for each class in annotated images + if(isinstance(outputs1, tuple) or isinstance(outputs1, list)): + outputs1 = outputs1[0] + p_argmax = torch.argmax(outputs1, dim = 1, keepdim = True) + p_soft = get_soft_label(p_argmax, class_num, self.tensor_type) + p_soft, y = reshape_prediction_and_ground_truth(p_soft, y) + dice_list = get_classwise_dice(p_soft, y) + train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 + train_avg_loss = train_loss / iter_valid + train_avg_loss_sup = train_loss_sup / iter_valid + train_avg_loss_reg = train_loss_reg / iter_valid + train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) + train_avg_dice = train_cls_dice[1:].mean() + + train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, + 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time} + return train_scalers + \ No newline at end of file From 70736d54f0f933614d8a0d08d9d47de1c2a384ae Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 1 Aug 2025 11:32:36 +0800 Subject: [PATCH 217/225] Create history.txt --- docs/history.txt | 1 + 1 file changed, 1 insertion(+) create mode 100644 docs/history.txt diff --git a/docs/history.txt b/docs/history.txt new file mode 100644 index 0000000..3ca7b28 --- /dev/null +++ b/docs/history.txt @@ -0,0 +1 @@ +2025.8.1 Add code of DMSPS \ No newline at end of file From a4634f26262a26b64a118fe9c6fe9e32c95b6ee8 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 2 Aug 2025 12:41:51 +0800 Subject: [PATCH 218/225] add support to h5 files 1, add support to h5 files 2, edit Rescale, Pad and Rotate so that setting output size to a 2D list is allowed for 3D images --- pymic/io/image_read_write.py | 4 +- pymic/io/nifty_dataset.py | 143 ++++++++++++++++++++++++----------- pymic/net_run/agent_seg.py | 12 ++- pymic/transform/pad.py | 5 ++ pymic/transform/rescale.py | 8 +- pymic/transform/rotate.py | 31 ++++++-- 6 files changed, 143 insertions(+), 60 deletions(-) diff --git a/pymic/io/image_read_write.py b/pymic/io/image_read_write.py index efbe656..cb17259 100644 --- a/pymic/io/image_read_write.py +++ b/pymic/io/image_read_write.py @@ -81,8 +81,8 @@ def load_image_as_nd_array(image_name): if (image_name.endswith(".nii.gz") or image_name.endswith(".nii") or image_name.endswith(".mha")): image_dict = load_nifty_volume_as_4d_array(image_name) - elif(image_name.endswith(".jpg") or image_name.endswith(".jpeg") or - image_name.endswith(".tif") or image_name.endswith(".png")): + elif(image_name.lower().endswith(".jpg") or image_name.lower().endswith(".jpeg") or + image_name.lower().endswith(".tif") or image_name.lower().endswith(".png")): image_dict = load_rgb_image_as_3d_array(image_name) else: raise ValueError("unsupported image format: {0:}".format(image_name)) diff --git a/pymic/io/nifty_dataset.py b/pymic/io/nifty_dataset.py index 048ba64..cb23e2f 100644 --- a/pymic/io/nifty_dataset.py +++ b/pymic/io/nifty_dataset.py @@ -3,12 +3,26 @@ import logging import os +import h5py import pandas as pd import numpy as np from torch.utils.data import Dataset from pymic import TaskType from pymic.io.image_read_write import load_image_as_nd_array +def check_and_expand_dim(x, img_dim): + """ + check the input dim and expand it with a channel dimension if necessary. + For 2D images, return a 3D numpy array with a shape of [C, H, W] + for 3D images, return a 3D numpy array with a shape of [C, D, H, W] + """ + input_dim = len(x.shape) + if(input_dim == 2 and img_dim == 2): + x = np.expand_dims(x, axis = 0) + elif(input_dim == 3 and img_dim == 3): + x = np.expand_dims(x, axis = 0) + return x + class NiftyDataset(Dataset): """ Dataset for loading images for segmentation. It generates 4D tensors with @@ -16,37 +30,64 @@ class NiftyDataset(Dataset): with dimention order [C, H, W] for 2D images. :param root_dir: (str) Directory with all the images. - :param csv_file: (str) Path to the csv file with image names. - :param modal_num: (int) Number of modalities. + :param csv: (str) Path to the csv file with image names. If it is None, + the images will be those under root_dir. This only works for testing with + a single input modality. If the images are stored in h5 files, the *.csv file + only has one column, while for other types of images such as .nii.gz and.png, + each column is for an input modality, and the last column is for label. + :param modal_num: (int) Number of modalities. This is only used if the data_file is *.csv. + :param image_dim: (int) Spacial dimension of the input image. This is ony used for h5 files. :param with_label: (bool) Load the data with segmentation ground truth or not. :param transform: (list) List of transforms to be applied on a sample. The built-in transforms can listed in :mod:`pymic.transform.trans_dict`. """ - # def __init__(self, root_dir, csv_file, modal_num = 1, - def __init__(self, root_dir, csv_file, modal_num = 1, allow_missing_modal = False, - with_label = False, transform=None, task = TaskType.SEGMENTATION): + def __init__(self, root_dir, csv_file, modal_num = 1, image_dim = 3, allow_missing_modal = False, + with_label = True, transform=None, task = TaskType.SEGMENTATION): self.root_dir = root_dir - self.csv_items = pd.read_csv(csv_file) + if(csv_file is not None): + self.csv_items = pd.read_csv(csv_file) + else: + img_names = os.listdir(root_dir) + img_names = [item for item in img_names if ("nii" in item or "jpg" in item or + "jpeg" in item or "bmp" in item or "png" in item)] + csv_dict = {"image":img_names} + self.csv_items = pd.DataFrame.from_dict(csv_dict) + self.modal_num = modal_num + self.image_dim = image_dim self.allow_emtpy= allow_missing_modal self.with_label = with_label self.transform = transform self.task = task + self.h5files = False assert self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION] - csv_keys = list(self.csv_items.keys()) - if('label' not in csv_keys): + # check if the files are h5 images, and if the labels are provided. + temp_name = self.csv_items.iloc[0, 0] + logging.warning(temp_name) + if(temp_name.endswith(".h5")): + self.h5files = True + temp_full_name = "{0:}/{1:}".format(self.root_dir, temp_name) + h5f = h5py.File(temp_full_name, 'r') + if('label' not in h5f): + self.with_label = False + else: + csv_keys = list(self.csv_items.keys()) + if('label' not in csv_keys): + self.with_label = False + + self.image_weight_idx = None + self.pixel_weight_idx = None + if('image_weight' in csv_keys): + self.image_weight_idx = csv_keys.index('image_weight') + if('pixel_weight' in csv_keys): + self.pixel_weight_idx = csv_keys.index('pixel_weight') + if(not self.with_label): logging.warning("`label` section is not found in the csv file {0:}".format( - csv_file) + "\n -- This is only allowed for self-supervised learning" + + csv_file) + "or the corresponding h5 file." + + "\n -- This is only allowed for self-supervised learning" + "\n -- when `SelfSuperviseLabel` is used in the transform, or when" + "\n -- loading the unlabeled data for preprocessing.") - self.with_label = False - self.image_weight_idx = None - self.pixel_weight_idx = None - if('image_weight' in csv_keys): - self.image_weight_idx = csv_keys.index('image_weight') - if('pixel_weight' in csv_keys): - self.pixel_weight_idx = csv_keys.index('pixel_weight') def __len__(self): return len(self.csv_items) @@ -92,36 +133,46 @@ def __get_pixel_weight__(self, idx): def __getitem__(self, idx): names_list, image_list = [], [] image_shape = None - for i in range (self.modal_num): - image_name = self.csv_items.iloc[idx, i] - image_full_name = "{0:}/{1:}".format(self.root_dir, image_name) - if(os.path.exists(image_full_name)): - image_dict = load_image_as_nd_array(image_full_name) - image_data = image_dict['data_array'] - elif(self.allow_emtpy and image_shape is not None): - image_data = np.zeros(image_shape) - else: - raise KeyError("File not found: {0:}".format(image_full_name)) - if(i == 0): - image_shape = image_data.shape - names_list.append(image_name) - image_list.append(image_data) - image = np.concatenate(image_list, axis = 0) - image = np.asarray(image, np.float32) - - sample = {'image': image, 'names' : names_list, - 'origin':image_dict['origin'], - 'spacing': image_dict['spacing'], - 'direction':image_dict['direction']} - if (self.with_label): - sample['label'], label_name = self.__getlabel__(idx) - sample['names'].append(label_name) - assert(image.shape[1:] == sample['label'].shape[1:]) - if (self.image_weight_idx is not None): - sample['image_weight'] = self.csv_items.iloc[idx, self.image_weight_idx] - if (self.pixel_weight_idx is not None): - sample['pixel_weight'] = self.__get_pixel_weight__(idx) - assert(image.shape[1:] == sample['pixel_weight'].shape[1:]) + if(self.h5files): + sample_name = self.csv_items.iloc[idx, 0] + h5f = h5py.File(self.root_dir + '/' + sample_name, 'r') + img = check_and_expand_dim(h5f['image'][:], self.image_dim) + sample = {'image':img} + if(self.with_label): + lab = check_and_expand_dim(h5f['label'][:], self.image_dim) + sample['label'] = lab + sample['names'] = [sample_name] + else: + for i in range (self.modal_num): + image_name = self.csv_items.iloc[idx, i] + image_full_name = "{0:}/{1:}".format(self.root_dir, image_name) + if(os.path.exists(image_full_name)): + image_dict = load_image_as_nd_array(image_full_name) + image_data = image_dict['data_array'] + elif(self.allow_emtpy and image_shape is not None): + image_data = np.zeros(image_shape) + else: + raise KeyError("File not found: {0:}".format(image_full_name)) + if(i == 0): + image_shape = image_data.shape + names_list.append(image_name) + image_list.append(image_data) + image = np.concatenate(image_list, axis = 0) + image = np.asarray(image, np.float32) + + sample = {'image': image, 'names' : names_list, + 'origin':image_dict['origin'], + 'spacing': image_dict['spacing'], + 'direction':image_dict['direction']} + if (self.with_label): + sample['label'], label_name = self.__getlabel__(idx) + sample['names'].append(label_name) + assert(image.shape[1:] == sample['label'].shape[1:]) + if (self.image_weight_idx is not None): + sample['image_weight'] = self.csv_items.iloc[idx, self.image_weight_idx] + if (self.pixel_weight_idx is not None): + sample['pixel_weight'] = self.__get_pixel_weight__(idx) + assert(image.shape[1:] == sample['pixel_weight'].shape[1:]) if self.transform: sample = self.transform(sample) diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 45e815b..e1a3377 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -71,14 +71,18 @@ def get_stage_dataset_from_config(self, stage): modal_num = self.config['dataset'].get('modal_num', 1) allow_miss = self.config['dataset'].get('allow_missing_modal', False) stage_dir = self.config['dataset'].get('train_dir', None) - if(stage == 'valid' and "valid_dir" in self.config['dataset']): - stage_dir = self.config['dataset']['valid_dir'] - if(stage == 'test' and "test_dir" in self.config['dataset']): - stage_dir = self.config['dataset']['test_dir'] + stage_dim = self.config['dataset'].get('train_dim', 3) + if(stage == 'valid'): # and "valid_dir" in self.config['dataset']): + stage_dir = self.config['dataset'].get('valid_dir', stage_dir) + stage_dim = self.config['dataset'].get('valid_dim', stage_dim) + if(stage == 'test'): # and "test_dir" in self.config['dataset']): + stage_dir = self.config['dataset'].get('test_dir', stage_dir) + stage_dim = self.config['dataset'].get('test_dim', stage_dim) logging.info("Creating dataset for {0:}".format(stage)) dataset = NiftyDataset(root_dir = stage_dir, csv_file = csv_file, modal_num = modal_num, + image_dim = stage_dim, allow_missing_modal = allow_miss, with_label= with_label, transform = data_transform, diff --git a/pymic/transform/pad.py b/pymic/transform/pad.py index 8624aa2..509643d 100644 --- a/pymic/transform/pad.py +++ b/pymic/transform/pad.py @@ -38,6 +38,11 @@ def __call__(self, sample): image = sample['image'] input_shape = image.shape input_dim = len(input_shape) - 1 + + if(input_dim == 3): + if(len(self.output_size) == 2): + # for 3D images, igore the z-axis + self.output_size = [input_shape[1]] + list(self.output_size) assert(len(self.output_size) == input_dim) if(self.ceil_mode): multiple = [int(math.ceil(float(input_shape[1+i])/self.output_size[i]))\ diff --git a/pymic/transform/rescale.py b/pymic/transform/rescale.py index 47271ec..154b1e0 100644 --- a/pymic/transform/rescale.py +++ b/pymic/transform/rescale.py @@ -17,9 +17,9 @@ class Rescale(AbstractTransform): following fields: :param `Rescale_output_size`: (list/tuple or int) The output size along each spatial axis, - such as [D, H, W] or [H, W]. If D is None, the input image is only reslcaled in 2D. - If int, the smallest axis is matched to output_size keeping aspect ratio the same - as the input. + such as [D, H, W] or [H, W]. For 3D images, if D is None, or the lenght of tuple/list is 2, + the input image is only reslcaled in 2D. If int, the smallest axis is matched to output_size + keeping aspect ratio the same as the input. :param `Rescale_inverse`: (optional, bool) Is inverse transform needed for inference. Default is `True`. """ @@ -38,6 +38,8 @@ def __call__(self, sample): output_size = self.output_size if(output_size[0] is None): output_size[0] = input_shape[1] + if(input_dim == 3 and len(self.output_size) == 2): + output_size = [input_shape[1]] + list(output_size) assert(len(output_size) == input_dim) else: min_edge = min(input_shape[1:]) diff --git a/pymic/transform/rotate.py b/pymic/transform/rotate.py index 5f85e28..b09f8da 100644 --- a/pymic/transform/rotate.py +++ b/pymic/transform/rotate.py @@ -19,13 +19,19 @@ class RandomRotate(AbstractTransform): :param `RandomRotate_angle_range_d`: (list/tuple or None) Rotation angle (degree) range along depth axis (x-y plane), e.g., (-90, 90). + The length of the list/tuple can be larger than 2, when `RandomRotate_discrete_mode` is True. If None, no rotation along this axis. :param `RandomRotate_angle_range_h`: (list/tuple or None) Rotation angle (degree) range along height axis (x-z plane), e.g., (-90, 90). + The length of the list/tuple can be larger than 2, when `RandomRotate_discrete_mode` is True. If None, no rotation along this axis. Only used for 3D images. :param `RandomRotate_angle_range_w`: (list/tuple or None) Rotation angle (degree) range along width axis (y-z plane), e.g., (-90, 90). + The length of the list/tuple can be larger than 2, when `RandomRotate_discrete_mode` is True. If None, no rotation along this axis. Only used for 3D images. + :param `RandomRotate_discrete_mode`: (optional, bool) Whether the rotate angles + are discrete values in rangle range. For example, if you only want to rotate + the images with a fixed set of angles like (90, 180, 270), then set discrete_mode mode as True. :param `RandomRotate_probability`: (optional, float) The probability of applying RandomRotate. Default is 0.5. :param `RandomRotate_inverse`: (optional, bool) @@ -36,8 +42,11 @@ def __init__(self, params): self.angle_range_d = params['RandomRotate_angle_range_d'.lower()] self.angle_range_h = params.get('RandomRotate_angle_range_h'.lower(), None) self.angle_range_w = params.get('RandomRotate_angle_range_w'.lower(), None) + self.discrete_mode = params.get('RandomRotate_discrete_mode'.lower(), False) self.prob = params.get('RandomRotate_probability'.lower(), 0.5) self.inverse = params.get('RandomRotate_inverse'.lower(), True) + if(len(self.angle_range_d) > 2): + assert(self.discrete_mode) def __apply_transformation(self, image, transform_param_list, order = 1): """ @@ -63,15 +72,27 @@ def __call__(self, sample): transform_param_list = [] if(self.angle_range_d is not None): - angle_d = np.random.uniform(self.angle_range_d[0], self.angle_range_d[1]) + if(self.discrete_mode): + idx = random.randint(0, len(self.angle_range_d) - 1) + angle_d = self.angle_range_d[idx] + else: + angle_d = np.random.uniform(self.angle_range_d[0], self.angle_range_d[1]) transform_param_list.append([angle_d, (-1, -2)]) if(input_dim == 3): if(self.angle_range_h is not None): - angle_h = np.random.uniform(self.angle_range_h[0], self.angle_range_h[1]) - transform_param_list.append([angle_h, (-1, -3)]) + if(self.discrete_mode): + idx = random.randint(0, len(self.angle_range_h) - 1) + angle_h = self.angle_range_h[idx] + else: + angle_h = np.random.uniform(self.angle_range_h[0], self.angle_range_h[1]) + transform_param_list.append([angle_h, (-1, -3)]) if(self.angle_range_w is not None): - angle_w = np.random.uniform(self.angle_range_w[0], self.angle_range_w[1]) - transform_param_list.append([angle_w, (-2, -3)]) + if(self.discrete_mode): + idx = random.randint(0, len(self.angle_range_w) - 1) + angle_w = self.angle_range_w[idx] + else: + angle_w = np.random.uniform(self.angle_range_w[0], self.angle_range_w[1]) + transform_param_list.append([angle_w, (-2, -3)]) assert(len(transform_param_list) > 0) # select a random transform from the possible list rather than # use a combination for higher efficiency From 4590b3045d9248e516d3d8f74a49ada085781125 Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 5 Aug 2025 10:42:56 +0800 Subject: [PATCH 219/225] update dataset and transform fix code for loading h5 images update transforms --- pymic/io/nifty_dataset.py | 20 +++---- pymic/net_run/agent_seg.py | 5 +- pymic/net_run/weak_sup/wsl_dmsps.py | 45 +++++++++------- pymic/transform/rotate.py | 81 ++++++++++++++++++----------- pymic/transform/trans_dict.py | 4 ++ 5 files changed, 94 insertions(+), 61 deletions(-) diff --git a/pymic/io/nifty_dataset.py b/pymic/io/nifty_dataset.py index cb23e2f..5424e91 100644 --- a/pymic/io/nifty_dataset.py +++ b/pymic/io/nifty_dataset.py @@ -41,8 +41,9 @@ class NiftyDataset(Dataset): :param transform: (list) List of transforms to be applied on a sample. The built-in transforms can listed in :mod:`pymic.transform.trans_dict`. """ - def __init__(self, root_dir, csv_file, modal_num = 1, image_dim = 3, allow_missing_modal = False, - with_label = True, transform=None, task = TaskType.SEGMENTATION): + def __init__(self, root_dir, csv_file, modal_num = 1, image_dim = 3, + allow_missing_modal = False, label_key = "label", + transform=None, task = TaskType.SEGMENTATION): self.root_dir = root_dir if(csv_file is not None): self.csv_items = pd.read_csv(csv_file) @@ -56,10 +57,11 @@ def __init__(self, root_dir, csv_file, modal_num = 1, image_dim = 3, allow_missi self.modal_num = modal_num self.image_dim = image_dim self.allow_emtpy= allow_missing_modal - self.with_label = with_label + self.label_key = label_key self.transform = transform self.task = task self.h5files = False + self.with_label = True assert self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION] # check if the files are h5 images, and if the labels are provided. @@ -69,11 +71,11 @@ def __init__(self, root_dir, csv_file, modal_num = 1, image_dim = 3, allow_missi self.h5files = True temp_full_name = "{0:}/{1:}".format(self.root_dir, temp_name) h5f = h5py.File(temp_full_name, 'r') - if('label' not in h5f): + if(self.label_key not in h5f): self.with_label = False else: csv_keys = list(self.csv_items.keys()) - if('label' not in csv_keys): + if(self.label_key not in csv_keys): self.with_label = False self.image_weight_idx = None @@ -84,7 +86,7 @@ def __init__(self, root_dir, csv_file, modal_num = 1, image_dim = 3, allow_missi self.pixel_weight_idx = csv_keys.index('pixel_weight') if(not self.with_label): logging.warning("`label` section is not found in the csv file {0:}".format( - csv_file) + "or the corresponding h5 file." + + csv_file) + " or the corresponding h5 file." + "\n -- This is only allowed for self-supervised learning" + "\n -- when `SelfSuperviseLabel` is used in the transform, or when" + "\n -- loading the unlabeled data for preprocessing.") @@ -94,7 +96,7 @@ def __len__(self): def __getlabel__(self, idx): csv_keys = list(self.csv_items.keys()) - label_idx = csv_keys.index('label') + label_idx = csv_keys.index(self.label_key) label_name = self.csv_items.iloc[idx, label_idx] label_name_full = "{0:}/{1:}".format(self.root_dir, label_name) label = load_image_as_nd_array(label_name_full)['data_array'] @@ -139,8 +141,8 @@ def __getitem__(self, idx): img = check_and_expand_dim(h5f['image'][:], self.image_dim) sample = {'image':img} if(self.with_label): - lab = check_and_expand_dim(h5f['label'][:], self.image_dim) - sample['label'] = lab + lab = check_and_expand_dim(h5f[self.label_key][:], self.image_dim) + sample['label'] = np.asarray(lab, np.float32) sample['names'] = [sample_name] else: for i in range (self.modal_num): diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index e1a3377..0259e3e 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -72,19 +72,22 @@ def get_stage_dataset_from_config(self, stage): allow_miss = self.config['dataset'].get('allow_missing_modal', False) stage_dir = self.config['dataset'].get('train_dir', None) stage_dim = self.config['dataset'].get('train_dim', 3) + stage_lab_key = self.config['dataset'].get('train_label_key', 'label') if(stage == 'valid'): # and "valid_dir" in self.config['dataset']): stage_dir = self.config['dataset'].get('valid_dir', stage_dir) stage_dim = self.config['dataset'].get('valid_dim', stage_dim) + stage_lab_key = self.config['dataset'].get('valid_label_key', 'label') if(stage == 'test'): # and "test_dir" in self.config['dataset']): stage_dir = self.config['dataset'].get('test_dir', stage_dir) stage_dim = self.config['dataset'].get('test_dim', stage_dim) + stage_lab_key = self.config['dataset'].get('test_label_key', 'label') logging.info("Creating dataset for {0:}".format(stage)) dataset = NiftyDataset(root_dir = stage_dir, csv_file = csv_file, modal_num = modal_num, image_dim = stage_dim, allow_missing_modal = allow_miss, - with_label= with_label, + label_key = stage_lab_key, transform = data_transform, task = self.task_type) return dataset diff --git a/pymic/net_run/weak_sup/wsl_dmsps.py b/pymic/net_run/weak_sup/wsl_dmsps.py index dca3643..5eec2a4 100644 --- a/pymic/net_run/weak_sup/wsl_dmsps.py +++ b/pymic/net_run/weak_sup/wsl_dmsps.py @@ -5,11 +5,13 @@ import random import time import torch +from PIL import Image from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.dice import DiceLoss from pymic.loss.seg.ce import CrossEntropyLoss +# from torch.nn.modules.loss import CrossEntropyLoss as TorchCELoss from pymic.net_run.weak_sup import WSLSegAgent from pymic.util.ramps import get_rampup_ratio @@ -33,11 +35,11 @@ class WSLDMSPS(WSLSegAgent): """ def __init__(self, config, stage = 'train'): net_type = config['network']['net_type'] - if net_type not in ['UNet2D_DualBranch', 'UNet3D_DualBranch']: - raise ValueError("""For WSL_DMPLS, a dual branch network is expected. \ - It only supports UNet2D_DualBranch and UNet3D_DualBranch currently.""") + # if net_type not in ['UNet2D_DualBranch', 'UNet3D_DualBranch']: + # raise ValueError("""For WSL_DMPLS, a dual branch network is expected. \ + # It only supports UNet2D_DualBranch and UNet3D_DualBranch currently.""") super(WSLDMSPS, self).__init__(config, stage) - + def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] @@ -49,10 +51,12 @@ def training(self): if (pseudo_loss_type not in ('dice_loss', 'ce_loss')): raise ValueError("""For pseudo supervision loss, only dice_loss and ce_loss \ are supported.""") + pseudo_loss_func = CrossEntropyLoss() if pseudo_loss_type == 'ce_loss' else DiceLoss() train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() + # ce_loss = CrossEntropyLoss() for it in range(iter_valid): t0 = time.time() try: @@ -66,7 +70,6 @@ def training(self): y = self.convert_tensor_type(data['label_prob']) inputs, y = inputs.to(self.device), y.to(self.device) - # zero the parameter gradients self.optimizer.zero_grad() @@ -78,23 +81,26 @@ def training(self): loss_sup2 = self.get_loss_value(data, outputs2, y) loss_sup = 0.5 * (loss_sup1 + loss_sup2) - # get pseudo label with dynamical mix + # torch_ce_loss = TorchCELoss(ignore_index=class_num) + # torch_ce_loss2 = TorchCELoss() + # loss_ce1 = torch_ce_loss(outputs1, label[:].long()) + # loss_ce2 = torch_ce_loss(outputs2, label[:].long()) + # loss_sup = 0.5 * (loss_ce1 + loss_ce2) + + # get pseudo label with dynamic mixture outputs_soft1 = torch.softmax(outputs1, dim=1) outputs_soft2 = torch.softmax(outputs2, dim=1) - beta = random.random() - pseudo_lab = beta*outputs_soft1.detach() + (1.0-beta)*outputs_soft2.detach() - # pseudo_lab = torch.argmax(pseudo_lab, dim = 1, keepdim = True) - # pseudo_lab = get_soft_label(pseudo_lab, class_num, self.tensor_type) - - # calculate the pseudo label supervision loss - loss_calculator = DiceLoss() if pseudo_loss_type == 'dice_loss' else CrossEntropyLoss() - loss_dict1 = {"prediction":outputs1, 'ground_truth':pseudo_lab} - loss_dict2 = {"prediction":outputs2, 'ground_truth':pseudo_lab} - loss_reg = 0.5 * (loss_calculator(loss_dict1) + loss_calculator(loss_dict2)) + alpha = random.random() + soft_pseudo_label = alpha * outputs_soft1.detach() + (1.0-alpha) * outputs_soft2.detach() + # loss_reg = 0.5*(torch_ce_loss2(outputs_soft1, soft_pseudo_label) +torch_ce_loss2(outputs_soft2, soft_pseudo_label) ) + loss_dict1 = {"prediction":outputs_soft1, 'ground_truth':soft_pseudo_label} + loss_dict2 = {"prediction":outputs_soft2, 'ground_truth':soft_pseudo_label} + loss_reg = 0.5 * (pseudo_loss_func(loss_dict1) + pseudo_loss_func(loss_dict2)) + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") - regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio - loss = loss_sup + regular_w*loss_reg + regular_w = wsl_cfg.get('regularize_w', 8.0) * rampup_ratio + loss = loss_sup + regular_w*loss_reg t3 = time.time() loss.backward() t4 = time.time() @@ -127,5 +133,4 @@ def training(self): 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, 'data_time': data_time, 'forward_time':gpu_time, 'loss_time':loss_time, 'backward_time':back_time} - return train_scalers - \ No newline at end of file + return train_scalers \ No newline at end of file diff --git a/pymic/transform/rotate.py b/pymic/transform/rotate.py index b09f8da..5de77d7 100644 --- a/pymic/transform/rotate.py +++ b/pymic/transform/rotate.py @@ -19,19 +19,13 @@ class RandomRotate(AbstractTransform): :param `RandomRotate_angle_range_d`: (list/tuple or None) Rotation angle (degree) range along depth axis (x-y plane), e.g., (-90, 90). - The length of the list/tuple can be larger than 2, when `RandomRotate_discrete_mode` is True. If None, no rotation along this axis. :param `RandomRotate_angle_range_h`: (list/tuple or None) Rotation angle (degree) range along height axis (x-z plane), e.g., (-90, 90). - The length of the list/tuple can be larger than 2, when `RandomRotate_discrete_mode` is True. If None, no rotation along this axis. Only used for 3D images. :param `RandomRotate_angle_range_w`: (list/tuple or None) Rotation angle (degree) range along width axis (y-z plane), e.g., (-90, 90). - The length of the list/tuple can be larger than 2, when `RandomRotate_discrete_mode` is True. If None, no rotation along this axis. Only used for 3D images. - :param `RandomRotate_discrete_mode`: (optional, bool) Whether the rotate angles - are discrete values in rangle range. For example, if you only want to rotate - the images with a fixed set of angles like (90, 180, 270), then set discrete_mode mode as True. :param `RandomRotate_probability`: (optional, float) The probability of applying RandomRotate. Default is 0.5. :param `RandomRotate_inverse`: (optional, bool) @@ -42,11 +36,8 @@ def __init__(self, params): self.angle_range_d = params['RandomRotate_angle_range_d'.lower()] self.angle_range_h = params.get('RandomRotate_angle_range_h'.lower(), None) self.angle_range_w = params.get('RandomRotate_angle_range_w'.lower(), None) - self.discrete_mode = params.get('RandomRotate_discrete_mode'.lower(), False) self.prob = params.get('RandomRotate_probability'.lower(), 0.5) self.inverse = params.get('RandomRotate_inverse'.lower(), True) - if(len(self.angle_range_d) > 2): - assert(self.discrete_mode) def __apply_transformation(self, image, transform_param_list, order = 1): """ @@ -61,38 +52,21 @@ def __apply_transformation(self, image, transform_param_list, order = 1): return image def __call__(self, sample): - # if(random.random() > self.prob): - # sample['RandomRotate_triggered'] = False - # return sample - # else: - # sample['RandomRotate_triggered'] = True image = sample['image'] input_shape = image.shape input_dim = len(input_shape) - 1 transform_param_list = [] if(self.angle_range_d is not None): - if(self.discrete_mode): - idx = random.randint(0, len(self.angle_range_d) - 1) - angle_d = self.angle_range_d[idx] - else: - angle_d = np.random.uniform(self.angle_range_d[0], self.angle_range_d[1]) + angle_d = np.random.uniform(self.angle_range_d[0], self.angle_range_d[1]) transform_param_list.append([angle_d, (-1, -2)]) if(input_dim == 3): if(self.angle_range_h is not None): - if(self.discrete_mode): - idx = random.randint(0, len(self.angle_range_h) - 1) - angle_h = self.angle_range_h[idx] - else: - angle_h = np.random.uniform(self.angle_range_h[0], self.angle_range_h[1]) - transform_param_list.append([angle_h, (-1, -3)]) + angle_h = np.random.uniform(self.angle_range_h[0], self.angle_range_h[1]) + transform_param_list.append([angle_h, (-1, -3)]) if(self.angle_range_w is not None): - if(self.discrete_mode): - idx = random.randint(0, len(self.angle_range_w) - 1) - angle_w = self.angle_range_w[idx] - else: - angle_w = np.random.uniform(self.angle_range_w[0], self.angle_range_w[1]) - transform_param_list.append([angle_w, (-2, -3)]) + angle_w = np.random.uniform(self.angle_range_w[0], self.angle_range_w[1]) + transform_param_list.append([angle_w, (-2, -3)]) assert(len(transform_param_list) > 0) # select a random transform from the possible list rather than # use a combination for higher efficiency @@ -123,4 +97,49 @@ def inverse_transform_for_prediction(self, sample): transform_param_list[i][0] = - transform_param_list[i][0] sample['predict'] = self.__apply_transformation(sample['predict'] , transform_param_list, 1) + return sample + +class RandomRot90(AbstractTransform): + """ + Random rotate an image in x-y plane with angles in [90, 180, 270]. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `RandomRot90_probability`: (optional, float) + The probability of applying RandomRot90. Default is 0.75. + :param `RandomRot90_inverse`: (optional, bool) + Is inverse transform needed for inference. Default is `True`. + """ + def __init__(self, params): + super(RandomRot90, self).__init__(params) + self.prob = params.get('RandomRot90_probability'.lower(), 0.75) + self.inverse = params.get('RandomRot90_inverse'.lower(), True) + + def __call__(self, sample): + if(random.random() > self.prob): + sample['RandomRot90_triggered'] = False + sample['RandomRot90_Param'] = 0 + return sample + else: + sample['RandomRot90_triggered'] = True + image = sample['image'] + rote_k = random.randint(1, 3) + sample['RandomRot90_Param'] = rote_k + image_t = np.rot90(image, rote_k, (-2, -1)) + sample['image'] = image_t + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + sample['label'] = np.rot90(sample['label'], rote_k, (-2, -1)) + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + sample['pixel_weight'] = np.rot90(sample['pixel_weight'], rote_k, (-2, -1)) + return sample + + def inverse_transform_for_prediction(self, sample): + if(not sample['RandomRot90_triggered']): + return sample + rote_k = sample['RandomRot90_Param'] + rote_i = 4 - rote_k + sample['predict'] = np.rot90(sample['predict'], rote_i, (-2, -1)) return sample \ No newline at end of file diff --git a/pymic/transform/trans_dict.py b/pymic/transform/trans_dict.py index 332e594..e4bfe24 100644 --- a/pymic/transform/trans_dict.py +++ b/pymic/transform/trans_dict.py @@ -25,6 +25,7 @@ 'RandomRescale': RandomRescale, 'RandomFlip': RandomFlip, 'RandomRotate': RandomRotate, + 'RandomRot90': RandomRot90, 'ReduceLabelDim': ReduceLabelDim, 'Rescale': Rescale, 'SelfSuperviseLabel': SelfSuperviseLabel, @@ -43,6 +44,7 @@ from pymic.transform.normalize import * from pymic.transform.crop import * from pymic.transform.crop4dino import Crop4Dino +from pymic.transform.crop4voco import Crop4VoCo from pymic.transform.crop4vox2vec import Crop4Vox2Vec from pymic.transform.crop4vf import Crop4VolumeFusion, VolumeFusion, VolumeFusionShuffle from pymic.transform.label_convert import * @@ -57,6 +59,7 @@ 'CropHumanRegion': CropHumanRegion, 'CenterCrop': CenterCrop, 'Crop4Dino': Crop4Dino, + 'Crop4VoCo': Crop4VoCo, 'Crop4Vox2Vec': Crop4Vox2Vec, 'Crop4VolumeFusion': Crop4VolumeFusion, 'GrayscaleToRGB': GrayscaleToRGB, @@ -83,6 +86,7 @@ 'RandomTranspose': RandomTranspose, 'RandomFlip': RandomFlip, 'RandomRotate': RandomRotate, + 'RandomRot90': RandomRot90, 'ReduceLabelDim': ReduceLabelDim, 'Rescale': Rescale, 'Resample': Resample, From 4baccae73d1c7eeeaf0353c8a53602f05ea5687a Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 11 Aug 2025 15:32:05 +0800 Subject: [PATCH 220/225] update files for DMSPS --- pymic/io/image_read_write.py | 8 ++- pymic/loss/loss_dict_seg.py | 2 + pymic/loss/seg/ce.py | 4 +- pymic/net/net2d/unet2d_multi_decoder.py | 19 +++---- pymic/net_run/weak_sup/wsl_dmsps.py | 70 ++++++++++++++++++++++++- pymic/transform/rescale.py | 20 +++++-- pymic/util/evaluation_seg.py | 24 ++++----- pymic/util/parse_config.py | 23 +++++++- 8 files changed, 137 insertions(+), 33 deletions(-) diff --git a/pymic/io/image_read_write.py b/pymic/io/image_read_write.py index cb17259..f570628 100644 --- a/pymic/io/image_read_write.py +++ b/pymic/io/image_read_write.py @@ -99,7 +99,7 @@ def save_array_as_nifty_volume(data, image_name, reference_name = None, spacing :param spacing: (list or tuple) the spacing of a volume data when `reference_name` is not provided. """ img = sitk.GetImageFromArray(data) - if(reference_name is not None): + if((reference_name is not None) and (not reference_name.endswith(".h5"))): img_ref = sitk.ReadImage(reference_name) #img.CopyInformation(img_ref) img.SetSpacing(img_ref.GetSpacing()) @@ -141,11 +141,15 @@ def save_nd_array_as_image(data, image_name, reference_name = None, spacing = [1 """ data_dim = len(data.shape) assert(data_dim == 2 or data_dim == 3) + if(image_name.endswith(".h5")): + if(data_dim == 3): + image_name = image_name.replace(".h5", ".nii.gz") + else: + image_name = image_name.replace(".h5", ".png") if (image_name.endswith(".nii.gz") or image_name.endswith(".nii") or image_name.endswith(".mha")): assert(data_dim == 3) save_array_as_nifty_volume(data, image_name, reference_name, spacing) - elif(image_name.endswith(".jpg") or image_name.endswith(".jpeg") or image_name.endswith(".tif") or image_name.endswith(".png")): assert(data_dim == 2) diff --git a/pymic/loss/loss_dict_seg.py b/pymic/loss/loss_dict_seg.py index fd72ce4..36e6a21 100644 --- a/pymic/loss/loss_dict_seg.py +++ b/pymic/loss/loss_dict_seg.py @@ -26,6 +26,7 @@ from pymic.loss.seg.dice import DiceLoss, FocalDiceLoss, \ NoiseRobustDiceLoss, BinaryDiceLoss, GroupDiceLoss from pymic.loss.seg.exp_log import ExpLogLoss +from pymic.loss.seg.ars_tversky import ARSTverskyLoss from pymic.loss.seg.mse import MSELoss, MAELoss from pymic.loss.seg.slsr import SLSRLoss @@ -35,6 +36,7 @@ 'DiceLoss': DiceLoss, 'BinaryDiceLoss': BinaryDiceLoss, 'FocalDiceLoss': FocalDiceLoss, + 'ARSTverskyLoss': ARSTverskyLoss, 'NoiseRobustDiceLoss': NoiseRobustDiceLoss, 'GroupDiceLoss': GroupDiceLoss, 'ExpLogLoss': ExpLogLoss, diff --git a/pymic/loss/seg/ce.py b/pymic/loss/seg/ce.py index 4edbbc3..bf036a3 100644 --- a/pymic/loss/seg/ce.py +++ b/pymic/loss/seg/ce.py @@ -36,7 +36,7 @@ def forward(self, loss_input_dict): soft_y = reshape_tensor_to_2D(soft_y) # for numeric stability - predict = predict * 0.999 + 5e-4 + # predict = predict * (1-1e-10) + 0.5e-10 ce = - soft_y* torch.log(predict) if(cls_w is not None): ce = torch.sum(ce*cls_w, dim = 1) @@ -46,7 +46,7 @@ def forward(self, loss_input_dict): ce = torch.mean(ce) else: pix_w = torch.squeeze(reshape_tensor_to_2D(pix_w)) - ce = torch.sum(pix_w * ce) / (pix_w.sum() + 1e-5) + ce = torch.sum(pix_w * ce) / (pix_w.sum() + 1e-10) return ce class GeneralizedCELoss(AbstractSegLoss): diff --git a/pymic/net/net2d/unet2d_multi_decoder.py b/pymic/net/net2d/unet2d_multi_decoder.py index 2c7b3c4..03bd99f 100644 --- a/pymic/net/net2d/unet2d_multi_decoder.py +++ b/pymic/net/net2d/unet2d_multi_decoder.py @@ -63,15 +63,16 @@ def forward(self, x): output2 = torch.reshape(output2, new_shape) output2 = torch.transpose(output2, 1, 2) - if(self.training): - return output1, output2 - else: - if(self.output_mode == "average"): - return (output1 + output2)/2 - elif(self.output_mode == "first"): - return output1 - else: - return output2 + return output1, output2 + # if(self.training): + # return output1, output2 + # else: + # if(self.output_mode == "average"): + # return (output1 + output2)/2 + # elif(self.output_mode == "first"): + # return output1 + # else: + # return output2 class UNet2D_TriBranch(nn.Module): """ diff --git a/pymic/net_run/weak_sup/wsl_dmsps.py b/pymic/net_run/weak_sup/wsl_dmsps.py index 5eec2a4..b610f72 100644 --- a/pymic/net_run/weak_sup/wsl_dmsps.py +++ b/pymic/net_run/weak_sup/wsl_dmsps.py @@ -1,11 +1,13 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division import logging +import os import numpy as np import random import time import torch -from PIL import Image +import scipy +from pymic.io.image_read_write import save_nd_array_as_image from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice @@ -133,4 +135,68 @@ def training(self): 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, 'data_time': data_time, 'forward_time':gpu_time, 'loss_time':loss_time, 'backward_time':back_time} - return train_scalers \ No newline at end of file + return train_scalers + + def save_outputs(self, data): + """ + Save prediction output. + + :param data: (dictionary) A data dictionary with prediciton result and other + information such as input image name. + """ + output_dir = self.config['testing']['output_dir'] + test_mode = self.config['testing'].get('dmsps_test_mode', 0) + uct_threshold = self.config['testing'].get('dmsps_uncertainty_threshold', 0.1) + # DMSPS_test_mode == 0: only save the segmentation label for the main decoder + # DMSPS_test_mode == 1: save all the results, including the the probability map of each decoder, + # the uncertainty map, and the confident predictions + if(not os.path.exists(output_dir)): + os.makedirs(output_dir, exist_ok=True) + + names, pred = data['names'], data['predict'] + pred0, pred1 = pred + prob0 = scipy.special.softmax(pred0, axis = 1) + prob1 = scipy.special.softmax(pred1, axis = 1) + prob_mean = (prob0 + prob1) / 2 + lab0 = np.asarray(np.argmax(prob0, axis = 1), np.uint8) + lab1 = np.asarray(np.argmax(prob1, axis = 1), np.uint8) + lab_mean = np.asarray(np.argmax(prob_mean, axis = 1), np.uint8) + + # save the output and (optionally) probability predictions + test_dir = self.config['dataset'].get('test_dir', None) + if(test_dir is None): + test_dir = self.config['dataset']['train_dir'] + img_name = names[0][0].split('/')[-1] + print(img_name) + lab0_name = img_name + if(".h5" in lab0_name): + lab0_name = lab0_name.replace(".h5", ".nii.gz") + save_nd_array_as_image(lab0[0], output_dir + "/" + lab0_name, test_dir + '/' + names[0][0]) + if(test_mode == 1): + lab1_name = lab0_name.replace(".nii.gz", "_predaux.nii.gz") + save_nd_array_as_image(lab1[0], output_dir + "/" + lab1_name, test_dir + '/' + names[0][0]) + C = pred0.shape[1] + uct = -1.0 * np.sum(prob_mean * np.log(prob_mean), axis=1, keepdims=False)/ np.log(C) + uct_name = lab0_name.replace(".nii.gz", "_uncertainty.nii.gz") + save_nd_array_as_image(uct[0], output_dir + "/" + uct_name, test_dir + '/' + names[0][0]) + conf_mask = uct < uct_threshold + conf_lab = conf_mask * lab_mean + (1 - conf_mask)*4 + conf_lab_name = lab0_name.replace(".nii.gz", "_seeds_expand.nii.gz") + + # get the largest connected component in each slice for each class + D, H, W = conf_lab[0].shape + from pymic.util.image_process import get_largest_k_components + for d in range(D): + lab2d = conf_lab[0][d] + for c in range(C): + lab2d_c = lab2d == c + mask_c = get_largest_k_components(lab2d_c, k = 1) + diff = lab2d_c != mask_c + if(np.sum(diff) > 0): + lab2d[diff] = C + conf_lab[0][d] = lab2d + save_nd_array_as_image(conf_lab[0], output_dir + "/" + conf_lab_name, test_dir + '/' + img_name) + + + + \ No newline at end of file diff --git a/pymic/transform/rescale.py b/pymic/transform/rescale.py index 154b1e0..ba519c7 100644 --- a/pymic/transform/rescale.py +++ b/pymic/transform/rescale.py @@ -72,12 +72,22 @@ def inverse_transform_for_prediction(self, sample): origin_shape = json.loads(sample['Rescale_origin_shape']) origin_dim = len(origin_shape) - 1 predict = sample['predict'] - input_shape = predict.shape - scale = [(origin_shape[1:][i] + 0.0)/input_shape[2:][i] for \ - i in range(origin_dim)] - scale = [1.0, 1.0] + scale - output_predict = ndimage.interpolation.zoom(predict, scale, order = 1) + if(isinstance(predict, tuple) or isinstance(predict, list)): + output_predict = [] + for predict_i in predict: + input_shape = predict_i.shape + scale = [(origin_shape[1:][i] + 0.0)/input_shape[2:][i] for \ + i in range(origin_dim)] + scale = [1.0, 1.0] + scale + output_predict_i = ndimage.interpolation.zoom(predict_i, scale, order = 1) + output_predict.append(output_predict_i) + else: + input_shape = predict.shape + scale = [(origin_shape[1:][i] + 0.0)/input_shape[2:][i] for \ + i in range(origin_dim)] + scale = [1.0, 1.0] + scale + output_predict = ndimage.interpolation.zoom(predict, scale, order = 1) sample['predict'] = output_predict return sample diff --git a/pymic/util/evaluation_seg.py b/pymic/util/evaluation_seg.py index 926ff0e..5099d3b 100644 --- a/pymic/util/evaluation_seg.py +++ b/pymic/util/evaluation_seg.py @@ -212,8 +212,10 @@ def get_binary_evaluation_score(s_volume, g_volume, spacing, metric): score = binary_iou(s_volume,g_volume) elif(metric_lower == 'assd'): score = binary_assd(s_volume, g_volume, spacing) + score = min(score, 20) # to reject outliers elif(metric_lower == "hd95"): score = binary_hd95(s_volume, g_volume, spacing) + score = min(score, 50) # to reject outliers elif(metric_lower == "rve"): score = binary_relative_volume_error(s_volume, g_volume) elif(metric_lower == "volume"): @@ -269,8 +271,8 @@ def evaluation(config): :param label_fuse: (option, bool) If true, fuse the labels in the `label_list` as the foreground, and other labels as the background. Default is False. :param organ_name: (str) The name of the organ for segmentation. - :param ground_truth_folder_root: (str) The root dir of ground truth images. - :param segmentation_folder_root: (str or list) The root dir of segmentation images. + :param ground_truth_folder: (str) The root dir of ground truth images. + :param segmentation_folder: (str or list) The root dir of segmentation images. When a list is given, each list element should be the root dir of the results of one method. :param evaluation_image_pair: (str) The csv file that provide the segmentation images and the corresponding ground truth images. @@ -366,23 +368,23 @@ def main(): """ parser = argparse.ArgumentParser() - parser.add_argument("-cfg", help="configuration file for evaluation", + parser.add_argument("--cfg", help="configuration file for evaluation", required=False, default=None) - parser.add_argument("-metric", help="evaluation metrics, e.g., dice, or [dice, assd]", + parser.add_argument("--metric", help="evaluation metrics, e.g., dice, or [dice, assd]", required=False, default=None) - parser.add_argument("-cls_num", help="number of classes", + parser.add_argument("--cls_num", help="number of classes", required=False, default=None) - parser.add_argument("-cls_index", help="The class index for evaluation, e.g., 255, [1, 2]", + parser.add_argument("--cls_index", help="The class index for evaluation, e.g., 255, [1, 2]", required=False, default=None) - parser.add_argument("-gt_dir", help="path of folder for ground truth", + parser.add_argument("--gt_dir", help="path of folder for ground truth", required=False, default=None) - parser.add_argument("-seg_dir", help="path of folder for segmentation", + parser.add_argument("--seg_dir", help="path of folder for segmentation", required=False, default=None) - parser.add_argument("-name_pair", help="the .csv file for name mapping in case" + parser.add_argument("--name_pair", help="the .csv file for name mapping in case" " the names of one case are different in the gt_dir " " and seg_dir", required=False, default=None) - parser.add_argument("-out", help="the output .csv file name", + parser.add_argument("--out", help="the output .csv file name", required=False, default=None) args = parser.parse_args() print(args) @@ -402,5 +404,3 @@ def main(): if __name__ == '__main__': main() - - main() diff --git a/pymic/util/parse_config.py b/pymic/util/parse_config.py index 09f6db0..3be02cb 100644 --- a/pymic/util/parse_config.py +++ b/pymic/util/parse_config.py @@ -96,12 +96,22 @@ def parse_config(args): args_key = getattr(args, key) if(args_key is not None): val_str = args_key + print(section, key, val_str) if(len(val_str)>0): val = parse_value_from_string(val_str) output[section][key] = val else: val = None - print(section, key, val) + + for key in ["train_dir", "train_csv", "valid_csv", "test_dir", "test_csv"]: + if key in args and getattr(args, key) is not None: + output["dataset"][key] = parse_value_from_string(getattr(args, key)) + for key in ["ckpt_dir", "iter_max", "gpus"]: + if key in args and getattr(args, key) is not None: + output["training"][key] = parse_value_from_string(getattr(args, key)) + for key in ["output_dir", "ckpt_mode", "ckpt_name"]: + if key in args and getattr(args, key) is not None: + output["testing"][key] = parse_value_from_string(getattr(args, key)) return output def synchronize_config(config): @@ -133,6 +143,17 @@ def synchronize_config(config): if('RandomResizedCrop' in transform and \ 'RandomResizedCrop_output_size'.lower() not in data_cfg): data_cfg['RandomResizedCrop_output_size'.lower()] = patch_size + if('testing' in config): + test_cfg = config['testing'] + sliding_window_enable = test_cfg.get("sliding_window_enable", False) + if(sliding_window_enable): + sliding_window_size = test_cfg.get("sliding_window_size", None) + if(sliding_window_size is None): + test_cfg["sliding_window_size"] = patch_size + sliding_window_stride = test_cfg.get("sliding_window_stride", None) + if(sliding_window_stride is None): + test_cfg["sliding_window_stride"] = [item // 2 for item in patch_size] + config['testing'] = test_cfg config['dataset'] = data_cfg # config['network'] = net_cfg return config From 254491e3deaeb2df7f47078846dcb6ba36937345 Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 12 Aug 2025 12:44:23 +0800 Subject: [PATCH 221/225] Update setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index cbf7355..51d939d 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setuptools.setup( name = 'PYMIC', - version = "0.5.0", + version = "0.5.1", author ='PyMIC Consortium', author_email = 'wguotai@gmail.com', description = description, From 92a9bf9e7fc9745cf6a153d85409d5b4671cc632 Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 12 Aug 2025 13:08:04 +0800 Subject: [PATCH 222/225] Create crop4voco.py --- pymic/transform/crop4voco.py | 107 +++++++++++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 pymic/transform/crop4voco.py diff --git a/pymic/transform/crop4voco.py b/pymic/transform/crop4voco.py new file mode 100644 index 0000000..6c52ca7 --- /dev/null +++ b/pymic/transform/crop4voco.py @@ -0,0 +1,107 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import json +import math +import random +import numpy as np +from scipy import ndimage +from pymic import TaskType +from pymic.transform.abstract_transform import AbstractTransform +from pymic.transform.crop import CenterCrop +from pymic.transform.intensity import * +from pymic.util.image_process import * + +def get_position_label(roi=96, num_crops=4): + half = roi // 2 + max_roi = roi * num_crops + center_x, center_y = np.random.randint(low=half, high=max_roi - half), \ + np.random.randint(low=half, high=max_roi - half) + + x_min, x_max = center_x - half, center_x + half + y_min, y_max = center_y - half, center_y + half + + total_area = roi * roi + labels = [] + for j in range(num_crops): + for i in range(num_crops): + crop_x_min, crop_x_max = i * roi, (i + 1) * roi + crop_y_min, crop_y_max = j * roi, (j + 1) * roi + + dx = min(crop_x_max, x_max) - max(crop_x_min, x_min) + dy = min(crop_y_max, y_max) - max(crop_y_min, y_min) + if dx <= 0 or dy <= 0: + area = 0 + else: + area = (dx * dy) / total_area + labels.append(area) + + labels = np.asarray(labels).reshape(1, num_crops * num_crops) + return x_min, y_min, labels + +class Crop4VoCo(CenterCrop): + """ + Randomly crop an volume into two views with augmentation. This is used for + self-supervised pretraining such as DeSD. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `DualViewCrop_output_size`: (list/tuple) Desired output size [D, H, W]. + The output channel is the same as the input channel. + :param `DualViewCrop_scale_lower_bound`: (list/tuple) Lower bound of the range of scale + for each dimension. e.g. (1.0, 0.5, 0.5). + param `DualViewCrop_scale_upper_bound`: (list/tuple) Upper bound of the range of scale + for each dimension. e.g. (1.0, 2.0, 2.0). + :param `DualViewCrop_inverse`: (optional, bool) Is inverse transform needed for inference. + Default is `False`. Currently, the inverse transform is not supported, and + this transform is assumed to be used only during training stage. + """ + def __init__(self, params): + roi_size = params.get('Crop4VoCo_roi_size'.lower(), 64) + if isinstance(roi_size, int): + self.roi_size = [roi_size] * 3 + else: + self.roi_size = roi_size + self.roi_num = params.get('Crop4VoCo_roi_num'.lower(), 2) + self.base_num = params.get('Crop4VoCo_base_num'.lower(), 4) + + self.inverse = params.get('Crop4VoCo_inverse'.lower(), False) + self.task = params['Task'.lower()] + + def __call__(self, sample): + image = sample['image'] + channel, input_size = image.shape[0], image.shape[1:] + input_dim = len(input_size) + # print(input_size, self.roi_size) + assert(input_size[0] == self.roi_size[0]) + assert(input_size[1] == self.roi_size[1] * self.base_num) + assert(input_size[2] == self.roi_size[2] * self.base_num) + + base_num, roi_num, roi_size = self.base_num, self.roi_num, self.roi_size + base_crops, roi_crops, roi_labels = [], [], [] + crop_size = [channel] + list(roi_size) + for j in range(base_num): + for i in range(base_num): + crop_min = [0, 0, roi_size[1]*j, roi_size[2]*i] + crop_max = [crop_min[d] + crop_size[d] for d in range(4)] + crop_out = crop_ND_volume_with_bounding_box(image, crop_min, crop_max) + base_crops.append(crop_out) + + for i in range(roi_num): + x_min, y_min, label = get_position_label(self.roi_size[2], base_num) + # print('label', label) + crop_min = [0, 0, y_min, x_min] + crop_max = [crop_min[d] + crop_size[d] for d in range(4)] + crop_out = crop_ND_volume_with_bounding_box(image, crop_min, crop_max) + roi_crops.append(crop_out) + roi_labels.append(label) + roi_labels = np.concatenate(roi_labels, 0).reshape(roi_num, base_num * base_num) + + base_crops = np.stack(base_crops, 0) + roi_crops = np.stack(roi_crops, 0) + sample['image'] = base_crops, roi_crops, roi_labels + return sample + + \ No newline at end of file From 48adf1d330d7dd9a2cd476b46c639d1b2858ea85 Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 12 Aug 2025 14:42:29 +0800 Subject: [PATCH 223/225] update dataset and requirement version --- pymic/net_run/noisy_label/nll_clslsr.py | 17 +++++++++++------ pymic/net_run/noisy_label/nll_dast.py | 16 ++++++++++------ pymic/net_run/predict.py | 17 ++++++++++------- pymic/net_run/semi_sup/ssl_abstract.py | 10 +++++++--- pymic/net_run/train.py | 10 +++++----- pymic/util/evaluation_cls.py | 8 ++++---- requirements.txt | 18 ++++++++++-------- 7 files changed, 57 insertions(+), 39 deletions(-) diff --git a/pymic/net_run/noisy_label/nll_clslsr.py b/pymic/net_run/noisy_label/nll_clslsr.py index c977eba..3f1059e 100644 --- a/pymic/net_run/noisy_label/nll_clslsr.py +++ b/pymic/net_run/noisy_label/nll_clslsr.py @@ -165,15 +165,20 @@ def get_confidence_map(cfg_file): transform_list.append(one_transform) data_transform = transforms.Compose(transform_list) + stage_dir = config['dataset']['train_dir'] csv_file = config['dataset']['train_csv'] modal_num = config['dataset'].get('modal_num', 1) - stage_dir = config['dataset']['train_dir'] + stage_dim = config['dataset'].get('train_dim', 3) + lab_key = config['dataset'].get('train_label_key', 'label') + dataset = NiftyDataset(root_dir = stage_dir, - csv_file = csv_file, - modal_num = modal_num, - with_label= True, - transform = data_transform, - task = agent.task_type) + csv_file = csv_file, + modal_num = modal_num, + image_dim = stage_dim, + allow_missing_modal = False, + label_key = lab_key, + transform = data_transform, + task = agent.task_type) agent.set_datasets(None, None, dataset) agent.transform_list = transform_list diff --git a/pymic/net_run/noisy_label/nll_dast.py b/pymic/net_run/noisy_label/nll_dast.py index 938e10a..95203ba 100644 --- a/pymic/net_run/noisy_label/nll_dast.py +++ b/pymic/net_run/noisy_label/nll_dast.py @@ -129,13 +129,17 @@ def get_noisy_dataset_from_config(self): data_transform = transforms.Compose(transform_list) modal_num = self.config['dataset'].get('modal_num', 1) - csv_file = self.config['dataset'].get('train_csv_noise', None) + stage_dim = self.config['dataset'].get('train_dim', 3) + lab_key = self.config['dataset'].get('train_label_key', 'label') + csv_file = self.config['dataset'].get('train_csv_noise', None) dataset = NiftyDataset(root_dir = self.config['dataset']['train_dir'], - csv_file = csv_file, - modal_num = modal_num, - with_label= True, - transform = data_transform , - task = self.task_type) + csv_file = csv_file, + modal_num = modal_num, + image_dim = stage_dim, + allow_missing_modal = False, + label_key = lab_key, + transform = data_transform, + task = self.task_type) return dataset diff --git a/pymic/net_run/predict.py b/pymic/net_run/predict.py index e618be6..d63cbad 100644 --- a/pymic/net_run/predict.py +++ b/pymic/net_run/predict.py @@ -21,23 +21,26 @@ def main(): exit() parser = argparse.ArgumentParser() parser.add_argument("cfg", help="configuration file for testing") - parser.add_argument("-test_csv", help="the csv file for testing images", - required=False, default=None) - parser.add_argument("-output_dir", help="the output dir for inference results", + parser.add_argument("--test_csv", help="the csv file for testing images", required=False, default=None) - parser.add_argument("-ckpt_dir", help="the dir for trained model", + parser.add_argument("--test_dir", help="the dir for testing images", required=False, default=None) - parser.add_argument("-ckpt_mode", help="the mode for chekpoint: 0-latest, 1-best, 2-customized", + parser.add_argument("--output_dir", help="the output dir for inference results", required=False, default=None) - parser.add_argument("-ckpt_name", help="the name chekpoint if ckpt_mode = 2", + parser.add_argument("--ckpt_dir", help="the dir for trained model", required=False, default=None) - parser.add_argument("-gpus", help="the gpus for runing, e.g., [0]", + parser.add_argument("--ckpt_mode", help="the mode for chekpoint: 0-latest, 1-best, 2-customized", + required=False, default=None) + parser.add_argument("--ckpt_name", help="the name chekpoint if ckpt_mode = 2", + required=False, default=None) + parser.add_argument("--gpus", help="the gpus for runing, e.g., [0]", required=False, default=None) args = parser.parse_args() if(not os.path.isfile(args.cfg)): raise ValueError("The config file does not exist: " + args.cfg) config = parse_config(args) config = synchronize_config(config) + print(config) log_dir = config['testing']['output_dir'] if(not os.path.exists(log_dir)): os.makedirs(log_dir, exist_ok=True) diff --git a/pymic/net_run/semi_sup/ssl_abstract.py b/pymic/net_run/semi_sup/ssl_abstract.py index 69a09fd..4925859 100644 --- a/pymic/net_run/semi_sup/ssl_abstract.py +++ b/pymic/net_run/semi_sup/ssl_abstract.py @@ -52,12 +52,16 @@ def get_unlabeled_dataset_from_config(self): self.transform_list.append(one_transform) data_transform = transforms.Compose(self.transform_list) - csv_file = self.config['dataset'].get('train_csv_unlab', None) + csv_file = self.config['dataset'].get('train_csv_unlab', None) + stage_dim = self.config['dataset'].get('train_dim', 3) dataset = NiftyDataset(root_dir = train_dir, csv_file = csv_file, modal_num = modal_num, - with_label= False, - transform = data_transform ) + image_dim = stage_dim, + allow_missing_modal = False, + label_key = None, + transform = data_transform, + task = self.task_type) return dataset def create_dataset(self): diff --git a/pymic/net_run/train.py b/pymic/net_run/train.py index ec0002f..d98145a 100644 --- a/pymic/net_run/train.py +++ b/pymic/net_run/train.py @@ -54,15 +54,15 @@ def main(): exit() parser = argparse.ArgumentParser() parser.add_argument("cfg", help="configuration file for training") - parser.add_argument("-train_csv", help="the csv file for training images", + parser.add_argument("--train_csv", help="the csv file for training images", required=False, default=None) - parser.add_argument("-valid_csv", help="the csv file for validation images", + parser.add_argument("--valid_csv", help="the csv file for validation images", required=False, default=None) - parser.add_argument("-ckpt_dir", help="the output dir for trained model", + parser.add_argument("--ckpt_dir", help="the output dir for trained model", required=False, default=None) - parser.add_argument("-iter_max", help="the maximal iteration number for training", + parser.add_argument("--iter_max", help="the maximal iteration number for training", required=False, default=None) - parser.add_argument("-gpus", help="the gpus for runing, e.g., [0]", + parser.add_argument("--gpus", help="the gpus for runing, e.g., [0]", required=False, default=None) args = parser.parse_args() if(not os.path.isfile(args.cfg)): diff --git a/pymic/util/evaluation_cls.py b/pymic/util/evaluation_cls.py index a65953a..686a811 100644 --- a/pymic/util/evaluation_cls.py +++ b/pymic/util/evaluation_cls.py @@ -176,13 +176,13 @@ def main(): :param pred_prob_csv: (str) The csv file for prediction probability. """ parser = argparse.ArgumentParser() - parser.add_argument("-cfg", help="configuration file for evaluation", + parser.add_argument("--cfg", help="configuration file for evaluation", required=False, default=None) - parser.add_argument("-metric", help="evaluation metrics, e.g., accuracy, or [accuracy, auc]", + parser.add_argument("--metric", help="evaluation metrics, e.g., accuracy, or [accuracy, auc]", required=False, default=None) - parser.add_argument("-gt_csv", help="csv file for ground truth", + parser.add_argument("--gt_csv", help="csv file for ground truth", required=False, default=None) - parser.add_argument("-pred_prob_csv", help="csv file for probability prediction", + parser.add_argument("--pred_prob_csv", help="csv file for probability prediction", required=False, default=None) args = parser.parse_args() print(args) diff --git a/requirements.txt b/requirements.txt index cac47f3..e70fc83 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,14 @@ h5py matplotlib>=3.1.2 -numpy>=1.17.4 -pandas>=0.25.3 -scikit-image>=0.16.2 -scikit-learn>=0.22 -scipy>=1.3.3 -SimpleITK>=2.0.0 +numpy>=1.23.5 +pandas>=1.5.2 +scikit-image>=0.19.3 +scikit-learn>=1.2.0 +scipy>=1.10.0 +SimpleITK>=2.0.2 tensorboard tensorboardX -torch>=1.1.12 -torchvision>=0.13.0 +torch>=1.13.1 +torchvision>=0.14.1 +causal-conv1d>=1.5.0 +mamba-ssm>=2.2.4 From 46835c73382c12480950d7714e2e6acd8acb9882 Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 12 Aug 2025 15:24:13 +0800 Subject: [PATCH 224/225] Update setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 51d939d..879ee6c 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setuptools.setup( name = 'PYMIC', - version = "0.5.1", + version = "0.5.4", author ='PyMIC Consortium', author_email = 'wguotai@gmail.com', description = description, From b27763cf6f78569a0748f74c2922894be024a7bd Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 12 Aug 2025 15:52:56 +0800 Subject: [PATCH 225/225] Update README.md --- README.md | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index bedeaae..e7757c4 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,16 @@ BibTeX entry: pages = {107398}, } +# News +* 2025/08 PyMIC has contained the implementation of [`DMSPS`][dmsps_paper], a state-of-the-art weakly supervised segmentation method by learning from scribble annotations. +* 2025/05 Several self-supervised learning methods have been provided in PyMIC, including [`VolF`][volf_paper], [`VoCo`][voco_paper] and [`Vox2Vec`][vox2vec_paper]. +* 2025/01 Novel architectures are available now, such as `UMamba`, `VMUNet`, `SwinUNet`, `TransUNet` and `UNETR++`. + +[dmsps_paper]: https://www.sciencedirect.com/science/article/pii/S1361841524001993 +[volf_paper]: https://arxiv.org/abs/2306.16925 +[voco_paper]: https://arxiv.org/abs/2402.17300 +[vox2vec_paper]:https://conferences.miccai.org/2023/papers/712-Paper3421.html + # Features PyMIC provides flixible modules for medical image computing tasks including classification and segmentation. It currently provides the following functions: * Support for annotation-efficient image segmentation, especially for semi-supervised, self-supervised, self-supervised, weakly-supervised and noisy-label learning. @@ -33,9 +43,10 @@ PyMIC provides flixible modules for medical image computing tasks including clas # Usage ## Requirement -* [Pytorch][torch_link] version >=1.0.1 +* [Pytorch][torch_link] version >=1.13.1 * [TensorboardX][tbx_link] to visualize training performance * Some common python packages such as Numpy, Pandas, SimpleITK +* causal-conv1d>=1.5.0 and mamba-ssm>=2.2.4 are required if you want to use Mamba in PyMIC. * See `requirements.txt` for details. [torch_link]:https://pytorch.org/ @@ -47,10 +58,10 @@ Run the following command to install the latest released version of PyMIC: ```bash pip install PYMIC ``` -To install a specific version of PYMIC such as 0.5.0, run: +To install a specific version of PYMIC such as 0.5.4, run: ```bash -pip install PYMIC==0.5.0 +pip install PYMIC==0.5.4 ``` Alternatively, you can download the source code for the latest version. Run the following command to compile and install: @@ -76,8 +87,11 @@ Using PyMIC, it becomes easy to develop deep learning models for different proje 4, [UGIR][ugir] (MICCAI 2020) Uncertainty-guided interactive refinement for medical image segmentation. +5, [DMSPS][dmsps] (MedIA 2024) Weakly supervised segmentation by learning from scribbles. + [myops]: https://github.com/HiLab-git/MyoPS2020 [coplenet]:https://github.com/HiLab-git/COPLE-Net [hn_gtv]: https://github.com/HiLab-git/Head-Neck-GTV [ugir]: https://github.com/HiLab-git/UGIR +[dmsps]: https://github.com/HiLab-git/DMSPS