From 713a56e948c0362f56c73f850e79d6a37c7c0421 Mon Sep 17 00:00:00 2001 From: taigw Date: Sun, 26 Feb 2023 16:39:21 +0800 Subject: [PATCH 01/86] Update index.rst update reference --- docs/source/index.rst | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index f9b62ba..1724b3a 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -28,9 +28,10 @@ Citation If you use PyMIC for your research, please acknowledge it accordingly by citing our paper: -`G. Wang, X. Luo, R. Gu, S. Yang, Y. Qu, S. Zhai, Q. Zhao, K. Li, S. Zhang. (2022). +`G. Wang, X. Luo, R. Gu, S. Yang, Y. Qu, S. Zhai, Q. Zhao, K. Li, S. Zhang. PyMIC: A deep learning toolkit for annotation-efficient medical image segmentation. -arXiv, 2208.09350. `_ + Computer Methods and Programs in Biomedicine (CMPB). 231 (2023): 107398. + `_ BibTeX entry: @@ -41,8 +42,8 @@ BibTeX entry: author = {Guotai Wang and Xiangde Luo and Ran Gu and Shuojue Yang and Yijie Qu and Shuwei Zhai and Qianfei Zhao and Kang Li and Shaoting Zhang}, title = {{PyMIC: A deep learning toolkit for annotation-efficient medical image segmentation}}, year = {2022}, - url = {http://arxiv.org/abs/2208.09350}, - journal = {arXiv}, - volume = {2208.09350}, - pages = {1-10}, + url = {https://doi.org/10.1016/j.cmpb.2023.107398}, + journal = {Computer Methods and Programs in Biomedicine}, + volume = {231}, + pages = {107398}, } From 4ff4088401618a06eac3a09de205f6227514674c Mon Sep 17 00:00:00 2001 From: taigw Date: Sun, 26 Feb 2023 17:23:34 +0800 Subject: [PATCH 02/86] update structure of docs --- docs/source/conf.py | 4 +- docs/source/index.rst | 7 +-- docs/source/installation.rst | 3 +- docs/source/pymic.net_run.noisy_label.rst | 45 +++++++++++++ docs/source/pymic.net_run.rst | 14 ++++- docs/source/pymic.net_run.self_sup.rst | 21 +++++++ docs/source/pymic.net_run.semi_sup.rst | 69 ++++++++++++++++++++ docs/source/pymic.net_run.weak_sup.rst | 69 ++++++++++++++++++++ docs/source/pymic.net_run_nll.rst | 53 ---------------- docs/source/pymic.net_run_ssl.rst | 77 ----------------------- docs/source/pymic.net_run_wsl.rst | 77 ----------------------- docs/source/pymic.rst | 3 - docs/source/setup.rst | 7 +++ docs/source/usage.quickstart.rst | 36 +++++++---- 14 files changed, 254 insertions(+), 231 deletions(-) create mode 100644 docs/source/pymic.net_run.noisy_label.rst create mode 100644 docs/source/pymic.net_run.self_sup.rst create mode 100644 docs/source/pymic.net_run.semi_sup.rst create mode 100644 docs/source/pymic.net_run.weak_sup.rst delete mode 100644 docs/source/pymic.net_run_nll.rst delete mode 100644 docs/source/pymic.net_run_ssl.rst delete mode 100644 docs/source/pymic.net_run_wsl.rst create mode 100644 docs/source/setup.rst diff --git a/docs/source/conf.py b/docs/source/conf.py index cf1b568..09f2b50 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -9,8 +9,8 @@ copyright = '2021, HiLab' author = 'HiLab' -release = '0.1' -version = '0.1.0' +release = '0.4' +version = '0.4.0' # -- General configuration diff --git a/docs/source/index.rst b/docs/source/index.rst index 1724b3a..c1b6523 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -28,10 +28,9 @@ Citation If you use PyMIC for your research, please acknowledge it accordingly by citing our paper: -`G. Wang, X. Luo, R. Gu, S. Yang, Y. Qu, S. Zhai, Q. Zhao, K. Li, S. Zhang. -PyMIC: A deep learning toolkit for annotation-efficient medical image segmentation. - Computer Methods and Programs in Biomedicine (CMPB). 231 (2023): 107398. - `_ +`G. Wang, X. Luo, R. Gu, S. Yang, Y. Qu, S. Zhai, Q. Zhao, K. Li, S. Zhang. +PyMIC: A deep learning toolkit for annotation-efficient medical image segmentation. +Computer Methods and Programs in Biomedicine (CMPB). 231 (2023): 107398. `_ BibTeX entry: diff --git a/docs/source/installation.rst b/docs/source/installation.rst index ced640f..c1055e1 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -17,7 +17,8 @@ Alternatively, you can download or clone the code from `GitHub `_ - `h5py `_ diff --git a/docs/source/pymic.net_run.noisy_label.rst b/docs/source/pymic.net_run.noisy_label.rst new file mode 100644 index 0000000..04d38bb --- /dev/null +++ b/docs/source/pymic.net_run.noisy_label.rst @@ -0,0 +1,45 @@ +pymic.net\_run.noisy\_label package +=================================== + +Submodules +---------- + +pymic.net\_run.noisy\_label.nll\_clslsr module +---------------------------------------------- + +.. automodule:: pymic.net_run.noisy_label.nll_clslsr + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.noisy\_label.nll\_co\_teaching module +---------------------------------------------------- + +.. automodule:: pymic.net_run.noisy_label.nll_co_teaching + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.noisy\_label.nll\_dast module +-------------------------------------------- + +.. automodule:: pymic.net_run.noisy_label.nll_dast + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.noisy\_label.nll\_trinet module +---------------------------------------------- + +.. automodule:: pymic.net_run.noisy_label.nll_trinet + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: pymic.net_run.noisy_label + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/pymic.net_run.rst b/docs/source/pymic.net_run.rst index 74aab12..5f9ed95 100644 --- a/docs/source/pymic.net_run.rst +++ b/docs/source/pymic.net_run.rst @@ -1,6 +1,16 @@ pymic.net\_run package ====================== +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + pymic.net_run.semi_sup + pymic.net_run.weak_sup + pymic.net_run.noisy_label + Submodules ---------- @@ -44,10 +54,10 @@ pymic.net\_run.infer\_func module :undoc-members: :show-inheritance: -pymic.net\_run.net\_run module +pymic.net\_run.semi\_sup module ------------------------------ -.. automodule:: pymic.net_run.net_run +.. automodule:: pymic.net_run.semi_sup :members: :undoc-members: :show-inheritance: diff --git a/docs/source/pymic.net_run.self_sup.rst b/docs/source/pymic.net_run.self_sup.rst new file mode 100644 index 0000000..a4568d1 --- /dev/null +++ b/docs/source/pymic.net_run.self_sup.rst @@ -0,0 +1,21 @@ +pymic.net\_run.self\_sup package +================================ + +Submodules +---------- + +pymic.net\_run.self\_sup.self\_sl\_agent module +----------------------------------------------- + +.. automodule:: pymic.net_run.self_sup.self_sl_agent + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: pymic.net_run.self_sup + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/pymic.net_run.semi_sup.rst b/docs/source/pymic.net_run.semi_sup.rst new file mode 100644 index 0000000..6ed157d --- /dev/null +++ b/docs/source/pymic.net_run.semi_sup.rst @@ -0,0 +1,69 @@ +pymic.net\_run.semi\_sup package +================================ + +Submodules +---------- + +pymic.net\_run.semi\_sup.ssl\_abstract module +--------------------------------------------- + +.. automodule:: pymic.net_run.semi_sup.ssl_abstract + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.semi\_sup.ssl\_cct module +---------------------------------------- + +.. automodule:: pymic.net_run.semi_sup.ssl_cct + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.semi\_sup.ssl\_cps module +---------------------------------------- + +.. automodule:: pymic.net_run.semi_sup.ssl_cps + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.semi\_sup.ssl\_em module +--------------------------------------- + +.. automodule:: pymic.net_run.semi_sup.ssl_em + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.semi\_sup.ssl\_mt module +--------------------------------------- + +.. automodule:: pymic.net_run.semi_sup.ssl_mt + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.semi\_sup.ssl\_uamt module +----------------------------------------- + +.. automodule:: pymic.net_run.semi_sup.ssl_uamt + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.semi\_sup.ssl\_urpc module +----------------------------------------- + +.. automodule:: pymic.net_run.semi_sup.ssl_urpc + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: pymic.net_run.semi_sup + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/pymic.net_run.weak_sup.rst b/docs/source/pymic.net_run.weak_sup.rst new file mode 100644 index 0000000..b906f72 --- /dev/null +++ b/docs/source/pymic.net_run.weak_sup.rst @@ -0,0 +1,69 @@ +pymic.net\_run.weak\_sup package +================================ + +Submodules +---------- + +pymic.net\_run.weak\_sup.wsl\_abstract module +--------------------------------------------- + +.. automodule:: pymic.net_run.weak_sup.wsl_abstract + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.weak\_sup.wsl\_dmpls module +------------------------------------------ + +.. automodule:: pymic.net_run.weak_sup.wsl_dmpls + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.weak\_sup.wsl\_em module +--------------------------------------- + +.. automodule:: pymic.net_run.weak_sup.wsl_em + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.weak\_sup.wsl\_gatedcrf module +--------------------------------------------- + +.. automodule:: pymic.net_run.weak_sup.wsl_gatedcrf + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.weak\_sup.wsl\_mumford\_shah module +-------------------------------------------------- + +.. automodule:: pymic.net_run.weak_sup.wsl_mumford_shah + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.weak\_sup.wsl\_tv module +--------------------------------------- + +.. automodule:: pymic.net_run.weak_sup.wsl_tv + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.weak\_sup.wsl\_ustm module +----------------------------------------- + +.. automodule:: pymic.net_run.weak_sup.wsl_ustm + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: pymic.net_run.weak_sup + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/pymic.net_run_nll.rst b/docs/source/pymic.net_run_nll.rst deleted file mode 100644 index 40ee9ef..0000000 --- a/docs/source/pymic.net_run_nll.rst +++ /dev/null @@ -1,53 +0,0 @@ -pymic.net\_run\_nll package -=========================== - -Submodules ----------- - -pymic.net\_run\_nll.nll\_clslsr module --------------------------------------- - -.. automodule:: pymic.net_run_nll.nll_clslsr - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_nll.nll\_co\_teaching module --------------------------------------------- - -.. automodule:: pymic.net_run_nll.nll_co_teaching - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_nll.nll\_dast module ------------------------------------- - -.. automodule:: pymic.net_run_nll.nll_dast - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_nll.nll\_main module ------------------------------------- - -.. automodule:: pymic.net_run_nll.nll_main - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_nll.nll\_trinet module --------------------------------------- - -.. automodule:: pymic.net_run_nll.nll_trinet - :members: - :undoc-members: - :show-inheritance: - -Module contents ---------------- - -.. automodule:: pymic.net_run_nll - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/pymic.net_run_ssl.rst b/docs/source/pymic.net_run_ssl.rst deleted file mode 100644 index 236e2d6..0000000 --- a/docs/source/pymic.net_run_ssl.rst +++ /dev/null @@ -1,77 +0,0 @@ -pymic.net\_run\_ssl package -=========================== - -Submodules ----------- - -pymic.net\_run\_ssl.ssl\_abstract module ----------------------------------------- - -.. automodule:: pymic.net_run_ssl.ssl_abstract - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_ssl.ssl\_cct module ------------------------------------ - -.. automodule:: pymic.net_run_ssl.ssl_cct - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_ssl.ssl\_cps module ------------------------------------ - -.. automodule:: pymic.net_run_ssl.ssl_cps - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_ssl.ssl\_em module ----------------------------------- - -.. automodule:: pymic.net_run_ssl.ssl_em - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_ssl.ssl\_main module ------------------------------------- - -.. automodule:: pymic.net_run_ssl.ssl_main - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_ssl.ssl\_mt module ----------------------------------- - -.. automodule:: pymic.net_run_ssl.ssl_mt - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_ssl.ssl\_uamt module ------------------------------------- - -.. automodule:: pymic.net_run_ssl.ssl_uamt - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_ssl.ssl\_urpc module ------------------------------------- - -.. automodule:: pymic.net_run_ssl.ssl_urpc - :members: - :undoc-members: - :show-inheritance: - -Module contents ---------------- - -.. automodule:: pymic.net_run_ssl - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/pymic.net_run_wsl.rst b/docs/source/pymic.net_run_wsl.rst deleted file mode 100644 index 5eda921..0000000 --- a/docs/source/pymic.net_run_wsl.rst +++ /dev/null @@ -1,77 +0,0 @@ -pymic.net\_run\_wsl package -=========================== - -Submodules ----------- - -pymic.net\_run\_wsl.wsl\_abstract module ----------------------------------------- - -.. automodule:: pymic.net_run_wsl.wsl_abstract - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_wsl.wsl\_dmpls module -------------------------------------- - -.. automodule:: pymic.net_run_wsl.wsl_dmpls - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_wsl.wsl\_em module ----------------------------------- - -.. automodule:: pymic.net_run_wsl.wsl_em - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_wsl.wsl\_gatedcrf module ----------------------------------------- - -.. automodule:: pymic.net_run_wsl.wsl_gatedcrf - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_wsl.wsl\_main module ------------------------------------- - -.. automodule:: pymic.net_run_wsl.wsl_main - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_wsl.wsl\_mumford\_shah module ---------------------------------------------- - -.. automodule:: pymic.net_run_wsl.wsl_mumford_shah - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_wsl.wsl\_tv module ----------------------------------- - -.. automodule:: pymic.net_run_wsl.wsl_tv - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_wsl.wsl\_ustm module ------------------------------------- - -.. automodule:: pymic.net_run_wsl.wsl_ustm - :members: - :undoc-members: - :show-inheritance: - -Module contents ---------------- - -.. automodule:: pymic.net_run_wsl - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/pymic.rst b/docs/source/pymic.rst index 7545740..dd180f0 100644 --- a/docs/source/pymic.rst +++ b/docs/source/pymic.rst @@ -12,9 +12,6 @@ Subpackages pymic.loss pymic.net pymic.net_run - pymic.net_run_nll - pymic.net_run_ssl - pymic.net_run_wsl pymic.transform pymic.util diff --git a/docs/source/setup.rst b/docs/source/setup.rst new file mode 100644 index 0000000..552eb49 --- /dev/null +++ b/docs/source/setup.rst @@ -0,0 +1,7 @@ +setup module +============ + +.. automodule:: setup + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/usage.quickstart.rst b/docs/source/usage.quickstart.rst index 95cf20d..bced277 100644 --- a/docs/source/usage.quickstart.rst +++ b/docs/source/usage.quickstart.rst @@ -12,13 +12,13 @@ for segmentation with full supervision, run the fullowing command: .. code-block:: bash - pymic_run train myconfig.cfg + pymic_train myconfig.cfg After training, run the following command for testing: .. code-block:: bash - pymic_run test myconfig.cfg + pymic_test myconfig.cfg .. tip:: @@ -51,11 +51,13 @@ file used for segmentation of lung from radiograph, which can be find in [dataset] # tensor type (float or double) tensor_type = float + task_type = seg root_dir = ../../PyMIC_data/JSRT train_csv = config/jsrt_train.csv valid_csv = config/jsrt_valid.csv test_csv = config/jsrt_test.csv + train_batch_size = 4 # data transforms @@ -69,19 +71,26 @@ file used for segmentation of lung from radiograph, which can be find in LabelConvert_source_list = [0, 255] LabelConvert_target_list = [0, 1] + [network] + # this section gives parameters for network + # the keys may be different for different networks + + # type of network net_type = UNet2D - # Parameters for UNet2D + + # number of class, required for segmentation task class_num = 2 in_chns = 1 feature_chns = [16, 32, 64, 128, 256] dropout = [0, 0, 0.3, 0.4, 0.5] bilinear = False - deep_supervise= False + multiscale_pred = False [training] # list of gpus gpus = [0] + loss_type = DiceLoss # for optimizers @@ -95,8 +104,8 @@ file used for segmentation of lung from radiograph, which can be find in lr_gamma = 0.5 lr_milestones = [2000, 4000, 6000] - ckpt_save_dir = model/unet_dice_loss - ckpt_prefix = unet + ckpt_save_dir = model/unet + ckpt_prefix = unet # start iter iter_start = 0 @@ -107,9 +116,10 @@ file used for segmentation of lung from radiograph, which can be find in [testing] # list of gpus gpus = [0] + # checkpoint mode can be [0-latest, 1-best, 2-specified] - ckpt_mode = 0 - output_dir = result + ckpt_mode = 0 + output_dir = result/unet # convert the label of prediction output label_source = [0, 1] @@ -131,17 +141,18 @@ For example, for segmentation tasks, run: pymic_eval_seg evaluation.cfg -The configuration file is like (an example from ``PYMIC_examples/seg_ssl/ACDC``): +The configuration file is like (an example from +`PyMIC_examples/seg_ssl/ACDC `_): .. code-block:: none [evaluation] - metric = dice + metric_list = [dice, hd95] label_list = [1,2,3] organ_name = heart ground_truth_folder_root = ../../PyMIC_data/ACDC/preprocess - segmentation_folder_root = result/unet2d_em + segmentation_folder_root = result/unet2d_urpc evaluation_image_pair = config/data/image_test_gt_seg.csv See :mod:`pymic.util.evaluation_seg.evaluation` for details of the configuration required. @@ -152,7 +163,8 @@ For classification tasks, run: pymic_eval_cls evaluation.cfg -The configuration file is like (an example from ``PYMIC_examples/classification/CHNCXR``): +The configuration file is like (an example from +`PyMIC_examples/classification/CHNCXR `_): .. code-block:: none From 6afeb8697d4fa9fb5dd1fb5cea12ca82f37baf07 Mon Sep 17 00:00:00 2001 From: taigw Date: Sun, 26 Feb 2023 21:45:56 +0800 Subject: [PATCH 03/86] update docs for v0.4.0 update docs for v0.4.0 --- docs/source/pymic.net_run.rst | 8 ------ docs/source/usage.fsl.rst | 29 +++++++++---------- docs/source/usage.nll.rst | 48 ++++++++++++-------------------- docs/source/usage.ssl.rst | 49 ++++++++++++++------------------- docs/source/usage.wsl.rst | 43 ++++++++++------------------- pymic/net_run/agent_abstract.py | 1 + pymic/net_run/get_optimizer.py | 16 +++++++++++ pymic/transform/intensity.py | 17 ++++++------ 8 files changed, 90 insertions(+), 121 deletions(-) diff --git a/docs/source/pymic.net_run.rst b/docs/source/pymic.net_run.rst index 5f9ed95..9f3bc26 100644 --- a/docs/source/pymic.net_run.rst +++ b/docs/source/pymic.net_run.rst @@ -54,14 +54,6 @@ pymic.net\_run.infer\_func module :undoc-members: :show-inheritance: -pymic.net\_run.semi\_sup module ------------------------------- - -.. automodule:: pymic.net_run.semi_sup - :members: - :undoc-members: - :show-inheritance: - Module contents --------------- diff --git a/docs/source/usage.fsl.rst b/docs/source/usage.fsl.rst index 053daf7..937bb01 100644 --- a/docs/source/usage.fsl.rst +++ b/docs/source/usage.fsl.rst @@ -28,8 +28,8 @@ configuration file for running. .. tip:: If you use the built-in modules such as ``UNet`` and ``Dice`` + ``CrossEntropy`` loss - for segmentation, you don't need to write the above code. Just just use the ``pymic_run`` - command. + for segmentation, you don't need to write the above code. Just just use the ``pymic_train`` + command. See examples in `PyMIC_examples/segmentation/ `_. Dataset ------- @@ -207,7 +207,7 @@ hyper-parameters. For example, the following is a configuration for using ``2DUN feature_chns = [16, 32, 64, 128, 256] dropout = [0, 0, 0.3, 0.4, 0.5] bilinear = False - deep_supervise= False + multiscale_pred = False The ``SegNetDict`` in :mod:`pymic.net.net_dict_seg` lists all the built-in network structures currently implemented in PyMIC. @@ -299,9 +299,6 @@ Itreations For training iterations, the following parameters need to be specified in the configuration file: -* ``iter_start``: the start iteration, by default is 0. None zero value means the - iteration where a pre-trained model stopped for continuing with the trainnig. - * ``iter_max``: the maximal allowed iteration for training. * ``iter_valid``: if the value is K, it means evaluating the performance on the @@ -321,9 +318,9 @@ Optimizer For optimizer, users need to set ``optimizer``, ``learning_rate``, ``momentum`` and ``weight_decay``. The built-in optimizers include ``SGD``, ``Adam``, ``SparseAdam``, ``Adadelta``, ``Adagrad``, ``Adamax``, ``ASGD``, -``LBFGS``, ``RMSprop`` and ``Rprop`` that are implemented in :mod:`torch.optim`. +``LBFGS``, ``RMSprop`` and ``Rprop`` that are implemented in `torch.optim`. -You can also use customized optimizers via :mod:`SegmentationAgent.set_optimizer()`. +You can also use customized optimizers via `SegmentationAgent.set_optimizer()`. Learning Rate Scheduler ^^^^^^^^^^^^^^^^^^^^^^^ @@ -335,7 +332,7 @@ the configuration file. Parameters related to ``ReduceLROnPlateau`` include ``lr_gamma``. Parameters related to ``MultiStepLR`` include ``lr_gamma`` and ``lr_milestones``. -You can also use customized lr schedulers via :mod:`SegmentationAgent.set_scheduler()`. +You can also use customized lr schedulers via `SegmentationAgent.set_scheduler()`. Other Options ^^^^^^^^^^^^^ @@ -373,8 +370,8 @@ test-time augmentation, etc. The following is a list of options availble for inf * ``ckpt_name`` (string, optinal): the full path to the checkpoint if ckpt_mode = 2. * ``post_process`` (string, default is None): the post process method after inference. - The current available post processing is :mod:`PostKeepLargestComponent`. Uses can also - specify customized post process methods via :mod:`SegmentationAgent.set_postprocessor()`. + The current available post processing is :mod:`pymic.util.post_process.PostKeepLargestComponent`. + Uses can also specify customized post process methods via `SegmentationAgent.set_postprocessor()`. * ``sliding_window_enable`` (bool, default is False): use sliding window for inference or not. @@ -390,14 +387,14 @@ test-time augmentation, etc. The following is a list of options availble for inf * ``ignore_dir`` (bool, default is True): if the input image name has a `/`, it will be replaced with `_` in the output file name. -* ``save_probability`` (boold, default is False): save the output probability for each class. +* ``save_probability`` (bool, default is False): save the output probability for each class. * ``label_source`` (list, default is None): a list of label to be converted after prediction. For example, - :mod:`label_source` = [0, 1] and :mod:`label_target` = [0, 255] will convert label value from 1 to 255. + `label_source` = [0, 1] and `label_target` = [0, 255] will convert label value from 1 to 255. -* ``label_target`` (list, default is None): a list of label after conversion. Use this with :mod:`label_source`. +* ``label_target`` (list, default is None): a list of label after conversion. Use this with `label_source`. * ``filename_replace_source`` (string, default is None): the substring in the filename will be replaced with - a new substring specified by :mod:`filename_replace_target`. + a new substring specified by `filename_replace_target`. -* ``filename_replace_target`` (string, default is None): work with :mod:`filename_replace_source`. \ No newline at end of file +* ``filename_replace_target`` (string, default is None): work with `filename_replace_source`. \ No newline at end of file diff --git a/docs/source/usage.nll.rst b/docs/source/usage.nll.rst index f6edd33..0a87be9 100644 --- a/docs/source/usage.nll.rst +++ b/docs/source/usage.nll.rst @@ -3,43 +3,28 @@ Noisy Label Learning ==================== -pymic_nll ---------- - -:mod:`pymic_nll` is the command for using built-in NLL methods for training. -Similarly to :mod:`pymic_run`, it should be followed by two parameters, specifying the -stage and configuration file, respectively. The training and testing commands are: - -.. code-block:: bash - - pymic_nll train myconfig_nll.cfg - pymic_nll test myconfig_nll.cfg - -.. tip:: - - If the NLL method only involves one network, either ``pymic_nll`` or ``pymic_run`` - can be used for inference. Their difference only exists in the training stage. - .. note:: Some NLL methods only use noise-robust loss functions without complex training process, and just combining the standard :mod:`SegmentationAgent` with such - loss function works for training. ``pymic_run`` instead of ``pymic_nll`` should - be used for these methods. + loss function works for training. NLL Configurations ------------------ -In the configuration file for ``pymic_nll``, in addition to those used in standard fully +In the configuration file for noisy label learning, in addition to those used in standard fully supervised learning, there is a ``noisy_label_learning`` section that is specifically designed -for NLL methods. In that section, users need to specify the ``nll_method`` and configurations -related to the NLL method. For example, the correspoinding configuration for CoTeaching is: +for NLL methods. In that section, users need to specify the ``method_name`` and configurations +related to the NLL method. ``supervise_type`` should be set as "`noisy_label`" in the ``dataset`` section. + For example, the correspoinding configuration for CoTeaching is: .. code-block:: none [dataset] ... + supervise_type = noisy_label + ... [network] ... @@ -48,7 +33,7 @@ related to the NLL method. For example, the correspoinding configuration for CoT ... [noisy_label_learning] - nll_method = CoTeaching + method_name = CoTeaching co_teaching_select_ratio = 0.8 rampup_start = 1000 rampup_end = 8000 @@ -60,13 +45,14 @@ related to the NLL method. For example, the correspoinding configuration for CoT The configuration items vary with different NLL methods. Please refer to the API of each built-in NLL method for details of the correspoinding configuration. + See examples in `PyMIC_examples/seg_nll/ `_. + Built-in NLL Methods -------------------- -Some NLL methods only use noise-robust loss functions. They are used with ``pymic_run`` -for training. Just set ``loss_type`` to one of them in the configuration file, similarly -to the fully supervised learning. +Some NLL methods only use noise-robust loss functions. They are used with a standard fully supervised training +paradigm. Just set ``supervise_type`` = `fully_sup`, and use ``loss_type`` to one of them in the configuration file: * ``GCELoss``: (`NeurIPS 2018 `_) Generalized cross entropy loss. @@ -78,7 +64,7 @@ to the fully supervised learning. Noise-robust Dice loss. The other NLL methods are implemented in child classes of -:mod:`pymic.net_run_nll.nll_abstract.NLLSegAgent`, and they are: +:mod:`pymic.net_run.agent_seg.SegmentationAgent`, and they are: * ``CLSLSR``: (`MICCAI 2020 `_) Confident learning with spatial label smoothing regularization. @@ -95,15 +81,15 @@ The other NLL methods are implemented in child classes of Customized NLL Methods ---------------------- -PyMIC alo supports customizing NLL methods by inheriting the :mod:`NLLSegAgent` class. -You may only need to rewrite the :mod:`training()` method and reuse most part of the +PyMIC alo supports customized NLL methods by inheriting the `SegmentationAgent` class. +You may only need to rewrite the `training()` method and reuse most part of the existing pipeline, such as data loading, validation and inference methods. For example: .. code-block:: none - from pymic.net_run_nll.nll_abstract import NLLSegAgent + from pymic.net_run.agent_seg import SegmentationAgent - class MyNLLMethod(NLLSegAgent): + class MyNLLMethod(SegmentationAgent): def __init__(self, config, stage = 'train'): super(MyNLLMethod, self).__init__(config, stage) ... diff --git a/docs/source/usage.ssl.rst b/docs/source/usage.ssl.rst index 143d3f8..0fd8dc6 100644 --- a/docs/source/usage.ssl.rst +++ b/docs/source/usage.ssl.rst @@ -3,37 +3,22 @@ Semi-Supervised Learning ========================= -pymic_ssl ---------- - -:mod:`pymic_ssl` is the command for using built-in semi-supervised methods for training. -Similarly to :mod:`pymic_run`, it should be followed by two parameters, specifying the -stage and configuration file, respectively. The training and testing commands are: - -.. code-block:: bash - - pymic_ssl train myconfig_ssl.cfg - pymic_ssl test myconfig_ssl.cfg - -.. tip:: - - If the SSL method only involves one network, either ``pymic_ssl`` or ``pymic_run`` - can be used for inference. Their difference only exists in the training stage. - SSL Configurations ------------------ -In the configuration file for ``pymic_ssl``, in addition to those used in fully +In the configuration file for semi-supervised segmentation, in addition to those used in fully supervised learning, there are some items specificalized for semi-supervised learning. Users should provide values for the following items in ``dataset`` section of the configuration file: +* ``supervise_type`` (string): The value should be "`semi_sup`". + * ``train_csv_unlab`` (string): the csv file for unlabeled dataset. Note that ``train_csv`` is only used for labeled dataset. * ``train_batch_size_unlab`` (int): the batch size for unlabeled dataset. - Note that ``train_batch_size`` means the batch size for the labeled dataset. + Note that `train_batch_size` means the batch size for the labeled dataset. * ``train_transform_unlab`` (list): a list of transforms used for unlabeled data. @@ -43,7 +28,12 @@ The following is an example of the ``dataset`` section for semi-supervised learn .. code-block:: none ... - root_dir =../../PyMIC_data/ACDC/preprocess/ + + tensor_type = float + task_type = seg + supervise_type = semi_sup + + root_dir = ../../PyMIC_data/ACDC/preprocess/ train_csv = config/data/image_train_r10_lab.csv train_csv_unlab = config/data/image_train_r10_unlab.csv valid_csv = config/data/image_valid.csv @@ -60,14 +50,14 @@ The following is an example of the ``dataset`` section for semi-supervised learn ... In addition, there is a ``semi_supervised_learning`` section that is specifically designed -for SSL methods. In that section, users need to specify the ``ssl_method`` and configurations +for SSL methods. In that section, users need to specify the ``method_name`` and configurations related to the SSL method. For example, the correspoinding configuration for CPS is: .. code-block:: none ... [semi_supervised_learning] - ssl_method = CPS + method_name = CPS regularize_w = 0.1 rampup_start = 1000 rampup_end = 20000 @@ -76,14 +66,15 @@ related to the SSL method. For example, the correspoinding configuration for CPS .. note:: The configuration items vary with different SSL methods. Please refer to the API - of each built-in SSL method for details of the correspoinding configuration. + of each built-in SSL method for details of the correspoinding configuration. + See examples in `PyMIC_examples/seg_ssl/ `_. Built-in SSL Methods -------------------- -:mod:`pymic.net_run_ssl.ssl_abstract.SSLSegAgent` is the abstract class used for -semi-supervised learning. The built-in SSL methods are child classes of :mod:`SSLSegAgent`. -The available SSL methods implemnted in PyMIC are listed in :mod:`pymic.net_run_ssl.ssl_main.SSLMethodDict`, +:mod:`pymic.net_run.semi_sup.ssl_abstract.SSLSegAgent` is the abstract class used for +semi-supervised learning. The built-in SSL methods are child classes of `SSLSegAgent`. +The available SSL methods implemnted in PyMIC are listed in `pymic.net_run.semi_sup.SSLMethodDict`, and they are: * ``EntropyMinimization``: (`NeurIPS 2005 `_) @@ -103,13 +94,13 @@ and they are: Customized SSL Methods ---------------------- -PyMIC alo supports customizing SSL methods by inheriting the :mod:`SSLSegAgent` class. -You may only need to rewrite the :mod:`training()` method and reuse most part of the +PyMIC alo supports customizing SSL methods by inheriting the `SSLSegAgent` class. +You may only need to rewrite the `training()` method and reuse most part of the existing pipeline, such as data loading, validation and inference methods. For example: .. code-block:: none - from pymic.net_run_ssl.ssl_abstract import SSLSegAgent + from pymic.net_run.semi_sup import SSLSegAgent class MySSLMethod(SSLSegAgent): def __init__(self, config, stage = 'train'): diff --git a/docs/source/usage.wsl.rst b/docs/source/usage.wsl.rst index 00471f6..10d6da2 100644 --- a/docs/source/usage.wsl.rst +++ b/docs/source/usage.wsl.rst @@ -3,23 +3,6 @@ Weakly-Supervised Learning ========================== -pymic_wsl ---------- - -:mod:`pymic_wsl` is the command for using built-in weakly-supervised methods for training. -Similarly to :mod:`pymic_run`, it should be followed by two parameters, specifying the -stage and configuration file, respectively. The training and testing commands are: - -.. code-block:: bash - - pymic_wsl train myconfig_wsl.cfg - pymic_wsl test myconfig_wsl.cfg - -.. tip:: - - If the WSL method only involves one network, either ``pymic_wsl`` or ``pymic_run`` - can be used for inference. Their difference only exists in the training stage. - .. note:: Currently, the weakly supervised methods supported by PyMIC are only for learning @@ -31,17 +14,19 @@ stage and configuration file, respectively. The training and testing commands ar WSL Configurations ------------------ -In the configuration file for ``pymic_wsl``, in addition to those used in fully +In the configuration file for weakly supervised learning, in addition to those used in fully supervised learning, there are some items specificalized for weakly-supervised learning. -First, in the :mod:`train_transform` list, a special transform named :mod:`PartialLabelToProbability` +First, ``supervise_type`` should be set as "`weak_sup`" in the ``dataset`` section. + +Second, in the ``train_transform`` list, a special transform named `PartialLabelToProbability` should be used to transform patial labels into a one-hot probability map and a weighting map of pixels (i.e., the weight of a pixel is 1 if labeled and 0 otherwise). The patial cross entropy loss on labeled pixels is actually implemented by a weighted cross entropy loss. The loss setting is `loss_type = CrossEntropyLoss`. -Second, there is a ``weakly_supervised_learning`` section that is specifically designed -for WSL methods. In that section, users need to specify the ``wsl_method`` and configurations +Thirdly, there is a ``weakly_supervised_learning`` section that is specifically designed +for WSL methods. In that section, users need to specify the ``method_name`` and configurations related to the WSL method. For example, the correspoinding configuration for GatedCRF is: @@ -50,6 +35,7 @@ related to the WSL method. For example, the correspoinding configuration for Gat [dataset] ... + supervise_type = weak_sup root_dir = ../../PyMIC_data/ACDC/preprocess train_csv = config/data/image_train.csv valid_csv = config/data/image_valid.csv @@ -72,7 +58,7 @@ related to the WSL method. For example, the correspoinding configuration for Gat ... [weakly_supervised_learning] - wsl_method = GatedCRF + method_name = GatedCRF regularize_w = 0.1 rampup_start = 2000 rampup_end = 15000 @@ -90,13 +76,14 @@ related to the WSL method. For example, the correspoinding configuration for Gat The configuration items vary with different WSL methods. Please refer to the API of each built-in WSL method for details of the correspoinding configuration. + See examples in `PyMIC_examples/seg_wsl/ `_. Built-in WSL Methods -------------------- -:mod:`pymic.net_run_wsl.wsl_abstract.WSLSegAgent` is the abstract class used for -weakly-supervised learning. The built-in WSL methods are child classes of :mod:`WSLSegAgent`. -The available WSL methods implemnted in PyMIC are listed in :mod:`pymic.net_run_wsl.wsl_main.WSLMethodDict`, +:mod:`pymic.net_run.weak_sup.wsl_abstract.WSLSegAgent` is the abstract class used for +weakly-supervised learning. The built-in WSL methods are child classes of `WSLSegAgent`. +The available WSL methods implemnted in PyMIC are listed in `pymic.net_run.weak_sup.WSLMethodDict`, and they are: * ``EntropyMinimization``: (`NeurIPS 2005 `_) @@ -120,13 +107,13 @@ and they are: Customized WSL Methods ---------------------- -PyMIC alo supports customizing WSL methods by inheriting the :mod:`WSLSegAgent` class. -You may only need to rewrite the :mod:`training()` method and reuse most part of the +PyMIC alo supports customizing WSL methods by inheriting the `WSLSegAgent` class. +You may only need to rewrite the `training()` method and reuse most part of the existing pipeline, such as data loading, validation and inference methods. For example: .. code-block:: none - from pymic.net_run_wsl.wsl_abstract import WSLSegAgent + from pymic.net_run.weak_sup import WSLSegAgent class MyWSLMethod(WSLSegAgent): def __init__(self, config, stage = 'train'): diff --git a/pymic/net_run/agent_abstract.py b/pymic/net_run/agent_abstract.py index 9131ba0..7a509d1 100644 --- a/pymic/net_run/agent_abstract.py +++ b/pymic/net_run/agent_abstract.py @@ -282,6 +282,7 @@ def create_optimizer(self, params, checkpoint = None): :param params: network parameters for optimization. Usually it is obtained by `self.get_parameters_to_update()`. + :param checkpoint: A previous checkpoint to load. Default is `None`. """ opt_params = self.config['training'] if(self.optimizer is None): diff --git a/pymic/net_run/get_optimizer.py b/pymic/net_run/get_optimizer.py index ad8fda0..755ebcf 100644 --- a/pymic/net_run/get_optimizer.py +++ b/pymic/net_run/get_optimizer.py @@ -7,6 +7,15 @@ from pymic.util.general import keyword_match def get_optimizer(name, net_params, optim_params): + """ + Create an optimizer for learnable parameters. + + :param name: (string) Name of the optimizer. Should be one of {`SGD`, `Adam`, + `SparseAdam`, `Adadelta`, `Adagrad`, `Adamax`, `ASGD`, `LBFGS`, `RMSprop`, `Rprop`}. + :param net_params: Learnable parameters that need to be set for an optimizer. + :param optim_params: (dict) The parameters required for the target optimizer. + :return: An instance of the target optimizer. + """ lr = optim_params['learning_rate'] momentum = optim_params['momentum'] weight_decay = optim_params['weight_decay'] @@ -39,6 +48,13 @@ def get_optimizer(name, net_params, optim_params): def get_lr_scheduler(optimizer, sched_params): + """ + Create learning rate scheduler for an optimizer + + :param optimizer: An optimizer instance. + :param sched_params: (dict) The parameters required for the scheduler. + :return: An instance of the target learning rate scheduler. + """ name = sched_params["lr_scheduler"] val_it = sched_params["iter_valid"] epoch_last = sched_params["last_iter"] diff --git a/pymic/transform/intensity.py b/pymic/transform/intensity.py index 3b5ee9d..96458fa 100644 --- a/pymic/transform/intensity.py +++ b/pymic/transform/intensity.py @@ -14,21 +14,20 @@ def bernstein_poly(i, n, t): """ - The Bernstein polynomial of n, i as a function of t + The Bernstein polynomial of n, i as a function of t. """ return comb(n, i) * ( t**(n-i) ) * (1 - t)**i def bezier_curve(points, nTimes=1000): """ - Given a set of control points, return the - bezier curve defined by the control points. - Control points should be a list of lists, or list of tuples - such as [ [1,1], - [2,3], - [4,5], ..[Xn, Yn] ] - nTimes is the number of time steps, defaults to 1000 - See http://processingjs.nihongoresources.com/bezierinfo/ + Given a set of control points, return the + bezier curve defined by the control points. + Control points should be a list of lists, or list of tuples + such as [ [1,1], [2,3], [4,5], ..[Xn, Yn] ]. + + `nTimes` is the number of time steps, defaults to 1000. + See http://processingjs.nihongoresources.com/bezierinfo/ """ nPoints = len(points) From e17a2fa6699710c074026f880979ce203bc8483f Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 15 Mar 2023 16:29:58 +0800 Subject: [PATCH 04/86] update task type: reconstruction update task type: reconstruction --- pymic/net_run/__init__.py | 2 - pymic/net_run/agent_abstract.py | 3 +- pymic/net_run/agent_cls.py | 2 +- pymic/net_run/agent_rec.py | 274 ++++++++++++++++++++++++ pymic/net_run/agent_seg.py | 2 +- pymic/net_run/get_optimizer.py | 16 -- pymic/net_run/predict.py | 7 +- pymic/net_run/self_sup/self_sl_agent.py | 227 +------------------- pymic/net_run/train.py | 7 +- pymic/transform/crop.py | 8 +- pymic/transform/flip.py | 4 +- pymic/transform/intensity.py | 17 +- pymic/transform/label_convert.py | 50 +++++ pymic/transform/pad.py | 4 +- pymic/transform/rescale.py | 8 +- pymic/transform/rotate.py | 4 +- pymic/transform/trans_dict.py | 2 + pymic/transform/transpose.py | 5 +- 18 files changed, 367 insertions(+), 275 deletions(-) create mode 100644 pymic/net_run/agent_rec.py diff --git a/pymic/net_run/__init__.py b/pymic/net_run/__init__.py index 72b8078..e69de29 100644 --- a/pymic/net_run/__init__.py +++ b/pymic/net_run/__init__.py @@ -1,2 +0,0 @@ -from __future__ import absolute_import -from . import * \ No newline at end of file diff --git a/pymic/net_run/agent_abstract.py b/pymic/net_run/agent_abstract.py index 7a509d1..2e6c062 100644 --- a/pymic/net_run/agent_abstract.py +++ b/pymic/net_run/agent_abstract.py @@ -57,7 +57,7 @@ def __init__(self, config, stage = 'train'): self.transform_dict = None self.inferer = None self.tensor_type = config['dataset']['tensor_type'] - self.task_type = config['dataset']['task_type'] #cls, cls_mtbc, seg + self.task_type = config['dataset']['task_type'] #cls, cls_mtbc, seg, rec self.deterministic = config['training'].get('deterministic', True) self.random_seed = config['training'].get('random_seed', 1) if(self.deterministic): @@ -282,7 +282,6 @@ def create_optimizer(self, params, checkpoint = None): :param params: network parameters for optimization. Usually it is obtained by `self.get_parameters_to_update()`. - :param checkpoint: A previous checkpoint to load. Default is `None`. """ opt_params = self.config['training'] if(self.optimizer is None): diff --git a/pymic/net_run/agent_cls.py b/pymic/net_run/agent_cls.py index 5610982..862a703 100644 --- a/pymic/net_run/agent_cls.py +++ b/pymic/net_run/agent_cls.py @@ -58,7 +58,7 @@ def get_stage_dataset_from_config(self, stage): data_transform = None else: transform_param = self.config['dataset'] - transform_param['task'] = 'classification' + transform_param['task'] = self.task_type for name in transform_names: if(name not in self.transform_dict): raise(ValueError("Undefined transform {0:}".format(name))) diff --git a/pymic/net_run/agent_rec.py b/pymic/net_run/agent_rec.py new file mode 100644 index 0000000..607623d --- /dev/null +++ b/pymic/net_run/agent_rec.py @@ -0,0 +1,274 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import copy +import logging +import time +import logging +import numpy as np +import os +import scipy +import torch +import torch.nn as nn +from datetime import datetime +from torch.optim import lr_scheduler +from tensorboardX import SummaryWriter +from pymic.io.image_read_write import save_nd_array_as_image +from pymic.net_run.infer_func import Inferer +from pymic.net_run.agent_seg import SegmentationAgent +from pymic.loss.seg.mse import MAELoss, MSELoss + +ReconstructionLossDict = { + 'MAELoss': MAELoss, + 'MSELoss': MSELoss + } + +class ReconstructionAgent(SegmentationAgent): + """ + An agent for image reconstruction (pixel-level intensity prediction). + """ + def __init__(self, config, stage = 'train'): + super(ReconstructionAgent, self).__init__(config, stage) + output_act_name = config['network'].get('output_activation', 'sigmoid') + if(output_act_name == "sigmoid"): + self.out_act = nn.Sigmoid() + elif(output_act_name == "tanh"): + self.out_act = nn.Tanh() + else: + raise ValueError("For reconstruction task, only sigmoid and tanh are " + \ + "supported for output_activation.") + + def create_loss_calculator(self): + if(self.loss_dict is None): + self.loss_dict = ReconstructionLossDict + loss_name = self.config['training']['loss_type'] + if isinstance(loss_name, (list, tuple)): + raise ValueError("Undefined loss function {0:}".format(loss_name)) + elif (loss_name not in self.loss_dict): + raise ValueError("Undefined loss function {0:}".format(loss_name)) + else: + loss_param = self.config['training'] + loss_param['loss_softmax'] = False + base_loss = self.loss_dict[loss_name](self.config['training']) + if(self.config['training'].get('deep_supervise', False)): + raise ValueError("Deep supervised loss not implemented for reconstruction tasks") + # weight = self.config['training'].get('deep_supervise_weight', None) + # mode = self.config['training'].get('deep_supervise_mode', 2) + # params = {'deep_supervise_weight': weight, + # 'deep_supervise_mode': mode, + # 'base_loss':base_loss} + # self.loss_calculator = DeepSuperviseLoss(params) + else: + self.loss_calculator = base_loss + + def training(self): + iter_valid = self.config['training']['iter_valid'] + train_loss = 0 + self.net.train() + for it in range(iter_valid): + try: + data = next(self.trainIter) + except StopIteration: + self.trainIter = iter(self.train_loader) + data = next(self.trainIter) + # get the inputs + inputs = self.convert_tensor_type(data['image']) + label = self.convert_tensor_type(data['label']) + + # for debug + # from pymic.io.image_read_write import save_nd_array_as_image + # for i in range(inputs.shape[0]): + # image_i = inputs[i][0] + # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) + # save_nd_array_as_image(image_i, image_name, reference_name = None) + # return + + inputs, label = inputs.to(self.device), label.to(self.device) + + # zero the parameter gradients + self.optimizer.zero_grad() + + # forward + backward + optimize + outputs = self.net(inputs) + outputs = self.out_act(outputs) + loss = self.get_loss_value(data, outputs, label) + loss.backward() + self.optimizer.step() + train_loss = train_loss + loss.item() + # get dice evaluation for each class + if(isinstance(outputs, tuple) or isinstance(outputs, list)): + outputs = outputs[0] + + train_avg_loss = train_loss / iter_valid + train_scalers = {'loss': train_avg_loss} + return train_scalers + + def validation(self): + class_num = self.config['network']['class_num'] + if(self.inferer is None): + infer_cfg = self.config['testing'] + infer_cfg['class_num'] = class_num + self.inferer = Inferer(infer_cfg) + + valid_loss_list = [] + validIter = iter(self.valid_loader) + with torch.no_grad(): + self.net.eval() + for data in validIter: + inputs = self.convert_tensor_type(data['image']) + label = self.convert_tensor_type(data['label']) + inputs, label = inputs.to(self.device), label.to(self.device) + outputs = self.inferer.run(self.net, inputs) + outputs = self.out_act(outputs) + # The tensors are on CPU when calculating loss for validation data + loss = self.get_loss_value(data, outputs, label) + valid_loss_list.append(loss.item()) + + valid_avg_loss = np.asarray(valid_loss_list).mean() + valid_scalers = {'loss': valid_avg_loss} + return valid_scalers + + def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): + loss_scalar ={'train':train_scalars['loss'], + 'valid':valid_scalars['loss']} + self.summ_writer.add_scalars('loss', loss_scalar, glob_it) + self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it) + logging.info('train loss {0:.4f}'.format(train_scalars['loss'])) + logging.info('valid loss {0:.4f}'.format(valid_scalars['loss'])) + + def train_valid(self): + device_ids = self.config['training']['gpus'] + if(len(device_ids) > 1): + self.device = torch.device("cuda:0") + self.net = nn.DataParallel(self.net, device_ids = device_ids) + else: + self.device = torch.device("cuda:{0:}".format(device_ids[0])) + self.net.to(self.device) + ckpt_dir = self.config['training']['ckpt_save_dir'] + ckpt_prefix = self.config['training'].get('ckpt_prefix', None) + if(ckpt_prefix is None): + ckpt_prefix = ckpt_dir.split('/')[-1] + iter_start = self.config['training']['iter_start'] + iter_max = self.config['training']['iter_max'] + iter_valid = self.config['training']['iter_valid'] + iter_save = self.config['training'].get('iter_save', None) + early_stop_it = self.config['training'].get('early_stop_patience', None) + if(iter_save is None): + iter_save_list = [iter_max] + elif(isinstance(iter_save, (tuple, list))): + iter_save_list = iter_save + else: + iter_save_list = range(0, iter_max + 1, iter_save) + + self.min_val_loss = 10000.0 + self.max_val_it = 0 + self.best_model_wts = None + self.checkpoint = None + if(iter_start > 0): + checkpoint_file = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, iter_start) + self.checkpoint = torch.load(checkpoint_file, map_location = self.device) + # assert(self.checkpoint['iteration'] == iter_start) + if(len(device_ids) > 1): + self.net.module.load_state_dict(self.checkpoint['model_state_dict']) + else: + self.net.load_state_dict(self.checkpoint['model_state_dict']) + self.min_val_loss = self.checkpoint.get('valid_loss', 10000) + # self.max_val_it = self.checkpoint['iteration'] + self.max_val_it = iter_start + self.best_model_wts = self.checkpoint['model_state_dict'] + + self.create_optimizer(self.get_parameters_to_update()) + self.create_loss_calculator() + + self.trainIter = iter(self.train_loader) + + logging.info("{0:} training start".format(str(datetime.now())[:-7])) + self.summ_writer = SummaryWriter(self.config['training']['ckpt_save_dir']) + self.glob_it = iter_start + for it in range(iter_start, iter_max, iter_valid): + lr_value = self.optimizer.param_groups[0]['lr'] + t0 = time.time() + train_scalars = self.training() + t1 = time.time() + valid_scalars = self.validation() + t2 = time.time() + if(isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): + self.scheduler.step(-valid_scalars['loss']) + else: + self.scheduler.step() + + self.glob_it = it + iter_valid + logging.info("\n{0:} it {1:}".format(str(datetime.now())[:-7], self.glob_it)) + logging.info('learning rate {0:}'.format(lr_value)) + logging.info("training/validation time: {0:.2f}s/{1:.2f}s".format(t1-t0, t2-t1)) + self.write_scalars(train_scalars, valid_scalars, lr_value, self.glob_it) + if(valid_scalars['loss'] < self.min_val_loss): + self.min_val_loss = valid_scalars['loss'] + self.max_val_it = self.glob_it + if(len(device_ids) > 1): + self.best_model_wts = copy.deepcopy(self.net.module.state_dict()) + else: + self.best_model_wts = copy.deepcopy(self.net.state_dict()) + + stop_now = True if(early_stop_it is not None and \ + self.glob_it - self.max_val_it > early_stop_it) else False + if ((self.glob_it in iter_save_list) or stop_now): + save_dict = {'iteration': self.glob_it, + 'valid_loss': valid_scalars['loss'], + 'model_state_dict': self.net.module.state_dict() \ + if len(device_ids) > 1 else self.net.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict()} + save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.glob_it) + torch.save(save_dict, save_name) + txt_file = open("{0:}/{1:}_latest.txt".format(ckpt_dir, ckpt_prefix), 'wt') + txt_file.write(str(self.glob_it)) + txt_file.close() + if(stop_now): + logging.info("The training is early stopped") + break + # save the best performing checkpoint + save_dict = {'iteration': self.max_val_it, + 'valid_loss': self.min_val_loss, + 'model_state_dict': self.best_model_wts, + 'optimizer_state_dict': self.optimizer.state_dict()} + save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.max_val_it) + torch.save(save_dict, save_name) + txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefix), 'wt') + txt_file.write(str(self.max_val_it)) + txt_file.close() + logging.info('The best performing iter is {0:}, valid loss {1:}'.format(\ + self.max_val_it, self.min_val_loss)) + self.summ_writer.close() + + def save_outputs(self, data): + """ + Save prediction output. + + :param data: (dictionary) A data dictionary with prediciton result and other + information such as input image name. + """ + output_dir = self.config['testing']['output_dir'] + ignore_dir = self.config['testing'].get('filename_ignore_dir', True) + filename_replace_source = self.config['testing'].get('filename_replace_source', None) + filename_replace_target = self.config['testing'].get('filename_replace_target', None) + if(not os.path.exists(output_dir)): + os.makedirs(output_dir, exist_ok=True) + + names, pred = data['names'], data['predict'] + if(isinstance(pred, (list, tuple))): + pred = pred[0] + if(isinstance(self.out_act, nn.Sigmoid)): + pred = scipy.special.expit(pred) + else: + pred = np.tanh(pred) + # save the output predictions + root_dir = self.config['dataset']['root_dir'] + for i in range(len(names)): + save_name = names[i].split('/')[-1] if ignore_dir else \ + names[i].replace('/', '_') + if((filename_replace_source is not None) and (filename_replace_target is not None)): + save_name = save_name.replace(filename_replace_source, filename_replace_target) + print(save_name) + save_name = "{0:}/{1:}".format(output_dir, save_name) + save_nd_array_as_image(pred[i][i], save_name, root_dir + '/' + names[i]) + + \ No newline at end of file diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 887a516..b85c716 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -54,7 +54,7 @@ def get_stage_dataset_from_config(self, stage): data_transform = None else: transform_param = self.config['dataset'] - transform_param['task'] = 'segmentation' + transform_param['task'] = self.task_type for name in transform_names: if(name not in self.transform_dict): raise(ValueError("Undefined transform {0:}".format(name))) diff --git a/pymic/net_run/get_optimizer.py b/pymic/net_run/get_optimizer.py index 755ebcf..ad8fda0 100644 --- a/pymic/net_run/get_optimizer.py +++ b/pymic/net_run/get_optimizer.py @@ -7,15 +7,6 @@ from pymic.util.general import keyword_match def get_optimizer(name, net_params, optim_params): - """ - Create an optimizer for learnable parameters. - - :param name: (string) Name of the optimizer. Should be one of {`SGD`, `Adam`, - `SparseAdam`, `Adadelta`, `Adagrad`, `Adamax`, `ASGD`, `LBFGS`, `RMSprop`, `Rprop`}. - :param net_params: Learnable parameters that need to be set for an optimizer. - :param optim_params: (dict) The parameters required for the target optimizer. - :return: An instance of the target optimizer. - """ lr = optim_params['learning_rate'] momentum = optim_params['momentum'] weight_decay = optim_params['weight_decay'] @@ -48,13 +39,6 @@ def get_optimizer(name, net_params, optim_params): def get_lr_scheduler(optimizer, sched_params): - """ - Create learning rate scheduler for an optimizer - - :param optimizer: An optimizer instance. - :param sched_params: (dict) The parameters required for the scheduler. - :return: An instance of the target learning rate scheduler. - """ name = sched_params["lr_scheduler"] val_it = sched_params["iter_valid"] epoch_last = sched_params["last_iter"] diff --git a/pymic/net_run/predict.py b/pymic/net_run/predict.py index ca4ef25..31fff86 100644 --- a/pymic/net_run/predict.py +++ b/pymic/net_run/predict.py @@ -7,6 +7,7 @@ from pymic.util.parse_config import * from pymic.net_run.agent_cls import ClassificationAgent from pymic.net_run.agent_seg import SegmentationAgent +from pymic.net_run.agent_rec import ReconstructionAgent def main(): """ @@ -34,11 +35,13 @@ def main(): logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging_config(config) task = config['dataset']['task_type'] - assert task in ['cls', 'cls_nexcl', 'seg'] + assert task in ['cls', 'cls_nexcl', 'seg', 'rec'] if(task == 'cls' or task == 'cls_nexcl'): agent = ClassificationAgent(config, 'test') - else: + elif(task == 'seg'): agent = SegmentationAgent(config, 'test') + else: + agent = ReconstructionAgent(config, 'test') agent.run() if __name__ == "__main__": diff --git a/pymic/net_run/self_sup/self_sl_agent.py b/pymic/net_run/self_sup/self_sl_agent.py index 24a6e66..c352adf 100644 --- a/pymic/net_run/self_sup/self_sl_agent.py +++ b/pymic/net_run/self_sup/self_sl_agent.py @@ -3,31 +3,10 @@ import copy import logging import time -import logging -import numpy as np -import random -import torch -import torch.nn as nn -import torchvision.transforms as transforms -from datetime import datetime -from random import random -from torch.optim import lr_scheduler -from tensorboardX import SummaryWriter -from pymic.io.nifty_dataset import NiftyDataset -from pymic.loss.seg.util import get_soft_label -from pymic.loss.seg.util import reshape_prediction_and_ground_truth -from pymic.loss.seg.util import get_classwise_dice -from pymic.net_run.infer_func import Inferer -from pymic.net_run.agent_seg import SegmentationAgent -from pymic.transform.trans_dict import TransformDict -from pymic.loss.seg.mse import MAELoss, MSELoss +from pymic.net_run.agent_rec import ReconstructionAgent -RegressionLossDict = { - 'MAELoss': MAELoss, - 'MSELoss': MSELoss - } -class SelfSLSegAgent(SegmentationAgent): +class SelfSLSegAgent(ReconstructionAgent): """ Abstract class for self-supervised segmentation. @@ -42,204 +21,4 @@ class SelfSLSegAgent(SegmentationAgent): """ def __init__(self, config, stage = 'train'): super(SelfSLSegAgent, self).__init__(config, stage) - self.transform_dict = TransformDict - - def create_loss_calculator(self): - if(self.loss_dict is None): - self.loss_dict = RegressionLossDict - loss_name = self.config['training']['loss_type'] - if isinstance(loss_name, (list, tuple)): - raise ValueError("Undefined loss function {0:}".format(loss_name)) - elif (loss_name not in self.loss_dict): - raise ValueError("Undefined loss function {0:}".format(loss_name)) - else: - loss_param = self.config['training'] - loss_param['loss_softmax'] = False - base_loss = self.loss_dict[loss_name](self.config['training']) - if(self.config['training'].get('deep_supervise', False)): - raise ValueError("Deep supervised loss not implemented for self-supervised learning") - # weight = self.config['training'].get('deep_supervise_weight', None) - # mode = self.config['training'].get('deep_supervise_mode', 2) - # params = {'deep_supervise_weight': weight, - # 'deep_supervise_mode': mode, - # 'base_loss':base_loss} - # self.loss_calculator = DeepSuperviseLoss(params) - else: - self.loss_calculator = base_loss - - def training(self): - iter_valid = self.config['training']['iter_valid'] - train_loss = 0 - self.net.train() - for it in range(iter_valid): - try: - data = next(self.trainIter) - except StopIteration: - self.trainIter = iter(self.train_loader) - data = next(self.trainIter) - # get the inputs - inputs = self.convert_tensor_type(data['image']) - label = self.convert_tensor_type(data['label']) - - # for debug - # from pymic.io.image_read_write import save_nd_array_as_image - # for i in range(inputs.shape[0]): - # image_i = inputs[i][0] - # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) - # save_nd_array_as_image(image_i, image_name, reference_name = None) - # return - - inputs, label = inputs.to(self.device), label.to(self.device) - - # zero the parameter gradients - self.optimizer.zero_grad() - - # forward + backward + optimize - outputs = self.net(inputs) - outputs = nn.Sigmoid()(outputs) - loss = self.get_loss_value(data, outputs, label) - loss.backward() - self.optimizer.step() - train_loss = train_loss + loss.item() - # get dice evaluation for each class - if(isinstance(outputs, tuple) or isinstance(outputs, list)): - outputs = outputs[0] - - train_avg_loss = train_loss / iter_valid - train_scalers = {'loss': train_avg_loss} - return train_scalers - - def validation(self): - if(self.inferer is None): - infer_cfg = self.config['testing'] - self.inferer = Inferer(infer_cfg) - - valid_loss_list = [] - validIter = iter(self.valid_loader) - with torch.no_grad(): - self.net.eval() - for data in validIter: - inputs = self.convert_tensor_type(data['image']) - label = self.convert_tensor_type(data['label']) - inputs, label = inputs.to(self.device), label.to(self.device) - outputs = self.inferer.run(self.net, inputs) - outputs = nn.Sigmoid()(outputs) - # The tensors are on CPU when calculating loss for validation data - loss = self.get_loss_value(data, outputs, label) - valid_loss_list.append(loss.item()) - - valid_avg_loss = np.asarray(valid_loss_list).mean() - valid_scalers = {'loss': valid_avg_loss} - return valid_scalers - - def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): - loss_scalar ={'train':train_scalars['loss'], - 'valid':valid_scalars['loss']} - self.summ_writer.add_scalars('loss', loss_scalar, glob_it) - self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it) - logging.info('train loss {0:.4f}'.format(train_scalars['loss'])) - logging.info('valid loss {0:.4f}'.format(valid_scalars['loss'])) - - def train_valid(self): - device_ids = self.config['training']['gpus'] - if(len(device_ids) > 1): - self.device = torch.device("cuda:0") - self.net = nn.DataParallel(self.net, device_ids = device_ids) - else: - self.device = torch.device("cuda:{0:}".format(device_ids[0])) - self.net.to(self.device) - ckpt_dir = self.config['training']['ckpt_save_dir'] - ckpt_prefix = self.config['training'].get('ckpt_prefix', None) - if(ckpt_prefix is None): - ckpt_prefix = ckpt_dir.split('/')[-1] - iter_start = self.config['training']['iter_start'] - iter_max = self.config['training']['iter_max'] - iter_valid = self.config['training']['iter_valid'] - iter_save = self.config['training'].get('iter_save', None) - early_stop_it = self.config['training'].get('early_stop_patience', None) - if(iter_save is None): - iter_save_list = [iter_max] - elif(isinstance(iter_save, (tuple, list))): - iter_save_list = iter_save - else: - iter_save_list = range(0, iter_max + 1, iter_save) - - self.min_val_loss = 10000.0 - self.max_val_it = 0 - self.best_model_wts = None - self.checkpoint = None - if(iter_start > 0): - checkpoint_file = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, iter_start) - self.checkpoint = torch.load(checkpoint_file, map_location = self.device) - # assert(self.checkpoint['iteration'] == iter_start) - if(len(device_ids) > 1): - self.net.module.load_state_dict(self.checkpoint['model_state_dict']) - else: - self.net.load_state_dict(self.checkpoint['model_state_dict']) - self.min_val_loss = self.checkpoint.get('valid_loss', 10000) - # self.max_val_it = self.checkpoint['iteration'] - self.max_val_it = iter_start - self.best_model_wts = self.checkpoint['model_state_dict'] - - self.create_optimizer(self.get_parameters_to_update()) - self.create_loss_calculator() - - self.trainIter = iter(self.train_loader) - - logging.info("{0:} training start".format(str(datetime.now())[:-7])) - self.summ_writer = SummaryWriter(self.config['training']['ckpt_save_dir']) - self.glob_it = iter_start - for it in range(iter_start, iter_max, iter_valid): - lr_value = self.optimizer.param_groups[0]['lr'] - t0 = time.time() - train_scalars = self.training() - t1 = time.time() - valid_scalars = self.validation() - t2 = time.time() - if(isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): - self.scheduler.step(-valid_scalars['loss']) - else: - self.scheduler.step() - - self.glob_it = it + iter_valid - logging.info("\n{0:} it {1:}".format(str(datetime.now())[:-7], self.glob_it)) - logging.info('learning rate {0:}'.format(lr_value)) - logging.info("training/validation time: {0:.2f}s/{1:.2f}s".format(t1-t0, t2-t1)) - self.write_scalars(train_scalars, valid_scalars, lr_value, self.glob_it) - if(valid_scalars['loss'] < self.min_val_loss): - self.min_val_loss = valid_scalars['loss'] - self.max_val_it = self.glob_it - if(len(device_ids) > 1): - self.best_model_wts = copy.deepcopy(self.net.module.state_dict()) - else: - self.best_model_wts = copy.deepcopy(self.net.state_dict()) - - stop_now = True if(early_stop_it is not None and \ - self.glob_it - self.max_val_it > early_stop_it) else False - if ((self.glob_it in iter_save_list) or stop_now): - save_dict = {'iteration': self.glob_it, - 'valid_loss': valid_scalars['loss'], - 'model_state_dict': self.net.module.state_dict() \ - if len(device_ids) > 1 else self.net.state_dict(), - 'optimizer_state_dict': self.optimizer.state_dict()} - save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.glob_it) - torch.save(save_dict, save_name) - txt_file = open("{0:}/{1:}_latest.txt".format(ckpt_dir, ckpt_prefix), 'wt') - txt_file.write(str(self.glob_it)) - txt_file.close() - if(stop_now): - logging.info("The training is early stopped") - break - # save the best performing checkpoint - save_dict = {'iteration': self.max_val_it, - 'valid_loss': self.min_val_loss, - 'model_state_dict': self.best_model_wts, - 'optimizer_state_dict': self.optimizer.state_dict()} - save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.max_val_it) - torch.save(save_dict, save_name) - txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefix), 'wt') - txt_file.write(str(self.max_val_it)) - txt_file.close() - logging.info('The best performing iter is {0:}, valid loss {1:}'.format(\ - self.max_val_it, self.min_val_loss)) - self.summ_writer.close() \ No newline at end of file + \ No newline at end of file diff --git a/pymic/net_run/train.py b/pymic/net_run/train.py index 1478527..107a519 100644 --- a/pymic/net_run/train.py +++ b/pymic/net_run/train.py @@ -13,7 +13,7 @@ from pymic.net_run.noisy_label import NLLMethodDict from pymic.net_run.self_sup import SelfSLSegAgent -def get_segmentation_agent(config, sup_type): +def get_seg_rec_agent(config, sup_type): assert(sup_type in ['fully_sup', 'semi_sup', 'self_sup', 'weak_sup', 'noisy_label']) if(sup_type == 'fully_sup'): logging.info("\n********** Fully Supervised Learning **********\n") @@ -86,12 +86,13 @@ def main(): logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging_config(config) task = config['dataset']['task_type'] - assert task in ['cls', 'cls_nexcl', 'seg'] + assert task in ['cls', 'cls_nexcl', 'seg', 'rec'] if(task == 'cls' or task == 'cls_nexcl'): agent = ClassificationAgent(config, 'train') else: sup_type = config['dataset'].get('supervise_type', 'fully_sup') - agent = get_segmentation_agent(config, sup_type) + agent = get_seg_rec_agent(config, sup_type) + agent.run() if __name__ == "__main__": diff --git a/pymic/transform/crop.py b/pymic/transform/crop.py index a27288d..60f5c9b 100644 --- a/pymic/transform/crop.py +++ b/pymic/transform/crop.py @@ -55,12 +55,12 @@ def __call__(self, sample): image_t = crop_ND_volume_with_bounding_box(image, crop_min, crop_max) sample['image'] = image_t - if('label' in sample and self.task == 'segmentation'): + if('label' in sample and self.task in ['seg', 'rec']): label = sample['label'] crop_max[0] = label.shape[0] label = crop_ND_volume_with_bounding_box(label, crop_min, crop_max) sample['label'] = label - if('pixel_weight' in sample and self.task == 'segmentation'): + if('pixel_weight' in sample and self.task in ['seg', 'rec']): weight = sample['pixel_weight'] crop_max[0] = weight.shape[0] weight = crop_ND_volume_with_bounding_box(weight, crop_min, crop_max) @@ -300,13 +300,13 @@ def __call__(self, sample): image_t = ndimage.interpolation.zoom(image_t, scale, order = 1) sample['image'] = image_t - if('label' in sample and self.task == 'segmentation'): + if('label' in sample and self.task in ['seg', 'rec']): label = sample['label'] crop_max[0] = label.shape[0] label = crop_ND_volume_with_bounding_box(label, crop_min, crop_max) label = ndimage.interpolation.zoom(label, scale, order = 0) sample['label'] = label - if('pixel_weight' in sample and self.task == 'segmentation'): + if('pixel_weight' in sample and self.task ['seg', 'rec']): weight = sample['pixel_weight'] crop_max[0] = weight.shape[0] weight = crop_ND_volume_with_bounding_box(weight, crop_min, crop_max) diff --git a/pymic/transform/flip.py b/pymic/transform/flip.py index ca0915e..0462935 100644 --- a/pymic/transform/flip.py +++ b/pymic/transform/flip.py @@ -52,9 +52,9 @@ def __call__(self, sample): # current pytorch does not support negative strides image_t = np.flip(image, flip_axis).copy() sample['image'] = image_t - if('label' in sample and self.task == 'segmentation'): + if('label' in sample and self.task in ['seg', 'rec']): sample['label'] = np.flip(sample['label'] , flip_axis).copy() - if('pixel_weight' in sample and self.task == 'segmentation'): + if('pixel_weight' in sample and self.task in ['seg', 'rec']): sample['pixel_weight'] = np.flip(sample['pixel_weight'] , flip_axis).copy() return sample diff --git a/pymic/transform/intensity.py b/pymic/transform/intensity.py index 96458fa..3b5ee9d 100644 --- a/pymic/transform/intensity.py +++ b/pymic/transform/intensity.py @@ -14,20 +14,21 @@ def bernstein_poly(i, n, t): """ - The Bernstein polynomial of n, i as a function of t. + The Bernstein polynomial of n, i as a function of t """ return comb(n, i) * ( t**(n-i) ) * (1 - t)**i def bezier_curve(points, nTimes=1000): """ - Given a set of control points, return the - bezier curve defined by the control points. - Control points should be a list of lists, or list of tuples - such as [ [1,1], [2,3], [4,5], ..[Xn, Yn] ]. - - `nTimes` is the number of time steps, defaults to 1000. - See http://processingjs.nihongoresources.com/bezierinfo/ + Given a set of control points, return the + bezier curve defined by the control points. + Control points should be a list of lists, or list of tuples + such as [ [1,1], + [2,3], + [4,5], ..[Xn, Yn] ] + nTimes is the number of time steps, defaults to 1000 + See http://processingjs.nihongoresources.com/bezierinfo/ """ nPoints = len(points) diff --git a/pymic/transform/label_convert.py b/pymic/transform/label_convert.py index 0dcae37..1efaab6 100644 --- a/pymic/transform/label_convert.py +++ b/pymic/transform/label_convert.py @@ -93,6 +93,27 @@ def __call__(self, sample): sample['label_prob'] = label_prob return sample +class LabelSmooth(AbstractTransform): + """ + Apply label smoothing to one-hot labels. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `LabelSmooth_alpha`: (float) Alpha value for label smoothing. + :param `LabelSmooth_inverse`: (optional, bool) + Is inverse transform needed for inference. Default is `False`. + """ + def __init__(self, params): + super(LabelSmooth, self).__init__(params) + self.alpha = params['LabelSmooth_alpha'.lower()] + self.inverse = params.get('LabelSmooth_inverse'.lower(), False) + + def __call__(self, sample): + label_prob = sample['label_prob'] + K = list(label_prob.shape)[1] + sample['label_prob'] = label_prob * (1.0 - self.alpha) + self.alpha / K + return sample class PartialLabelToProbability(AbstractTransform): """ @@ -130,5 +151,34 @@ def __call__(self, sample): return sample +class SelfSuperviseLabel(AbstractTransform): + """ + Convert one-channel partial label map to one-hot multi-channel probability map. + This is used for segmentation tasks only. In the input label map, 0 represents the + background class, 1 to C-1 represent the foreground classes, and C represents + unlabeled pixels. In the output dictionary, `label_prob` is the one-hot probability + map, and `pixel_weight` represents a weighting map, where the weight for a pixel + is 0 if the label is unkown. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `PartialLabelToProbability_class_num`: (int) The class number for the + segmentation task. + :param `PartialLabelToProbability_inverse`: (optional, bool) + Is inverse transform needed for inference. Default is `False`. + """ + def __init__(self, params): + """ + class_num (int): the class number in the label map + """ + super(SelfSuperviseLabel, self).__init__(params) + self.inverse = params.get('SelfSuperviseLabel_inverse'.lower(), False) + + def __call__(self, sample): + image = sample['image'] + label = image * 1.0 + sample['label'] = label + return sample diff --git a/pymic/transform/pad.py b/pymic/transform/pad.py index 0ec196c..4d292be 100644 --- a/pymic/transform/pad.py +++ b/pymic/transform/pad.py @@ -59,11 +59,11 @@ def __call__(self, sample): sample['image'] = image_t - if('label' in sample and self.task == 'segmentation'): + if('label' in sample and self.task in ['seg', 'rec']): label = sample['label'] label = np.pad(label, pad, 'reflect') if(max(margin) > 0) else label sample['label'] = label - if('pixel_weight' in sample and self.task == 'segmentation'): + if('pixel_weight' in sample and self.task in ['seg', 'rec']): weight = sample['pixel_weight'] weight = np.pad(weight, pad, 'reflect') if(max(margin) > 0) else weight sample['pixel_weight'] = weight diff --git a/pymic/transform/rescale.py b/pymic/transform/rescale.py index 2a671fd..660b156 100644 --- a/pymic/transform/rescale.py +++ b/pymic/transform/rescale.py @@ -48,11 +48,11 @@ def __call__(self, sample): sample['image'] = image_t sample['Rescale_origin_shape'] = json.dumps(input_shape) - if('label' in sample and self.task == 'segmentation'): + if('label' in sample and self.task in ['seg', 'rec']): label = sample['label'] label = ndimage.interpolation.zoom(label, scale, order = 0) sample['label'] = label - if('pixel_weight' in sample and self.task == 'segmentation'): + if('pixel_weight' in sample and self.task in ['seg', 'rec']): weight = sample['pixel_weight'] weight = ndimage.interpolation.zoom(weight, scale, order = 1) sample['pixel_weight'] = weight @@ -126,11 +126,11 @@ def __call__(self, sample): sample['image'] = image_t sample['RandomRescale_Param'] = json.dumps(input_shape) - if('label' in sample and self.task == 'segmentation'): + if('label' in sample and self.task in ['seg', 'rec']): label = sample['label'] label = ndimage.interpolation.zoom(label, scale, order = 0) sample['label'] = label - if('pixel_weight' in sample and self.task == 'segmentation'): + if('pixel_weight' in sample and self.task in ['seg', 'rec']): weight = sample['pixel_weight'] weight = ndimage.interpolation.zoom(weight, scale, order = 1) sample['pixel_weight'] = weight diff --git a/pymic/transform/rotate.py b/pymic/transform/rotate.py index 2aa06d4..bd68e6a 100644 --- a/pymic/transform/rotate.py +++ b/pymic/transform/rotate.py @@ -78,10 +78,10 @@ def __call__(self, sample): sample['RandomRotate_Param'] = json.dumps(transform_param_list) image_t = self.__apply_transformation(image, transform_param_list, 1) sample['image'] = image_t - if('label' in sample and self.task == 'segmentation'): + if('label' in sample and self.task in ['seg', 'rec']): sample['label'] = self.__apply_transformation(sample['label'] , transform_param_list, 0) - if('pixel_weight' in sample and self.task == 'segmentation'): + if('pixel_weight' in sample and self.task in ['seg', 'rec']): sample['pixel_weight'] = self.__apply_transformation(sample['pixel_weight'] , transform_param_list, 1) return sample diff --git a/pymic/transform/trans_dict.py b/pymic/transform/trans_dict.py index bc72c93..a7d96fd 100644 --- a/pymic/transform/trans_dict.py +++ b/pymic/transform/trans_dict.py @@ -25,6 +25,7 @@ 'RandomRotate': RandomRotate, 'ReduceLabelDim': ReduceLabelDim, 'Rescale': Rescale, + 'SelfSuperviseLabel': SelfSuperviseLabel, 'Pad': Pad. """ @@ -67,6 +68,7 @@ 'RandomRotate': RandomRotate, 'ReduceLabelDim': ReduceLabelDim, 'Rescale': Rescale, + 'SelfSuperviseLabel': SelfSuperviseLabel, 'OutPainting': OutPainting, 'Pad': Pad, } diff --git a/pymic/transform/transpose.py b/pymic/transform/transpose.py index 9c73bda..67e8611 100644 --- a/pymic/transform/transpose.py +++ b/pymic/transform/transpose.py @@ -39,10 +39,11 @@ def __call__(self, sample): if(transpose_axis is not None): image_t = np.transpose(image, transpose_axis) sample['image'] = image_t - if('label' in sample and self.task == 'segmentation'): + if('label' in sample and self.task in ['seg', 'rec']): sample['label'] = np.transpose(sample['label'] , transpose_axis) - if('pixel_weight' in sample and self.task == 'segmentation'): + if('pixel_weight' in sample and self.task in ['seg', 'rec']): sample['pixel_weight'] = np.transpose(sample['pixel_weight'] , transpose_axis) + return sample From 0015adefba9895c9591d7952511f444c088d2611 Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 20 Mar 2023 10:54:48 +0800 Subject: [PATCH 05/86] update task type update task type --- pymic/__init__.py | 19 ++++++++++++++++++- pymic/io/nifty_dataset.py | 22 ++++++++++++++++++---- pymic/net_run/agent_abstract.py | 2 +- pymic/net_run/agent_cls.py | 12 ++++++------ pymic/net_run/agent_rec.py | 21 +++++++++++++++++++++ pymic/net_run/agent_seg.py | 3 ++- pymic/net_run/predict.py | 12 +++++++----- pymic/net_run/semi_sup/ssl_abstract.py | 2 +- pymic/net_run/train.py | 6 +++--- pymic/transform/crop.py | 13 +++++++++---- pymic/transform/flip.py | 7 +++++-- pymic/transform/label_convert.py | 5 +++-- pymic/transform/pad.py | 7 +++++-- pymic/transform/rescale.py | 13 +++++++++---- pymic/transform/rotate.py | 7 +++++-- pymic/transform/transpose.py | 11 ++++++----- pymic/util/parse_config.py | 2 ++ pymic/util/preprocess.py | 1 + 18 files changed, 122 insertions(+), 43 deletions(-) diff --git a/pymic/__init__.py b/pymic/__init__.py index cb6356a..1520d82 100644 --- a/pymic/__init__.py +++ b/pymic/__init__.py @@ -1,2 +1,19 @@ from __future__ import absolute_import -__version__ = "0.4.0" \ No newline at end of file +from enum import Enum + +__version__ = "0.4.0" + +class TaskType(Enum): + CLASSIFICATION_ONE_HOT = 1 + CLASSIFICATION_COEXIST = 2 + REGRESSION = 3 + SEGMENTATION = 4 + RECONSTRUCTION = 5 + +TaskDict = { + 'cls': TaskType.CLASSIFICATION_ONE_HOT, + 'cls_coexist': TaskType.CLASSIFICATION_COEXIST, + 'regress': TaskType.REGRESSION, + 'seg': TaskType.SEGMENTATION, + 'rec': TaskType.RECONSTRUCTION +} \ No newline at end of file diff --git a/pymic/io/nifty_dataset.py b/pymic/io/nifty_dataset.py index bb1ff23..9812d13 100644 --- a/pymic/io/nifty_dataset.py +++ b/pymic/io/nifty_dataset.py @@ -1,12 +1,14 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division +import logging import os import torch import pandas as pd import numpy as np from torch.utils.data import Dataset, DataLoader from torchvision import transforms, utils +from pymic import TaskType from pymic.io.image_read_write import load_image_as_nd_array class NiftyDataset(Dataset): @@ -23,14 +25,21 @@ class NiftyDataset(Dataset): The built-in transforms can listed in :mod:`pymic.transform.trans_dict`. """ def __init__(self, root_dir, csv_file, modal_num = 1, - with_label = False, transform=None): + with_label = False, transform=None, task = TaskType.SEGMENTATION): self.root_dir = root_dir self.csv_items = pd.read_csv(csv_file) self.modal_num = modal_num self.with_label = with_label self.transform = transform + self.task = task + assert self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION] csv_keys = list(self.csv_items.keys()) + if('label' not in csv_keys): + logging.warning("`label` section is not found in the csv file {0:}".format( + csv_file) + "\n -- This is only allowed for self-supervised learning" + + "\n -- when `SelfSuperviseLabel` is used in the transform.") + self.with_label = False self.image_weight_idx = None self.pixel_weight_idx = None if('image_weight' in csv_keys): @@ -42,12 +51,15 @@ def __len__(self): return len(self.csv_items) def __getlabel__(self, idx): - csv_keys = list(self.csv_items.keys()) + csv_keys = list(self.csv_items.keys()) label_idx = csv_keys.index('label') label_name = "{0:}/{1:}".format(self.root_dir, self.csv_items.iloc[idx, label_idx]) label = load_image_as_nd_array(label_name)['data_array'] - label = np.asarray(label, np.int32) + if(self.task == TaskType.SEGMENTATION): + label = np.asarray(label, np.int32) + elif(self.task == TaskType.RECONSTRUCTION): + label = np.asarray(label, np.float32) return label def __get_pixel_weight__(self, idx): @@ -101,10 +113,12 @@ class ClassificationDataset(NiftyDataset): The built-in transforms can listed in :mod:`pymic.transform.trans_dict`. """ def __init__(self, root_dir, csv_file, modal_num = 1, class_num = 2, - with_label = False, transform=None): + with_label = False, transform=None, task = TaskType.CLASSIFICATION_ONE_HOT): super(ClassificationDataset, self).__init__(root_dir, csv_file, modal_num, with_label, transform) self.class_num = class_num + self.task = task + assert self.task in [TaskType.CLASSIFICATION_ONE_HOT, TaskType.CLASSIFICATION_COEXIST] def __getlabel__(self, idx): csv_keys = list(self.csv_items.keys()) diff --git a/pymic/net_run/agent_abstract.py b/pymic/net_run/agent_abstract.py index 2e6c062..109eabc 100644 --- a/pymic/net_run/agent_abstract.py +++ b/pymic/net_run/agent_abstract.py @@ -57,7 +57,7 @@ def __init__(self, config, stage = 'train'): self.transform_dict = None self.inferer = None self.tensor_type = config['dataset']['tensor_type'] - self.task_type = config['dataset']['task_type'] #cls, cls_mtbc, seg, rec + self.task_type = config['dataset']['task_type'] self.deterministic = config['training'].get('deterministic', True) self.random_seed = config['training'].get('random_seed', 1) if(self.deterministic): diff --git a/pymic/net_run/agent_cls.py b/pymic/net_run/agent_cls.py index 862a703..1ded3b1 100644 --- a/pymic/net_run/agent_cls.py +++ b/pymic/net_run/agent_cls.py @@ -13,6 +13,7 @@ from torch.optim import lr_scheduler from torchvision import transforms from tensorboardX import SummaryWriter +from pymic import TaskType from pymic.io.nifty_dataset import ClassificationDataset from pymic.loss.loss_dict_cls import PyMICClsLossDict from pymic.net.net_dict_cls import TorchClsNetDict @@ -38,7 +39,6 @@ class ClassificationAgent(NetRunAgent): def __init__(self, config, stage = 'train'): super(ClassificationAgent, self).__init__(config, stage) self.transform_dict = TransformDict - assert(self.task_type in ["cls", "cls_nexcl"]) def get_stage_dataset_from_config(self, stage): assert(stage in ['train', 'valid', 'test']) @@ -119,12 +119,12 @@ def get_evaluation_score(self, outputs, labels): metrics = self.config['training'].get("evaluation_metric", "accuracy") if(metrics != "accuracy"): # default classification accuracy raise ValueError("Not implemeted for metric {0:}".format(metrics)) - if(self.task_type == "cls"): + if(self.task_type == TaskType.CLASSIFICATION_ONE_HOT): out_argmax = torch.argmax(outputs, 1) lab_argmax = torch.argmax(labels, 1) consis = self.convert_tensor_type(out_argmax == lab_argmax) score = torch.mean(consis) - elif(self.task_type == "cls_nexcl"): #nonexclusive classification + elif(self.task_type == TaskType.CLASSIFICATION_COEXIST): preds = self.convert_tensor_type(outputs > 0.5) consis= self.convert_tensor_type(preds == labels.data) score = torch.mean(consis) @@ -346,15 +346,15 @@ def infer(self): infer_time = time.time() - start_time infer_time_list.append(infer_time) - if (self.task_type == "cls"): + if (self.task_type == TaskType.CLASSIFICATION_ONE_HOT): out_prob = nn.Softmax(dim = 1)(out_digit).detach().cpu().numpy() out_lab = np.argmax(out_prob, axis=1) - else: #self.task_type == "cls_nexcl" + else: #self.task_type == TaskType.CLASSIFICATION_COEXIST out_prob = nn.Sigmoid()(out_digit).detach().cpu().numpy() out_lab = np.asarray(out_prob > 0.5, np.uint8) for i in range(len(names)): print(names[i], out_lab[i]) - if(self.task_type == "cls"): + if(self.task_type == TaskType.CLASSIFICATION_ONE_HOT): out_lab_list.append([names[i]] + [out_lab[i]]) else: out_lab_list.append([names[i]] + out_lab[i].tolist()) diff --git a/pymic/net_run/agent_rec.py b/pymic/net_run/agent_rec.py index 607623d..188a5fc 100644 --- a/pymic/net_run/agent_rec.py +++ b/pymic/net_run/agent_rec.py @@ -113,6 +113,9 @@ def validation(self): validIter = iter(self.valid_loader) with torch.no_grad(): self.net.eval() + + # for debug + # save_num = 0 for data in validIter: inputs = self.convert_tensor_type(data['image']) label = self.convert_tensor_type(data['label']) @@ -123,6 +126,24 @@ def validation(self): loss = self.get_loss_value(data, outputs, label) valid_loss_list.append(loss.item()) + # for debug + # print(inputs.shape, label.shape, outputs.shape) + # inputs = inputs.cpu().numpy() + # label = label.cpu().numpy() + # outputs = outputs.cpu().numpy() + # for i in range(inputs.shape[0]): + # image_i = inputs[i][0] + # label_i = label[i][0] + # output_i = outputs[i][0] + # image_name = "temp/case{0:}_image.nii.gz".format(save_num + i) + # label_name = "temp/case{0:}_label.nii.gz".format(save_num + i) + # output_name= "temp/case{0:}_output.nii.gz".format(save_num + i) + # save_nd_array_as_image(image_i, image_name, reference_name = None) + # save_nd_array_as_image(label_i, label_name, reference_name = None) + # save_nd_array_as_image(output_i, output_name, reference_name = None) + # save_num += inputs.shape[0] + # if(save_num > 20): + # break valid_avg_loss = np.asarray(valid_loss_list).mean() valid_scalers = {'loss': valid_avg_loss} return valid_scalers diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index b85c716..fb99075 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -67,7 +67,8 @@ def get_stage_dataset_from_config(self, stage): csv_file = csv_file, modal_num = modal_num, with_label= not (stage == 'test'), - transform = data_transform ) + transform = data_transform, + task = self.task_type) return dataset def create_network(self): diff --git a/pymic/net_run/predict.py b/pymic/net_run/predict.py index 31fff86..80134d8 100644 --- a/pymic/net_run/predict.py +++ b/pymic/net_run/predict.py @@ -4,6 +4,7 @@ import os import sys from datetime import datetime +from pymic import TaskType from pymic.util.parse_config import * from pymic.net_run.agent_cls import ClassificationAgent from pymic.net_run.agent_seg import SegmentationAgent @@ -34,14 +35,15 @@ def main(): level=logging.INFO, format='%(message)s') # for python 3.6 logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging_config(config) - task = config['dataset']['task_type'] - assert task in ['cls', 'cls_nexcl', 'seg', 'rec'] - if(task == 'cls' or task == 'cls_nexcl'): + task = config['dataset']['task_type'] + if(task == TaskType.CLASSIFICATION_ONE_HOT or task == TaskType.CLASSIFICATION_COEXIST): agent = ClassificationAgent(config, 'test') - elif(task == 'seg'): + elif(task == TaskType.SEGMENTATION): agent = SegmentationAgent(config, 'test') - else: + elif(task == TaskType.RECONSTRUCTION): agent = ReconstructionAgent(config, 'test') + else: + raise ValueError("Undefined task for inference: {0:}".format(task)) agent.run() if __name__ == "__main__": diff --git a/pymic/net_run/semi_sup/ssl_abstract.py b/pymic/net_run/semi_sup/ssl_abstract.py index 5a46257..b27edc9 100644 --- a/pymic/net_run/semi_sup/ssl_abstract.py +++ b/pymic/net_run/semi_sup/ssl_abstract.py @@ -44,7 +44,7 @@ def get_unlabeled_dataset_from_config(self): data_transform = None else: transform_param = self.config['dataset'] - transform_param['task'] = 'segmentation' + transform_param['task'] = self.task_type for name in transform_names: if(name not in self.transform_dict): raise(ValueError("Undefined transform {0:}".format(name))) diff --git a/pymic/net_run/train.py b/pymic/net_run/train.py index 107a519..0167f2f 100644 --- a/pymic/net_run/train.py +++ b/pymic/net_run/train.py @@ -5,6 +5,7 @@ import sys import shutil from datetime import datetime +from pymic import TaskType from pymic.util.parse_config import * from pymic.net_run.agent_cls import ClassificationAgent from pymic.net_run.agent_seg import SegmentationAgent @@ -47,7 +48,7 @@ def get_seg_rec_agent(config, sup_type): 'inpainting_probability': 0.2 } config['dataset']['train_transform'].extend(transforms) - config['dataset']['valid_transform'].extend(transforms) + # config['dataset']['valid_transform'].extend(transforms) config['dataset'].update(genesis_cfg) logging_config(config['dataset']) else: @@ -86,8 +87,7 @@ def main(): logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging_config(config) task = config['dataset']['task_type'] - assert task in ['cls', 'cls_nexcl', 'seg', 'rec'] - if(task == 'cls' or task == 'cls_nexcl'): + if(task == TaskType.CLASSIFICATION_ONE_HOT or task == TaskType.CLASSIFICATION_COEXIST): agent = ClassificationAgent(config, 'train') else: sup_type = config['dataset'].get('supervise_type', 'fully_sup') diff --git a/pymic/transform/crop.py b/pymic/transform/crop.py index 60f5c9b..acadc49 100644 --- a/pymic/transform/crop.py +++ b/pymic/transform/crop.py @@ -7,6 +7,7 @@ import random import numpy as np from scipy import ndimage +from pymic import TaskType from pymic.transform.abstract_transform import AbstractTransform from pymic.util.image_process import * @@ -55,12 +56,14 @@ def __call__(self, sample): image_t = crop_ND_volume_with_bounding_box(image, crop_min, crop_max) sample['image'] = image_t - if('label' in sample and self.task in ['seg', 'rec']): + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): label = sample['label'] crop_max[0] = label.shape[0] label = crop_ND_volume_with_bounding_box(label, crop_min, crop_max) sample['label'] = label - if('pixel_weight' in sample and self.task in ['seg', 'rec']): + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): weight = sample['pixel_weight'] crop_max[0] = weight.shape[0] weight = crop_ND_volume_with_bounding_box(weight, crop_min, crop_max) @@ -300,13 +303,15 @@ def __call__(self, sample): image_t = ndimage.interpolation.zoom(image_t, scale, order = 1) sample['image'] = image_t - if('label' in sample and self.task in ['seg', 'rec']): + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): label = sample['label'] crop_max[0] = label.shape[0] label = crop_ND_volume_with_bounding_box(label, crop_min, crop_max) label = ndimage.interpolation.zoom(label, scale, order = 0) sample['label'] = label - if('pixel_weight' in sample and self.task ['seg', 'rec']): + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): weight = sample['pixel_weight'] crop_max[0] = weight.shape[0] weight = crop_ND_volume_with_bounding_box(weight, crop_min, crop_max) diff --git a/pymic/transform/flip.py b/pymic/transform/flip.py index 0462935..24cafb4 100644 --- a/pymic/transform/flip.py +++ b/pymic/transform/flip.py @@ -7,6 +7,7 @@ import random import numpy as np from scipy import ndimage +from pymic import TaskType from pymic.transform.abstract_transform import AbstractTransform from pymic.util.image_process import * @@ -52,9 +53,11 @@ def __call__(self, sample): # current pytorch does not support negative strides image_t = np.flip(image, flip_axis).copy() sample['image'] = image_t - if('label' in sample and self.task in ['seg', 'rec']): + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): sample['label'] = np.flip(sample['label'] , flip_axis).copy() - if('pixel_weight' in sample and self.task in ['seg', 'rec']): + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): sample['pixel_weight'] = np.flip(sample['pixel_weight'] , flip_axis).copy() return sample diff --git a/pymic/transform/label_convert.py b/pymic/transform/label_convert.py index 1efaab6..e3accdf 100644 --- a/pymic/transform/label_convert.py +++ b/pymic/transform/label_convert.py @@ -7,6 +7,7 @@ import random import numpy as np from scipy import ndimage +from pymic import TaskType from pymic.transform.abstract_transform import AbstractTransform from pymic.util.image_process import * @@ -80,13 +81,13 @@ def __init__(self, params): self.inverse = params.get('LabelToProbability_inverse'.lower(), False) def __call__(self, sample): - if(self.task == 'segmentation'): + if(self.task == TaskType.SEGMENTATION): label = sample['label'][0] # sample['label'] is (1, h, w) label_prob = np.zeros((self.class_num, *label.shape), dtype = np.float32) for i in range(self.class_num): label_prob[i] = label == i*np.ones_like(label) sample['label_prob'] = label_prob - elif(self.task == 'classification'): + elif(self.task == TaskType.CLASSIFICATION_ONE_HOT): label_idx = sample['label'] label_prob = np.zeros((self.class_num,), np.float32) label_prob[label_idx] = 1.0 diff --git a/pymic/transform/pad.py b/pymic/transform/pad.py index 4d292be..91cf6da 100644 --- a/pymic/transform/pad.py +++ b/pymic/transform/pad.py @@ -7,6 +7,7 @@ import random import numpy as np from scipy import ndimage +from pymic import TaskType from pymic.transform.abstract_transform import AbstractTransform from pymic.util.image_process import * @@ -59,11 +60,13 @@ def __call__(self, sample): sample['image'] = image_t - if('label' in sample and self.task in ['seg', 'rec']): + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): label = sample['label'] label = np.pad(label, pad, 'reflect') if(max(margin) > 0) else label sample['label'] = label - if('pixel_weight' in sample and self.task in ['seg', 'rec']): + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): weight = sample['pixel_weight'] weight = np.pad(weight, pad, 'reflect') if(max(margin) > 0) else weight sample['pixel_weight'] = weight diff --git a/pymic/transform/rescale.py b/pymic/transform/rescale.py index 660b156..355712e 100644 --- a/pymic/transform/rescale.py +++ b/pymic/transform/rescale.py @@ -5,6 +5,7 @@ import random import numpy as np from scipy import ndimage +from pymic import TaskType from pymic.transform.abstract_transform import AbstractTransform from pymic.util.image_process import * @@ -48,11 +49,13 @@ def __call__(self, sample): sample['image'] = image_t sample['Rescale_origin_shape'] = json.dumps(input_shape) - if('label' in sample and self.task in ['seg', 'rec']): + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): label = sample['label'] label = ndimage.interpolation.zoom(label, scale, order = 0) sample['label'] = label - if('pixel_weight' in sample and self.task in ['seg', 'rec']): + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): weight = sample['pixel_weight'] weight = ndimage.interpolation.zoom(weight, scale, order = 1) sample['pixel_weight'] = weight @@ -126,11 +129,13 @@ def __call__(self, sample): sample['image'] = image_t sample['RandomRescale_Param'] = json.dumps(input_shape) - if('label' in sample and self.task in ['seg', 'rec']): + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): label = sample['label'] label = ndimage.interpolation.zoom(label, scale, order = 0) sample['label'] = label - if('pixel_weight' in sample and self.task in ['seg', 'rec']): + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): weight = sample['pixel_weight'] weight = ndimage.interpolation.zoom(weight, scale, order = 1) sample['pixel_weight'] = weight diff --git a/pymic/transform/rotate.py b/pymic/transform/rotate.py index bd68e6a..65e5328 100644 --- a/pymic/transform/rotate.py +++ b/pymic/transform/rotate.py @@ -5,6 +5,7 @@ import random import numpy as np from scipy import ndimage +from pymic import TaskType from pymic.transform.abstract_transform import AbstractTransform from pymic.util.image_process import * @@ -78,10 +79,12 @@ def __call__(self, sample): sample['RandomRotate_Param'] = json.dumps(transform_param_list) image_t = self.__apply_transformation(image, transform_param_list, 1) sample['image'] = image_t - if('label' in sample and self.task in ['seg', 'rec']): + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): sample['label'] = self.__apply_transformation(sample['label'] , transform_param_list, 0) - if('pixel_weight' in sample and self.task in ['seg', 'rec']): + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): sample['pixel_weight'] = self.__apply_transformation(sample['pixel_weight'] , transform_param_list, 1) return sample diff --git a/pymic/transform/transpose.py b/pymic/transform/transpose.py index 67e8611..6f5d5fa 100644 --- a/pymic/transform/transpose.py +++ b/pymic/transform/transpose.py @@ -4,6 +4,7 @@ import json import random import numpy as np +from pymic import TaskType from pymic.transform.abstract_transform import AbstractTransform @@ -39,12 +40,12 @@ def __call__(self, sample): if(transpose_axis is not None): image_t = np.transpose(image, transpose_axis) sample['image'] = image_t - if('label' in sample and self.task in ['seg', 'rec']): + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): sample['label'] = np.transpose(sample['label'] , transpose_axis) - if('pixel_weight' in sample and self.task in ['seg', 'rec']): - sample['pixel_weight'] = np.transpose(sample['pixel_weight'] , transpose_axis) - - + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + sample['pixel_weight'] = np.transpose(sample['pixel_weight'] , transpose_axis) return sample def inverse_transform_for_prediction(self, sample): diff --git a/pymic/util/parse_config.py b/pymic/util/parse_config.py index 232762f..18afd08 100644 --- a/pymic/util/parse_config.py +++ b/pymic/util/parse_config.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, print_function +from pymic import TaskDict import configparser import logging @@ -103,6 +104,7 @@ def synchronize_config(config): data_cfg = config['dataset'] net_cfg = config['network'] # data_cfg["modal_num"] = net_cfg["in_chns"] + data_cfg["task_type"] = TaskDict[data_cfg["task_type"]] data_cfg["LabelToProbability_class_num".lower()] = net_cfg["class_num"] if "PartialLabelToProbability" in data_cfg['train_transform']: data_cfg["PartialLabelToProbability_class_num".lower()] = net_cfg["class_num"] diff --git a/pymic/util/preprocess.py b/pymic/util/preprocess.py index 5f20372..c0dc9a1 100644 --- a/pymic/util/preprocess.py +++ b/pymic/util/preprocess.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- import os import numpy as np import SimpleITK as sitk From 462e8a5daaca6f36ef0f8cd2ca49304649f641a8 Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 12 Apr 2023 14:14:10 +0800 Subject: [PATCH 06/86] update transform add IntensityClip and edit Normalization --- pymic/net_run/agent_seg.py | 12 +- pymic/transform/crop.py | 64 +++++------ pymic/transform/intensity.py | 39 ++++++- pymic/transform/normalize.py | 7 +- pymic/transform/trans_dict.py | 2 + pymic/util/evaluation_seg.py | 199 +++++++++++++++++++--------------- pymic/util/general.py | 10 +- 7 files changed, 199 insertions(+), 134 deletions(-) diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index fb99075..1bdad06 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -149,15 +149,15 @@ def training(self): # for i in range(inputs.shape[0]): # image_i = inputs[i][0] # label_i = labels_prob[i][1] - # pixw_i = pix_w[i][0] - # print(image_i.shape, label_i.shape, pixw_i.shape) + # # pixw_i = pix_w[i][0] + # print(image_i.shape, label_i.shape) # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) # label_name = "temp/label_{0:}_{1:}.nii.gz".format(it, i) - # weight_name= "temp/weight_{0:}_{1:}.nii.gz".format(it, i) + # # weight_name= "temp/weight_{0:}_{1:}.nii.gz".format(it, i) # save_nd_array_as_image(image_i, image_name, reference_name = None) # save_nd_array_as_image(label_i, label_name, reference_name = None) - # save_nd_array_as_image(pixw_i, weight_name, reference_name = None) - # continue + # # save_nd_array_as_image(pixw_i, weight_name, reference_name = None) + # # continue inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device) @@ -297,7 +297,7 @@ def train_valid(self): if(ckpt_init_mode > 0): # Load other information self.max_val_dice = checkpoint.get('valid_pred', 0) - iter_start = checkpoint['iteration'] - 1 + iter_start = checkpoint['iteration'] self.max_val_it = iter_start self.best_model_wts = checkpoint['model_state_dict'] ckpt_for_optm = checkpoint diff --git a/pymic/transform/crop.py b/pymic/transform/crop.py index acadc49..1012712 100644 --- a/pymic/transform/crop.py +++ b/pymic/transform/crop.py @@ -243,62 +243,51 @@ def _get_param_for_inverse_transform(self, sample): class RandomResizedCrop(CenterCrop): """ - Randomly crop the input image (shape [C, H, W]). Only 2D images are supported. - + Randomly resize and crop the input image (shape [C, D, H, W]). The arguments should be written in the `params` dictionary, and it has the following fields: - :param `RandomResizedCrop_output_size`: (list/tuple) Desired output size [H, W]. + :param `RandomResizedCrop_output_size`: (list/tuple) Desired output size [D, H, W]. The output channel is the same as the input channel. - :param `RandomResizedCrop_scale`: (list/tuple) Range of scale, e.g. (0.08, 1.0). - :param `RandomResizedCrop_ratio`: (list/tuple) Range of aspect ratio, e.g. (0.75, 1.33). + :param `RandomResizedCrop_scale_range`: (list/tuple) Range of scale, e.g. (0.08, 1.0). :param `RandomResizedCrop_inverse`: (optional, bool) Is inverse transform needed for inference. Default is `False`. Currently, the inverse transform is not supported, and this transform is assumed to be used only during training stage. """ def __init__(self, params): self.output_size = params['RandomResizedCrop_output_size'.lower()] - self.scale = params['RandomResizedCrop_scale'.lower()] - self.ratio = params['RandomResizedCrop_ratio'.lower()] + self.scale = params['RandomResizedCrop_scale_range'.lower()] self.inverse = params.get('RandomResizedCrop_inverse'.lower(), False) self.task = params['Task'.lower()] assert isinstance(self.output_size, (list, tuple)) assert isinstance(self.scale, (list, tuple)) - assert isinstance(self.ratio, (list, tuple)) - def _get_crop_param(self, sample): + def __call__(self, sample): image = sample['image'] - input_shape = image.shape - input_dim = len(input_shape) - 1 - assert(input_dim == 2) + channel, input_size = image.shape[0], image.shape[1:] + input_dim = len(input_size) assert(input_dim == len(self.output_size)) - scale = self.scale[0] + random.random()*(self.scale[1] - self.scale[0]) - ratio = self.ratio[0] + random.random()*(self.ratio[1] - self.ratio[0]) - crop_w = input_shape[-1] * scale - crop_h = crop_w * ratio - crop_h = min(crop_h, input_shape[-2]) - output_shape = [int(crop_h), int(crop_w)] - - crop_margin = [input_shape[i + 1] - output_shape[i]\ - for i in range(input_dim)] + crop_size = [int(self.output_size[i] * scale) for i in range(input_dim)] + crop_margin = [input_size[i] - crop_size[i] for i in range(input_dim)] + pad_image = False + if(min(crop_margin) < 0): + pad_image = True + pad_size = [max(0, -crop_margin[i]) for i in range(input_dim)] + pad_lower = [int(pad_size[i] / 2) for i in range(input_dim)] + pad_upper = [pad_size[i] - pad_lower[i] for i in range(input_dim)] + pad = [(pad_lower[i], pad_upper[i]) for i in range(input_dim)] + pad = tuple([(0, 0)] + pad) + image = np.pad(image, pad, 'reflect') + crop_margin = [max(0, crop_margin[i]) for i in range(input_dim)] + crop_min = [random.randint(0, item) for item in crop_margin] - crop_max = [crop_min[i] + output_shape[i] \ - for i in range(input_dim)] + crop_max = [crop_min[i] + crop_size[i] for i in range(input_dim)] crop_min = [0] + crop_min - crop_max = list(input_shape[0:1]) + crop_max - sample['RandomResizedCrop_Param'] = json.dumps((input_shape, crop_min, crop_max)) - return sample, crop_min, crop_max - - def __call__(self, sample): - image = sample['image'] - input_shape = image.shape - input_dim = len(input_shape) - 1 - sample, crop_min, crop_max = self._get_crop_param(sample) + crop_max = [channel] + crop_max image_t = crop_ND_volume_with_bounding_box(image, crop_min, crop_max) - crp_shape = image_t.shape - scale = [(self.output_size[i] + 0.0)/crp_shape[1:][i] for i in range(input_dim)] + scale = [(self.output_size[i] + 0.0)/crop_size[i] for i in range(input_dim)] scale = [1.0] + scale image_t = ndimage.interpolation.zoom(image_t, scale, order = 1) sample['image'] = image_t @@ -306,13 +295,18 @@ def __call__(self, sample): if('label' in sample and \ self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): label = sample['label'] + if(pad_image): + label = np.pad(label, pad, 'reflect') crop_max[0] = label.shape[0] label = crop_ND_volume_with_bounding_box(label, crop_min, crop_max) - label = ndimage.interpolation.zoom(label, scale, order = 0) + order = 0 if(self.task == TaskType.SEGMENTATION) else 1 + label = ndimage.interpolation.zoom(label, scale, order = order) sample['label'] = label if('pixel_weight' in sample and \ self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): weight = sample['pixel_weight'] + if(pad_image): + weight = np.pad(weight, pad, 'reflect') crop_max[0] = weight.shape[0] weight = crop_ND_volume_with_bounding_box(weight, crop_min, crop_max) weight = ndimage.interpolation.zoom(weight, scale, order = 1) diff --git a/pymic/transform/intensity.py b/pymic/transform/intensity.py index 3b5ee9d..1a13190 100644 --- a/pymic/transform/intensity.py +++ b/pymic/transform/intensity.py @@ -44,6 +44,40 @@ def bezier_curve(points, nTimes=1000): return xvals, yvals +class IntensityClip(AbstractTransform): + """ + Clip the intensity for input image + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `IntensityClip_channels`: (list) A list of int for specifying the channels. + :param `IntensityClip_lower`: (list) The lower bound for clip in each channel. + :param `IntensityClip_upper`: (list) The upper bound for clip in each channel. + :param `IntensityClip_inverse`: (optional, bool) + Is inverse transform needed for inference. Default is `False`. + """ + def __init__(self, params): + super(IntensityClip, self).__init__(params) + self.channels = params['IntensityClip_channels'.lower()] + self.lower = params.get('IntensityClip_lower'.lower(), None) + self.upper = params.get('IntensityClip_upper'.lower(), None) + self.inverse = params.get('IntensityClip_inverse'.lower(), False) + + def __call__(self, sample): + image = sample['image'] + lower = self.lower if self.lower is not None else [None] * len(self.channels) + upper = self.upper if self.upper is not None else [None] * len(self.channels) + for chn in self.channels: + lower_c, upper_c = lower[chn], upper[chn] + if(lower_c is None): + lower_c = np.percentile(image[chn], 0.05) + if(upper_c is None): + upper_c = np.percentile(image[chn, 99.95]) + image[chn] = np.clip(image[chn], lower_c, upper_c) + sample['image'] = image + return sample + class GammaCorrection(AbstractTransform): """ Apply random gamma correction to given channels. @@ -76,8 +110,9 @@ def __call__(self, sample): img_c = image[chn] v_min = img_c.min() v_max = img_c.max() - img_c = (img_c - v_min)/(v_max - v_min) - img_c = np.power(img_c, gamma_c)*(v_max - v_min) + v_min + if(v_min < v_max): + img_c = (img_c - v_min)/(v_max - v_min) + img_c = np.power(img_c, gamma_c)*(v_max - v_min) + v_min image[chn] = img_c sample['image'] = image diff --git a/pymic/transform/normalize.py b/pymic/transform/normalize.py index 4e493dd..77852d2 100644 --- a/pymic/transform/normalize.py +++ b/pymic/transform/normalize.py @@ -53,9 +53,12 @@ def __call__(self, sample): if(chn_mean is None): if(self.ingore_np): pixels = image[chn][image[chn] > 0] - chn_mean, chn_std = pixels.mean(), pixels.std() + if(len(pixels) > 0): + chn_mean, chn_std = pixels.mean(), pixels.std() + 1e-5 + else: + chn_mean, chn_std = 0.0, 1.0 else: - chn_mean, chn_std = image[chn].mean(), image[chn].std() + chn_mean, chn_std = image[chn].mean(), image[chn].std() + 1e-5 chn_norm = (image[chn] - chn_mean)/chn_std diff --git a/pymic/transform/trans_dict.py b/pymic/transform/trans_dict.py index a7d96fd..e8f1c37 100644 --- a/pymic/transform/trans_dict.py +++ b/pymic/transform/trans_dict.py @@ -14,6 +14,7 @@ 'LabelConvert': LabelConvert, 'LabelConvertNonzero': LabelConvertNonzero, 'LabelToProbability': LabelToProbability, + 'IntensityClip': IntensityClip, 'NormalizeWithMeanStd': NormalizeWithMeanStd, 'NormalizeWithMinMax': NormalizeWithMinMax, 'NormalizeWithPercentiles': NormalizeWithPercentiles, @@ -55,6 +56,7 @@ 'LabelConvertNonzero': LabelConvertNonzero, 'LabelToProbability': LabelToProbability, 'LocalShuffling': LocalShuffling, + 'IntensityClip': IntensityClip, 'NonLinearTransform': NonLinearTransform, 'NormalizeWithMeanStd': NormalizeWithMeanStd, 'NormalizeWithMinMax': NormalizeWithMinMax, diff --git a/pymic/util/evaluation_seg.py b/pymic/util/evaluation_seg.py index ba04a73..836401d 100644 --- a/pymic/util/evaluation_seg.py +++ b/pymic/util/evaluation_seg.py @@ -3,15 +3,19 @@ Evaluation module for segmenation tasks. """ from __future__ import absolute_import, print_function +import argparse import csv import os import sys import pandas as pd import numpy as np +from os.path import join from scipy import ndimage from pymic.io.image_read_write import * from pymic.util.image_process import * -from pymic.util.parse_config import parse_config +from pymic.util.general import is_image_name +from pymic.util.parse_config import parse_config, parse_value_from_string + def binary_dice(s, g, resize = False): @@ -257,108 +261,127 @@ def evaluation(config): When a list is given, each list element should be the root dir of the results of one method. :param evaluation_image_pair: (str) The csv file that provide the segmentation images and the corresponding ground truth images. - :param ground_truth_label_convert_source: (optional, list) The list of source - labels for label conversion in the ground truth. - :param ground_truth_label_convert_target: (optional, list) The list of target - labels for label conversion in the ground truth. - :param segmentation_label_convert_source: (optional, list) The list of source - labels for label conversion in the segmentation. - :param segmentation_label_convert_target: (optional, list) The list of target - labels for label conversion in the segmentation. """ metric_list = config['metric_list'] - label_list = config['label_list'] - label_fuse = config.get('label_fuse', False) - organ_name = config['organ_name'] - gt_root = config['ground_truth_folder_root'] - seg_root = config['segmentation_folder_root'] - if(not(isinstance(seg_root, tuple) or isinstance(seg_root, list))): - seg_root = [seg_root] - image_pair_csv = config['evaluation_image_pair'] - ground_truth_label_convert_source = config.get('ground_truth_label_convert_source', None) - ground_truth_label_convert_target = config.get('ground_truth_label_convert_target', None) - segmentation_label_convert_source = config.get('segmentation_label_convert_source', None) - segmentation_label_convert_target = config.get('segmentation_label_convert_target', None) - - image_items = pd.read_csv(image_pair_csv) - item_num = len(image_items) - - for seg_root_n in seg_root: # for each segmentation method - for metric in metric_list: - score_all_data = [] - name_score_list= [] - for i in range(item_num): - gt_name = image_items.iloc[i, 0] - seg_name = image_items.iloc[i, 1] - # seg_name = seg_name.replace(".nii.gz", "_pred.nii.gz") - gt_full_name = gt_root + '/' + gt_name - seg_full_name = seg_root_n + '/' + seg_name - - s_dict = load_image_as_nd_array(seg_full_name) - g_dict = load_image_as_nd_array(gt_full_name) - s_volume = s_dict["data_array"]; s_spacing = s_dict["spacing"] - g_volume = g_dict["data_array"]; g_spacing = g_dict["spacing"] - # for dim in range(len(s_spacing)): - # assert(s_spacing[dim] == g_spacing[dim]) - if((ground_truth_label_convert_source is not None) and \ - ground_truth_label_convert_target is not None): - g_volume = convert_label(g_volume, ground_truth_label_convert_source, \ - ground_truth_label_convert_target) - - if((segmentation_label_convert_source is not None) and \ - segmentation_label_convert_target is not None): - s_volume = convert_label(s_volume, segmentation_label_convert_source, \ - segmentation_label_convert_target) - - score_vector = get_multi_class_evaluation_score(s_volume, g_volume, label_list, - label_fuse, s_spacing, metric ) - if(len(label_list) > 1): - score_vector.append(np.asarray(score_vector).mean()) - score_all_data.append(score_vector) - name_score_list.append([seg_name] + score_vector) - print(seg_name, score_vector) - score_all_data = np.asarray(score_all_data) - score_mean = score_all_data.mean(axis = 0) - score_std = score_all_data.std(axis = 0) - name_score_list.append(['mean'] + list(score_mean)) - name_score_list.append(['std'] + list(score_std)) + if(not isinstance(metric_list, list)): + metric_list = [metric_list] + label_list = config.get('label_list', None) + if(label_list is None): + label_list = range(1, config["class_number"]) + elif(not isinstance(label_list, list)): + label_list = [label_list] + label_fuse = config.get('label_fuse', False) + output_name = config.get('output_name', None) + gt_root = config['ground_truth_folder_root'] + seg_root = config['segmentation_folder_root'] + image_pair_csv = config.get('evaluation_image_pair', None) + + if(image_pair_csv is not None): + image_pair = pd.read_csv(image_pair_csv) + gt_names, seg_names = image_pair.iloc[:, 0], image_pair.iloc[:, 1] + else: + seg_names = sorted(os.listdir(seg_root)) + seg_names = [item for item in seg_names if is_image_name(item)] + gt_names = seg_names + - # save the result as csv - score_csv = "{0:}/{1:}_{2:}_all.csv".format(seg_root_n, organ_name, metric) - with open(score_csv, mode='w') as csv_file: - csv_writer = csv.writer(csv_file, delimiter=',', - quotechar='"',quoting=csv.QUOTE_MINIMAL) - head = ['image'] + ["class_{0:}".format(i) for i in label_list] - if(len(label_list) > 1): - head = head + ["average"] - csv_writer.writerow(head) - for item in name_score_list: - csv_writer.writerow(item) - - print("{0:} mean ".format(metric), score_mean) - print("{0:} std ".format(metric), score_std) + for metric in metric_list: + print(metric) + score_all_data = [] + name_score_list= [] + for i in range(len(gt_names)): + gt_full_name = join(gt_root, gt_names[i]) + seg_full_name = join(seg_root, seg_names[i]) + s_dict = load_image_as_nd_array(seg_full_name) + g_dict = load_image_as_nd_array(gt_full_name) + s_volume = s_dict["data_array"]; s_spacing = s_dict["spacing"] + g_volume = g_dict["data_array"]; g_spacing = g_dict["spacing"] + # for dim in range(len(s_spacing)): + # assert(s_spacing[dim] == g_spacing[dim]) + + score_vector = get_multi_class_evaluation_score(s_volume, g_volume, label_list, + label_fuse, s_spacing, metric ) + if(len(label_list) > 1): + score_vector.append(np.asarray(score_vector).mean()) + score_all_data.append(score_vector) + name_score_list.append([seg_names[i]] + score_vector) + print(seg_names[i], score_vector) + score_all_data = np.asarray(score_all_data) + score_mean = score_all_data.mean(axis = 0) + score_std = score_all_data.std(axis = 0) + name_score_list.append(['mean'] + list(score_mean)) + name_score_list.append(['std'] + list(score_std)) + + # save the result as csv + if(output_name is None): + output_name = "{0:}/eval_{1:}.csv".format(seg_root, metric) + with open(output_name, mode='w') as csv_file: + csv_writer = csv.writer(csv_file, delimiter=',', + quotechar='"',quoting=csv.QUOTE_MINIMAL) + head = ['image'] + ["class_{0:}".format(i) for i in label_list] + if(len(label_list) > 1): + head = head + ["average"] + csv_writer.writerow(head) + for item in name_score_list: + csv_writer.writerow(item) + + print("{0:} mean ".format(metric), score_mean) + print("{0:} std ".format(metric), score_std) def main(): """ Main function for evaluation of segmentation results. - A configuration file is needed for runing. e.g., + You can use a configuration file for runing. e.g., .. code-block:: none - pymic_evaluate_cls config.cfg + pymic_evaluate_seg -cfg config.cfg The configuration file should have an `evaluation` section. See :mod:`pymic.util.evaluation_seg.evaluation` for details of the configuration required. + + In addition, you can also provide a list of args in the command if -cfg is not used. For example: + + .. code-block:: none + + pymic_evaluate_seg -metric dice -cls_index 255 -gt_dir ground_truth_dir -seg_dir segmentation_dir + """ - if(len(sys.argv) < 2): - print('Number of arguments should be 2. e.g.') - print(' pymic_evaluate_seg config.cfg') - exit() - config_file = str(sys.argv[1]) - assert(os.path.isfile(config_file)) - config = parse_config(config_file)['evaluation'] + parser = argparse.ArgumentParser() + parser.add_argument("-cfg", help="configuration file for evaluation", + required=False, default=None) + parser.add_argument("-metric", help="evaluation metrics, e.g., dice, or [dice, assd]", + required=False, default=None) + parser.add_argument("-cls_num", help="number of classes", + required=False, default=None) + parser.add_argument("-cls_index", help="The class index for evaluation, e.g., 255, [1, 2]", + required=False, default=None) + parser.add_argument("-gt_dir", help="path of folder for ground truth", + required=False, default=None) + parser.add_argument("-seg_dir", help="path of folder for segmentation", + required=False, default=None) + parser.add_argument("-name_pair", help="the .csv file for name mapping in case" + " the names of one case are different in the gt_dir " + " and seg_dir", + required=False, default=None) + parser.add_argument("-out", help="the output .csv file name", + required=False, default=None) + args = parser.parse_args() + print(args) + if(args.cfg is not None): + config = parse_config(args.cfg)['evaluation'] + else: + config = {} + config['metric_list'] = parse_value_from_string(args.metric) + config['label_list'] = None if args.cls_index is None else parse_value_from_string(args.cls_index) + config['class_number']= None if args.cls_num is None else parse_value_from_string(args.cls_num) + config['ground_truth_folder_root'] = args.gt_dir + config['segmentation_folder_root'] = args.seg_dir + config['evaluation_image_pair'] = args.name_pair + config['output_name'] = args.out + print(config) evaluation(config) - + if __name__ == '__main__': main() diff --git a/pymic/util/general.py b/pymic/util/general.py index 075d2e1..cf52746 100644 --- a/pymic/util/general.py +++ b/pymic/util/general.py @@ -26,7 +26,15 @@ def tensor_shape_match(a,b): return False return True - +def is_image_name(x): + valid_names = ["jpg", "jpeg", "png", "bmp", "nii.gz", + "tif", "nii", "nii.gz", "mha"] + valid = False + for item in valid_names: + if(x.endswith(item)): + valid = True + break + return valid def get_one_hot_seg(label, class_num): """ From 64cbca070df47dcbd8458ff4217af63270185692 Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 4 May 2023 21:59:08 +0800 Subject: [PATCH 07/86] update train and inference mode 1, allow unlabeled validation dataset (for self-supervised training) 2, add gaussian weight map for inference --- pymic/net_run/agent_seg.py | 15 +++++++++++-- pymic/net_run/infer_func.py | 43 ++++++++++++++++++++++++------------- 2 files changed, 41 insertions(+), 17 deletions(-) diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 1bdad06..2e703d4 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -63,10 +63,14 @@ def get_stage_dataset_from_config(self, stage): data_transform = transforms.Compose(self.transform_list) csv_file = self.config['dataset'].get(stage + '_csv', None) + if(stage == 'test'): + with_label = False + else: + with_label = self.config['dataset'].get(stage + '_label', True) dataset = NiftyDataset(root_dir = root_dir, csv_file = csv_file, modal_num = modal_num, - with_label= not (stage == 'test'), + with_label= with_label, transform = data_transform, task = self.task_type) return dataset @@ -189,8 +193,15 @@ def training(self): def validation(self): class_num = self.config['network']['class_num'] if(self.inferer is None): - infer_cfg = self.config['testing'] + infer_cfg = {} infer_cfg['class_num'] = class_num + infer_cfg['sliding_window_enable'] = self.config['testing'].get('sliding_window_enable', False) + if(infer_cfg['sliding_window_enable']): + patch_size = self.config['dataset'].get('patch_size', None) + if(patch_size is None): + patch_size = self.config['testing']['sliding_window_size'] + infer_cfg['sliding_window_size'] = patch_size + infer_cfg['sliding_window_stride'] = [i//2 for i in patch_size] self.inferer = Inferer(infer_cfg) valid_loss_list = [] diff --git a/pymic/net_run/infer_func.py b/pymic/net_run/infer_func.py index 43162d0..f6ed515 100644 --- a/pymic/net_run/infer_func.py +++ b/pymic/net_run/infer_func.py @@ -2,6 +2,8 @@ from __future__ import print_function, division import torch +import numpy as np +from scipy.ndimage.filters import gaussian_filter from torch.nn.functional import interpolate class Inferer(object): @@ -48,9 +50,19 @@ def __get_prediction_number_and_scales(self, tempx): output_num, scales = 1, None return output_num, scales + def __get_gaussian_weight_map(self, window_size, sigma_scale = 1.0/8): + w = np.zeros(window_size) + center = [i//2 for i in window_size] + sigmas = [i*sigma_scale for i in window_size] + w[tuple(center)] = 1.0 + w = gaussian_filter(w, sigmas, 0, mode='constant', cval=0) + return w + def __infer_with_sliding_window(self, image): """ - Use sliding window to predict segmentation for large images. + Use sliding window to predict segmentation for large images. The outupt of each + sliding window is weighted by a Gaussian map that hihglights contributions of windows + with a centroid closer to a given pixel. Note that the network may output a list of tensors with difference sizes. """ window_size = [x for x in self.config['sliding_window_size']] @@ -86,9 +98,10 @@ def __infer_with_sliding_window(self, image): crop_start_list.append([d_min, h_min, w_min]) output_shape = [batch_size, class_num] + img_shape - mask_shape = [batch_size, class_num] + window_size - counter = torch.zeros(output_shape).to(image.device) - temp_mask = torch.ones(mask_shape).to(image.device) + weight = torch.zeros(output_shape).to(image.device) + temp_w = self.__get_gaussian_weight_map(window_size) + temp_w = np.broadcast_to(temp_w, [batch_size, class_num] + window_size) + temp_w = torch.from_numpy(temp_w).to(image.device) temp_in_shape = img_full_shape[:2] + window_size tempx = torch.ones(temp_in_shape).to(image.device) out_num, scale_list = self.__get_prediction_number_and_scales(tempx) @@ -104,12 +117,12 @@ def __infer_with_sliding_window(self, image): if(isinstance(patch_out, (tuple, list))): patch_out = patch_out[0] if(img_dim == 2): - output[:, :, c0[0]:c1[0], c0[1]:c1[1]] += patch_out - counter[:, :, c0[0]:c1[0], c0[1]:c1[1]] += temp_mask + output[:, :, c0[0]:c1[0], c0[1]:c1[1]] += patch_out * temp_w + weight[:, :, c0[0]:c1[0], c0[1]:c1[1]] += temp_w else: - output[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] += patch_out - counter[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] += temp_mask - return output/counter + output[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] += patch_out * temp_w + weight[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] += temp_w + return output/weight else: # for multiple prediction output_list= [] for i in range(out_num): @@ -129,14 +142,14 @@ def __infer_with_sliding_window(self, image): c0_i = [int(c0[d] * scale_list[i][d]) for d in range(img_dim)] c1_i = [int(c1[d] * scale_list[i][d]) for d in range(img_dim)] if(img_dim == 2): - output_list[i][:, :, c0_i[0]:c1_i[0], c0_i[1]:c1_i[1]] += patch_out[i] - counter[:, :, c0[0]:c1[0], c0[1]:c1[1]] += temp_mask + output_list[i][:, :, c0_i[0]:c1_i[0], c0_i[1]:c1_i[1]] += patch_out[i] * temp_w + weight[:, :, c0[0]:c1[0], c0[1]:c1[1]] += temp_w else: - output_list[i][:, :, c0_i[0]:c1_i[0], c0_i[1]:c1_i[1], c0_i[2]:c1_i[2]] += patch_out[i] - counter[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] += temp_mask + output_list[i][:, :, c0_i[0]:c1_i[0], c0_i[1]:c1_i[1], c0_i[2]:c1_i[2]] += patch_out[i] * temp_w + weight[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] += temp_w for i in range(out_num): - counter_i = interpolate(counter, scale_factor = scale_list[i]) - output_list[i] = output_list[i] / counter_i + weight_i = interpolate(weight, scale_factor = scale_list[i]) + output_list[i] = output_list[i] / weight_i return output_list def run(self, model, image): From c29631ce7f3aa51b7fced92c54b34fec69417d80 Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 4 May 2023 22:05:03 +0800 Subject: [PATCH 08/86] update transform and image process update transform and image process --- pymic/transform/crop.py | 39 +++++--------- pymic/transform/pad.py | 2 +- pymic/util/image_process.py | 101 ++++++++++++++++++++++++++++++++++-- pymic/util/parse_config.py | 10 ++++ 4 files changed, 121 insertions(+), 31 deletions(-) diff --git a/pymic/transform/crop.py b/pymic/transform/crop.py index 1012712..b9345c8 100644 --- a/pymic/transform/crop.py +++ b/pymic/transform/crop.py @@ -198,39 +198,28 @@ def __init__(self, params): def _get_crop_param(self, sample): image = sample['image'] - input_shape = image.shape - input_dim = len(input_shape) - 1 + chns = image.shape[0] + input_shape = image.shape[1:] + input_dim = len(input_shape) assert(input_dim == len(self.output_size)) - temp_output_size = self.output_size - if(input_dim == 3 and self.output_size[0] is None): - # note that output size is [D, H, W] and input is [C, D, H, W] - temp_output_size = [input_shape[1]] + self.output_size[1:] - crop_margin = [input_shape[i + 1] - temp_output_size[i]\ - for i in range(input_dim)] + crop_margin = [input_shape[i] - self.output_size[i] for i in range(input_dim)] crop_min = [0 if item == 0 else random.randint(0, item) for item in crop_margin] + crop_max = [crop_min[i] + self.output_size[i] for i in range(input_dim)] if(self.fg_focus and random.random() < self.fg_ratio): - label = sample['label'] + label = sample['label'][0] mask = np.zeros_like(label) for temp_lab in self.mask_label: mask = np.maximum(mask, label == temp_lab) - if(mask.sum() == 0): - bb_min = [0] * (input_dim + 1) - bb_max = mask.shape - else: - bb_min, bb_max = get_ND_bounding_box(mask) - bb_min, bb_max = bb_min[1:], bb_max[1:] - crop_min = [random.randint(bb_min[i], bb_max[i]) - int(temp_output_size[i]/2) \ - for i in range(input_dim)] - crop_min = [max(0, item) for item in crop_min] - crop_min = [min(crop_min[i], input_shape[i+1] - temp_output_size[i]) \ - for i in range(input_dim)] - - crop_max = [crop_min[i] + temp_output_size[i] \ - for i in range(input_dim)] + if(mask.max() > 0): + crop_min, crop_max = get_random_box_from_mask(mask, self.output_size) + # to avoid Typeerror: object of type int64 is not json serializable + crop_min = [int(i) for i in crop_min] + crop_max = [int(i) for i in crop_max] crop_min = [0] + crop_min - crop_max = list(input_shape[0:1]) + crop_max - sample['RandomCrop_Param'] = json.dumps((input_shape, crop_min, crop_max)) + crop_max = [chns] + crop_max + + sample['RandomCrop_Param'] = json.dumps((image.shape, crop_min, crop_max)) return sample, crop_min, crop_max def _get_param_for_inverse_transform(self, sample): diff --git a/pymic/transform/pad.py b/pymic/transform/pad.py index 91cf6da..c9b75fe 100644 --- a/pymic/transform/pad.py +++ b/pymic/transform/pad.py @@ -22,7 +22,7 @@ class Pad(AbstractTransform): following fields: :param `Pad_output_size`: (list/tuple) The output size along each spatial axis. - :param `Pad_ceil_mode`: (optional, bool) If true (by default), the real output size will + :param `Pad_ceil_mode`: (optional, bool) If true, the real output size will be the minimal integer multiples of output_size higher than the input size. For example, the input image has a shape of [3, 100, 100], `Pad_output_size` = [32, 32], and the real output size will be [3, 128, 128] if `Pad_ceil_mode` = True. diff --git a/pymic/util/image_process.py b/pymic/util/image_process.py index 9a484e8..d5d7a7e 100644 --- a/pymic/util/image_process.py +++ b/pymic/util/image_process.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, print_function +import random import numpy as np import SimpleITK as sitk from scipy import ndimage @@ -96,7 +97,7 @@ def set_ND_volume_roi_with_bounding_box_range(volume, bb_min, bb_max, sub_volume raise ValueError("array dimension should be 2 to 5") return out -def crop_and_pad_ND_array_to_desired_shape(image, out_shape, pad_mod): +def crop_and_pad_ND_array_to_desired_shape(image, out_shape, pad_mod='reflect'): """ Crop and pad an image to a given shape. @@ -136,6 +137,96 @@ def crop_and_pad_ND_array_to_desired_shape(image, out_shape, pad_mod): return image_pad +def random_crop_ND_volume(volume, out_shape): + """ + randomly crop a volume with to a given shape. + + :param volume: The input ND array. + :param out_shape: (list) The desired output shape. + """ + in_shape = volume.shape + dim = len(in_shape) + + # pad the image first if the input size is smaller than the output size + pad_shape = [max(out_shape[i], in_shape[i]) for i in range(dim)] + mgnp = [pad_shape[i] - in_shape[i] for i in range(dim)] + if(max(mgnp) == 0): + image_pad = volume + else: + ml = [int(mgnp[i]/2) for i in range(dim)] + mr = [mgnp[i] - ml[i] for i in range(dim)] + pad = [(ml[i], mr[i]) for i in range(dim)] + pad = tuple(pad) + image_pad = np.pad(volume, pad, 'reflect') + + bb_min = [random.randint(0, pad_shape[i] - out_shape[i]) for i in range(dim)] + bb_max = [bb_min[i] + out_shape[i] for i in range(dim)] + crop_volume = crop_ND_volume_with_bounding_box(image_pad, bb_min, bb_max) + return crop_volume + +def get_random_box_from_mask(mask, out_shape): + mask_shape = mask.shape + dim = len(out_shape) + left_margin = [int(out_shape[i]/2) for i in range(dim)] + right_margin= [mask_shape[i] - (out_shape[i] - left_margin[i]) + 1 for i in range(dim)] + + valid_center_shape = [right_margin[i] - left_margin[i] for i in range(dim)] + valid_mask = np.zeros(mask_shape) + valid_mask = set_ND_volume_roi_with_bounding_box_range(valid_mask, + left_margin, right_margin, np.ones(valid_center_shape)) + valid_mask = valid_mask * mask + + indexes = np.where(valid_mask) + voxel_num = len(indexes[0]) + j = random.randint(0, voxel_num - 1) + bb_c = [indexes[i][j] for i in range(dim)] + bb_min = [bb_c[i] - left_margin[i] for i in range(dim)] + bb_max = [bb_min[i] + out_shape[i] for i in range(dim)] + return bb_min, bb_max + +def random_crop_ND_volume_with_mask(volume, out_shape, mask): + """ + randomly crop a volume with to a given shape. + + :param volume: The input ND array. + :param out_shape: (list) The desired output shape. + :param mask: A binary ND array. Default is None. If not None, + the center of the cropped region should be limited to the mask region. + """ + in_shape = volume.shape + dim = len(in_shape) + # pad the image first if the input size is smaller than the output size + pad_shape = [max(out_shape[i], in_shape[i]) for i in range(dim)] + mgnp = [pad_shape[i] - in_shape[i] for i in range(dim)] + if(max(mgnp) == 0): + image_pad, mask_pad = volume, mask + else: + ml = [int(mgnp[i]/2) for i in range(dim)] + mr = [mgnp[i] - ml[i] for i in range(dim)] + pad = [(ml[i], mr[i]) for i in range(dim)] + pad = tuple(pad) + image_pad = np.pad(volume, pad, 'reflect') + mask_pad = np.pad(mask, pad, 'reflect') + + bb_min, bb_max = get_random_box_from_mask(mask_pad, out_shape) + # left_margin = [int(out_shape[i]/2) for i in range(dim)] + # right_margin= [pad_shape[i] - (out_shape[i] - left_margin[i]) + 1 for i in range(dim)] + + # valid_center_shape = [right_margin[i] - left_margin[i] for i in range(dim)] + # valid_mask = np.zeros(pad_shape) + # valid_mask = set_ND_volume_roi_with_bounding_box_range(valid_mask, + # left_margin, right_margin, np.ones(valid_center_shape)) + # valid_mask = valid_mask * mask_pad + + # indexes = np.where(valid_mask) + # voxel_num = len(indexes[0]) + # j = random.randint(0, voxel_num) + # bb_c = [indexes[i][j] for i in range(dim)] + # bb_min = [bb_c[i] - left_margin[i] for i in range(dim)] + # bb_max = [bb_min[i] + out_shape[i] for i in range(dim)] + crop_volume = crop_ND_volume_with_bounding_box(image_pad, bb_min, bb_max) + return crop_volume + def get_largest_k_components(image, k = 1): """ Get the largest K components from 2D or 3D binary image. @@ -200,11 +291,11 @@ def convert_label(label, source_list, target_list): :param target_list: A list of target labels, e.g. [0, 1, 2, 3] """ assert(len(source_list) == len(target_list)) - label_converted = np.zeros_like(label) + label_converted = label * 1 for i in range(len(source_list)): - label_temp = np.asarray(label == source_list[i], label.dtype) - label_temp = label_temp * target_list[i] - label_converted = label_converted + label_temp + label_s = np.asarray(label == source_list[i], label.dtype) + label_t = label_s * target_list[i] + label_converted[label_s > 0] = label_t[label_s > 0] return label_converted def resample_sitk_image_to_given_spacing(image, spacing, order): diff --git a/pymic/util/parse_config.py b/pymic/util/parse_config.py index 18afd08..a12cc76 100644 --- a/pymic/util/parse_config.py +++ b/pymic/util/parse_config.py @@ -108,6 +108,16 @@ def synchronize_config(config): data_cfg["LabelToProbability_class_num".lower()] = net_cfg["class_num"] if "PartialLabelToProbability" in data_cfg['train_transform']: data_cfg["PartialLabelToProbability_class_num".lower()] = net_cfg["class_num"] + patch_size = data_cfg.get('patch_size', None) + if(patch_size is not None): + if('Pad' in data_cfg['train_transform']): + data_cfg['Pad_output_size'.lower()] = patch_size + if('CenterCrop' in data_cfg['train_transform']): + data_cfg['CenterCrop_output_size'.lower()] = patch_size + if('RandomCrop' in data_cfg['train_transform']): + data_cfg['RandomCrop_output_size'.lower()] = patch_size + if('RandomResizedCrop' in data_cfg['train_transform']): + data_cfg['RandomResizedCrop_output_size'.lower()] = patch_size config['dataset'] = data_cfg config['network'] = net_cfg return config From d0e800aec25123f7856b1b21f518c24d900202c2 Mon Sep 17 00:00:00 2001 From: taigw Date: Sun, 7 May 2023 21:26:10 +0800 Subject: [PATCH 09/86] update transform and inference mode 1, add CropwithForeground 2, all large batch size for sliding window inference 3, add BinaryDice loss and GroupDiceLoss --- pymic/loss/loss_dict_seg.py | 5 +- pymic/loss/seg/dice.py | 53 +++++++++++++++++++++ pymic/net_run/infer_func.py | 87 ++++++++++++++++++++++------------- pymic/transform/crop.py | 40 ++++++++++++++++ pymic/transform/trans_dict.py | 1 + 5 files changed, 154 insertions(+), 32 deletions(-) diff --git a/pymic/loss/loss_dict_seg.py b/pymic/loss/loss_dict_seg.py index 97c537e..fd72ce4 100644 --- a/pymic/loss/loss_dict_seg.py +++ b/pymic/loss/loss_dict_seg.py @@ -23,7 +23,8 @@ from __future__ import print_function, division import torch.nn as nn from pymic.loss.seg.ce import CrossEntropyLoss, GeneralizedCELoss -from pymic.loss.seg.dice import DiceLoss, FocalDiceLoss, NoiseRobustDiceLoss +from pymic.loss.seg.dice import DiceLoss, FocalDiceLoss, \ + NoiseRobustDiceLoss, BinaryDiceLoss, GroupDiceLoss from pymic.loss.seg.exp_log import ExpLogLoss from pymic.loss.seg.mse import MSELoss, MAELoss from pymic.loss.seg.slsr import SLSRLoss @@ -32,8 +33,10 @@ 'CrossEntropyLoss': CrossEntropyLoss, 'GeneralizedCELoss': GeneralizedCELoss, 'DiceLoss': DiceLoss, + 'BinaryDiceLoss': BinaryDiceLoss, 'FocalDiceLoss': FocalDiceLoss, 'NoiseRobustDiceLoss': NoiseRobustDiceLoss, + 'GroupDiceLoss': GroupDiceLoss, 'ExpLogLoss': ExpLogLoss, 'MAELoss': MAELoss, 'MSELoss': MSELoss, diff --git a/pymic/loss/seg/dice.py b/pymic/loss/seg/dice.py index c3a1134..350e0c4 100644 --- a/pymic/loss/seg/dice.py +++ b/pymic/loss/seg/dice.py @@ -31,6 +31,59 @@ def forward(self, loss_input_dict): dice_loss = 1.0 - dice_score.mean() return dice_loss +class BinaryDiceLoss(AbstractSegLoss): + ''' + Fuse all the foreground classes together and calculate the Dice value. + ''' + def __init__(self, params = None): + super(BinaryDiceLoss, self).__init__(params) + + def forward(self, loss_input_dict): + predict = loss_input_dict['prediction'] + soft_y = loss_input_dict['ground_truth'] + + if(isinstance(predict, (list, tuple))): + predict = predict[0] + if(self.softmax): + predict = nn.Softmax(dim = 1)(predict) + predict = 1.0 - predict[:, :1, :, :, :] + soft_y = 1.0 - soft_y[:, :1, :, :, :] + predict = reshape_tensor_to_2D(predict) + soft_y = reshape_tensor_to_2D(soft_y) + dice_score = get_classwise_dice(predict, soft_y) + dice_loss = 1.0 - dice_score.mean() + return dice_loss + +class GroupDiceLoss(AbstractSegLoss): + ''' + Fuse all the foreground classes together and calculate the Dice value. + ''' + def __init__(self, params = None): + super(GroupDiceLoss, self).__init__(params) + self.group = 2 + + def forward(self, loss_input_dict): + predict = loss_input_dict['prediction'] + soft_y = loss_input_dict['ground_truth'] + + if(isinstance(predict, (list, tuple))): + predict = predict[0] + if(self.softmax): + predict = nn.Softmax(dim = 1)(predict) + predict = reshape_tensor_to_2D(predict) + soft_y = reshape_tensor_to_2D(soft_y) + num_class = list(predict.size())[1] + cls_per_group = (num_class - 1) // self.group + loss_all = 0.0 + for g in range(self.group): + c0 = 1 + g*cls_per_group + c1 = min(num_class, c0 + cls_per_group) + pred_g = torch.sum(predict[:, c0:c1], dim = 1, keepdim = True) + y_g = torch.sum( soft_y[:, c0:c1], dim = 1, keepdim = True) + loss_all += 1.0 - get_classwise_dice(pred_g, y_g)[0] + avg_loss = loss_all / self.group + return avg_loss + class FocalDiceLoss(AbstractSegLoss): """ Focal Dice according to the following paper: diff --git a/pymic/net_run/infer_func.py b/pymic/net_run/infer_func.py index f6ed515..d2b9af2 100644 --- a/pymic/net_run/infer_func.py +++ b/pymic/net_run/infer_func.py @@ -63,13 +63,16 @@ def __infer_with_sliding_window(self, image): Use sliding window to predict segmentation for large images. The outupt of each sliding window is weighted by a Gaussian map that hihglights contributions of windows with a centroid closer to a given pixel. - Note that the network may output a list of tensors with difference sizes. + Note that the network may output a list of tensors with difference sizes for multi-scale prediction. """ window_size = [x for x in self.config['sliding_window_size']] window_stride = [x for x in self.config['sliding_window_stride']] + window_batch = self.config.get('sliding_window_batch', 1) class_num = self.config['class_num'] img_full_shape = list(image.shape) batch_size = img_full_shape[0] + assert(batch_size == 1 or window_batch == 1) + img_chns = img_full_shape[1] img_shape = img_full_shape[2:] img_dim = len(img_shape) if(img_dim != 2 and img_dim !=3): @@ -105,23 +108,37 @@ def __infer_with_sliding_window(self, image): temp_in_shape = img_full_shape[:2] + window_size tempx = torch.ones(temp_in_shape).to(image.device) out_num, scale_list = self.__get_prediction_number_and_scales(tempx) + + window_num = len(crop_start_list) + assert(window_num >= window_batch) + patches_shape = [window_batch, img_chns] + window_size + patches_in = torch.ones(patches_shape).to(image.device) if(out_num == 1): # for a single prediction output = torch.zeros(output_shape).to(image.device) - for c0 in crop_start_list: - c1 = [c0[d] + window_size[d] for d in range(img_dim)] - if(img_dim == 2): - patch_in = image[:, :, c0[0]:c1[0], c0[1]:c1[1]] - else: - patch_in = image[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] - patch_out = self.model(patch_in) - if(isinstance(patch_out, (tuple, list))): - patch_out = patch_out[0] - if(img_dim == 2): - output[:, :, c0[0]:c1[0], c0[1]:c1[1]] += patch_out * temp_w - weight[:, :, c0[0]:c1[0], c0[1]:c1[1]] += temp_w - else: - output[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] += patch_out * temp_w - weight[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] += temp_w + for w_i in range(0, window_num, window_batch): + for k in range(window_batch): + if(w_i + k >= window_num): + break + c0 = crop_start_list[w_i + k] + c1 = [c0[d] + window_size[d] for d in range(img_dim)] + if(img_dim == 2): + patches_in[k] = image[:, :, c0[0]:c1[0], c0[1]:c1[1]] + else: + patches_in[k] = image[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] + patches_out = self.model(patches_in) + if(isinstance(patches_out, (tuple, list))): + patches_out = patches_out[0] + for k in range(window_batch): + if(w_i + k >= window_num): + break + c0 = crop_start_list[w_i + k] + c1 = [c0[d] + window_size[d] for d in range(img_dim)] + if(img_dim == 2): + output[:, :, c0[0]:c1[0], c0[1]:c1[1]] += patches_out[k] * temp_w + weight[:, :, c0[0]:c1[0], c0[1]:c1[1]] += temp_w + else: + output[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] += patches_out[k] * temp_w + weight[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] += temp_w return output/weight else: # for multiple prediction output_list= [] @@ -130,23 +147,31 @@ def __infer_with_sliding_window(self, image): [int(img_shape[d] * scale_list[i][d]) for d in range(img_dim)] output_list.append(torch.zeros(output_shape_i).to(image.device)) - for c0 in crop_start_list: - c1 = [c0[d] + window_size[d] for d in range(img_dim)] - if(img_dim == 2): - patch_in = image[:, :, c0[0]:c1[0], c0[1]:c1[1]] - else: - patch_in = image[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] - patch_out = self.model(patch_in) - - for i in range(out_num): - c0_i = [int(c0[d] * scale_list[i][d]) for d in range(img_dim)] - c1_i = [int(c1[d] * scale_list[i][d]) for d in range(img_dim)] + for w_i in range(0, window_num, window_batch): + for k in range(window_batch): + if(w_i + k >= window_num): + break + c0 = crop_start_list[w_i + k] + c1 = [c0[d] + window_size[d] for d in range(img_dim)] if(img_dim == 2): - output_list[i][:, :, c0_i[0]:c1_i[0], c0_i[1]:c1_i[1]] += patch_out[i] * temp_w - weight[:, :, c0[0]:c1[0], c0[1]:c1[1]] += temp_w + patches_in[k] = image[:, :, c0[0]:c1[0], c0[1]:c1[1]] else: - output_list[i][:, :, c0_i[0]:c1_i[0], c0_i[1]:c1_i[1], c0_i[2]:c1_i[2]] += patch_out[i] * temp_w - weight[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] += temp_w + patches_in[k] = image[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] + patches_out = self.model(patches_in) + + for i in range(out_num): + for k in range(window_batch): + if(w_i + k >= window_num): + break + c0 = crop_start_list[w_i + k] + c0_i = [int(c0[d] * scale_list[i][d]) for d in range(img_dim)] + c1_i = [int(c1[d] * scale_list[i][d]) for d in range(img_dim)] + if(img_dim == 2): + output_list[i][:, :, c0_i[0]:c1_i[0], c0_i[1]:c1_i[1]] += patches_out[i][k] * temp_w + weight[:, :, c0[0]:c1[0], c0[1]:c1[1]] += temp_w + else: + output_list[i][:, :, c0_i[0]:c1_i[0], c0_i[1]:c1_i[1], c0_i[2]:c1_i[2]] += patches_out[i][k] * temp_w + weight[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] += temp_w for i in range(out_num): weight_i = interpolate(weight, scale_factor = scale_list[i]) output_list[i] = output_list[i] / weight_i diff --git a/pymic/transform/crop.py b/pymic/transform/crop.py index b9345c8..95e3489 100644 --- a/pymic/transform/crop.py +++ b/pymic/transform/crop.py @@ -164,7 +164,47 @@ def _get_param_for_inverse_transform(self, sample): params = json.loads(sample['CropWithBoundingBox_Param']) return params +class CropWithForeground(CenterCrop): + """ + Crop the image (shape [C, D, H, W] or [C, H, W]) based on a bounding box. + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `CropWithBoundingBox_start`: (None, or list/tuple) The start index + along each spatial axis. If None, calculate the start index automatically + so that the cropped region is centered at the non-zero region. + :param `CropWithBoundingBox_output_size`: (None or tuple/list): + Desired spatial output size. + If None, set it as the size of bounding box of non-zero region. + :param `CropWithBoundingBox_inverse`: (optional, bool) Is inverse transform needed for inference. + Default is `True`. + """ + def __init__(self, params): + self.labels = params.get('CropWithForeground_labels'.lower(), None) + self.margin = params.get('CropWithForeground_margin'.lower(), [5, 10, 10]) + self.inverse = params.get('CropWithForeground_inverse'.lower(), True) + self.task = params['task'] + + def _get_crop_param(self, sample): + image = sample['image'] + label = sample['label'] + input_shape = sample['image'].shape + bb_min, bb_max = get_ND_bounding_box(label, margin=[0] + self.margin) + bb_max[0] = input_shape[0] + + sample['CropWithForeground_Param'] = json.dumps((input_shape, bb_min, bb_max)) + + return sample, bb_min, bb_max + + def _get_param_for_inverse_transform(self, sample): + if(isinstance(sample['CropWithForeground_Param'], list) or \ + isinstance(sample['CropWithForeground_Param'], tuple)): + params = json.loads(sample['CropWithForeground_Param'][0]) + else: + params = json.loads(sample['CropWithForeground_Param']) + return params + class RandomCrop(CenterCrop): """Randomly crop the input image (shape [C, D, H, W] or [C, H, W]). diff --git a/pymic/transform/trans_dict.py b/pymic/transform/trans_dict.py index e8f1c37..c1ecdc2 100644 --- a/pymic/transform/trans_dict.py +++ b/pymic/transform/trans_dict.py @@ -46,6 +46,7 @@ 'ChannelWiseThreshold': ChannelWiseThreshold, 'ChannelWiseThresholdWithNormalize': ChannelWiseThresholdWithNormalize, 'CropWithBoundingBox': CropWithBoundingBox, + 'CropWithForeground': CropWithForeground, 'CenterCrop': CenterCrop, 'GrayscaleToRGB': GrayscaleToRGB, 'GammaCorrection': GammaCorrection, From f6be3eafa7b79e8f40cc18708610ab33c7adc69e Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 16 May 2023 09:38:14 +0800 Subject: [PATCH 10/86] update checkpoint save code Save the best checkpoint immediately --- pymic/net/net_dict_seg.py | 31 +++++++++++++++++++++++++++++-- pymic/net_run/agent_abstract.py | 2 ++ pymic/net_run/agent_cls.py | 19 +++++++++---------- pymic/net_run/agent_seg.py | 21 +++++++++++---------- 4 files changed, 51 insertions(+), 22 deletions(-) diff --git a/pymic/net/net_dict_seg.py b/pymic/net/net_dict_seg.py index fc7692f..e6c10bd 100644 --- a/pymic/net/net_dict_seg.py +++ b/pymic/net/net_dict_seg.py @@ -21,10 +21,24 @@ from pymic.net.net2d.unet2d_attention import AttentionUNet2D from pymic.net.net2d.unet2d_nest import NestedUNet2D from pymic.net.net2d.unet2d_scse import UNet2D_ScSE +from pymic.net.net2d.trans2d.transunet import TransUNet +from pymic.net.net2d.trans2d.swinunet import SwinUNet from pymic.net.net3d.unet2d5 import UNet2D5 from pymic.net.net3d.unet3d import UNet3D from pymic.net.net3d.unet3d_scse import UNet3D_ScSE from pymic.net.net3d.unet3d_dual_branch import UNet3D_DualBranch +from pymic.net.net3d.trans3d.nnFormer_wrap import nnFormer_wrap +from pymic.net.net3d.trans3d.unetr import UNETR +from pymic.net.net3d.trans3d.unetr_pp import UNETR_PP +from pymic.net.net3d.trans3d.MedFormer_v1 import MedFormerV1 +from pymic.net.net3d.trans3d.MedFormer_v2 import MedFormerV2 +from pymic.net.net3d.trans3d.MedFormer_v3 import MedFormerV3 +from pymic.net.net3d.trans3d.MedFormer_va1 import MedFormerVA1 +from pymic.net.net3d.trans3d.HiFormer_v1 import HiFormer_v1 +from pymic.net.net3d.trans3d.HiFormer_v2 import HiFormer_v2 +from pymic.net.net3d.trans3d.HiFormer_v3 import HiFormer_v3 +from pymic.net.net3d.trans3d.HiFormer_v4 import HiFormer_v4 +from pymic.net.net3d.trans3d.HiFormer_v5 import HiFormer_v5 SegNetDict = { 'UNet2D': UNet2D, @@ -34,9 +48,22 @@ 'AttentionUNet2D': AttentionUNet2D, 'NestedUNet2D': NestedUNet2D, 'UNet2D_ScSE': UNet2D_ScSE, + 'TransUNet': TransUNet, + 'SwinUNet': SwinUNet, 'UNet2D5': UNet2D5, 'UNet3D': UNet3D, 'UNet3D_ScSE': UNet3D_ScSE, - 'UNet3D_DualBranch': UNet3D_DualBranch - + 'UNet3D_DualBranch': UNet3D_DualBranch, + 'nnFormer': nnFormer_wrap, + 'UNETR': UNETR, + 'UNETR_PP': UNETR_PP, + 'MedFormerV1': MedFormerV1, + 'MedFormerV2': MedFormerV2, + 'MedFormerV3': MedFormerV3, + 'MedFormerVA1':MedFormerVA1, + 'HiFormer_v1': HiFormer_v1, + 'HiFormer_v2': HiFormer_v2, + 'HiFormer_v3': HiFormer_v3, + 'HiFormer_v4': HiFormer_v4, + 'HiFormer_v5': HiFormer_v5 } diff --git a/pymic/net_run/agent_abstract.py b/pymic/net_run/agent_abstract.py index 109eabc..7a49a2b 100644 --- a/pymic/net_run/agent_abstract.py +++ b/pymic/net_run/agent_abstract.py @@ -148,6 +148,8 @@ def get_checkpoint_name(self): with open(txt_name, 'r') as txt_file: it_num = txt_file.read().replace('\n', '') ckpt_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, it_num) + if(ckpt_mode == 1 and not os.path.isfile(ckpt_name)): + ckpt_name = "{0:}/{1:}_best.pt".format(ckpt_dir, ckpt_prefix) else: ckpt_name = self.config['testing']['ckpt_name'] return ckpt_name diff --git a/pymic/net_run/agent_cls.py b/pymic/net_run/agent_cls.py index 1ded3b1..ee4e25b 100644 --- a/pymic/net_run/agent_cls.py +++ b/pymic/net_run/agent_cls.py @@ -285,6 +285,15 @@ def train_valid(self): self.best_model_wts = copy.deepcopy(self.net.module.state_dict()) else: self.best_model_wts = copy.deepcopy(self.net.state_dict()) + save_dict = {'iteration': self.max_val_it, + 'valid_pred': self.max_val_score, + 'model_state_dict': self.best_model_wts, + 'optimizer_state_dict': self.optimizer.state_dict()} + save_name = "{0:}/{1:}_best.pt".format(ckpt_dir, ckpt_prefix) + torch.save(save_dict, save_name) + txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefix), 'wt') + txt_file.write(str(self.max_val_it)) + txt_file.close() stop_now = True if(early_stop_it is not None and \ self.glob_it - self.max_val_it > early_stop_it) else False @@ -302,16 +311,6 @@ def train_valid(self): if(stop_now): logging.info("The training is early stopped") break - # save the best performing checkpoint - save_dict = {'iteration': self.max_val_it, - 'valid_pred': self.max_val_score, - 'model_state_dict': self.best_model_wts, - 'optimizer_state_dict': self.optimizer.state_dict()} - save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.max_val_it) - torch.save(save_dict, save_name) - txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefix), 'wt') - txt_file.write(str(self.max_val_it)) - txt_file.close() logging.info('The best perfroming iter is {0:}, valid {1:} {2:}'.format(\ self.max_val_it, metrics, self.max_val_score)) self.summ_writer.close() diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 2e703d4..b007574 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -337,6 +337,7 @@ def train_valid(self): logging.info('learning rate {0:}'.format(lr_value)) logging.info("training/validation time: {0:.2f}s/{1:.2f}s".format(t1-t0, t2-t1)) self.write_scalars(train_scalars, valid_scalars, lr_value, self.glob_it) + if(valid_scalars['avg_fg_dice'] > self.max_val_dice): self.max_val_dice = valid_scalars['avg_fg_dice'] self.max_val_it = self.glob_it @@ -344,8 +345,17 @@ def train_valid(self): self.best_model_wts = copy.deepcopy(self.net.module.state_dict()) else: self.best_model_wts = copy.deepcopy(self.net.state_dict()) + save_dict = {'iteration': self.max_val_it, + 'valid_pred': self.max_val_dice, + 'model_state_dict': self.best_model_wts, + 'optimizer_state_dict': self.optimizer.state_dict()} + save_name = "{0:}/{1:}_best.pt".format(ckpt_dir, ckpt_prefix) + torch.save(save_dict, save_name) + txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefix), 'wt') + txt_file.write(str(self.max_val_it)) + txt_file.close() - stop_now = True if(early_stop_it is not None and \ + stop_now = True if (early_stop_it is not None and \ self.glob_it - self.max_val_it > early_stop_it) else False if ((self.glob_it in iter_save_list) or stop_now): save_dict = {'iteration': self.glob_it, @@ -362,15 +372,6 @@ def train_valid(self): logging.info("The training is early stopped") break # save the best performing checkpoint - save_dict = {'iteration': self.max_val_it, - 'valid_pred': self.max_val_dice, - 'model_state_dict': self.best_model_wts, - 'optimizer_state_dict': self.optimizer.state_dict()} - save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.max_val_it) - torch.save(save_dict, save_name) - txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefix), 'wt') - txt_file.write(str(self.max_val_it)) - txt_file.close() logging.info('The best performing iter is {0:}, valid dice {1:}'.format(\ self.max_val_it, self.max_val_dice)) self.summ_writer.close() From a205f31e5d4d44d562d83afc9617337820b0e841 Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 18 May 2023 12:28:16 +0800 Subject: [PATCH 11/86] update code for MIM add transform MaskedImageModelingLabel --- pymic/loss/seg/mse.py | 6 +- pymic/net/net3d/trans3d/HiFormer_v1.py | 1010 +++++++++++++++++ pymic/net/net3d/trans3d/HiFormer_v2.py | 381 +++++++ pymic/net/net3d/trans3d/HiFormer_v3.py | 455 ++++++++ pymic/net/net3d/trans3d/HiFormer_v4.py | 455 ++++++++ pymic/net/net3d/trans3d/HiFormer_v5.py | 308 +++++ pymic/net/net3d/trans3d/MedFormer_v1.py | 173 +++ pymic/net/net3d/trans3d/MedFormer_v2.py | 464 ++++++++ pymic/net/net3d/trans3d/MedFormer_v3.py | 255 +++++ pymic/net/net3d/trans3d/MedFormer_va1.py | 105 ++ pymic/net/net3d/trans3d/__init__.py | 0 pymic/net/net3d/trans3d/nnFormer_wrap.py | 43 + pymic/net/net3d/trans3d/unetr.py | 227 ++++ pymic/net/net3d/trans3d/unetr_pp.py | 460 ++++++++ pymic/net/net3d/trans3d/unetr_pp_block.py | 278 +++++ pymic/net_run/agent_rec.py | 6 +- pymic/net_run/self_sup/__init__.py | 3 +- .../net_run/self_sup/self_patch_mix_agent.py | 144 +++ pymic/net_run/self_sup/util.py | 167 +++ pymic/transform/label_convert.py | 59 +- pymic/transform/mix.py | 66 ++ pymic/transform/trans_dict.py | 3 +- 22 files changed, 5046 insertions(+), 22 deletions(-) create mode 100644 pymic/net/net3d/trans3d/HiFormer_v1.py create mode 100644 pymic/net/net3d/trans3d/HiFormer_v2.py create mode 100644 pymic/net/net3d/trans3d/HiFormer_v3.py create mode 100644 pymic/net/net3d/trans3d/HiFormer_v4.py create mode 100644 pymic/net/net3d/trans3d/HiFormer_v5.py create mode 100644 pymic/net/net3d/trans3d/MedFormer_v1.py create mode 100644 pymic/net/net3d/trans3d/MedFormer_v2.py create mode 100644 pymic/net/net3d/trans3d/MedFormer_v3.py create mode 100644 pymic/net/net3d/trans3d/MedFormer_va1.py create mode 100644 pymic/net/net3d/trans3d/__init__.py create mode 100644 pymic/net/net3d/trans3d/nnFormer_wrap.py create mode 100644 pymic/net/net3d/trans3d/unetr.py create mode 100644 pymic/net/net3d/trans3d/unetr_pp.py create mode 100644 pymic/net/net3d/trans3d/unetr_pp_block.py create mode 100644 pymic/net_run/self_sup/self_patch_mix_agent.py create mode 100644 pymic/net_run/self_sup/util.py create mode 100644 pymic/transform/mix.py diff --git a/pymic/loss/seg/mse.py b/pymic/loss/seg/mse.py index ad83899..5b657c5 100644 --- a/pymic/loss/seg/mse.py +++ b/pymic/loss/seg/mse.py @@ -40,11 +40,15 @@ def __init__(self, params = None): def forward(self, loss_input_dict): predict = loss_input_dict['prediction'] soft_y = loss_input_dict['ground_truth'] + weight = loss_input_dict.get('pixel_weight', None) if(isinstance(predict, (list, tuple))): predict = predict[0] if(self.softmax): predict = nn.Softmax(dim = 1)(predict) mae = torch.abs(predict - soft_y) - mae = torch.mean(mae) + if(weight is None): + mae = torch.mean(mae) + else: + mae = torch.sum(mae * weight) / weight.sum() return mae diff --git a/pymic/net/net3d/trans3d/HiFormer_v1.py b/pymic/net/net3d/trans3d/HiFormer_v1.py new file mode 100644 index 0000000..af73683 --- /dev/null +++ b/pymic/net/net3d/trans3d/HiFormer_v1.py @@ -0,0 +1,1010 @@ +from einops import rearrange +from copy import deepcopy +from nnformer.utilities.nd_softmax import softmax_helper +from torch import nn +import torch +import numpy as np +import torch.nn.functional +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_3tuple, trunc_normal_ +from pymic.net.net3d.unet3d import ConvBlock, DownBlock +# from nnFormer +class ContiguousGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return x + @staticmethod + def backward(ctx, grad_out): + return grad_out.contiguous() + +# from nnFormer +class Mlp(nn.Module): + """ Multilayer perceptron.""" + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + +# from nnFormer +def window_partition(x, window_size): + + B, S, H, W, C = x.shape + x = x.view(B, S // window_size, window_size, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size, window_size, window_size, C) + return windows + +# from nnFormer +def window_reverse(windows, window_size, S, H, W): + + B = int(windows.shape[0] / (S * H * W / window_size / window_size / window_size)) + x = windows.view(B, S // window_size, H // window_size, W // window_size, window_size, window_size, window_size, -1) + x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, S, H, W, -1) + return x + + +# from nnFormer +class SwinTransformerBlock_kv(nn.Module): + + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention_kv( + dim, window_size=to_3tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + #self.window_size=to_3tuple(self.window_size) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, mask_matrix,skip=None,x_up=None): + + B, L, C = x.shape + S, H, W = self.input_resolution + + assert L == S * H * W, "input feature has wrong size" + + shortcut = x + skip = self.norm1(skip) + x_up = self.norm1(x_up) + + skip = skip.view(B, S, H, W, C) + x_up = x_up.view(B, S, H, W, C) + x = x.view(B, S, H, W, C) + # pad feature maps to multiples of window size + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + pad_g = (self.window_size - S % self.window_size) % self.window_size + + skip = F.pad(skip, (0, 0, 0, pad_r, 0, pad_b, 0, pad_g)) + x_up = F.pad(x_up, (0, 0, 0, pad_r, 0, pad_b, 0, pad_g)) + _, Sp, Hp, Wp, _ = skip.shape + + + + # cyclic shift + if self.shift_size > 0: + skip = torch.roll(skip, shifts=(-self.shift_size, -self.shift_size,-self.shift_size), dims=(1, 2,3)) + x_up = torch.roll(x_up, shifts=(-self.shift_size, -self.shift_size,-self.shift_size), dims=(1, 2,3)) + attn_mask = mask_matrix + else: + skip = skip + x_up=x_up + attn_mask = None + # partition windows + skip = window_partition(skip, self.window_size) + skip = skip.view(-1, self.window_size * self.window_size * self.window_size, + C) + x_up = window_partition(x_up, self.window_size) + x_up = x_up.view(-1, self.window_size * self.window_size * self.window_size, + C) + attn_windows=self.attn(skip,x_up,mask=attn_mask,pos_embed=None) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Sp, Hp, Wp) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size, self.shift_size), dims=(1, 2, 3)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0 or pad_g > 0: + x = x[:, :S, :H, :W, :].contiguous() + + x = x.view(B, S * H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + +# from nnFormer +class WindowAttention_kv(nn.Module): + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), + num_heads)) + + # get pair-wise relative position index for each token inside the window + coords_s = torch.arange(self.window_size[0]) + coords_h = torch.arange(self.window_size[1]) + coords_w = torch.arange(self.window_size[2]) + coords = torch.stack(torch.meshgrid([coords_s, coords_h, coords_w])) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 2] += self.window_size[2] - 1 + + relative_coords[:, :, 0] *= 3 * self.window_size[1] - 1 + relative_coords[:, :, 1] *= 2 * self.window_size[1] - 1 + + relative_position_index = relative_coords.sum(-1) + self.register_buffer("relative_position_index", relative_position_index) + + self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + trunc_normal_(self.relative_position_bias_table, std=.02) + + + def forward(self, skip,x_up,pos_embed=None, mask=None): + + B_, N, C = skip.shape + + kv = self.kv(skip) + q = x_up + + kv=kv.reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q = q.reshape(B_,N,self.num_heads,C//self.num_heads).permute(0,2,1,3).contiguous() + k,v = kv[0], kv[1] + q = q * self.scale + attn = (q @ k.transpose(-2, -1).contiguous()) + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] * self.window_size[2], + self.window_size[0] * self.window_size[1] * self.window_size[2], -1) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C).contiguous() + if pos_embed is not None: + x = x + pos_embed + x = self.proj(x) + x = self.proj_drop(x) + return x + +# from nnFormer +class WindowAttention(nn.Module): + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), + num_heads)) + + # get pair-wise relative position index for each token inside the window + coords_s = torch.arange(self.window_size[0]) + coords_h = torch.arange(self.window_size[1]) + coords_w = torch.arange(self.window_size[2]) + coords = torch.stack(torch.meshgrid([coords_s, coords_h, coords_w])) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 2] += self.window_size[2] - 1 + + relative_coords[:, :, 0] *= 3 * self.window_size[1] - 1 + relative_coords[:, :, 1] *= 2 * self.window_size[1] - 1 + + relative_position_index = relative_coords.sum(-1) + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + trunc_normal_(self.relative_position_bias_table, std=.02) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None,pos_embed=None): + + B_, N, C = x.shape + + qkv = self.qkv(x) + + qkv=qkv.reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1).contiguous()) + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] * self.window_size[2], + self.window_size[0] * self.window_size[1] * self.window_size[2], -1) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C).contiguous() + if pos_embed is not None: + x = x+pos_embed + x = self.proj(x) + x = self.proj_drop(x) + return x + +# from nnFormer +class SwinTransformerBlock(nn.Module): + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + + self.attn = WindowAttention( + dim, window_size=to_3tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + + def forward(self, x, mask_matrix): + + B, L, C = x.shape + S, H, W = self.input_resolution + + assert L == S * H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, S, H, W, C) + + # pad feature maps to multiples of window size + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + pad_g = (self.window_size - S % self.window_size) % self.window_size + + x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b, 0, pad_g)) + _, Sp, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size,-self.shift_size), dims=(1, 2,3)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size * self.window_size, + C) + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=attn_mask,pos_embed=None) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Sp, Hp, Wp) + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size, self.shift_size), dims=(1, 2, 3)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0 or pad_g > 0: + x = x[:, :S, :H, :W, :].contiguous() + + x = x.view(B, S * H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + +# from nnFormer +class PatchMerging(nn.Module): + + + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Conv3d(dim,dim*2,kernel_size=3,stride=2,padding=1) + + self.norm = norm_layer(dim) + + def forward(self, x, S, H, W): + + B, L, C = x.shape + assert L == H * W * S, "input feature has wrong size" + x = x.view(B, S, H, W, C) + + x = F.gelu(x) + x = self.norm(x) + x=x.permute(0,4,1,2,3).contiguous() + x=self.reduction(x) + x=x.permute(0,2,3,4,1).contiguous().view(B,-1,2*C) + + return x + +# from nnFormer +class Patch_Expanding(nn.Module): + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + + self.norm = norm_layer(dim) + self.up=nn.ConvTranspose3d(dim,dim//2,2,2) + def forward(self, x, S, H, W): + + + B, L, C = x.shape + assert L == H * W * S, "input feature has wrong size" + + x = x.view(B, S, H, W, C) + + + + x = self.norm(x) + x=x.permute(0,4,1,2,3).contiguous() + x = self.up(x) + x = ContiguousGrad.apply(x) + x=x.permute(0,2,3,4,1).contiguous().view(B,-1,C//2) + + return x + +# from nnFormer +class BasicLayer(nn.Module): + + def __init__(self, + dim, + input_resolution, + depth, + num_heads, + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=True + ): + super().__init__() + self.window_size = window_size + self.shift_size = window_size // 2 + self.depth = depth + # build blocks + + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, S, H, W): + + + # calculate attention mask for SW-MSA + Sp = int(np.ceil(S / self.window_size)) * self.window_size + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Sp, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + s_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for s in s_slices: + for h in h_slices: + for w in w_slices: + img_mask[:, s, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) + mask_windows = mask_windows.view(-1, + self.window_size * self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + for blk in self.blocks: + + x = blk(x, attn_mask) + if self.downsample is not None: + x_down = self.downsample(x, S, H, W) + Ws, Wh, Ww = (S + 1) // 2, (H + 1) // 2, (W + 1) // 2 + return x, S, H, W, x_down, Ws, Wh, Ww + else: + return x, S, H, W, x, S, H, W + +# from nnFormer +class BasicLayer_up(nn.Module): + + def __init__(self, + dim, + input_resolution, + depth, + num_heads, + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + upsample=True + ): + super().__init__() + self.window_size = window_size + self.shift_size = window_size // 2 + self.depth = depth + + + # build blocks + self.blocks = nn.ModuleList() + self.blocks.append( + SwinTransformerBlock_kv( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=0 , + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[0] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) + ) + for i in range(depth-1): + self.blocks.append( + SwinTransformerBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=window_size // 2 , + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i+1] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) + ) + + + + self.Upsample = upsample(dim=2*dim, norm_layer=norm_layer) + def forward(self, x,skip, S, H, W): + + + x_up = self.Upsample(x, S, H, W) + + x = x_up + skip + S, H, W = S * 2, H * 2, W * 2 + # calculate attention mask for SW-MSA + Sp = int(np.ceil(S / self.window_size)) * self.window_size + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Sp, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + s_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for s in s_slices: + for h in h_slices: + for w in w_slices: + img_mask[:, s, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, + self.window_size * self.window_size * self.window_size) # 3d��3��winds�˻�����Ŀ�Ǻܴ�ģ�����winds����̫�� + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + x = self.blocks[0](x, attn_mask,skip=skip,x_up=x_up) + for i in range(self.depth-1): + x = self.blocks[i+1](x,attn_mask) + + return x, S, H, W + + +# from nnFormer +class project(nn.Module): + def __init__(self,in_dim,out_dim,stride,padding,activate,norm,last=False): + super().__init__() + self.out_dim=out_dim + self.conv1=nn.Conv3d(in_dim,out_dim,kernel_size=3,stride=stride,padding=padding) + self.conv2=nn.Conv3d(out_dim,out_dim,kernel_size=3,stride=1,padding=1) + self.activate=activate() + self.norm1=norm(out_dim) + self.last=last + if not last: + self.norm2=norm(out_dim) + + def forward(self,x): + x=self.conv1(x) + x=self.activate(x) + #norm1 + Ws, Wh, Ww = x.size(2), x.size(3), x.size(4) + x = x.flatten(2).transpose(1, 2).contiguous() + x = self.norm1(x) + x = x.transpose(1, 2).contiguous().view(-1, self.out_dim, Ws, Wh, Ww) + + + x=self.conv2(x) + if not self.last: + x=self.activate(x) + #norm2 + Ws, Wh, Ww = x.size(2), x.size(3), x.size(4) + x = x.flatten(2).transpose(1, 2).contiguous() + x = self.norm2(x) + x = x.transpose(1, 2).contiguous().view(-1, self.out_dim, Ws, Wh, Ww) + return x + + +# from nnFormer +class PatchEmbed_backup(nn.Module): + def __init__(self, patch_size=4, in_chans=4, embed_dim=96, norm_layer=None): + super().__init__() + patch_size = to_3tuple(patch_size) + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + stride1=[patch_size[0]//2,patch_size[1]//2,patch_size[2]//2] + stride2=[patch_size[0]//2,patch_size[1]//2,patch_size[2]//2] + self.proj1 = project(in_chans,embed_dim//2,stride1,1,nn.GELU,nn.LayerNorm,False) + self.proj2 = project(embed_dim//2,embed_dim,stride2,1,nn.GELU,nn.LayerNorm,True) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, S, H, W = x.size() + if W % self.patch_size[2] != 0: + x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2])) + if H % self.patch_size[1] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1])) + if S % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - S % self.patch_size[0])) + x = self.proj1(x) # B C Ws Wh Ww + x = self.proj2(x) # B C Ws Wh Ww + if self.norm is not None: + Ws, Wh, Ww = x.size(2), x.size(3), x.size(4) + x = x.flatten(2).transpose(1, 2).contiguous() + x = self.norm(x) + x = x.transpose(1, 2).contiguous().view(-1, self.embed_dim, Ws, Wh, Ww) + + return x + + +class PatchEmbed(nn.Module): + """ + replace patch embed with conv layers""" + def __init__(self, in_chns=1, ft_chns = [32, 64, 128], dropout = [0, 0, 0.2]): + super().__init__() + self.in_conv= ConvBlock(in_chns, ft_chns[0], dropout[0]) + self.down1 = DownBlock(ft_chns[0], ft_chns[1], dropout[1]) + self.down2 = DownBlock(ft_chns[1], ft_chns[2], dropout[2]) + + + def forward(self, x): + """Forward function.""" + x0 = self.in_conv(x) + x1 = self.down1(x0) + x2 = self.down2(x1) + return x2 + +# from nnFormer +class Encoder(nn.Module): + + def __init__(self, + pretrain_img_size=224, + patch_size=4, + in_chans=1 , + embed_dim=96, + depths=[2, 2, 2, 2], + num_heads=[4, 8, 16, 32], + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + patch_norm=True, + out_indices=(0, 1, 2, 3) + ): + super().__init__() + + self.pretrain_img_size = pretrain_img_size + + self.num_layers = len(depths) + print("number of layers in encoder", self.num_layers, depths) + self.embed_dim = embed_dim + self.patch_norm = patch_norm + self.out_indices = out_indices + + # split image into non-overlapping patches + # self.patch_embed = PatchEmbed( + # patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + # norm_layer=norm_layer if self.patch_norm else None) + self.patch_embed = PatchEmbed(in_chans, ft_chns=[embed_dim // 4, embed_dim //2, embed_dim]) + + + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2 ** i_layer), + input_resolution=( + pretrain_img_size[0] // patch_size[0] // 2 ** i_layer, pretrain_img_size[1] // patch_size[1] // 2 ** i_layer, + pretrain_img_size[2] // patch_size[2] // 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size[i_layer], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum( + depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging + if (i_layer < self.num_layers - 1) else None + ) + self.layers.append(layer) + + num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + + def forward(self, x): + """Forward function.""" + + x = self.patch_embed(x) + down=[] + Ws, Wh, Ww = x.size(2), x.size(3), x.size(4) + + x = x.flatten(2).transpose(1, 2).contiguous() + x = self.pos_drop(x) + + + for i in range(self.num_layers): + layer = self.layers[i] + x_out, S, H, W, x, Ws, Wh, Ww = layer(x, Ws, Wh, Ww) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + x_out = norm_layer(x_out) + + out = x_out.view(-1, S, H, W, self.num_features[i]).permute(0, 4, 1, 2, 3).contiguous() + + down.append(out) + return down + + +# from nnFormer +class Decoder(nn.Module): + def __init__(self, + pretrain_img_size, + embed_dim, + patch_size=4, + depths=[2,2,2], + num_heads=[24,12,6], + window_size=4, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm + ): + super().__init__() + + + self.num_layers = len(depths) + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers)[::-1]: + + layer = BasicLayer_up( + dim=int(embed_dim * 2 ** (len(depths)-i_layer-1)), + input_resolution=( + pretrain_img_size[0] // patch_size[0] // 2 ** (len(depths)-i_layer-1), pretrain_img_size[1] // patch_size[1] // 2 ** (len(depths)-i_layer-1), + pretrain_img_size[2] // patch_size[2] // 2 ** (len(depths)-i_layer-1)), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size[i_layer], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum( + depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + upsample=Patch_Expanding + ) + self.layers.append(layer) + self.num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + def forward(self,x,skips): + + outs=[] + S, H, W = x.size(2), x.size(3), x.size(4) + x = x.flatten(2).transpose(1, 2).contiguous() + for index,i in enumerate(skips): + i = i.flatten(2).transpose(1, 2).contiguous() + skips[index]=i + x = self.pos_drop(x) + + for i in range(self.num_layers)[::-1]: + + layer = self.layers[i] + + x, S, H, W, = layer(x,skips[i], S, H, W) + out = x.view(-1, S, H, W, self.num_features[i]) + outs.append(out) + return outs + + +class final_patch_expanding(nn.Module): + def __init__(self,dim,num_class,patch_size): + super().__init__() + self.up=nn.ConvTranspose3d(dim,num_class,patch_size,patch_size) + + def forward(self,x): + x=x.permute(0,4,1,2,3).contiguous() + x=self.up(x) + + + return x + + + + + +class HiFormer_v1(nn.Module): + def __init__(self, params): + """ + replace the embedding layer with convolutional blocks + """ + super(HiFormer_v1, self).__init__() + # crop_size=[96,96,96], + # embedding_dim=192, + # input_channels=1, + # num_classes=9, + # conv_op=nn.Conv3d, + # depths=[2,2,2,2], + # num_heads=[6, 12, 24, 48], + # patch_size=[4,4,4], + # window_size=[4,4,8,4], + # deep_supervision=False): + + crop_size = params["input_size"] + embed_dim = params.get("embedding_dim", 192) + input_channels = params["in_chns"] + num_classes = params["class_num"] + self.conv_op = nn.Conv3d + depths = params.get("depths", [2, 2, 2, 2]) + num_heads = params.get("num_heads", [6, 12, 24, 48]) + patch_size = params.get("patch_size", [4, 4, 4]) # for patch embedding + window_size = params.get("window_size", [4, 4, 8, 4]) # for swin transformer window + self._deep_supervision = params.get("deep_supervision", False) + self.do_ds = params.get("deep_supervision", False) + + + self.num_classes = num_classes + self.upscale_logits_ops = [] + self.upscale_logits_ops.append(lambda x: x) + + self.model_down=Encoder(pretrain_img_size=crop_size,window_size=window_size,embed_dim=embed_dim, + patch_size=patch_size,depths=depths,num_heads=num_heads,in_chans=input_channels, out_indices=range(len(depths))) + self.decoder=Decoder(pretrain_img_size=crop_size,embed_dim=embed_dim,window_size=window_size[::-1][1:],patch_size=patch_size,num_heads=num_heads[::-1][:-1],depths=depths[::-1][1:]) + + self.final=[] + if self.do_ds: + + for i in range(len(depths)-1): + self.final.append(final_patch_expanding(embed_dim*2**i,num_classes,patch_size=patch_size)) + + else: + self.final.append(final_patch_expanding(embed_dim,num_classes,patch_size=patch_size)) + + self.final=nn.ModuleList(self.final) + + + def forward(self, x): + + + seg_outputs=[] + skips = self.model_down(x) + neck=skips[-1] + + out=self.decoder(neck,skips) + + + + if self.do_ds: + for i in range(len(out)): + seg_outputs.append(self.final[-(i+1)](out[i])) + + + return seg_outputs[::-1] + else: + seg_outputs.append(self.final[0](out[-1])) + return seg_outputs[-1] + + +if __name__ == "__main__": + # params = {"input_size": [96, 96, 96], + # "in_chns": 1, + # "depth": [2, 2, 2, 2], + # "num_heads": [6, 12, 24, 48], + # "window_size": [6, 6, 6, 3], + # "class_num": 5} + params = {"input_size": [96, 96, 96], + "in_chns": 1, + "depths": [2, 2, 2], + "num_heads": [6, 12, 24], + "window_size": [6, 6, 6], + "class_num": 5} + Net = HiFormer_v1(params) + Net = Net.double() + + x = np.random.rand(1, 1, 96, 96, 96) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = Net(xt) + print(y.shape) + + + diff --git a/pymic/net/net3d/trans3d/HiFormer_v2.py b/pymic/net/net3d/trans3d/HiFormer_v2.py new file mode 100644 index 0000000..7d4c440 --- /dev/null +++ b/pymic/net/net3d/trans3d/HiFormer_v2.py @@ -0,0 +1,381 @@ + +import torch +import numpy as np +import torch.utils.checkpoint as checkpoint +from einops import rearrange +from copy import deepcopy +from torch import nn +from pymic.net.net3d.trans3d.HiFormer_v1 import BasicLayer + +class ConvBlock(nn.Module): + """ + 2D or 3D convolutional block + + :param in_channels: (int) Input channel number. + :param out_channels: (int) Output channel number. + :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. + :param dropout_p: (int) Dropout probability. + """ + def __init__(self, in_channels, out_channels, dim = 2, dropout_p = 0.0): + super(ConvBlock, self).__init__() + assert(dim == 2 or dim == 3) + if(dim == 2): + kernel_size = [1, 3, 3] + padding = [0, 1, 1] + else: + kernel_size = 3 + padding = 1 + + self.conv_conv = nn.Sequential( + nn.BatchNorm3d(in_channels), + nn.PReLU(), + nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=padding), + nn.BatchNorm3d(out_channels), + nn.PReLU(), + nn.Dropout(dropout_p), + nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, padding=padding), + ) + + def forward(self, x): + return self.conv_conv(x) + + +class DownSample(nn.Module): + def __init__(self, in_channels, out_channels, dim = 2, first_layer = False): + super(DownSample, self).__init__() + assert(dim == 2 or dim == 3) + if(dim == 2): + kernel_size = [1, 3, 3] + stride = [1, 2, 2] + padding = [0, 1, 1] + else: + kernel_size = 3 + stride = 2 + padding = 1 + + if(first_layer): + self.down = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, + padding=padding, stride = stride) + else: + self.down = nn.Sequential( + nn.BatchNorm3d(in_channels), + nn.PReLU(), + nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, + padding=padding, stride = stride), + ) + + def forward(self, x): + return self.down(x) + + + +class ConvTransBlock(nn.Module): + def __init__(self, + input_resolution= [32, 32, 32], + chns=96, + depth=2, + num_head=4, + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + patch_norm=True, + ): + super().__init__() + self.conv = ConvBlock(chns, chns, dim = 3, dropout_p = drop_rate) + self.trans = BasicLayer( + dim= chns, + input_resolution= input_resolution, + depth=depth, + num_heads=num_head, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + norm_layer=norm_layer, + downsample= None + ) + self.norm_layer = nn.LayerNorm(chns) + self.pos_drop = nn.Dropout(p=drop_rate) + + def forward(self, x): + """Forward function.""" + x1 = self.conv(x) + C, Ws, Wh, Ww = x.size(1), x.size(2), x.size(3), x.size(4) + x = x.flatten(2).transpose(1, 2).contiguous() + x = self.pos_drop(x) + x2, S, H, W, x, Ws, Wh, Ww = self.trans(x, Ws, Wh, Ww) + # x2 = self.norm_layer(x2) + x2 = x2.view(-1, S, H, W, C).permute(0, 4, 1, 2, 3).contiguous() + return x1 + x2 + + +class UpCatBlock(nn.Module): + """ + 3D upsampling followed by ConvBlock + + :param in_channels1: (int) Channel number of high-level features. + :param in_channels2: (int) Channel number of low-level features. + :param out_channels: (int) Output channel number. + :param dropout_p: (int) Dropout probability. + :param trilinear: (bool) Use trilinear for up-sampling (by default). + If False, deconvolution is used for up-sampling. + """ + def __init__(self, chns_l, chns_h, up_dim = 3, conv_dim = 3): + super(UpCatBlock, self).__init__() + assert(up_dim == 2 or up_dim == 3) + if(up_dim == 2): + kernel_size, stride = [1, 2, 2], [1, 2, 2] + else: + kernel_size, stride = 2, 2 + self.up = nn.ConvTranspose3d(chns_h, chns_l, + kernel_size = kernel_size, stride=stride) + + if(conv_dim == 2): + kernel_size, padding = [1, 3, 3], [0, 1, 1] + else: + kernel_size, padding = 3, 1 + self.conv = nn.Sequential( + nn.BatchNorm3d(chns_l*2), + nn.PReLU(), + nn.Conv3d(chns_l*2, chns_l, kernel_size=kernel_size, padding=padding), + ) + + def forward(self, x_l, x_h): + # print("input shapes", x1.shape, x2.shape) + # print("after upsample", x1.shape) + y = torch.cat([x_l, self.up(x_h)], dim=1) + return self.conv(y) + +class Encoder(nn.Module): + def __init__(self, + in_chns = 1 , + ft_chns = [48, 192, 384, 768], + input_size= [32, 128, 128], + down_dims = [2, 2, 3, 3], + conv_dims = [2, 3, 3, 3], + dropout = [0, 0.2, 0.2, 0.2], + depths = [2, 2, 2], + num_heads = [4, 8, 16], + window_sizes = [6, 6, 6], + ): + super().__init__() + + self.down1 = DownSample(in_chns, ft_chns[0], down_dims[0], first_layer=True) + self.down2 = DownSample(ft_chns[0], ft_chns[1], down_dims[1]) + self.down3 = DownSample(ft_chns[1], ft_chns[2], down_dims[2]) + self.down4 = DownSample(ft_chns[2], ft_chns[3], down_dims[3]) + + self.conv1 = ConvBlock(ft_chns[0], ft_chns[0], conv_dims[0], dropout[0]) + self.conv2 = ConvBlock(ft_chns[1], ft_chns[1], conv_dims[1], dropout[1]) + + down_scales = [] + for i in range(4): + down_scale = [1, 2, 2] if down_dims[i] == 2 else [2, 2, 2] + down_scales.append(down_scale) + + r_t2 = [input_size[i] // down_scales[0][i] // down_scales[1][i] for i in range(3)] + r_t3 = [r_t2[i] // down_scales[2][i] for i in range(3)] + r_t4 = [r_t3[i] // down_scales[3][i] for i in range(3)] + + self.conv_t2 = ConvTransBlock(chns = ft_chns[1], + input_resolution = r_t2, + window_size = window_sizes[0], + depth = depths[0], + num_head = num_heads[0], + drop_rate = dropout[1], + attn_drop_rate=dropout[1] + ) + self.conv_t3 = ConvTransBlock(chns = ft_chns[2], + input_resolution = r_t3, + window_size = window_sizes[1], + depth = depths[1], + num_head = num_heads[1], + drop_rate = dropout[2], + attn_drop_rate=dropout[2] + ) + self.conv_t4 = ConvTransBlock(chns = ft_chns[3], + input_resolution = r_t4, + window_size = window_sizes[2], + depth = depths[2], + num_head = num_heads[2], + drop_rate = dropout[3], + attn_drop_rate=dropout[3] + ) + + + + def forward(self, x): + """Forward function.""" + x1 = self.conv1(self.down1(x)) + x2 = self.conv2(self.down2(x1)) + x2 = self.conv_t2(x2) + x3 = self.conv_t3(self.down3(x2)) + x4 = self.conv_t4(self.down4(x3)) + + return x1, x2, x3, x4 + +class Decoder(nn.Module): + """ + Decoder of 3D UNet. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param trilinear: (bool) Using bilinear for up-sampling or not. + If False, deconvolution will be used for up-sampling. + :param multiscale_pred: (bool) Get multi-scale prediction. + """ + def __init__(self, + ft_chns = [48, 192, 384, 768], + input_size = [32, 128, 128], + down_dims = [2, 2, 3, 3], + conv_dims = [2, 3, 3, 3], + dropout = [0, 0, 0.2, 0.2], + depths = [2, 2, 2], + num_heads = [4, 8, 16], + window_sizes = [6, 6, 6], + class_num = 2, + multiscale_pred = False + ): + super(Decoder, self).__init__() + + self.up1 = UpCatBlock(ft_chns[0], ft_chns[1], down_dims[1], conv_dims[0]) + self.up2 = UpCatBlock(ft_chns[1], ft_chns[2], down_dims[2], conv_dims[1]) + self.up3 = UpCatBlock(ft_chns[2], ft_chns[3], down_dims[3], conv_dims[2]) + + down_scales = [] + for i in range(4): + down_scale = [1, 2, 2] if down_dims[i] == 2 else [2, 2, 2] + down_scales.append(down_scale) + + r_t2 = [input_size[i] // down_scales[0][i] // down_scales[1][i] for i in range(3)] + r_t3 = [r_t2[i] // down_scales[2][i] for i in range(3)] + + self.conv1 = ConvBlock(ft_chns[0], ft_chns[0], conv_dims[0], dropout[0]) + self.conv2 = ConvTransBlock(chns = ft_chns[1], + input_resolution = r_t2, + window_size = window_sizes[0], + depth = depths[0], + num_head = num_heads[0], + drop_rate = dropout[1], + attn_drop_rate=dropout[1] + ) + self.conv3 = ConvTransBlock(chns = ft_chns[2], + input_resolution = r_t3, + window_size = window_sizes[1], + depth = depths[1], + num_head = num_heads[1], + drop_rate = dropout[2], + attn_drop_rate=dropout[2] + ) + + kernel_size, stride = 2, 2 + if down_dims[0] == 2: + kernel_size, stride = [1, 2, 2], [1, 2, 2] + self.out_conv0 = nn.ConvTranspose3d(ft_chns[0], class_num, + kernel_size = kernel_size, stride= stride) + + self.mul_pred = multiscale_pred + if(self.mul_pred): + self.out_conv1 = nn.Conv3d(ft_chns[0], class_num, kernel_size = 1) + self.out_conv2 = nn.Conv3d(ft_chns[1], class_num, kernel_size = 1) + self.out_conv3 = nn.Conv3d(ft_chns[2], class_num, kernel_size = 1) + + def forward(self, x): + x1, x2, x3, x4 = x + x_d3 = self.conv3(self.up3(x3, x4)) + x_d2 = self.conv2(self.up2(x2, x_d3)) + x_d1 = self.conv1(self.up1(x1, x_d2)) + + output = self.out_conv0(x_d1) + if(self.mul_pred): + output1 = self.out_conv1(x_d1) + output2 = self.out_conv2(x_d2) + output3 = self.out_conv3(x_d3) + output = [output, output1, output2, output3] + return output + +class HiFormer_v2(nn.Module): + def __init__(self, params): + """ + replace the embedding layer with convolutional blocks + """ + super(HiFormer_v2, self).__init__() + in_chns = params["in_chns"] + class_num = params["class_num"] + input_size = params["input_size"] + ft_chns = params.get("feature_chns", [48, 192, 384, 764]) + down_dims = params.get("down_dims", [2, 2, 3, 3]) + conv_dims = params.get("conv_dims", [2, 3, 3, 3]) + dropout = params.get('dropout', [0, 0.2, 0.2, 0.2]) + depths = params.get("depths", [2, 2, 2]) + num_heads = params.get("num_heads", [4, 8, 16]) + window_sizes= params.get("window_sizes", [6, 6, 6]) + multiscale_pred = params.get("multiscale_pred", False) + + self.encoder = Encoder(in_chns, + ft_chns = ft_chns, + input_size = input_size, + down_dims = down_dims, + conv_dims = conv_dims, + dropout = dropout, + depths = depths, + num_heads = num_heads, + window_sizes= window_sizes) + + self.decoder = Decoder(ft_chns = ft_chns, + input_size = input_size, + down_dims = down_dims, + conv_dims = conv_dims, + dropout = dropout, + depths = depths, + num_heads = num_heads, + window_sizes= window_sizes, + class_num = class_num, + multiscale_pred = multiscale_pred + ) + + def forward(self, x): + x = self.encoder(x) + x = self.decoder(x) + return x + + +if __name__ == "__main__": + params = {"input_size": [32, 128, 128], + "in_chns": 1, + "down_dims": [2, 2, 3, 3], + "conv_dims": [2, 3, 3, 3], + "feature_chns": [96, 192, 384, 768], + "class_num": 5, + "multiscale_pred": True} + Net = HiFormer_v2(params) + Net = Net.double() + + x = np.random.rand(1, 1, 32, 128, 128) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = Net(xt) + if(params['multiscale_pred']): + for yi in y: + print(yi.shape) + else: + print(y.shape) + + + diff --git a/pymic/net/net3d/trans3d/HiFormer_v3.py b/pymic/net/net3d/trans3d/HiFormer_v3.py new file mode 100644 index 0000000..2f8c831 --- /dev/null +++ b/pymic/net/net3d/trans3d/HiFormer_v3.py @@ -0,0 +1,455 @@ + +import torch +import numpy as np +import torch.utils.checkpoint as checkpoint +from einops import rearrange +from copy import deepcopy +from torch import nn +from pymic.net.net3d.trans3d.HiFormer_v1 import BasicLayer + +class ConvBlock(nn.Module): + """ + 2D or 3D convolutional block + + :param in_channels: (int) Input channel number. + :param out_channels: (int) Output channel number. + :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. + :param dropout_p: (int) Dropout probability. + """ + def __init__(self, in_channels, out_channels, dim = 2, dropout_p = 0.0): + super(ConvBlock, self).__init__() + assert(dim == 2 or dim == 3) + if(dim == 2): + kernel_size = [1, 3, 3] + padding = [0, 1, 1] + else: + kernel_size = 3 + padding = 1 + + self.conv_conv = nn.Sequential( + nn.BatchNorm3d(in_channels), + nn.PReLU(), + nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=padding), + nn.BatchNorm3d(out_channels), + nn.PReLU(), + nn.Dropout(dropout_p), + nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, padding=padding), + ) + + def forward(self, x): + return self.conv_conv(x) + + +class DownSample(nn.Module): + def __init__(self, in_channels, out_channels, dim = 2, first_layer = False): + super(DownSample, self).__init__() + assert(dim == 2 or dim == 3) + if(dim == 2): + kernel_size = [1, 3, 3] + stride = [1, 2, 2] + padding = [0, 1, 1] + else: + kernel_size = 3 + stride = 2 + padding = 1 + + if(first_layer): + self.down = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, + padding=padding, stride = stride) + else: + self.down = nn.Sequential( + nn.BatchNorm3d(in_channels), + nn.PReLU(), + nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, + padding=padding, stride = stride), + ) + + def forward(self, x): + return self.down(x) + + + +class ConvTransBlock_backup(nn.Module): + def __init__(self, + input_resolution= [32, 32, 32], + chns=96, + depth=2, + num_head=4, + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + patch_norm=True, + ): + super().__init__() + self.conv = ConvBlock(chns, chns, dim = 3, dropout_p = drop_rate) + self.trans = BasicLayer( + dim= chns, + input_resolution= input_resolution, + depth=depth, + num_heads=num_head, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + norm_layer=norm_layer, + downsample= None + ) + self.norm_layer = nn.LayerNorm(chns) + self.pos_drop = nn.Dropout(p=drop_rate) + + def forward(self, x): + """Forward function.""" + x1 = self.conv(x) + C, Ws, Wh, Ww = x.size(1), x.size(2), x.size(3), x.size(4) + x = x.flatten(2).transpose(1, 2).contiguous() + x = self.pos_drop(x) + x2, S, H, W, x, Ws, Wh, Ww = self.trans(x, Ws, Wh, Ww) + # x2 = self.norm_layer(x2) + x2 = x2.view(-1, S, H, W, C).permute(0, 4, 1, 2, 3).contiguous() + return x1 + x2 + +# only using the conv block +class ConvTransBlock(nn.Module): + def __init__(self, + input_resolution= [32, 32, 32], + chns=96, + depth=2, + num_head=4, + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + patch_norm=True, + ): + super().__init__() + self.conv = ConvBlock(chns, chns, dim = 3, dropout_p = drop_rate) + # self.trans = BasicLayer( + # dim= chns, + # input_resolution= input_resolution, + # depth=depth, + # num_heads=num_head, + # window_size=window_size, + # mlp_ratio=mlp_ratio, + # qkv_bias=qkv_bias, + # qk_scale=qk_scale, + # drop=drop_rate, + # attn_drop=attn_drop_rate, + # drop_path=drop_path_rate, + # norm_layer=norm_layer, + # downsample= None + # ) + # self.norm_layer = nn.LayerNorm(chns) + # self.pos_drop = nn.Dropout(p=drop_rate) + + def forward(self, x): + """Forward function.""" + x1 = self.conv(x) + return x1 + # C, Ws, Wh, Ww = x.size(1), x.size(2), x.size(3), x.size(4) + # x = x.flatten(2).transpose(1, 2).contiguous() + # x = self.pos_drop(x) + # x2, S, H, W, x, Ws, Wh, Ww = self.trans(x, Ws, Wh, Ww) + # # x2 = self.norm_layer(x2) + # x2 = x2.view(-1, S, H, W, C).permute(0, 4, 1, 2, 3).contiguous() + # return x1 + x2 + +class UpCatBlock(nn.Module): + """ + 3D upsampling followed by ConvBlock + + :param in_channels1: (int) Channel number of high-level features. + :param in_channels2: (int) Channel number of low-level features. + :param out_channels: (int) Output channel number. + :param dropout_p: (int) Dropout probability. + :param trilinear: (bool) Use trilinear for up-sampling (by default). + If False, deconvolution is used for up-sampling. + """ + def __init__(self, chns_l, chns_h, up_dim = 3, conv_dim = 3): + super(UpCatBlock, self).__init__() + assert(up_dim == 2 or up_dim == 3) + if(up_dim == 2): + kernel_size, stride = [1, 2, 2], [1, 2, 2] + else: + kernel_size, stride = 2, 2 + self.up = nn.ConvTranspose3d(chns_h, chns_l, + kernel_size = kernel_size, stride=stride) + + if(conv_dim == 2): + kernel_size, padding = [1, 3, 3], [0, 1, 1] + else: + kernel_size, padding = 3, 1 + self.conv = nn.Sequential( + nn.BatchNorm3d(chns_l*2), + nn.PReLU(), + nn.Conv3d(chns_l*2, chns_l, kernel_size=kernel_size, padding=padding), + ) + + def forward(self, x_l, x_h): + # print("input shapes", x1.shape, x2.shape) + # print("after upsample", x1.shape) + y = torch.cat([x_l, self.up(x_h)], dim=1) + return self.conv(y) + +class Encoder(nn.Module): + def __init__(self, + in_chns = 1 , + ft_chns = [48, 192, 384, 768], + input_size= [32, 128, 128], + down_dims = [2, 2, 3, 3], + conv_dims = [2, 3, 3, 3], + dropout = [0, 0.2, 0.2, 0.2], + depths = [2, 2, 2], + num_heads = [4, 8, 16], + window_sizes = [6, 6, 6], + high_res = False, + ): + super().__init__() + self.high_res = high_res + + self.down1 = DownSample(in_chns, ft_chns[0], down_dims[0], first_layer=True) + self.down2 = DownSample(ft_chns[0], ft_chns[1], down_dims[1]) + self.down3 = DownSample(ft_chns[1], ft_chns[2], down_dims[2]) + self.down4 = DownSample(ft_chns[2], ft_chns[3], down_dims[3]) + + if(high_res): + self.conv0 = ConvBlock(in_chns, ft_chns[0] // 2, 3, 0) + self.conv1 = ConvBlock(ft_chns[0], ft_chns[0], conv_dims[0], dropout[0]) + self.conv2 = ConvBlock(ft_chns[1], ft_chns[1], conv_dims[1], dropout[1]) + + down_scales = [] + for i in range(4): + down_scale = [1, 2, 2] if down_dims[i] == 2 else [2, 2, 2] + down_scales.append(down_scale) + + r_t2 = [input_size[i] // down_scales[0][i] // down_scales[1][i] for i in range(3)] + r_t3 = [r_t2[i] // down_scales[2][i] for i in range(3)] + r_t4 = [r_t3[i] // down_scales[3][i] for i in range(3)] + + self.conv_t2 = ConvTransBlock(chns = ft_chns[1], + input_resolution = r_t2, + window_size = window_sizes[0], + depth = depths[0], + num_head = num_heads[0], + drop_rate = dropout[1], + attn_drop_rate=dropout[1] + ) + self.conv_t3 = ConvTransBlock(chns = ft_chns[2], + input_resolution = r_t3, + window_size = window_sizes[1], + depth = depths[1], + num_head = num_heads[1], + drop_rate = dropout[2], + attn_drop_rate=dropout[2] + ) + self.conv_t4 = ConvTransBlock(chns = ft_chns[3], + input_resolution = r_t4, + window_size = window_sizes[2], + depth = depths[2], + num_head = num_heads[2], + drop_rate = dropout[3], + attn_drop_rate=dropout[3] + ) + + + + def forward(self, x): + """Forward function.""" + if(self.high_res): + x0 = self.conv0(x) + x1 = self.conv1(self.down1(x)) + x2 = self.conv2(self.down2(x1)) + x2 = self.conv_t2(x2) + x3 = self.conv_t3(self.down3(x2)) + x4 = self.conv_t4(self.down4(x3)) + if(self.high_res): + return x0, x1, x2, x3, x4 + else: + return x1, x2, x3, x4 + +class Decoder(nn.Module): + """ + Decoder of 3D UNet. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param trilinear: (bool) Using bilinear for up-sampling or not. + If False, deconvolution will be used for up-sampling. + :param multiscale_pred: (bool) Get multi-scale prediction. + """ + def __init__(self, + ft_chns = [48, 192, 384, 768], + input_size = [32, 128, 128], + down_dims = [2, 2, 3, 3], + conv_dims = [2, 3, 3, 3], + dropout = [0, 0, 0.2, 0.2], + depths = [2, 2, 2], + num_heads = [4, 8, 16], + window_sizes = [6, 6, 6], + high_res = False, + class_num = 2, + multiscale_pred = False + ): + super(Decoder, self).__init__() + self.high_res = high_res + if(self.high_res): + self.up0 = UpCatBlock(ft_chns[0] // 2, ft_chns[0], down_dims[0], 3) + self.conv0 = ConvBlock(ft_chns[0] // 2, ft_chns[0] // 2, 3, 0) + self.up1 = UpCatBlock(ft_chns[0], ft_chns[1], down_dims[1], conv_dims[0]) + self.up2 = UpCatBlock(ft_chns[1], ft_chns[2], down_dims[2], conv_dims[1]) + self.up3 = UpCatBlock(ft_chns[2], ft_chns[3], down_dims[3], conv_dims[2]) + + down_scales = [] + for i in range(4): + down_scale = [1, 2, 2] if down_dims[i] == 2 else [2, 2, 2] + down_scales.append(down_scale) + + r_t2 = [input_size[i] // down_scales[0][i] // down_scales[1][i] for i in range(3)] + r_t3 = [r_t2[i] // down_scales[2][i] for i in range(3)] + + self.conv1 = ConvBlock(ft_chns[0], ft_chns[0], conv_dims[0], dropout[0]) + self.conv2 = ConvTransBlock(chns = ft_chns[1], + input_resolution = r_t2, + window_size = window_sizes[0], + depth = depths[0], + num_head = num_heads[0], + drop_rate = dropout[1], + attn_drop_rate=dropout[1] + ) + self.conv3 = ConvTransBlock(chns = ft_chns[2], + input_resolution = r_t3, + window_size = window_sizes[1], + depth = depths[1], + num_head = num_heads[1], + drop_rate = dropout[2], + attn_drop_rate=dropout[2] + ) + + kernel_size, stride = 2, 2 + if down_dims[0] == 2: + kernel_size, stride = [1, 2, 2], [1, 2, 2] + if(self.high_res): + self.out_conv0 = nn.Conv3d(ft_chns[0] // 2, class_num, + kernel_size = [1, 3, 3], padding = [0, 1, 1]) + else: + self.out_conv0 = nn.ConvTranspose3d(ft_chns[0], class_num, + kernel_size = kernel_size, stride= stride) + + self.mul_pred = multiscale_pred + if(self.mul_pred): + self.out_conv1 = nn.Conv3d(ft_chns[0], class_num, kernel_size = 1) + self.out_conv2 = nn.Conv3d(ft_chns[1], class_num, kernel_size = 1) + self.out_conv3 = nn.Conv3d(ft_chns[2], class_num, kernel_size = 1) + + def forward(self, x): + if(self.high_res): + x0, x1, x2, x3, x4 = x + else: + x1, x2, x3, x4 = x + x_d3 = self.conv3(self.up3(x3, x4)) + x_d2 = self.conv2(self.up2(x2, x_d3)) + x_d1 = self.conv1(self.up1(x1, x_d2)) + if(self.high_res): + x_d0 = self.conv0(self.up0(x0, x_d1)) + output = self.out_conv0(x_d0) + else: + output = self.out_conv0(x_d1) + if(self.mul_pred): + output1 = self.out_conv1(x_d1) + output2 = self.out_conv2(x_d2) + output3 = self.out_conv3(x_d3) + output = [output, output1, output2, output3] + return output + +class HiFormer_v3(nn.Module): + def __init__(self, params): + """ + replace the embedding layer with convolutional blocks + """ + super(HiFormer_v3, self).__init__() + in_chns = params["in_chns"] + class_num = params["class_num"] + input_size = params["input_size"] + ft_chns = params.get("feature_chns", [48, 192, 384, 764]) + down_dims = params.get("down_dims", [2, 2, 3, 3]) + conv_dims = params.get("conv_dims", [2, 3, 3, 3]) + dropout = params.get('dropout', [0, 0.2, 0.2, 0.2]) + high_res = params.get("high_res", False) + depths = params.get("depths", [2, 2, 2]) + num_heads = params.get("num_heads", [4, 8, 16]) + window_sizes= params.get("window_sizes", [6, 6, 6]) + multiscale_pred = params.get("multiscale_pred", False) + + self.encoder = Encoder(in_chns, + ft_chns = ft_chns, + input_size = input_size, + down_dims = down_dims, + conv_dims = conv_dims, + dropout = dropout, + depths = depths, + num_heads = num_heads, + window_sizes= window_sizes, + high_res = high_res) + + self.decoder = Decoder(ft_chns = ft_chns, + input_size = input_size, + down_dims = down_dims, + conv_dims = conv_dims, + dropout = dropout, + depths = depths, + num_heads = num_heads, + window_sizes= window_sizes, + high_res = high_res, + class_num = class_num, + multiscale_pred = multiscale_pred + ) + + def forward(self, x): + x = self.encoder(x) + x = self.decoder(x) + return x + + +if __name__ == "__main__": + params = {"input_size": [64, 96, 96], + "in_chns": 1, + "down_dims": [3, 3, 3, 3], + "conv_dims": [3, 3, 3, 3], + "feature_chns": [96, 192, 384, 768], + "high_res": True, + "class_num": 5, + "multiscale_pred": True} + Net = HiFormer_v3(params) + Net = Net.double() + + x = np.random.rand(1, 1, 64, 96, 96) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = Net(xt) + if(params['multiscale_pred']): + for yi in y: + print(yi.shape) + else: + print(y.shape) + + + diff --git a/pymic/net/net3d/trans3d/HiFormer_v4.py b/pymic/net/net3d/trans3d/HiFormer_v4.py new file mode 100644 index 0000000..f0c6087 --- /dev/null +++ b/pymic/net/net3d/trans3d/HiFormer_v4.py @@ -0,0 +1,455 @@ + +import torch +import numpy as np +import torch.utils.checkpoint as checkpoint +from einops import rearrange +from copy import deepcopy +from torch import nn +from pymic.net.net3d.trans3d.HiFormer_v1 import BasicLayer + +class ConvBlock(nn.Module): + """ + 2D or 3D convolutional block + + :param in_channels: (int) Input channel number. + :param out_channels: (int) Output channel number. + :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. + :param dropout_p: (int) Dropout probability. + """ + def __init__(self, in_channels, out_channels, dim = 2, dropout_p = 0.0): + super(ConvBlock, self).__init__() + assert(dim == 2 or dim == 3) + if(dim == 2): + kernel_size = [1, 3, 3] + padding = [0, 1, 1] + else: + kernel_size = 3 + padding = 1 + + self.conv_conv = nn.Sequential( + nn.BatchNorm3d(in_channels), + nn.PReLU(), + nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=padding), + nn.BatchNorm3d(out_channels), + nn.PReLU(), + nn.Dropout(dropout_p), + nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, padding=padding), + ) + + def forward(self, x): + return self.conv_conv(x) + + +class DownSample(nn.Module): + def __init__(self, in_channels, out_channels, down_dim = 3, conv_dim = 3): + super(DownSample, self).__init__() + assert(down_dim == 2 or down_dim == 3) + assert(conv_dim == 2 or conv_dim == 3) + + kernel_size = [1, 2, 2] if(down_dim == 2) else 2 + self.pool = nn.MaxPool3d(kernel_size) + + if(conv_dim == 2): + kernel_size = [1, 3, 3] + padding = [0, 1, 1] + else: + kernel_size = 3 + padding = 1 + + self.conv = nn.Sequential( + nn.BatchNorm3d(in_channels), + nn.PReLU(), + nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=padding), + ) + + def forward(self, x): + return self.conv(self.pool(x)) + + + +# class ConvTransBlock(nn.Module): +# def __init__(self, +# input_resolution= [32, 32, 32], +# chns=96, +# depth=2, +# num_head=4, +# window_size=7, +# mlp_ratio=4., +# qkv_bias=True, +# qk_scale=None, +# drop_rate=0., +# attn_drop_rate=0., +# drop_path_rate=0.2, +# norm_layer=nn.LayerNorm, +# patch_norm=True, +# ): +# super().__init__() +# self.conv = ConvBlock(chns, chns, dim = 3, dropout_p = drop_rate) +# self.trans = BasicLayer( +# dim= chns, +# input_resolution= input_resolution, +# depth=depth, +# num_heads=num_head, +# window_size=window_size, +# mlp_ratio=mlp_ratio, +# qkv_bias=qkv_bias, +# qk_scale=qk_scale, +# drop=drop_rate, +# attn_drop=attn_drop_rate, +# drop_path=drop_path_rate, +# norm_layer=norm_layer, +# downsample= None +# ) +# self.norm_layer = nn.LayerNorm(chns) +# self.pos_drop = nn.Dropout(p=drop_rate) + +# def forward(self, x): +# """Forward function.""" +# x1 = self.conv(x) +# C, Ws, Wh, Ww = x.size(1), x.size(2), x.size(3), x.size(4) +# x = x.flatten(2).transpose(1, 2).contiguous() +# x = self.pos_drop(x) +# x2, S, H, W, x, Ws, Wh, Ww = self.trans(x, Ws, Wh, Ww) +# # x2 = self.norm_layer(x2) +# x2 = x2.view(-1, S, H, W, C).permute(0, 4, 1, 2, 3).contiguous() +# return x1 + x2 + +# only using the conv block +class ConvTransBlock(nn.Module): + def __init__(self, + input_resolution= [32, 32, 32], + chns=96, + depth=2, + num_head=4, + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + patch_norm=True, + ): + super().__init__() + self.conv = ConvBlock(chns, chns, dim = 3, dropout_p = drop_rate) + # self.trans = BasicLayer( + # dim= chns, + # input_resolution= input_resolution, + # depth=depth, + # num_heads=num_head, + # window_size=window_size, + # mlp_ratio=mlp_ratio, + # qkv_bias=qkv_bias, + # qk_scale=qk_scale, + # drop=drop_rate, + # attn_drop=attn_drop_rate, + # drop_path=drop_path_rate, + # norm_layer=norm_layer, + # downsample= None + # ) + # self.norm_layer = nn.LayerNorm(chns) + # self.pos_drop = nn.Dropout(p=drop_rate) + + def forward(self, x): + """Forward function.""" + x1 = self.conv(x) + return x1 + # C, Ws, Wh, Ww = x.size(1), x.size(2), x.size(3), x.size(4) + # x = x.flatten(2).transpose(1, 2).contiguous() + # x = self.pos_drop(x) + # x2, S, H, W, x, Ws, Wh, Ww = self.trans(x, Ws, Wh, Ww) + # # x2 = self.norm_layer(x2) + # x2 = x2.view(-1, S, H, W, C).permute(0, 4, 1, 2, 3).contiguous() + # return x1 + x2 + +class ConvLayer(nn.Module): + """ + 2D or 3D convolutional block + + :param in_channels: (int) Input channel number. + :param out_channels: (int) Output channel number. + :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. + :param dropout_p: (int) Dropout probability. + """ + def __init__(self, in_channels, out_channels, kernel = 1, padding = 0): + super(ConvLayer, self).__init__() + + self.conv = nn.Sequential( + nn.BatchNorm3d(in_channels), + nn.PReLU(), + nn.Conv3d(in_channels, out_channels, kernel_size=kernel, padding=padding), + ) + + def forward(self, x): + return self.conv(x) + +class UpCatBlock(nn.Module): + """ + 3D upsampling followed by ConvBlock + + :param in_channels1: (int) Channel number of high-level features. + :param in_channels2: (int) Channel number of low-level features. + :param out_channels: (int) Output channel number. + :param dropout_p: (int) Dropout probability. + :param trilinear: (bool) Use trilinear for up-sampling (by default). + If False, deconvolution is used for up-sampling. + """ + def __init__(self, chns_l, chns_h, up_dim = 3, conv_dim = 3): + super(UpCatBlock, self).__init__() + assert(up_dim == 2 or up_dim == 3) + if(up_dim == 2): + kernel_size, stride = [1, 2, 2], [1, 2, 2] + else: + kernel_size, stride = 2, 2 + + self.up = nn.Sequential( + nn.BatchNorm3d(chns_h), + nn.PReLU(), + nn.ConvTranspose3d(chns_h, chns_l, kernel_size = kernel_size, stride=stride) + ) + + if(conv_dim == 2): + kernel_size, padding = [1, 3, 3], [0, 1, 1] + else: + kernel_size, padding = 3, 1 + self.conv = nn.Sequential( + nn.BatchNorm3d(chns_l*2), + nn.PReLU(), + nn.Conv3d(chns_l*2, chns_l, kernel_size=kernel_size, padding=padding), + ) + + def forward(self, x_l, x_h): + # print("input shapes", x1.shape, x2.shape) + # print("after upsample", x1.shape) + y = torch.cat([x_l, self.up(x_h)], dim=1) + return self.conv(y) + +class Encoder(nn.Module): + def __init__(self, + in_chns = 1 , + ft_chns = [24, 48, 192, 384, 768], + input_size= [32, 128, 128], + down_dims = [3, 3, 3, 3, 3], + conv_dims = [3, 3, 3, 3, 3], + dropout = [0, 0, 0.2, 0.2, 0.2], + depths = [2, 2, 2], + num_heads = [4, 8, 16], + window_sizes = [6, 6, 6], + ): + super().__init__() + self.proj = nn.Conv3d(in_chns, ft_chns[0], kernel_size=3, padding=1) + self.conv0 = ConvBlock(ft_chns[0], ft_chns[0], conv_dims[0], dropout[0]) + self.conv1 = ConvBlock(ft_chns[1], ft_chns[1], conv_dims[1], dropout[1]) + self.conv2 = ConvBlock(ft_chns[2], ft_chns[2], conv_dims[2], dropout[2]) + + self.down1 = DownSample(ft_chns[0], ft_chns[1], down_dims[0], conv_dims[1]) + self.down2 = DownSample(ft_chns[1], ft_chns[2], down_dims[1], conv_dims[2]) + self.down3 = DownSample(ft_chns[2], ft_chns[3], down_dims[2], conv_dims[3]) + self.down4 = DownSample(ft_chns[3], ft_chns[4], down_dims[3], conv_dims[4]) + + down_scales = [] + for i in range(4): + down_scale = [1, 2, 2] if down_dims[i] == 2 else [2, 2, 2] + down_scales.append(down_scale) + + r_t2 = [input_size[i] // down_scales[0][i] // down_scales[1][i] for i in range(3)] + r_t3 = [r_t2[i] // down_scales[2][i] for i in range(3)] + r_t4 = [r_t3[i] // down_scales[3][i] for i in range(3)] + + self.conv_t2 = ConvTransBlock(chns = ft_chns[2], + input_resolution = r_t2, + window_size = window_sizes[0], + depth = depths[0], + num_head = num_heads[0], + drop_rate = dropout[2], + attn_drop_rate=dropout[2] + ) + self.conv_t3 = ConvTransBlock(chns = ft_chns[3], + input_resolution = r_t3, + window_size = window_sizes[1], + depth = depths[1], + num_head = num_heads[1], + drop_rate = dropout[3], + attn_drop_rate=dropout[3] + ) + self.conv_t4 = ConvTransBlock(chns = ft_chns[4], + input_resolution = r_t4, + window_size = window_sizes[2], + depth = depths[2], + num_head = num_heads[2], + drop_rate = dropout[4], + attn_drop_rate=dropout[4] + ) + + + + def forward(self, x): + """Forward function.""" + x0 = self.conv0(self.proj(x)) + x1 = self.conv1(self.down1(x0)) + x2 = self.conv2(self.down2(x1)) + x2 = self.conv_t2(x2) + x3 = self.conv_t3(self.down3(x2)) + x4 = self.conv_t4(self.down4(x3)) + return x0, x1, x2, x3, x4 + + +class Decoder(nn.Module): + """ + Decoder of 3D UNet. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param trilinear: (bool) Using bilinear for up-sampling or not. + If False, deconvolution will be used for up-sampling. + :param multiscale_pred: (bool) Get multi-scale prediction. + """ + def __init__(self, + ft_chns = [24, 48, 192, 384, 768], + input_size= [32, 128, 128], + down_dims = [3, 3, 3, 3, 3], + conv_dims = [3, 3, 3, 3, 3], + dropout = [0, 0, 0.2, 0.2, 0.2], + depths = [2, 2, 2], + num_heads = [4, 8, 16], + window_sizes = [6, 6, 6], + class_num = 2, + multiscale_pred = False + ): + super(Decoder, self).__init__() + # self.up0 = UpCatBlock(ft_chns[0] // 2, ft_chns[0], down_dims[0], 3) + # self.conv0 = ConvBlock(ft_chns[0] // 2, ft_chns[0] // 2, 3, 0) + self.up1 = UpCatBlock(ft_chns[0], ft_chns[1], down_dims[0], conv_dims[0]) + self.up2 = UpCatBlock(ft_chns[1], ft_chns[2], down_dims[1], conv_dims[1]) + self.up3 = UpCatBlock(ft_chns[2], ft_chns[3], down_dims[2], conv_dims[2]) + self.up4 = UpCatBlock(ft_chns[3], ft_chns[4], down_dims[3], conv_dims[3]) + + down_scales = [] + for i in range(4): + down_scale = [1, 2, 2] if down_dims[i] == 2 else [2, 2, 2] + down_scales.append(down_scale) + + r_t2 = [input_size[i] // down_scales[0][i] // down_scales[1][i] for i in range(3)] + r_t3 = [r_t2[i] // down_scales[2][i] for i in range(3)] + + self.conv0 = ConvBlock(ft_chns[0], ft_chns[0], conv_dims[0], dropout[0]) + self.conv1 = ConvBlock(ft_chns[1], ft_chns[1], conv_dims[1], dropout[1]) + self.conv2 = ConvTransBlock(chns = ft_chns[2], + input_resolution = r_t2, + window_size = window_sizes[0], + depth = depths[0], + num_head = num_heads[0], + drop_rate = dropout[2], + attn_drop_rate=dropout[2] + ) + self.conv3 = ConvTransBlock(chns = ft_chns[3], + input_resolution = r_t3, + window_size = window_sizes[1], + depth = depths[1], + num_head = num_heads[1], + drop_rate = dropout[3], + attn_drop_rate=dropout[3] + ) + + self.out_conv0 = ConvLayer(ft_chns[0], class_num) + + self.mul_pred = multiscale_pred + if(self.mul_pred): + self.out_conv1 = ConvLayer(ft_chns[1], class_num) + self.out_conv2 = ConvLayer(ft_chns[2], class_num) + self.out_conv3 = ConvLayer(ft_chns[3], class_num) + + def forward(self, x): + x0, x1, x2, x3, x4 = x + + x_d3 = self.conv3(self.up4(x3, x4)) + x_d2 = self.conv2(self.up3(x2, x_d3)) + x_d1 = self.conv1(self.up2(x1, x_d2)) + x_d0 = self.conv0(self.up1(x0, x_d1)) + output = self.out_conv0(x_d0) + + if(self.mul_pred): + output1 = self.out_conv1(x_d1) + output2 = self.out_conv2(x_d2) + output3 = self.out_conv3(x_d3) + output = [output, output1, output2, output3] + return output + +class HiFormer_v4(nn.Module): + def __init__(self, params): + """ + replace the embedding layer with convolutional blocks + """ + super(HiFormer_v4, self).__init__() + in_chns = params["in_chns"] + class_num = params["class_num"] + input_size = params["input_size"] + ft_chns = params.get("feature_chns", [32, 64, 128, 256, 512]) + down_dims = params.get("down_dims", [3, 3, 3, 3, 3]) + conv_dims = params.get("conv_dims", [3, 3, 3, 3, 3]) + dropout = params.get('dropout', [0, 0, 0.2, 0.2, 0.2]) + depths = params.get("depths", [2, 2, 2]) + num_heads = params.get("num_heads", [4, 8, 16]) + window_sizes= params.get("window_sizes", [6, 6, 6]) + multiscale_pred = params.get("multiscale_pred", False) + + self.encoder = Encoder(in_chns, + ft_chns = ft_chns, + input_size = input_size, + down_dims = down_dims, + conv_dims = conv_dims, + dropout = dropout, + depths = depths, + num_heads = num_heads, + window_sizes= window_sizes) + + self.decoder = Decoder(ft_chns = ft_chns, + input_size = input_size, + down_dims = down_dims, + conv_dims = conv_dims, + dropout = dropout, + depths = depths, + num_heads = num_heads, + window_sizes= window_sizes, + class_num = class_num, + multiscale_pred = multiscale_pred + ) + + def forward(self, x): + x = self.encoder(x) + x = self.decoder(x) + return x + + +if __name__ == "__main__": + params = {"input_size": [64, 96, 96], + "in_chns": 1, + "down_dims": [3, 3, 3, 3, 3], + "conv_dims": [3, 3, 3, 3, 3], + "feature_chns": [32, 64, 128, 256, 512], + "class_num": 5, + "multiscale_pred": True} + Net = HiFormer_v4(params) + Net = Net.double() + + x = np.random.rand(1, 1, 64, 96, 96) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = Net(xt) + if(params['multiscale_pred']): + for yi in y: + print(yi.shape) + else: + print(y.shape) + + + diff --git a/pymic/net/net3d/trans3d/HiFormer_v5.py b/pymic/net/net3d/trans3d/HiFormer_v5.py new file mode 100644 index 0000000..5fcef5a --- /dev/null +++ b/pymic/net/net3d/trans3d/HiFormer_v5.py @@ -0,0 +1,308 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import torch.nn as nn +import numpy as np +from torch.nn.functional import interpolate + + +class ConvBlock(nn.Module): + """ + 2D or 3D convolutional block + + :param in_channels: (int) Input channel number. + :param out_channels: (int) Output channel number. + :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. + :param dropout_p: (int) Dropout probability. + """ + def __init__(self, in_channels, out_channels, dropout_p = 0.0, dim = 3): + super(ConvBlock, self).__init__() + assert(dim == 2 or dim == 3) + if(dim == 2): + kernel_size = [1, 3, 3] + padding = [0, 1, 1] + else: + kernel_size = 3 + padding = 1 + + self.conv_conv = nn.Sequential( + nn.BatchNorm3d(in_channels), + nn.LeakyReLU(), + nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=padding), + nn.BatchNorm3d(out_channels), + nn.LeakyReLU(), + nn.Dropout(dropout_p), + nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, padding=padding), + ) + + def forward(self, x): + return self.conv_conv(x) + +class ConvLayer(nn.Module): + """ + 2D or 3D convolutional block + + :param in_channels: (int) Input channel number. + :param out_channels: (int) Output channel number. + :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. + :param dropout_p: (int) Dropout probability. + """ + def __init__(self, in_channels, out_channels, kernel = 1, padding = 0): + super(ConvLayer, self).__init__() + + self.conv = nn.Sequential( + nn.BatchNorm3d(in_channels), + nn.LeakyReLU(), + nn.Conv3d(in_channels, out_channels, kernel_size=kernel, padding=padding), + ) + + def forward(self, x): + return self.conv(x) + +class DownBlock(nn.Module): + """ + 3D downsampling followed by ConvBlock + + :param in_channels: (int) Input channel number. + :param out_channels: (int) Output channel number. + :param dropout_p: (int) Dropout probability. + """ + def __init__(self, in_channels, out_channels, dropout_p): + super(DownBlock, self).__init__() + self.maxpool_conv = nn.Sequential( + nn.MaxPool3d(2), + ConvBlock(in_channels, out_channels, dropout_p) + ) + + def forward(self, x): + return self.maxpool_conv(x) + +class UpBlock(nn.Module): + """ + 3D upsampling followed by ConvBlock + + :param in_channels1: (int) Channel number of high-level features. + :param in_channels2: (int) Channel number of low-level features. + :param out_channels: (int) Output channel number. + :param dropout_p: (int) Dropout probability. + :param trilinear: (bool) Use trilinear for up-sampling (by default). + If False, deconvolution is used for up-sampling. + """ + def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, + trilinear=True): + super(UpBlock, self).__init__() + self.trilinear = trilinear + if trilinear: + self.conv1x1 = nn.Conv3d(in_channels1, in_channels2, kernel_size = 1) + self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True) + else: + self.up = nn.Sequential( + nn.BatchNorm3d(in_channels1), + nn.LeakyReLU(), + nn.ConvTranspose3d(in_channels1, in_channels2, kernel_size=2, stride=2) + ) + self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p) + + def forward(self, x1, x2): + if self.trilinear: + x1 = self.conv1x1(x1) + x1 = self.up(x1) + x = torch.cat([x2, x1], dim=1) + return self.conv(x) + +class Encoder(nn.Module): + """ + Encoder of 3D UNet. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + """ + def __init__(self, params): + super(Encoder, self).__init__() + self.params = params + self.in_chns = self.params['in_chns'] + self.ft_chns = self.params['feature_chns'] + self.dropout = self.params['dropout'] + assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) + + self.proj = nn.Conv3d(self.in_chns, self.ft_chns[0], kernel_size=3, padding=1) + self.in_conv= ConvBlock(self.ft_chns[0], self.ft_chns[0], self.dropout[0]) + self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) + self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) + self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) + if(len(self.ft_chns) == 5): + self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) + + def forward(self, x): + x0 = self.in_conv(self.proj(x)) + x1 = self.down1(x0) + x2 = self.down2(x1) + x3 = self.down3(x2) + output = [x0, x1, x2, x3] + if(len(self.ft_chns) == 5): + x4 = self.down4(x3) + output.append(x4) + return output + +class Decoder(nn.Module): + """ + Decoder of 3D UNet. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param trilinear: (bool) Using bilinear for up-sampling or not. + If False, deconvolution will be used for up-sampling. + :param multiscale_pred: (bool) Get multi-scale prediction. + """ + def __init__(self, params): + super(Decoder, self).__init__() + self.params = params + self.in_chns = self.params['in_chns'] + self.ft_chns = self.params['feature_chns'] + self.dropout = self.params['dropout'] + self.n_class = self.params['class_num'] + self.trilinear = self.params.get('trilinear', True) + self.mul_pred = self.params.get('multiscale_pred', False) + + assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) + + if(len(self.ft_chns) == 5): + self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.trilinear) + self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.trilinear) + self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.trilinear) + self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.trilinear) + self.out_conv = ConvLayer(self.ft_chns[0], self.n_class) + if(self.mul_pred): + self.out_conv1 = ConvLayer(self.ft_chns[1], self.n_class) + self.out_conv2 = ConvLayer(self.ft_chns[2], self.n_class) + self.out_conv3 = ConvLayer(self.ft_chns[3], self.n_class) + + def forward(self, x): + if(len(self.ft_chns) == 5): + assert(len(x) == 5) + x0, x1, x2, x3, x4 = x + x_d3 = self.up1(x4, x3) + else: + assert(len(x) == 4) + x0, x1, x2, x3 = x + x_d3 = x3 + x_d2 = self.up2(x_d3, x2) + x_d1 = self.up3(x_d2, x1) + x_d0 = self.up4(x_d1, x0) + output = self.out_conv(x_d0) + if(self.mul_pred): + output1 = self.out_conv1(x_d1) + output2 = self.out_conv2(x_d2) + output3 = self.out_conv3(x_d3) + output = [output, output1, output2, output3] + return output + +class HiFormer_v5(nn.Module): + """ + An implementation of the U-Net. + + * Reference: Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, Olaf Ronneberger: + 3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation. + `MICCAI (2) 2016: 424-432. `_ + + Note that there are some modifications from the original paper, such as + the use of batch normalization, dropout, leaky relu and deep supervision. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param trilinear: (bool) Using trilinear for up-sampling or not. + If False, deconvolution will be used for up-sampling. + :param multiscale_pred: (bool) Get multi-scale prediction. + """ + def __init__(self, params): + super(HiFormer_v5, self).__init__() + self.params = params + self.in_chns = self.params['in_chns'] + self.ft_chns = self.params['feature_chns'] + self.dropout = self.params['dropout'] + self.n_class = self.params['class_num'] + self.trilinear = self.params['trilinear'] + self.mul_pred = self.params['multiscale_pred'] + assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) + + self.proj = nn.Conv3d(self.in_chns, self.ft_chns[0], kernel_size=3, padding=1) + self.in_conv= ConvBlock(self.ft_chns[0], self.ft_chns[0], self.dropout[0]) + self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) + self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) + self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) + if(len(self.ft_chns) == 5): + self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) + self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], + dropout_p = self.dropout[3], trilinear=self.trilinear) + self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], + dropout_p = self.dropout[2], trilinear=self.trilinear) + self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], + dropout_p = self.dropout[1], trilinear=self.trilinear) + self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], + dropout_p = self.dropout[0], trilinear=self.trilinear) + + self.out_conv = ConvLayer(self.ft_chns[0], self.n_class) + if(self.mul_pred): + self.out_conv1 = ConvLayer(self.ft_chns[1], self.n_class) + self.out_conv2 = ConvLayer(self.ft_chns[2], self.n_class) + self.out_conv3 = ConvLayer(self.ft_chns[3], self.n_class) + + def forward(self, x): + x0 = self.in_conv(self.proj(x)) + x1 = self.down1(x0) + x2 = self.down2(x1) + x3 = self.down3(x2) + if(len(self.ft_chns) == 5): + x4 = self.down4(x3) + x_d3 = self.up1(x4, x3) + else: + x_d3 = x3 + x_d2 = self.up2(x_d3, x2) + x_d1 = self.up3(x_d2, x1) + x_d0 = self.up4(x_d1, x0) + output = self.out_conv(x_d0) + if(self.mul_pred): + output1 = self.out_conv1(x_d1) + output2 = self.out_conv2(x_d2) + output3 = self.out_conv3(x_d3) + output = [output, output1, output2, output3] + return output + +if __name__ == "__main__": + params = {'in_chns':4, + 'class_num': 2, + 'feature_chns':[32, 64, 128, 256, 512], + 'dropout' : [0, 0, 0, 0, 0.5], + 'trilinear': False, + 'multiscale_pred': False} + Net = HiFormer_v5(params) + Net = Net.double() + + x = np.random.rand(4, 4, 96, 96, 96) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = Net(xt) + y = y.detach().numpy() + print(y.shape) diff --git a/pymic/net/net3d/trans3d/MedFormer_v1.py b/pymic/net/net3d/trans3d/MedFormer_v1.py new file mode 100644 index 0000000..1f2ed54 --- /dev/null +++ b/pymic/net/net3d/trans3d/MedFormer_v1.py @@ -0,0 +1,173 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import math +import torch +import torch.nn as nn +import numpy as np +from torch.nn import Dropout, Softmax, Linear, Conv2d, LayerNorm +from pymic.net.net3d.unet3d import Encoder, Decoder + +class Attention(nn.Module): + def __init__(self, params): + super(Attention, self).__init__() + hidden_size = params["attention_hidden_size"] + self.num_attention_heads = params["attention_num_heads"] + self.attention_head_size = int(hidden_size / self.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = Linear(hidden_size, self.all_head_size) + self.key = Linear(hidden_size, self.all_head_size) + self.value = Linear(hidden_size, self.all_head_size) + + self.out = Linear(hidden_size, hidden_size) + self.attn_dropout = Dropout(params["attention_dropout_rate"]) + self.proj_dropout = Dropout(params["attention_dropout_rate"]) + + self.softmax = Softmax(dim=-1) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + attention_probs = self.softmax(attention_scores) + # weights = attention_probs if self.vis else None + attention_probs = self.attn_dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + attention_output = self.out(context_layer) + attention_output = self.proj_dropout(attention_output) + return attention_output + +class MLP(nn.Module): + def __init__(self, params): + super(MLP, self).__init__() + hidden_size = params["attention_hidden_size"] + mlp_dim = params["attention_mlp_dim"] + self.fc1 = Linear(hidden_size, mlp_dim) + self.fc2 = Linear(mlp_dim, hidden_size) + self.act_fn = torch.nn.functional.gelu + self.dropout = Dropout(params["attention_dropout_rate"]) + + self._init_weights() + + def _init_weights(self): + nn.init.xavier_uniform_(self.fc1.weight) + nn.init.xavier_uniform_(self.fc2.weight) + nn.init.normal_(self.fc1.bias, std=1e-6) + nn.init.normal_(self.fc2.bias, std=1e-6) + + def forward(self, x): + x = self.fc1(x) + x = self.act_fn(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + +class Block(nn.Module): + def __init__(self, params): + super(Block, self).__init__() + hidden_size = params["attention_hidden_size"] + self.attention_norm = LayerNorm(hidden_size, eps=1e-6) + self.ffn_norm = LayerNorm(hidden_size, eps=1e-6) + self.ffn = MLP(params) + self.attn = Attention(params) + + def forward(self, x): + # convert the tensor shape from [B, C, D, H, W] to [B, DHW, C] + [B, C, D, H, W] = list(x.shape) + new_shape = [B, C, D*H*W] + x = torch.reshape(x, new_shape) + x = torch.transpose(x, 1, 2) + + h = x + x = self.attention_norm(x) + x = self.attn(x) + x = x + h + + h = x + x = self.ffn_norm(x) + x = self.ffn(x) + x = x + h + + # convert the result back to [B, C, D, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, [B, C, D, H, W]) + return x + +class MedFormerV1(nn.Module): + """ + An implementation of the U-Net. + + * Reference: Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, Olaf Ronneberger: + 3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation. + `MICCAI (2) 2016: 424-432. `_ + + Note that there are some modifications from the original paper, such as + the use of batch normalization, dropout, leaky relu and deep supervision. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param trilinear: (bool) Using trilinear for up-sampling or not. + If False, deconvolution will be used for up-sampling. + :param deep_supervise: (bool) Using deep supervision for training or not. + """ + def __init__(self, params): + super(MedFormerV1, self).__init__() + self.params = params + self.encoder = Encoder(params) + self.decoder = Decoder(params) + self.attn = Block(params) + + def forward(self, x): + f = self.encoder(x) + f[-1] = self.attn(f[-1]) + output = self.decoder(f) + return output + +if __name__ == "__main__": + params = {'in_chns':4, + 'class_num': 2, + 'feature_chns':[16, 32, 64, 128], + 'dropout' : [0, 0, 0, 0.5], + 'trilinear': True, + 'deep_supervise': True, + 'attention_hidden_size': 128, + 'attention_num_heads': 4, + 'attention_mlp_dim': 256, + 'attention_dropout_rate': 0.2} + Net = MedFormerV1(params) + Net = Net.double() + + x = np.random.rand(1, 4, 96, 96, 96) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = Net(xt) + print("output length", len(y)) + for yi in y: + yi = yi.detach().numpy() + print(yi.shape) diff --git a/pymic/net/net3d/trans3d/MedFormer_v2.py b/pymic/net/net3d/trans3d/MedFormer_v2.py new file mode 100644 index 0000000..00cb295 --- /dev/null +++ b/pymic/net/net3d/trans3d/MedFormer_v2.py @@ -0,0 +1,464 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import math +import copy +import torch +import torch.nn as nn +import numpy as np +import torch.nn.functional as F +from pymic.net.net3d.unet3d import ConvBlock, Encoder, Decoder +from pymic.net.net3d.trans3d.MedFormer_v1 import Block +from timm.models.layers import DropPath, to_3tuple, trunc_normal_ + + +# code from nnFormer +class Mlp(nn.Module): + """ Multilayer perceptron.""" + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + + B, S, H, W, C = x.shape + x = x.view(B, S // window_size, window_size, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, S, H, W): + + B = int(windows.shape[0] / (S * H * W / window_size / window_size / window_size)) + x = windows.view(B, S // window_size, H // window_size, W // window_size, window_size, window_size, window_size, -1) + x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, S, H, W, -1) + return x + + +class WindowAttention(nn.Module): + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), + num_heads)) + + # get pair-wise relative position index for each token inside the window + coords_s = torch.arange(self.window_size[0]) + coords_h = torch.arange(self.window_size[1]) + coords_w = torch.arange(self.window_size[2]) + coords = torch.stack(torch.meshgrid([coords_s, coords_h, coords_w])) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 2] += self.window_size[2] - 1 + + relative_coords[:, :, 0] *= 3 * self.window_size[1] - 1 + relative_coords[:, :, 1] *= 2 * self.window_size[1] - 1 + + relative_position_index = relative_coords.sum(-1) + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + trunc_normal_(self.relative_position_bias_table, std=.02) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None,pos_embed=None): + + B_, N, C = x.shape + + qkv = self.qkv(x) + + qkv=qkv.reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1).contiguous()) + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] * self.window_size[2], + self.window_size[0] * self.window_size[1] * self.window_size[2], -1) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C).contiguous() + if pos_embed is not None: + x = x+pos_embed + x = self.proj(x) + x = self.proj_drop(x) + return x + +class SwinTransformerBlock(nn.Module): + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + + self.attn = WindowAttention( + dim, window_size=to_3tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + + def forward(self, x, mask_matrix): + + B, L, C = x.shape + S, H, W = self.input_resolution + + assert L == S * H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, S, H, W, C) + + # pad feature maps to multiples of window size + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + pad_g = (self.window_size - S % self.window_size) % self.window_size + + x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b, 0, pad_g)) + _, Sp, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size,-self.shift_size), dims=(1, 2,3)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size * self.window_size, + C) + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=attn_mask,pos_embed=None) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Sp, Hp, Wp) + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size, self.shift_size), dims=(1, 2, 3)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0 or pad_g > 0: + x = x[:, :S, :H, :W, :].contiguous() + + x = x.view(B, S * H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + +class BasicLayer(nn.Module): + + def __init__(self, + dim, + input_resolution, + depth, + num_heads, + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=True + ): + super().__init__() + self.window_size = window_size + self.shift_size = window_size // 2 + self.depth = depth + # build blocks + + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, S, H, W): + + + # calculate attention mask for SW-MSA + Sp = int(np.ceil(S / self.window_size)) * self.window_size + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Sp, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + s_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for s in s_slices: + for h in h_slices: + for w in w_slices: + img_mask[:, s, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) + mask_windows = mask_windows.view(-1, + self.window_size * self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + for blk in self.blocks: + + x = blk(x, attn_mask) + if self.downsample is not None: + x_down = self.downsample(x, S, H, W) + Ws, Wh, Ww = (S + 1) // 2, (H + 1) // 2, (W + 1) // 2 + return x, S, H, W, x_down, Ws, Wh, Ww + else: + return x, S, H, W, x, S, H, W + + +class AttUpBlock(nn.Module): + """ + 3D upsampling followed by ConvBlock + + :param in_channels1: (int) Channel number of high-level features. + :param in_channels2: (int) Channel number of low-level features. + :param out_channels: (int) Output channel number. + :param dropout_p: (int) Dropout probability. + :param trilinear: (bool) Use trilinear for up-sampling (by default). + If False, deconvolution is used for up-sampling. + """ + def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, + trilinear=True, with_att = False, att_params = None): + super(AttUpBlock, self).__init__() + self.trilinear = trilinear + self.with_att = with_att + if trilinear: + self.conv1x1 = nn.Conv3d(in_channels1, in_channels2, kernel_size = 1) + self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True) + else: + self.up = nn.ConvTranspose3d(in_channels1, in_channels2, kernel_size=2, stride=2) + self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p) + if(self.with_att): + input_resolution = att_params['input_resolution'] + depth = att_params['depth'] + num_heads = att_params['num_heads'] + self.attn = BasicLayer(out_channels, input_resolution, depth, num_heads, downsample=None) + + def forward(self, x1, x2): + if self.trilinear: + x1 = self.conv1x1(x1) + x1 = self.up(x1) + x = torch.cat([x2, x1], dim=1) + x = self.conv(x) + if(self.with_att): + [B, C, D, H, W] = list(x.shape) + x = x.flatten(2).transpose(1, 2).contiguous() + x = self.attn(x, D, H, W)[0] + x = x.view(-1, D, H, W, C).permute(0, 4, 1, 2, 3).contiguous() + return x + +class AttDecoder(nn.Module): + """ + Decoder of 3D UNet. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param trilinear: (bool) Using bilinear for up-sampling or not. + If False, deconvolution will be used for up-sampling. + """ + def __init__(self, params): + super(AttDecoder, self).__init__() + self.params = params + self.in_chns = self.params['in_chns'] + self.ft_chns = self.params['feature_chns'] + self.dropout = self.params['dropout'] + self.n_class = self.params['class_num'] + self.trilinear = self.params.get('trilinear', True) + self.mul_pred = self.params['multiscale_pred'] + + assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) + + if(len(self.ft_chns) == 5): + self.up1 = AttUpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.trilinear) + att_params = {"input_resolution": [24, 24, 24], "depth": 2, "num_heads": 4} + self.up2 = AttUpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.trilinear, True, att_params) + att_params = {"input_resolution": [48, 48, 48], "depth": 2, "num_heads": 4} + self.up3 = AttUpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.trilinear, True, att_params) + self.up4 = AttUpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.trilinear) + self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 1) + if(self.mul_pred): + self.out_conv1 = nn.Conv3d(self.ft_chns[1], self.n_class, kernel_size = 1) + self.out_conv2 = nn.Conv3d(self.ft_chns[2], self.n_class, kernel_size = 1) + self.out_conv3 = nn.Conv3d(self.ft_chns[3], self.n_class, kernel_size = 1) + + def forward(self, x): + if(len(self.ft_chns) == 5): + assert(len(x) == 5) + x0, x1, x2, x3, x4 = x + x_d3 = self.up1(x4, x3) + else: + assert(len(x) == 4) + x0, x1, x2, x3 = x + x_d3 = x3 + x_d2 = self.up2(x_d3, x2) + x_d1 = self.up3(x_d2, x1) + x_d0 = self.up4(x_d1, x0) + output = self.out_conv(x_d0) + if(self.mul_pred): + output1 = self.out_conv1(x_d1) + output2 = self.out_conv2(x_d2) + output3 = self.out_conv3(x_d3) + output = [output, output1, output2, output3] + return output + +class MedFormerV2(nn.Module): + """ + An implementation of the U-Net. + + * Reference: Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, Olaf Ronneberger: + 3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation. + `MICCAI (2) 2016: 424-432. `_ + + Note that there are some modifications from the original paper, such as + the use of batch normalization, dropout, leaky relu and deep supervision. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param trilinear: (bool) Using trilinear for up-sampling or not. + If False, deconvolution will be used for up-sampling. + """ + def __init__(self, params): + super(MedFormerV2, self).__init__() + self.params = params + self.encoder = Encoder(params) + self.decoder = AttDecoder(params) + self.attn = Block(params) + + def forward(self, x): + f = self.encoder(x) + f[-1] = self.attn(f[-1]) + output = self.decoder(f) + return output + +if __name__ == "__main__": + params = {'in_chns':4, + 'class_num': 2, + 'feature_chns':[16, 32, 64, 128], + 'dropout' : [0, 0, 0, 0.5], + 'trilinear': True, + 'multiscale_pred': True, + 'attention_hidden_size': 128, + 'attention_num_heads': 4, + 'attention_mlp_dim': 256, + 'attention_dropout_rate': 0.2} + + Net = MedFormerV2(params) + Net = Net.double() + + x = np.random.rand(1, 4, 96, 96, 96) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = Net(xt) + print("output length", len(y)) + for yi in y: + yi = yi.detach().numpy() + print(yi.shape) diff --git a/pymic/net/net3d/trans3d/MedFormer_v3.py b/pymic/net/net3d/trans3d/MedFormer_v3.py new file mode 100644 index 0000000..f119a9c --- /dev/null +++ b/pymic/net/net3d/trans3d/MedFormer_v3.py @@ -0,0 +1,255 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import numpy as np +import torch +import torch.nn as nn +from torch.nn.functional import interpolate +from pymic.net.net3d.unet3d import ConvBlock, Encoder +from pymic.net.net3d.trans3d.MedFormer_v1 import Block +from pymic.net.net3d.trans3d.MedFormer_v2 import SwinTransformerBlock, window_partition + +class GLAttLayer(nn.Module): + def __init__(self, + dim, + input_resolution, + num_heads, + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm): + super().__init__() + self.window_size = window_size + self.shift_size = window_size // 2 + # build blocks + + self.lcl_att = SwinTransformerBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=0, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path) + self.adpool = nn.AdaptiveAvgPool3d([12, 12, 12]) + + params = {'attention_hidden_size': dim, + 'attention_num_heads': 4, + 'attention_mlp_dim': dim, + 'attention_dropout_rate': 0.2} + self.glb_att = Block(params) + self.conv1x1 = nn.Sequential( + nn.Conv3d(2*dim, dim, kernel_size=1), + nn.BatchNorm3d(dim), + nn.LeakyReLU()) + + def forward(self, x): + [B, C, S, H, W] = list(x.shape) + # calculate attention mask for SW-MSA + Sp = int(np.ceil(S / self.window_size)) * self.window_size + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Sp, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + s_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for s in s_slices: + for h in h_slices: + for w in w_slices: + img_mask[:, s, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) + mask_windows = mask_windows.view(-1, + self.window_size * self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + # for local attention + xl = x.flatten(2).transpose(1, 2).contiguous() + xl = self.lcl_att(xl, attn_mask) + xl = xl.view(-1, S, H, W, C).permute(0, 4, 1, 2, 3).contiguous() + + # for global attention + xg = self.adpool(x) + xg = self.glb_att(xg) + xg = interpolate(xg, [S, H, W], mode = 'trilinear') + out = torch.cat([xl, xg], dim=1) + out = self.conv1x1(out) + return out + +class AttUpBlock(nn.Module): + """ + 3D upsampling followed by ConvBlock + + :param in_channels1: (int) Channel number of high-level features. + :param in_channels2: (int) Channel number of low-level features. + :param out_channels: (int) Output channel number. + :param dropout_p: (int) Dropout probability. + :param trilinear: (bool) Use trilinear for up-sampling (by default). + If False, deconvolution is used for up-sampling. + """ + def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, + trilinear=True, with_att = False, att_params = None): + super(AttUpBlock, self).__init__() + self.trilinear = trilinear + self.with_att = with_att + if trilinear: + self.conv1x1 = nn.Conv3d(in_channels1, in_channels2, kernel_size = 1) + self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True) + else: + self.up = nn.ConvTranspose3d(in_channels1, in_channels2, kernel_size=2, stride=2) + self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p) + if(self.with_att): + input_resolution = att_params['input_resolution'] + num_heads = att_params['num_heads'] + window_size = att_params['window_size'] + self.attn = GLAttLayer(out_channels, input_resolution, num_heads, window_size, 2.0) + + def forward(self, x1, x2): + if self.trilinear: + x1 = self.conv1x1(x1) + x1 = self.up(x1) + x = torch.cat([x2, x1], dim=1) + x = self.conv(x) + if(self.with_att): + x = self.attn(x) + return x + + +class AttDecoder(nn.Module): + """ + Decoder of 3D UNet. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param trilinear: (bool) Using bilinear for up-sampling or not. + If False, deconvolution will be used for up-sampling. + """ + def __init__(self, params): + super(AttDecoder, self).__init__() + self.params = params + self.in_chns = self.params['in_chns'] + self.ft_chns = self.params['feature_chns'] + self.dropout = self.params['dropout'] + self.n_class = self.params['class_num'] + self.trilinear = self.params.get('trilinear', True) + self.mul_pred = self.params['multiscale_pred'] + + assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) + + if(len(self.ft_chns) == 5): + self.up1 = AttUpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.trilinear) + att_params = {"input_resolution": [24, 24, 24], "num_heads": 4, "window_size": 7} + self.up2 = AttUpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.trilinear, True, att_params) + att_params = {"input_resolution": [48, 48, 48], "num_heads": 4, "window_size": 7} + self.up3 = AttUpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.trilinear, True, att_params) + self.up4 = AttUpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.trilinear) + self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 1) + if(self.mul_pred): + self.out_conv1 = nn.Conv3d(self.ft_chns[1], self.n_class, kernel_size = 1) + self.out_conv2 = nn.Conv3d(self.ft_chns[2], self.n_class, kernel_size = 1) + self.out_conv3 = nn.Conv3d(self.ft_chns[3], self.n_class, kernel_size = 1) + + def forward(self, x): + if(len(self.ft_chns) == 5): + assert(len(x) == 5) + x0, x1, x2, x3, x4 = x + x_d3 = self.up1(x4, x3) + else: + assert(len(x) == 4) + x0, x1, x2, x3 = x + x_d3 = x3 + x_d2 = self.up2(x_d3, x2) + x_d1 = self.up3(x_d2, x1) + x_d0 = self.up4(x_d1, x0) + output = self.out_conv(x_d0) + if(self.mul_pred): + output1 = self.out_conv1(x_d1) + output2 = self.out_conv2(x_d2) + output3 = self.out_conv3(x_d3) + output = [output, output1, output2, output3] + return output + +class MedFormerV3(nn.Module): + """ + An implementation of the U-Net. + + * Reference: Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, Olaf Ronneberger: + 3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation. + `MICCAI (2) 2016: 424-432. `_ + + Note that there are some modifications from the original paper, such as + the use of batch normalization, dropout, leaky relu and deep supervision. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param trilinear: (bool) Using trilinear for up-sampling or not. + If False, deconvolution will be used for up-sampling. + """ + def __init__(self, params): + super(MedFormerV3, self).__init__() + self.params = params + self.encoder = Encoder(params) + self.decoder = AttDecoder(params) + params["attention_hidden_size"] = params['feature_chns'][-1] + params["attention_mlp_dim"] = params['feature_chns'][-1] + self.attn = Block(params) + + def forward(self, x): + f = self.encoder(x) + f[-1] = self.attn(f[-1]) + output = self.decoder(f) + return output + +if __name__ == "__main__": + params = {'in_chns':4, + 'class_num': 2, + 'feature_chns':[16, 32, 64, 128], + 'dropout' : [0, 0, 0, 0.5], + 'trilinear': True, + 'multiscale_pred': True, + 'attention_num_heads': 4, + 'attention_dropout_rate': 0.2} + + Net = MedFormerV3(params) + Net = Net.double() + + x = np.random.rand(2, 4, 96, 96, 96) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = Net(xt) + print("output length", len(y)) + for yi in y: + yi = yi.detach().numpy() + print(yi.shape) diff --git a/pymic/net/net3d/trans3d/MedFormer_va1.py b/pymic/net/net3d/trans3d/MedFormer_va1.py new file mode 100644 index 0000000..27dfa3e --- /dev/null +++ b/pymic/net/net3d/trans3d/MedFormer_va1.py @@ -0,0 +1,105 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import math +import torch +import torch.nn as nn +import numpy as np +from torch.nn import Dropout, Softmax, Linear, Conv2d, LayerNorm +from pymic.net.net3d.unet3d import Decoder + +class EmbeddingBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, padding, stride): + super(EmbeddingBlock, self).__init__() + self.out_channels = out_channels + self.conv1 = nn.Conv3d(in_channels, out_channels//2, kernel_size=kernel_size, padding=padding, stride = stride) + self.conv2 = nn.Conv3d(out_channels//2, out_channels, kernel_size=1) + self.act = nn.GELU() + self.norm1 = nn.LayerNorm(out_channels//2) + self.norm2 = nn.LayerNorm(out_channels) + + + def forward(self, x): + x = self.act(self.conv1(x)) + # norm 1 + Ws, Wh, Ww = x.size(2), x.size(3), x.size(4) + x = x.flatten(2).transpose(1, 2).contiguous() + x = self.norm1(x) + x = x.transpose(1, 2).contiguous().view(-1, self.out_channels // 2, Ws, Wh, Ww) + + x = self.act(self.conv2(x)) + x = x.flatten(2).transpose(1, 2).contiguous() + x = self.norm2(x) + x = x.transpose(1, 2).contiguous().view(-1, self.out_channels, Ws, Wh, Ww) + + return x + +class Encoder(nn.Module): + """ + Encoder of 3D UNet. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + """ + def __init__(self, params): + super(Encoder, self).__init__() + self.params = params + self.in_chns = self.params['in_chns'] + self.ft_chns = self.params['feature_chns'] + assert(len(self.ft_chns) == 4) + + self.down0 = EmbeddingBlock(self.in_chns, self.ft_chns[0], 3, 1, 1) + self.down1 = EmbeddingBlock(self.in_chns, self.ft_chns[1], 2, 0, 2) + self.down2 = EmbeddingBlock(self.in_chns, self.ft_chns[2], 4, 0, 4) + self.down3 = EmbeddingBlock(self.in_chns, self.ft_chns[3], 8, 0, 8) + + def forward(self, x): + x0 = self.down0(x) + x1 = self.down1(x) + x2 = self.down2(x) + x3 = self.down3(x) + output = [x0, x1, x2, x3] + return output + +class MedFormerVA1(nn.Module): + def __init__(self, params): + super(MedFormerVA1, self).__init__() + self.params = params + self.encoder = Encoder(params) + self.decoder = Decoder(params) + + def forward(self, x): + f = self.encoder(x) + output = self.decoder(f) + return output + + +if __name__ == "__main__": + params = {'in_chns':1, + 'class_num': 8, + 'feature_chns':[16, 32, 64, 128], + 'dropout' : [0, 0, 0, 0.5], + 'trilinear': True, + 'deep_supervise': True, + 'attention_hidden_size': 128, + 'attention_num_heads': 4, + 'attention_mlp_dim': 256, + 'attention_dropout_rate': 0.2} + Net = MedFormerVA1(params) + Net = Net.double() + + x = np.random.rand(1, 1, 96, 96, 96) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = Net(xt) + print("output length", len(y)) + for yi in y: + yi = yi.detach().numpy() + print(yi.shape) \ No newline at end of file diff --git a/pymic/net/net3d/trans3d/__init__.py b/pymic/net/net3d/trans3d/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pymic/net/net3d/trans3d/nnFormer_wrap.py b/pymic/net/net3d/trans3d/nnFormer_wrap.py new file mode 100644 index 0000000..35593a4 --- /dev/null +++ b/pymic/net/net3d/trans3d/nnFormer_wrap.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import math +import torch +import torch.nn as nn +import numpy as np +from nnformer.network_architecture.nnFormer_tumor import nnFormer + +class nnFormer_wrap(nn.Module): + def __init__(self, params): + super(nnFormer_wrap, self).__init__() + patch_size = params["patch_size"] # 96x96x96 + n_class = params['class_num'] + in_chns = params['in_chns'] + # https://github.com/282857341/nnFormer/blob/main/nnformer/network_architecture/nnFormer_tumor.py + self.nnformer = nnFormer(crop_size = patch_size, + embedding_dim=192, + input_channels = in_chns, + num_classes = n_class, + conv_op=nn.Conv3d, + depths =[2,2,2,2], + num_heads = [6, 12, 24, 48], + patch_size = [4,4,4], + window_size= [4,4,8,4], + deep_supervision=False) + + def forward(self, x): + return self.nnformer(x) + +if __name__ == "__main__": + params = {"patch_size": [96, 96, 96], + "in_chns": 1, + "class_num": 5} + Net = nnFormer_wrap(params) + Net = Net.double() + + x = np.random.rand(1, 1, 96, 96, 96) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = Net(xt) + print(y.shape) diff --git a/pymic/net/net3d/trans3d/unetr.py b/pymic/net/net3d/trans3d/unetr.py new file mode 100644 index 0000000..ea90b2f --- /dev/null +++ b/pymic/net/net3d/trans3d/unetr.py @@ -0,0 +1,227 @@ +from __future__ import print_function, division + +import torch +import torch.nn as nn + +from monai.networks.blocks import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock +from monai.networks.blocks.dynunet_block import UnetOutBlock +from monai.networks.nets import ViT + + +class UNETR(nn.Module): + """ + UNETR based on: "Hatamizadeh et al., + UNETR: Transformers for 3D Medical Image Segmentation " + """ + + def __init__(self, params): + # in_channels: int, + # out_channels: int, + # img_size: Tuple[int, int, int], + # feature_size: int = 16, + # hidden_size: int = 768, + # mlp_dim: int = 3072, + # num_heads: int = 12, + # pos_embed: str = "perceptron", + # norm_name: Union[Tuple, str] = "instance", + # conv_block: bool = False, + # res_block: bool = True, + # dropout_rate: float = 0.0, + # ) -> None: + """ + Args: + in_channels: dimension of input channels. + out_channels: dimension of output channels. + img_size: dimension of input image. + feature_size: dimension of network feature size. + hidden_size: dimension of hidden layer. + mlp_dim: dimension of feedforward layer. + num_heads: number of attention heads. + pos_embed: position embedding layer type. + norm_name: feature normalization type and arguments. + conv_block: bool argument to determine if convolutional block is used. + res_block: bool argument to determine if residual block is used. + dropout_rate: faction of the input units to drop. + Examples:: + # for single channel input 4-channel output with patch size of (96,96,96), feature size of 32 and batch norm + >>> net = UNETR(in_channels=1, out_channels=4, img_size=(96,96,96), feature_size=32, norm_name='batch') + # for 4-channel input 3-channel output with patch size of (128,128,128), conv position embedding and instance norm + >>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), pos_embed='conv', norm_name='instance') + """ + + super().__init__() + in_channels = params['in_chns'] + out_channels = params['class_num'] + img_size = params['img_size'] + feature_size = 16 + hidden_size = 768 + mlp_dim = 3072 + num_heads = 12 + pos_embed = "perceptron" + norm_name = "instance" + conv_block = False + res_block = True + dropout_rate = 0.0 + + if not (0 <= dropout_rate <= 1): + raise AssertionError("dropout_rate should be between 0 and 1.") + + if hidden_size % num_heads != 0: + raise AssertionError("hidden size should be divisible by num_heads.") + + if pos_embed not in ["conv", "perceptron"]: + raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.") + + self.num_layers = 12 + self.patch_size = (16, 16, 16) + self.feat_size = ( + img_size[0] // self.patch_size[0], + img_size[1] // self.patch_size[1], + img_size[2] // self.patch_size[2], + ) + self.hidden_size = hidden_size + self.classification = False + self.vit = ViT( + in_channels=in_channels, + img_size=img_size, + patch_size=self.patch_size, + hidden_size=hidden_size, + mlp_dim=mlp_dim, + num_layers=self.num_layers, + num_heads=num_heads, + pos_embed=pos_embed, + classification=self.classification, + dropout_rate=dropout_rate, + ) + self.encoder1 = UnetrBasicBlock( + spatial_dims=3, + in_channels=in_channels, + out_channels=feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=res_block, + ) + self.encoder2 = UnetrPrUpBlock( + spatial_dims=3, + in_channels=hidden_size, + out_channels=feature_size * 2, + num_layer=2, + kernel_size=3, + stride=1, + upsample_kernel_size=2, + norm_name=norm_name, + conv_block=conv_block, + res_block=res_block, + ) + self.encoder3 = UnetrPrUpBlock( + spatial_dims=3, + in_channels=hidden_size, + out_channels=feature_size * 4, + num_layer=1, + kernel_size=3, + stride=1, + upsample_kernel_size=2, + norm_name=norm_name, + conv_block=conv_block, + res_block=res_block, + ) + self.encoder4 = UnetrPrUpBlock( + spatial_dims=3, + in_channels=hidden_size, + out_channels=feature_size * 8, + num_layer=0, + kernel_size=3, + stride=1, + upsample_kernel_size=2, + norm_name=norm_name, + conv_block=conv_block, + res_block=res_block, + ) + self.decoder5 = UnetrUpBlock( + spatial_dims=3, + in_channels=hidden_size, + out_channels=feature_size * 8, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=res_block, + ) + self.decoder4 = UnetrUpBlock( + spatial_dims=3, + in_channels=feature_size * 8, + out_channels=feature_size * 4, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=res_block, + ) + self.decoder3 = UnetrUpBlock( + spatial_dims=3, + in_channels=feature_size * 4, + out_channels=feature_size * 2, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=res_block, + ) + self.decoder2 = UnetrUpBlock( + spatial_dims=3, + in_channels=feature_size * 2, + out_channels=feature_size, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=res_block, + ) + self.out = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels) # type: ignore + + def proj_feat(self, x, hidden_size, feat_size): + x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) + x = x.permute(0, 4, 1, 2, 3).contiguous() + return x + + def load_from(self, weights): + with torch.no_grad(): + res_weight = weights + # copy weights from patch embedding + for i in weights["state_dict"]: + print(i) + self.vit.patch_embedding.position_embeddings.copy_( + weights["state_dict"]["module.transformer.patch_embedding.position_embeddings_3d"] + ) + self.vit.patch_embedding.cls_token.copy_( + weights["state_dict"]["module.transformer.patch_embedding.cls_token"] + ) + self.vit.patch_embedding.patch_embeddings[1].weight.copy_( + weights["state_dict"]["module.transformer.patch_embedding.patch_embeddings.1.weight"] + ) + self.vit.patch_embedding.patch_embeddings[1].bias.copy_( + weights["state_dict"]["module.transformer.patch_embedding.patch_embeddings.1.bias"] + ) + + # copy weights from encoding blocks (default: num of blocks: 12) + for bname, block in self.vit.blocks.named_children(): + print(block) + block.loadFrom(weights, n_block=bname) + # last norm layer of transformer + self.vit.norm.weight.copy_(weights["state_dict"]["module.transformer.norm.weight"]) + self.vit.norm.bias.copy_(weights["state_dict"]["module.transformer.norm.bias"]) + + def forward(self, x_in): + x, hidden_states_out = self.vit(x_in) + enc1 = self.encoder1(x_in) + x2 = hidden_states_out[3] + enc2 = self.encoder2(self.proj_feat(x2, self.hidden_size, self.feat_size)) + x3 = hidden_states_out[6] + enc3 = self.encoder3(self.proj_feat(x3, self.hidden_size, self.feat_size)) + x4 = hidden_states_out[9] + enc4 = self.encoder4(self.proj_feat(x4, self.hidden_size, self.feat_size)) + dec4 = self.proj_feat(x, self.hidden_size, self.feat_size) + dec3 = self.decoder5(dec4, enc4) + dec2 = self.decoder4(dec3, enc3) + dec1 = self.decoder3(dec2, enc2) + out = self.decoder2(dec1, enc1) + logits = self.out(out) + return logits + diff --git a/pymic/net/net3d/trans3d/unetr_pp.py b/pymic/net/net3d/trans3d/unetr_pp.py new file mode 100644 index 0000000..3ef6736 --- /dev/null +++ b/pymic/net/net3d/trans3d/unetr_pp.py @@ -0,0 +1,460 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional, Sequence, Tuple, Union +from pymic.net.net3d.trans3d.unetr_pp_block import UnetOutBlock, UnetResBlock, get_conv_layer +from timm.models.layers import trunc_normal_ +from monai.utils import optional_import +from monai.networks.blocks.convolutions import Convolution +from monai.networks.layers.factories import Act, Norm +from monai.networks.layers.utils import get_act_layer, get_norm_layer + +einops, _ = optional_import("einops") + +class LayerNorm(nn.Module): + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.normalized_shape = (normalized_shape,) + + def forward(self, x): + if self.data_format == "channels_last": + return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + elif self.data_format == "channels_first": + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + +class EPA(nn.Module): + """ + Efficient Paired Attention Block, based on: "Shaker et al., + UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" + """ + def __init__(self, input_size, hidden_size, proj_size, num_heads=4, qkv_bias=False, + channel_attn_drop=0.1, spatial_attn_drop=0.1): + super().__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + self.temperature2 = nn.Parameter(torch.ones(num_heads, 1, 1)) + + # qkvv are 4 linear layers (query_shared, key_shared, value_spatial, value_channel) + self.qkvv = nn.Linear(hidden_size, hidden_size * 4, bias=qkv_bias) + + # E and F are projection matrices with shared weights used in spatial attention module to project + # keys and values from HWD-dimension to P-dimension + self.E = self.F = nn.Linear(input_size, proj_size) + + self.attn_drop = nn.Dropout(channel_attn_drop) + self.attn_drop_2 = nn.Dropout(spatial_attn_drop) + + self.out_proj = nn.Linear(hidden_size, int(hidden_size // 2)) + self.out_proj2 = nn.Linear(hidden_size, int(hidden_size // 2)) + + def forward(self, x): + B, N, C = x.shape + + qkvv = self.qkvv(x).reshape(B, N, 4, self.num_heads, C // self.num_heads) + + qkvv = qkvv.permute(2, 0, 3, 1, 4) + + q_shared, k_shared, v_CA, v_SA = qkvv[0], qkvv[1], qkvv[2], qkvv[3] + + q_shared = q_shared.transpose(-2, -1) + k_shared = k_shared.transpose(-2, -1) + v_CA = v_CA.transpose(-2, -1) + v_SA = v_SA.transpose(-2, -1) + + k_shared_projected = self.E(k_shared) + + v_SA_projected = self.F(v_SA) + + q_shared = torch.nn.functional.normalize(q_shared, dim=-1) + k_shared = torch.nn.functional.normalize(k_shared, dim=-1) + + attn_CA = (q_shared @ k_shared.transpose(-2, -1)) * self.temperature + + attn_CA = attn_CA.softmax(dim=-1) + attn_CA = self.attn_drop(attn_CA) + + x_CA = (attn_CA @ v_CA).permute(0, 3, 1, 2).reshape(B, N, C) + + attn_SA = (q_shared.permute(0, 1, 3, 2) @ k_shared_projected) * self.temperature2 + + attn_SA = attn_SA.softmax(dim=-1) + attn_SA = self.attn_drop_2(attn_SA) + + x_SA = (attn_SA @ v_SA_projected.transpose(-2, -1)).permute(0, 3, 1, 2).reshape(B, N, C) + + # Concat fusion + x_SA = self.out_proj(x_SA) + x_CA = self.out_proj2(x_CA) + x = torch.cat((x_SA, x_CA), dim=-1) + return x + + @torch.jit.ignore + def no_weight_decay(self): + return {'temperature', 'temperature2'} + + +class TransformerBlock(nn.Module): + """ + A transformer block, based on: "Shaker et al., + UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + proj_size: int, + num_heads: int, + dropout_rate: float = 0.0, + pos_embed=False, + ) -> None: + """ + Args: + input_size: the size of the input for each stage. + hidden_size: dimension of hidden layer. + proj_size: projection size for keys and values in the spatial attention module. + num_heads: number of attention heads. + dropout_rate: faction of the input units to drop. + pos_embed: bool argument to determine if positional embedding is used. + + """ + + super().__init__() + + if not (0 <= dropout_rate <= 1): + raise ValueError("dropout_rate should be between 0 and 1.") + + if hidden_size % num_heads != 0: + print("Hidden size is ", hidden_size) + print("Num heads is ", num_heads) + raise ValueError("hidden_size should be divisible by num_heads.") + + self.norm = nn.LayerNorm(hidden_size) + self.gamma = nn.Parameter(1e-6 * torch.ones(hidden_size), requires_grad=True) + self.epa_block = EPA(input_size=input_size, hidden_size=hidden_size, proj_size=proj_size, num_heads=num_heads, channel_attn_drop=dropout_rate,spatial_attn_drop=dropout_rate) + self.conv51 = UnetResBlock(3, hidden_size, hidden_size, kernel_size=3, stride=1, norm_name="batch") + self.conv8 = nn.Sequential(nn.Dropout3d(0.1, False), nn.Conv3d(hidden_size, hidden_size, 1)) + + self.pos_embed = None + if pos_embed: + self.pos_embed = nn.Parameter(torch.zeros(1, input_size, hidden_size)) + + def forward(self, x): + B, C, H, W, D = x.shape + + x = x.reshape(B, C, H * W * D).permute(0, 2, 1) + + if self.pos_embed is not None: + x = x + self.pos_embed + attn = x + self.gamma * self.epa_block(self.norm(x)) + + attn_skip = attn.reshape(B, H, W, D, C).permute(0, 4, 1, 2, 3) # (B, C, H, W, D) + attn = self.conv51(attn_skip) + x = attn_skip + self.conv8(attn) + + return x + +class UnetrPPEncoder(nn.Module): + def __init__(self, input_size=[32 * 32 * 32, 16 * 16 * 16, 8 * 8 * 8, 4 * 4 * 4],dims=[32, 64, 128, 256], + proj_size =[64,64,64,32], depths=[3, 3, 3, 3], num_heads=4, spatial_dims=3, in_channels=1, dropout=0.0, transformer_dropout_rate=0.15 ,**kwargs): + super().__init__() + + self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers + stem_layer = nn.Sequential( + get_conv_layer(spatial_dims, in_channels, dims[0], kernel_size=(2, 4, 4), stride=(2, 4, 4), + dropout=dropout, conv_only=True, ), + get_norm_layer(name=("group", {"num_groups": in_channels}), channels=dims[0]), + ) + self.downsample_layers.append(stem_layer) + for i in range(3): + downsample_layer = nn.Sequential( + get_conv_layer(spatial_dims, dims[i], dims[i + 1], kernel_size=(2, 2, 2), stride=(2, 2, 2), + dropout=dropout, conv_only=True, ), + get_norm_layer(name=("group", {"num_groups": dims[i]}), channels=dims[i + 1]), + ) + self.downsample_layers.append(downsample_layer) + + self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple Transformer blocks + for i in range(4): + stage_blocks = [] + for j in range(depths[i]): + stage_blocks.append(TransformerBlock(input_size=input_size[i], hidden_size=dims[i], proj_size=proj_size[i], num_heads=num_heads, + dropout_rate=transformer_dropout_rate, pos_embed=True)) + self.stages.append(nn.Sequential(*stage_blocks)) + self.hidden_states = [] + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, (LayerNorm, nn.LayerNorm)): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward_features(self, x): + hidden_states = [] + + x = self.downsample_layers[0](x) + x = self.stages[0](x) + + hidden_states.append(x) + + for i in range(1, 4): + x = self.downsample_layers[i](x) + x = self.stages[i](x) + if i == 3: # Reshape the output of the last stage + x = einops.rearrange(x, "b c h w d -> b (h w d) c") + hidden_states.append(x) + return x, hidden_states + + def forward(self, x): + x, hidden_states = self.forward_features(x) + return x, hidden_states + + +class UnetrUpBlock(nn.Module): + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int], + upsample_kernel_size: Union[Sequence[int], int], + norm_name: Union[Tuple, str], + proj_size: int = 64, + num_heads: int = 4, + out_size: int = 0, + depth: int = 3, + conv_decoder: bool = False, + ) -> None: + """ + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + kernel_size: convolution kernel size. + upsample_kernel_size: convolution kernel size for transposed convolution layers. + norm_name: feature normalization type and arguments. + proj_size: projection size for keys and values in the spatial attention module. + num_heads: number of heads inside each EPA module. + out_size: spatial size for each decoder. + depth: number of blocks for the current decoder stage. + """ + + super().__init__() + upsample_stride = upsample_kernel_size + self.transp_conv = get_conv_layer( + spatial_dims, + in_channels, + out_channels, + kernel_size=upsample_kernel_size, + stride=upsample_stride, + conv_only=True, + is_transposed=True, + ) + + # 4 feature resolution stages, each consisting of multiple residual blocks + self.decoder_block = nn.ModuleList() + + # If this is the last decoder, use ConvBlock(UnetResBlock) instead of EPA_Block (see suppl. material in the paper) + if conv_decoder == True: + self.decoder_block.append( + UnetResBlock(spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1, + norm_name=norm_name, )) + else: + stage_blocks = [] + for j in range(depth): + stage_blocks.append(TransformerBlock(input_size=out_size, hidden_size= out_channels, proj_size=proj_size, num_heads=num_heads, + dropout_rate=0.15, pos_embed=True)) + self.decoder_block.append(nn.Sequential(*stage_blocks)) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, (nn.LayerNorm)): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, inp, skip): + + out = self.transp_conv(inp) + out = out + skip + out = self.decoder_block[0](out) + + return out + + +class UNETR_PP(nn.Module): + """ + UNETR++ based on: "Shaker et al., + UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" + """ + + def __init__(self, params): + """ + Args: + in_channels: dimension of input channels. + out_channels: dimension of output channels. + img_size: dimension of input image. + feature_size: dimension of network feature size. + hidden_size: dimension of the last encoder. + num_heads: number of attention heads. + pos_embed: position embedding layer type. + norm_name: feature normalization type and arguments. + dropout_rate: faction of the input units to drop. + depths: number of blocks for each stage. + dims: number of channel maps for the stages. + conv_op: type of convolution operation. + do_ds: use deep supervision to compute the loss. + + """ + super().__init__() + in_channels = params['in_chns'] + out_channels = params['class_num'] + img_size = params['img_size'] + feature_size = params.get('feature_size', 16) + hidden_size = params.get('hidden_size', 256) + num_heads = params.get('num_heads', 4) + pos_embed = params.get('pos_embed', "perceptron") + norm_name = params.get('norm_name', "instance") + dropout_rate = params.get('dropout_rate', 0.0) + depths = params.get('depths', [3, 3, 3, 3]) + dims = params.get('dims', [32, 64, 128, 256]) + conv_op = nn.Conv3d + do_ds = params.get('deep_supervise', True) + + self.do_ds = do_ds + self.conv_op = conv_op + self.num_classes = out_channels + if not (0 <= dropout_rate <= 1): + raise AssertionError("dropout_rate should be between 0 and 1.") + + if pos_embed not in ["conv", "perceptron"]: + raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.") + + self.patch_size = (2, 4, 4) + self.feat_size = ( + img_size[0] // self.patch_size[0] // 8, # 8 is the downsampling happened through the four encoders stages + img_size[1] // self.patch_size[1] // 8, # 8 is the downsampling happened through the four encoders stages + img_size[2] // self.patch_size[2] // 8, # 8 is the downsampling happened through the four encoders stages + ) + self.hidden_size = hidden_size + + self.unetr_pp_encoder = UnetrPPEncoder(dims=dims, depths=depths, num_heads=num_heads) + + self.encoder1 = UnetResBlock( + spatial_dims=3, + in_channels=in_channels, + out_channels=feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + ) + self.decoder5 = UnetrUpBlock( + spatial_dims=3, + in_channels=feature_size * 16, + out_channels=feature_size * 8, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + out_size=8 * 8 * 8, + ) + self.decoder4 = UnetrUpBlock( + spatial_dims=3, + in_channels=feature_size * 8, + out_channels=feature_size * 4, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + out_size=16 * 16 * 16, + ) + self.decoder3 = UnetrUpBlock( + spatial_dims=3, + in_channels=feature_size * 4, + out_channels=feature_size * 2, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + out_size=32 * 32 * 32, + ) + self.decoder2 = UnetrUpBlock( + spatial_dims=3, + in_channels=feature_size * 2, + out_channels=feature_size, + kernel_size=3, + upsample_kernel_size=(2, 4, 4), + norm_name=norm_name, + out_size=64 * 128 * 128, + conv_decoder=True, + ) + self.out1 = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels) + if self.do_ds: + self.out2 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 2, out_channels=out_channels) + self.out3 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 4, out_channels=out_channels) + + def proj_feat(self, x, hidden_size, feat_size): + x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) + x = x.permute(0, 4, 1, 2, 3).contiguous() + return x + + def forward(self, x_in): + x_output, hidden_states = self.unetr_pp_encoder(x_in) + + convBlock = self.encoder1(x_in) + + # Four encoders + enc1 = hidden_states[0] + enc2 = hidden_states[1] + enc3 = hidden_states[2] + enc4 = hidden_states[3] + + # Four decoders + dec4 = self.proj_feat(enc4, self.hidden_size, self.feat_size) + dec3 = self.decoder5(dec4, enc3) + dec2 = self.decoder4(dec3, enc2) + dec1 = self.decoder3(dec2, enc1) + + out = self.decoder2(dec1, convBlock) + if self.do_ds: + logits = [self.out1(out), self.out2(dec1), self.out3(dec2)] + else: + logits = self.out1(out) + + return logits + + +if __name__ == "__main__": + params = {'in_chns': 1, + 'class_num': 2, + 'img_size': [64, 128, 128] + } + net = UNETR_PP(params) + net.double() + + x = np.random.rand(2, 1, 64, 128, 128) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = net(xt) + print(len(y)) + for yi in y: + yi = yi.detach().numpy() + print(yi.shape) \ No newline at end of file diff --git a/pymic/net/net3d/trans3d/unetr_pp_block.py b/pymic/net/net3d/trans3d/unetr_pp_block.py new file mode 100644 index 0000000..89a8769 --- /dev/null +++ b/pymic/net/net3d/trans3d/unetr_pp_block.py @@ -0,0 +1,278 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import numpy as np +import torch +import torch.nn as nn +from typing import Optional, Sequence, Tuple, Union +from monai.networks.blocks.convolutions import Convolution +from monai.networks.layers.factories import Act, Norm +from monai.networks.layers.utils import get_act_layer, get_norm_layer + + +class UnetResBlock(nn.Module): + """ + A skip-connection based module that can be used for DynUNet, based on: + `Automated Design of Deep Learning Methods for Biomedical Image Segmentation `_. + `nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation `_. + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + kernel_size: convolution kernel size. + stride: convolution stride. + norm_name: feature normalization type and arguments. + act_name: activation layer type and arguments. + dropout: dropout probability. + + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int], + stride: Union[Sequence[int], int], + norm_name: Union[Tuple, str], + act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), + dropout: Optional[Union[Tuple, str, float]] = None, + ): + super().__init__() + self.conv1 = get_conv_layer( + spatial_dims, + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + dropout=dropout, + conv_only=True, + ) + self.conv2 = get_conv_layer( + spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1, dropout=dropout, conv_only=True + ) + self.lrelu = get_act_layer(name=act_name) + self.norm1 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) + self.norm2 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) + self.downsample = in_channels != out_channels + stride_np = np.atleast_1d(stride) + if not np.all(stride_np == 1): + self.downsample = True + if self.downsample: + self.conv3 = get_conv_layer( + spatial_dims, in_channels, out_channels, kernel_size=1, stride=stride, dropout=dropout, conv_only=True + ) + self.norm3 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) + + def forward(self, inp): + residual = inp + out = self.conv1(inp) + out = self.norm1(out) + out = self.lrelu(out) + out = self.conv2(out) + out = self.norm2(out) + if hasattr(self, "conv3"): + residual = self.conv3(residual) + if hasattr(self, "norm3"): + residual = self.norm3(residual) + out += residual + out = self.lrelu(out) + return out + + +class UnetBasicBlock(nn.Module): + """ + A CNN module module that can be used for DynUNet, based on: + `Automated Design of Deep Learning Methods for Biomedical Image Segmentation `_. + `nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation `_. + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + kernel_size: convolution kernel size. + stride: convolution stride. + norm_name: feature normalization type and arguments. + act_name: activation layer type and arguments. + dropout: dropout probability. + + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int], + stride: Union[Sequence[int], int], + norm_name: Union[Tuple, str], + act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), + dropout: Optional[Union[Tuple, str, float]] = None, + ): + super().__init__() + self.conv1 = get_conv_layer( + spatial_dims, + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + dropout=dropout, + conv_only=True, + ) + self.conv2 = get_conv_layer( + spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1, dropout=dropout, conv_only=True + ) + self.lrelu = get_act_layer(name=act_name) + self.norm1 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) + self.norm2 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) + + def forward(self, inp): + out = self.conv1(inp) + out = self.norm1(out) + out = self.lrelu(out) + out = self.conv2(out) + out = self.norm2(out) + out = self.lrelu(out) + return out + + +class UnetUpBlock(nn.Module): + """ + An upsampling module that can be used for DynUNet, based on: + `Automated Design of Deep Learning Methods for Biomedical Image Segmentation `_. + `nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation `_. + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + kernel_size: convolution kernel size. + stride: convolution stride. + upsample_kernel_size: convolution kernel size for transposed convolution layers. + norm_name: feature normalization type and arguments. + act_name: activation layer type and arguments. + dropout: dropout probability. + trans_bias: transposed convolution bias. + + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int], + stride: Union[Sequence[int], int], + upsample_kernel_size: Union[Sequence[int], int], + norm_name: Union[Tuple, str], + act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), + dropout: Optional[Union[Tuple, str, float]] = None, + trans_bias: bool = False, + ): + super().__init__() + upsample_stride = upsample_kernel_size + self.transp_conv = get_conv_layer( + spatial_dims, + in_channels, + out_channels, + kernel_size=upsample_kernel_size, + stride=upsample_stride, + dropout=dropout, + bias=trans_bias, + conv_only=True, + is_transposed=True, + ) + self.conv_block = UnetBasicBlock( + spatial_dims, + out_channels + out_channels, + out_channels, + kernel_size=kernel_size, + stride=1, + dropout=dropout, + norm_name=norm_name, + act_name=act_name, + ) + + def forward(self, inp, skip): + # number of channels for skip should equals to out_channels + out = self.transp_conv(inp) + out = torch.cat((out, skip), dim=1) + out = self.conv_block(out) + return out + + +class UnetOutBlock(nn.Module): + def __init__( + self, spatial_dims: int, in_channels: int, out_channels: int, dropout: Optional[Union[Tuple, str, float]] = None + ): + super().__init__() + self.conv = get_conv_layer( + spatial_dims, in_channels, out_channels, kernel_size=1, stride=1, dropout=dropout, bias=True, conv_only=True + ) + + def forward(self, inp): + return self.conv(inp) + + +def get_conv_layer( + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int] = 3, + stride: Union[Sequence[int], int] = 1, + act: Optional[Union[Tuple, str]] = Act.PRELU, + norm: Union[Tuple, str] = Norm.INSTANCE, + dropout: Optional[Union[Tuple, str, float]] = None, + bias: bool = False, + conv_only: bool = True, + is_transposed: bool = False, +): + padding = get_padding(kernel_size, stride) + output_padding = None + if is_transposed: + output_padding = get_output_padding(kernel_size, stride, padding) + return Convolution( + spatial_dims, + in_channels, + out_channels, + strides=stride, + kernel_size=kernel_size, + act=act, + norm=norm, + dropout=dropout, + bias=bias, + conv_only=conv_only, + is_transposed=is_transposed, + padding=padding, + output_padding=output_padding, + ) + + +def get_padding( + kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int] +) -> Union[Tuple[int, ...], int]: + + kernel_size_np = np.atleast_1d(kernel_size) + stride_np = np.atleast_1d(stride) + padding_np = (kernel_size_np - stride_np + 1) / 2 + if np.min(padding_np) < 0: + raise AssertionError("padding value should not be negative, please change the kernel size and/or stride.") + padding = tuple(int(p) for p in padding_np) + + return padding if len(padding) > 1 else padding[0] + + +def get_output_padding( + kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], padding: Union[Sequence[int], int] +) -> Union[Tuple[int, ...], int]: + kernel_size_np = np.atleast_1d(kernel_size) + stride_np = np.atleast_1d(stride) + padding_np = np.atleast_1d(padding) + + out_padding_np = 2 * padding_np + stride_np - kernel_size_np + if np.min(out_padding_np) < 0: + raise AssertionError("out_padding value should not be negative, please change the kernel size and/or stride.") + out_padding = tuple(int(p) for p in out_padding_np) + + return out_padding if len(out_padding) > 1 else out_padding[0] diff --git a/pymic/net_run/agent_rec.py b/pymic/net_run/agent_rec.py index 188a5fc..4860144 100644 --- a/pymic/net_run/agent_rec.py +++ b/pymic/net_run/agent_rec.py @@ -76,6 +76,7 @@ def training(self): # for debug # from pymic.io.image_read_write import save_nd_array_as_image + # print(inputs.shape) # for i in range(inputs.shape[0]): # image_i = inputs[i][0] # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) @@ -168,7 +169,8 @@ def train_valid(self): ckpt_prefix = self.config['training'].get('ckpt_prefix', None) if(ckpt_prefix is None): ckpt_prefix = ckpt_dir.split('/')[-1] - iter_start = self.config['training']['iter_start'] + # iter_start = self.config['training']['iter_start'] + iter_start = 0 iter_max = self.config['training']['iter_max'] iter_valid = self.config['training']['iter_valid'] iter_save = self.config['training'].get('iter_save', None) @@ -193,7 +195,7 @@ def train_valid(self): else: self.net.load_state_dict(self.checkpoint['model_state_dict']) self.min_val_loss = self.checkpoint.get('valid_loss', 10000) - # self.max_val_it = self.checkpoint['iteration'] + iter_start = self.checkpoint['iteration'] self.max_val_it = iter_start self.best_model_wts = self.checkpoint['model_state_dict'] diff --git a/pymic/net_run/self_sup/__init__.py b/pymic/net_run/self_sup/__init__.py index 55f26bf..73308e6 100644 --- a/pymic/net_run/self_sup/__init__.py +++ b/pymic/net_run/self_sup/__init__.py @@ -1,2 +1,3 @@ from __future__ import absolute_import -from pymic.net_run.self_sup.self_sl_agent import SelfSLSegAgent \ No newline at end of file +from pymic.net_run.self_sup.self_sl_agent import SelfSLSegAgent +from pymic.net_run.self_sup.self_patch_mix_agent import SelfSLPatchMixAgent \ No newline at end of file diff --git a/pymic/net_run/self_sup/self_patch_mix_agent.py b/pymic/net_run/self_sup/self_patch_mix_agent.py new file mode 100644 index 0000000..e30a131 --- /dev/null +++ b/pymic/net_run/self_sup/self_patch_mix_agent.py @@ -0,0 +1,144 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import copy +import os +import sys +import shutil +import time +import logging +import scipy +import torch +import torchvision.transforms as transforms +import numpy as np +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +from datetime import datetime +from random import random +from torch.optim import lr_scheduler +from tensorboardX import SummaryWriter +from pymic.io.image_read_write import save_nd_array_as_image +from pymic.io.nifty_dataset import NiftyDataset +from pymic.net.net_dict_seg import SegNetDict +from pymic.net_run.agent_abstract import NetRunAgent +from pymic.net_run.infer_func import Inferer +from pymic.loss.loss_dict_seg import SegLossDict +from pymic.loss.seg.combined import CombinedLoss +from pymic.loss.seg.deep_sup import DeepSuperviseLoss +from pymic.loss.seg.util import get_soft_label +from pymic.loss.seg.util import reshape_prediction_and_ground_truth +from pymic.loss.seg.util import get_classwise_dice +from pymic.transform.trans_dict import TransformDict +from pymic.util.post_process import PostProcessDict +from pymic.util.image_process import convert_label +from pymic.util.parse_config import * +from pymic.io.image_read_write import save_nd_array_as_image +from pymic.net_run.self_sup.util import patch_mix +from pymic.net_run.agent_seg import SegmentationAgent + +class SelfSLPatchMixAgent(SegmentationAgent): + """ + Abstract class for self-supervised segmentation. + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. + """ + def __init__(self, config, stage = 'train'): + super(SelfSLPatchMixAgent, self).__init__(config, stage) + + def training(self): + class_num = self.config['network']['class_num'] + iter_valid = self.config['training']['iter_valid'] + fg_num = self.config['network']['class_num'] - 1 + patch_num = self.config['patch_mix']['patch_num_range'] + size_d = self.config['patch_mix']['patch_depth_range'] + size_h = self.config['patch_mix']['patch_height_range'] + size_w = self.config['patch_mix']['patch_width_range'] + + train_loss = 0 + train_dice_list = [] + self.net.train() + for it in range(iter_valid): + try: + data = next(self.trainIter) + except StopIteration: + self.trainIter = iter(self.train_loader) + data = next(self.trainIter) + # get the inputs + inputs = self.convert_tensor_type(data['image']) + inputs, labels_prob = patch_mix(inputs, fg_num, patch_num, size_d, size_h, size_w) + + # # for debug + # if(it==10): + # break + # for i in range(inputs.shape[0]): + # image_i = inputs[i][0] + # label_i = np.argmax(labels_prob[i], axis = 0) + # # pixw_i = pix_w[i][0] + # print(image_i.shape, label_i.shape) + # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) + # label_name = "temp/label_{0:}_{1:}.nii.gz".format(it, i) + # # weight_name= "temp/weight_{0:}_{1:}.nii.gz".format(it, i) + # save_nd_array_as_image(image_i, image_name, reference_name = None) + # save_nd_array_as_image(label_i, label_name, reference_name = None) + # # save_nd_array_as_image(pixw_i, weight_name, reference_name = None) + # continue + + inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device) + + # zero the parameter gradients + self.optimizer.zero_grad() + + # forward + backward + optimize + outputs = self.net(inputs) + loss = self.get_loss_value(data, outputs, labels_prob) + loss.backward() + self.optimizer.step() + train_loss = train_loss + loss.item() + # get dice evaluation for each class + if(isinstance(outputs, tuple) or isinstance(outputs, list)): + outputs = outputs[0] + outputs_argmax = torch.argmax(outputs, dim = 1, keepdim = True) + soft_out = get_soft_label(outputs_argmax, class_num, self.tensor_type) + soft_out, labels_prob = reshape_prediction_and_ground_truth(soft_out, labels_prob) + dice_list = get_classwise_dice(soft_out, labels_prob) + train_dice_list.append(dice_list.cpu().numpy()) + train_avg_loss = train_loss / iter_valid + train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) + train_avg_dice = train_cls_dice[1:].mean() + + train_scalers = {'loss': train_avg_loss, 'avg_fg_dice':train_avg_dice,\ + 'class_dice': train_cls_dice} + return train_scalers + +def main(): + cfg_file = str(sys.argv[1]) + if(not os.path.isfile(cfg_file)): + raise ValueError("The config file does not exist: " + cfg_file) + config = parse_config(cfg_file) + config = synchronize_config(config) + log_dir = config['training']['ckpt_save_dir'] + if(not os.path.exists(log_dir)): + os.makedirs(log_dir, exist_ok=True) + dst_cfg = cfg_file if "/" not in cfg_file else cfg_file.split("/")[-1] + shutil.copy(cfg_file, log_dir + "/" + dst_cfg) + if sys.version.startswith("3.9"): + logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(str(datetime.now())[:-7]), + level=logging.INFO, format='%(message)s', force=True) # for python 3.9 + else: + logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(str(datetime.now())[:-7]), + level=logging.INFO, format='%(message)s') # for python 3.6 + logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) + logging_config(config) + agent = SelfSLPatchMixAgent(config) + agent.run() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/pymic/net_run/self_sup/util.py b/pymic/net_run/self_sup/util.py new file mode 100644 index 0000000..e131941 --- /dev/null +++ b/pymic/net_run/self_sup/util.py @@ -0,0 +1,167 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import os +import torch +import random +import numpy as np +from scipy import ndimage +from pymic.io.image_read_write import * +from pymic.util.image_process import * +from pymic.util.general import get_one_hot_seg + +def get_human_region_mask(img): + """ + Get the mask of human region in CT volumes + """ + dim = len(img.shape) + if( dim == 4): + img = img[0] + mask = np.asarray(img > -600) + se = np.ones([3,3,3]) + mask = ndimage.binary_opening(mask, se, iterations = 2) + mask = get_largest_k_components(mask, 1) + mask_close = ndimage.binary_closing(mask, se, iterations = 3) + + D, H, W = mask.shape + for d in [1, 2, D-3, D-2]: + mask_close[d] = mask[d] + for d in [0, -1, int(D/2)]: + mask_close[d, 1:-1, 1:-1] = np.ones((H-2, W-2)) + + bg = get_largest_k_components(1- mask_close, 1) + fg = 1 - bg + se = np.ones([3,3,3]) + fg = ndimage.binary_opening(fg, se, iterations = 1) + fg = get_largest_k_components(fg, 1) + if(dim == 4): + fg = np.expand_dims(fg, 0) + fg = np.asarray(fg, np.uint8) + return fg + +def crop_ct_scan(input_img, output_img, input_lab = None, output_lab = None): + """ + Crop a CT scan based on the bounding box of the human region. + """ + img_obj = sitk.ReadImage(input_img) + img = sitk.GetArrayFromImage(img_obj) + mask = np.asarray(img > -600) + se = np.ones([3,3,3]) + mask = ndimage.binary_opening(mask, se, iterations = 2) + mask = get_largest_k_components(mask, 1) + bbmin, bbmax = get_ND_bounding_box(mask, margin = [5, 10, 10]) + img_sub = crop_ND_volume_with_bounding_box(img, bbmin, bbmax) + img_sub_obj = sitk.GetImageFromArray(img_sub) + img_sub_obj.SetSpacing(img_obj.GetSpacing()) + sitk.WriteImage(img_sub_obj, output_img) + if(input_lab is not None): + lab_obj = sitk.ReadImage(input_lab) + lab = sitk.GetArrayFromImage(lab_obj) + lab_sub = crop_ND_volume_with_bounding_box(lab, bbmin, bbmax) + lab_sub_obj = sitk.GetImageFromArray(lab_sub) + lab_sub_obj.SetSpacing(img_obj.GetSpacing()) + sitk.WriteImage(lab_sub_obj, output_lab) + + +def patch_mix(x, fg_num, patch_num, size_d, size_h, size_w): + """ + Copy a sub region of an impage and paste to another one to generate + images and labels for self-supervised segmentation. + """ + N, C, D, H, W = list(x.shape) + fg_mask = torch.zeros_like(x) + # generate mask + for n in range(N): + p_num = random.randint(patch_num[0], patch_num[1]) + for i in range(p_num): + d = random.randint(size_d[0], size_d[1]) + h = random.randint(size_h[0], size_h[1]) + w = random.randint(size_w[0], size_w[1]) + d_c = random.randint(0, D) + h_c = random.randint(0, H) + w_c = random.randint(0, W) + d0, d1 = max(0, d_c - d), min(D, d_c + d) + h0, h1 = max(0, h_c - h), min(H, h_c + h) + w0, w1 = max(0, w_c - w), min(W, w_c + w) + temp_m = torch.ones([C, d1-d0, h1-h0, w1-w0]) * random.randint(1, fg_num) + fg_mask[n, :, d0:d1, h0:h1, w0:w1] = temp_m + fg_w = fg_mask * 1.0 / fg_num + x_roll = torch.roll(x, 1, 0) + x_fuse = fg_w*x_roll + (1.0 - fg_w)*x + y_prob = get_one_hot_seg(fg_mask.to(torch.int32), fg_num + 1) + return x_fuse, y_prob + +def create_mixed_dataset(input_dir, output_dir, fg_num = 1, crop_num = 1, + mask = 'default', data_format = "nii.gz"): + """ + Create dataset based on patch mix. + + :param input_dir: (str) The path of folder for input images + :param output_dir: (str) The path of folder for output images + :param fg_num: (int) The number of foreground classes + :param crop_num: (int) The number of patches to crop for each input image + :param mask: ND array to specify a mask, or 'default' or None. If default, + a mask for body region is automatically generated (just for CT). + :param data_format: (str) The format of images. + """ + img_names = os.listdir(input_dir) + img_names = [item for item in img_names if item.endswith(data_format)] + img_names = sorted(img_names) + out_img_dir = output_dir + "/image" + out_lab_dir = output_dir + "/label" + if(not os.path.exists(out_img_dir)): + os.mkdir(out_img_dir) + if(not os.path.exists(out_lab_dir)): + os.mkdir(out_lab_dir) + + img_num = len(img_names) + print("image number", img_num) + i_range = range(img_num) + j_range = list(i_range) + random.shuffle(j_range) + for i in i_range: + print(i, img_names[i]) + j = j_range[i] + if(i == j): + j = i + 1 if i < img_num - 1 else 0 + img_i = load_image_as_nd_array(input_dir + "/" + img_names[i])['data_array'] + img_j = load_image_as_nd_array(input_dir + "/" + img_names[j])['data_array'] + + chns = img_i.shape[0] + # random crop to patch size + if(mask == 'default'): + mask_i = get_human_region_mask(img_i) + mask_j = get_human_region_mask(img_j) + for k in range(crop_num): + if(mask is None): + img_ik = random_crop_ND_volume(img_i, [chns, 96, 96, 96]) + img_jk = random_crop_ND_volume(img_j, [chns, 96, 96, 96]) + else: + img_ik = random_crop_ND_volume_with_mask(img_i, [chns, 96, 96, 96], mask_i) + img_jk = random_crop_ND_volume_with_mask(img_j, [chns, 96, 96, 96], mask_j) + C, D, H, W = img_ik.shape + # generate mask + fg_mask = np.zeros_like(img_ik, np.uint8) + patch_num = random.randint(4, 40) + for patch in range(patch_num): + d = random.randint(4, 20) # half of window size + h = random.randint(4, 40) + w = random.randint(4, 40) + d_c = random.randint(0, D) + h_c = random.randint(0, H) + w_c = random.randint(0, W) + d0, d1 = max(0, d_c - d), min(D, d_c + d) + h0, h1 = max(0, h_c - h), min(H, h_c + h) + w0, w1 = max(0, w_c - w), min(W, w_c + w) + temp_m = np.ones([C, d1-d0, h1-h0, w1-w0]) * random.randint(1, fg_num) + fg_mask[:, d0:d1, h0:h1, w0:w1] = temp_m + fg_w = fg_mask * 1.0 / fg_num + x_fuse = fg_w*img_jk + (1.0 - fg_w)*img_ik + + out_name = img_names[i] + if crop_num > 1: + out_name = out_name.replace(".nii.gz", "_{0:}.nii.gz".format(k)) + save_nd_array_as_image(x_fuse[0], out_img_dir + "/" + out_name, + reference_name = input_dir + "/" + img_names[i]) + save_nd_array_as_image(fg_mask[0], out_lab_dir + "/" + out_name, + reference_name = input_dir + "/" + img_names[i]) + diff --git a/pymic/transform/label_convert.py b/pymic/transform/label_convert.py index e3accdf..00c505e 100644 --- a/pymic/transform/label_convert.py +++ b/pymic/transform/label_convert.py @@ -152,29 +152,16 @@ def __call__(self, sample): return sample -class SelfSuperviseLabel(AbstractTransform): +class SelfReconstructionLabel(AbstractTransform): """ - Convert one-channel partial label map to one-hot multi-channel probability map. - This is used for segmentation tasks only. In the input label map, 0 represents the - background class, 1 to C-1 represent the foreground classes, and C represents - unlabeled pixels. In the output dictionary, `label_prob` is the one-hot probability - map, and `pixel_weight` represents a weighting map, where the weight for a pixel - is 0 if the label is unkown. - - The arguments should be written in the `params` dictionary, and it has the - following fields: - - :param `PartialLabelToProbability_class_num`: (int) The class number for the - segmentation task. - :param `PartialLabelToProbability_inverse`: (optional, bool) - Is inverse transform needed for inference. Default is `False`. + Used for self-supervised learning with image reconstruction tasks. """ def __init__(self, params): """ class_num (int): the class number in the label map """ - super(SelfSuperviseLabel, self).__init__(params) - self.inverse = params.get('SelfSuperviseLabel_inverse'.lower(), False) + super(SelfReconstructionLabel, self).__init__(params) + self.inverse = params.get('SelfReconstructionLabel_inverse'.lower(), False) def __call__(self, sample): image = sample['image'] @@ -183,3 +170,41 @@ def __call__(self, sample): return sample +class MaskedImageModelingLabel(AbstractTransform): + """ + Used for self-supervised learning with image reconstruction tasks. + Only reconstruct the masked region in the input. + The input images is masked in local patches. + """ + def __init__(self, params): + """ + class_num (int): the class number in the label map + """ + super(MaskedImageModelingLabel, self).__init__(params) + self.patch_size = params.get('MaskedImageModelingLabel_patch_size'.lower(), [16, 16, 16]) + self.masking_ratio = params.get('MaskedImageModelingLabel_ratio'.lower(), 0.15) + self.inverse = params.get('MaskedImageModelingLabel_inverse'.lower(), False) + + def __call__(self, sample): + image = sample['image'] + C, D, H, W = image.shape + patch_size = self.patch_size + mask = np.ones([D, H, W], np.float32) + grid_size = [math.ceil((image.shape[i+1] + 0.0) / patch_size[i]) for i in range(3)] + for d in range(grid_size[0]): + d0 = d*patch_size[0] + for h in range(grid_size[1]): + h0 = h*patch_size[1] + for w in range(grid_size[2]): + w0 = w*patch_size[2] + if(random.random() > self.masking_ratio): + continue + d1 = min(d0 + patch_size[0], D) + h1 = min(h0 + patch_size[1], H) + w1 = min(w0 + patch_size[2], W) + mask[d0:d1, h0:h1, w0:w1] = np.zeros([d1 - d0, h1 - h0, w1 - w0]) + sample['pixel_weight'] = 1 - mask + sample['image'] = image * mask + sample['label'] = image + return sample + diff --git a/pymic/transform/mix.py b/pymic/transform/mix.py new file mode 100644 index 0000000..fe2f315 --- /dev/null +++ b/pymic/transform/mix.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import copy +import json +import math +import random +import numpy as np +from pymic.transform.abstract_transform import AbstractTransform +from pymic.util.image_process import * +try: # SciPy >= 0.19 + from scipy.special import comb +except ImportError: + from scipy.misc import comb + + +class CopyPaste(AbstractTransform): + """ + In-painting of an input image, used for self-supervised learning + """ + def __init__(self, params): + super(CopyPaste, self).__init__(params) + self.inverse = params.get('CopyPaste_inverse'.lower(), False) + self.block_range = params.get('CopyPaste_block_range'.lower(), (1, 6)) + self.block_size_min = params.get('CopyPaste_block_size_min'.lower(), None) + self.block_size_max = params.get('CopyPaste_block_size_max'.lower(), None) + + def __call__(self, sample): + image= sample['image'] + img_shape = image.shape + img_dim = len(img_shape) - 1 + assert(img_dim == 2 or img_dim == 3) + + if(self.block_size_min is None): + block_size_min = [img_shape[1+i]//6 for i in range(img_dim)] + elif(isinstance(self.block_size_min, int)): + block_size_min = [self.block_size_min] * img_dim + else: + assert(len(self.block_size_min) == img_dim) + block_size_min = self.block_size_min + + if(self.block_size_max is None): + block_size_max = [img_shape[1+i]//3 for i in range(img_dim)] + elif(isinstance(self.block_size_min, int)): + block_size_max = [self.block_size_max] * img_dim + else: + assert(len(self.block_size_max) == img_dim) + block_size_max = self.block_size_max + block_num = random.randint(self.block_range[0], self.block_range[1]) + + for n in range(block_num): + block_size = [random.randint(block_size_min[i], block_size_max[i]) \ + for i in range(img_dim)] + coord_min = [random.randint(3, img_shape[1+i] - block_size[i] - 3) \ + for i in range(img_dim)] + if(img_dim == 2): + random_block = np.random.rand(img_shape[0], block_size[0], block_size[1]) + image[:, coord_min[0]:coord_min[0] + block_size[0], + coord_min[1]:coord_min[1] + block_size[1]] = random_block + else: + random_block = np.random.rand(img_shape[0], block_size[0], + block_size[1], block_size[2]) + image[:, coord_min[0]:coord_min[0] + block_size[0], + coord_min[1]:coord_min[1] + block_size[1], + coord_min[2]:coord_min[2] + block_size[2]] = random_block + sample['image'] = image + return sample \ No newline at end of file diff --git a/pymic/transform/trans_dict.py b/pymic/transform/trans_dict.py index c1ecdc2..5dac73b 100644 --- a/pymic/transform/trans_dict.py +++ b/pymic/transform/trans_dict.py @@ -71,7 +71,8 @@ 'RandomRotate': RandomRotate, 'ReduceLabelDim': ReduceLabelDim, 'Rescale': Rescale, - 'SelfSuperviseLabel': SelfSuperviseLabel, + 'SelfReconstructionLabel': SelfReconstructionLabel, + 'MaskedImageModelingLabel': MaskedImageModelingLabel, 'OutPainting': OutPainting, 'Pad': Pad, } From bebf534ef9673961806bea4703cb6946560b8834 Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 22 May 2023 12:40:04 +0800 Subject: [PATCH 12/86] update random crop allow default setting of foreground labels --- pymic/transform/crop.py | 21 +++++++++++---------- pymic/util/image_process.py | 36 +++++++++++++++++------------------- 2 files changed, 28 insertions(+), 29 deletions(-) diff --git a/pymic/transform/crop.py b/pymic/transform/crop.py index 95e3489..b8130f9 100644 --- a/pymic/transform/crop.py +++ b/pymic/transform/crop.py @@ -221,7 +221,8 @@ class RandomCrop(CenterCrop): `RandomCrop_foreground_focus` is True. :param `RandomCrop_mask_label`: (optional, None, or list/tuple) Specifying the foreground labels for foreground focus cropping when - `RandomCrop_foreground_focus` is True. + `RandomCrop_foreground_focus` is True. If it is None (by default), + the mask label will be the list of all the foreground classes. :param `RandomCrop_inverse`: (optional, bool) Is inverse transform needed for inference. Default is `True`. """ @@ -229,7 +230,7 @@ def __init__(self, params): self.output_size = params['RandomCrop_output_size'.lower()] self.fg_focus = params.get('RandomCrop_foreground_focus'.lower(), False) self.fg_ratio = params.get('RandomCrop_foreground_ratio'.lower(), 0.5) - self.mask_label = params.get('RandomCrop_mask_label'.lower(), [1]) + self.mask_label = params.get('RandomCrop_mask_label'.lower(), None) self.inverse = params.get('RandomCrop_inverse'.lower(), True) self.task = params['Task'.lower()] assert isinstance(self.output_size, (list, tuple)) @@ -246,16 +247,16 @@ def _get_crop_param(self, sample): crop_margin = [input_shape[i] - self.output_size[i] for i in range(input_dim)] crop_min = [0 if item == 0 else random.randint(0, item) for item in crop_margin] crop_max = [crop_min[i] + self.output_size[i] for i in range(input_dim)] + if(self.fg_focus and random.random() < self.fg_ratio): label = sample['label'][0] - mask = np.zeros_like(label) - for temp_lab in self.mask_label: - mask = np.maximum(mask, label == temp_lab) - if(mask.max() > 0): - crop_min, crop_max = get_random_box_from_mask(mask, self.output_size) - # to avoid Typeerror: object of type int64 is not json serializable - crop_min = [int(i) for i in crop_min] - crop_max = [int(i) for i in crop_max] + if(self.mask_label is None): + mask_label = np.unique(label)[1:] + else: + mask_label = self.mask_label + random_label = random.choice(mask_label) + crop_min, crop_max = get_random_box_from_mask(label == random_label, self.output_size) + crop_min = [0] + crop_min crop_max = [chns] + crop_max diff --git a/pymic/util/image_process.py b/pymic/util/image_process.py index d5d7a7e..98409d7 100644 --- a/pymic/util/image_process.py +++ b/pymic/util/image_process.py @@ -165,23 +165,19 @@ def random_crop_ND_volume(volume, out_shape): return crop_volume def get_random_box_from_mask(mask, out_shape): - mask_shape = mask.shape - dim = len(out_shape) - left_margin = [int(out_shape[i]/2) for i in range(dim)] - right_margin= [mask_shape[i] - (out_shape[i] - left_margin[i]) + 1 for i in range(dim)] - - valid_center_shape = [right_margin[i] - left_margin[i] for i in range(dim)] - valid_mask = np.zeros(mask_shape) - valid_mask = set_ND_volume_roi_with_bounding_box_range(valid_mask, - left_margin, right_margin, np.ones(valid_center_shape)) - valid_mask = valid_mask * mask - - indexes = np.where(valid_mask) + indexes = np.where(mask) voxel_num = len(indexes[0]) - j = random.randint(0, voxel_num - 1) - bb_c = [indexes[i][j] for i in range(dim)] - bb_min = [bb_c[i] - left_margin[i] for i in range(dim)] + dim = len(out_shape) + left_bound = [int(out_shape[i]/2) for i in range(dim)] + right_bound = [mask.shape[i] - (out_shape[i] - left_bound[i]) for i in range(dim)] + + j = random.randint(0, voxel_num - 1) + bb_c = [int(indexes[i][j]) for i in range(dim)] + bb_c = [max(left_bound[i], bb_c[i]) for i in range(dim)] + bb_c = [min(right_bound[i], bb_c[i]) for i in range(dim)] + bb_min = [bb_c[i] - left_bound[i] for i in range(dim)] bb_max = [bb_min[i] + out_shape[i] for i in range(dim)] + return bb_min, bb_max def random_crop_ND_volume_with_mask(volume, out_shape, mask): @@ -234,7 +230,8 @@ def get_largest_k_components(image, k = 1): :param image: The input ND array for binary segmentation. :param k: (int) The value of k. - :return: An output array with only the largest K components of the input. + :return: An output array (k == 1) or a list of ND array (k>1) + with only the largest K components of the input. """ dim = len(image.shape) if(image.sum() == 0 ): @@ -247,11 +244,12 @@ def get_largest_k_components(image, k = 1): sizes = ndimage.sum(image, labeled_array, range(1, numpatches + 1)) sizes_sort = sorted(sizes, reverse = True) kmin = min(k, numpatches) - output = np.zeros_like(image) + output = [] for i in range(kmin): labeli = np.where(sizes == sizes_sort[i])[0] + 1 - output = output + np.asarray(labeled_array == labeli, np.uint8) - return output + output_i = np.asarray(labeled_array == labeli, np.uint8) + output.append(output_i) + return output[0] if k == 1 else output def get_euclidean_distance(image, dim = 3, spacing = [1.0, 1.0, 1.0]): """ From 11ffcd01ce337d51d75130c297e8047ae6e002eb Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 2 Jun 2023 22:05:24 +0800 Subject: [PATCH 13/86] update infer_func update multi-scale prediction with gaussian weight --- pymic/net_run/agent_rec.py | 56 +++++++++++++++++----------- pymic/net_run/agent_seg.py | 8 +++- pymic/net_run/infer_func.py | 14 +++---- pymic/net_run/self_sup/util.py | 38 ++++++++++++------- pymic/util/image_process.py | 68 ++++++++++++++++++++++++++++++++++ 5 files changed, 142 insertions(+), 42 deletions(-) diff --git a/pymic/net_run/agent_rec.py b/pymic/net_run/agent_rec.py index 4860144..1e58bc6 100644 --- a/pymic/net_run/agent_rec.py +++ b/pymic/net_run/agent_rec.py @@ -16,6 +16,7 @@ from pymic.net_run.infer_func import Inferer from pymic.net_run.agent_seg import SegmentationAgent from pymic.loss.seg.mse import MAELoss, MSELoss +from pymic.util.general import mixup, tensor_shape_match ReconstructionLossDict = { 'MAELoss': MAELoss, @@ -165,6 +166,7 @@ def train_valid(self): else: self.device = torch.device("cuda:{0:}".format(device_ids[0])) self.net.to(self.device) + ckpt_dir = self.config['training']['ckpt_save_dir'] ckpt_prefix = self.config['training'].get('ckpt_prefix', None) if(ckpt_prefix is None): @@ -186,20 +188,31 @@ def train_valid(self): self.max_val_it = 0 self.best_model_wts = None self.checkpoint = None - if(iter_start > 0): - checkpoint_file = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, iter_start) - self.checkpoint = torch.load(checkpoint_file, map_location = self.device) - # assert(self.checkpoint['iteration'] == iter_start) - if(len(device_ids) > 1): - self.net.module.load_state_dict(self.checkpoint['model_state_dict']) + # initialize the network with pre-trained weights + ckpt_init_name = self.config['training'].get('ckpt_init_name', None) + ckpt_init_mode = self.config['training'].get('ckpt_init_mode', 0) + ckpt_for_optm = None + if(ckpt_init_name is not None): + checkpoint = torch.load(ckpt_dir + "/" + ckpt_init_name, map_location = self.device) + pretrained_dict = checkpoint['model_state_dict'] + model_dict = self.net.module.state_dict() if (len(device_ids) > 1) else self.net.state_dict() + pretrained_dict = {k: v for k, v in pretrained_dict.items() if \ + k in model_dict and tensor_shape_match(pretrained_dict[k], model_dict[k])} + logging.info("Initializing the following parameters with pre-trained model") + for k in pretrained_dict: + logging.info(k) + if (len(device_ids) > 1): + self.net.module.load_state_dict(pretrained_dict, strict = False) else: - self.net.load_state_dict(self.checkpoint['model_state_dict']) - self.min_val_loss = self.checkpoint.get('valid_loss', 10000) - iter_start = self.checkpoint['iteration'] - self.max_val_it = iter_start - self.best_model_wts = self.checkpoint['model_state_dict'] + self.net.load_state_dict(pretrained_dict, strict = False) + if(ckpt_init_mode > 0): # Load other information + self.min_val_loss = self.checkpoint.get('valid_loss', 10000) + iter_start = checkpoint['iteration'] + self.max_val_it = iter_start + self.best_model_wts = checkpoint['model_state_dict'] + ckpt_for_optm = checkpoint - self.create_optimizer(self.get_parameters_to_update()) + self.create_optimizer(self.get_parameters_to_update(), ckpt_for_optm) self.create_loss_calculator() self.trainIter = iter(self.train_loader) @@ -231,6 +244,16 @@ def train_valid(self): self.best_model_wts = copy.deepcopy(self.net.module.state_dict()) else: self.best_model_wts = copy.deepcopy(self.net.state_dict()) + + save_dict = {'iteration': self.max_val_it, + 'valid_loss': self.min_val_loss, + 'model_state_dict': self.best_model_wts, + 'optimizer_state_dict': self.optimizer.state_dict()} + save_name = "{0:}/{1:}_best.pt".format(ckpt_dir, ckpt_prefix) + torch.save(save_dict, save_name) + txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefix), 'wt') + txt_file.write(str(self.max_val_it)) + txt_file.close() stop_now = True if(early_stop_it is not None and \ self.glob_it - self.max_val_it > early_stop_it) else False @@ -249,15 +272,6 @@ def train_valid(self): logging.info("The training is early stopped") break # save the best performing checkpoint - save_dict = {'iteration': self.max_val_it, - 'valid_loss': self.min_val_loss, - 'model_state_dict': self.best_model_wts, - 'optimizer_state_dict': self.optimizer.state_dict()} - save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.max_val_it) - torch.save(save_dict, save_name) - txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefix), 'wt') - txt_file.write(str(self.max_val_it)) - txt_file.close() logging.info('The best performing iter is {0:}, valid loss {1:}'.format(\ self.max_val_it, self.min_val_loss)) self.summ_writer.close() diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index b007574..3d2fb21 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -89,7 +89,12 @@ def create_network(self): logging.info('parameter number {0:}'.format(param_number)) def get_parameters_to_update(self): - return self.net.parameters() + if hasattr(self.net, "get_parameters_to_update"): + params = self.net.get_parameters_to_update() + else: + params = self.net.parameters() + return params + def create_loss_calculator(self): if(self.loss_dict is None): @@ -401,6 +406,7 @@ def test_time_dropout(m): raise ValueError("ckpt_mode should be 3 if ckpt_name is a list") # load network parameters and set the network as evaluation mode + print("ckpt name", ckpt_name) checkpoint = torch.load(ckpt_name, map_location = device) self.net.load_state_dict(checkpoint['model_state_dict']) diff --git a/pymic/net_run/infer_func.py b/pymic/net_run/infer_func.py index d2b9af2..b0190ad 100644 --- a/pymic/net_run/infer_func.py +++ b/pymic/net_run/infer_func.py @@ -146,7 +146,8 @@ def __infer_with_sliding_window(self, image): output_shape_i = [batch_size, class_num] + \ [int(img_shape[d] * scale_list[i][d]) for d in range(img_dim)] output_list.append(torch.zeros(output_shape_i).to(image.device)) - + temp_ws = [interpolate(temp_w, scale_factor = scale_list[i]) for i in range(out_num)] + weights = [interpolate(weight, scale_factor = scale_list[i]) for i in range(out_num)] for w_i in range(0, window_num, window_batch): for k in range(window_batch): if(w_i + k >= window_num): @@ -167,14 +168,13 @@ def __infer_with_sliding_window(self, image): c0_i = [int(c0[d] * scale_list[i][d]) for d in range(img_dim)] c1_i = [int(c1[d] * scale_list[i][d]) for d in range(img_dim)] if(img_dim == 2): - output_list[i][:, :, c0_i[0]:c1_i[0], c0_i[1]:c1_i[1]] += patches_out[i][k] * temp_w - weight[:, :, c0[0]:c1[0], c0[1]:c1[1]] += temp_w + output_list[i][:, :, c0_i[0]:c1_i[0], c0_i[1]:c1_i[1]] += patches_out[i][k] * temp_ws[i] + weights[i][:, :, c0_i[0]:c1_i[0], c0_i[1]:c1_i[1]] += temp_ws[i] else: - output_list[i][:, :, c0_i[0]:c1_i[0], c0_i[1]:c1_i[1], c0_i[2]:c1_i[2]] += patches_out[i][k] * temp_w - weight[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] += temp_w + output_list[i][:, :, c0_i[0]:c1_i[0], c0_i[1]:c1_i[1], c0_i[2]:c1_i[2]] += patches_out[i][k] * temp_ws[i] + weights[i][:, :, c0_i[0]:c1_i[0], c0_i[1]:c1_i[1], c0_i[2]:c1_i[2]] += temp_ws[i] for i in range(out_num): - weight_i = interpolate(weight, scale_factor = scale_list[i]) - output_list[i] = output_list[i] / weight_i + output_list[i] = output_list[i] / weights[i] return output_list def run(self, model, image): diff --git a/pymic/net_run/self_sup/util.py b/pymic/net_run/self_sup/util.py index e131941..9cffaa7 100644 --- a/pymic/net_run/self_sup/util.py +++ b/pymic/net_run/self_sup/util.py @@ -20,17 +20,26 @@ def get_human_region_mask(img): se = np.ones([3,3,3]) mask = ndimage.binary_opening(mask, se, iterations = 2) mask = get_largest_k_components(mask, 1) - mask_close = ndimage.binary_closing(mask, se, iterations = 3) + mask_close = ndimage.binary_closing(mask, se, iterations = 2) D, H, W = mask.shape for d in [1, 2, D-3, D-2]: mask_close[d] = mask[d] - for d in [0, -1, int(D/2)]: - mask_close[d, 1:-1, 1:-1] = np.ones((H-2, W-2)) + for d in range(0, D, 2): + mask_close[d, 2:-2, 2:-2] = np.ones((H-4, W-4)) - bg = get_largest_k_components(1- mask_close, 1) + # get background component + bg = np.zeros_like(mask) + bgs = get_largest_k_components(1- mask_close, 10) + for bgi in bgs: + indices = np.where(bgi) + if(bgi.sum() < 1000): + break + if(indices[0].min() == 0 or indices[1].min() == 0 or indices[2].min() ==0 or \ + indices[0].max() == D-1 or indices[1].max() == H-1 or indices[2].max() ==W-1): + bg = bg + bgi fg = 1 - bg - se = np.ones([3,3,3]) + fg = ndimage.binary_opening(fg, se, iterations = 1) fg = get_largest_k_components(fg, 1) if(dim == 4): @@ -91,7 +100,7 @@ def patch_mix(x, fg_num, patch_num, size_d, size_h, size_w): return x_fuse, y_prob def create_mixed_dataset(input_dir, output_dir, fg_num = 1, crop_num = 1, - mask = 'default', data_format = "nii.gz"): + mask_dir = None, data_format = "nii.gz"): """ Create dataset based on patch mix. @@ -128,16 +137,19 @@ def create_mixed_dataset(input_dir, output_dir, fg_num = 1, crop_num = 1, chns = img_i.shape[0] # random crop to patch size - if(mask == 'default'): + if(mask_dir is None): mask_i = get_human_region_mask(img_i) mask_j = get_human_region_mask(img_j) + else: + mask_i = load_image_as_nd_array(mask_dir + "/" + img_names[i])['data_array'] + mask_j = load_image_as_nd_array(mask_dir + "/" + img_names[j])['data_array'] for k in range(crop_num): - if(mask is None): - img_ik = random_crop_ND_volume(img_i, [chns, 96, 96, 96]) - img_jk = random_crop_ND_volume(img_j, [chns, 96, 96, 96]) - else: - img_ik = random_crop_ND_volume_with_mask(img_i, [chns, 96, 96, 96], mask_i) - img_jk = random_crop_ND_volume_with_mask(img_j, [chns, 96, 96, 96], mask_j) + # if(mask is None): + # img_ik = random_crop_ND_volume(img_i, [chns, 96, 96, 96]) + # img_jk = random_crop_ND_volume(img_j, [chns, 96, 96, 96]) + # else: + img_ik = random_crop_ND_volume_with_mask(img_i, [chns, 96, 96, 96], mask_i) + img_jk = random_crop_ND_volume_with_mask(img_j, [chns, 96, 96, 96], mask_j) C, D, H, W = img_ik.shape # generate mask fg_mask = np.zeros_like(img_ik, np.uint8) diff --git a/pymic/util/image_process.py b/pymic/util/image_process.py index 98409d7..8ae8e80 100644 --- a/pymic/util/image_process.py +++ b/pymic/util/image_process.py @@ -1,10 +1,13 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, print_function +import csv import random +import pandas as pd import numpy as np import SimpleITK as sitk from scipy import ndimage +from pymic.io.image_read_write import load_image_as_nd_array def get_ND_bounding_box(volume, margin = None): """ @@ -315,3 +318,68 @@ def resample_sitk_image_to_given_spacing(image, spacing, order): out_img.SetSpacing(spacing) out_img.SetDirection(image.GetDirection()) return out_img + +def get_image_info(img_names): + space0, space1, slices = [], [], [] + for img_name in img_names: + img_obj = sitk.ReadImage(img_name) + img_arr = sitk.GetArrayFromImage(img_obj) + spacing = img_obj.GetSpacing() + slices.append(img_arr.shape[0]) + space0.append(spacing[0]) + space1.append(spacing[2]) + print(img_name, spacing, img_arr.shape) + + space0 = np.asarray(space0) + space1 = np.asarray(space1) + slices = np.asarray(slices) + print("intra-slice spacing") + print(space0.min(), space0.max(), space0.mean()) + print("inter-slice spacing") + print(space1.min(), space1.max(), space1.mean()) + print("slice number") + print(slices.min(), slices.max(), slices.mean()) + +def get_average_mean_std(data_dir, data_csv): + df = pd.read_csv(data_csv) + mean_list, std_list = [], [] + for i in range(len(df)): + img_name = data_dir + "/" + df.iloc[i, 0] + lab_name = data_dir + "/" + df.iloc[i, 1] + img = load_image_as_nd_array(img_name)["data_array"][0] + lab = load_image_as_nd_array(lab_name)["data_array"][0] + voxels = img[lab>0] + mean = voxels.mean() + std = voxels.std() + mean_list.append(mean) + std_list.append(std) + print(img_name, mean, std) + mean = np.asarray(mean_list).mean() + std = np.asarray(std_list).mean() + print("mean and std value", mean, std) + +def get_label_info(data_dir, label_csv, class_num): + df = pd.read_csv(label_csv) + size_list = [] + # mean_list, std_list = [], [] + num_no_tumor = 0 + for i in range(len(df)): + lab_name = data_dir + "/" + df.iloc[i, 1] + lab = load_image_as_nd_array(lab_name)["data_array"][0] + size_per_class = [] + for c in range(1, class_num): + labc = lab == c + size_per_class.append(np.sum(labc)) + if(np.sum(labc) == 0): + num_no_tumor = num_no_tumor + 1 + size_list.append(size_per_class) + print(lab_name, size_per_class) + size = np.asarray(size_list) + size_min = size.min(axis = 0) + size_max = size.max(axis = 0) + size_mean = size.mean(axis = 0) + + print("size min", size_min) + print("size max", size_max) + print("size mean", size_mean) + print("case number without tumor", num_no_tumor) \ No newline at end of file From 258041aa2d398a134f243b599b4d4b753d6c947e Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 5 Jun 2023 13:09:46 +0800 Subject: [PATCH 14/86] update ce and dice loss set pixel weight and class weight --- pymic/loss/seg/ce.py | 25 +++++++++++++------------ pymic/loss/seg/dice.py | 14 +++++++++++--- pymic/net_run/agent_seg.py | 20 +++++++++++++++----- 3 files changed, 39 insertions(+), 20 deletions(-) diff --git a/pymic/loss/seg/ce.py b/pymic/loss/seg/ce.py index 529482b..9524d57 100644 --- a/pymic/loss/seg/ce.py +++ b/pymic/loss/seg/ce.py @@ -23,6 +23,7 @@ def forward(self, loss_input_dict): predict = loss_input_dict['prediction'] soft_y = loss_input_dict['ground_truth'] pix_w = loss_input_dict.get('pixel_weight', None) + cls_w = loss_input_dict.get('class_weight', None) if(isinstance(predict, (list, tuple))): predict = predict[0] @@ -34,7 +35,10 @@ def forward(self, loss_input_dict): # for numeric stability predict = predict * 0.999 + 5e-4 ce = - soft_y* torch.log(predict) - ce = torch.sum(ce, dim = 1) # shape is [N] + if(cls_w is not None): + ce = torch.sum(ce*cls_w, dim = 1) + else: + ce = torch.sum(ce, dim = 1) # shape is [N] if(pix_w is None): ce = torch.mean(ce) else: @@ -61,12 +65,12 @@ class GeneralizedCELoss(AbstractSegLoss): def __init__(self, params): super(GeneralizedCELoss, self).__init__(params) self.q = params.get('loss_gce_q', 0.5) - self.enable_pix_weight = params.get('loss_with_pixel_weight', False) - self.cls_weight = params.get('loss_class_weight', None) def forward(self, loss_input_dict): predict = loss_input_dict['prediction'] - soft_y = loss_input_dict['ground_truth'] + soft_y = loss_input_dict['ground_truth'] + pix_w = loss_input_dict.get('pixel_weight', None) + cls_w = loss_input_dict.get('class_weight', None) if(isinstance(predict, (list, tuple))): predict = predict[0] @@ -76,17 +80,14 @@ def forward(self, loss_input_dict): soft_y = reshape_tensor_to_2D(soft_y) gce = (1.0 - torch.pow(predict, self.q)) / self.q * soft_y - if(self.cls_weight is not None): - gce = torch.sum(gce * self.cls_w, dim = 1) + if(cls_w is not None): + gce = torch.sum(gce * cls_w, dim = 1) else: gce = torch.sum(gce, dim = 1) - if(self.enable_pix_weight): - pix_w = loss_input_dict.get('pixel_weight', None) - if(pix_w is None): - raise ValueError("Pixel weight is enabled but not defined") - pix_w = reshape_tensor_to_2D(pix_w) - gce = torch.sum(gce * pix_w) / torch.sum(pix_w) + if(pix_w is not None): + pix_w = torch.squeeze(reshape_tensor_to_2D(pix_w)) + gce = torch.sum(gce * pix_w) / torch.sum(pix_w) else: gce = torch.mean(gce) return gce diff --git a/pymic/loss/seg/dice.py b/pymic/loss/seg/dice.py index 350e0c4..2c2df32 100644 --- a/pymic/loss/seg/dice.py +++ b/pymic/loss/seg/dice.py @@ -20,6 +20,8 @@ def __init__(self, params = None): def forward(self, loss_input_dict): predict = loss_input_dict['prediction'] soft_y = loss_input_dict['ground_truth'] + pix_w = loss_input_dict.get('pixel_weight', None) + cls_w = loss_input_dict.get('class_weight', None) if(isinstance(predict, (list, tuple))): predict = predict[0] @@ -27,9 +29,15 @@ def forward(self, loss_input_dict): predict = nn.Softmax(dim = 1)(predict) predict = reshape_tensor_to_2D(predict) soft_y = reshape_tensor_to_2D(soft_y) - dice_score = get_classwise_dice(predict, soft_y) - dice_loss = 1.0 - dice_score.mean() - return dice_loss + if(pix_w is not None): + pix_w = reshape_tensor_to_2D(pix_w) + dice_loss = 1.0 - get_classwise_dice(predict, soft_y, pix_w) + if(cls_w is not None): + weighted_loss = dice_loss * cls_w + avg_loss = weighted_loss.sum() / cls_w.sum() + else: + avg_loss = dice_loss.mean() + return avg_loss class BinaryDiceLoss(AbstractSegLoss): ''' diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 3d2fb21..70f01b3 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -118,11 +118,21 @@ def create_loss_calculator(self): def get_loss_value(self, data, pred, gt, param = None): loss_input_dict = {'prediction':pred, 'ground_truth': gt} - if data.get('pixel_weight', None) is not None: - if(isinstance(pred, tuple) or isinstance(pred, list)): - loss_input_dict['pixel_weight'] = data['pixel_weight'].to(pred[0].device) - else: - loss_input_dict['pixel_weight'] = data['pixel_weight'].to(pred.device) + if(isinstance(pred, tuple) or isinstance(pred, list)): + device = pred[0].device + else: + device = pred.device + pixel_weight = data.get('pixel_weight', None) + if(pixel_weight is not None): + loss_input_dict['pixel_weight'] = pixel_weight.to(device) + + class_weight = self.config['training'].get('class_weight', None) + if(class_weight is not None): + class_num = self.config['network']['class_num'] + assert(len(class_weight) == class_num) + class_weight = torch.from_numpy(np.asarray(class_weight)) + class_weight = self.convert_tensor_type(class_weight) + loss_input_dict['class_weight'] = class_weight.to(device) loss_value = self.loss_calculator(loss_input_dict) return loss_value From d74ac93760a76a1cb32e947aaefe545627bfc9ce Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 27 Jun 2023 11:10:23 +0800 Subject: [PATCH 15/86] Update net_dict_seg.py --- pymic/net/net_dict_seg.py | 56 +++++++++++++++++++-------------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/pymic/net/net_dict_seg.py b/pymic/net/net_dict_seg.py index e6c10bd..ffaa023 100644 --- a/pymic/net/net_dict_seg.py +++ b/pymic/net/net_dict_seg.py @@ -21,24 +21,24 @@ from pymic.net.net2d.unet2d_attention import AttentionUNet2D from pymic.net.net2d.unet2d_nest import NestedUNet2D from pymic.net.net2d.unet2d_scse import UNet2D_ScSE -from pymic.net.net2d.trans2d.transunet import TransUNet -from pymic.net.net2d.trans2d.swinunet import SwinUNet +# from pymic.net.net2d.trans2d.transunet import TransUNet +# from pymic.net.net2d.trans2d.swinunet import SwinUNet from pymic.net.net3d.unet2d5 import UNet2D5 from pymic.net.net3d.unet3d import UNet3D from pymic.net.net3d.unet3d_scse import UNet3D_ScSE from pymic.net.net3d.unet3d_dual_branch import UNet3D_DualBranch -from pymic.net.net3d.trans3d.nnFormer_wrap import nnFormer_wrap -from pymic.net.net3d.trans3d.unetr import UNETR -from pymic.net.net3d.trans3d.unetr_pp import UNETR_PP -from pymic.net.net3d.trans3d.MedFormer_v1 import MedFormerV1 -from pymic.net.net3d.trans3d.MedFormer_v2 import MedFormerV2 -from pymic.net.net3d.trans3d.MedFormer_v3 import MedFormerV3 -from pymic.net.net3d.trans3d.MedFormer_va1 import MedFormerVA1 -from pymic.net.net3d.trans3d.HiFormer_v1 import HiFormer_v1 -from pymic.net.net3d.trans3d.HiFormer_v2 import HiFormer_v2 -from pymic.net.net3d.trans3d.HiFormer_v3 import HiFormer_v3 -from pymic.net.net3d.trans3d.HiFormer_v4 import HiFormer_v4 -from pymic.net.net3d.trans3d.HiFormer_v5 import HiFormer_v5 +# from pymic.net.net3d.trans3d.nnFormer_wrap import nnFormer_wrap +# from pymic.net.net3d.trans3d.unetr import UNETR +# from pymic.net.net3d.trans3d.unetr_pp import UNETR_PP +# from pymic.net.net3d.trans3d.MedFormer_v1 import MedFormerV1 +# from pymic.net.net3d.trans3d.MedFormer_v2 import MedFormerV2 +# from pymic.net.net3d.trans3d.MedFormer_v3 import MedFormerV3 +# from pymic.net.net3d.trans3d.MedFormer_va1 import MedFormerVA1 +# from pymic.net.net3d.trans3d.HiFormer_v1 import HiFormer_v1 +# from pymic.net.net3d.trans3d.HiFormer_v2 import HiFormer_v2 +# from pymic.net.net3d.trans3d.HiFormer_v3 import HiFormer_v3 +# from pymic.net.net3d.trans3d.HiFormer_v4 import HiFormer_v4 +# from pymic.net.net3d.trans3d.HiFormer_v5 import HiFormer_v5 SegNetDict = { 'UNet2D': UNet2D, @@ -48,22 +48,22 @@ 'AttentionUNet2D': AttentionUNet2D, 'NestedUNet2D': NestedUNet2D, 'UNet2D_ScSE': UNet2D_ScSE, - 'TransUNet': TransUNet, - 'SwinUNet': SwinUNet, + # 'TransUNet': TransUNet, + # 'SwinUNet': SwinUNet, 'UNet2D5': UNet2D5, 'UNet3D': UNet3D, 'UNet3D_ScSE': UNet3D_ScSE, 'UNet3D_DualBranch': UNet3D_DualBranch, - 'nnFormer': nnFormer_wrap, - 'UNETR': UNETR, - 'UNETR_PP': UNETR_PP, - 'MedFormerV1': MedFormerV1, - 'MedFormerV2': MedFormerV2, - 'MedFormerV3': MedFormerV3, - 'MedFormerVA1':MedFormerVA1, - 'HiFormer_v1': HiFormer_v1, - 'HiFormer_v2': HiFormer_v2, - 'HiFormer_v3': HiFormer_v3, - 'HiFormer_v4': HiFormer_v4, - 'HiFormer_v5': HiFormer_v5 + # 'nnFormer': nnFormer_wrap, + # 'UNETR': UNETR, + # 'UNETR_PP': UNETR_PP, + # 'MedFormerV1': MedFormerV1, + # 'MedFormerV2': MedFormerV2, + # 'MedFormerV3': MedFormerV3, + # 'MedFormerVA1':MedFormerVA1, + # 'HiFormer_v1': HiFormer_v1, + # 'HiFormer_v2': HiFormer_v2, + # 'HiFormer_v3': HiFormer_v3, + # 'HiFormer_v4': HiFormer_v4, + # 'HiFormer_v5': HiFormer_v5 } From f43f409d1f0023e5b07fcfb72ec53ed45fc1ee95 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 30 Jun 2023 18:37:29 +0800 Subject: [PATCH 16/86] Update train.py fix issues for datetime --- pymic/net_run/train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pymic/net_run/train.py b/pymic/net_run/train.py index 0167f2f..426b620 100644 --- a/pymic/net_run/train.py +++ b/pymic/net_run/train.py @@ -78,11 +78,12 @@ def main(): os.makedirs(log_dir, exist_ok=True) dst_cfg = cfg_file if "/" not in cfg_file else cfg_file.split("/")[-1] shutil.copy(cfg_file, log_dir + "/" + dst_cfg) + datetime_str = str(datetime.now())[:-7].replace(":", "_") if sys.version.startswith("3.9"): - logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(str(datetime.now())[:-7]), + logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(datetime_str), level=logging.INFO, format='%(message)s', force=True) # for python 3.9 else: - logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(str(datetime.now())[:-7]), + logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(datetime_str), level=logging.INFO, format='%(message)s') # for python 3.6 logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging_config(config) From a7d179fee45edfe006bf8b9c863da2a8190b4b3c Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 18 Jul 2023 09:56:17 +0800 Subject: [PATCH 17/86] Update image_process.py --- pymic/util/image_process.py | 50 +++++++++++++++++++++++++------------ 1 file changed, 34 insertions(+), 16 deletions(-) diff --git a/pymic/util/image_process.py b/pymic/util/image_process.py index 8ae8e80..4143f39 100644 --- a/pymic/util/image_process.py +++ b/pymic/util/image_process.py @@ -319,26 +319,44 @@ def resample_sitk_image_to_given_spacing(image, spacing, order): out_img.SetDirection(image.GetDirection()) return out_img -def get_image_info(img_names): - space0, space1, slices = [], [], [] +def get_image_info(img_names, output_csv = None): + spacing_list, shape_list = [], [] for img_name in img_names: img_obj = sitk.ReadImage(img_name) img_arr = sitk.GetArrayFromImage(img_obj) spacing = img_obj.GetSpacing() - slices.append(img_arr.shape[0]) - space0.append(spacing[0]) - space1.append(spacing[2]) - print(img_name, spacing, img_arr.shape) - - space0 = np.asarray(space0) - space1 = np.asarray(space1) - slices = np.asarray(slices) - print("intra-slice spacing") - print(space0.min(), space0.max(), space0.mean()) - print("inter-slice spacing") - print(space1.min(), space1.max(), space1.mean()) - print("slice number") - print(slices.min(), slices.max(), slices.mean()) + shape = img_arr.shape + spacing_list.append(spacing) + shape_list.append(shape) + print(img_name, spacing, shape) + spacings = np.asarray(spacing_list) + shapes = np.asarray(shape_list) + spacing_min = spacings.min(axis = 0) + spacing_max = spacings.max(axis = 0) + spacing_median = np.percentile(spacings, 50, axis = 0) + print("spacing min", spacing_min) + print("spacing max", spacing_max) + print("spacing median", spacing_median) + + shape_min = shapes.min(axis = 0) + shape_max = shapes.max(axis = 0) + shape_median = np.percentile(shapes, 50, axis = 0) + print("shape min", shape_min) + print("shape max", shape_max) + print("shape median", shape_median) + + if(output_csv is not None): + img_names_short = [item.split("/")[-1] for item in img_names] + img_names_short.extend(["spacing min", "spacing max", "spacing median", + "shape min", "shape max", "shape median"]) + spacing_list.extend([spacing_min, spacing_max, spacing_median, + shape_min, shape_max, shape_median]) + shape_list.extend(['']* 6) + out_dict = {"img_name": img_names_short, + "spacing": spacing_list, + "shape": shape_list} + df = pd.DataFrame.from_dict(out_dict) + df.to_csv(output_csv, index=False) def get_average_mean_std(data_dir, data_csv): df = pd.read_csv(data_csv) From 5db981ffbc35e5e7be69e1a15eae8b3d81419709 Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 18 Jul 2023 20:47:54 +0800 Subject: [PATCH 18/86] Update image_read_write.py --- pymic/io/image_read_write.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pymic/io/image_read_write.py b/pymic/io/image_read_write.py index cb65e19..5278959 100644 --- a/pymic/io/image_read_write.py +++ b/pymic/io/image_read_write.py @@ -23,11 +23,9 @@ def load_nifty_volume_as_4d_array(filename): spacing = img_obj.GetSpacing() direction = img_obj.GetDirection() shape = data_array.shape - if(len(shape) == 4): - assert(shape[3] == 1) - elif(len(shape) == 3): + if(len(shape) == 3): data_array = np.expand_dims(data_array, axis = 0) - else: + elif(len(shape) > 4 or len(shape) < 3): raise ValueError("unsupported image dim: {0:}".format(len(shape))) output = {} output['data_array'] = data_array From d69a415a57e762b6aef71aee1a3ddeac2e52b62b Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 19 Jul 2023 09:40:27 +0800 Subject: [PATCH 19/86] Update rescale.py --- pymic/transform/rescale.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/pymic/transform/rescale.py b/pymic/transform/rescale.py index 355712e..36f122b 100644 --- a/pymic/transform/rescale.py +++ b/pymic/transform/rescale.py @@ -100,27 +100,26 @@ def __init__(self, params): self.ratio0 = params["RandomRescale_lower_bound".lower()] self.ratio1 = params["RandomRescale_upper_bound".lower()] self.prob = params.get('RandomRescale_probability'.lower(), 0.5) - self.inverse = params.get("RandomRescale_inverse".lower(), True) + self.inverse = params.get("RandomRescale_inverse".lower(), False) assert isinstance(self.ratio0, (float, list, tuple)) assert isinstance(self.ratio1, (float, list, tuple)) def __call__(self, sample): - # if(random.random() > self.prob): - # print("rescale not started") - # sample['RandomRescale_triggered'] = False - # return sample - # else: - # print("rescale started") - # sample['RandomRescale_triggered'] = True + image = sample['image'] input_shape = image.shape input_dim = len(input_shape) - 1 - + assert(input_dim == len(self.ratio0) and input_dim == len(self.ratio1)) + if isinstance(self.ratio0, (list, tuple)): - for i in range(len(self.ratio0)): + for i in range(input_dim): + if(self.ratio0[i] is None): + self.ratio0[i] = 1.0 + if(self.ratio1[i] is None): + self.ratio1[i] = 1.0 assert(self.ratio0[i] <= self.ratio1[i]) scale = [self.ratio0[i] + random.random()*(self.ratio1[i] - self.ratio0[i]) \ - for i in range(len(self.ratio0))] + for i in range(input_dim)] else: scale = self.ratio0 + random.random()*(self.ratio1 - self.ratio0) scale = [scale] * input_dim @@ -130,12 +129,12 @@ def __call__(self, sample): sample['image'] = image_t sample['RandomRescale_Param'] = json.dumps(input_shape) if('label' in sample and \ - self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): label = sample['label'] label = ndimage.interpolation.zoom(label, scale, order = 0) sample['label'] = label if('pixel_weight' in sample and \ - self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): weight = sample['pixel_weight'] weight = ndimage.interpolation.zoom(weight, scale, order = 1) sample['pixel_weight'] = weight @@ -143,8 +142,6 @@ def __call__(self, sample): return sample def inverse_transform_for_prediction(self, sample): - if(not sample['RandomRescale_triggered']): - return sample if(isinstance(sample['RandomRescale_Param'], list) or \ isinstance(sample['RandomRescale_Param'], tuple)): origin_shape = json.loads(sample['RandomRescale_Param'][0]) From c09e00a45c8081e82126abea47e65bb7126f5349 Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 20 Jul 2023 08:37:53 +0800 Subject: [PATCH 20/86] Update image_read_write.py --- pymic/io/image_read_write.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pymic/io/image_read_write.py b/pymic/io/image_read_write.py index 5278959..6c8c6b0 100644 --- a/pymic/io/image_read_write.py +++ b/pymic/io/image_read_write.py @@ -97,7 +97,10 @@ def save_array_as_nifty_volume(data, image_name, reference_name = None): #img.CopyInformation(img_ref) img.SetSpacing(img_ref.GetSpacing()) img.SetOrigin(img_ref.GetOrigin()) - img.SetDirection(img_ref.GetDirection()) + direction0 = img_ref.GetDirection() + direction1 = img.GetDirection() + if(len(direction0) == len(direction1)): + img.SetDirection(direction0) sitk.WriteImage(img, image_name) def save_array_as_rgb_image(data, image_name): From ac96e04cfbba183903e623548440cd584076d4a6 Mon Sep 17 00:00:00 2001 From: taigw Date: Sun, 6 Aug 2023 10:26:03 +0800 Subject: [PATCH 21/86] update transforms --- pymic/transform/crop.py | 21 ++++++---- pymic/transform/intensity.py | 74 +++++++++++++++++++++++++++++++----- pymic/transform/rescale.py | 70 ++++++++++++++++++++++++++++++++++ pymic/transform/rotate.py | 4 +- 4 files changed, 151 insertions(+), 18 deletions(-) diff --git a/pymic/transform/crop.py b/pymic/transform/crop.py index b8130f9..36a0ca9 100644 --- a/pymic/transform/crop.py +++ b/pymic/transform/crop.py @@ -153,7 +153,6 @@ def _get_crop_param(self, sample): crop_min = [0] + crop_min crop_max = list(input_shape[0:1]) + crop_max sample['CropWithBoundingBox_Param'] = json.dumps((input_shape, crop_min, crop_max)) - print("for crop", crop_min, crop_max) return sample, crop_min, crop_max def _get_param_for_inverse_transform(self, sample): @@ -248,7 +247,8 @@ def _get_crop_param(self, sample): crop_min = [0 if item == 0 else random.randint(0, item) for item in crop_margin] crop_max = [crop_min[i] + self.output_size[i] for i in range(input_dim)] - if(self.fg_focus and random.random() < self.fg_ratio): + label_exist = False if ('label' not in sample or sample['label']) is None else True + if(label_exist and self.fg_focus and random.random() < self.fg_ratio): label = sample['label'][0] if(self.mask_label is None): mask_label = np.unique(label)[1:] @@ -279,26 +279,33 @@ class RandomResizedCrop(CenterCrop): :param `RandomResizedCrop_output_size`: (list/tuple) Desired output size [D, H, W]. The output channel is the same as the input channel. - :param `RandomResizedCrop_scale_range`: (list/tuple) Range of scale, e.g. (0.08, 1.0). + :param `RandomResizedCrop_scale_lower_bound`: (list/tuple) Lower bound of the range of scale + for each dimension. e.g. (1.0, 0.5, 0.5). + param `RandomResizedCrop_scale_upper_bound`: (list/tuple) Upper bound of the range of scale + for each dimension. e.g. (1.0, 2.0, 2.0). :param `RandomResizedCrop_inverse`: (optional, bool) Is inverse transform needed for inference. Default is `False`. Currently, the inverse transform is not supported, and this transform is assumed to be used only during training stage. """ def __init__(self, params): self.output_size = params['RandomResizedCrop_output_size'.lower()] - self.scale = params['RandomResizedCrop_scale_range'.lower()] + self.scale_lower = params['RandomResizedCrop_scale_lower_bound'.lower()] + self.scale_upper = params['RandomResizedCrop_scale_upper_bound'.lower()] self.inverse = params.get('RandomResizedCrop_inverse'.lower(), False) self.task = params['Task'.lower()] assert isinstance(self.output_size, (list, tuple)) - assert isinstance(self.scale, (list, tuple)) + assert isinstance(self.scale_lower, (list, tuple)) + assert isinstance(self.scale_upper, (list, tuple)) def __call__(self, sample): image = sample['image'] channel, input_size = image.shape[0], image.shape[1:] input_dim = len(input_size) assert(input_dim == len(self.output_size)) - scale = self.scale[0] + random.random()*(self.scale[1] - self.scale[0]) - crop_size = [int(self.output_size[i] * scale) for i in range(input_dim)] + scale = [self.scale_lower[i] + (self.scale_upper[i] - self.scale_lower[i]) * random.random() \ + for i in range(input_dim)] + + crop_size = [int(self.output_size[i] * scale[i]) for i in range(input_dim)] crop_margin = [input_size[i] - crop_size[i] for i in range(input_dim)] pad_image = False if(min(crop_margin) < 0): diff --git a/pymic/transform/intensity.py b/pymic/transform/intensity.py index 1a13190..9dbf6d9 100644 --- a/pymic/transform/intensity.py +++ b/pymic/transform/intensity.py @@ -171,21 +171,31 @@ def __call__(self, sample): class NonLinearTransform(AbstractTransform): def __init__(self, params): super(NonLinearTransform, self).__init__(params) - self.inverse = params.get('NonLinearTransform_inverse'.lower(), False) + self.channels = params['NonLinearTransform_channels'.lower()] self.prob = params.get('NonLinearTransform_probability'.lower(), 0.5) + self.inverse = params.get('NonLinearTransform_inverse'.lower(), False) + def __call__(self, sample): if(random.random() > self.prob): return sample - image= sample['image'] - points = [[0, 0], [random.random(), random.random()], [random.random(), random.random()], [1, 1]] - xvals, yvals = bezier_curve(points, nTimes=100000) - if random.random() < 0.5: # Half change to get flip - xvals = np.sort(xvals) - else: - xvals, yvals = np.sort(xvals), np.sort(yvals) - image = np.interp(image, xvals, yvals) + image = sample['image'] + for chn in self.channels: + points = [[0, 0], [random.random(), random.random()], [random.random(), random.random()], [1, 1]] + xvals, yvals = bezier_curve(points, nTimes=10000) + if random.random() < 0.5: # Half change to get flip + xvals = np.sort(xvals) + else: + xvals, yvals = np.sort(xvals), np.sort(yvals) + # normalize the image intensity to [0, 1] before the non-linear tranform + img_c = image[chn] + v_min = img_c.min() + v_max = img_c.max() + if(v_min < v_max): + img_c = (img_c - v_min)/(v_max - v_min) + img_c = np.interp(img_c, xvals, yvals) + image[chn] = img_c * (v_max - v_min) + v_min sample['image'] = image return sample @@ -392,4 +402,50 @@ def __call__(self, sample): sample = self.inpaint(sample) else: sample = self.outpaint(sample) + return sample + +class PatchSwaping(AbstractTransform): + """ + Apply patch swaping for context restoration in self-supervised learning. + Reference: Liang Chen et al., Self-supervised learning for medical image analysis + using image context restoration, Medical Image Analysis, 2019. + """ + def __init__(self, params): + super(PatchSwaping, self).__init__(params) + self.inverse = params.get('PatchSwaping_inverse'.lower(), False) + self.swap_t = params.get('PatchSwaping_swap_time'.lower(), (1, 6)) + self.patch_size_min = params.get('PatchSwaping_patch_size_min'.lower(), None) + self.patch_size_max = params.get('PatchSwaping_patch_size_max'.lower(), None) + + def __call__(self, sample): + + image= sample['image'] + img_shape = image.shape + img_dim = len(img_shape) - 1 + assert(img_dim == 2 or img_dim == 3) + img_out = image + + C, D, H, W = image.shape + patch_size = [random.randint(self.patch_size_min[i], self.patch_size_max[i]) for \ + i in range(img_dim)] + + coordinate_list = [] + for d in range(0, D-patch_size[0], patch_size[0]): + for h in range(0, H-patch_size[1], patch_size[1]): + for w in range(0, W-patch_size[2], patch_size[2]): + coordinate_list.append((d, h, w)) + random.shuffle(coordinate_list) + + for t in range(self.swap_t): + pos_a0 = coordinate_list[2*t] + pos_b0 = coordinate_list[2*t + 1] + pos_a1 = [pos_a0[i] + patch_size[i] for i in range(img_dim)] + pos_b1 = [pos_b0[i] + patch_size[i] for i in range(img_dim)] + img_out[:, pos_a0[0]:pos_a1[0], pos_a0[1]:pos_a1[1], pos_a0[2]:pos_a1[2]] = \ + image[:, pos_b0[0]:pos_b1[0], pos_b0[1]:pos_b1[1], pos_b0[2]:pos_b1[2]] + img_out[:, pos_b0[0]:pos_b1[0], pos_b0[1]:pos_b1[1], pos_b0[2]:pos_b1[2]] = \ + image[:, pos_a0[0]:pos_a1[0], pos_a0[1]:pos_a1[1], pos_a0[2]:pos_a1[2]] + + sample['image'] = img_out + sample['label'] = image return sample \ No newline at end of file diff --git a/pymic/transform/rescale.py b/pymic/transform/rescale.py index 36f122b..fa4f052 100644 --- a/pymic/transform/rescale.py +++ b/pymic/transform/rescale.py @@ -154,6 +154,76 @@ def inverse_transform_for_prediction(self, sample): i in range(origin_dim)] scale = [1.0, 1.0] + scale + output_predict = ndimage.interpolation.zoom(predict, scale, order = 1) + sample['predict'] = output_predict + return sample + + +class Resample(Rescale): + """Resample the image to a given spatial resolution. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `Rescale_output_size`: (list/tuple or int) The output size along each spatial axis, + such as [D, H, W] or [H, W]. If D is None, the input image is only reslcaled in 2D. + If int, the smallest axis is matched to output_size keeping aspect ratio the same + as the input. + :param `Rescale_inverse`: (optional, bool) + Is inverse transform needed for inference. Default is `True`. + """ + def __init__(self, params): + super(Rescale, self).__init__(params) + self.output_spacing = params["Resample_output_spacing".lower()] + self.ignore_zspacing= params.get("Resample_ignore_zspacing_range".lower(), None) + self.inverse = params.get("Resample_inverse".lower(), True) + # assert isinstance(self.output_size, (int, list, tuple)) + + def __call__(self, sample): + image = sample['image'] + input_shape = image.shape + input_dim = len(input_shape) - 1 + spacing = sample['spacing'] + out_spacing = [item for item in self.output_spacing] + for i in range(input_dim): + out_spacing[i] = spacing[i] if out_spacing[i] is None else out_spacing[i] + if(self.ignore_zspacing is not None): + if(spacing[0] > self.ignore_zspacing[0] and spacing[0] < self.ignore_zspacing[1]): + out_spacing[0] = spacing[0] + scale = [spacing[i] / out_spacing[i] for i in range(input_dim)] + scale = [1.0] + scale + + image_t = ndimage.interpolation.zoom(image, scale, order = 1) + + sample['image'] = image_t + sample['spacing'] = out_spacing + sample['Resample_origin_shape'] = json.dumps(input_shape) + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + label = sample['label'] + label = ndimage.interpolation.zoom(label, scale, order = 0) + sample['label'] = label + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + weight = sample['pixel_weight'] + weight = ndimage.interpolation.zoom(weight, scale, order = 1) + sample['pixel_weight'] = weight + + return sample + + def inverse_transform_for_prediction(self, sample): + if(isinstance(sample['Resample_origin_shape'], list) or \ + isinstance(sample['Resample_origin_shape'], tuple)): + origin_shape = json.loads(sample['Resample_origin_shape'][0]) + else: + origin_shape = json.loads(sample['Resample_origin_shape']) + origin_dim = len(origin_shape) - 1 + predict = sample['predict'] + input_shape = predict.shape + scale = [(origin_shape[1:][i] + 0.0)/input_shape[2:][i] for \ + i in range(origin_dim)] + scale = [1.0, 1.0] + scale + output_predict = ndimage.interpolation.zoom(predict, scale, order = 1) sample['predict'] = output_predict return sample \ No newline at end of file diff --git a/pymic/transform/rotate.py b/pymic/transform/rotate.py index 65e5328..5f85e28 100644 --- a/pymic/transform/rotate.py +++ b/pymic/transform/rotate.py @@ -34,8 +34,8 @@ class RandomRotate(AbstractTransform): def __init__(self, params): super(RandomRotate, self).__init__(params) self.angle_range_d = params['RandomRotate_angle_range_d'.lower()] - self.angle_range_h = params['RandomRotate_angle_range_h'.lower()] - self.angle_range_w = params['RandomRotate_angle_range_w'.lower()] + self.angle_range_h = params.get('RandomRotate_angle_range_h'.lower(), None) + self.angle_range_w = params.get('RandomRotate_angle_range_w'.lower(), None) self.prob = params.get('RandomRotate_probability'.lower(), 0.5) self.inverse = params.get('RandomRotate_inverse'.lower(), True) From 05d1748b10870760d336f36a53fcbc57798c5354 Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 10 Aug 2023 09:17:21 +0800 Subject: [PATCH 22/86] update transform --- pymic/transform/flip.py | 2 +- pymic/transform/intensity.py | 3 ++ pymic/transform/mix.py | 85 ++++++++++++++++++++++++++++++++++- pymic/transform/trans_dict.py | 4 ++ 4 files changed, 92 insertions(+), 2 deletions(-) diff --git a/pymic/transform/flip.py b/pymic/transform/flip.py index 24cafb4..486180c 100644 --- a/pymic/transform/flip.py +++ b/pymic/transform/flip.py @@ -54,7 +54,7 @@ def __call__(self, sample): image_t = np.flip(image, flip_axis).copy() sample['image'] = image_t if('label' in sample and \ - self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): sample['label'] = np.flip(sample['label'] , flip_axis).copy() if('pixel_weight' in sample and \ self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): diff --git a/pymic/transform/intensity.py b/pymic/transform/intensity.py index 9dbf6d9..e5e30f6 100644 --- a/pymic/transform/intensity.py +++ b/pymic/transform/intensity.py @@ -98,6 +98,7 @@ def __init__(self, params): self.channels = params['GammaCorrection_channels'.lower()] self.gamma_min = params['GammaCorrection_gamma_min'.lower()] self.gamma_max = params['GammaCorrection_gamma_max'.lower()] + self.flip_prob = params.get('GammaCorrection_intensity_flip_probability'.lower(), 0.2) self.prob = params.get('GammaCorrection_probability'.lower(), 0.5) self.inverse = params.get('GammaCorrection_inverse'.lower(), False) @@ -112,6 +113,8 @@ def __call__(self, sample): v_max = img_c.max() if(v_min < v_max): img_c = (img_c - v_min)/(v_max - v_min) + if(np.random.uniform() < self.flip_prob): + img_c = 1.0 - img_c img_c = np.power(img_c, gamma_c)*(v_max - v_min) + v_min image[chn] = img_c diff --git a/pymic/transform/mix.py b/pymic/transform/mix.py index fe2f315..6e2fb8e 100644 --- a/pymic/transform/mix.py +++ b/pymic/transform/mix.py @@ -63,4 +63,87 @@ def __call__(self, sample): coord_min[1]:coord_min[1] + block_size[1], coord_min[2]:coord_min[2] + block_size[2]] = random_block sample['image'] = image - return sample \ No newline at end of file + return sample + +class PatchMix(AbstractTransform): + """ + In-painting of an input image, used for self-supervised learning + """ + def __init__(self, params): + super(PatchMix, self).__init__(params) + self.inverse = params.get('PatchMix_inverse'.lower(), False) + self.crop_size = params.get('PatchMix_crop_size'.lower(), [64, 128, 128]) + self.fg_cls_num = params.get('PatchMix_cls_num'.lower(), [4, 40]) + self.patch_num_range= params.get('PatchMix_patch_range'.lower(), [4, 40]) + self.patch_size_min = params.get('PatchMix_patch_size_min'.lower(), [4, 4, 4]) + self.patch_size_max = params.get('PatchMix_patch_size_max'.lower(), [20, 40, 40]) + + def __call__(self, sample): + x0, x1 = self._random_crop_and_flip(sample) + C, D, H, W = x0.shape + # generate mask + fg_mask = np.zeros_like(x0, np.uint8) + patch_num = random.randint(self.patch_num_range[0], self.patch_num_range[1]) + for patch in range(patch_num): + d = random.randint(self.patch_size_min[0], self.patch_size_max[0]) + h = random.randint(self.patch_size_min[1], self.patch_size_max[1]) + w = random.randint(self.patch_size_min[2], self.patch_size_max[2]) + d_c = random.randint(0, D) + h_c = random.randint(0, H) + w_c = random.randint(0, W) + d0, d1 = max(0, d_c - d // 2), min(D, d_c + d // 2) + h0, h1 = max(0, h_c - h // 2), min(H, h_c + h // 2) + w0, w1 = max(0, w_c - w // 2), min(W, w_c + w // 2) + temp_m = np.ones([C, d1-d0, h1-h0, w1-w0]) * random.randint(1, self.fg_cls_num) + fg_mask[:, d0:d1, h0:h1, w0:w1] = temp_m + fg_w = fg_mask * 1.0 / self.fg_cls_num + x_fuse = fg_w*x0 + (1.0 - fg_w)*x1 # x1 is used as background + + sample['image'] = x_fuse + sample['label'] = fg_mask + return sample + + def _random_crop_and_flip(self, sample): + input_shape = sample['image'].shape + input_dim = len(input_shape) - 1 + assert(input_dim == 3) + + if('label' in sample): + # get the center for crop randomly + mask = sample['label'] > 0 + C, D, H, W = input_shape + size_h = [i// 2 for i in self.crop_size] + temp_mask = np.zeros_like(mask) + temp_mask[:,size_h[0]:D-size_h[0]+1,size_h[1]:H-size_h[1]+1,size_h[2]:W-size_h[2]+1] = \ + np.ones([C, D-self.crop_size[0]+1, H-self.crop_size[1]+1, W-self.crop_size[2]+1]) + mask = mask * temp_mask + indices = np.where(mask) + n0 = random.randint(0, len(indices[0])-1) + n1 = random.randint(0, len(indices[0])-1) + center0 = [indices[i][n0] for i in range(1, 4)] + center1 = [indices[i][n1] for i in range(1, 4)] + crop_min0 = [center0[i] - size_h[i] for i in range(3)] + crop_min1 = [center1[i] - size_h[i] for i in range(3)] + else: + crop_margin = [input_shape[1+i] - self.crop_size[i] for i in range(input_dim)] + crop_min0 = [0 if item == 0 else random.randint(0, item) for item in crop_margin] + crop_min1 = [0 if item == 0 else random.randint(0, item) for item in crop_margin] + + patches = [] + for crop_min in [crop_min0, crop_min1]: + crop_max = [crop_min[i] + self.crop_size[i] for i in range(input_dim)] + crop_min = [0] + crop_min + crop_max = [C] + crop_max + x = crop_ND_volume_with_bounding_box(sample['image'], crop_min, crop_max) + flip_axis = [] + if(random.random() > 0.5): + flip_axis.append(-1) + if(random.random() > 0.5): + flip_axis.append(-2) + if(random.random() > 0.5): + flip_axis.append(-3) + if(len(flip_axis) > 0): + x = np.flip(x, flip_axis).copy() + patches.append(x) + + return patches \ No newline at end of file diff --git a/pymic/transform/trans_dict.py b/pymic/transform/trans_dict.py index 5dac73b..f779d00 100644 --- a/pymic/transform/trans_dict.py +++ b/pymic/transform/trans_dict.py @@ -40,6 +40,7 @@ from pymic.transform.threshold import * from pymic.transform.normalize import * from pymic.transform.crop import * +from pymic.transform.mix import * from pymic.transform.label_convert import * TransformDict = { @@ -71,8 +72,11 @@ 'RandomRotate': RandomRotate, 'ReduceLabelDim': ReduceLabelDim, 'Rescale': Rescale, + 'Resample': Resample, 'SelfReconstructionLabel': SelfReconstructionLabel, 'MaskedImageModelingLabel': MaskedImageModelingLabel, 'OutPainting': OutPainting, 'Pad': Pad, + 'PatchSwaping':PatchSwaping, + 'PatchMix': PatchMix } From b526f936e78f54f7a68157fe0a93e871c52b94da Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 2 Sep 2023 15:03:15 +0800 Subject: [PATCH 23/86] update transform allow default channel setting --- pymic/transform/intensity.py | 29 ++++++++++++++++------------- pymic/transform/normalize.py | 13 +++++++------ 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/pymic/transform/intensity.py b/pymic/transform/intensity.py index e5e30f6..ffa5141 100644 --- a/pymic/transform/intensity.py +++ b/pymic/transform/intensity.py @@ -95,18 +95,20 @@ class GammaCorrection(AbstractTransform): """ def __init__(self, params): super(GammaCorrection, self).__init__(params) - self.channels = params['GammaCorrection_channels'.lower()] - self.gamma_min = params['GammaCorrection_gamma_min'.lower()] - self.gamma_max = params['GammaCorrection_gamma_max'.lower()] - self.flip_prob = params.get('GammaCorrection_intensity_flip_probability'.lower(), 0.2) + self.channels = params.get('GammaCorrection_channels'.lower(), None) + self.gamma_min = params.get('GammaCorrection_gamma_min'.lower(), 0.7) + self.gamma_max = params.get('GammaCorrection_gamma_max'.lower(), 1.5) + self.flip_prob = params.get('GammaCorrection_intensity_flip_probability'.lower(), 0.0) self.prob = params.get('GammaCorrection_probability'.lower(), 0.5) self.inverse = params.get('GammaCorrection_inverse'.lower(), False) def __call__(self, sample): - if(np.random.uniform() > self.prob): - return sample image= sample['image'] + if(self.channels is None): + self.channels = range(image.shape[0]) for chn in self.channels: + if(np.random.uniform() > self.prob): + continue gamma_c = random.random() * (self.gamma_max - self.gamma_min) + self.gamma_min img_c = image[chn] v_min = img_c.min() @@ -138,20 +140,21 @@ class GaussianNoise(AbstractTransform): """ def __init__(self, params): super(GaussianNoise, self).__init__(params) - self.channels = params['GaussianNoise_channels'.lower()] + self.channels = params.get('GaussianNoise_channels'.lower(), None) self.mean = params['GaussianNoise_mean'.lower()] self.std = params['GaussianNoise_std'.lower()] self.prob = params.get('GaussianNoise_probability'.lower(), 0.5) self.inverse = params.get('GaussianNoise_inverse'.lower(), False) def __call__(self, sample): - if(np.random.uniform() > self.prob): - return sample - image= sample['image'] + image = sample['image'] + if(self.channels is None): + self.channels = range(image.shape[0]) for chn in self.channels: - img_c = image[chn] - noise = np.random.normal(self.mean, self.std, img_c.shape) - image[chn] = img_c + noise + if(np.random.uniform() < self.prob): + img_c = image[chn] + noise = np.random.normal(self.mean, self.std, img_c.shape) + image[chn] = img_c + noise sample['image'] = image return sample diff --git a/pymic/transform/normalize.py b/pymic/transform/normalize.py index 77852d2..6531f17 100644 --- a/pymic/transform/normalize.py +++ b/pymic/transform/normalize.py @@ -34,7 +34,7 @@ class NormalizeWithMeanStd(AbstractTransform): """ def __init__(self, params): super(NormalizeWithMeanStd, self).__init__(params) - self.chns = params['NormalizeWithMeanStd_channels'.lower()] + self.chns = params.get('NormalizeWithMeanStd_channels'.lower(), None) self.mean = params.get('NormalizeWithMeanStd_mean'.lower(), None) self.std = params.get('NormalizeWithMeanStd_std'.lower(), None) self.ingore_np = params.get('NormalizeWithMeanStd_ignore_non_positive'.lower(), False) @@ -42,13 +42,14 @@ def __init__(self, params): def __call__(self, sample): image= sample['image'] - chns = self.chns if self.chns is not None else range(image.shape[0]) + if(self.chns is None): + self.chns = range(image.shape[0]) if(self.mean is None): - self.mean = [None] * len(chns) - self.std = [None] * len(chns) + self.mean = [None] * len(self.chns) + self.std = [None] * len(self.chns) - for i in range(len(chns)): - chn = chns[i] + for i in range(len(self.chns)): + chn = self.chns[i] chn_mean, chn_std = self.mean[i], self.std[i] if(chn_mean is None): if(self.ingore_np): From 73a13a39ab9b36d95cd75cbc93a1df38558d5215 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 2 Sep 2023 15:24:08 +0800 Subject: [PATCH 24/86] Update unetr_pp.py --- pymic/net/net3d/trans3d/unetr_pp.py | 63 ++++++++++++++++------------- 1 file changed, 36 insertions(+), 27 deletions(-) diff --git a/pymic/net/net3d/trans3d/unetr_pp.py b/pymic/net/net3d/trans3d/unetr_pp.py index 3ef6736..a4ab7e6 100644 --- a/pymic/net/net3d/trans3d/unetr_pp.py +++ b/pymic/net/net3d/trans3d/unetr_pp.py @@ -155,7 +155,6 @@ def __init__( def forward(self, x): B, C, H, W, D = x.shape - x = x.reshape(B, C, H * W * D).permute(0, 2, 1) if self.pos_embed is not None: @@ -170,12 +169,13 @@ def forward(self, x): class UnetrPPEncoder(nn.Module): def __init__(self, input_size=[32 * 32 * 32, 16 * 16 * 16, 8 * 8 * 8, 4 * 4 * 4],dims=[32, 64, 128, 256], - proj_size =[64,64,64,32], depths=[3, 3, 3, 3], num_heads=4, spatial_dims=3, in_channels=1, dropout=0.0, transformer_dropout_rate=0.15 ,**kwargs): + proj_size =[64,64,64,32], depths=[3, 3, 3, 3], num_heads=4, spatial_dims=3, + in_channels=1, dropout=0.0, transformer_dropout_rate=0.15, kernel_size=(2,4,4), **kwargs): super().__init__() self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers stem_layer = nn.Sequential( - get_conv_layer(spatial_dims, in_channels, dims[0], kernel_size=(2, 4, 4), stride=(2, 4, 4), + get_conv_layer(spatial_dims, in_channels, dims[0], kernel_size=kernel_size, stride=kernel_size, dropout=dropout, conv_only=True, ), get_norm_layer(name=("group", {"num_groups": in_channels}), channels=dims[0]), ) @@ -209,7 +209,6 @@ def _init_weights(self, m): def forward_features(self, x): hidden_states = [] - x = self.downsample_layers[0](x) x = self.stages[0](x) @@ -330,6 +329,7 @@ def __init__(self, params): in_channels = params['in_chns'] out_channels = params['class_num'] img_size = params['img_size'] + self.res_mode= params.get("resolution_mode", 1) feature_size = params.get('feature_size', 16) hidden_size = params.get('hidden_size', 256) num_heads = params.get('num_heads', 4) @@ -350,15 +350,20 @@ def __init__(self, params): if pos_embed not in ["conv", "perceptron"]: raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.") - self.patch_size = (2, 4, 4) + kernel_ds = [4, 2, 1] + kernel_d = kernel_ds[self.res_mode] + self.patch_size = (kernel_d, 4, 4) + self.feat_size = ( img_size[0] // self.patch_size[0] // 8, # 8 is the downsampling happened through the four encoders stages img_size[1] // self.patch_size[1] // 8, # 8 is the downsampling happened through the four encoders stages img_size[2] // self.patch_size[2] // 8, # 8 is the downsampling happened through the four encoders stages ) + self.hidden_size = hidden_size - self.unetr_pp_encoder = UnetrPPEncoder(dims=dims, depths=depths, num_heads=num_heads) + self.unetr_pp_encoder = UnetrPPEncoder(dims=dims, depths=depths, num_heads=num_heads, + in_channels=in_channels, kernel_size=self.patch_size) self.encoder1 = UnetResBlock( spatial_dims=3, @@ -395,20 +400,21 @@ def __init__(self, params): norm_name=norm_name, out_size=32 * 32 * 32, ) + self.decoder2 = UnetrUpBlock( spatial_dims=3, in_channels=feature_size * 2, out_channels=feature_size, kernel_size=3, - upsample_kernel_size=(2, 4, 4), + upsample_kernel_size= self.patch_size, norm_name=norm_name, - out_size=64 * 128 * 128, + out_size= kernel_d*32 * 128 * 128, conv_decoder=True, ) self.out1 = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels) - if self.do_ds: - self.out2 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 2, out_channels=out_channels) - self.out3 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 4, out_channels=out_channels) + # if self.do_ds: + self.out2 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 2, out_channels=out_channels) + self.out3 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 4, out_channels=out_channels) def proj_feat(self, x, hidden_size, feat_size): x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) @@ -442,19 +448,22 @@ def forward(self, x_in): if __name__ == "__main__": - params = {'in_chns': 1, - 'class_num': 2, - 'img_size': [64, 128, 128] - } - net = UNETR_PP(params) - net.double() - - x = np.random.rand(2, 1, 64, 128, 128) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = net(xt) - print(len(y)) - for yi in y: - yi = yi.detach().numpy() - print(yi.shape) \ No newline at end of file + depths = [128, 64, 32] + for i in range(3): + params = {'in_chns': 4, + 'class_num': 2, + 'img_size': [depths[i], 128, 128], + 'resolution_mode': i + } + net = UNETR_PP(params) + net.double() + + x = np.random.rand(2, 4, depths[i], 128, 128) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = net(xt) + print(len(y)) + for yi in y: + yi = yi.detach().numpy() + print(yi.shape) \ No newline at end of file From c57f50e442a338b8be0d3e10a7d2cc51b1f15981 Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 13 Sep 2023 10:50:22 +0800 Subject: [PATCH 25/86] update crop update functions for random crop --- pymic/transform/crop.py | 61 +++++++++++++++++++++++++------------ pymic/util/image_process.py | 50 +++++++++++++++++++++++------- 2 files changed, 81 insertions(+), 30 deletions(-) diff --git a/pymic/transform/crop.py b/pymic/transform/crop.py index 36a0ca9..b4d0b63 100644 --- a/pymic/transform/crop.py +++ b/pymic/transform/crop.py @@ -255,7 +255,7 @@ def _get_crop_param(self, sample): else: mask_label = self.mask_label random_label = random.choice(mask_label) - crop_min, crop_max = get_random_box_from_mask(label == random_label, self.output_size) + crop_min, crop_max = get_random_box_from_mask(label == random_label, self.output_size, mode = 1) crop_min = [0] + crop_min crop_max = [chns] + crop_max @@ -289,8 +289,11 @@ class RandomResizedCrop(CenterCrop): """ def __init__(self, params): self.output_size = params['RandomResizedCrop_output_size'.lower()] - self.scale_lower = params['RandomResizedCrop_scale_lower_bound'.lower()] - self.scale_upper = params['RandomResizedCrop_scale_upper_bound'.lower()] + self.scale_lower = params['RandomResizedCrop_resize_lower_bound'.lower()] + self.scale_upper = params['RandomResizedCrop_resize_upper_bound'.lower()] + self.prob = params.get('RandomResizedCrop_resize_prob'.lower(), 0.5) + self.fg_ratio = params.get('RandomResizedCrop_foreground_ratio'.lower(), 0.0) + self.mask_label = params.get('RandomResizedCrop_mask_label'.lower(), None) self.inverse = params.get('RandomResizedCrop_inverse'.lower(), False) self.task = params['Task'.lower()] assert isinstance(self.output_size, (list, tuple)) @@ -302,14 +305,19 @@ def __call__(self, sample): channel, input_size = image.shape[0], image.shape[1:] input_dim = len(input_size) assert(input_dim == len(self.output_size)) - scale = [self.scale_lower[i] + (self.scale_upper[i] - self.scale_lower[i]) * random.random() \ - for i in range(input_dim)] - - crop_size = [int(self.output_size[i] * scale[i]) for i in range(input_dim)] + + # get the resized crop size + resize = random.random() < self.prob + if(resize): + scale = [self.scale_lower[i] + (self.scale_upper[i] - self.scale_lower[i]) * random.random() \ + for i in range(input_dim)] + crop_size = [int(self.output_size[i] * scale[i]) for i in range(input_dim)] + else: + crop_size = self.output_size + crop_margin = [input_size[i] - crop_size[i] for i in range(input_dim)] - pad_image = False - if(min(crop_margin) < 0): - pad_image = True + pad_image = min(crop_margin) < 0 + if(pad_image): # pad the image if necessary pad_size = [max(0, -crop_margin[i]) for i in range(input_dim)] pad_lower = [int(pad_size[i] / 2) for i in range(input_dim)] pad_upper = [pad_size[i] - pad_lower[i] for i in range(input_dim)] @@ -317,16 +325,29 @@ def __call__(self, sample): pad = tuple([(0, 0)] + pad) image = np.pad(image, pad, 'reflect') crop_margin = [max(0, crop_margin[i]) for i in range(input_dim)] - - crop_min = [random.randint(0, item) for item in crop_margin] - crop_max = [crop_min[i] + crop_size[i] for i in range(input_dim)] + # ge the bounding box for crop + if(random.random() < self.fg_ratio): + label = sample['label'] + if(pad_image): + label = np.pad(label, pad, 'reflect') + label = label[0] + if(self.mask_label is None): + mask_label = np.unique(label)[1:] + else: + mask_label = self.mask_label + random_label = random.choice(mask_label) + crop_min, crop_max = get_random_box_from_mask(label == random_label, crop_size, mode = 1) + else: + crop_min = [random.randint(0, item) for item in crop_margin] + crop_max = [crop_min[i] + crop_size[i] for i in range(input_dim)] crop_min = [0] + crop_min crop_max = [channel] + crop_max image_t = crop_ND_volume_with_bounding_box(image, crop_min, crop_max) - scale = [(self.output_size[i] + 0.0)/crop_size[i] for i in range(input_dim)] - scale = [1.0] + scale - image_t = ndimage.interpolation.zoom(image_t, scale, order = 1) + if(resize): + scale = [(self.output_size[i] + 0.0)/crop_size[i] for i in range(input_dim)] + scale = [1.0] + scale + image_t = ndimage.interpolation.zoom(image_t, scale, order = 1) sample['image'] = image_t if('label' in sample and \ @@ -336,8 +357,9 @@ def __call__(self, sample): label = np.pad(label, pad, 'reflect') crop_max[0] = label.shape[0] label = crop_ND_volume_with_bounding_box(label, crop_min, crop_max) - order = 0 if(self.task == TaskType.SEGMENTATION) else 1 - label = ndimage.interpolation.zoom(label, scale, order = order) + if(resize): + order = 0 if(self.task == TaskType.SEGMENTATION) else 1 + label = ndimage.interpolation.zoom(label, scale, order = order) sample['label'] = label if('pixel_weight' in sample and \ self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): @@ -346,6 +368,7 @@ def __call__(self, sample): weight = np.pad(weight, pad, 'reflect') crop_max[0] = weight.shape[0] weight = crop_ND_volume_with_bounding_box(weight, crop_min, crop_max) - weight = ndimage.interpolation.zoom(weight, scale, order = 1) + if(resize): + weight = ndimage.interpolation.zoom(weight, scale, order = 1) sample['pixel_weight'] = weight return sample \ No newline at end of file diff --git a/pymic/util/image_process.py b/pymic/util/image_process.py index 4143f39..d6a7220 100644 --- a/pymic/util/image_process.py +++ b/pymic/util/image_process.py @@ -167,18 +167,46 @@ def random_crop_ND_volume(volume, out_shape): crop_volume = crop_ND_volume_with_bounding_box(image_pad, bb_min, bb_max) return crop_volume -def get_random_box_from_mask(mask, out_shape): - indexes = np.where(mask) - voxel_num = len(indexes[0]) - dim = len(out_shape) - left_bound = [int(out_shape[i]/2) for i in range(dim)] - right_bound = [mask.shape[i] - (out_shape[i] - left_bound[i]) for i in range(dim)] +def get_random_box_from_mask(mask, out_shape, mode = 0): + """ + get a bounding box of a subvolume according to a mask + + mode == 0: The output bounding box should be a sub region of the mask region + mode == 1: The center point of the output bounding box can be ahy where of the mask region + """ + dim = len(out_shape) + left_margin = [int(out_shape[i]/2) for i in range(dim)] + right_margin = [out_shape[i] - left_margin[i] for i in range(dim)] + + if(mode == 0): + bb_mask_min, bb_mask_max = get_ND_bounding_box(mask) + bb_valid_min, bb_valid_max = [], [] + for i in range(dim): + mask_size = bb_mask_max[i] - bb_mask_min[i] + if(mask_size > out_shape[i]): + valid_left = bb_mask_min[i] + left_margin[i] + valid_right = bb_mask_max[i] - right_margin[i] + else: + valid_left = (bb_mask_max[i] - bb_mask_min[i]) // 2 + valid_right = valid_left + 1 + bb_valid_min.append(valid_left) + bb_valid_max.append(valid_right) + + valid_region_shape = [bb_valid_max[i] - bb_valid_min[i] for i in range(dim)] + valid_mask = np.zeros_like(mask) + valid_mask = set_ND_volume_roi_with_bounding_box_range(valid_mask, + bb_valid_min, bb_valid_max, np.ones(valid_region_shape, np.bool), addition = True) + valid_mask = valid_mask * mask + else: + valid_mask = mask + indices = np.where(valid_mask) + voxel_num = len(indices[0]) j = random.randint(0, voxel_num - 1) - bb_c = [int(indexes[i][j]) for i in range(dim)] - bb_c = [max(left_bound[i], bb_c[i]) for i in range(dim)] - bb_c = [min(right_bound[i], bb_c[i]) for i in range(dim)] - bb_min = [bb_c[i] - left_bound[i] for i in range(dim)] + bb_c = [int(indices[i][j]) for i in range(dim)] + bb_min = [max(0, bb_c[i] - left_margin[i]) for i in range(dim)] + mask_shape = np.shape(mask) + bb_min = [min(bb_min[i], mask_shape[i] - out_shape[i]) for i in range(dim)] bb_max = [bb_min[i] + out_shape[i] for i in range(dim)] return bb_min, bb_max @@ -205,7 +233,7 @@ def random_crop_ND_volume_with_mask(volume, out_shape, mask): pad = [(ml[i], mr[i]) for i in range(dim)] pad = tuple(pad) image_pad = np.pad(volume, pad, 'reflect') - mask_pad = np.pad(mask, pad, 'reflect') + mask_pad = np.pad(mask, pad, 'constant') bb_min, bb_max = get_random_box_from_mask(mask_pad, out_shape) # left_margin = [int(out_shape[i]/2) for i in range(dim)] From fbf1b26d6f26a4fb79be6741a2d4344e5b6dfb76 Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 30 Oct 2023 16:21:36 +0800 Subject: [PATCH 26/86] update 2D transformers update 2D transformers --- pymic/net/net2d/trans2d/__init__.py | 0 pymic/net/net2d/trans2d/swinunet.py | 124 ++++ pymic/net/net2d/trans2d/swinunet_sys.py | 749 ++++++++++++++++++++ pymic/net/net2d/trans2d/transunet.py | 491 +++++++++++++ pymic/net/net2d/trans2d/transunet_cfg.py | 135 ++++ pymic/net/net2d/trans2d/transunet_resnet.py | 164 +++++ 6 files changed, 1663 insertions(+) create mode 100644 pymic/net/net2d/trans2d/__init__.py create mode 100644 pymic/net/net2d/trans2d/swinunet.py create mode 100644 pymic/net/net2d/trans2d/swinunet_sys.py create mode 100644 pymic/net/net2d/trans2d/transunet.py create mode 100644 pymic/net/net2d/trans2d/transunet_cfg.py create mode 100644 pymic/net/net2d/trans2d/transunet_resnet.py diff --git a/pymic/net/net2d/trans2d/__init__.py b/pymic/net/net2d/trans2d/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pymic/net/net2d/trans2d/swinunet.py b/pymic/net/net2d/trans2d/swinunet.py new file mode 100644 index 0000000..f35539a --- /dev/null +++ b/pymic/net/net2d/trans2d/swinunet.py @@ -0,0 +1,124 @@ +# -*- coding: utf-8 -*- +""" +code adapted from: https://github.com/HuCaoFighting/Swin-Unet + +""" +from __future__ import print_function, division + +import copy +import numpy as np +import torch +import torch.nn as nn + +from pymic.net.net2d.trans2d.swinunet_sys import SwinTransformerSys + +class SwinUNet(nn.Module): + """ + Implementatin of Swin-UNet. + + * Reference: Hu Cao, Yueyue Wang et al: + Swin-Unet: Unet-Like Pure Transformer for Medical Image Segmentation. + `ECCV 2022 Workshops. `_ + + Note that the input channel can only be 1 or 3, and the input image size should be 224x224. + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param img_size: (tuple) The input image size, should be [224, 224]. + :param class_num: (int) The class number for segmentation task. + """ + def __init__(self, params): + super(SwinUNet, self).__init__() + img_size = params['img_size'] + if(isinstance(img_size, tuple) or isinstance(img_size, list)): + img_size = img_size[0] + self.num_classes = params['class_num'] + self.swin_unet = SwinTransformerSys(img_size = img_size, num_classes=self.num_classes) + # self.swin_unet = SwinTransformerSys(img_size=config.DATA.IMG_SIZE, + # patch_size=config.MODEL.SWIN.PATCH_SIZE, + # in_chans=config.MODEL.SWIN.IN_CHANS, + # num_classes=self.num_classes, + # embed_dim=config.MODEL.SWIN.EMBED_DIM, + # depths=config.MODEL.SWIN.DEPTHS, + # num_heads=config.MODEL.SWIN.NUM_HEADS, + # window_size=config.MODEL.SWIN.WINDOW_SIZE, + # mlp_ratio=config.MODEL.SWIN.MLP_RATIO, + # qkv_bias=config.MODEL.SWIN.QKV_BIAS, + # qk_scale=config.MODEL.SWIN.QK_SCALE, + # drop_rate=config.MODEL.DROP_RATE, + # drop_path_rate=config.MODEL.DROP_PATH_RATE, + # ape=config.MODEL.SWIN.APE, + # patch_norm=config.MODEL.SWIN.PATCH_NORM, + # use_checkpoint=config.TRAIN.USE_CHECKPOINT) + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + + if x.size()[1] == 1: + x = x.repeat(1,3,1,1) + logits = self.swin_unet(x) + + if(len(x_shape) == 5): + new_shape = [N, D] + list(logits.shape)[1:] + logits = torch.reshape(logits, new_shape) + logits = torch.transpose(logits, 1, 2) + + return logits + + def load_from(self, config): + pretrained_path = config.MODEL.PRETRAIN_CKPT + if pretrained_path is not None: + print("pretrained_path:{}".format(pretrained_path)) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + pretrained_dict = torch.load(pretrained_path, map_location=device) + if "model" not in pretrained_dict: + print("---start load pretrained modle by splitting---") + pretrained_dict = {k[17:]:v for k,v in pretrained_dict.items()} + for k in list(pretrained_dict.keys()): + if "output" in k: + print("delete key:{}".format(k)) + del pretrained_dict[k] + msg = self.swin_unet.load_state_dict(pretrained_dict,strict=False) + # print(msg) + return + pretrained_dict = pretrained_dict['model'] + print("---start load pretrained modle of swin encoder---") + + model_dict = self.swin_unet.state_dict() + full_dict = copy.deepcopy(pretrained_dict) + for k, v in pretrained_dict.items(): + if "layers." in k: + current_layer_num = 3-int(k[7:8]) + current_k = "layers_up." + str(current_layer_num) + k[8:] + full_dict.update({current_k:v}) + for k in list(full_dict.keys()): + if k in model_dict: + if full_dict[k].shape != model_dict[k].shape: + print("delete:{};shape pretrain:{};shape model:{}".format(k,v.shape,model_dict[k].shape)) + del full_dict[k] + + msg = self.swin_unet.load_state_dict(full_dict, strict=False) + # print(msg) + else: + print("none pretrain") + + +if __name__ == "__main__": + params = {'img_size': [224, 224], + 'class_num': 2} + net = SwinUNet(params) + net.double() + + x = np.random.rand(4, 3, 224, 224) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = net(xt) + print(len(y.size())) + y = y.detach().numpy() + print(y.shape) \ No newline at end of file diff --git a/pymic/net/net2d/trans2d/swinunet_sys.py b/pymic/net/net2d/trans2d/swinunet_sys.py new file mode 100644 index 0000000..a6e3552 --- /dev/null +++ b/pymic/net/net2d/trans2d/swinunet_sys.py @@ -0,0 +1,749 @@ +# -*- coding: utf-8 -*- +""" +code adapted from: https://github.com/HuCaoFighting/Swin-Unet + +""" +from __future__ import print_function, division + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint +from einops import rearrange +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + +class PatchExpand(nn.Module): + def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.expand = nn.Linear(dim, 2*dim, bias=False) if dim_scale==2 else nn.Identity() + self.norm = norm_layer(dim // dim_scale) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + x = self.expand(x) + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4) + x = x.view(B,-1,C//4) + x= self.norm(x) + + return x + +class FinalPatchExpand_X4(nn.Module): + def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.dim_scale = dim_scale + self.expand = nn.Linear(dim, 16*dim, bias=False) + self.output_dim = dim + self.norm = norm_layer(self.output_dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + x = self.expand(x) + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//(self.dim_scale**2)) + x = x.view(B,-1,self.output_dim) + x= self.norm(x) + + return x + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + +class BasicLayer_up(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if upsample is not None: + self.upsample = PatchExpand(input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer) + else: + self.upsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.upsample is not None: + x = self.upsample(x) + return x + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + + +class SwinTransformerSys(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, + embed_dim=96, depths=[2, 2, 2, 2], depths_decoder=[1, 2, 2, 2], num_heads=[3, 6, 12, 24], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, final_upsample="expand_first", **kwargs): + super().__init__() + + print("SwinTransformerSys expand initial----depths:{};depths_decoder:{};drop_path_rate:{};num_classes:{}".format(depths, + depths_decoder,drop_path_rate,num_classes)) + + self.num_classes = num_classes + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.num_features_up = int(embed_dim * 2) + self.mlp_ratio = mlp_ratio + self.final_upsample = final_upsample + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build encoder and bottleneck layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), + input_resolution=(patches_resolution[0] // (2 ** i_layer), + patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) + self.layers.append(layer) + + # build decoder layers + self.layers_up = nn.ModuleList() + self.concat_back_dim = nn.ModuleList() + for i_layer in range(self.num_layers): + concat_linear = nn.Linear(2*int(embed_dim*2**(self.num_layers-1-i_layer)), + int(embed_dim*2**(self.num_layers-1-i_layer))) if i_layer > 0 else nn.Identity() + if i_layer ==0 : + layer_up = PatchExpand(input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)), + patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), dim_scale=2, norm_layer=norm_layer) + else: + layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), + input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)), + patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), + depth=depths[(self.num_layers-1-i_layer)], + num_heads=num_heads[(self.num_layers-1-i_layer)], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:(self.num_layers-1-i_layer)]):sum(depths[:(self.num_layers-1-i_layer) + 1])], + norm_layer=norm_layer, + upsample=PatchExpand if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) + self.layers_up.append(layer_up) + self.concat_back_dim.append(concat_linear) + + self.norm = norm_layer(self.num_features) + self.norm_up= norm_layer(self.embed_dim) + + if self.final_upsample == "expand_first": + print("---final upsample expand_first---") + self.up = FinalPatchExpand_X4(input_resolution=(img_size//patch_size,img_size//patch_size),dim_scale=4,dim=embed_dim) + self.output = nn.Conv2d(in_channels=embed_dim,out_channels=self.num_classes,kernel_size=1,bias=False) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + #Encoder and Bottleneck + def forward_features(self, x): + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + x_downsample = [] + + for layer in self.layers: + x_downsample.append(x) + x = layer(x) + + x = self.norm(x) # B L C + + return x, x_downsample + + #Dencoder and Skip connection + def forward_up_features(self, x, x_downsample): + for inx, layer_up in enumerate(self.layers_up): + if inx == 0: + x = layer_up(x) + else: + x = torch.cat([x,x_downsample[3-inx]],-1) + x = self.concat_back_dim[inx](x) + x = layer_up(x) + + x = self.norm_up(x) # B L C + + return x + + def up_x4(self, x): + H, W = self.patches_resolution + B, L, C = x.shape + assert L == H*W, "input features has wrong size" + + if self.final_upsample=="expand_first": + x = self.up(x) + x = x.view(B,4*H,4*W,-1) + x = x.permute(0,3,1,2) #B,C,H,W + x = self.output(x) + + return x + + def forward(self, x): + x, x_downsample = self.forward_features(x) + x = self.forward_up_features(x,x_downsample) + x = self.up_x4(x) + + return x + + def flops(self): + flops = 0 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) + flops += self.num_features * self.num_classes + return \ No newline at end of file diff --git a/pymic/net/net2d/trans2d/transunet.py b/pymic/net/net2d/trans2d/transunet.py new file mode 100644 index 0000000..9db5d2d --- /dev/null +++ b/pymic/net/net2d/trans2d/transunet.py @@ -0,0 +1,491 @@ +# -*- coding: utf-8 -*- +""" +code adapted from: https://github.com/Beckschen/TransUNet +""" +from __future__ import print_function, division + +import copy +# import logging +import math +import torch +import torch.nn as nn +from torch.nn import Dropout, Softmax, Linear, Conv2d, LayerNorm +from torch.nn.modules.utils import _pair + +import numpy as np +from scipy import ndimage +from os.path import join as pjoin +import pymic.net.net2d.trans2d.transunet_cfg as configs +from pymic.net.net2d.trans2d.transunet_resnet import ResNetV2 + + +VIT_CONFIGS = { + 'ViT-B_16': configs.get_b16_config(), + 'ViT-B_32': configs.get_b32_config(), + 'ViT-L_16': configs.get_l16_config(), + 'ViT-L_32': configs.get_l32_config(), + 'ViT-H_14': configs.get_h14_config(), + 'R50-ViT-B_16': configs.get_r50_b16_config(), + 'R50-ViT-L_16': configs.get_r50_l16_config(), + 'testing': configs.get_testing(), +} + +ATTENTION_Q = "MultiHeadDotProductAttention_1/query" +ATTENTION_K = "MultiHeadDotProductAttention_1/key" +ATTENTION_V = "MultiHeadDotProductAttention_1/value" +ATTENTION_OUT = "MultiHeadDotProductAttention_1/out" +FC_0 = "MlpBlock_3/Dense_0" +FC_1 = "MlpBlock_3/Dense_1" +ATTENTION_NORM = "LayerNorm_0" +MLP_NORM = "LayerNorm_2" + + +def np2th(weights, conv=False): + """Possibly convert HWIO to OIHW.""" + if conv: + weights = weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(weights) + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} + +class Attention(nn.Module): + def __init__(self, config, vis): + super(Attention, self).__init__() + self.vis = vis + self.num_attention_heads = config.transformer["num_heads"] + self.attention_head_size = int(config.hidden_size / self.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = Linear(config.hidden_size, self.all_head_size) + self.key = Linear(config.hidden_size, self.all_head_size) + self.value = Linear(config.hidden_size, self.all_head_size) + + self.out = Linear(config.hidden_size, config.hidden_size) + self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"]) + self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"]) + + self.softmax = Softmax(dim=-1) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + attention_probs = self.softmax(attention_scores) + weights = attention_probs if self.vis else None + attention_probs = self.attn_dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + attention_output = self.out(context_layer) + attention_output = self.proj_dropout(attention_output) + return attention_output, weights + +class Mlp(nn.Module): + def __init__(self, config): + super(Mlp, self).__init__() + self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"]) + self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size) + self.act_fn = ACT2FN["gelu"] + self.dropout = Dropout(config.transformer["dropout_rate"]) + + self._init_weights() + + def _init_weights(self): + nn.init.xavier_uniform_(self.fc1.weight) + nn.init.xavier_uniform_(self.fc2.weight) + nn.init.normal_(self.fc1.bias, std=1e-6) + nn.init.normal_(self.fc2.bias, std=1e-6) + + def forward(self, x): + x = self.fc1(x) + x = self.act_fn(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + +class Embeddings(nn.Module): + """Construct the embeddings from patch, position embeddings. + """ + def __init__(self, config, img_size, in_channels=3): + super(Embeddings, self).__init__() + self.hybrid = None + self.config = config + img_size = _pair(img_size) + + if config.patches.get("grid") is not None: # ResNet + grid_size = config.patches["grid"] + patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1]) + patch_size_real = (patch_size[0] * 16, patch_size[1] * 16) + n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1]) + self.hybrid = True + else: + patch_size = _pair(config.patches["size"]) + n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) + self.hybrid = False + + if self.hybrid: + self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor) + in_channels = self.hybrid_model.width * 16 + self.patch_embeddings = Conv2d(in_channels=in_channels, + out_channels=config.hidden_size, + kernel_size=patch_size, + stride=patch_size) + self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size)) + + self.dropout = Dropout(config.transformer["dropout_rate"]) + + + def forward(self, x): + if self.hybrid: + x, features = self.hybrid_model(x) + else: + features = None + x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2)) + x = x.flatten(2) + x = x.transpose(-1, -2) # (B, n_patches, hidden) + + embeddings = x + self.position_embeddings + embeddings = self.dropout(embeddings) + return embeddings, features + +class Block(nn.Module): + def __init__(self, config, vis): + super(Block, self).__init__() + self.hidden_size = config.hidden_size + self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6) + self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6) + self.ffn = Mlp(config) + self.attn = Attention(config, vis) + + def forward(self, x): + h = x + x = self.attention_norm(x) + x, weights = self.attn(x) + x = x + h + + h = x + x = self.ffn_norm(x) + x = self.ffn(x) + x = x + h + return x, weights + + def load_from(self, weights, n_block): + ROOT = f"Transformer/encoderblock_{n_block}" + with torch.no_grad(): + query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t() + key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t() + value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t() + out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t() + + query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1) + key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1) + value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1) + out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1) + + self.attn.query.weight.copy_(query_weight) + self.attn.key.weight.copy_(key_weight) + self.attn.value.weight.copy_(value_weight) + self.attn.out.weight.copy_(out_weight) + self.attn.query.bias.copy_(query_bias) + self.attn.key.bias.copy_(key_bias) + self.attn.value.bias.copy_(value_bias) + self.attn.out.bias.copy_(out_bias) + + mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t() + mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t() + mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t() + mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t() + + self.ffn.fc1.weight.copy_(mlp_weight_0) + self.ffn.fc2.weight.copy_(mlp_weight_1) + self.ffn.fc1.bias.copy_(mlp_bias_0) + self.ffn.fc2.bias.copy_(mlp_bias_1) + + self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")])) + self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")])) + self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")])) + self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")])) + +class Encoder(nn.Module): + def __init__(self, config, vis): + super(Encoder, self).__init__() + self.vis = vis + self.layer = nn.ModuleList() + self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6) + for _ in range(config.transformer["num_layers"]): + layer = Block(config, vis) + self.layer.append(copy.deepcopy(layer)) + + def forward(self, hidden_states): + attn_weights = [] + for layer_block in self.layer: + hidden_states, weights = layer_block(hidden_states) + if self.vis: + attn_weights.append(weights) + encoded = self.encoder_norm(hidden_states) + return encoded, attn_weights + +class Transformer(nn.Module): + def __init__(self, config, img_size, vis): + super(Transformer, self).__init__() + self.embeddings = Embeddings(config, img_size=img_size) + self.encoder = Encoder(config, vis) + + def forward(self, input_ids): + embedding_output, features = self.embeddings(input_ids) + encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden) + return encoded, attn_weights, features + +class Conv2dReLU(nn.Sequential): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + padding=0, + stride=1, + use_batchnorm=True, + ): + conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=not (use_batchnorm), + ) + relu = nn.ReLU(inplace=True) + + bn = nn.BatchNorm2d(out_channels) + + super(Conv2dReLU, self).__init__(conv, bn, relu) + + +class DecoderBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + skip_channels=0, + use_batchnorm=True, + ): + super().__init__() + self.conv1 = Conv2dReLU( + in_channels + skip_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.conv2 = Conv2dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.up = nn.UpsamplingBilinear2d(scale_factor=2) + + def forward(self, x, skip=None): + x = self.up(x) + if skip is not None: + x = torch.cat([x, skip], dim=1) + x = self.conv1(x) + x = self.conv2(x) + return x + + +class SegmentationHead(nn.Sequential): + + def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1): + conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) + upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity() + super().__init__(conv2d, upsampling) + + +class DecoderCup(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + head_channels = 512 + self.conv_more = Conv2dReLU( + config.hidden_size, + head_channels, + kernel_size=3, + padding=1, + use_batchnorm=True, + ) + decoder_channels = config.decoder_channels + in_channels = [head_channels] + list(decoder_channels[:-1]) + out_channels = decoder_channels + + if self.config.n_skip != 0: + skip_channels = self.config.skip_channels + for i in range(4-self.config.n_skip): # re-select the skip channels according to n_skip + skip_channels[3-i]=0 + + else: + skip_channels=[0,0,0,0] + + blocks = [ + DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels) + ] + self.blocks = nn.ModuleList(blocks) + + def forward(self, hidden_states, features=None): + B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden) + h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch)) + x = hidden_states.permute(0, 2, 1) + x = x.contiguous().view(B, hidden, h, w) + x = self.conv_more(x) + for i, decoder_block in enumerate(self.blocks): + if features is not None: + skip = features[i] if (i < self.config.n_skip) else None + else: + skip = None + x = decoder_block(x, skip=skip) + return x + +class TransUNet(nn.Module): + """ + Implementatin of TransUNet. + + * Reference: Jieneng Chen, Yongyi Lu et al: + TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation. + `Arxiv 2021. `_ + + Note that the input channel can only be 1 or 3, and the input image size should be 256x256. + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param img_size: (tuple) The input image size, should be [256, 256]. + :param class_num: (int) The class number for segmentation task. + :param vit_name: (string) The name for vit backbone. It can be one of the following: 'ViT-B_16', + 'ViT-B_32','ViT-L_16', 'ViT-L_32', 'ViT-H_14'. 'R50-ViT-B_16', 'R50-ViT-L_16'. + By default, it is 'R50-ViT-B_16'. + """ + def __init__(self, params): + super(TransUNet, self).__init__() + vit_name = params.get("vit_name", 'R50-ViT-B_16') + img_size = params['img_size'] + vis = params.get("vis", False) + self.config = VIT_CONFIGS[vit_name] + self.num_classes = params['class_num'] + self.zero_head = params.get("zero_head", False) + + self.classifier = self.config.classifier + self.transformer = Transformer(self.config, img_size, vis) + self.decoder = DecoderCup(self.config) + self.segmentation_head = SegmentationHead( + in_channels=self.config['decoder_channels'][-1], + out_channels=self.num_classes, + kernel_size=3, + ) + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + + if x.size()[1] == 1: + x = x.repeat(1,3,1,1) + elif(x.size()[1] !=3): + raise ValueError("The input channel number should be 1 or 3 for TransUNet") + x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden) + x = self.decoder(x, features) + logits = self.segmentation_head(x) + + if(len(x_shape) == 5): + new_shape = [N, D] + list(logits.shape)[1:] + logits = torch.reshape(logits, new_shape) + logits = torch.transpose(logits, 1, 2) + + return logits + + def load_from(self, weights): + with torch.no_grad(): + + res_weight = weights + self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True)) + self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"])) + + self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"])) + self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"])) + + posemb = np2th(weights["Transformer/posembed_input/pos_embedding"]) + + posemb_new = self.transformer.embeddings.position_embeddings + if posemb.size() == posemb_new.size(): + self.transformer.embeddings.position_embeddings.copy_(posemb) + elif posemb.size()[1]-1 == posemb_new.size()[1]: + posemb = posemb[:, 1:] + self.transformer.embeddings.position_embeddings.copy_(posemb) + else: + # logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size())) + ntok_new = posemb_new.size(1) + if self.classifier == "seg": + _, posemb_grid = posemb[:, :1], posemb[0, 1:] + gs_old = int(np.sqrt(len(posemb_grid))) + gs_new = int(np.sqrt(ntok_new)) + print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new)) + posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) + zoom = (gs_new / gs_old, gs_new / gs_old, 1) + posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) # th2np + posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) + posemb = posemb_grid + self.transformer.embeddings.position_embeddings.copy_(np2th(posemb)) + + # Encoder whole + for bname, block in self.transformer.encoder.named_children(): + for uname, unit in block.named_children(): + unit.load_from(weights, n_block=uname) + + if self.transformer.embeddings.hybrid: + self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(res_weight["conv_root/kernel"], conv=True)) + gn_weight = np2th(res_weight["gn_root/scale"]).view(-1) + gn_bias = np2th(res_weight["gn_root/bias"]).view(-1) + self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight) + self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias) + + for bname, block in self.transformer.embeddings.hybrid_model.body.named_children(): + for uname, unit in block.named_children(): + unit.load_from(res_weight, n_block=bname, n_unit=uname) + +if __name__ == "__main__": + params = {'img_size': [256, 256], + 'class_num': 2} + net = TransUNet(params) + net.double() + + for c in [1,3]: + x = np.random.rand(4, c, 256, 256) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = net(xt) + print(len(y.size())) + y = y.detach().numpy() + print(y.shape) \ No newline at end of file diff --git a/pymic/net/net2d/trans2d/transunet_cfg.py b/pymic/net/net2d/trans2d/transunet_cfg.py new file mode 100644 index 0000000..aab62d4 --- /dev/null +++ b/pymic/net/net2d/trans2d/transunet_cfg.py @@ -0,0 +1,135 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +""" +code adapted from: https://github.com/Beckschen/TransUNet +""" +import ml_collections + +def get_b16_config(): + """Returns the ViT-B/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.hidden_size = 768 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 3072 + config.transformer.num_heads = 12 + config.transformer.num_layers = 12 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + + config.classifier = 'seg' + config.representation_size = None + config.resnet_pretrained_path = None + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_16.npz' + config.patch_size = 16 + + config.decoder_channels = (256, 128, 64, 16) + config.n_classes = 2 + config.activation = 'softmax' + return config + + +def get_testing(): + """Returns a minimal configuration for testing.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.hidden_size = 1 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 1 + config.transformer.num_heads = 1 + config.transformer.num_layers = 1 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.classifier = 'token' + config.representation_size = None + return config + +def get_r50_b16_config(): + """Returns the Resnet50 + ViT-B/16 configuration.""" + config = get_b16_config() + config.patches.grid = (16, 16) + config.resnet = ml_collections.ConfigDict() + config.resnet.num_layers = (3, 4, 9) + config.resnet.width_factor = 1 + + config.classifier = 'seg' + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz' + config.decoder_channels = (256, 128, 64, 16) + config.skip_channels = [512, 256, 64, 16] + config.n_classes = 2 + config.n_skip = 3 + config.activation = 'softmax' + + return config + + +def get_b32_config(): + """Returns the ViT-B/32 configuration.""" + config = get_b16_config() + config.patches.size = (32, 32) + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_32.npz' + return config + + +def get_l16_config(): + """Returns the ViT-L/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.hidden_size = 1024 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 4096 + config.transformer.num_heads = 16 + config.transformer.num_layers = 24 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.representation_size = None + + # custom + config.classifier = 'seg' + config.resnet_pretrained_path = None + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-L_16.npz' + config.decoder_channels = (256, 128, 64, 16) + config.n_classes = 2 + config.activation = 'softmax' + return config + + +def get_r50_l16_config(): + """Returns the Resnet50 + ViT-L/16 configuration. customized """ + config = get_l16_config() + config.patches.grid = (16, 16) + config.resnet = ml_collections.ConfigDict() + config.resnet.num_layers = (3, 4, 9) + config.resnet.width_factor = 1 + + config.classifier = 'seg' + config.resnet_pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz' + config.decoder_channels = (256, 128, 64, 16) + config.skip_channels = [512, 256, 64, 16] + config.n_classes = 2 + config.activation = 'softmax' + return config + + +def get_l32_config(): + """Returns the ViT-L/32 configuration.""" + config = get_l16_config() + config.patches.size = (32, 32) + return config + + +def get_h14_config(): + """Returns the ViT-L/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (14, 14)}) + config.hidden_size = 1280 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 5120 + config.transformer.num_heads = 16 + config.transformer.num_layers = 32 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.classifier = 'token' + config.representation_size = None + + return config \ No newline at end of file diff --git a/pymic/net/net2d/trans2d/transunet_resnet.py b/pymic/net/net2d/trans2d/transunet_resnet.py new file mode 100644 index 0000000..144a268 --- /dev/null +++ b/pymic/net/net2d/trans2d/transunet_resnet.py @@ -0,0 +1,164 @@ +# -*- coding: utf-8 -*- +""" +code adapted from: https://github.com/Beckschen/TransUNet +""" +from __future__ import print_function, division + +from os.path import join as pjoin +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def np2th(weights, conv=False): + """Possibly convert HWIO to OIHW.""" + if conv: + weights = weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(weights) + + +class StdConv2d(nn.Conv2d): + + def forward(self, x): + w = self.weight + v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False) + w = (w - m) / torch.sqrt(v + 1e-5) + return F.conv2d(x, w, self.bias, self.stride, self.padding, + self.dilation, self.groups) + + +def conv3x3(cin, cout, stride=1, groups=1, bias=False): + return StdConv2d(cin, cout, kernel_size=3, stride=stride, + padding=1, bias=bias, groups=groups) + + +def conv1x1(cin, cout, stride=1, bias=False): + return StdConv2d(cin, cout, kernel_size=1, stride=stride, + padding=0, bias=bias) + + +class PreActBottleneck(nn.Module): + """Pre-activation (v2) bottleneck block. + """ + + def __init__(self, cin, cout=None, cmid=None, stride=1): + super().__init__() + cout = cout or cin + cmid = cmid or cout//4 + + self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6) + self.conv1 = conv1x1(cin, cmid, bias=False) + self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6) + self.conv2 = conv3x3(cmid, cmid, stride, bias=False) # Original code has it on conv1!! + self.gn3 = nn.GroupNorm(32, cout, eps=1e-6) + self.conv3 = conv1x1(cmid, cout, bias=False) + self.relu = nn.ReLU(inplace=True) + + if (stride != 1 or cin != cout): + # Projection also with pre-activation according to paper. + self.downsample = conv1x1(cin, cout, stride, bias=False) + self.gn_proj = nn.GroupNorm(cout, cout) + + def forward(self, x): + + # Residual branch + residual = x + if hasattr(self, 'downsample'): + residual = self.downsample(x) + residual = self.gn_proj(residual) + + # Unit's branch + y = self.relu(self.gn1(self.conv1(x))) + y = self.relu(self.gn2(self.conv2(y))) + y = self.gn3(self.conv3(y)) + + y = self.relu(residual + y) + return y + + def load_from(self, weights, n_block, n_unit): + conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True) + conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True) + conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True) + + gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")]) + gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")]) + + gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")]) + gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")]) + + gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")]) + gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")]) + + self.conv1.weight.copy_(conv1_weight) + self.conv2.weight.copy_(conv2_weight) + self.conv3.weight.copy_(conv3_weight) + + self.gn1.weight.copy_(gn1_weight.view(-1)) + self.gn1.bias.copy_(gn1_bias.view(-1)) + + self.gn2.weight.copy_(gn2_weight.view(-1)) + self.gn2.bias.copy_(gn2_bias.view(-1)) + + self.gn3.weight.copy_(gn3_weight.view(-1)) + self.gn3.bias.copy_(gn3_bias.view(-1)) + + if hasattr(self, 'downsample'): + proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True) + proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")]) + proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")]) + + self.downsample.weight.copy_(proj_conv_weight) + self.gn_proj.weight.copy_(proj_gn_weight.view(-1)) + self.gn_proj.bias.copy_(proj_gn_bias.view(-1)) + +class ResNetV2(nn.Module): + """Implementation of Pre-activation (v2) ResNet mode.""" + + def __init__(self, block_units, width_factor): + super().__init__() + width = int(64 * width_factor) + self.width = width + + self.root = nn.Sequential(OrderedDict([ + ('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)), + ('gn', nn.GroupNorm(32, width, eps=1e-6)), + ('relu', nn.ReLU(inplace=True)), + # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)) + ])) + + self.body = nn.Sequential(OrderedDict([ + ('block1', nn.Sequential(OrderedDict( + [('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] + + [(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)], + ))), + ('block2', nn.Sequential(OrderedDict( + [('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] + + [(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)], + ))), + ('block3', nn.Sequential(OrderedDict( + [('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] + + [(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)], + ))), + ])) + + def forward(self, x): + features = [] + b, c, in_size, _ = x.size() + x = self.root(x) + features.append(x) + x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x) + for i in range(len(self.body)-1): + x = self.body[i](x) + right_size = int(in_size / 4 / (i+1)) + if x.size()[2] != right_size: + pad = right_size - x.size()[2] + assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), right_size) + feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device) + feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:] + else: + feat = x + features.append(feat) + x = self.body[-1](x) + return x, features[::-1] \ No newline at end of file From 317da9febf192b3cbac047d5b883e80c8855a4c3 Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 31 Oct 2023 16:58:43 +0800 Subject: [PATCH 27/86] add mcnet add mcnet for semi-supervised segmentation --- pymic/net/net2d/unet2d.py | 145 +++++++++++++------------- pymic/net/net2d/unet2d_dual_branch.py | 17 ++- pymic/net/net2d/unet2d_mcnet.py | 69 ++++++++++++ pymic/net/net2d/unet2d_urpc.py | 132 +++++++++++++++++++++++ pymic/net/net_dict_seg.py | 17 +-- pymic/net_run/semi_sup/ssl_mcnet.py | 129 +++++++++++++++++++++++ 6 files changed, 432 insertions(+), 77 deletions(-) create mode 100644 pymic/net/net2d/unet2d_mcnet.py create mode 100644 pymic/net/net2d/unet2d_urpc.py create mode 100644 pymic/net_run/semi_sup/ssl_mcnet.py diff --git a/pymic/net/net2d/unet2d.py b/pymic/net/net2d/unet2d.py index 9acc0ad..758bfe5 100644 --- a/pymic/net/net2d/unet2d.py +++ b/pymic/net/net2d/unet2d.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division +import logging import torch import torch.nn as nn import numpy as np @@ -56,22 +57,33 @@ class UpBlock(nn.Module): :param in_channels2: (int) Channel number of low-level features. :param out_channels: (int) Output channel number. :param dropout_p: (int) Dropout probability. - :param bilinear: (bool) Use bilinear for up-sampling (by default). - If False, deconvolution is used for up-sampling. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Bilinear`), 3 (`Bicubic`). The default value + is 2 (`Bilinear`). """ def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, - bilinear=True): + up_mode = 2): super(UpBlock, self).__init__() - self.bilinear = bilinear - if bilinear: - self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size = 1) - self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + if(isinstance(up_mode, int)): + up_mode_values = ["transconv", "nearest", "bilinear", "bicubic"] + if(up_mode > 3): + raise ValueError("The upsample mode should be 0-3, but {0:} is given.".format(up_mode)) + self.up_mode = up_mode_values[up_mode] else: + self.up_mode = up_mode.lower() + + if (self.up_mode == "transconv"): self.up = nn.ConvTranspose2d(in_channels1, in_channels2, kernel_size=2, stride=2) + else: + self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size = 1) + if(self.up_mode == "nearest"): + self.up = nn.Upsample(scale_factor=2, mode=self.up_mode) + else: + self.up = nn.Upsample(scale_factor=2, mode=self.up_mode, align_corners=True) self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p) def forward(self, x1, x2): - if self.bilinear: + if self.up_mode != "transconv": x1 = self.conv1x1(x1) x1 = self.up(x1) x = torch.cat([x2, x1], dim=1) @@ -129,8 +141,10 @@ class Decoder(nn.Module): :param dropout: (list) The dropout ratio for each resolution level. The length should be the same as that of `feature_chns`. :param class_num: (int) The class number for segmentation task. - :param bilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (or `Nearest`), 2 (or `Bilinear`), 3 (or `Bicubic`). + The default value is 2 (or `Bilinear`). + :param multiscale_pred: (bool) Get multiscale prediction. """ def __init__(self, params): super(Decoder, self).__init__() @@ -139,17 +153,23 @@ def __init__(self, params): self.ft_chns = self.params['feature_chns'] self.dropout = self.params['dropout'] self.n_class = self.params['class_num'] - self.bilinear = self.params['bilinear'] + self.up_mode = self.params['up_mode'] + self.mul_pred = self.params['multiscale_pred'] assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) if(len(self.ft_chns) == 5): - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.bilinear) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.bilinear) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.bilinear) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.bilinear) + self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.up_mode) + self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.up_mode) + self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.up_mode) + self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.up_mode) self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 1) + if(self.mul_pred): + self.out_conv1 = nn.Conv2d(self.ft_chns[1], self.n_class, kernel_size = 1) + self.out_conv2 = nn.Conv2d(self.ft_chns[2], self.n_class, kernel_size = 1) + self.out_conv3 = nn.Conv2d(self.ft_chns[3], self.n_class, kernel_size = 1) + def forward(self, x): if(len(self.ft_chns) == 5): assert(len(x) == 5) @@ -163,6 +183,11 @@ def forward(self, x): x_d1 = self.up3(x_d2, x1) x_d0 = self.up4(x_d1, x0) output = self.out_conv(x_d0) + if(self.mul_pred): + output1 = self.out_conv1(x_d1) + output2 = self.out_conv1(x_d2) + output3 = self.out_conv1(x_d3) + output = [output, output1, output2, output3] return output class UNet2D(nn.Module): @@ -180,43 +205,39 @@ class UNet2D(nn.Module): following fields: :param in_chns: (int) Input channel number. + :param class_num: (int) The class number for segmentation task. + + Optional parameters: + :param feature_chns: (list) Feature channel for each resolution level. The length should be 4 or 5, such as [16, 32, 64, 128, 256]. :param dropout: (list) The dropout ratio for each resolution level. The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param bilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (or `Nearest`), 2 (or `Bilinear`), 3 (or `Bicubic`). + The default value is 2 (or `Bilinear`). :param multiscale_pred: (bool) Get multiscale prediction. """ def __init__(self, params): super(UNet2D, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.n_class = self.params['class_num'] - self.bilinear = self.params['bilinear'] - self.mul_pred = self.params['multiscale_pred'] - - assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) - - self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) - self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) - self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) - self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) - if(len(self.ft_chns) == 5): - self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.bilinear) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.bilinear) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.bilinear) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.bilinear) - - self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 1) - if(self.mul_pred): - self.out_conv1 = nn.Conv2d(self.ft_chns[1], self.n_class, kernel_size = 1) - self.out_conv2 = nn.Conv2d(self.ft_chns[2], self.n_class, kernel_size = 1) - self.out_conv3 = nn.Conv2d(self.ft_chns[3], self.n_class, kernel_size = 1) + params = self.get_default_parameters(params) + for p in params: + print(p, params[p]) + self.encoder = Encoder(params) + self.decoder = Decoder(params) + + def get_default_parameters(self, params): + default_param = { + 'feature_chns': [32, 64, 128, 256, 512], + 'dropout': [0.0, 0.0, 0.2, 0.3, 0.4], + 'up_mode': 2, + 'multiscale_pred': False + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params def forward(self, x): x_shape = list(x.shape) @@ -226,42 +247,26 @@ def forward(self, x): x = torch.transpose(x, 1, 2) x = torch.reshape(x, new_shape) - x0 = self.in_conv(x) - x1 = self.down1(x0) - x2 = self.down2(x1) - x3 = self.down3(x2) - if(len(self.ft_chns) == 5): - x4 = self.down4(x3) - x_d3 = self.up1(x4, x3) - else: - x_d3 = x3 - x_d2 = self.up2(x_d3, x2) - x_d1 = self.up3(x_d2, x1) - x_d0 = self.up4(x_d1, x0) - output = self.out_conv(x_d0) - if(self.mul_pred): - output1 = self.out_conv1(x_d1) - output2 = self.out_conv2(x_d2) - output3 = self.out_conv3(x_d3) - output = [output, output1, output2, output3] - - if(len(x_shape) == 5): + f = self.encoder(x) + output = self.decoder(f) + if(len(x_shape) == 5): + if(isinstance(output, (list,tuple))): for i in range(len(output)): new_shape = [N, D] + list(output[i].shape)[1:] output[i] = torch.transpose(torch.reshape(output[i], new_shape), 1, 2) - elif(len(x_shape) == 5): - new_shape = [N, D] + list(output.shape)[1:] - output = torch.reshape(output, new_shape) - output = torch.transpose(output, 1, 2) - + else: + new_shape = [N, D] + list(output.shape)[1:] + output = torch.transpose(torch.reshape(output, new_shape), 1, 2) + return output + if __name__ == "__main__": params = {'in_chns':4, 'feature_chns':[2, 8, 32, 48, 64], 'dropout': [0, 0, 0.3, 0.4, 0.5], 'class_num': 2, - 'bilinear': True, + 'up_mode': 0, 'multiscale_pred': False} Net = UNet2D(params) Net = Net.double() diff --git a/pymic/net/net2d/unet2d_dual_branch.py b/pymic/net/net2d/unet2d_dual_branch.py index 828bdfe..19a0788 100644 --- a/pymic/net/net2d/unet2d_dual_branch.py +++ b/pymic/net/net2d/unet2d_dual_branch.py @@ -25,11 +25,26 @@ class UNet2D_DualBranch(nn.Module): """ def __init__(self, params): super(UNet2D_DualBranch, self).__init__() - self.output_mode = params.get("output_mode", "average") + params = self.get_default_parameters(params) + self.output_mode = params["output_mode"] self.encoder = Encoder(params) self.decoder1 = Decoder(params) self.decoder2 = Decoder(params) + def get_default_parameters(self, params): + default_param = { + 'feature_chns': [32, 64, 128, 256, 512], + 'dropout': [0.0, 0.0, 0.2, 0.3, 0.4], + 'up_mode': 2, + 'multiscale_pred': False, + 'output_mode': "average" + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params + def forward(self, x): x_shape = list(x.shape) if(len(x_shape) == 5): diff --git a/pymic/net/net2d/unet2d_mcnet.py b/pymic/net/net2d/unet2d_mcnet.py new file mode 100644 index 0000000..c3e8a5f --- /dev/null +++ b/pymic/net/net2d/unet2d_mcnet.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch.nn as nn +from pymic.net.net2d.unet2d import * + + +class MCNet2D(nn.Module): + """ + A tri-branch network using UNet2D as backbone. + + * Reference: Yicheng Wu, Zongyuan Ge et al. Mutual consistency learning for + semi-supervised medical image segmentation. + `MIA 2022. `_ + + The original code is at: https://github.com/ycwu1997/MC-Net + + The parameters for the backbone should be given in the `params` dictionary. + See :mod:`pymic.net.net2d.unet2d.UNet2D` for details. + """ + +class MCNet2D(nn.Module): + def __init__(self, params): + super(MCNet2D, self).__init__() + in_chns = params['in_chns'] + class_num = params['class_num'] + params1 = {'in_chns': in_chns, + 'feature_chns': [16, 32, 64, 128, 256], + 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], + 'class_num': class_num, + 'up_mode': 0, + 'multiscale_pred': False } + params2 = {'in_chns': in_chns, + 'feature_chns': [16, 32, 64, 128, 256], + 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], + 'class_num': class_num, + 'up_mode': 1, + 'multiscale_pred': False} + params3 = {'in_chns': in_chns, + 'feature_chns': [16, 32, 64, 128, 256], + 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], + 'class_num': class_num, + 'up_mode': 2, + 'multiscale_pred': False} + self.encoder = Encoder(params1) + self.decoder1 = Decoder(params1) + self.decoder2 = Decoder(params2) + self.decoder3 = Decoder(params3) + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + + feature = self.encoder(x) + output1 = self.decoder1(feature) + new_shape = [N, D] + list(output1.shape)[1:] + output1 = torch.transpose(torch.reshape(output1, new_shape), 1, 2) + if(not self.training): + return output1 + output2 = self.decoder2(feature) + output3 = self.decoder3(feature) + if(len(x_shape) == 5): + output2 = torch.transpose(torch.reshape(output2, new_shape), 1, 2) + output3 = torch.transpose(torch.reshape(output3, new_shape), 1, 2) + return output1, output2, output3 diff --git a/pymic/net/net2d/unet2d_urpc.py b/pymic/net/net2d/unet2d_urpc.py new file mode 100644 index 0000000..ee8ab7c --- /dev/null +++ b/pymic/net/net2d/unet2d_urpc.py @@ -0,0 +1,132 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import torch.nn as nn +import numpy as np +from torch.distributions.uniform import Uniform +from pymic.net.net2d.unet2d import ConvBlock, DownBlock, UpBlock + +def FeatureDropout(x): + attention = torch.mean(x, dim=1, keepdim=True) + max_val, _ = torch.max(attention.view( + x.size(0), -1), dim=1, keepdim=True) + threshold = max_val * np.random.uniform(0.7, 0.9) + threshold = threshold.view(x.size(0), 1, 1, 1).expand_as(attention) + drop_mask = (attention < threshold).float() + x = x.mul(drop_mask) + return x + +class FeatureNoise(nn.Module): + def __init__(self, uniform_range=0.3): + super(FeatureNoise, self).__init__() + self.uni_dist = Uniform(-uniform_range, uniform_range) + + def feature_based_noise(self, x): + noise_vector = self.uni_dist.sample( + x.shape[1:]).to(x.device).unsqueeze(0) + x_noise = x.mul(noise_vector) + x + return x_noise + + def forward(self, x): + x = self.feature_based_noise(x) + return x + +class UNet2D_URPC(nn.Module): + """ + An modification the U-Net to obtain multi-scale prediction according to + the URPC paper. + + * Reference: Xiangde Luo, Guotai Wang*, Wenjun Liao, Jieneng Chen, Tao Song, Yinan Chen, + Shichuan Zhang, Dimitris N. Metaxas, Shaoting Zhang. + Semi-Supervised Medical Image Segmentation via Uncertainty Rectified Pyramid Consistency . + `Medical Image Analysis 2022. `_ + + Also see: https://github.com/HiLab-git/SSL4MIS/blob/master/code/networks/unet.py + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param bilinear: (bool) Using bilinear for up-sampling or not. + If False, deconvolution will be used for up-sampling. + """ + def __init__(self, params): + super(UNet2D_URPC, self).__init__() + self.params = params + self.in_chns = self.params['in_chns'] + self.ft_chns = self.params['feature_chns'] + self.dropout = self.params['dropout'] + self.n_class = self.params['class_num'] + self.bilinear = self.params['bilinear'] + assert(len(self.ft_chns) == 5) + + self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) + self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) + self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) + self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) + self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) + self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], 0.0, self.bilinear) + self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], 0.0, self.bilinear) + self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], 0.0, self.bilinear) + self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], 0.0, self.bilinear) + + self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, + kernel_size = 3, padding = 1) + self.out_conv_dp4 = nn.Conv2d(self.ft_chns[4], self.n_class, + kernel_size=3, padding=1) + self.out_conv_dp3 = nn.Conv2d(self.ft_chns[3], self.n_class, + kernel_size=3, padding=1) + self.out_conv_dp2 = nn.Conv2d(self.ft_chns[2], self.n_class, + kernel_size=3, padding=1) + self.out_conv_dp1 = nn.Conv2d(self.ft_chns[1], self.n_class, + kernel_size=3, padding=1) + self.feature_noise = FeatureNoise() + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + x0 = self.in_conv(x) + x1 = self.down1(x0) + x2 = self.down2(x1) + x3 = self.down3(x2) + x4 = self.down4(x3) + + x = self.up1(x4, x3) + if self.training: + x = nn.functional.dropout(x, p=0.5) + dp3_out = self.out_conv_dp3(x) + + x = self.up2(x, x2) + if self.training: + x = FeatureDropout(x) + dp2_out = self.out_conv_dp2(x) + + x = self.up3(x, x1) + if self.training: + x = self.feature_noise(x) + dp1_out = self.out_conv_dp1(x) + + x = self.up4(x, x0) + dp0_out = self.out_conv(x) + + out_shape = list(dp0_out.shape)[2:] + dp3_out = nn.functional.interpolate(dp3_out, out_shape) + dp2_out = nn.functional.interpolate(dp2_out, out_shape) + dp1_out = nn.functional.interpolate(dp1_out, out_shape) + out = [dp0_out, dp1_out, dp2_out, dp3_out] + + if(len(x_shape) == 5): + new_shape = [N, D] + list(dp0_out.shape)[1:] + for i in range(len(out)): + out[i] = torch.transpose(torch.reshape(out[i], new_shape), 1, 2) + return out \ No newline at end of file diff --git a/pymic/net/net_dict_seg.py b/pymic/net/net_dict_seg.py index ffaa023..e1e250c 100644 --- a/pymic/net/net_dict_seg.py +++ b/pymic/net/net_dict_seg.py @@ -7,6 +7,7 @@ * UNet2D_CCT :mod:`pymic.net.net2d.unet2d_cct.UNet2D_CCT` * UNet2D_ScSE :mod:`pymic.net.net2d.unet2d_scse.UNet2D_ScSE` * AttentionUNet2D :mod:`pymic.net.net2d.unet2d_attention.AttentionUNet2D` +* MCNet2D :mod:`pymic.net.net2d.unet2d_mcnet.MCNet2D` * NestedUNet2D :mod:`pymic.net.net2d.unet2d_nest.NestedUNet2D` * COPLENet :mod:`pymic.net.net2d.cople_net.COPLENet` * UNet2D5 :mod:`pymic.net.net3d.unet2d5.UNet2D5` @@ -17,12 +18,13 @@ from pymic.net.net2d.unet2d import UNet2D from pymic.net.net2d.unet2d_dual_branch import UNet2D_DualBranch from pymic.net.net2d.unet2d_cct import UNet2D_CCT +from pymic.net.net2d.unet2d_mcnet import MCNet2D from pymic.net.net2d.cople_net import COPLENet from pymic.net.net2d.unet2d_attention import AttentionUNet2D from pymic.net.net2d.unet2d_nest import NestedUNet2D from pymic.net.net2d.unet2d_scse import UNet2D_ScSE -# from pymic.net.net2d.trans2d.transunet import TransUNet -# from pymic.net.net2d.trans2d.swinunet import SwinUNet +from pymic.net.net2d.trans2d.transunet import TransUNet +from pymic.net.net2d.trans2d.swinunet import SwinUNet from pymic.net.net3d.unet2d5 import UNet2D5 from pymic.net.net3d.unet3d import UNet3D from pymic.net.net3d.unet3d_scse import UNet3D_ScSE @@ -39,24 +41,26 @@ # from pymic.net.net3d.trans3d.HiFormer_v3 import HiFormer_v3 # from pymic.net.net3d.trans3d.HiFormer_v4 import HiFormer_v4 # from pymic.net.net3d.trans3d.HiFormer_v5 import HiFormer_v5 +# from pymic.net.net3d.trans3d.SwitchNet import SwitchNet SegNetDict = { 'UNet2D': UNet2D, 'UNet2D_DualBranch': UNet2D_DualBranch, 'UNet2D_CCT': UNet2D_CCT, + 'MCNet2D': MCNet2D, 'COPLENet': COPLENet, 'AttentionUNet2D': AttentionUNet2D, 'NestedUNet2D': NestedUNet2D, 'UNet2D_ScSE': UNet2D_ScSE, - # 'TransUNet': TransUNet, - # 'SwinUNet': SwinUNet, + 'TransUNet': TransUNet, + 'SwinUNet': SwinUNet, 'UNet2D5': UNet2D5, 'UNet3D': UNet3D, 'UNet3D_ScSE': UNet3D_ScSE, 'UNet3D_DualBranch': UNet3D_DualBranch, # 'nnFormer': nnFormer_wrap, - # 'UNETR': UNETR, - # 'UNETR_PP': UNETR_PP, + 'UNETR': UNETR, + 'UNETR_PP': UNETR_PP, # 'MedFormerV1': MedFormerV1, # 'MedFormerV2': MedFormerV2, # 'MedFormerV3': MedFormerV3, @@ -66,4 +70,5 @@ # 'HiFormer_v3': HiFormer_v3, # 'HiFormer_v4': HiFormer_v4, # 'HiFormer_v5': HiFormer_v5 + # 'SwitchNet': SwitchNet } diff --git a/pymic/net_run/semi_sup/ssl_mcnet.py b/pymic/net_run/semi_sup/ssl_mcnet.py new file mode 100644 index 0000000..66e1034 --- /dev/null +++ b/pymic/net_run/semi_sup/ssl_mcnet.py @@ -0,0 +1,129 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import logging +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from pymic.loss.seg.util import get_soft_label +from pymic.loss.seg.util import reshape_prediction_and_ground_truth +from pymic.loss.seg.util import get_classwise_dice +from pymic.net_run.semi_sup import SSLSegAgent +from pymic.util.ramps import get_rampup_ratio + +def sharpening(P, T = 0.1): + T = 1.0/T + P_sharpen = P**T / (P**T + (1-P)**T) + return P_sharpen + +class SSLMCNet(SSLSegAgent): + """ + Mutual Consistency Learning for semi-supervised segmentation. It requires a network + with multiple decoders for learning, such as `pymic.net.net2d.unet2d_mcnet.MCNet2D`. + + * Reference: Yicheng Wu, Zongyuan Ge et al. Mutual consistency learning for + semi-supervised medical image segmentation. + `MIA 2022. `_ + + The original code is at: https://github.com/ycwu1997/MC-Net + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + .. note:: + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. + """ + def training(self): + class_num = self.config['network']['class_num'] + iter_valid = self.config['training']['iter_valid'] + ssl_cfg = self.config['semi_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = ssl_cfg.get('rampup_start', 0) + rampup_end = ssl_cfg.get('rampup_end', iter_max) + temperature = ssl_cfg.get('temperature', 0.1) + unsup_loss_name = ssl_cfg.get('unsupervised_loss', "MSE") + train_loss = 0 + train_loss_sup = 0 + train_loss_reg = 0 + train_dice_list = [] + self.net.train() + + for it in range(iter_valid): + try: + data_lab = next(self.trainIter) + except StopIteration: + self.trainIter = iter(self.train_loader) + data_lab = next(self.trainIter) + try: + data_unlab = next(self.trainIter_unlab) + except StopIteration: + self.trainIter_unlab = iter(self.train_loader_unlab) + data_unlab = next(self.trainIter_unlab) + + # get the inputs + x0 = self.convert_tensor_type(data_lab['image']) + y0 = self.convert_tensor_type(data_lab['label_prob']) + x1 = self.convert_tensor_type(data_unlab['image']) + inputs = torch.cat([x0, x1], dim = 0) + inputs, y0 = inputs.to(self.device), y0.to(self.device) + + # zero the parameter gradients + self.optimizer.zero_grad() + + # forward pass to obtain multiple predictions + outputs = self.net(inputs) + num_outputs = len(outputs) + n0 = list(x0.shape)[0] + p0 = F.softmax(outputs[0], dim=1)[:n0] + # for probability prediction and pseudo respectively + p_ori = torch.zeros((num_outputs,) + outputs[0].shape) + y_psu = torch.zeros((num_outputs,) + outputs[0].shape) + + # get supervised loss + loss_sup = 0 + for idx in range(num_outputs): + p0i = outputs[idx][:n0] + loss_sup += self.get_loss_value(data_lab, p0i, y0) + + # get pseudo labels + p_i = F.softmax(outputs[idx], dim=1) + p_ori[idx] = p_i + y_psu[idx] = sharpening(p_i, temperature) + + # get regularization loss + loss_reg = 0.0 + for i in range(num_outputs): + for j in range(num_outputs): + if (i!=j): + loss_reg += F.mse_loss(p_ori[i], y_psu[j], reduction='mean') + + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + regular_w = ssl_cfg.get('regularize_w', 0.1) * rampup_ratio + loss = loss_sup + regular_w*loss_reg + + loss.backward() + self.optimizer.step() + + train_loss = train_loss + loss.item() + train_loss_sup = train_loss_sup + loss_sup.item() + train_loss_reg = train_loss_reg + loss_reg.item() + # get dice evaluation for each class in annotated images + if(isinstance(p0, tuple) or isinstance(p0, list)): + p0 = p0[0] + p0_argmax = torch.argmax(p0, dim = 1, keepdim = True) + p0_soft = get_soft_label(p0_argmax, class_num, self.tensor_type) + p0_soft, y0 = reshape_prediction_and_ground_truth(p0_soft, y0) + dice_list = get_classwise_dice(p0_soft, y0) + train_dice_list.append(dice_list.cpu().numpy()) + train_avg_loss = train_loss / iter_valid + train_avg_loss_sup = train_loss_sup / iter_valid + train_avg_loss_reg = train_loss_reg / iter_valid + train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) + train_avg_dice = train_cls_dice[1:].mean() + + train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, + 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + return train_scalers From 649e62d59a0582c1d06fc0d7f4ffdc3fb9b1cd3f Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 31 Oct 2023 21:23:43 +0800 Subject: [PATCH 28/86] update docs update docs for mecnet --- docs/source/api.rst | 3 --- docs/source/pymic.net.net2d.rst | 8 ++++++++ docs/source/pymic.net_run.semi_sup.rst | 8 ++++++++ pymic/net/net2d/unet2d_mcnet.py | 7 ++----- pymic/net/net_dict_seg.py | 4 ++-- pymic/net_run/semi_sup/__init__.py | 13 +++++++------ pymic/net_run/semi_sup/ssl_mcnet.py | 9 +++++++-- 7 files changed, 34 insertions(+), 18 deletions(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index 206000d..d09809c 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -9,8 +9,5 @@ API pymic.loss pymic.net pymic.net_run - pymic.net_run_nll - pymic.net_run_ssl - pymic.net_run_wsl pymic.transform pymic.util \ No newline at end of file diff --git a/docs/source/pymic.net.net2d.rst b/docs/source/pymic.net.net2d.rst index d978dfe..bd54bc6 100644 --- a/docs/source/pymic.net.net2d.rst +++ b/docs/source/pymic.net.net2d.rst @@ -52,6 +52,14 @@ pymic.net.net2d.unet2d\_dual\_branch module :undoc-members: :show-inheritance: +pymic.net.net2d.unet2d\_mcnet module +------------------------------------------- + +.. automodule:: pymic.net.net2d.unet2d_mcnet + :members: + :undoc-members: + :show-inheritance: + pymic.net.net2d.unet2d\_nest module ----------------------------------- diff --git a/docs/source/pymic.net_run.semi_sup.rst b/docs/source/pymic.net_run.semi_sup.rst index 6ed157d..15692b2 100644 --- a/docs/source/pymic.net_run.semi_sup.rst +++ b/docs/source/pymic.net_run.semi_sup.rst @@ -28,6 +28,14 @@ pymic.net\_run.semi\_sup.ssl\_cps module :undoc-members: :show-inheritance: +pymic.net\_run.semi\_sup.ssl\_mcnet module +---------------------------------------- + +.. automodule:: pymic.net_run.semi_sup.ssl_mcnet + :members: + :undoc-members: + :show-inheritance: + pymic.net\_run.semi\_sup.ssl\_em module --------------------------------------- diff --git a/pymic/net/net2d/unet2d_mcnet.py b/pymic/net/net2d/unet2d_mcnet.py index c3e8a5f..76ee5de 100644 --- a/pymic/net/net2d/unet2d_mcnet.py +++ b/pymic/net/net2d/unet2d_mcnet.py @@ -4,22 +4,19 @@ import torch.nn as nn from pymic.net.net2d.unet2d import * - class MCNet2D(nn.Module): """ A tri-branch network using UNet2D as backbone. * Reference: Yicheng Wu, Zongyuan Ge et al. Mutual consistency learning for - semi-supervised medical image segmentation. - `MIA 2022. `_ + semi-supervised medical image segmentation. + `Medical Image Analysis 2022. `_ The original code is at: https://github.com/ycwu1997/MC-Net The parameters for the backbone should be given in the `params` dictionary. See :mod:`pymic.net.net2d.unet2d.UNet2D` for details. """ - -class MCNet2D(nn.Module): def __init__(self, params): super(MCNet2D, self).__init__() in_chns = params['in_chns'] diff --git a/pymic/net/net_dict_seg.py b/pymic/net/net_dict_seg.py index e1e250c..8862fd1 100644 --- a/pymic/net/net_dict_seg.py +++ b/pymic/net/net_dict_seg.py @@ -59,8 +59,8 @@ 'UNet3D_ScSE': UNet3D_ScSE, 'UNet3D_DualBranch': UNet3D_DualBranch, # 'nnFormer': nnFormer_wrap, - 'UNETR': UNETR, - 'UNETR_PP': UNETR_PP, + # 'UNETR': UNETR, + # 'UNETR_PP': UNETR_PP, # 'MedFormerV1': MedFormerV1, # 'MedFormerV2': MedFormerV2, # 'MedFormerV3': MedFormerV3, diff --git a/pymic/net_run/semi_sup/__init__.py b/pymic/net_run/semi_sup/__init__.py index be753c2..39ca5dc 100644 --- a/pymic/net_run/semi_sup/__init__.py +++ b/pymic/net_run/semi_sup/__init__.py @@ -1,4 +1,5 @@ from __future__ import absolute_import +# from . import * from pymic.net_run.semi_sup.ssl_abstract import SSLSegAgent from pymic.net_run.semi_sup.ssl_em import SSLEntropyMinimization from pymic.net_run.semi_sup.ssl_mt import SSLMeanTeacher @@ -8,9 +9,9 @@ from pymic.net_run.semi_sup.ssl_urpc import SSLURPC -SSLMethodDict = {'EntropyMinimization': SSLEntropyMinimization, - 'MeanTeacher': SSLMeanTeacher, - 'UAMT': SSLUncertaintyAwareMeanTeacher, - 'CCT': SSLCCT, - 'CPS': SSLCPS, - 'URPC': SSLURPC} \ No newline at end of file +# SSLMethodDict = {'EntropyMinimization': SSLEntropyMinimization, +# 'MeanTeacher': SSLMeanTeacher, +# 'UAMT': SSLUncertaintyAwareMeanTeacher, +# 'CCT': SSLCCT, +# 'CPS': SSLCPS, +# 'URPC': SSLURPC} \ No newline at end of file diff --git a/pymic/net_run/semi_sup/ssl_mcnet.py b/pymic/net_run/semi_sup/ssl_mcnet.py index 66e1034..955357d 100644 --- a/pymic/net_run/semi_sup/ssl_mcnet.py +++ b/pymic/net_run/semi_sup/ssl_mcnet.py @@ -22,8 +22,8 @@ class SSLMCNet(SSLSegAgent): with multiple decoders for learning, such as `pymic.net.net2d.unet2d_mcnet.MCNet2D`. * Reference: Yicheng Wu, Zongyuan Ge et al. Mutual consistency learning for - semi-supervised medical image segmentation. - `MIA 2022. `_ + semi-supervised medical image segmentation. + `Medical Image Analysis 2022. `_ The original code is at: https://github.com/ycwu1997/MC-Net @@ -34,6 +34,11 @@ class SSLMCNet(SSLSegAgent): In the configuration dictionary, in addition to the four sections (`dataset`, `network`, `training` and `inference`) used in fully supervised learning, an extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. + + Special parameters required for MCNet in `semi_supervised_learning` section: + + :param temperature: (float) temperature for label sharpening. The default value is 0.1. + """ def training(self): class_num = self.config['network']['class_num'] From ad5957e083b26e82fad5e27d90c09bbf1ba5b573 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 4 Nov 2023 14:10:44 +0800 Subject: [PATCH 29/86] add multi-net --- pymic/net/multi_net.py | 32 ++++++++++++++++++++ pymic/net_run/agent_seg.py | 49 +++++++++++++++++++++---------- pymic/net_run/semi_sup/ssl_cps.py | 28 +----------------- 3 files changed, 66 insertions(+), 43 deletions(-) create mode 100644 pymic/net/multi_net.py diff --git a/pymic/net/multi_net.py b/pymic/net/multi_net.py new file mode 100644 index 0000000..8807b0c --- /dev/null +++ b/pymic/net/multi_net.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import torch.nn as nn + +class MultiNet(nn.Module): + ''' + A combination of multiple networks. + Parameters should be saved in the `params` dictionary. + + :param `net_names`: (list) A list of network class name. + :param `infer_mode`: (int) Mode for inference. 0: only use the first network. + 1: taking an average of all the networks. + ''' + def __init__(self, net_dict, params): + super(MultiNet, self).__init__() + net_names = params['net_names'] # should be a list of network class name + self.output_mode = params.get('infer_mode', 0) + self.networks = nn.ModuleList([net_dict[item](params) for item in net_names]) + + def forward(self, x): + if(self.training): + output = [net(x) for net in self.networks] + else: + output = self.networks[0](x) + if(self.output_mode == 1): + for i in range(1, len(self.networks)): + output += self.networks[i](x) + output = output / len(self.networks) + return output + \ No newline at end of file diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 70f01b3..65c2b68 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -18,6 +18,7 @@ from pymic.io.image_read_write import save_nd_array_as_image from pymic.io.nifty_dataset import NiftyDataset from pymic.net.net_dict_seg import SegNetDict +from pymic.net.multi_net import MultiNet from pymic.net_run.agent_abstract import NetRunAgent from pymic.net_run.infer_func import Inferer from pymic.loss.loss_dict_seg import SegLossDict @@ -78,9 +79,12 @@ def get_stage_dataset_from_config(self, stage): def create_network(self): if(self.net is None): net_name = self.config['network']['net_type'] - if(net_name not in self.net_dict): - raise ValueError("Undefined network {0:}".format(net_name)) - self.net = self.net_dict[net_name](self.config['network']) + if(isinstance(net_name, (tuple, list))): + self.net = MultiNet(self.net_dict, self.config['network']) + else: + if(net_name not in self.net_dict): + raise ValueError("Undefined network {0:}".format(net_name)) + self.net = self.net_dict[net_name](self.config['network']) if(self.tensor_type == 'float'): self.net.float() else: @@ -164,10 +168,11 @@ def training(self): if(mixup_prob > 0 and random() < mixup_prob): inputs, labels_prob = mixup(inputs, labels_prob) - # # for debug + # for debug # for i in range(inputs.shape[0]): # image_i = inputs[i][0] - # label_i = labels_prob[i][1] + # # label_i = labels_prob[i][1] + # label_i = np.argmax(labels_prob[i], axis = 0) # # pixw_i = pix_w[i][0] # print(image_i.shape, label_i.shape) # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) @@ -176,7 +181,7 @@ def training(self): # save_nd_array_as_image(image_i, image_name, reference_name = None) # save_nd_array_as_image(label_i, label_name, reference_name = None) # # save_nd_array_as_image(pixw_i, weight_name, reference_name = None) - # # continue + # continue inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device) @@ -271,6 +276,27 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): valid_scalars['loss'], valid_scalars['avg_fg_dice']) + "[" + \ ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") + def load_pretrained_weights(self, network, pretrained_dict, device_ids): + if(len(device_ids) > 1): + if(hasattr(network.module, "get_parameters_to_load")): + model_dict = network.module.get_parameters_to_load() + else: + model_dict = network.module.state_dict() + else: + if(hasattr(network, "get_parameters_to_load")): + model_dict = network.get_parameters_to_load() + else: + model_dict = network.state_dict() + pretrained_dict = {k: v for k, v in pretrained_dict.items() if \ + k in model_dict and tensor_shape_match(pretrained_dict[k], model_dict[k])} + logging.info("Initializing the following parameters with pre-trained model") + for k in pretrained_dict: + logging.info(k) + if (len(device_ids) > 1): + network.module.load_state_dict(pretrained_dict, strict = False) + else: + network.load_state_dict(pretrained_dict, strict = False) + def train_valid(self): device_ids = self.config['training']['gpus'] if(len(device_ids) > 1): @@ -310,16 +336,7 @@ def train_valid(self): if(ckpt_init_name is not None): checkpoint = torch.load(ckpt_dir + "/" + ckpt_init_name, map_location = self.device) pretrained_dict = checkpoint['model_state_dict'] - model_dict = self.net.module.state_dict() if (len(device_ids) > 1) else self.net.state_dict() - pretrained_dict = {k: v for k, v in pretrained_dict.items() if \ - k in model_dict and tensor_shape_match(pretrained_dict[k], model_dict[k])} - logging.info("Initializing the following parameters with pre-trained model") - for k in pretrained_dict: - logging.info(k) - if (len(device_ids) > 1): - self.net.module.load_state_dict(pretrained_dict, strict = False) - else: - self.net.load_state_dict(pretrained_dict, strict = False) + self.load_pretrained_weights(self.net, pretrained_dict, device_ids) if(ckpt_init_mode > 0): # Load other information self.max_val_dice = checkpoint.get('valid_pred', 0) diff --git a/pymic/net_run/semi_sup/ssl_cps.py b/pymic/net_run/semi_sup/ssl_cps.py index 4a3be9c..db1fb28 100644 --- a/pymic/net_run/semi_sup/ssl_cps.py +++ b/pymic/net_run/semi_sup/ssl_cps.py @@ -3,30 +3,12 @@ import logging import numpy as np import torch -import torch.nn as nn from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.net_run.semi_sup import SSLSegAgent -from pymic.net.net_dict_seg import SegNetDict from pymic.util.ramps import get_rampup_ratio -class BiNet(nn.Module): - def __init__(self, params): - super(BiNet, self).__init__() - net_name = params['net_type'] - self.net1 = SegNetDict[net_name](params) - self.net2 = SegNetDict[net_name](params) - - def forward(self, x): - out1 = self.net1(x) - out2 = self.net2(x) - - if(self.training): - return out1, out2 - else: - return (out1 + out2) / 2 - class SSLCPS(SSLSegAgent): """ Using cross pseudo supervision for semi-supervised segmentation. @@ -47,14 +29,6 @@ class SSLCPS(SSLSegAgent): def __init__(self, config, stage = 'train'): super(SSLCPS, self).__init__(config, stage) - def create_network(self): - if(self.net is None): - self.net = BiNet(self.config['network']) - if(self.tensor_type == 'float'): - self.net.float() - else: - self.net.double() - def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] @@ -89,7 +63,7 @@ def training(self): # zero the parameter gradients self.optimizer.zero_grad() - outputs1, outputs2 = self.net(inputs) + outputs1, outputs2 = self.net(inputs) outputs_soft1 = torch.softmax(outputs1, dim=1) outputs_soft2 = torch.softmax(outputs2, dim=1) From 6017047877c3860326b283cd20ee9c30347917e9 Mon Sep 17 00:00:00 2001 From: taigw Date: Sun, 5 Nov 2023 16:26:13 +0800 Subject: [PATCH 30/86] add CANet --- pymic/net/multi_net.py | 2 +- pymic/net/net2d/canet_module.py | 578 +++++++++++++++++++++ pymic/net/net2d/unet2d_attention.py | 15 +- pymic/net/net2d/unet2d_canet.py | 756 ++++++++++++++++++++++++++++ pymic/net/net2d/unet2d_mcnet.py | 2 +- pymic/util/evaluation_seg.py | 55 +- 6 files changed, 1379 insertions(+), 29 deletions(-) create mode 100644 pymic/net/net2d/canet_module.py create mode 100644 pymic/net/net2d/unet2d_canet.py diff --git a/pymic/net/multi_net.py b/pymic/net/multi_net.py index 8807b0c..78209b1 100644 --- a/pymic/net/multi_net.py +++ b/pymic/net/multi_net.py @@ -15,7 +15,7 @@ class MultiNet(nn.Module): ''' def __init__(self, net_dict, params): super(MultiNet, self).__init__() - net_names = params['net_names'] # should be a list of network class name + net_names = params['net_type'] # should be a list of network class name self.output_mode = params.get('infer_mode', 0) self.networks = nn.ModuleList([net_dict[item](params) for item in net_names]) diff --git a/pymic/net/net2d/canet_module.py b/pymic/net/net2d/canet_module.py new file mode 100644 index 0000000..097a4f1 --- /dev/null +++ b/pymic/net/net2d/canet_module.py @@ -0,0 +1,578 @@ +# -*- coding: utf-8 -*- +""" +Building blcoks for CA-Net. + +Oringinal file is on `Github. +`_ +""" + +from __future__ import print_function, division +import torch +import torch.nn as nn +import functools +from torch.nn import functional as F + + +class conv_block(nn.Module): + def __init__(self, ch_in, ch_out, drop_out=False): + super(conv_block, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), + nn.BatchNorm2d(ch_out), + nn.ReLU(inplace=True), + nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True), + nn.BatchNorm2d(ch_out), + nn.ReLU(inplace=True), + ) + self.dropout = drop_out + + def forward(self, x): + x = self.conv(x) + if self.dropout: + x = nn.Dropout2d(0.5)(x) + return x + + +# # UpCat(nn.Module) for U-net UP convolution +class UpCat(nn.Module): + def __init__(self, in_feat, out_feat, is_deconv=True): + super(UpCat, self).__init__() + if is_deconv: + self.up = nn.ConvTranspose2d(in_feat, out_feat, kernel_size=2, stride=2) + else: + self.up = nn.Upsample(scale_factor=2, mode='bilinear') + + def forward(self, inputs, down_outputs): + # TODO: Upsampling required after deconv? + outputs = self.up(down_outputs) + offset = inputs.size()[3] - outputs.size()[3] + if offset == 1: + addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2]), out=None).unsqueeze( + 3).cuda() + outputs = torch.cat([outputs, addition], dim=3) + elif offset > 1: + addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2], offset), out=None).cuda() + outputs = torch.cat([outputs, addition], dim=3) + out = torch.cat([inputs, outputs], dim=1) + + return out + + +# # UpCatconv(nn.Module) for up convolution +class UpCatconv(nn.Module): + def __init__(self, in_feat, out_feat, is_deconv=True, drop_out=False): + super(UpCatconv, self).__init__() + + if is_deconv: + self.conv = conv_block(in_feat, out_feat, drop_out=drop_out) + self.up = nn.ConvTranspose2d(in_feat, out_feat, kernel_size=2, stride=2) + else: + self.conv = conv_block(in_feat + out_feat, out_feat, drop_out=drop_out) + self.up = nn.Upsample(scale_factor=2, mode='bilinear') + + def forward(self, inputs, down_outputs): + # TODO: Upsampling required after deconv + outputs = self.up(down_outputs) + offset = inputs.size()[3] - outputs.size()[3] + if offset == 1: + addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2]), out=None).unsqueeze( + 3).cuda() + outputs = torch.cat([outputs, addition], dim=3) + elif offset > 1: + addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2], offset), out=None).cuda() + outputs = torch.cat([outputs, addition], dim=3) + out = self.conv(torch.cat([inputs, outputs], dim=1)) + + return out + + +class UnetDsv3(nn.Module): + def __init__(self, in_size, out_size, scale_factor): + super(UnetDsv3, self).__init__() + self.dsv = nn.Sequential(nn.Conv2d(in_size, out_size, kernel_size=1, stride=1, padding=0), + nn.Upsample(size=scale_factor, mode='bilinear'), ) + + def forward(self, input): + return self.dsv(input) + + +###### Intial weights ##### +def weights_init_normal(m): + classname = m.__class__.__name__ + #print(classname) + if classname.find('Conv') != -1: + nn.init.normal(m.weight.data, 0.0, 0.02) + elif classname.find('Linear') != -1: + nn.init.normal(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm') != -1: + nn.init.normal(m.weight.data, 1.0, 0.02) + nn.init.constant(m.bias.data, 0.0) + + +def weights_init_xavier(m): + classname = m.__class__.__name__ + #print(classname) + if classname.find('Conv') != -1: + nn.init.xavier_normal(m.weight.data, gain=1) + elif classname.find('Linear') != -1: + nn.init.xavier_normal(m.weight.data, gain=1) + elif classname.find('BatchNorm') != -1: + nn.init.normal(m.weight.data, 1.0, 0.02) + nn.init.constant(m.bias.data, 0.0) + + +def weights_init_kaiming(m): + classname = m.__class__.__name__ + #print(classname) + if classname.find('Conv') != -1: + nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in') + elif classname.find('Linear') != -1: + nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in') + elif classname.find('BatchNorm') != -1: + nn.init.normal(m.weight.data, 1.0, 0.02) + nn.init.constant(m.bias.data, 0.0) + + +def weights_init_orthogonal(m): + classname = m.__class__.__name__ + #print(classname) + if classname.find('Conv') != -1: + nn.init.orthogonal(m.weight.data, gain=1) + elif classname.find('Linear') != -1: + nn.init.orthogonal(m.weight.data, gain=1) + elif classname.find('BatchNorm') != -1: + nn.init.normal(m.weight.data, 1.0, 0.02) + nn.init.constant(m.bias.data, 0.0) + + +def init_weights(net, init_type='normal'): + #print('initialization method [%s]' % init_type) + if init_type == 'normal': + net.apply(weights_init_normal) + elif init_type == 'xavier': + net.apply(weights_init_xavier) + elif init_type == 'kaiming': + net.apply(weights_init_kaiming) + elif init_type == 'orthogonal': + net.apply(weights_init_orthogonal) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + + +def get_norm_layer(norm_type='instance'): + if norm_type == 'batch': + norm_layer = functools.partial(nn.BatchNorm2d, affine=True) + elif norm_type == 'instance': + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) + elif norm_type == 'none': + norm_layer = None + else: + raise NotImplementedError('normalization layer [%s] is not found' % norm_type) + return norm_layer + + +###### For attention ###### +class _GridAttentionBlockND(nn.Module): + def __init__(self, in_channels, gating_channels, inter_channels=None, dimension=3, mode='concatenation', + sub_sample_factor=(2,2,2)): + super(_GridAttentionBlockND, self).__init__() + + assert dimension in [2, 3] + assert mode in ['concatenation', 'concatenation_debug', 'concatenation_residual'] + + # Downsampling rate for the input featuremap + if isinstance(sub_sample_factor, tuple): self.sub_sample_factor = sub_sample_factor + elif isinstance(sub_sample_factor, list): self.sub_sample_factor = tuple(sub_sample_factor) + else: self.sub_sample_factor = tuple([sub_sample_factor]) * dimension + + # Default parameter set + self.mode = mode + self.dimension = dimension + self.sub_sample_kernel_size = self.sub_sample_factor + + # Number of channels (pixel dimensions) + self.in_channels = in_channels + self.gating_channels = gating_channels + self.inter_channels = inter_channels + + if self.inter_channels is None: + self.inter_channels = in_channels // 2 + if self.inter_channels == 0: + self.inter_channels = 1 + + if dimension == 3: + conv_nd = nn.Conv3d + bn = nn.BatchNorm3d + self.upsample_mode = 'trilinear' + elif dimension == 2: + conv_nd = nn.Conv2d + bn = nn.BatchNorm2d + self.upsample_mode = 'bilinear' + else: + raise NotImplemented + + # Output transform + self.W = nn.Sequential( + conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), + bn(self.in_channels), + ) + + # Theta^T * x_ij + Phi^T * gating_signal + bias + self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, + kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=True) + self.phi = conv_nd(in_channels=self.gating_channels, out_channels=self.inter_channels, + kernel_size=(1, 1), stride=1, padding=0, bias=True) + self.psi = conv_nd(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True) + + # Initialise weights + for m in self.children(): + init_weights(m, init_type='kaiming') + + # Define the operation + if mode == 'concatenation': + self.operation_function = self._concatenation + elif mode == 'concatenation_debug': + self.operation_function = self._concatenation_debug + elif mode == 'concatenation_residual': + self.operation_function = self._concatenation_residual + else: + raise NotImplementedError('Unknown operation function.') + + + def forward(self, x, g): + ''' + :param x: (b, c, t, h, w) + :param g: (b, g_d) + :return: + ''' + + output = self.operation_function(x, g) + return output + + def _concatenation(self, x, g): + input_size = x.size() + batch_size = input_size[0] + assert batch_size == g.size(0) + + # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) + # phi => (b, g_d) -> (b, i_c) + theta_x = self.theta(x) + theta_x_size = theta_x.size() + + # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') + # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) + phi_g = F.upsample(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) + f = F.relu(theta_x + phi_g, inplace=True) + + # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) + sigm_psi_f = F.sigmoid(self.psi(f)) + + # upsample the attentions and multiply + sigm_psi_f = F.upsample(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) + y = sigm_psi_f.expand_as(x) * x + W_y = self.W(y) + + return W_y, sigm_psi_f + + def _concatenation_debug(self, x, g): + input_size = x.size() + batch_size = input_size[0] + assert batch_size == g.size(0) + + # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) + # phi => (b, g_d) -> (b, i_c) + theta_x = self.theta(x) + theta_x_size = theta_x.size() + + # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') + # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) + phi_g = F.upsample(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) + f = F.softplus(theta_x + phi_g) + + # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) + sigm_psi_f = F.sigmoid(self.psi(f)) + + # upsample the attentions and multiply + sigm_psi_f = F.upsample(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) + y = sigm_psi_f.expand_as(x) * x + W_y = self.W(y) + + return W_y, sigm_psi_f + + + def _concatenation_residual(self, x, g): + input_size = x.size() + batch_size = input_size[0] + assert batch_size == g.size(0) + + # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) + # phi => (b, g_d) -> (b, i_c) + theta_x = self.theta(x) + theta_x_size = theta_x.size() + + # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') + # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) + phi_g = F.upsample(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) + f = F.relu(theta_x + phi_g, inplace=True) + + # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) + f = self.psi(f).view(batch_size, 1, -1) + sigm_psi_f = F.softmax(f, dim=2).view(batch_size, 1, *theta_x.size()[2:]) + + # upsample the attentions and multiply + sigm_psi_f = F.upsample(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) + y = sigm_psi_f.expand_as(x) * x + W_y = self.W(y) + + return W_y, sigm_psi_f + + +class GridAttentionBlock2D(_GridAttentionBlockND): + def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation', + sub_sample_factor=(2, 2)): + super(GridAttentionBlock2D, self).__init__(in_channels, + inter_channels=inter_channels, + gating_channels=gating_channels, + dimension=2, mode=mode, + sub_sample_factor=sub_sample_factor, + ) + + +class GridAttentionBlock3D(_GridAttentionBlockND): + def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation', + sub_sample_factor=(2,2,2)): + super(GridAttentionBlock3D, self).__init__(in_channels, + inter_channels=inter_channels, + gating_channels=gating_channels, + dimension=3, mode=mode, + sub_sample_factor=sub_sample_factor, + ) + +class _GridAttentionBlockND_TORR(nn.Module): + def __init__(self, in_channels, gating_channels, inter_channels=None, dimension=3, mode='concatenation', + sub_sample_factor=(1,1,1), bn_layer=True, use_W=True, use_phi=True, use_theta=True, use_psi=True, nonlinearity1='relu'): + super(_GridAttentionBlockND_TORR, self).__init__() + + assert dimension in [2, 3] + assert mode in ['concatenation', 'concatenation_softmax', + 'concatenation_sigmoid', 'concatenation_mean', + 'concatenation_range_normalise', 'concatenation_mean_flow'] + + # Default parameter set + self.mode = mode + self.dimension = dimension + self.sub_sample_factor = sub_sample_factor if isinstance(sub_sample_factor, tuple) else tuple([sub_sample_factor])*dimension + self.sub_sample_kernel_size = self.sub_sample_factor + + # Number of channels (pixel dimensions) + self.in_channels = in_channels + self.gating_channels = gating_channels + self.inter_channels = inter_channels + + if self.inter_channels is None: + self.inter_channels = in_channels // 2 + if self.inter_channels == 0: + self.inter_channels = 1 + + if dimension == 3: + conv_nd = nn.Conv3d + bn = nn.BatchNorm3d + self.upsample_mode = 'trilinear' + elif dimension == 2: + conv_nd = nn.Conv2d + bn = nn.BatchNorm2d + self.upsample_mode = 'bilinear' + else: + raise NotImplemented + + # initialise id functions + # Theta^T * x_ij + Phi^T * gating_signal + bias + self.W = lambda x: x + self.theta = lambda x: x + self.psi = lambda x: x + self.phi = lambda x: x + self.nl1 = lambda x: x + + if use_W: + if bn_layer: + self.W = nn.Sequential( + conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), + bn(self.in_channels), + ) + else: + self.W = conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0) + + if use_theta: + self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, + kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=False) + + + if use_phi: + self.phi = conv_nd(in_channels=self.gating_channels, out_channels=self.inter_channels, + kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=False) + + + if use_psi: + self.psi = conv_nd(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True) + + + if nonlinearity1: + if nonlinearity1 == 'relu': + self.nl1 = lambda x: F.relu(x, inplace=True) + + if 'concatenation' in mode: + self.operation_function = self._concatenation + else: + raise NotImplementedError('Unknown operation function.') + + # Initialise weights + for m in self.children(): + init_weights(m, init_type='kaiming') + + + if use_psi and self.mode == 'concatenation_sigmoid': + nn.init.constant(self.psi.bias.data, 3.0) + + if use_psi and self.mode == 'concatenation_softmax': + nn.init.constant(self.psi.bias.data, 10.0) + + # if use_psi and self.mode == 'concatenation_mean': + # nn.init.constant(self.psi.bias.data, 3.0) + + # if use_psi and self.mode == 'concatenation_range_normalise': + # nn.init.constant(self.psi.bias.data, 3.0) + + parallel = False + if parallel: + if use_W: self.W = nn.DataParallel(self.W) + if use_phi: self.phi = nn.DataParallel(self.phi) + if use_psi: self.psi = nn.DataParallel(self.psi) + if use_theta: self.theta = nn.DataParallel(self.theta) + + def forward(self, x, g): + ''' + :param x: (b, c, t, h, w) + :param g: (b, g_d) + :return: + ''' + + output = self.operation_function(x, g) + return output + + def _concatenation(self, x, g): + input_size = x.size() + batch_size = input_size[0] + assert batch_size == g.size(0) + + ############################# + # compute compatibility score + + # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) + # phi => (b, c, t, h, w) -> (b, i_c, t, h, w) + theta_x = self.theta(x) + theta_x_size = theta_x.size() + + # nl(theta.x + phi.g + bias) -> f = (b, i_c, t/s1, h/s2, w/s3) + phi_g = F.upsample(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) + + f = theta_x + phi_g + f = self.nl1(f) + + psi_f = self.psi(f) + + ############################################ + # normalisation -- scale compatibility score + # psi^T . f -> (b, 1, t/s1, h/s2, w/s3) + if self.mode == 'concatenation_softmax': + sigm_psi_f = F.softmax(psi_f.view(batch_size, 1, -1), dim=2) + sigm_psi_f = sigm_psi_f.view(batch_size, 1, *theta_x_size[2:]) + elif self.mode == 'concatenation_mean': + psi_f_flat = psi_f.view(batch_size, 1, -1) + psi_f_sum = torch.sum(psi_f_flat, dim=2)#clamp(1e-6) + psi_f_sum = psi_f_sum[:,:,None].expand_as(psi_f_flat) + + sigm_psi_f = psi_f_flat / psi_f_sum + sigm_psi_f = sigm_psi_f.view(batch_size, 1, *theta_x_size[2:]) + elif self.mode == 'concatenation_mean_flow': + psi_f_flat = psi_f.view(batch_size, 1, -1) + ss = psi_f_flat.shape + psi_f_min = psi_f_flat.min(dim=2)[0].view(ss[0],ss[1],1) + psi_f_flat = psi_f_flat - psi_f_min + psi_f_sum = torch.sum(psi_f_flat, dim=2).view(ss[0],ss[1],1).expand_as(psi_f_flat) + + sigm_psi_f = psi_f_flat / psi_f_sum + sigm_psi_f = sigm_psi_f.view(batch_size, 1, *theta_x_size[2:]) + elif self.mode == 'concatenation_range_normalise': + psi_f_flat = psi_f.view(batch_size, 1, -1) + ss = psi_f_flat.shape + psi_f_max = torch.max(psi_f_flat, dim=2)[0].view(ss[0], ss[1], 1) + psi_f_min = torch.min(psi_f_flat, dim=2)[0].view(ss[0], ss[1], 1) + + sigm_psi_f = (psi_f_flat - psi_f_min) / (psi_f_max - psi_f_min).expand_as(psi_f_flat) + sigm_psi_f = sigm_psi_f.view(batch_size, 1, *theta_x_size[2:]) + + elif self.mode == 'concatenation_sigmoid': + sigm_psi_f = F.sigmoid(psi_f) + else: + raise NotImplementedError + + # sigm_psi_f is attention map! upsample the attentions and multiply + sigm_psi_f = F.upsample(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) + y = sigm_psi_f.expand_as(x) * x + W_y = self.W(y) + + return W_y, sigm_psi_f + + +class GridAttentionBlock2D_TORR(_GridAttentionBlockND_TORR): + def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation', + sub_sample_factor=(1,1), bn_layer=True, + use_W=True, use_phi=True, use_theta=True, use_psi=True, + nonlinearity1='relu'): + super(GridAttentionBlock2D_TORR, self).__init__(in_channels, + inter_channels=inter_channels, + gating_channels=gating_channels, + dimension=2, mode=mode, + sub_sample_factor=sub_sample_factor, + bn_layer=bn_layer, + use_W=use_W, + use_phi=use_phi, + use_theta=use_theta, + use_psi=use_psi, + nonlinearity1=nonlinearity1) + + +class GridAttentionBlock3D_TORR(_GridAttentionBlockND_TORR): + def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation', + sub_sample_factor=(1,1,1), bn_layer=True): + super(GridAttentionBlock3D_TORR, self).__init__(in_channels, + inter_channels=inter_channels, + gating_channels=gating_channels, + dimension=3, mode=mode, + sub_sample_factor=sub_sample_factor, + bn_layer=bn_layer) + + +class MultiAttentionBlock(nn.Module): + def __init__(self, in_size, gate_size, inter_size, nonlocal_mode, sub_sample_factor): + super(MultiAttentionBlock, self).__init__() + self.gate_block_1 = GridAttentionBlock2D(in_channels=in_size, gating_channels=gate_size, + inter_channels=inter_size, mode=nonlocal_mode, + sub_sample_factor=sub_sample_factor) + self.gate_block_2 = GridAttentionBlock2D(in_channels=in_size, gating_channels=gate_size, + inter_channels=inter_size, mode=nonlocal_mode, + sub_sample_factor=sub_sample_factor) + self.combine_gates = nn.Sequential(nn.Conv2d(in_size*2, in_size, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(in_size), + nn.ReLU(inplace=True)) + + # initialise the blocks + for m in self.children(): + if m.__class__.__name__.find('GridAttentionBlock2D') != -1: continue + init_weights(m, init_type='kaiming') + + def forward(self, input, gating_signal): + gate_1, attention_1 = self.gate_block_1(input, gating_signal) + gate_2, attention_2 = self.gate_block_2(input, gating_signal) + + return self.combine_gates(torch.cat([gate_1, gate_2], 1)), torch.cat([attention_1, attention_2], 1) \ No newline at end of file diff --git a/pymic/net/net2d/unet2d_attention.py b/pymic/net/net2d/unet2d_attention.py index 6afdfdc..36faec8 100644 --- a/pymic/net/net2d/unet2d_attention.py +++ b/pymic/net/net2d/unet2d_attention.py @@ -4,14 +4,7 @@ import torch import torch.nn as nn from pymic.net.net2d.unet2d import * -""" -A Reimplementation of the attention U-Net paper: - Ozan Oktay, Jo Schlemper et al.: - Attentin U-Net: Looking Where to Look for the Pancreas. MIDL, 2018. -Note that there are some modifications from the original paper, such as -the use of batch normalization, dropout, and leaky relu here. -""" class AttentionGateBlock(nn.Module): def __init__(self, chns_l, chns_h): """ @@ -80,6 +73,14 @@ def forward(self, x1, x2): return self.conv(x) class AttentionUNet2D(UNet2D): + """ + A Reimplementation of the attention U-Net paper: + Ozan Oktay, Jo Schlemper et al.: + Attentin U-Net: Looking Where to Look for the Pancreas. MIDL, 2018. + + Note that there are some modifications from the original paper, such as + the use of batch normalization, dropout, and leaky relu here. + """ def __init__(self, params): super(AttentionUNet2D, self).__init__(params) self.up1 = UpBlockWithAttention(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p = 0.0) diff --git a/pymic/net/net2d/unet2d_canet.py b/pymic/net/net2d/unet2d_canet.py new file mode 100644 index 0000000..7aeba84 --- /dev/null +++ b/pymic/net/net2d/unet2d_canet.py @@ -0,0 +1,756 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F +from pymic.net.net2d.canet_module import * + + +def conv3x3(in_planes, out_planes, stride=1, bias=False, group=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,padding=1, groups=group, bias=bias) + + +class SE_Conv_Block(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, drop_out=False): + super(SE_Conv_Block, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes * 2) + self.bn2 = nn.BatchNorm2d(planes * 2) + self.conv3 = conv3x3(planes * 2, planes) + self.bn3 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + self.dropout = drop_out + + self.fc1 = nn.Linear(in_features=planes * 2, out_features=round(planes / 2)) + self.fc2 = nn.Linear(in_features=round(planes / 2), out_features=planes * 2) + self.sigmoid = nn.Sigmoid() + + self.downchannel = None + if inplanes != planes: + self.downchannel = nn.Sequential(nn.Conv2d(inplanes, planes * 2, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * 2),) + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downchannel is not None: + residual = self.downchannel(x) + + original_out = out + out1 = out + # For global average pool + out = F.adaptive_avg_pool2d(out, (1,1)) + out = out.view(out.size(0), -1) + out = self.fc1(out) + out = self.relu(out) + out = self.fc2(out) + out = self.sigmoid(out) + out = out.view(out.size(0), out.size(1), 1, 1) + avg_att = out + out = out * original_out + # For global maximum pool + out1 = F.adaptive_max_pool2d(out1, (1,1)) + out1 = out1.view(out1.size(0), -1) + out1 = self.fc1(out1) + out1 = self.relu(out1) + out1 = self.fc2(out1) + out1 = self.sigmoid(out1) + out1 = out1.view(out1.size(0), out1.size(1), 1, 1) + max_att = out1 + out1 = out1 * original_out + + att_weight = avg_att + max_att + out += out1 + out += residual + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + out = self.relu(out) + if self.dropout: + out = nn.Dropout2d(0.5)(out) + + return out, att_weight + +# # CBAM Convolutional block attention module +class BasicConv(nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, + relu=True, bn=True, bias=False): + super(BasicConv, self).__init__() + self.out_channels = out_planes + self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, + dilation=dilation, groups=groups, bias=bias) + self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None + self.relu = nn.ReLU() if relu else None + + def forward(self, x): + x = self.conv(x) + if self.bn is not None: + x = self.bn(x) + if self.relu is not None: + x = self.relu(x) + return x + + +class Flatten(nn.Module): + def forward(self, x): + return x.view(x.size(0), -1) + + +class ChannelGate(nn.Module): + def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): + super(ChannelGate, self).__init__() + self.gate_channels = gate_channels + self.mlp = nn.Sequential( + Flatten(), + nn.Linear(gate_channels, gate_channels // reduction_ratio), + nn.ReLU(), + nn.Linear(gate_channels // reduction_ratio, gate_channels) + ) + self.pool_types = pool_types + + def forward(self, x): + channel_att_sum = None + for pool_type in self.pool_types: + if pool_type == 'avg': + avg_pool = F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) + channel_att_raw = self.mlp(avg_pool) + elif pool_type == 'max': + max_pool = F.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) + channel_att_raw = self.mlp(max_pool) + elif pool_type == 'lp': + lp_pool = F.lp_pool2d(x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) + channel_att_raw = self.mlp(lp_pool) + elif pool_type == 'lse': + # LSE pool only + lse_pool = logsumexp_2d(x) + channel_att_raw = self.mlp(lse_pool) + + if channel_att_sum is None: + channel_att_sum = channel_att_raw + else: + channel_att_sum = channel_att_sum + channel_att_raw + + # scalecoe = F.sigmoid(channel_att_sum) + # print("channel att_sum", channel_att_sum.shape) + # channel_att_sum = channel_att_sum.reshape(channel_att_sum.shape[0], 4, 4) + # avg_weight = torch.mean(channel_att_sum, dim=2).unsqueeze(2) + # avg_weight = avg_weight.expand(channel_att_sum.shape[0], 4, 4).reshape(channel_att_sum.shape[0], 16) + # scale = F.sigmoid(avg_weight).unsqueeze(2).unsqueeze(3).expand_as(x) + scale = F.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x) + return x * scale, scale + + +def logsumexp_2d(tensor): + tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) + s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) + outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() + return outputs + + +class ChannelPool(nn.Module): + def forward(self, x): + return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1) + + +class SpatialGate(nn.Module): + def __init__(self): + super(SpatialGate, self).__init__() + kernel_size = 7 + self.compress = ChannelPool() + self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False) + + def forward(self, x): + x_compress = self.compress(x) + x_out = self.spatial(x_compress) + scale = F.sigmoid(x_out) # broadcasting + return x * scale, scale + +class SpatialAtten(nn.Module): + def __init__(self, in_size, out_size, kernel_size=3, stride=1): + super(SpatialAtten, self).__init__() + self.conv1 = BasicConv(in_size, out_size, kernel_size, stride=stride, + padding=(kernel_size-1) // 2, relu=True) + self.conv2 = BasicConv(out_size, in_size, kernel_size=1, stride=stride, + padding=0, relu=True, bn=False) + + def forward(self, x): + residual = x + x_out = self.conv1(x) + x_out = self.conv2(x_out) + spatial_att = F.sigmoid(x_out) + # .unsqueeze(4).permute(0, 1, 4, 2, 3) + # spatial_att = spatial_att.expand(spatial_att.shape[0], 4, 4, spatial_att.shape[3], spatial_att.shape[4]).reshape( + # spatial_att.shape[0], 16, spatial_att.shape[3], spatial_att.shape[4]) + x_out = residual * spatial_att + + x_out += residual + + return x_out, spatial_att + +class Scale_atten_block(nn.Module): + def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False): + super(Scale_atten_block, self).__init__() + self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) + self.no_spatial = no_spatial + if not no_spatial: + self.SpatialGate = SpatialAtten(gate_channels, gate_channels //reduction_ratio) + + def forward(self, x): + x_out, ca_atten = self.ChannelGate(x) + if not self.no_spatial: + x_out, sa_atten = self.SpatialGate(x_out) + + return x_out, ca_atten, sa_atten + + +class scale_atten_convblock(nn.Module): + def __init__(self, in_size, out_size, stride=1, downsample=None, use_cbam=True, no_spatial=False, drop_out=False): + super(scale_atten_convblock, self).__init__() + self.downsample = downsample + self.stride = stride + self.no_spatial = no_spatial + self.dropout = drop_out + + self.relu = nn.ReLU(inplace=True) + self.conv3 = conv3x3(in_size, out_size) + self.bn3 = nn.BatchNorm2d(out_size) + + if use_cbam: + self.cbam = Scale_atten_block(in_size, reduction_ratio=4, no_spatial=self.no_spatial) # out_size + else: + self.cbam = None + + def forward(self, x): + residual = x + + if self.downsample is not None: + residual = self.downsample(x) + + if not self.cbam is None: + out, scale_c_atten, scale_s_atten = self.cbam(x) + + out += residual + out = self.relu(out) + out = self.conv3(out) + out = self.bn3(out) + out = self.relu(out) + + if self.dropout: + out = nn.Dropout2d(0.5)(out) + + return out + +class _NonLocalBlockND(nn.Module): + def __init__(self, in_channels, inter_channels=None, dimension=3, mode='embedded_gaussian', + sub_sample_factor=4, bn_layer=True): + super(_NonLocalBlockND, self).__init__() + + assert dimension in [1, 2, 3] + assert mode in ['embedded_gaussian', 'gaussian', 'dot_product', 'concatenation', 'concat_proper', 'concat_proper_down'] + + # print('Dimension: %d, mode: %s' % (dimension, mode)) + + self.mode = mode + self.dimension = dimension + self.sub_sample_factor = sub_sample_factor if isinstance(sub_sample_factor, list) else [sub_sample_factor] + + self.in_channels = in_channels + self.inter_channels = inter_channels + + if self.inter_channels is None: + self.inter_channels = in_channels // 2 + if self.inter_channels == 0: + self.inter_channels = 1 + + if dimension == 3: + conv_nd = nn.Conv3d + max_pool = nn.MaxPool3d + bn = nn.BatchNorm3d + elif dimension == 2: + conv_nd = nn.Conv2d + max_pool = nn.MaxPool2d + bn = nn.BatchNorm2d + else: + conv_nd = nn.Conv1d + max_pool = nn.MaxPool1d + bn = nn.BatchNorm1d + + self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, + kernel_size=1, stride=1, padding=0) + + if bn_layer: + self.W = nn.Sequential( + conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, + kernel_size=1, stride=1, padding=0), + bn(self.in_channels) + ) + nn.init.constant(self.W[1].weight, 0) + nn.init.constant(self.W[1].bias, 0) + else: + self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, + kernel_size=1, stride=1, padding=0) + nn.init.constant(self.W.weight, 0) + nn.init.constant(self.W.bias, 0) + + self.theta = None + self.phi = None + + if mode in ['embedded_gaussian', 'dot_product', 'concatenation', 'concat_proper', 'concat_proper_down']: + self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, + kernel_size=1, stride=1, padding=0) + self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, + kernel_size=1, stride=1, padding=0) + + if mode in ['concatenation']: + self.wf_phi = nn.Linear(self.inter_channels, 1, bias=False) + self.wf_theta = nn.Linear(self.inter_channels, 1, bias=False) + elif mode in ['concat_proper', 'concat_proper_down']: + self.psi = nn.Conv2d(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, + padding=0, bias=True) + + if mode == 'embedded_gaussian': + self.operation_function = self._embedded_gaussian + elif mode == 'dot_product': + self.operation_function = self._dot_product + elif mode == 'gaussian': + self.operation_function = self._gaussian + elif mode == 'concatenation': + self.operation_function = self._concatenation + elif mode == 'concat_proper': + self.operation_function = self._concatenation_proper + elif mode == 'concat_proper_down': + self.operation_function = self._concatenation_proper_down + else: + raise NotImplementedError('Unknown operation function.') + + if any(ss > 1 for ss in self.sub_sample_factor): + self.g = nn.Sequential(self.g, max_pool(kernel_size=sub_sample_factor)) + if self.phi is None: + self.phi = max_pool(kernel_size=sub_sample_factor) + else: + self.phi = nn.Sequential(self.phi, max_pool(kernel_size=sub_sample_factor)) + if mode == 'concat_proper_down': + self.theta = nn.Sequential(self.theta, max_pool(kernel_size=sub_sample_factor)) + + # Initialise weights + for m in self.children(): + init_weights(m, init_type='kaiming') + + def forward(self, x): + ''' + :param x: (b, c, t, h, w) + :return: + ''' + + output = self.operation_function(x) + return output + + def _embedded_gaussian(self, x): + batch_size = x.size(0) + + # g=>(b, c, t, h, w)->(b, 0.5c, t, h, w)->(b, thw, 0.5c) + g_x = self.g(x).view(batch_size, self.inter_channels, -1) + g_x = g_x.permute(0, 2, 1) + + # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw, 0.5c) + # phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw) + # f=>(b, thw, 0.5c)dot(b, 0.5c, twh) = (b, thw, thw) + theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) + theta_x = theta_x.permute(0, 2, 1) + phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) + f = torch.matmul(theta_x, phi_x) + f_div_C = F.softmax(f, dim=-1) + + # (b, thw, thw)dot(b, thw, 0.5c) = (b, thw, 0.5c)->(b, 0.5c, t, h, w)->(b, c, t, h, w) + y = torch.matmul(f_div_C, g_x) + y = y.permute(0, 2, 1).contiguous() + y = y.view(batch_size, self.inter_channels, *x.size()[2:]) + W_y = self.W(y) + z = W_y + x + + return z + + def _gaussian(self, x): + batch_size = x.size(0) + g_x = self.g(x).view(batch_size, self.inter_channels, -1) + g_x = g_x.permute(0, 2, 1) + + theta_x = x.view(batch_size, self.in_channels, -1) + theta_x = theta_x.permute(0, 2, 1) + + if self.sub_sample_factor > 1: + phi_x = self.phi(x).view(batch_size, self.in_channels, -1) + else: + phi_x = x.view(batch_size, self.in_channels, -1) + + f = torch.matmul(theta_x, phi_x) + f_div_C = F.softmax(f, dim=-1) + + y = torch.matmul(f_div_C, g_x) + y = y.permute(0, 2, 1).contiguous() + y = y.view(batch_size, self.inter_channels, *x.size()[2:]) + W_y = self.W(y) + z = W_y + x + + return z + + def _dot_product(self, x): + batch_size = x.size(0) + + g_x = self.g(x).view(batch_size, self.inter_channels, -1) + g_x = g_x.permute(0, 2, 1) + + theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) + theta_x = theta_x.permute(0, 2, 1) + phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) + f = torch.matmul(theta_x, phi_x) + N = f.size(-1) + f_div_C = f / N + + y = torch.matmul(f_div_C, g_x) + y = y.permute(0, 2, 1).contiguous() + y = y.view(batch_size, self.inter_channels, *x.size()[2:]) + W_y = self.W(y) + z = W_y + x + + return z + + def _concatenation(self, x): + batch_size = x.size(0) + + # g=>(b, c, t, h, w)->(b, 0.5c, thw/s**2) + g_x = self.g(x).view(batch_size, self.inter_channels, -1) + + # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw, 0.5c) + # phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw/s**2, 0.5c) + theta_x = self.theta(x).view(batch_size, self.inter_channels, -1).permute(0, 2, 1) + phi_x = self.phi(x).view(batch_size, self.inter_channels, -1).permute(0, 2, 1) + + # theta => (b, thw, 0.5c) -> (b, thw, 1) -> (b, 1, thw) -> (expand) (b, thw/s**2, thw) + # phi => (b, thw/s**2, 0.5c) -> (b, thw/s**2, 1) -> (expand) (b, thw/s**2, thw) + # f=> RELU[(b, thw/s**2, thw) + (b, thw/s**2, thw)] = (b, thw/s**2, thw) + f = self.wf_theta(theta_x).permute(0, 2, 1).repeat(1, phi_x.size(1), 1) + \ + self.wf_phi(phi_x).repeat(1, 1, theta_x.size(1)) + f = F.relu(f, inplace=True) + + # Normalise the relations + N = f.size(-1) + f_div_c = f / N + + # g(x_j) * f(x_j, x_i) + # (b, 0.5c, thw/s**2) * (b, thw/s**2, thw) -> (b, 0.5c, thw) + y = torch.matmul(g_x, f_div_c) + y = y.contiguous().view(batch_size, self.inter_channels, *x.size()[2:]) + W_y = self.W(y) + z = W_y + x + + return z + + def _concatenation_proper(self, x): + batch_size = x.size(0) + + # g=>(b, c, t, h, w)->(b, 0.5c, thw/s**2) + g_x = self.g(x).view(batch_size, self.inter_channels, -1) + + # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw) + # phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw/s**2) + theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) + phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) + + # theta => (b, 0.5c, thw) -> (expand) (b, 0.5c, thw/s**2, thw) + # phi => (b, 0.5c, thw/s**2) -> (expand) (b, 0.5c, thw/s**2, thw) + # f=> RELU[(b, 0.5c, thw/s**2, thw) + (b, 0.5c, thw/s**2, thw)] = (b, 0.5c, thw/s**2, thw) + f = theta_x.unsqueeze(dim=2).repeat(1,1,phi_x.size(2),1) + \ + phi_x.unsqueeze(dim=3).repeat(1,1,1,theta_x.size(2)) + f = F.relu(f, inplace=True) + + # psi -> W_psi^t * f -> (b, 1, thw/s**2, thw) -> (b, thw/s**2, thw) + f = torch.squeeze(self.psi(f), dim=1) + + # Normalise the relations + f_div_c = F.softmax(f, dim=1) + + # g(x_j) * f(x_j, x_i) + # (b, 0.5c, thw/s**2) * (b, thw/s**2, thw) -> (b, 0.5c, thw) + y = torch.matmul(g_x, f_div_c) + y = y.contiguous().view(batch_size, self.inter_channels, *x.size()[2:]) + W_y = self.W(y) + z = W_y + x + + return z + + def _concatenation_proper_down(self, x): + batch_size = x.size(0) + + # g=>(b, c, t, h, w)->(b, 0.5c, thw/s**2) + g_x = self.g(x).view(batch_size, self.inter_channels, -1) + + # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw) + # phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw/s**2) + theta_x = self.theta(x) + downsampled_size = theta_x.size() + theta_x = theta_x.view(batch_size, self.inter_channels, -1) + phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) + + # theta => (b, 0.5c, thw) -> (expand) (b, 0.5c, thw/s**2, thw) + # phi => (b, 0.5, thw/s**2) -> (expand) (b, 0.5c, thw/s**2, thw) + # f=> RELU[(b, 0.5c, thw/s**2, thw) + (b, 0.5c, thw/s**2, thw)] = (b, 0.5c, thw/s**2, thw) + f = theta_x.unsqueeze(dim=2).repeat(1,1,phi_x.size(2),1) + \ + phi_x.unsqueeze(dim=3).repeat(1,1,1,theta_x.size(2)) + f = F.relu(f, inplace=True) + + # psi -> W_psi^t * f -> (b, 0.5c, thw/s**2, thw) -> (b, 1, thw/s**2, thw) -> (b, thw/s**2, thw) + f = torch.squeeze(self.psi(f), dim=1) + + # Normalise the relations + f_div_c = F.softmax(f, dim=1) + + # g(x_j) * f(x_j, x_i) + # (b, 0.5c, thw/s**2) * (b, thw/s**2, thw) -> (b, 0.5c, thw) + y = torch.matmul(g_x, f_div_c) + y = y.contiguous().view(batch_size, self.inter_channels, *downsampled_size[2:]) + + # upsample the final featuremaps # (b,0.5c,t/s1,h/s2,w/s3) + y = F.upsample(y, size=x.size()[2:], mode='trilinear') + + # attention block output + W_y = self.W(y) + z = W_y + x + + return z + + +class NONLocalBlock2D(_NonLocalBlockND): + def __init__(self, in_channels, inter_channels=None, mode='embedded_gaussian', sub_sample_factor=2, bn_layer=True): + super(NONLocalBlock2D, self).__init__(in_channels, + inter_channels=inter_channels, + dimension=2, mode=mode, + sub_sample_factor=sub_sample_factor, + bn_layer=bn_layer) + + +class CANet(nn.Module): + """ + Implementation of CANet (Comprehensive Attention Network) for image segmentation. + + * Reference: R. Gu et al. `CA-Net: Comprehensive Attention Convolutional Neural Networks + for Explainable Medical Image Segmentation `_. + IEEE Transactions on Medical Imaging, 40(2),2021:699-711. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param bilinear: (bool) Using bilinear for up-sampling or not. + If False, deconvolution will be used for up-sampling. + """ + def __init__(self, params): #args, in_ch=3, n_classes=2, feature_scale=4, is_deconv=True, is_batchnorm=True, + # nonlocal_mode='concatenation', attention_dsample=(1, 1)): + super(CANet, self).__init__() + self.in_channels = params['in_chns'] + self.num_classes = params['class_num'] + self.is_deconv = params.get('is_deconv', True) + self.is_batchnorm = params.get('is_batchnorm', True) + self.feature_scale = params.get('feature_scale', 4) + nonlocal_mode = 'concatenation' + attention_dsample = (1, 1) + + filters = [64, 128, 256, 512, 1024] + filters = [int(x / self.feature_scale) for x in filters] + + # downsampling + self.conv1 = conv_block(self.in_channels, filters[0]) + self.maxpool1 = nn.MaxPool2d(kernel_size=(2, 2)) + + self.conv2 = conv_block(filters[0], filters[1]) + self.maxpool2 = nn.MaxPool2d(kernel_size=(2, 2)) + + self.conv3 = conv_block(filters[1], filters[2]) + self.maxpool3 = nn.MaxPool2d(kernel_size=(2, 2)) + + self.conv4 = conv_block(filters[2], filters[3], drop_out=True) + self.maxpool4 = nn.MaxPool2d(kernel_size=(2, 2)) + + self.center = conv_block(filters[3], filters[4], drop_out=True) + + # attention blocks + # self.attentionblock1 = GridAttentionBlock2D(in_channels=filters[0], gating_channels=filters[1], + # inter_channels=filters[0]) + self.attentionblock2 = MultiAttentionBlock(in_size=filters[1], gate_size=filters[2], inter_size=filters[1], + nonlocal_mode=nonlocal_mode, sub_sample_factor=attention_dsample) + self.attentionblock3 = MultiAttentionBlock(in_size=filters[2], gate_size=filters[3], inter_size=filters[2], + nonlocal_mode=nonlocal_mode, sub_sample_factor=attention_dsample) + self.nonlocal4_2 = NONLocalBlock2D(in_channels=filters[4], inter_channels=filters[4] // 4) + + # upsampling + self.up_concat4 = UpCat(filters[4], filters[3], self.is_deconv) + self.up_concat3 = UpCat(filters[3], filters[2], self.is_deconv) + self.up_concat2 = UpCat(filters[2], filters[1], self.is_deconv) + self.up_concat1 = UpCat(filters[1], filters[0], self.is_deconv) + self.up4 = SE_Conv_Block(filters[4], filters[3], drop_out=True) + self.up3 = SE_Conv_Block(filters[3], filters[2]) + self.up2 = SE_Conv_Block(filters[2], filters[1]) + self.up1 = SE_Conv_Block(filters[1], filters[0]) + + # For deep supervision, project the multi-scale feature maps to the same number of channels + self.dsv1 = nn.Conv2d(in_channels=filters[0], out_channels=filters[0]//2, kernel_size=1) + self.dsv2 = nn.Conv2d(in_channels=filters[1], out_channels=filters[0]//2, kernel_size=1) + self.dsv3 = nn.Conv2d(in_channels=filters[2], out_channels=filters[0]//2, kernel_size=1) + self.dsv4 = nn.Conv2d(in_channels=filters[3], out_channels=filters[0]//2, kernel_size=1) + + self.scale_att = scale_atten_convblock(in_size=filters[0]//2 * 4, out_size=filters[0]) + self.final = nn.Conv2d(filters[0], self.num_classes, kernel_size=1) + + def forward(self, inputs): + # Feature Extraction + conv1 = self.conv1(inputs) + maxpool1 = self.maxpool1(conv1) + + conv2 = self.conv2(maxpool1) + maxpool2 = self.maxpool2(conv2) + + conv3 = self.conv3(maxpool2) + maxpool3 = self.maxpool3(conv3) + + conv4 = self.conv4(maxpool3) + maxpool4 = self.maxpool4(conv4) + + # Gating Signal Generation + center = self.center(maxpool4) + + # Attention Mechanism + # Upscaling Part (Decoder) + up4 = self.up_concat4(conv4, center) + g_conv4 = self.nonlocal4_2(up4) + + up4, att_weight4 = self.up4(g_conv4) + g_conv3, att3 = self.attentionblock3(conv3, up4) + + # atten3_map = att3.cpu().detach().numpy().astype(np.float) + # atten3_map = ndimage.interpolation.zoom(atten3_map, [1.0, 1.0, 224 / atten3_map.shape[2], + # 300 / atten3_map.shape[3]], order=0) + + up3 = self.up_concat3(g_conv3, up4) + up3, att_weight3 = self.up3(up3) + g_conv2, att2 = self.attentionblock2(conv2, up3) + + up2 = self.up_concat2(g_conv2, up3) + up2, att_weight2 = self.up2(up2) + + up1 = self.up_concat1(conv1, up2) + up1, att_weight1 = self.up1(up1) + + # Deep Supervision + dsv1 = self.dsv1(up1) + dsv2 = F.interpolate(self.dsv2(up2), dsv1.shape[2:], mode = 'bilinear') + dsv3 = F.interpolate(self.dsv3(up3), dsv1.shape[2:], mode = 'bilinear') + dsv4 = F.interpolate(self.dsv4(up4), dsv1.shape[2:], mode = 'bilinear') + + dsv_cat = torch.cat([dsv1, dsv2, dsv3, dsv4], dim=1) + out = self.scale_att(dsv_cat) + + out = self.final(out) + + return out + +if __name__ == "__main__": + params = {'in_chns':3, + 'class_num':2} + Net = CANet(params) + Net = Net.double() + + x = np.random.rand(4, 3, 224, 224) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = Net(xt) + print(len(y.size())) + y = y.detach().numpy() + print(y.shape) +axpool3(conv3) + + conv4 = self.conv4(maxpool3) + maxpool4 = self.maxpool4(conv4) + + # Gating Signal Generation + center = self.center(maxpool4) + + # Attention Mechanism + # Upscaling Part (Decoder) + up4 = self.up_concat4(conv4, center) + g_conv4 = self.nonlocal4_2(up4) + + up4, att_weight4 = self.up4(g_conv4) + g_conv3, att3 = self.attentionblock3(conv3, up4) + + # atten3_map = att3.cpu().detach().numpy().astype(np.float) + # atten3_map = ndimage.interpolation.zoom(atten3_map, [1.0, 1.0, 224 / atten3_map.shape[2], + # 300 / atten3_map.shape[3]], order=0) + + up3 = self.up_concat3(g_conv3, up4) + up3, att_weight3 = self.up3(up3) + g_conv2, att2 = self.attentionblock2(conv2, up3) + + # atten2_map = att2.cpu().detach().numpy().astype(np.float) + # atten2_map = ndimage.interpolation.zoom(atten2_map, [1.0, 1.0, 224 / atten2_map.shape[2], + # 300 / atten2_map.shape[3]], order=0) + + up2 = self.up_concat2(g_conv2, up3) + up2, att_weight2 = self.up2(up2) + # g_conv1, att1 = self.attentionblock1(conv1, up2) + + # atten1_map = att1.cpu().detach().numpy().astype(np.float) + # atten1_map = ndimage.interpolation.zoom(atten1_map, [1.0, 1.0, 224 / atten1_map.shape[2], + # 300 / atten1_map.shape[3]], order=0) + up1 = self.up_concat1(conv1, up2) + up1, att_weight1 = self.up1(up1) + + # Deep Supervision + dsv4 = self.dsv4(up4) + dsv3 = self.dsv3(up3) + dsv2 = self.dsv2(up2) + dsv1 = self.dsv1(up1) + dsv_cat = torch.cat([dsv1, dsv2, dsv3, dsv4], dim=1) + out = self.scale_att(dsv_cat) + + out = self.final(out) + + return out + + + +if __name__ == "__main__": + params = {'in_chns':3, + 'class_num':2} + Net = CANet(params) + Net = Net.double() + + x = np.random.rand(2, 3, 256, 256) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = Net(xt) + print(len(y.size())) + y = y.detach().numpy() + print(y.shape) diff --git a/pymic/net/net2d/unet2d_mcnet.py b/pymic/net/net2d/unet2d_mcnet.py index 76ee5de..be5b16b 100644 --- a/pymic/net/net2d/unet2d_mcnet.py +++ b/pymic/net/net2d/unet2d_mcnet.py @@ -39,7 +39,7 @@ def __init__(self, params): 'class_num': class_num, 'up_mode': 2, 'multiscale_pred': False} - self.encoder = Encoder(params1) + self.encoder = Encoder(params1) self.decoder1 = Decoder(params1) self.decoder2 = Decoder(params2) self.decoder3 = Decoder(params3) diff --git a/pymic/util/evaluation_seg.py b/pymic/util/evaluation_seg.py index 836401d..a9b114b 100644 --- a/pymic/util/evaluation_seg.py +++ b/pymic/util/evaluation_seg.py @@ -108,22 +108,30 @@ def binary_hd95(s, g, spacing = None): """ s_edge = get_edge_points(s) g_edge = get_edge_points(g) - image_dim = len(s.shape) - assert(image_dim == len(g.shape)) - if(spacing == None): - spacing = [1.0] * image_dim + ns = s_edge.sum() + ng = g_edge.sum() + if(ns + ng == 0): + hd95 = 0.0 + elif(ns * ng == 0): + hd95 = 100.0 else: - assert(image_dim == len(spacing)) - s_dis = ndimage.distance_transform_edt(1-s_edge, sampling = spacing) - g_dis = ndimage.distance_transform_edt(1-g_edge, sampling = spacing) - - dist_list1 = s_dis[g_edge > 0] - dist_list1 = sorted(dist_list1) - dist1 = dist_list1[int(len(dist_list1)*0.95)] - dist_list2 = g_dis[s_edge > 0] - dist_list2 = sorted(dist_list2) - dist2 = dist_list2[int(len(dist_list2)*0.95)] - return max(dist1, dist2) + image_dim = len(s.shape) + assert(image_dim == len(g.shape)) + if(spacing == None): + spacing = [1.0] * image_dim + else: + assert(image_dim == len(spacing)) + s_dis = ndimage.distance_transform_edt(1-s_edge, sampling = spacing) + g_dis = ndimage.distance_transform_edt(1-g_edge, sampling = spacing) + + dist_list1 = s_dis[g_edge > 0] + dist_list1 = sorted(dist_list1) + dist1 = dist_list1[int(len(dist_list1)*0.95)] + dist_list2 = g_dis[s_edge > 0] + dist_list2 = sorted(dist_list2) + dist2 = dist_list2[int(len(dist_list2)*0.95)] + hd95 = max(dist1, dist2) + return hd95 def binary_assd(s, g, spacing = None): @@ -150,9 +158,14 @@ def binary_assd(s, g, spacing = None): ns = s_edge.sum() ng = g_edge.sum() - s_dis_g_edge = s_dis * g_edge - g_dis_s_edge = g_dis * s_edge - assd = (s_dis_g_edge.sum() + g_dis_s_edge.sum()) / (ns + ng) + if(ns + ng == 0): + assd = 0.0 + elif(ns*ng == 0): + assd = 20.0 + else: + s_dis_g_edge = s_dis * g_edge + g_dis_s_edge = g_dis * s_edge + assd = (s_dis_g_edge.sum() + g_dis_s_edge.sum()) / (ns + ng) return assd # relative volume error evaluation @@ -315,8 +328,10 @@ def evaluation(config): # save the result as csv if(output_name is None): - output_name = "{0:}/eval_{1:}.csv".format(seg_root, metric) - with open(output_name, mode='w') as csv_file: + metric_output_name = "{0:}/eval_{1:}.csv".format(seg_root, metric) + else: + metric_output_name = output_name + with open(metric_output_name, mode='w') as csv_file: csv_writer = csv.writer(csv_file, delimiter=',', quotechar='"',quoting=csv.QUOTE_MINIMAL) head = ['image'] + ["class_{0:}".format(i) for i in label_list] From 28000e230842e4a69155c5d3ed15ac146aecb5e0 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 10 Nov 2023 11:44:06 +0800 Subject: [PATCH 31/86] update unet2d only give one prediction for inference --- pymic/net/net2d/unet2d.py | 28 ++++--- pymic/net/net2d/unet2d_urpc.py | 132 --------------------------------- 2 files changed, 16 insertions(+), 144 deletions(-) delete mode 100644 pymic/net/net2d/unet2d_urpc.py diff --git a/pymic/net/net2d/unet2d.py b/pymic/net/net2d/unet2d.py index 758bfe5..7d14a2e 100644 --- a/pymic/net/net2d/unet2d.py +++ b/pymic/net/net2d/unet2d.py @@ -153,8 +153,8 @@ def __init__(self, params): self.ft_chns = self.params['feature_chns'] self.dropout = self.params['dropout'] self.n_class = self.params['class_num'] - self.up_mode = self.params['up_mode'] - self.mul_pred = self.params['multiscale_pred'] + self.up_mode = self.params.get('up_mode', 2) + self.mul_pred = self.params.get('multiscale_pred', False) assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) @@ -183,10 +183,10 @@ def forward(self, x): x_d1 = self.up3(x_d2, x1) x_d0 = self.up4(x_d1, x0) output = self.out_conv(x_d0) - if(self.mul_pred): + if(self.mul_pred and self.training): output1 = self.out_conv1(x_d1) - output2 = self.out_conv1(x_d2) - output3 = self.out_conv1(x_d3) + output2 = self.out_conv2(x_d2) + output3 = self.out_conv3(x_d3) output = [output, output1, output2, output3] return output @@ -263,19 +263,23 @@ def forward(self, x): if __name__ == "__main__": params = {'in_chns':4, - 'feature_chns':[2, 8, 32, 48, 64], + 'feature_chns':[16, 32, 64, 128, 256], 'dropout': [0, 0, 0.3, 0.4, 0.5], 'class_num': 2, 'up_mode': 0, - 'multiscale_pred': False} + 'multiscale_pred': True} Net = UNet2D(params) Net = Net.double() - x = np.random.rand(4, 4, 10, 96, 96) + x = np.random.rand(4, 4, 10, 256, 256) xt = torch.from_numpy(x) xt = torch.tensor(xt) - y = Net(xt) - print(len(y.size())) - y = y.detach().numpy() - print(y.shape) + out = Net(xt) + if params['multiscale_pred']: + for y in out: + print(len(y.size())) + y = y.detach().numpy() + print(y.shape) + else: + print(out.shape) diff --git a/pymic/net/net2d/unet2d_urpc.py b/pymic/net/net2d/unet2d_urpc.py deleted file mode 100644 index ee8ab7c..0000000 --- a/pymic/net/net2d/unet2d_urpc.py +++ /dev/null @@ -1,132 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division - -import torch -import torch.nn as nn -import numpy as np -from torch.distributions.uniform import Uniform -from pymic.net.net2d.unet2d import ConvBlock, DownBlock, UpBlock - -def FeatureDropout(x): - attention = torch.mean(x, dim=1, keepdim=True) - max_val, _ = torch.max(attention.view( - x.size(0), -1), dim=1, keepdim=True) - threshold = max_val * np.random.uniform(0.7, 0.9) - threshold = threshold.view(x.size(0), 1, 1, 1).expand_as(attention) - drop_mask = (attention < threshold).float() - x = x.mul(drop_mask) - return x - -class FeatureNoise(nn.Module): - def __init__(self, uniform_range=0.3): - super(FeatureNoise, self).__init__() - self.uni_dist = Uniform(-uniform_range, uniform_range) - - def feature_based_noise(self, x): - noise_vector = self.uni_dist.sample( - x.shape[1:]).to(x.device).unsqueeze(0) - x_noise = x.mul(noise_vector) + x - return x_noise - - def forward(self, x): - x = self.feature_based_noise(x) - return x - -class UNet2D_URPC(nn.Module): - """ - An modification the U-Net to obtain multi-scale prediction according to - the URPC paper. - - * Reference: Xiangde Luo, Guotai Wang*, Wenjun Liao, Jieneng Chen, Tao Song, Yinan Chen, - Shichuan Zhang, Dimitris N. Metaxas, Shaoting Zhang. - Semi-Supervised Medical Image Segmentation via Uncertainty Rectified Pyramid Consistency . - `Medical Image Analysis 2022. `_ - - Also see: https://github.com/HiLab-git/SSL4MIS/blob/master/code/networks/unet.py - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param bilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - """ - def __init__(self, params): - super(UNet2D_URPC, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.n_class = self.params['class_num'] - self.bilinear = self.params['bilinear'] - assert(len(self.ft_chns) == 5) - - self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) - self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) - self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) - self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) - self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], 0.0, self.bilinear) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], 0.0, self.bilinear) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], 0.0, self.bilinear) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], 0.0, self.bilinear) - - self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, - kernel_size = 3, padding = 1) - self.out_conv_dp4 = nn.Conv2d(self.ft_chns[4], self.n_class, - kernel_size=3, padding=1) - self.out_conv_dp3 = nn.Conv2d(self.ft_chns[3], self.n_class, - kernel_size=3, padding=1) - self.out_conv_dp2 = nn.Conv2d(self.ft_chns[2], self.n_class, - kernel_size=3, padding=1) - self.out_conv_dp1 = nn.Conv2d(self.ft_chns[1], self.n_class, - kernel_size=3, padding=1) - self.feature_noise = FeatureNoise() - - def forward(self, x): - x_shape = list(x.shape) - if(len(x_shape) == 5): - [N, C, D, H, W] = x_shape - new_shape = [N*D, C, H, W] - x = torch.transpose(x, 1, 2) - x = torch.reshape(x, new_shape) - x0 = self.in_conv(x) - x1 = self.down1(x0) - x2 = self.down2(x1) - x3 = self.down3(x2) - x4 = self.down4(x3) - - x = self.up1(x4, x3) - if self.training: - x = nn.functional.dropout(x, p=0.5) - dp3_out = self.out_conv_dp3(x) - - x = self.up2(x, x2) - if self.training: - x = FeatureDropout(x) - dp2_out = self.out_conv_dp2(x) - - x = self.up3(x, x1) - if self.training: - x = self.feature_noise(x) - dp1_out = self.out_conv_dp1(x) - - x = self.up4(x, x0) - dp0_out = self.out_conv(x) - - out_shape = list(dp0_out.shape)[2:] - dp3_out = nn.functional.interpolate(dp3_out, out_shape) - dp2_out = nn.functional.interpolate(dp2_out, out_shape) - dp1_out = nn.functional.interpolate(dp1_out, out_shape) - out = [dp0_out, dp1_out, dp2_out, dp3_out] - - if(len(x_shape) == 5): - new_shape = [N, D] + list(dp0_out.shape)[1:] - for i in range(len(out)): - out[i] = torch.transpose(torch.reshape(out[i], new_shape), 1, 2) - return out \ No newline at end of file From 5948827dd670c53aa1b96fe6faab14af50514298 Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 29 Nov 2023 09:58:36 +0800 Subject: [PATCH 32/86] update unet2d_scse --- pymic/net/net2d/unet2d_scse.py | 143 +++++++++++---------------------- pymic/test/test_net2d.py | 57 +++++++++++++ pymic/test/test_net3d.py | 108 +++++++++++++++++++++++++ 3 files changed, 211 insertions(+), 97 deletions(-) create mode 100644 pymic/test/test_net2d.py create mode 100644 pymic/test/test_net3d.py diff --git a/pymic/net/net2d/unet2d_scse.py b/pymic/net/net2d/unet2d_scse.py index 125843e..54a5d2f 100644 --- a/pymic/net/net2d/unet2d_scse.py +++ b/pymic/net/net2d/unet2d_scse.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn import numpy as np +from pymic.net.net2d.unet2d import UpBlock, Encoder, Decoder, UNet2D from pymic.net.net2d.scse2d import * class ConvScSEBlock(nn.Module): @@ -50,116 +51,64 @@ def __init__(self, in_channels, out_channels, dropout_p): def forward(self, x): return self.maxpool_conv(x) -class UpBlock(nn.Module): +class UpBlockScSE(UpBlock): """Up-sampling followed by `ConvScSEBlock` in U-Net structure. - :param in_channels1: (int) Input channel number for low-resolution feature map. - :param in_channels2: (int) Input channel number for high-resolution feature map. - :param out_channels: (int) Output channel number. - :param dropout_p: (int) Dropout probability. - :param bilinear: (bool) Use bilinear for up-sampling or not. + The parameters for the backbone should be given in the `params` dictionary. + See :mod:`pymic.net.net2d.unet2d.UpBlock` for details. """ - def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, - bilinear=True): - super(UpBlock, self).__init__() - self.bilinear = bilinear - if bilinear: - self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size = 1) - self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) - else: - self.up = nn.ConvTranspose2d(in_channels1, in_channels2, kernel_size=2, stride=2) + def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, up_mode = 2): + super(UpBlockScSE, self).__init__(in_channels1, in_channels2, out_channels, dropout_p, up_mode) self.conv = ConvScSEBlock(in_channels2 * 2, out_channels, dropout_p) - def forward(self, x1, x2): - if self.bilinear: - x1 = self.conv1x1(x1) - x1 = self.up(x1) - x = torch.cat([x2, x1], dim=1) - return self.conv(x) - -class UNet2D_ScSE(nn.Module): +class EncoderScSE(Encoder): """ - Combining 2D U-Net with SCSE module. + Encoder of 2D UNet with ScSE. - * Reference: Abhijit Guha Roy, Nassir Navab, Christian Wachinger: - Recalibrating Fully Convolutional Networks With Spatial and Channel - "Squeeze and Excitation" Blocks. - `IEEE Trans. Med. Imaging 38(2): 540-549 (2019). `_ - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param bilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. + The parameters for the backbone should be given in the `params` dictionary. + See :mod:`pymic.net.net2d.unet2d.Encoder` for details. """ def __init__(self, params): - super(UNet2D_ScSE, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.n_class = self.params['class_num'] - self.bilinear = self.params['bilinear'] - assert(len(self.ft_chns) == 5) + super(EncoderScSE, self).__init__(params) self.in_conv= ConvScSEBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) - self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p = self.dropout[3]) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p = self.dropout[2]) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p = self.dropout[1]) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p = self.dropout[0]) - - self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, - kernel_size = 3, padding = 1) + if(len(self.ft_chns) == 5): + self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - def forward(self, x): - x_shape = list(x.shape) - if(len(x_shape) == 5): - [N, C, D, H, W] = x_shape - new_shape = [N*D, C, H, W] - x = torch.transpose(x, 1, 2) - x = torch.reshape(x, new_shape) - x0 = self.in_conv(x) - x1 = self.down1(x0) - x2 = self.down2(x1) - x3 = self.down3(x2) - x4 = self.down4(x3) +class DecoderScSE(Decoder): + """ + Decoder of 2D UNet with ScSE. + + The parameters for the backbone should be given in the `params` dictionary. + See :mod:`pymic.net.net2d.unet2d.Decoder` for details. + """ + def __init__(self, params): + super(DecoderScSE, self).__init__(params) - x = self.up1(x4, x3) - x = self.up2(x, x2) - x = self.up3(x, x1) - x = self.up4(x, x0) - output = self.out_conv(x) - - if(len(x_shape) == 5): - new_shape = [N, D] + list(output.shape)[1:] - output = torch.reshape(output, new_shape) - output = torch.transpose(output, 1, 2) - return output - -if __name__ == "__main__": - params = {'in_chns':4, - 'feature_chns':[2, 8, 32, 48, 64], - 'dropout': [0, 0, 0.3, 0.4, 0.5], - 'class_num': 2, - 'bilinear': True} - Net = UNet2D_ScSE(params) - Net = Net.double() - - x = np.random.rand(4, 4, 10, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - print(len(y.size())) - y = y.detach().numpy() - print(y.shape) \ No newline at end of file + + if(len(self.ft_chns) == 5): + self.up1 = UpBlockScSE(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.up_mode) + self.up2 = UpBlockScSE(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.up_mode) + self.up3 = UpBlockScSE(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.up_mode) + self.up4 = UpBlockScSE(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.up_mode) + + +class UNet2D_ScSE(UNet2D): + """ + Combining 2D U-Net with SCSE module. + + * Reference: Abhijit Guha Roy, Nassir Navab, Christian Wachinger: + Recalibrating Fully Convolutional Networks With Spatial and Channel + "Squeeze and Excitation" Blocks. + `IEEE Trans. Med. Imaging 38(2): 540-549 (2019). `_ + + The parameters for the backbone should be given in the `params` dictionary. + See :mod:`pymic.net.net2d.unet2d.unet2d` for details. + """ + def __init__(self, params): + super(UNet2D_ScSE, self).__init__(params) + self.encoder = Encoder(params) + self.decoder = Decoder(params) diff --git a/pymic/test/test_net2d.py b/pymic/test/test_net2d.py new file mode 100644 index 0000000..aafaf20 --- /dev/null +++ b/pymic/test/test_net2d.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import numpy as np +from pymic.net.net2d.unet2d import UNet2D +from pymic.net.net2d.unet2d_scse import UNet2D_ScSE + +def test_unet2d(): + params = {'in_chns':4, + 'feature_chns':[16, 32, 64, 128, 256], + 'dropout': [0, 0, 0.3, 0.4, 0.5], + 'class_num': 2, + 'up_mode': 0, + 'multiscale_pred': True} + Net = UNet2D(params) + Net = Net.double() + + x = np.random.rand(4, 4, 10, 256, 256) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + out = Net(xt) + if params['multiscale_pred']: + for y in out: + print(len(y.size())) + y = y.detach().numpy() + print(y.shape) + else: + print(out.shape) + +def test_unet2d_scse(): + params = {'in_chns':4, + 'feature_chns':[16, 32, 64, 128, 256], + 'dropout': [0, 0, 0.3, 0.4, 0.5], + 'class_num': 2, + 'up_mode': 0, + 'multiscale_pred': True} + Net = UNet2D_ScSE(params) + Net = Net.double() + + x = np.random.rand(4, 4, 10, 256, 256) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + out = Net(xt) + if params['multiscale_pred']: + for y in out: + print(len(y.size())) + y = y.detach().numpy() + print(y.shape) + else: + print(out.shape) + +if __name__ == "__main__": + # test_unet2d() + test_unet2d_scse() \ No newline at end of file diff --git a/pymic/test/test_net3d.py b/pymic/test/test_net3d.py new file mode 100644 index 0000000..180dcff --- /dev/null +++ b/pymic/test/test_net3d.py @@ -0,0 +1,108 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import numpy as np +from pymic.net.net3d.unet3d import UNet3D +from pymic.net.net3d.unet3d_scse import UNet3D_ScSE +from pymic.net.net3d.unet2d5 import UNet2D5 + +def test_unet3d(): + params = {'in_chns':4, + 'class_num': 2, + 'feature_chns':[2, 8, 32, 64], + 'dropout' : [0, 0, 0, 0.5], + 'up_mode': 2, + 'multiscale_pred': False} + Net = UNet3D(params) + Net = Net.double() + + x = np.random.rand(4, 4, 96, 96, 96) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = Net(xt) + y = y.detach().numpy() + print(y.shape) + + params = {'in_chns':4, + 'class_num': 2, + 'feature_chns':[2, 8, 32, 64, 128], + 'dropout' : [0, 0, 0, 0.4, 0.5], + 'up_mode': 3, + 'multiscale_pred': True} + Net = UNet3D(params) + Net = Net.double() + + x = np.random.rand(4, 4, 96, 128, 128) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + out = Net(xt) + for y in out: + y = y.detach().numpy() + print(y.shape) + +def test_unet3d_scse(): + params = {'in_chns':4, + 'feature_chns':[2, 8, 32, 48, 64], + 'dropout': [0, 0, 0.3, 0.4, 0.5], + 'class_num': 2, + 'up_mode': 2} + Net = UNet3D_ScSE(params) + Net = Net.double() + + x = np.random.rand(4, 4, 96, 96, 96) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = Net(xt) + print(len(y.size())) + y = y.detach().numpy() + print(y.shape) + +def test_unet2d5(): + params = {'in_chns':4, + 'feature_chns':[8, 16, 32, 64, 128], + 'conv_dims': [2, 2, 3, 3, 3], + 'dropout': [0, 0, 0.3, 0.4, 0.5], + 'class_num': 2, + 'up_mode': 2, + 'multiscale_pred': True} + Net = UNet2D5(params) + Net = Net.double() + + x = np.random.rand(4, 4, 32, 128, 128) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + out = Net(xt) + for y in out: + y = y.detach().numpy() + print(y.shape) + + params = {'in_chns':4, + 'feature_chns':[8, 16, 32, 64, 128], + 'conv_dims': [2, 3, 3, 3, 3], + 'dropout': [0, 0, 0.3, 0.4, 0.5], + 'class_num': 2, + 'up_mode': 0, + 'multiscale_pred': True} + Net = UNet2D5(params) + Net = Net.double() + + x = np.random.rand(4, 4, 64, 128, 128) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + out = Net(xt) + for y in out: + y = y.detach().numpy() + print(y.shape) + +if __name__ == "__main__": + # test_unet3d() + # test_unet3d_scse() + test_unet2d5() + + \ No newline at end of file From 0248db0b164e6994188fd6487d9cca2897486b7b Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 4 Dec 2023 17:14:20 +0800 Subject: [PATCH 33/86] add affine transform add affine transform --- pymic/io/nifty_dataset.py | 19 ++-- pymic/net_run/agent_preprocess.py | 104 +++++++++++++++++++ pymic/net_run/agent_seg.py | 2 + pymic/net_run/get_optimizer.py | 2 +- pymic/net_run/semi_sup/ssl_cps.py | 14 +++ pymic/net_run/semi_sup/ssl_mcnet.py | 9 +- pymic/net_run/train.py | 5 +- pymic/test/test_assd.py | 3 +- pymic/transform/affine.py | 156 ++++++++++++++++++++++++++++ pymic/transform/crop.py | 31 ++++-- pymic/transform/flip.py | 1 - pymic/transform/intensity.py | 7 +- pymic/transform/normalize.py | 24 +++-- pymic/transform/pad.py | 1 - pymic/transform/trans_dict.py | 2 + 15 files changed, 338 insertions(+), 42 deletions(-) create mode 100644 pymic/net_run/agent_preprocess.py create mode 100644 pymic/transform/affine.py diff --git a/pymic/io/nifty_dataset.py b/pymic/io/nifty_dataset.py index 9812d13..438a07b 100644 --- a/pymic/io/nifty_dataset.py +++ b/pymic/io/nifty_dataset.py @@ -38,7 +38,8 @@ def __init__(self, root_dir, csv_file, modal_num = 1, if('label' not in csv_keys): logging.warning("`label` section is not found in the csv file {0:}".format( csv_file) + "\n -- This is only allowed for self-supervised learning" + - "\n -- when `SelfSuperviseLabel` is used in the transform.") + "\n -- when `SelfSuperviseLabel` is used in the transform, or when" + + "\n -- loading the unlabeled data for preprocessing.") self.with_label = False self.image_weight_idx = None self.pixel_weight_idx = None @@ -52,15 +53,15 @@ def __len__(self): def __getlabel__(self, idx): csv_keys = list(self.csv_items.keys()) - label_idx = csv_keys.index('label') - label_name = "{0:}/{1:}".format(self.root_dir, - self.csv_items.iloc[idx, label_idx]) - label = load_image_as_nd_array(label_name)['data_array'] + label_idx = csv_keys.index('label') + label_name = self.csv_items.iloc[idx, label_idx] + label_name_full = "{0:}/{1:}".format(self.root_dir, label_name) + label = load_image_as_nd_array(label_name_full)['data_array'] if(self.task == TaskType.SEGMENTATION): label = np.asarray(label, np.int32) elif(self.task == TaskType.RECONSTRUCTION): label = np.asarray(label, np.float32) - return label + return label, label_name def __get_pixel_weight__(self, idx): weight_name = "{0:}/{1:}".format(self.root_dir, @@ -80,12 +81,14 @@ def __getitem__(self, idx): image_list.append(image_data) image = np.concatenate(image_list, axis = 0) image = np.asarray(image, np.float32) - sample = {'image': image, 'names' : names_list[0], + + sample = {'image': image, 'names' : names_list, 'origin':image_dict['origin'], 'spacing': image_dict['spacing'], 'direction':image_dict['direction']} if (self.with_label): - sample['label'] = self.__getlabel__(idx) + sample['label'], label_name = self.__getlabel__(idx) + sample['names'].append(label_name) assert(image.shape[1:] == sample['label'].shape[1:]) if (self.image_weight_idx is not None): sample['image_weight'] = self.csv_items.iloc[idx, self.image_weight_idx] diff --git a/pymic/net_run/agent_preprocess.py b/pymic/net_run/agent_preprocess.py new file mode 100644 index 0000000..c681de9 --- /dev/null +++ b/pymic/net_run/agent_preprocess.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import os +import sys +import torch +import torchvision.transforms as transforms +from pymic.util.parse_config import * +from pymic.io.image_read_write import save_nd_array_as_image +from pymic.io.nifty_dataset import NiftyDataset +from pymic.transform.trans_dict import TransformDict + + + +class PreprocessAgent(object): + def __init__(self, config): + super(PreprocessAgent, self).__init__() + self.config = config + self.transform_dict = TransformDict + self.task_type = config['dataset']['task_type'] + self.dataloader = None + self.dataloader_unlab= None + + def get_dataset_from_config(self): + root_dir = self.config['dataset']['root_dir'] + modal_num = self.config['dataset'].get('modal_num', 1) + transform_names = self.config['dataset']["transform"] + + self.transform_list = [] + if(transform_names is None or len(transform_names) == 0): + data_transform = None + else: + transform_param = self.config['dataset'] + transform_param['task'] = self.task_type + for name in transform_names: + if(name not in self.transform_dict): + raise(ValueError("Undefined transform {0:}".format(name))) + one_transform = self.transform_dict[name](transform_param) + self.transform_list.append(one_transform) + data_transform = transforms.Compose(self.transform_list) + + data_csv = self.config['dataset'].get('data_csv', None) + data_csv_unlab = self.config['dataset'].get('data_csv_unlab', None) + if(data_csv is not None): + dataset = NiftyDataset(root_dir = root_dir, + csv_file = data_csv, + modal_num = modal_num, + with_label= True, + transform = data_transform, + task = self.task_type) + self.dataloader = torch.utils.data.DataLoader(dataset, + batch_size = 1, shuffle=False, num_workers= 8, + worker_init_fn=None, generator = torch.Generator()) + if(data_csv_unlab is not None): + dataset_unlab = NiftyDataset(root_dir = root_dir, + csv_file = data_csv_unlab, + modal_num = modal_num, + with_label= False, + transform = data_transform, + task = self.task_type) + self.dataloader_unlab = torch.utils.data.DataLoader(dataset_unlab, + batch_size = 1, shuffle=False, num_workers= 8, + worker_init_fn=None, generator = torch.Generator()) + + def run(self): + """ + Do preprocessing for labeled and unlabeled data. + """ + self.get_dataset_from_config() + out_dir = self.config['dataset']['output_dir'] + for dataloader in [self.dataloader, self.dataloader_unlab]: + for item in dataloader: + img = item['image'][0] # the batch size is 1 + # save differnt modaliteis + img_names = item['names'] + spacing = [x.numpy()[0] for x in item['spacing']] + for i in range(img.shape[0]): + image_name = out_dir + "/" + img_names[i][0] + print(image_name) + save_nd_array_as_image(img[i], image_name, reference_name = None, spacing=spacing) + if('label' in item): + lab = item['label'][0] + label_name = out_dir + "/" + img_names[-1][0] + print(label_name) + save_nd_array_as_image(lab[0], label_name, reference_name = None, spacing=spacing) + +def main(): + """ + The main function for data preprocessing. + """ + if(len(sys.argv) < 2): + print('Number of arguments should be 2. e.g.') + print(' pymic_preprocess config.cfg') + exit() + cfg_file = str(sys.argv[1]) + if(not os.path.isfile(cfg_file)): + raise ValueError("The config file does not exist: " + cfg_file) + config = parse_config(cfg_file) + config = synchronize_config(config) + agent = PreprocessAgent(config) + agent.run() + +if __name__ == "__main__": + main() + diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 65c2b68..c376546 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -89,6 +89,8 @@ def create_network(self): self.net.float() else: self.net.double() + if(hasattr(self.net, "set_stage")): + self.net.set_stage(self.stage) param_number = sum(p.numel() for p in self.net.parameters() if p.requires_grad) logging.info('parameter number {0:}'.format(param_number)) diff --git a/pymic/net_run/get_optimizer.py b/pymic/net_run/get_optimizer.py index ad8fda0..53de5e3 100644 --- a/pymic/net_run/get_optimizer.py +++ b/pymic/net_run/get_optimizer.py @@ -14,7 +14,7 @@ def get_optimizer(name, net_params, optim_params): param_group = [{'params': net_params, 'initial_lr': lr}] if(keyword_match(name, "SGD")): return optim.SGD(param_group, lr, - momentum = momentum, weight_decay = weight_decay) + momentum = momentum, weight_decay = weight_decay, nesterov = True) elif(keyword_match(name, "Adam")): return optim.Adam(param_group, lr, weight_decay = weight_decay) elif(keyword_match(name, "SparseAdam")): diff --git a/pymic/net_run/semi_sup/ssl_cps.py b/pymic/net_run/semi_sup/ssl_cps.py index db1fb28..df4b5af 100644 --- a/pymic/net_run/semi_sup/ssl_cps.py +++ b/pymic/net_run/semi_sup/ssl_cps.py @@ -6,6 +6,7 @@ from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice +from pymic.io.image_read_write import save_nd_array_as_image from pymic.net_run.semi_sup import SSLSegAgent from pymic.util.ramps import get_rampup_ratio @@ -57,6 +58,19 @@ def training(self): x0 = self.convert_tensor_type(data_lab['image']) y0 = self.convert_tensor_type(data_lab['label_prob']) x1 = self.convert_tensor_type(data_unlab['image']) + + # for debug + # for i in range(x0.shape[0]): + # image_i = x0[i][0] + # label_i = np.argmax(y0[i], axis = 0) + # # pixw_i = pix_w[i][0] + # print(image_i.shape, label_i.shape) + # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) + # label_name = "temp/label_{0:}_{1:}.nii.gz".format(it, i) + # save_nd_array_as_image(image_i, image_name, reference_name = None) + # save_nd_array_as_image(label_i, label_name, reference_name = None) + # continue + inputs = torch.cat([x0, x1], dim = 0) inputs, y0 = inputs.to(self.device), y0.to(self.device) diff --git a/pymic/net_run/semi_sup/ssl_mcnet.py b/pymic/net_run/semi_sup/ssl_mcnet.py index 955357d..66e1034 100644 --- a/pymic/net_run/semi_sup/ssl_mcnet.py +++ b/pymic/net_run/semi_sup/ssl_mcnet.py @@ -22,8 +22,8 @@ class SSLMCNet(SSLSegAgent): with multiple decoders for learning, such as `pymic.net.net2d.unet2d_mcnet.MCNet2D`. * Reference: Yicheng Wu, Zongyuan Ge et al. Mutual consistency learning for - semi-supervised medical image segmentation. - `Medical Image Analysis 2022. `_ + semi-supervised medical image segmentation. + `MIA 2022. `_ The original code is at: https://github.com/ycwu1997/MC-Net @@ -34,11 +34,6 @@ class SSLMCNet(SSLSegAgent): In the configuration dictionary, in addition to the four sections (`dataset`, `network`, `training` and `inference`) used in fully supervised learning, an extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. - - Special parameters required for MCNet in `semi_supervised_learning` section: - - :param temperature: (float) temperature for label sharpening. The default value is 0.1. - """ def training(self): class_num = self.config['network']['class_num'] diff --git a/pymic/net_run/train.py b/pymic/net_run/train.py index 0167f2f..426b620 100644 --- a/pymic/net_run/train.py +++ b/pymic/net_run/train.py @@ -78,11 +78,12 @@ def main(): os.makedirs(log_dir, exist_ok=True) dst_cfg = cfg_file if "/" not in cfg_file else cfg_file.split("/")[-1] shutil.copy(cfg_file, log_dir + "/" + dst_cfg) + datetime_str = str(datetime.now())[:-7].replace(":", "_") if sys.version.startswith("3.9"): - logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(str(datetime.now())[:-7]), + logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(datetime_str), level=logging.INFO, format='%(message)s', force=True) # for python 3.9 else: - logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(str(datetime.now())[:-7]), + logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(datetime_str), level=logging.INFO, format='%(message)s') # for python 3.6 logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging_config(config) diff --git a/pymic/test/test_assd.py b/pymic/test/test_assd.py index 35c1804..6b2732a 100644 --- a/pymic/test/test_assd.py +++ b/pymic/test/test_assd.py @@ -18,7 +18,8 @@ def test_assd_2d(): plt.show() def test_assd_3d(): - img_name = "/home/x/projects/PyMIC_project/PyMIC_examples/seg_ssl/ACDC/result/unet2d_baseline/patient001_frame01.nii.gz" + # img_name = "/home/x/projects/PyMIC_project/PyMIC_examples/seg_ssl/ACDC/result/unet2d_baseline/patient001_frame01.nii.gz" + img_name = "/home/disk4t/data/heart/ACDC/preprocess/patient001_frame12_gt.nii.gz" img_obj = sitk.ReadImage(img_name) spacing = img_obj.GetSpacing() spacing = spacing[::-1] diff --git a/pymic/transform/affine.py b/pymic/transform/affine.py new file mode 100644 index 0000000..552516f --- /dev/null +++ b/pymic/transform/affine.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import json +import math +import random +import numpy as np +from skimage import transform +from scipy import ndimage +from pymic import TaskType +from pymic.transform.abstract_transform import AbstractTransform +from pymic.util.image_process import * + +class Affine(AbstractTransform): + """ + Apply Affine Transform to an ND volume in the x-y plane. + Input shape should be [C, D, H, W] or [C, H, W]. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `Affine_scale_range`: (list or tuple) The range for scaling, e.g., (0.5, 2.0) + :param `Affine_shear_range`: (list or tuple) The range for shearing angle, e.g., (0, 30) + :param `Affine_rotate_range`: (list or tuple) The range for rotation, e.g., (-45, 45) + :param `Affine_output_size`: (None, list or tuple of length 2) The output size after affine transformation. + For 3D volumes, as we only apply affine transformation in x-y plane, the output slice + number will be the same as the input slice number, so only the output height and width + need to be given here, e.g., (H, W). By default (`None`), the output size will be the + same as the input size. + """ + def __init__(self, params): + super(Affine, self).__init__(params) + self.scale_range = params['Affine_scale_range'.lower()] + self.shear_range = params['Affine_shear_range'.lower()] + self.rotat_range = params['Affine_rotate_range'.lower()] + self.output_shape= params.get('Affine_output_size'.lower(), None) + self.inverse = params.get('Affine_inverse'.lower(), True) + + def _get_affine_param(self, sample, output_shape): + """ + output_shape should only has two dimensions, e.g., (H, W) + """ + input_shape = sample['image'].shape + input_dim = len(input_shape) - 1 + assert(len(output_shape) >=2) + + in_y, in_x = input_shape[-2:] + out_y, out_x = output_shape[-2:] + points = [[0, out_y], + [0, 0], + [out_x, 0], + [out_x, out_y]] + + sx = random.random() * (self.scale_range[1] - self.scale_range[0]) + self.scale_range[0] + sy = random.random() * (self.scale_range[1] - self.scale_range[0]) + self.scale_range[0] + shx = (random.random() * (self.shear_range[1] - self.shear_range[0]) + self.shear_range[0]) * 3.14159/180 + shy = (random.random() * (self.shear_range[1] - self.shear_range[0]) + self.shear_range[0]) * 3.14159/180 + rot = (random.random() * (self.rotat_range[1] - self.rotat_range[0]) + self.rotat_range[0]) * 3.14159/180 + # get affine transform parameters + new_points = [] + for p in points: + x = sx * p[0] * (math.cos(rot) + math.tan(shy) * math.sin(rot)) - \ + sy * p[1] * (math.tan(shx) * math.cos(rot) + math.sin(rot)) + y = sx * p[0] * (math.sin(rot) - math.tan(shy) * math.cos(rot)) - \ + sy * p[1] * (math.tan(shx) * math.sin(rot) - math.cos(rot)) + new_points.append([x,y]) + bb_min = np.array(new_points).min(axis = 0) + bb_max = np.array(new_points).max(axis = 0) + bbx, bby = int(bb_max[0] - bb_min[0]), int(bb_max[1] - bb_min[1]) + # transform the points to the image coordinate + margin_x = in_x - bbx + margin_y = in_y - bby + p0x = random.random() * margin_x if margin_x > 0 else margin_x / 2 + p0y = random.random() * margin_y if margin_y > 0 else margin_y / 2 + dst = [[new_points[i][0] - bb_min[0] + p0x, new_points[i][1] - bb_min[1] + p0y] \ + for i in range(3)] + + tform = transform.AffineTransform() + tform.estimate(np.array(points[:3]), np.array(dst)) + # to do: need to find a solution to save the affine transform matrix + # Use the matplotlib.transforms.Affine2D function to generate transform matrices, + # and the scipy.ndimage.warp function to warp images using the transform matrices. + # The skimage AffineTransform shear functionality is weird, + # and the scipy affine_transform function for warping images swaps the X and Y axes. + # sample['Affine_Param'] = json.dumps((input_shape, tform["matrix"])) + return sample, tform + + def _apply_affine_to_ND_volume(self, image, output_shape, tform, order = 3): + """ + output_shape should only has two dimensions, e.g., (H, W) + """ + dim = len(image.shape) - 1 + if(dim == 2): + C, H, W = image.shape + output = np.zeros([C] + output_shape) + for c in range(C): + output[c] = ndimage.affine_transform(image[c], tform, + output_shape = output_shape, mode='mirror', order = order) + elif(dim == 3): + C, D, H, W = image.shape + output = np.zeros([C, D] + output_shape) + for c in range(C): + for d in range(D): + output[c,d] = ndimage.affine_transform(image[c,d], tform, + output_shape = output_shape, mode='mirror', order = order) + return output + + def __call__(self, sample): + image = sample['image'] + input_shape = sample['image'].shape + output_shape= input_shape if self.output_shape is None else self.output_shape + aff_out_shape = output_shape[-2:] + sample, tform = self._get_affine_param(sample, aff_out_shape) + image_t = self._apply_affine_to_ND_volume(image, aff_out_shape, tform) + sample['image'] = image_t + + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + label = sample['label'] + label = self._apply_affine_to_ND_volume(label, aff_out_shape, tform, order = 0) + sample['label'] = label + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + weight = sample['pixel_weight'] + weight = self._apply_affine_to_ND_volume(weight, aff_out_shape, tform) + sample['pixel_weight'] = weight + return sample + + def _get_param_for_inverse_transform(self, sample): + if(isinstance(sample['Affine_Param'], list) or \ + isinstance(sample['Affine_Param'], tuple)): + params = json.loads(sample['Affine_Param'][0]) + else: + params = json.loads(sample['Affine_Param']) + return params + + # def inverse_transform_for_prediction(self, sample): + # params = self._get_param_for_inverse_transform(sample) + # origin_shape = params[0] + # tform = params[1] + + # predict = sample['predict'] + # if(isinstance(predict, tuple) or isinstance(predict, list)): + # output_predict = [] + # for predict_i in predict: + # aff_out_shape = origin_shape[-2:] + # output_predict_i = self._apply_affine_to_ND_volume(predict_i, + # aff_out_shape, tform.inverse) + # output_predict.append(output_predict_i) + # else: + # aff_out_shape = origin_shape[-2:] + # output_predict = self._apply_affine_to_ND_volume(predict, aff_out_shape, tform.inverse) + + # sample['predict'] = output_predict + # return sample diff --git a/pymic/transform/crop.py b/pymic/transform/crop.py index b4d0b63..6977639 100644 --- a/pymic/transform/crop.py +++ b/pymic/transform/crop.py @@ -113,16 +113,20 @@ class CropWithBoundingBox(CenterCrop): :param `CropWithBoundingBox_start`: (None, or list/tuple) The start index along each spatial axis. If None, calculate the start index automatically - so that the cropped region is centered at the non-zero region. + so that the cropped region is centered at the mask region defined by the threshold. :param `CropWithBoundingBox_output_size`: (None or tuple/list): Desired spatial output size. - If None, set it as the size of bounding box of non-zero region. + If None, set it as the size of bounding box of the mask region defined by the threshold. + :param `CropWithBoundingBox_threshold`: (None or float): + Threshold for obtaining a mask. This is used only when + `CropWithBoundingBox_start` is None. Default is 1.0 :param `CropWithBoundingBox_inverse`: (optional, bool) Is inverse transform needed for inference. Default is `True`. """ def __init__(self, params): self.start = params['CropWithBoundingBox_start'.lower()] self.output_size = params['CropWithBoundingBox_output_size'.lower()] + self.threshold = params.get('CropWithBoundingBox_threshold'.lower(), 1.0) self.inverse = params.get('CropWithBoundingBox_inverse'.lower(), True) self.task = params['task'] @@ -130,8 +134,9 @@ def _get_crop_param(self, sample): image = sample['image'] input_shape = sample['image'].shape input_dim = len(input_shape) - 1 - bb_min, bb_max = get_ND_bounding_box(image) - bb_min, bb_max = bb_min[1:], bb_max[1:] + if(self.start is None or self.output_size is None): + bb_min, bb_max = get_ND_bounding_box(image > self.threshold) + bb_min, bb_max = bb_min[1:], bb_max[1:] if(self.start is None): if(self.output_size is None): crop_min, crop_max = bb_min, bb_max @@ -212,7 +217,9 @@ class RandomCrop(CenterCrop): :param `RandomCrop_output_size`: (list/tuple) Desired output size [D, H, W] or [H, W]. The output channel is the same as the input channel. - If D is None for 3D images, the z-axis is not cropped. + If `None` is set for a certain axis, that axis will not be cropped. For example, + for 3D vlumes, (None, H, W) means only crop in 2D, and (D, None, None) means only + crop along the z axis. :param `RandomCrop_foreground_focus`: (optional, bool) If true, allow crop around the foreground. Default is False. :param `RandomCrop_foreground_ratio`: (optional, float) @@ -242,10 +249,16 @@ def _get_crop_param(self, sample): input_shape = image.shape[1:] input_dim = len(input_shape) assert(input_dim == len(self.output_size)) - - crop_margin = [input_shape[i] - self.output_size[i] for i in range(input_dim)] + + output_size = [item for item in self.output_size] + # print("crop input and output size", input_shape, output_size) + for i in range(input_dim): + if(output_size[i] is None): + output_size[i] = input_shape[i] + # print(output_size) + crop_margin = [input_shape[i] - output_size[i] for i in range(input_dim)] crop_min = [0 if item == 0 else random.randint(0, item) for item in crop_margin] - crop_max = [crop_min[i] + self.output_size[i] for i in range(input_dim)] + crop_max = [crop_min[i] + output_size[i] for i in range(input_dim)] label_exist = False if ('label' not in sample or sample['label']) is None else True if(label_exist and self.fg_focus and random.random() < self.fg_ratio): @@ -255,7 +268,7 @@ def _get_crop_param(self, sample): else: mask_label = self.mask_label random_label = random.choice(mask_label) - crop_min, crop_max = get_random_box_from_mask(label == random_label, self.output_size, mode = 1) + crop_min, crop_max = get_random_box_from_mask(label == random_label, output_size, mode = 1) crop_min = [0] + crop_min crop_max = [chns] + crop_max diff --git a/pymic/transform/flip.py b/pymic/transform/flip.py index 486180c..6ea017c 100644 --- a/pymic/transform/flip.py +++ b/pymic/transform/flip.py @@ -6,7 +6,6 @@ import math import random import numpy as np -from scipy import ndimage from pymic import TaskType from pymic.transform.abstract_transform import AbstractTransform from pymic.util.image_process import * diff --git a/pymic/transform/intensity.py b/pymic/transform/intensity.py index ffa5141..d39eb2c 100644 --- a/pymic/transform/intensity.py +++ b/pymic/transform/intensity.py @@ -62,6 +62,7 @@ def __init__(self, params): self.channels = params['IntensityClip_channels'.lower()] self.lower = params.get('IntensityClip_lower'.lower(), None) self.upper = params.get('IntensityClip_upper'.lower(), None) + self.perct = params.get('IntensityClip_percentile_mode'.lower(), False) self.inverse = params.get('IntensityClip_inverse'.lower(), False) def __call__(self, sample): @@ -72,8 +73,12 @@ def __call__(self, sample): lower_c, upper_c = lower[chn], upper[chn] if(lower_c is None): lower_c = np.percentile(image[chn], 0.05) + elif(self.perct): + lower_c = np.percentile(image[chn], lower_c) if(upper_c is None): - upper_c = np.percentile(image[chn, 99.95]) + upper_c = np.percentile(image[chn], 99.95) + elif(self.perct): + upper_c = np.percentile(image[chn], upper_c) image[chn] = np.clip(image[chn], lower_c, upper_c) sample['image'] = image return sample diff --git a/pymic/transform/normalize.py b/pymic/transform/normalize.py index 6531f17..3d28c0d 100644 --- a/pymic/transform/normalize.py +++ b/pymic/transform/normalize.py @@ -24,11 +24,12 @@ class NormalizeWithMeanStd(AbstractTransform): :param `NormalizeWithMeanStd_std`: (list/tuple or None) The std values along each specified channel. If None, the std values are calculated automatically. - :param `NormalizeWithMeanStd_ignore_non_positive`: (optional, bool) - Only used when mean and std are not given. Default is False. - If True, calculate mean and std in the positive region for normalization, - and set non-positive region to random. If False, calculate - the mean and std values in the entire image region. + :param `NormalizeWithMeanStd_mask_threshold`: (optional, float) + Only used when mean and std are not given. Default is 1.0. + Calculate mean and std in the mask region where the intensity is higher than the mask. + :param `NormalizeWithMeanStd_set_background_to_random`: (optional, bool) + Set background region to random or not, and only applicable when + `NormalizeWithMeanStd_mask_threshold` is not None. Default is True. :param `NormalizeWithMeanStd_inverse`: (optional, bool) Is inverse transform needed for inference. Default is `False`. """ @@ -37,8 +38,9 @@ def __init__(self, params): self.chns = params.get('NormalizeWithMeanStd_channels'.lower(), None) self.mean = params.get('NormalizeWithMeanStd_mean'.lower(), None) self.std = params.get('NormalizeWithMeanStd_std'.lower(), None) - self.ingore_np = params.get('NormalizeWithMeanStd_ignore_non_positive'.lower(), False) - self.inverse = params.get('NormalizeWithMeanStd_inverse'.lower(), False) + self.mask_thrd = params.get('NormalizeWithMeanStd_mask_threshold'.lower(), 1.0) + self.bg_random = params.get('NormalizeWithMeanStd_set_background_to_random'.lower(), True) + self.inverse = params.get('NormalizeWithMeanStd_inverse'.lower(), False) def __call__(self, sample): image= sample['image'] @@ -52,8 +54,8 @@ def __call__(self, sample): chn = self.chns[i] chn_mean, chn_std = self.mean[i], self.std[i] if(chn_mean is None): - if(self.ingore_np): - pixels = image[chn][image[chn] > 0] + if(self.mask_thrd is not None): + pixels = image[chn][image[chn] > self.mask_thrd] if(len(pixels) > 0): chn_mean, chn_std = pixels.mean(), pixels.std() + 1e-5 else: @@ -63,9 +65,9 @@ def __call__(self, sample): chn_norm = (image[chn] - chn_mean)/chn_std - if(self.ingore_np): + if(self.mask_thrd is not None and self.bg_random): chn_random = np.random.normal(0, 1, size = chn_norm.shape) - chn_norm[image[chn] <= 0] = chn_random[image[chn] <= 0] + chn_norm[image[chn] <= self.mask_thrd] = chn_random[image[chn] <=self.mask_thrd] image[chn] = chn_norm sample['image'] = image return sample diff --git a/pymic/transform/pad.py b/pymic/transform/pad.py index c9b75fe..8624aa2 100644 --- a/pymic/transform/pad.py +++ b/pymic/transform/pad.py @@ -6,7 +6,6 @@ import math import random import numpy as np -from scipy import ndimage from pymic import TaskType from pymic.transform.abstract_transform import AbstractTransform from pymic.util.image_process import * diff --git a/pymic/transform/trans_dict.py b/pymic/transform/trans_dict.py index f779d00..a28e848 100644 --- a/pymic/transform/trans_dict.py +++ b/pymic/transform/trans_dict.py @@ -31,6 +31,7 @@ """ from __future__ import print_function, division +from pymic.transform.affine import * from pymic.transform.intensity import * from pymic.transform.flip import * from pymic.transform.pad import * @@ -44,6 +45,7 @@ from pymic.transform.label_convert import * TransformDict = { + 'Affine': Affine, 'ChannelWiseThreshold': ChannelWiseThreshold, 'ChannelWiseThresholdWithNormalize': ChannelWiseThresholdWithNormalize, 'CropWithBoundingBox': CropWithBoundingBox, From 0a65b4994987e4c69604f9d1d2fba44f8726bc0c Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 13 Dec 2023 10:46:34 +0800 Subject: [PATCH 34/86] update segmentation and reconstruction agent --- pymic/net_run/agent_rec.py | 44 +++++++++-------- pymic/net_run/agent_seg.py | 63 ++++++++++++++---------- pymic/net_run/self_sup/util.py | 88 ++++++++++++++++++++++++++++++---- pymic/net_run/train.py | 26 ++-------- pymic/transform/mix.py | 86 +++++++++++++++++---------------- 5 files changed, 188 insertions(+), 119 deletions(-) diff --git a/pymic/net_run/agent_rec.py b/pymic/net_run/agent_rec.py index 1e58bc6..634ea20 100644 --- a/pymic/net_run/agent_rec.py +++ b/pymic/net_run/agent_rec.py @@ -29,14 +29,6 @@ class ReconstructionAgent(SegmentationAgent): """ def __init__(self, config, stage = 'train'): super(ReconstructionAgent, self).__init__(config, stage) - output_act_name = config['network'].get('output_activation', 'sigmoid') - if(output_act_name == "sigmoid"): - self.out_act = nn.Sigmoid() - elif(output_act_name == "tanh"): - self.out_act = nn.Tanh() - else: - raise ValueError("For reconstruction task, only sigmoid and tanh are " + \ - "supported for output_activation.") def create_loss_calculator(self): if(self.loss_dict is None): @@ -48,7 +40,6 @@ def create_loss_calculator(self): raise ValueError("Undefined loss function {0:}".format(loss_name)) else: loss_param = self.config['training'] - loss_param['loss_softmax'] = False base_loss = self.loss_dict[loss_name](self.config['training']) if(self.config['training'].get('deep_supervise', False)): raise ValueError("Deep supervised loss not implemented for reconstruction tasks") @@ -80,8 +71,13 @@ def training(self): # print(inputs.shape) # for i in range(inputs.shape[0]): # image_i = inputs[i][0] + # label_i = label[i][0] # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) + # label_name = "temp/label_{0:}_{1:}.nii.gz".format(it, i) # save_nd_array_as_image(image_i, image_name, reference_name = None) + # save_nd_array_as_image(label_i, label_name, reference_name = None) + # if(it > 10): + # break # return inputs, label = inputs.to(self.device), label.to(self.device) @@ -91,7 +87,18 @@ def training(self): # forward + backward + optimize outputs = self.net(inputs) - outputs = self.out_act(outputs) + + # for debug + # if it < 5: + # outputs = nn.Tanh()(outputs) + # for i in range(inputs.shape[0]): + # out_name = "temp/output_{0:}_{1:}.nii.gz".format(it, i) + # output = outputs[i][0] + # output = output.cpu().detach().numpy() + # save_nd_array_as_image(output, out_name, reference_name = None) + # else: + # break + loss = self.get_loss_value(data, outputs, label) loss.backward() self.optimizer.step() @@ -123,7 +130,6 @@ def validation(self): label = self.convert_tensor_type(data['label']) inputs, label = inputs.to(self.device), label.to(self.device) outputs = self.inferer.run(self.net, inputs) - outputs = self.out_act(outputs) # The tensors are on CPU when calculating loss for validation data loss = self.get_loss_value(data, outputs, label) valid_loss_list.append(loss.item()) @@ -293,19 +299,19 @@ def save_outputs(self, data): names, pred = data['names'], data['predict'] if(isinstance(pred, (list, tuple))): pred = pred[0] - if(isinstance(self.out_act, nn.Sigmoid)): - pred = scipy.special.expit(pred) - else: - pred = np.tanh(pred) + pred = np.tanh(pred) # save the output predictions - root_dir = self.config['dataset']['root_dir'] + test_dir = self.config['dataset'].get('test_dir', None) + if(test_dir is None): + test_dir = self.config['dataset']['train_dir'] + for i in range(len(names)): - save_name = names[i].split('/')[-1] if ignore_dir else \ - names[i].replace('/', '_') + save_name = names[i][0].split('/')[-1] if ignore_dir else \ + names[i][0].replace('/', '_') if((filename_replace_source is not None) and (filename_replace_target is not None)): save_name = save_name.replace(filename_replace_source, filename_replace_target) print(save_name) save_name = "{0:}/{1:}".format(output_dir, save_name) - save_nd_array_as_image(pred[i][i], save_name, root_dir + '/' + names[i]) + save_nd_array_as_image(pred[i][i], save_name, test_dir + '/' + names[i][0]) \ No newline at end of file diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index c376546..14f7567 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -39,36 +39,44 @@ def __init__(self, config, stage = 'train'): self.net_dict = SegNetDict self.postprocess_dict = PostProcessDict self.postprocessor = None - - def get_stage_dataset_from_config(self, stage): - assert(stage in ['train', 'valid', 'test']) - root_dir = self.config['dataset']['root_dir'] - modal_num = self.config['dataset'].get('modal_num', 1) + def get_transform_names_and_parameters(self, stage): + """ + Get a list of transform objects for creating a dataset + """ + assert(stage in ['train', 'valid', 'test']) transform_key = stage + '_transform' if(stage == "valid" and transform_key not in self.config['dataset']): transform_key = "train_transform" - transform_names = self.config['dataset'][transform_key] - - self.transform_list = [] - if(transform_names is None or len(transform_names) == 0): - data_transform = None - else: - transform_param = self.config['dataset'] - transform_param['task'] = self.task_type - for name in transform_names: + trans_names = self.config['dataset'][transform_key] + trans_params = self.config['dataset'] + trans_params['task'] = self.task_type + return trans_names, trans_params + + def get_stage_dataset_from_config(self, stage): + trans_names, trans_params = self.get_transform_names_and_parameters(stage) + transform_list = [] + if(trans_names is not None and len(trans_names) > 0): + for name in trans_names: if(name not in self.transform_dict): raise(ValueError("Undefined transform {0:}".format(name))) - one_transform = self.transform_dict[name](transform_param) - self.transform_list.append(one_transform) - data_transform = transforms.Compose(self.transform_list) + one_transform = self.transform_dict[name](trans_params) + transform_list.append(one_transform) + data_transform = transforms.Compose(transform_list) - csv_file = self.config['dataset'].get(stage + '_csv', None) + csv_file = self.config['dataset'].get(stage + '_csv', None) if(stage == 'test'): with_label = False + self.test_transforms = transform_list else: with_label = self.config['dataset'].get(stage + '_label', True) - dataset = NiftyDataset(root_dir = root_dir, + modal_num = self.config['dataset'].get('modal_num', 1) + stage_dir = self.config['dataset'].get('train_dir', None) + if(stage == 'valid' and "valid_dir" in self.config['dataset']): + stage_dir = self.config['dataset']['valid_dir'] + if(stage == 'test' and "test_dir" in self.config['dataset']): + stage_dir = self.config['dataset']['test_dir'] + dataset = NiftyDataset(root_dir = stage_dir, csv_file = csv_file, modal_num = modal_num, with_label= with_label, @@ -471,7 +479,7 @@ def test_time_dropout(m): pred = pred.cpu().numpy() data['predict'] = pred # inverse transform - for transform in self.transform_list[::-1]: + for transform in self.test_transforms[::-1]: if (transform.inverse): data = transform.inverse_transform_for_prediction(data) @@ -525,7 +533,7 @@ def infer_with_multiple_checkpoints(self): pred = np.mean(predict_list, axis=0) data['predict'] = pred # inverse transform - for transform in self.transform_list[::-1]: + for transform in self.test_transforms[::-1]: if (transform.inverse): data = transform.inverse_transform_for_prediction(data) @@ -564,15 +572,18 @@ def save_outputs(self, data): for i in range(len(names)): output[i] = self.postprocessor(output[i]) # save the output and (optionally) probability predictions - root_dir = self.config['dataset']['root_dir'] + test_dir = self.config['dataset'].get('test_dir', None) + if(test_dir is None): + test_dir = self.config['dataset']['train_dir'] + for i in range(len(names)): - save_name = names[i].split('/')[-1] if ignore_dir else \ - names[i].replace('/', '_') + save_name = names[i][0].split('/')[-1] if ignore_dir else \ + names[i][0].replace('/', '_') if((filename_replace_source is not None) and (filename_replace_target is not None)): save_name = save_name.replace(filename_replace_source, filename_replace_target) print(save_name) save_name = "{0:}/{1:}".format(output_dir, save_name) - save_nd_array_as_image(output[i], save_name, root_dir + '/' + names[i]) + save_nd_array_as_image(output[i], save_name, test_dir + '/' + names[i][0]) save_name_split = save_name.split('.') if(not save_prob): @@ -590,4 +601,4 @@ def save_outputs(self, data): prob_save_name = "{0:}_prob_{1:}.{2:}".format(save_prefix, c, save_format) if(len(temp_prob.shape) == 2): temp_prob = np.asarray(temp_prob * 255, np.uint8) - save_nd_array_as_image(temp_prob, prob_save_name, root_dir + '/' + names[i]) + save_nd_array_as_image(temp_prob, prob_save_name, test_dir + '/' + names[i][0]) diff --git a/pymic/net_run/self_sup/util.py b/pymic/net_run/self_sup/util.py index 9cffaa7..0ab1670 100644 --- a/pymic/net_run/self_sup/util.py +++ b/pymic/net_run/self_sup/util.py @@ -19,6 +19,10 @@ def get_human_region_mask(img): mask = np.asarray(img > -600) se = np.ones([3,3,3]) mask = ndimage.binary_opening(mask, se, iterations = 2) + D, H, W = mask.shape + for h in range(H): + if(mask[:,h,:].sum() < 2000): + mask[:,h, :] = np.zeros((D, W)) mask = get_largest_k_components(mask, 1) mask_close = ndimage.binary_closing(mask, se, iterations = 2) @@ -47,20 +51,39 @@ def get_human_region_mask(img): fg = np.asarray(fg, np.uint8) return fg -def crop_ct_scan(input_img, output_img, input_lab = None, output_lab = None): +def get_human_region_mask_fast(img, itk_spacing): + # downsample + D, H, W = img.shape + # scale_down = [1, 1, 1] + if(itk_spacing[2] <= 1): + scale_down = [1/2, 1/2, 1/2] + else: + scale_down = [1, 1/2, 1/2] + img_sub = ndimage.interpolation.zoom(img, scale_down, order = 0) + mask = get_human_region_mask(img_sub) + D1, H1, W1 = mask.shape + scale_up = [D/D1, H/H1, W/W1] + mask = ndimage.interpolation.zoom(mask, scale_up, order = 0) + return mask + +def crop_ct_scan(input_img, output_img, input_lab = None, output_lab = None, z_axis_density = 0.5): """ Crop a CT scan based on the bounding box of the human region. """ img_obj = sitk.ReadImage(input_img) - img = sitk.GetArrayFromImage(img_obj) - mask = np.asarray(img > -600) - se = np.ones([3,3,3]) - mask = ndimage.binary_opening(mask, se, iterations = 2) - mask = get_largest_k_components(mask, 1) - bbmin, bbmax = get_ND_bounding_box(mask, margin = [5, 10, 10]) + img = sitk.GetArrayFromImage(img_obj) + mask = np.asarray(img > -600) + mask2d = np.mean(mask, axis = 0) > z_axis_density + se = np.ones([3,3]) + mask2d = ndimage.binary_opening(mask2d, se, iterations = 2) + mask2d = get_largest_k_components(mask2d, 1) + bbmin, bbmax = get_ND_bounding_box(mask2d, margin = [0, 0]) + bbmin = [0] + bbmin + bbmax = [img.shape[0]] + bbmax img_sub = crop_ND_volume_with_bounding_box(img, bbmin, bbmax) img_sub_obj = sitk.GetImageFromArray(img_sub) img_sub_obj.SetSpacing(img_obj.GetSpacing()) + img_sub_obj.SetDirection(img_obj.GetDirection()) sitk.WriteImage(img_sub_obj, output_img) if(input_lab is not None): lab_obj = sitk.ReadImage(input_lab) @@ -70,6 +93,49 @@ def crop_ct_scan(input_img, output_img, input_lab = None, output_lab = None): lab_sub_obj.SetSpacing(img_obj.GetSpacing()) sitk.WriteImage(lab_sub_obj, output_lab) +def get_human_body_mask_and_crop(input_dir, out_img_dir, out_mask_dir): + if(not os.path.exists(out_img_dir)): + os.mkdir(out_img_dir) + os.mkdir(out_mask_dir) + + img_names = [item for item in os.listdir(input_dir) if "nii.gz" in item] + img_names = sorted(img_names) + for img_name in img_names: + print(img_name) + input_name = input_dir + "/" + img_name + out_name = out_img_dir + "/" + img_name + mask_name = out_mask_dir + "/" + img_name + if(os.path.isfile(out_name)): + continue + img_obj = sitk.ReadImage(input_name) + img = sitk.GetArrayFromImage(img_obj) + spacing = img_obj.GetSpacing() + + # downsample + D, H, W = img.shape + spacing = img_obj.GetSpacing() + # scale_down = [1, 1, 1] + if(spacing[2] <= 1): + scale_down = [1/2, 1/2, 1/2] + else: + scale_down = [1, 1/2, 1/2] + img_sub = ndimage.interpolation.zoom(img, scale_down, order = 0) + mask = get_human_region_mask(img_sub) + D1, H1, W1 = mask.shape + scale_up = [D/D1, H/H1, W/W1] + mask = ndimage.interpolation.zoom(mask, scale_up, order = 0) + + bbmin, bbmax = get_ND_bounding_box(mask) + img_crop = crop_ND_volume_with_bounding_box(img, bbmin, bbmax) + mask_crop = crop_ND_volume_with_bounding_box(mask, bbmin, bbmax) + + out_img_obj = sitk.GetImageFromArray(img_crop) + out_img_obj.SetSpacing(spacing) + sitk.WriteImage(out_img_obj, out_name) + mask_obj = sitk.GetImageFromArray(mask_crop) + mask_obj.CopyInformation(out_img_obj) + sitk.WriteImage(mask_obj, mask_name) + def patch_mix(x, fg_num, patch_num, size_d, size_h, size_w): """ @@ -99,7 +165,7 @@ def patch_mix(x, fg_num, patch_num, size_d, size_h, size_w): y_prob = get_one_hot_seg(fg_mask.to(torch.int32), fg_num + 1) return x_fuse, y_prob -def create_mixed_dataset(input_dir, output_dir, fg_num = 1, crop_num = 1, +def create_mixed_dataset(input_dir, output_dir, fg_num = 1, crop_num = 1, patch_size=[128,128,128], mask_dir = None, data_format = "nii.gz"): """ Create dataset based on patch mix. @@ -136,6 +202,7 @@ def create_mixed_dataset(input_dir, output_dir, fg_num = 1, crop_num = 1, img_j = load_image_as_nd_array(input_dir + "/" + img_names[j])['data_array'] chns = img_i.shape[0] + crop_size = [chns] + patch_size # random crop to patch size if(mask_dir is None): mask_i = get_human_region_mask(img_i) @@ -148,8 +215,9 @@ def create_mixed_dataset(input_dir, output_dir, fg_num = 1, crop_num = 1, # img_ik = random_crop_ND_volume(img_i, [chns, 96, 96, 96]) # img_jk = random_crop_ND_volume(img_j, [chns, 96, 96, 96]) # else: - img_ik = random_crop_ND_volume_with_mask(img_i, [chns, 96, 96, 96], mask_i) - img_jk = random_crop_ND_volume_with_mask(img_j, [chns, 96, 96, 96], mask_j) + + img_ik = random_crop_ND_volume_with_mask(img_i, crop_size, mask_i) + img_jk = random_crop_ND_volume_with_mask(img_j, crop_size, mask_j) C, D, H, W = img_ik.shape # generate mask fg_mask = np.zeros_like(img_ik, np.uint8) diff --git a/pymic/net_run/train.py b/pymic/net_run/train.py index 426b620..f2bbe0f 100644 --- a/pymic/net_run/train.py +++ b/pymic/net_run/train.py @@ -11,8 +11,9 @@ from pymic.net_run.agent_seg import SegmentationAgent from pymic.net_run.semi_sup import SSLMethodDict from pymic.net_run.weak_sup import WSLMethodDict +from pymic.net_run.self_sup import SelfSupMethodDict from pymic.net_run.noisy_label import NLLMethodDict -from pymic.net_run.self_sup import SelfSLSegAgent +# from pymic.net_run.self_sup import SelfSLSegAgent def get_seg_rec_agent(config, sup_type): assert(sup_type in ['fully_sup', 'semi_sup', 'self_sup', 'weak_sup', 'noisy_label']) @@ -34,28 +35,7 @@ def get_seg_rec_agent(config, sup_type): elif(sup_type == 'self_sup'): logging.info("\n********** Self Supervised Learning **********\n") method = config['self_supervised_learning']['method_name'] - if(method == "custom"): - pass - elif(method == "model_genesis"): - transforms = ['RandomFlip', 'LocalShuffling', 'NonLinearTransform', 'InOutPainting'] - genesis_cfg = { - 'randomflip_flip_depth': True, - 'randomflip_flip_height': True, - 'randomflip_flip_width': True, - 'localshuffling_probability': 0.5, - 'nonLineartransform_probability': 0.9, - 'inoutpainting_probability': 0.9, - 'inpainting_probability': 0.2 - } - config['dataset']['train_transform'].extend(transforms) - # config['dataset']['valid_transform'].extend(transforms) - config['dataset'].update(genesis_cfg) - logging_config(config['dataset']) - else: - raise ValueError("The specified method {0:} is not implemented. ".format(method) + \ - "Consider to set `self_sl_method = custom` and use customized" + \ - " transforms for self-supervised learning.") - agent = SelfSLSegAgent(config, 'train') + agent = SelfSupMethodDict[method](config, 'train') else: raise ValueError("undefined supervision type: {0:}".format(sup_type)) return agent diff --git a/pymic/transform/mix.py b/pymic/transform/mix.py index 6e2fb8e..6efed6a 100644 --- a/pymic/transform/mix.py +++ b/pymic/transform/mix.py @@ -71,7 +71,8 @@ class PatchMix(AbstractTransform): """ def __init__(self, params): super(PatchMix, self).__init__(params) - self.inverse = params.get('PatchMix_inverse'.lower(), False) + self.inverse = params.get('PatchMix_inverse'.lower(), False) + self.threshold = params.get('PatchMix_threshold'.lower(), 0) self.crop_size = params.get('PatchMix_crop_size'.lower(), [64, 128, 128]) self.fg_cls_num = params.get('PatchMix_cls_num'.lower(), [4, 40]) self.patch_num_range= params.get('PatchMix_patch_range'.lower(), [4, 40]) @@ -79,7 +80,8 @@ def __init__(self, params): self.patch_size_max = params.get('PatchMix_patch_size_max'.lower(), [20, 40, 40]) def __call__(self, sample): - x0, x1 = self._random_crop_and_flip(sample) + x0 = self._random_crop_and_flip(sample) + x1 = self._random_crop_and_flip(sample) C, D, H, W = x0.shape # generate mask fg_mask = np.zeros_like(x0, np.uint8) @@ -104,46 +106,48 @@ def __call__(self, sample): return sample def _random_crop_and_flip(self, sample): - input_shape = sample['image'].shape - input_dim = len(input_shape) - 1 + image = sample['image'] + input_dim = len(image.shape) - 1 assert(input_dim == 3) - + C, D, H, W = image.shape + + half_size = [x // 2 for x in self.crop_size] + dc = random.randint(half_size[0], D - half_size[0]) + image2d = image[0, dc, :, :] + mask2d = np.zeros_like(image2d) + mask2d[half_size[1]:H+1-half_size[1], half_size[2]:W+1-half_size[2]] = \ + np.ones([H-self.crop_size[1]+1, W-self.crop_size[2]+1]) if('label' in sample): - # get the center for crop randomly - mask = sample['label'] > 0 - C, D, H, W = input_shape - size_h = [i// 2 for i in self.crop_size] - temp_mask = np.zeros_like(mask) - temp_mask[:,size_h[0]:D-size_h[0]+1,size_h[1]:H-size_h[1]+1,size_h[2]:W-size_h[2]+1] = \ - np.ones([C, D-self.crop_size[0]+1, H-self.crop_size[1]+1, W-self.crop_size[2]+1]) - mask = mask * temp_mask - indices = np.where(mask) - n0 = random.randint(0, len(indices[0])-1) - n1 = random.randint(0, len(indices[0])-1) - center0 = [indices[i][n0] for i in range(1, 4)] - center1 = [indices[i][n1] for i in range(1, 4)] - crop_min0 = [center0[i] - size_h[i] for i in range(3)] - crop_min1 = [center1[i] - size_h[i] for i in range(3)] - else: - crop_margin = [input_shape[1+i] - self.crop_size[i] for i in range(input_dim)] - crop_min0 = [0 if item == 0 else random.randint(0, item) for item in crop_margin] - crop_min1 = [0 if item == 0 else random.randint(0, item) for item in crop_margin] + temp_mask = sample['label'][0, dc, :, :] > 0 + mask2d = temp_mask * mask2d + elif(self.threshold is not None): + temp_mask = image2d > self.threshold + se = np.ones([3,3]) + temp_mask = ndimage.binary_opening(temp_mask, se, iterations = 2) + temp_mask = get_largest_k_components(temp_mask, 1) + mask2d = temp_mask * mask2d + + indices = np.where(mask2d) + n = random.randint(0, len(indices[0])-1) + center = [indices[i][n] for i in range(2)] + crop_min = [dc - half_size[0], center[0]-half_size[1], center[1] - half_size[2]] + crop_max = [crop_min[i] + self.crop_size[i] for i in range(input_dim)] + crop_min = [0] + crop_min + crop_max = [C] + crop_max + x = crop_ND_volume_with_bounding_box(image, crop_min, crop_max) - patches = [] - for crop_min in [crop_min0, crop_min1]: - crop_max = [crop_min[i] + self.crop_size[i] for i in range(input_dim)] - crop_min = [0] + crop_min - crop_max = [C] + crop_max - x = crop_ND_volume_with_bounding_box(sample['image'], crop_min, crop_max) - flip_axis = [] - if(random.random() > 0.5): - flip_axis.append(-1) - if(random.random() > 0.5): - flip_axis.append(-2) - if(random.random() > 0.5): - flip_axis.append(-3) - if(len(flip_axis) > 0): - x = np.flip(x, flip_axis).copy() - patches.append(x) + flip_axis = [] + if(random.random() > 0.5): + flip_axis.append(-1) + if(random.random() > 0.5): + flip_axis.append(-2) + if(random.random() > 0.5): + flip_axis.append(-3) + if(len(flip_axis) > 0): + x = np.flip(x, flip_axis).copy() - return patches \ No newline at end of file + if(x.shape[1] == 63): + print("crop shape == 63", x.shape) + print(sample['names']) + print(image.shape, crop_min, crop_max) + return x \ No newline at end of file From a906dfc1e23045c9aa0b235adf7ae550889dc9dc Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 3 Jan 2024 10:18:34 +0800 Subject: [PATCH 35/86] update self-supervised learning --- .gitignore | 1 + pymic/io/h5_dataset.py | 7 +- pymic/io/image_read_write.py | 13 +- pymic/io/nifty_dataset.py | 23 +- pymic/loss/seg/abstract.py | 15 +- pymic/loss/seg/ce.py | 15 +- pymic/loss/seg/dice.py | 20 +- pymic/loss/seg/exp_log.py | 4 +- pymic/loss/seg/mse.py | 8 +- pymic/net/net2d/unet2d.py | 44 ++-- pymic/net/net2d/unet2d_canet.py | 66 ------ pymic/net/net_dict_cls.py | 2 +- pymic/net/net_dict_seg.py | 2 + pymic/net_run/agent_abstract.py | 13 +- pymic/net_run/agent_rec.py | 5 +- pymic/net_run/agent_seg.py | 8 +- pymic/net_run/get_optimizer.py | 3 +- pymic/net_run/preprocess.py | 82 +++++++ pymic/net_run/self_sup/__init__.py | 11 +- pymic/net_run/self_sup/self_genesis.py | 51 +++++ pymic/net_run/self_sup/self_patch_swapping.py | 44 ++++ pymic/net_run/self_sup/self_sl_agent.py | 3 +- ...tch_mix_agent.py => self_volume_fusion.py} | 49 +--- pymic/net_run/self_sup/util.py | 121 ++-------- pymic/net_run/semi_sup/__init__.py | 15 +- pymic/net_run/semi_sup/ssl_abstract.py | 6 +- pymic/net_run/semi_sup/ssl_cps.py | 8 +- pymic/net_run/train.py | 6 +- pymic/transform/crop.py | 90 +++++++- pymic/transform/extract_channel.py | 39 ++++ pymic/transform/intensity.py | 210 ++++++++---------- pymic/transform/normalize.py | 8 +- pymic/transform/rescale.py | 1 + pymic/transform/trans_dict.py | 2 + pymic/util/image_process.py | 2 +- 35 files changed, 580 insertions(+), 417 deletions(-) create mode 100644 pymic/net_run/preprocess.py create mode 100644 pymic/net_run/self_sup/self_genesis.py create mode 100644 pymic/net_run/self_sup/self_patch_swapping.py rename pymic/net_run/self_sup/{self_patch_mix_agent.py => self_volume_fusion.py} (71%) create mode 100644 pymic/transform/extract_channel.py diff --git a/.gitignore b/.gitignore index f7da7ac..639f873 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ dist/* *egg*/* *stop* files.txt +pymic/test/runs/* # Created by https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks # Edit at https://www.toptal.com/developers/gitignore?templates=python,jupyternotebooks diff --git a/pymic/io/h5_dataset.py b/pymic/io/h5_dataset.py index 02f94f3..34fa1a4 100644 --- a/pymic/io/h5_dataset.py +++ b/pymic/io/h5_dataset.py @@ -8,8 +8,9 @@ import pandas as pd from torch.utils.data import Dataset from torch.utils.data.sampler import Sampler +from pymic import TaskType -class H5DataSet(Dataset): +class H5DataSet_backup(Dataset): """ Dataset for loading images stored in h5 format. It generates 4D tensors with dimention order [C, D, H, W] for 3D images, and @@ -39,7 +40,9 @@ def __getitem__(self, idx): if self.transform: sample = self.transform(sample) return sample - + + + class TwoStreamBatchSampler(Sampler): """Iterate two sets of indices diff --git a/pymic/io/image_read_write.py b/pymic/io/image_read_write.py index 6c8c6b0..3aa87bd 100644 --- a/pymic/io/image_read_write.py +++ b/pymic/io/image_read_write.py @@ -79,10 +79,10 @@ def load_image_as_nd_array(image_name): image_name.endswith(".tif") or image_name.endswith(".png")): image_dict = load_rgb_image_as_3d_array(image_name) else: - raise ValueError("unsupported image format") + raise ValueError("unsupported image format: {0:}".format(image_name)) return image_dict -def save_array_as_nifty_volume(data, image_name, reference_name = None): +def save_array_as_nifty_volume(data, image_name, reference_name = None, spacing = [1.0,1.0,1.0]): """ Save a numpy array as nifty image @@ -90,6 +90,7 @@ def save_array_as_nifty_volume(data, image_name, reference_name = None): :param image_name: (str) The ouput file name. :param reference_name: (str) File name of the reference image of which meta information is used. + :param spacing: (list or tuple) the spacing of a volume data when `reference_name` is not provided. """ img = sitk.GetImageFromArray(data) if(reference_name is not None): @@ -101,6 +102,9 @@ def save_array_as_nifty_volume(data, image_name, reference_name = None): direction1 = img.GetDirection() if(len(direction0) == len(direction1)): img.SetDirection(direction0) + else: + nifty_spacing = spacing[1:] + spacing[:1] + img.SetSpacing(nifty_spacing) sitk.WriteImage(img, image_name) def save_array_as_rgb_image(data, image_name): @@ -119,7 +123,7 @@ def save_array_as_rgb_image(data, image_name): img = Image.fromarray(data) img.save(image_name) -def save_nd_array_as_image(data, image_name, reference_name = None): +def save_nd_array_as_image(data, image_name, reference_name = None, spacing = [1.0,1.0,1.0]): """ Save a 3D or 2D numpy array as medical image or RGB image @@ -127,13 +131,14 @@ def save_nd_array_as_image(data, image_name, reference_name = None): [H, W, 3] or [H, W]. :param reference_name: (str) File name of the reference image of which meta information is used. + :param spacing: (list or tuple) the spacing of a volume data when `reference_name` is not provided. """ data_dim = len(data.shape) assert(data_dim == 2 or data_dim == 3) if (image_name.endswith(".nii.gz") or image_name.endswith(".nii") or image_name.endswith(".mha")): assert(data_dim == 3) - save_array_as_nifty_volume(data, image_name, reference_name) + save_array_as_nifty_volume(data, image_name, reference_name, spacing) elif(image_name.endswith(".jpg") or image_name.endswith(".jpeg") or image_name.endswith(".tif") or image_name.endswith(".png")): diff --git a/pymic/io/nifty_dataset.py b/pymic/io/nifty_dataset.py index 438a07b..aefe4da 100644 --- a/pymic/io/nifty_dataset.py +++ b/pymic/io/nifty_dataset.py @@ -3,11 +3,9 @@ import logging import os -import torch import pandas as pd import numpy as np -from torch.utils.data import Dataset, DataLoader -from torchvision import transforms, utils +from torch.utils.data import Dataset from pymic import TaskType from pymic.io.image_read_write import load_image_as_nd_array @@ -70,6 +68,25 @@ def __get_pixel_weight__(self, idx): weight = np.asarray(weight, np.float32) return weight + # def __getitem__(self, idx): + # sample_name = self.csv_items.iloc[idx, 0] + # h5f = h5py.File(self.root_dir + '/' + sample_name, 'r') + # image = np.asarray(h5f['image'][:], np.float32) + + # # this a temporaory process, will be delieted later + # if(len(image.shape) == 3 and image.shape[0] > 1): + # image = np.expand_dims(image, 0) + # sample = {'image': image, 'names':sample_name} + + # if('label' in h5f): + # label = np.asarray(h5f['label'][:], np.uint8) + # if(len(label.shape) == 3 and label.shape[0] > 1): + # label = np.expand_dims(label, 0) + # sample['label'] = label + # if self.transform: + # sample = self.transform(sample) + # return sample + def __getitem__(self, idx): names_list, image_list = [], [] for i in range (self.modal_num): diff --git a/pymic/loss/seg/abstract.py b/pymic/loss/seg/abstract.py index f42d816..68643e8 100644 --- a/pymic/loss/seg/abstract.py +++ b/pymic/loss/seg/abstract.py @@ -16,9 +16,20 @@ class AbstractSegLoss(nn.Module): def __init__(self, params = None): super(AbstractSegLoss, self).__init__() if(params is None): - self.softmax = True + self.acti_func = 'softmax' else: - self.softmax = params.get('loss_softmax', True) + self.acti_func = params.get('loss_acti_func', 'softmax') + + def get_activated_prediction(self, p, acti_func = 'softmax'): + if(acti_func == "softmax"): + p = nn.Softmax(dim = 1)(p) + elif(acti_func == "tanh"): + p = nn.Tanh()(p) + elif(acti_func == "sigmoid"): + p = nn.Sigmoid()(p) + else: + raise ValueError("activation for output is not supported: {0:}".format(acti_func)) + return p def forward(self, loss_input_dict): """ diff --git a/pymic/loss/seg/ce.py b/pymic/loss/seg/ce.py index 9524d57..4edbbc3 100644 --- a/pymic/loss/seg/ce.py +++ b/pymic/loss/seg/ce.py @@ -13,8 +13,10 @@ class CrossEntropyLoss(AbstractSegLoss): The parameters should be written in the `params` dictionary, and it has the following fields: - :param `loss_softmax`: (optional, bool) - Apply softmax to the prediction of network or not. Default is True. + :param `loss_acti_func`: (optional, string) + Apply an activation function to the prediction of network or not, for example, + 'softmax' for image segmentation tasks, 'tanh' for reconstruction tasks, and None + means no activation is used. """ def __init__(self, params = None): super(CrossEntropyLoss, self).__init__(params) @@ -27,8 +29,9 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) + predict = reshape_tensor_to_2D(predict) soft_y = reshape_tensor_to_2D(soft_y) @@ -74,8 +77,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) predict = reshape_tensor_to_2D(predict) soft_y = reshape_tensor_to_2D(soft_y) gce = (1.0 - torch.pow(predict, self.q)) / self.q * soft_y diff --git a/pymic/loss/seg/dice.py b/pymic/loss/seg/dice.py index 2c2df32..c423c2c 100644 --- a/pymic/loss/seg/dice.py +++ b/pymic/loss/seg/dice.py @@ -25,8 +25,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) predict = reshape_tensor_to_2D(predict) soft_y = reshape_tensor_to_2D(soft_y) if(pix_w is not None): @@ -52,8 +52,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) predict = 1.0 - predict[:, :1, :, :, :] soft_y = 1.0 - soft_y[:, :1, :, :, :] predict = reshape_tensor_to_2D(predict) @@ -76,8 +76,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) predict = reshape_tensor_to_2D(predict) soft_y = reshape_tensor_to_2D(soft_y) num_class = list(predict.size())[1] @@ -115,8 +115,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) predict = reshape_tensor_to_2D(predict) soft_y = reshape_tensor_to_2D(soft_y) @@ -149,8 +149,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) predict = reshape_tensor_to_2D(predict) soft_y = reshape_tensor_to_2D(soft_y) diff --git a/pymic/loss/seg/exp_log.py b/pymic/loss/seg/exp_log.py index c1b3f00..8c0d494 100644 --- a/pymic/loss/seg/exp_log.py +++ b/pymic/loss/seg/exp_log.py @@ -32,8 +32,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) predict = reshape_tensor_to_2D(predict) soft_y = reshape_tensor_to_2D(soft_y) diff --git a/pymic/loss/seg/mse.py b/pymic/loss/seg/mse.py index 5b657c5..eb53af4 100644 --- a/pymic/loss/seg/mse.py +++ b/pymic/loss/seg/mse.py @@ -19,8 +19,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) mse = torch.square(predict - soft_y) mse = torch.mean(mse) return mse @@ -44,8 +44,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) mae = torch.abs(predict - soft_y) if(weight is None): mae = torch.mean(mae) diff --git a/pymic/net/net2d/unet2d.py b/pymic/net/net2d/unet2d.py index 7d14a2e..be69f0d 100644 --- a/pymic/net/net2d/unet2d.py +++ b/pymic/net/net2d/unet2d.py @@ -5,7 +5,6 @@ import torch import torch.nn as nn import numpy as np -from torch.nn.functional import interpolate class ConvBlock(nn.Module): """ @@ -61,8 +60,7 @@ class UpBlock(nn.Module): 0 (or `TransConv`), 1 (`Nearest`), 2 (`Bilinear`), 3 (`Bicubic`). The default value is 2 (`Bilinear`). """ - def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, - up_mode = 2): + def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, up_mode = 2): super(UpBlock, self).__init__() if(isinstance(up_mode, int)): up_mode_values = ["transconv", "nearest", "bilinear", "bicubic"] @@ -144,7 +142,7 @@ class Decoder(nn.Module): :param up_mode: (string or int) The mode for upsampling. The allowed values are: 0 (or `TransConv`), 1 (or `Nearest`), 2 (or `Bilinear`), 3 (or `Bicubic`). The default value is 2 (or `Bilinear`). - :param multiscale_pred: (bool) Get multiscale prediction. + :param multiscale_pred: (bool) Get multi-scale prediction. """ def __init__(self, params): super(Decoder, self).__init__() @@ -165,10 +163,14 @@ def __init__(self, params): self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.up_mode) self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 1) - if(self.mul_pred): + if(self.mul_pred and (self.training or self.mul_infer)): self.out_conv1 = nn.Conv2d(self.ft_chns[1], self.n_class, kernel_size = 1) self.out_conv2 = nn.Conv2d(self.ft_chns[2], self.n_class, kernel_size = 1) self.out_conv3 = nn.Conv2d(self.ft_chns[3], self.n_class, kernel_size = 1) + self.stage = 'train' + + def set_stage(self, stage): + self.stage = stage def forward(self, x): if(len(self.ft_chns) == 5): @@ -183,7 +185,7 @@ def forward(self, x): x_d1 = self.up3(x_d2, x1) x_d0 = self.up4(x_d1, x0) output = self.out_conv(x_d0) - if(self.mul_pred and self.training): + if(self.mul_pred and self.stage == 'train'): output1 = self.out_conv1(x_d1) output2 = self.out_conv2(x_d2) output3 = self.out_conv3(x_d3) @@ -239,6 +241,10 @@ def get_default_parameters(self, params): logging.info("{0:} = {1:}".format(key, params[key])) return params + def set_stage(self, stage): + self.stage = stage + self.decoder.set_stage(stage) + def forward(self, x): x_shape = list(x.shape) if(len(x_shape) == 5): @@ -258,28 +264,4 @@ def forward(self, x): new_shape = [N, D] + list(output.shape)[1:] output = torch.transpose(torch.reshape(output, new_shape), 1, 2) - return output - - -if __name__ == "__main__": - params = {'in_chns':4, - 'feature_chns':[16, 32, 64, 128, 256], - 'dropout': [0, 0, 0.3, 0.4, 0.5], - 'class_num': 2, - 'up_mode': 0, - 'multiscale_pred': True} - Net = UNet2D(params) - Net = Net.double() - - x = np.random.rand(4, 4, 10, 256, 256) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - out = Net(xt) - if params['multiscale_pred']: - for y in out: - print(len(y.size())) - y = y.detach().numpy() - print(y.shape) - else: - print(out.shape) + return output \ No newline at end of file diff --git a/pymic/net/net2d/unet2d_canet.py b/pymic/net/net2d/unet2d_canet.py index 7aeba84..defcb60 100644 --- a/pymic/net/net2d/unet2d_canet.py +++ b/pymic/net/net2d/unet2d_canet.py @@ -684,72 +684,6 @@ def forward(self, inputs): xt = torch.from_numpy(x) xt = torch.tensor(xt) - y = Net(xt) - print(len(y.size())) - y = y.detach().numpy() - print(y.shape) -axpool3(conv3) - - conv4 = self.conv4(maxpool3) - maxpool4 = self.maxpool4(conv4) - - # Gating Signal Generation - center = self.center(maxpool4) - - # Attention Mechanism - # Upscaling Part (Decoder) - up4 = self.up_concat4(conv4, center) - g_conv4 = self.nonlocal4_2(up4) - - up4, att_weight4 = self.up4(g_conv4) - g_conv3, att3 = self.attentionblock3(conv3, up4) - - # atten3_map = att3.cpu().detach().numpy().astype(np.float) - # atten3_map = ndimage.interpolation.zoom(atten3_map, [1.0, 1.0, 224 / atten3_map.shape[2], - # 300 / atten3_map.shape[3]], order=0) - - up3 = self.up_concat3(g_conv3, up4) - up3, att_weight3 = self.up3(up3) - g_conv2, att2 = self.attentionblock2(conv2, up3) - - # atten2_map = att2.cpu().detach().numpy().astype(np.float) - # atten2_map = ndimage.interpolation.zoom(atten2_map, [1.0, 1.0, 224 / atten2_map.shape[2], - # 300 / atten2_map.shape[3]], order=0) - - up2 = self.up_concat2(g_conv2, up3) - up2, att_weight2 = self.up2(up2) - # g_conv1, att1 = self.attentionblock1(conv1, up2) - - # atten1_map = att1.cpu().detach().numpy().astype(np.float) - # atten1_map = ndimage.interpolation.zoom(atten1_map, [1.0, 1.0, 224 / atten1_map.shape[2], - # 300 / atten1_map.shape[3]], order=0) - up1 = self.up_concat1(conv1, up2) - up1, att_weight1 = self.up1(up1) - - # Deep Supervision - dsv4 = self.dsv4(up4) - dsv3 = self.dsv3(up3) - dsv2 = self.dsv2(up2) - dsv1 = self.dsv1(up1) - dsv_cat = torch.cat([dsv1, dsv2, dsv3, dsv4], dim=1) - out = self.scale_att(dsv_cat) - - out = self.final(out) - - return out - - - -if __name__ == "__main__": - params = {'in_chns':3, - 'class_num':2} - Net = CANet(params) - Net = Net.double() - - x = np.random.rand(2, 3, 256, 256) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - y = Net(xt) print(len(y.size())) y = y.detach().numpy() diff --git a/pymic/net/net_dict_cls.py b/pymic/net/net_dict_cls.py index 7996e59..3a7808b 100644 --- a/pymic/net/net_dict_cls.py +++ b/pymic/net/net_dict_cls.py @@ -3,7 +3,7 @@ Built-in networks for classification. * resnet18 :mod:`pymic.net.cls.torch_pretrained_net.ResNet18` -* vgg16 :mod:`pymic.net.cls.torch_pretrained_net.VGG16` +* vgg16 :mod:`pymic.net.cls.torch_pretrained_net.VGG16` * mobilenetv2 :mod:`pymic.net.cls.torch_pretrained_net.MobileNetV2` """ diff --git a/pymic/net/net_dict_seg.py b/pymic/net/net_dict_seg.py index 8862fd1..e381421 100644 --- a/pymic/net/net_dict_seg.py +++ b/pymic/net/net_dict_seg.py @@ -17,6 +17,7 @@ from __future__ import print_function, division from pymic.net.net2d.unet2d import UNet2D from pymic.net.net2d.unet2d_dual_branch import UNet2D_DualBranch +from pymic.net.net2d.unet2d_canet import CANet from pymic.net.net2d.unet2d_cct import UNet2D_CCT from pymic.net.net2d.unet2d_mcnet import MCNet2D from pymic.net.net2d.cople_net import COPLENet @@ -48,6 +49,7 @@ 'UNet2D_DualBranch': UNet2D_DualBranch, 'UNet2D_CCT': UNet2D_CCT, 'MCNet2D': MCNet2D, + 'CANet': CANet, 'COPLENet': COPLENet, 'AttentionUNet2D': AttentionUNet2D, 'NestedUNet2D': NestedUNet2D, diff --git a/pymic/net_run/agent_abstract.py b/pymic/net_run/agent_abstract.py index 7a49a2b..f9575ab 100644 --- a/pymic/net_run/agent_abstract.py +++ b/pymic/net_run/agent_abstract.py @@ -154,6 +154,15 @@ def get_checkpoint_name(self): ckpt_name = self.config['testing']['ckpt_name'] return ckpt_name + @abstractmethod + def get_stage_transform_from_config(self, stage): + """ + Get the transform list required by dataset for training, validation or inference stage. + + :param stage: (str) `train`, `valid` or `test`. + """ + raise(ValueError("not implemented")) + @abstractmethod def get_stage_dataset_from_config(self, stage): """ @@ -261,13 +270,13 @@ def worker_init_fn(worker_id): bn_train = self.config['dataset']['train_batch_size'] bn_valid = self.config['dataset'].get('valid_batch_size', 1) - num_worker = self.config['dataset'].get('num_worker', 16) + num_worker = self.config['dataset'].get('num_worker', 8) g_train, g_valid = torch.Generator(), torch.Generator() g_train.manual_seed(self.random_seed) g_valid.manual_seed(self.random_seed) self.train_loader = torch.utils.data.DataLoader(self.train_set, batch_size = bn_train, shuffle=True, num_workers= num_worker, - worker_init_fn=worker_init, generator = g_train) + worker_init_fn=worker_init, generator = g_train, drop_last = True) self.valid_loader = torch.utils.data.DataLoader(self.valid_set, batch_size = bn_valid, shuffle=False, num_workers= num_worker, worker_init_fn=worker_init, generator = g_valid) diff --git a/pymic/net_run/agent_rec.py b/pymic/net_run/agent_rec.py index 634ea20..cd311ad 100644 --- a/pymic/net_run/agent_rec.py +++ b/pymic/net_run/agent_rec.py @@ -193,7 +193,7 @@ def train_valid(self): self.min_val_loss = 10000.0 self.max_val_it = 0 self.best_model_wts = None - self.checkpoint = None + checkpoint = None # initialize the network with pre-trained weights ckpt_init_name = self.config['training'].get('ckpt_init_name', None) ckpt_init_mode = self.config['training'].get('ckpt_init_mode', 0) @@ -212,7 +212,7 @@ def train_valid(self): else: self.net.load_state_dict(pretrained_dict, strict = False) if(ckpt_init_mode > 0): # Load other information - self.min_val_loss = self.checkpoint.get('valid_loss', 10000) + self.min_val_loss = checkpoint.get('valid_loss', 10000) iter_start = checkpoint['iteration'] self.max_val_it = iter_start self.best_model_wts = checkpoint['model_state_dict'] @@ -300,6 +300,7 @@ def save_outputs(self, data): if(isinstance(pred, (list, tuple))): pred = pred[0] pred = np.tanh(pred) + # pred = scipy.special.expit(pred) # save the output predictions test_dir = self.config['dataset'].get('test_dir', None) if(test_dir is None): diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 14f7567..2d6d489 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -46,8 +46,6 @@ def get_transform_names_and_parameters(self, stage): """ assert(stage in ['train', 'valid', 'test']) transform_key = stage + '_transform' - if(stage == "valid" and transform_key not in self.config['dataset']): - transform_key = "train_transform" trans_names = self.config['dataset'][transform_key] trans_params = self.config['dataset'] trans_params['task'] = self.task_type @@ -179,6 +177,8 @@ def training(self): inputs, labels_prob = mixup(inputs, labels_prob) # for debug + # if(it > 10): + # break # for i in range(inputs.shape[0]): # image_i = inputs[i][0] # # label_i = labels_prob[i][1] @@ -192,6 +192,7 @@ def training(self): # save_nd_array_as_image(label_i, label_name, reference_name = None) # # save_nd_array_as_image(pixw_i, weight_name, reference_name = None) # continue + inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device) @@ -241,6 +242,9 @@ def validation(self): self.net.eval() for data in validIter: inputs = self.convert_tensor_type(data['image']) + if('label_prob' not in data): + raise ValueError("label_prob is not found in validation data, make sure" + + "that LabelToProbability is used in valid_transform.") labels_prob = self.convert_tensor_type(data['label_prob']) inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device) batch_n = inputs.shape[0] diff --git a/pymic/net_run/get_optimizer.py b/pymic/net_run/get_optimizer.py index 53de5e3..771b448 100644 --- a/pymic/net_run/get_optimizer.py +++ b/pymic/net_run/get_optimizer.py @@ -13,8 +13,9 @@ def get_optimizer(name, net_params, optim_params): # see https://www.codeleading.com/article/44815584159/ param_group = [{'params': net_params, 'initial_lr': lr}] if(keyword_match(name, "SGD")): + nesterov = optim_params.get('nesterov', True) return optim.SGD(param_group, lr, - momentum = momentum, weight_decay = weight_decay, nesterov = True) + momentum = momentum, weight_decay = weight_decay, nesterov = nesterov) elif(keyword_match(name, "Adam")): return optim.Adam(param_group, lr, weight_decay = weight_decay) elif(keyword_match(name, "SparseAdam")): diff --git a/pymic/net_run/preprocess.py b/pymic/net_run/preprocess.py new file mode 100644 index 0000000..f2bbe0f --- /dev/null +++ b/pymic/net_run/preprocess.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import logging +import os +import sys +import shutil +from datetime import datetime +from pymic import TaskType +from pymic.util.parse_config import * +from pymic.net_run.agent_cls import ClassificationAgent +from pymic.net_run.agent_seg import SegmentationAgent +from pymic.net_run.semi_sup import SSLMethodDict +from pymic.net_run.weak_sup import WSLMethodDict +from pymic.net_run.self_sup import SelfSupMethodDict +from pymic.net_run.noisy_label import NLLMethodDict +# from pymic.net_run.self_sup import SelfSLSegAgent + +def get_seg_rec_agent(config, sup_type): + assert(sup_type in ['fully_sup', 'semi_sup', 'self_sup', 'weak_sup', 'noisy_label']) + if(sup_type == 'fully_sup'): + logging.info("\n********** Fully Supervised Learning **********\n") + agent = SegmentationAgent(config, 'train') + elif(sup_type == 'semi_sup'): + logging.info("\n********** Semi Supervised Learning **********\n") + method = config['semi_supervised_learning']['method_name'] + agent = SSLMethodDict[method](config, 'train') + elif(sup_type == 'weak_sup'): + logging.info("\n********** Weakly Supervised Learning **********\n") + method = config['weakly_supervised_learning']['method_name'] + agent = WSLMethodDict[method](config, 'train') + elif(sup_type == 'noisy_label'): + logging.info("\n********** Noisy Label Learning **********\n") + method = config['noisy_label_learning']['method_name'] + agent = NLLMethodDict[method](config, 'train') + elif(sup_type == 'self_sup'): + logging.info("\n********** Self Supervised Learning **********\n") + method = config['self_supervised_learning']['method_name'] + agent = SelfSupMethodDict[method](config, 'train') + else: + raise ValueError("undefined supervision type: {0:}".format(sup_type)) + return agent + +def main(): + """ + The main function for running a network for training. + """ + if(len(sys.argv) < 2): + print('Number of arguments should be 2. e.g.') + print(' pymic_train config.cfg') + exit() + cfg_file = str(sys.argv[1]) + if(not os.path.isfile(cfg_file)): + raise ValueError("The config file does not exist: " + cfg_file) + config = parse_config(cfg_file) + config = synchronize_config(config) + log_dir = config['training']['ckpt_save_dir'] + if(not os.path.exists(log_dir)): + os.makedirs(log_dir, exist_ok=True) + dst_cfg = cfg_file if "/" not in cfg_file else cfg_file.split("/")[-1] + shutil.copy(cfg_file, log_dir + "/" + dst_cfg) + datetime_str = str(datetime.now())[:-7].replace(":", "_") + if sys.version.startswith("3.9"): + logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(datetime_str), + level=logging.INFO, format='%(message)s', force=True) # for python 3.9 + else: + logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(datetime_str), + level=logging.INFO, format='%(message)s') # for python 3.6 + logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) + logging_config(config) + task = config['dataset']['task_type'] + if(task == TaskType.CLASSIFICATION_ONE_HOT or task == TaskType.CLASSIFICATION_COEXIST): + agent = ClassificationAgent(config, 'train') + else: + sup_type = config['dataset'].get('supervise_type', 'fully_sup') + agent = get_seg_rec_agent(config, sup_type) + + agent.run() + +if __name__ == "__main__": + main() + + diff --git a/pymic/net_run/self_sup/__init__.py b/pymic/net_run/self_sup/__init__.py index 73308e6..d73e42a 100644 --- a/pymic/net_run/self_sup/__init__.py +++ b/pymic/net_run/self_sup/__init__.py @@ -1,3 +1,10 @@ from __future__ import absolute_import -from pymic.net_run.self_sup.self_sl_agent import SelfSLSegAgent -from pymic.net_run.self_sup.self_patch_mix_agent import SelfSLPatchMixAgent \ No newline at end of file +from pymic.net_run.self_sup.self_genesis import SelfSupModelGenesis +from pymic.net_run.self_sup.self_patch_swapping import SelfSupPatchSwapping +from pymic.net_run.self_sup.self_volume_fusion import SelfSupVolumeFusion + +SelfSupMethodDict = { + 'ModelGenesis': SelfSupModelGenesis, + 'PatchSwapping': SelfSupPatchSwapping, + 'VolumeFusion': SelfSupVolumeFusion + } \ No newline at end of file diff --git a/pymic/net_run/self_sup/self_genesis.py b/pymic/net_run/self_sup/self_genesis.py new file mode 100644 index 0000000..85ee194 --- /dev/null +++ b/pymic/net_run/self_sup/self_genesis.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import copy +import logging +import time +from pymic.net_run.agent_rec import ReconstructionAgent + +class SelfSupModelGenesis(ReconstructionAgent): + """ + Patch swapping-based self-supervised learning. + + Reference: Liang Chen et al., Self-supervised learning for medical image analysis + using image context restoration, Medical Image Analysis, 2019. + + A PatchSwaping transform need to be used in the cnfiguration. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `self_supervised_learning` is needed. See :doc:`usage.selfsl` for details. + + In the configuration file, it should look like this: + ``` + [dataset] + task_type = rec + supervise_type = self_sup + train_transform = [..., ..., PatchSwaping] + valid_transform = [..., ..., PatchSwaping] + + [self_supervised_learning] + method_name = ModelGenesis + + """ + def __init__(self, config, stage = 'train'): + super(SelfSupModelGenesis, self).__init__(config, stage) + + def get_transform_names_and_parameters(self, stage): + trans_names, trans_params = super(SelfSupModelGenesis, self).get_transform_names_and_parameters(stage) + # if(stage == 'train'): + # print('training transforms:', trans_names) + # if("LocalShuffling" not in trans_names): + # raise ValueError("LocalShuffling is required for model genesis, \ + # but it is not given in training transform") + # if("NonLinearTransform" not in trans_names): + # raise ValueError("NonLinearTransform is required for model genesis, \ + # but it is not given in training transform") + # if("InOutPainting" not in trans_names): + # raise ValueError("InOutPainting is required for model genesis, \ + # but it is not given in training transform") + return trans_names, trans_params diff --git a/pymic/net_run/self_sup/self_patch_swapping.py b/pymic/net_run/self_sup/self_patch_swapping.py new file mode 100644 index 0000000..1692fa7 --- /dev/null +++ b/pymic/net_run/self_sup/self_patch_swapping.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import copy +import logging +import time +from pymic.net_run.agent_rec import ReconstructionAgent + +class SelfSupPatchSwapping(ReconstructionAgent): + """ + Patch swapping-based self-supervised learning. + + Reference: Liang Chen et al., Self-supervised learning for medical image analysis + using image context restoration, Medical Image Analysis, 2019. + + A PatchSwaping transform need to be used in the cnfiguration. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `self_supervised_learning` is needed. See :doc:`usage.selfsl` for details. + + In the configuration file, it should look like this: + ``` + [dataset] + task_type = rec + supervise_type = self_sup + train_transform = [..., ..., PatchSwaping] + valid_transform = [..., ..., PatchSwaping] + + [self_supervised_learning] + method_name = PatchSwapping + + """ + def __init__(self, config, stage = 'train'): + super(SelfSupPatchSwapping, self).__init__(config, stage) + + def get_transform_names_and_parameters(self, stage): + trans_names, trans_params = super(SelfSupPatchSwapping, self).get_transform_names_and_parameters(stage) + if(stage == 'train'): + print('training transforms:', trans_names) + assert("PatchSwaping" in trans_names) + return trans_names, trans_params + diff --git a/pymic/net_run/self_sup/self_sl_agent.py b/pymic/net_run/self_sup/self_sl_agent.py index c352adf..45bee26 100644 --- a/pymic/net_run/self_sup/self_sl_agent.py +++ b/pymic/net_run/self_sup/self_sl_agent.py @@ -6,6 +6,7 @@ from pymic.net_run.agent_rec import ReconstructionAgent + class SelfSLSegAgent(ReconstructionAgent): """ Abstract class for self-supervised segmentation. @@ -17,7 +18,7 @@ class SelfSLSegAgent(ReconstructionAgent): In the configuration dictionary, in addition to the four sections (`dataset`, `network`, `training` and `inference`) used in fully supervised learning, an - extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. + extra section `self_supervised_learning` is needed. See :doc:`usage.selfsl` for details. """ def __init__(self, config, stage = 'train'): super(SelfSLSegAgent, self).__init__(config, stage) diff --git a/pymic/net_run/self_sup/self_patch_mix_agent.py b/pymic/net_run/self_sup/self_volume_fusion.py similarity index 71% rename from pymic/net_run/self_sup/self_patch_mix_agent.py rename to pymic/net_run/self_sup/self_volume_fusion.py index e30a131..91fe088 100644 --- a/pymic/net_run/self_sup/self_patch_mix_agent.py +++ b/pymic/net_run/self_sup/self_volume_fusion.py @@ -32,11 +32,13 @@ from pymic.util.post_process import PostProcessDict from pymic.util.image_process import convert_label from pymic.util.parse_config import * +from pymic.util.general import get_one_hot_seg from pymic.io.image_read_write import save_nd_array_as_image -from pymic.net_run.self_sup.util import patch_mix +from pymic.net_run.self_sup.util import volume_fusion from pymic.net_run.agent_seg import SegmentationAgent -class SelfSLPatchMixAgent(SegmentationAgent): + +class SelfSupVolumeFusion(SegmentationAgent): """ Abstract class for self-supervised segmentation. @@ -50,16 +52,15 @@ class SelfSLPatchMixAgent(SegmentationAgent): extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. """ def __init__(self, config, stage = 'train'): - super(SelfSLPatchMixAgent, self).__init__(config, stage) + super(SelfSupVolumeFusion, self).__init__(config, stage) def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] - fg_num = self.config['network']['class_num'] - 1 - patch_num = self.config['patch_mix']['patch_num_range'] - size_d = self.config['patch_mix']['patch_depth_range'] - size_h = self.config['patch_mix']['patch_height_range'] - size_w = self.config['patch_mix']['patch_width_range'] + cls_num = self.config['network']['class_num'] + block_range = self.config['self_supervised_learning']['VolumeFusion_block_range'.lower()] + size_min = self.config['self_supervised_learning']['VolumeFusion_size_min'.lower()] + size_max = self.config['self_supervised_learning']['VolumeFusion_size_max'.lower()] train_loss = 0 train_dice_list = [] @@ -72,16 +73,16 @@ def training(self): data = next(self.trainIter) # get the inputs inputs = self.convert_tensor_type(data['image']) - inputs, labels_prob = patch_mix(inputs, fg_num, patch_num, size_d, size_h, size_w) + inputs, labels = volume_fusion(inputs, cls_num - 1, block_range, size_min, size_max) + labels_prob = get_one_hot_seg(labels, cls_num) - # # for debug + # for debug # if(it==10): # break # for i in range(inputs.shape[0]): # image_i = inputs[i][0] # label_i = np.argmax(labels_prob[i], axis = 0) # # pixw_i = pix_w[i][0] - # print(image_i.shape, label_i.shape) # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) # label_name = "temp/label_{0:}_{1:}.nii.gz".format(it, i) # # weight_name= "temp/weight_{0:}_{1:}.nii.gz".format(it, i) @@ -116,29 +117,3 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'avg_fg_dice':train_avg_dice,\ 'class_dice': train_cls_dice} return train_scalers - -def main(): - cfg_file = str(sys.argv[1]) - if(not os.path.isfile(cfg_file)): - raise ValueError("The config file does not exist: " + cfg_file) - config = parse_config(cfg_file) - config = synchronize_config(config) - log_dir = config['training']['ckpt_save_dir'] - if(not os.path.exists(log_dir)): - os.makedirs(log_dir, exist_ok=True) - dst_cfg = cfg_file if "/" not in cfg_file else cfg_file.split("/")[-1] - shutil.copy(cfg_file, log_dir + "/" + dst_cfg) - if sys.version.startswith("3.9"): - logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(str(datetime.now())[:-7]), - level=logging.INFO, format='%(message)s', force=True) # for python 3.9 - else: - logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(str(datetime.now())[:-7]), - level=logging.INFO, format='%(message)s') # for python 3.6 - logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) - logging_config(config) - agent = SelfSLPatchMixAgent(config) - agent.run() - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/pymic/net_run/self_sup/util.py b/pymic/net_run/self_sup/util.py index 0ab1670..db27702 100644 --- a/pymic/net_run/self_sup/util.py +++ b/pymic/net_run/self_sup/util.py @@ -7,7 +7,7 @@ from scipy import ndimage from pymic.io.image_read_write import * from pymic.util.image_process import * -from pymic.util.general import get_one_hot_seg + def get_human_region_mask(img): """ @@ -137,111 +137,36 @@ def get_human_body_mask_and_crop(input_dir, out_img_dir, out_mask_dir): sitk.WriteImage(mask_obj, mask_name) -def patch_mix(x, fg_num, patch_num, size_d, size_h, size_w): +def volume_fusion(x, fg_num, block_range, size_min, size_max): """ - Copy a sub region of an impage and paste to another one to generate + Fuse a subregion of an impage with another one to generate images and labels for self-supervised segmentation. + input x should be a batch of tensors """ + #n_min, n_max, N, C, D, H, W = list(x.shape) - fg_mask = torch.zeros_like(x) + fg_mask = torch.zeros_like(x).to(torch.int32) # generate mask for n in range(N): - p_num = random.randint(patch_num[0], patch_num[1]) + p_num = random.randint(block_range[0], block_range[1]) for i in range(p_num): - d = random.randint(size_d[0], size_d[1]) - h = random.randint(size_h[0], size_h[1]) - w = random.randint(size_w[0], size_w[1]) - d_c = random.randint(0, D) - h_c = random.randint(0, H) - w_c = random.randint(0, W) - d0, d1 = max(0, d_c - d), min(D, d_c + d) - h0, h1 = max(0, h_c - h), min(H, h_c + h) - w0, w1 = max(0, w_c - w), min(W, w_c + w) - temp_m = torch.ones([C, d1-d0, h1-h0, w1-w0]) * random.randint(1, fg_num) + d = random.randint(size_min[0], size_max[0]) + h = random.randint(size_min[1], size_max[1]) + w = random.randint(size_min[2], size_max[2]) + dc = random.randint(0, D - 1) + hc = random.randint(0, H - 1) + wc = random.randint(0, W - 1) + d0 = dc - d // 2 + h0 = hc - h // 2 + w0 = wc - w // 2 + d1 = min(D, d0 + d) + h1 = min(H, h0 + h) + w1 = min(W, w0 + w) + d0, h0, w0 = max(0, d0), max(0, h0), max(0, w0) + temp_m = torch.ones([C, d1 - d0, h1 - h0, w1 - w0]) * random.randint(1, fg_num) fg_mask[n, :, d0:d1, h0:h1, w0:w1] = temp_m fg_w = fg_mask * 1.0 / fg_num x_roll = torch.roll(x, 1, 0) x_fuse = fg_w*x_roll + (1.0 - fg_w)*x - y_prob = get_one_hot_seg(fg_mask.to(torch.int32), fg_num + 1) - return x_fuse, y_prob - -def create_mixed_dataset(input_dir, output_dir, fg_num = 1, crop_num = 1, patch_size=[128,128,128], - mask_dir = None, data_format = "nii.gz"): - """ - Create dataset based on patch mix. - - :param input_dir: (str) The path of folder for input images - :param output_dir: (str) The path of folder for output images - :param fg_num: (int) The number of foreground classes - :param crop_num: (int) The number of patches to crop for each input image - :param mask: ND array to specify a mask, or 'default' or None. If default, - a mask for body region is automatically generated (just for CT). - :param data_format: (str) The format of images. - """ - img_names = os.listdir(input_dir) - img_names = [item for item in img_names if item.endswith(data_format)] - img_names = sorted(img_names) - out_img_dir = output_dir + "/image" - out_lab_dir = output_dir + "/label" - if(not os.path.exists(out_img_dir)): - os.mkdir(out_img_dir) - if(not os.path.exists(out_lab_dir)): - os.mkdir(out_lab_dir) - - img_num = len(img_names) - print("image number", img_num) - i_range = range(img_num) - j_range = list(i_range) - random.shuffle(j_range) - for i in i_range: - print(i, img_names[i]) - j = j_range[i] - if(i == j): - j = i + 1 if i < img_num - 1 else 0 - img_i = load_image_as_nd_array(input_dir + "/" + img_names[i])['data_array'] - img_j = load_image_as_nd_array(input_dir + "/" + img_names[j])['data_array'] - - chns = img_i.shape[0] - crop_size = [chns] + patch_size - # random crop to patch size - if(mask_dir is None): - mask_i = get_human_region_mask(img_i) - mask_j = get_human_region_mask(img_j) - else: - mask_i = load_image_as_nd_array(mask_dir + "/" + img_names[i])['data_array'] - mask_j = load_image_as_nd_array(mask_dir + "/" + img_names[j])['data_array'] - for k in range(crop_num): - # if(mask is None): - # img_ik = random_crop_ND_volume(img_i, [chns, 96, 96, 96]) - # img_jk = random_crop_ND_volume(img_j, [chns, 96, 96, 96]) - # else: - - img_ik = random_crop_ND_volume_with_mask(img_i, crop_size, mask_i) - img_jk = random_crop_ND_volume_with_mask(img_j, crop_size, mask_j) - C, D, H, W = img_ik.shape - # generate mask - fg_mask = np.zeros_like(img_ik, np.uint8) - patch_num = random.randint(4, 40) - for patch in range(patch_num): - d = random.randint(4, 20) # half of window size - h = random.randint(4, 40) - w = random.randint(4, 40) - d_c = random.randint(0, D) - h_c = random.randint(0, H) - w_c = random.randint(0, W) - d0, d1 = max(0, d_c - d), min(D, d_c + d) - h0, h1 = max(0, h_c - h), min(H, h_c + h) - w0, w1 = max(0, w_c - w), min(W, w_c + w) - temp_m = np.ones([C, d1-d0, h1-h0, w1-w0]) * random.randint(1, fg_num) - fg_mask[:, d0:d1, h0:h1, w0:w1] = temp_m - fg_w = fg_mask * 1.0 / fg_num - x_fuse = fg_w*img_jk + (1.0 - fg_w)*img_ik - - out_name = img_names[i] - if crop_num > 1: - out_name = out_name.replace(".nii.gz", "_{0:}.nii.gz".format(k)) - save_nd_array_as_image(x_fuse[0], out_img_dir + "/" + out_name, - reference_name = input_dir + "/" + img_names[i]) - save_nd_array_as_image(fg_mask[0], out_lab_dir + "/" + out_name, - reference_name = input_dir + "/" + img_names[i]) - + # y_prob = get_one_hot_seg(fg_mask.to(torch.int32), fg_num + 1) + return x_fuse, fg_mask diff --git a/pymic/net_run/semi_sup/__init__.py b/pymic/net_run/semi_sup/__init__.py index 39ca5dc..d3095f6 100644 --- a/pymic/net_run/semi_sup/__init__.py +++ b/pymic/net_run/semi_sup/__init__.py @@ -1,17 +1,18 @@ from __future__ import absolute_import -# from . import * from pymic.net_run.semi_sup.ssl_abstract import SSLSegAgent from pymic.net_run.semi_sup.ssl_em import SSLEntropyMinimization from pymic.net_run.semi_sup.ssl_mt import SSLMeanTeacher +from pymic.net_run.semi_sup.ssl_mcnet import SSLMCNet from pymic.net_run.semi_sup.ssl_uamt import SSLUncertaintyAwareMeanTeacher from pymic.net_run.semi_sup.ssl_cct import SSLCCT from pymic.net_run.semi_sup.ssl_cps import SSLCPS from pymic.net_run.semi_sup.ssl_urpc import SSLURPC -# SSLMethodDict = {'EntropyMinimization': SSLEntropyMinimization, -# 'MeanTeacher': SSLMeanTeacher, -# 'UAMT': SSLUncertaintyAwareMeanTeacher, -# 'CCT': SSLCCT, -# 'CPS': SSLCPS, -# 'URPC': SSLURPC} \ No newline at end of file +SSLMethodDict = {'EntropyMinimization': SSLEntropyMinimization, + 'MeanTeacher': SSLMeanTeacher, + 'MCNet': SSLMCNet, + 'UAMT': SSLUncertaintyAwareMeanTeacher, + 'CCT': SSLCCT, + 'CPS': SSLCPS, + 'URPC': SSLURPC} \ No newline at end of file diff --git a/pymic/net_run/semi_sup/ssl_abstract.py b/pymic/net_run/semi_sup/ssl_abstract.py index b27edc9..0e05281 100644 --- a/pymic/net_run/semi_sup/ssl_abstract.py +++ b/pymic/net_run/semi_sup/ssl_abstract.py @@ -35,7 +35,7 @@ def get_unlabeled_dataset_from_config(self): """ Create a dataset for the unlabeled images based on configuration. """ - root_dir = self.config['dataset']['root_dir'] + train_dir = self.config['dataset']['train_dir'] modal_num = self.config['dataset'].get('modal_num', 1) transform_names = self.config['dataset']['train_transform_unlab'] @@ -53,7 +53,7 @@ def get_unlabeled_dataset_from_config(self): data_transform = transforms.Compose(self.transform_list) csv_file = self.config['dataset'].get('train_csv_unlab', None) - dataset = NiftyDataset(root_dir=root_dir, + dataset = NiftyDataset(root_dir = train_dir, csv_file = csv_file, modal_num = modal_num, with_label= False, @@ -76,7 +76,7 @@ def worker_init_fn(worker_id): num_worker = self.config['dataset'].get('num_worker', 16) self.train_loader_unlab = torch.utils.data.DataLoader(self.train_set_unlab, batch_size = bn_train_unlab, shuffle=True, num_workers= num_worker, - worker_init_fn=worker_init) + worker_init_fn=worker_init, drop_last = True) def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): loss_scalar ={'train':train_scalars['loss'], diff --git a/pymic/net_run/semi_sup/ssl_cps.py b/pymic/net_run/semi_sup/ssl_cps.py index df4b5af..7acfe17 100644 --- a/pymic/net_run/semi_sup/ssl_cps.py +++ b/pymic/net_run/semi_sup/ssl_cps.py @@ -3,12 +3,14 @@ import logging import numpy as np import torch +from random import random from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.io.image_read_write import save_nd_array_as_image from pymic.net_run.semi_sup import SSLSegAgent from pymic.util.ramps import get_rampup_ratio +from pymic.util.general import mixup, tensor_shape_match class SSLCPS(SSLSegAgent): """ @@ -34,7 +36,8 @@ def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] ssl_cfg = self.config['semi_supervised_learning'] - iter_max = self.config['training']['iter_max'] + iter_max = self.config['training']['iter_max'] + mixup_prob = self.config['training'].get('mixup_probability', 0.0) rampup_start = ssl_cfg.get('rampup_start', 0) rampup_end = ssl_cfg.get('rampup_end', iter_max) train_loss = 0 @@ -70,7 +73,8 @@ def training(self): # save_nd_array_as_image(image_i, image_name, reference_name = None) # save_nd_array_as_image(label_i, label_name, reference_name = None) # continue - + if(mixup_prob > 0 and random() < mixup_prob): + x0, y0 = mixup(x0, y0) inputs = torch.cat([x0, x1], dim = 0) inputs, y0 = inputs.to(self.device), y0.to(self.device) diff --git a/pymic/net_run/train.py b/pymic/net_run/train.py index f2bbe0f..50a5fb7 100644 --- a/pymic/net_run/train.py +++ b/pymic/net_run/train.py @@ -9,6 +9,7 @@ from pymic.util.parse_config import * from pymic.net_run.agent_cls import ClassificationAgent from pymic.net_run.agent_seg import SegmentationAgent +from pymic.net_run.agent_rec import ReconstructionAgent from pymic.net_run.semi_sup import SSLMethodDict from pymic.net_run.weak_sup import WSLMethodDict from pymic.net_run.self_sup import SelfSupMethodDict @@ -19,7 +20,10 @@ def get_seg_rec_agent(config, sup_type): assert(sup_type in ['fully_sup', 'semi_sup', 'self_sup', 'weak_sup', 'noisy_label']) if(sup_type == 'fully_sup'): logging.info("\n********** Fully Supervised Learning **********\n") - agent = SegmentationAgent(config, 'train') + if config['dataset']['task_type'] == TaskType.SEGMENTATION: + agent = SegmentationAgent(config, 'train') + else: + agent = ReconstructionAgent(config, 'train') elif(sup_type == 'semi_sup'): logging.info("\n********** Semi Supervised Learning **********\n") method = config['semi_supervised_learning']['method_name'] diff --git a/pymic/transform/crop.py b/pymic/transform/crop.py index 6977639..b821bb2 100644 --- a/pymic/transform/crop.py +++ b/pymic/transform/crop.py @@ -384,4 +384,92 @@ def __call__(self, sample): if(resize): weight = ndimage.interpolation.zoom(weight, scale, order = 1) sample['pixel_weight'] = weight - return sample \ No newline at end of file + return sample + +class RandomSlice(AbstractTransform): + """Randomly selecting N slices from a volume + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `RandomSlice_output_size`: (int) Desired number of slice for output. + :param `RandomSlice_inverse`: (optional, bool) Is inverse transform needed for inference. + Default is `True`. + """ + def __init__(self, params): + self.output_size = params['RandomSlice_output_size'.lower()] + self.shuffle = params.get('RandomSlice_shuffle'.lower(), False) + self.inverse = params.get('RandomSlice_inverse'.lower(), False) + self.task = params['Task'.lower()] + + def __call__(self, sample): + image = sample['image'] + D = image.shape[1] + assert( D >= self.output_size) + slice_idx = list(range(D)) + if(self.shuffle): + random.shuffle(slice_idx) + slice_idx = slice_idx[:self.output_size] + else: + d0 = random.randint(0, D - self.output_size) + d1 = d0 + self.output_size + slice_idx = slice_idx[d0:d1] + sample['image'] = image[:, slice_idx, :, :] + + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + label = sample['label'] + sample['label'] = label[:, slice_idx, :, :] + + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + weight = sample['pixel_weight'] + sample['pixel_weight'] = weight[:, slice_idx, :, :] + + return sample + +class CropHumanRegionFromCT(CenterCrop): + """ + Crop the human region from a CT volume. + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `CropWithBoundingBox_start`: (None, or list/tuple) The start index + along each spatial axis. If None, calculate the start index automatically + so that the cropped region is centered at the mask region defined by the threshold. + :param `CropWithBoundingBox_output_size`: (None or tuple/list): + Desired spatial output size. + If None, set it as the size of bounding box of the mask region defined by the threshold. + :param `CropWithBoundingBox_threshold`: (None or float): + Threshold for obtaining a mask. This is used only when + `CropWithBoundingBox_start` is None. Default is 1.0 + :param `CropWithBoundingBox_inverse`: (optional, bool) Is inverse transform needed for inference. + Default is `True`. + """ + def __init__(self, params): + self.threshold_i = params.get('CropHumanRegionFromCT_intensity_threshold'.lower(), -600) + self.threshold_z = params.get('CropHumanRegionFromCT_zaxis_threshold'.lower(), 0.5) + self.inverse = params.get('CropHumanRegionFromCT_inverse'.lower(), True) + self.task = params['task'] + + def _get_crop_param(self, sample): + image = sample['image'] + input_shape = image.shape + mask = np.asarray(image[0] > self.threshold_i) + mask2d = np.mean(mask, axis = 0) > self.threshold_z + se = np.ones([3,3]) + mask2d = ndimage.binary_opening(mask2d, se, iterations = 2) + mask2d = get_largest_k_components(mask2d, 1) + bbmin, bbmax = get_ND_bounding_box(mask2d, margin = [0, 0]) + crop_min = [0, 0] + bbmin + crop_max = list(input_shape[:2]) + bbmax + sample['CropHumanRegionFromCT_Param'] = json.dumps((input_shape, crop_min, crop_max)) + return sample, crop_min, crop_max + + def _get_param_for_inverse_transform(self, sample): + if(isinstance(sample['CropHumanRegionFromCT_Param'], list) or \ + isinstance(sample['CropHumanRegionFromCT_Param'], tuple)): + params = json.loads(sample['CropHumanRegionFromCT_Param'][0]) + else: + params = json.loads(sample['CropHumanRegionFromCT_Param']) + return params \ No newline at end of file diff --git a/pymic/transform/extract_channel.py b/pymic/transform/extract_channel.py new file mode 100644 index 0000000..c4974be --- /dev/null +++ b/pymic/transform/extract_channel.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import json +import math +import random +import numpy as np +from scipy import ndimage +from pymic import TaskType +from pymic.transform.abstract_transform import AbstractTransform +from pymic.util.image_process import * + + +class ExtractChannel(AbstractTransform): + """ Random flip the image. The shape is [C, D, H, W] or [C, H, W]. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `RandomFlip_flip_depth`: (bool) + Random flip along depth axis or not, only used for 3D images. + :param `RandomFlip_flip_height`: (bool) Random flip along height axis or not. + :param `RandomFlip_flip_width`: (bool) Random flip along width axis or not. + :param `RandomFlip_inverse`: (optional, bool) Is inverse transform needed for inference. + Default is `True`. + """ + def __init__(self, params): + super(ExtractChannel, self).__init__(params) + self.channels = params['ExtractChannel_channels'.lower()] + self.inverse = params.get('ExtractChannel_inverse'.lower(), False) + + def __call__(self, sample): + image = sample['image'] + image_extract = [] + for i in self.channels: + image_extract.append(image[i]) + sample['image'] = np.asarray(image_extract) + return sample diff --git a/pymic/transform/intensity.py b/pymic/transform/intensity.py index d39eb2c..2b19ebc 100644 --- a/pymic/transform/intensity.py +++ b/pymic/transform/intensity.py @@ -37,7 +37,7 @@ def bezier_curve(points, nTimes=1000): t = np.linspace(0.0, 1.0, nTimes) - polynomial_array = np.array([ bernstein_poly(i, nPoints-1, t) for i in range(0, nPoints) ]) + polynomial_array = np.array([ bernstein_poly(i, nPoints-1, t) for i in range(0, nPoints)]) xvals = np.dot(xPoints, polynomial_array) yvals = np.dot(yPoints, polynomial_array) @@ -182,30 +182,54 @@ def __call__(self, sample): class NonLinearTransform(AbstractTransform): def __init__(self, params): super(NonLinearTransform, self).__init__(params) - self.channels = params['NonLinearTransform_channels'.lower()] + self.channels = params.get('NonLinearTransform_channels'.lower(), None) self.prob = params.get('NonLinearTransform_probability'.lower(), 0.5) self.inverse = params.get('NonLinearTransform_inverse'.lower(), False) + self.block_range = params.get('NonLinearTransform_block_range'.lower(), None) + self.block_size = params.get('NonLinearTransform_block_size'.lower(), [8, 16, 16]) + def __apply_nonlinear_transform(self, img): + """ + the input img should be normlized to [0, 1]""" + points = [[0, 0], [random.random(), random.random()], [random.random(), random.random()], [1, 1]] + xvals, yvals = bezier_curve(points, nTimes=10000) + if random.random() < 0.5: # Half chance to get flip + xvals = np.sort(xvals) + else: + xvals, yvals = np.sort(xvals), np.sort(yvals) + + img = np.interp(img, xvals, yvals) + return img + def __call__(self, sample): if(random.random() > self.prob): return sample - image = sample['image'] - for chn in self.channels: - points = [[0, 0], [random.random(), random.random()], [random.random(), random.random()], [1, 1]] - xvals, yvals = bezier_curve(points, nTimes=10000) - if random.random() < 0.5: # Half change to get flip - xvals = np.sort(xvals) - else: - xvals, yvals = np.sort(xvals), np.sort(yvals) + image = sample['image'] + img_shape = image.shape + img_dim = len(img_shape) - 1 + channels = self.channels if self.channels is not None else range(image.shape[0]) + for chn in channels: # normalize the image intensity to [0, 1] before the non-linear tranform img_c = image[chn] - v_min = img_c.min() - v_max = img_c.max() + v_min, v_max = img_c.min(), img_c.max() if(v_min < v_max): img_c = (img_c - v_min)/(v_max - v_min) - img_c = np.interp(img_c, xvals, yvals) + if(self.block_range is None): # apply non-linear transform to the entire image + img_c = self.__apply_nonlinear_transform(img_c) + else: # non-linear transform to random blocks + img_c_sr = copy.deepcopy(img_c) + for n in range(self.block_range[0], self.block_range[1]): + coord_min = [random.randint(0, img_shape[1+i] - self.block_size[i]) \ + for i in range(img_dim)] + window = img_c_sr[coord_min[0]:coord_min[0] + self.block_size[0], + coord_min[1]:coord_min[1] + self.block_size[1], + coord_min[2]:coord_min[2] + self.block_size[2]] + img_c[coord_min[0]:coord_min[0] + self.block_size[0], + coord_min[1]:coord_min[1] + self.block_size[1], + coord_min[2]:coord_min[2] + self.block_size[2]] = \ + self.__apply_nonlinear_transform(window) image[chn] = img_c * (v_max - v_min) + v_min sample['image'] = image return sample @@ -218,9 +242,8 @@ def __init__(self, params): super(LocalShuffling, self).__init__(params) self.inverse = params.get('LocalShuffling_inverse'.lower(), False) self.prob = params.get('LocalShuffling_probability'.lower(), 0.5) - self.block_range = params.get('LocalShuffling_block_range'.lower(), (5000, 10000)) - self.block_size_min = params.get('LocalShuffling_block_size_min'.lower(), None) - self.block_size_max = params.get('LocalShuffling_block_size_max'.lower(), None) + self.block_range = params.get('LocalShuffling_block_range'.lower(), [40, 80]) + self.block_size = params.get('LocalShuffling_block_size'.lower(), [4, 8, 8]) def __call__(self, sample): if(random.random() > self.prob): @@ -231,49 +254,33 @@ def __call__(self, sample): img_dim = len(img_shape) - 1 assert(img_dim == 2 or img_dim == 3) img_out = copy.deepcopy(image) - if(self.block_size_min is None): - block_size_min = [2] * img_dim - elif(isinstance(self.block_size_min, int)): - block_size_min = [self.block_size_min] * img_dim - else: - assert(len(self.block_size_min) == img_dim) - block_size_min = self.block_size_min - - if(self.block_size_max is None): - block_size_max = [img_shape[1+i]//10 for i in range(img_dim)] - elif(isinstance(self.block_size_min, int)): - block_size_max = [self.block_size_max] * img_dim - else: - assert(len(self.block_size_max) == img_dim) - block_size_max = self.block_size_max + block_num = random.randint(self.block_range[0], self.block_range[1]) for n in range(block_num): - block_size = [random.randint(block_size_min[i], block_size_max[i]) \ - for i in range(img_dim)] - coord_min = [random.randint(0, img_shape[1+i] - block_size[i]) \ + coord_min = [random.randint(0, img_shape[1+i] - self.block_size[i]) \ for i in range(img_dim)] if(img_dim == 2): - window = image[:, coord_min[0]:coord_min[0] + block_size[0], - coord_min[1]:coord_min[1] + block_size[1]] - n_pixels = block_size[0] * block_size[1] + window = image[:, coord_min[0]:coord_min[0] + self.block_size[0], + coord_min[1]:coord_min[1] + self.block_size[1]] + n_pixels = self.block_size[0] * self.block_size[1] else: - window = image[:, coord_min[0]:coord_min[0] + block_size[0], - coord_min[1]:coord_min[1] + block_size[1], - coord_min[2]:coord_min[2] + block_size[2]] - n_pixels = block_size[0] * block_size[1] * block_size[2] + window = image[:, coord_min[0]:coord_min[0] + self.block_size[0], + coord_min[1]:coord_min[1] + self.block_size[1], + coord_min[2]:coord_min[2] + self.block_size[2]] + n_pixels = self.block_size[0] * self.block_size[1] * self.block_size[2] window = np.reshape(window, [-1, n_pixels]) np.random.shuffle(np.transpose(window)) window = np.transpose(window) if(img_dim == 2): - window = np.reshape(window, [-1, block_size[0], block_size[1]]) - img_out[:, coord_min[0]:coord_min[0] + block_size[0], - coord_min[1]:coord_min[1] + block_size[1]] = window + window = np.reshape(window, [-1, self.block_size[0], self.block_size[1]]) + img_out[:, coord_min[0]:coord_min[0] + self.block_size[0], + coord_min[1]:coord_min[1] + self.block_size[1]] = window else: - window = np.reshape(window, [-1, block_size[0], block_size[1], block_size[2]]) - img_out[:, coord_min[0]:coord_min[0] + block_size[0], - coord_min[1]:coord_min[1] + block_size[1], - coord_min[2]:coord_min[2] + block_size[2]] = window + window = np.reshape(window, [-1, self.block_size[0], self.block_size[1], self.block_size[2]]) + img_out[:, coord_min[0]:coord_min[0] + self.block_size[0], + coord_min[1]:coord_min[1] + self.block_size[1], + coord_min[2]:coord_min[2] + self.block_size[2]] = window sample['image'] = img_out return sample @@ -285,10 +292,9 @@ def __init__(self, params): super(InPainting, self).__init__(params) self.inverse = params.get('InPainting_inverse'.lower(), False) self.prob = params.get('InPainting_probability'.lower(), 0.5) - self.block_range = params.get('InPainting_block_range'.lower(), (1, 6)) - self.block_size_min = params.get('InPainting_block_size_min'.lower(), None) - self.block_size_max = params.get('InPainting_block_size_max'.lower(), None) - + self.block_range = params.get('InPainting_block_range'.lower(), (20, 40)) + self.block_size = params.get('InPainting_block_size'.lower(), [8, 16, 16]) + def __call__(self, sample): if(random.random() > self.prob): return sample @@ -298,38 +304,21 @@ def __call__(self, sample): img_dim = len(img_shape) - 1 assert(img_dim == 2 or img_dim == 3) - if(self.block_size_min is None): - block_size_min = [img_shape[1+i]//6 for i in range(img_dim)] - elif(isinstance(self.block_size_min, int)): - block_size_min = [self.block_size_min] * img_dim - else: - assert(len(self.block_size_min) == img_dim) - block_size_min = self.block_size_min - - if(self.block_size_max is None): - block_size_max = [img_shape[1+i]//3 for i in range(img_dim)] - elif(isinstance(self.block_size_min, int)): - block_size_max = [self.block_size_max] * img_dim - else: - assert(len(self.block_size_max) == img_dim) - block_size_max = self.block_size_max block_num = random.randint(self.block_range[0], self.block_range[1]) - for n in range(block_num): - block_size = [random.randint(block_size_min[i], block_size_max[i]) \ - for i in range(img_dim)] - coord_min = [random.randint(3, img_shape[1+i] - block_size[i] - 3) \ + for n in range(block_num): + coord_min = [random.randint(3, img_shape[1+i] - self.block_size[i] - 3) \ for i in range(img_dim)] if(img_dim == 2): - random_block = np.random.rand(img_shape[0], block_size[0], block_size[1]) - image[:, coord_min[0]:coord_min[0] + block_size[0], - coord_min[1]:coord_min[1] + block_size[1]] = random_block + random_block = np.random.rand(img_shape[0], self.block_size[0], self.block_size[1]) * 2 -1 + image[:, coord_min[0]:coord_min[0] + self.block_size[0], + coord_min[1]:coord_min[1] + self.block_size[1]] = random_block else: - random_block = np.random.rand(img_shape[0], block_size[0], - block_size[1], block_size[2]) - image[:, coord_min[0]:coord_min[0] + block_size[0], - coord_min[1]:coord_min[1] + block_size[1], - coord_min[2]:coord_min[2] + block_size[2]] = random_block + random_block = np.random.rand(img_shape[0], self.block_size[0], + self.block_size[1], self.block_size[2]) * 2 -1 + image[:, coord_min[0]:coord_min[0] + self.block_size[0], + coord_min[1]:coord_min[1] + self.block_size[1], + coord_min[2]:coord_min[2] + self.block_size[2]] = random_block sample['image'] = image return sample @@ -341,9 +330,8 @@ def __init__(self, params): super(OutPainting, self).__init__(params) self.inverse = params.get('OutPainting_inverse'.lower(), False) self.prob = params.get('OutPainting_probability'.lower(), 0.5) - self.block_range = params.get('OutPainting_block_range'.lower(), (1, 6)) - self.block_size_min = params.get('OutPainting_block_size_min'.lower(), None) - self.block_size_max = params.get('OutPainting_block_size_max'.lower(), None) + self.block_range = params.get('OutPainting_block_range'.lower(), (2, 8)) + self.block_size = params.get('OutPainting_block_size'.lower(), None) def __call__(self, sample): if(random.random() > self.prob): @@ -353,28 +341,18 @@ def __call__(self, sample): img_shape = image.shape img_dim = len(img_shape) - 1 assert(img_dim == 2 or img_dim == 3) - img_out = np.random.rand(*img_shape) + img_out = np.random.rand(*img_shape) * 2 -1 - if(self.block_size_min is None): - block_size_min = [img_shape[1+i] - 4 * img_shape[1+i]//7 for i in range(img_dim)] - elif(isinstance(self.block_size_min, int)): - block_size_min = [self.block_size_min] * img_dim + if(self.block_size is None): + margin = [16, 32, 32] + block_size = [img_shape[1+i] - margin[i] for i in range(img_dim)] else: - assert(len(self.block_size_min) == img_dim) - block_size_min = self.block_size_min + assert(len(self.block_size) == img_dim) + block_size = self.block_size - if(self.block_size_max is None): - block_size_max = [img_shape[1+i] - 3 * img_shape[1+i]//7 for i in range(img_dim)] - elif(isinstance(self.block_size_min, int)): - block_size_max = [self.block_size_max] * img_dim - else: - assert(len(self.block_size_max) == img_dim) - block_size_max = self.block_size_max block_num = random.randint(self.block_range[0], self.block_range[1]) for n in range(block_num): - block_size = [random.randint(block_size_min[i], block_size_max[i]) \ - for i in range(img_dim)] coord_min = [random.randint(3, img_shape[1+i] - block_size[i] - 3) \ for i in range(img_dim)] if(img_dim == 2): @@ -401,8 +379,8 @@ def __init__(self, params): self.inverse = params.get('InOutPainting_inverse'.lower(), False) self.prob = params.get('InOutPainting_probability'.lower(), 0.5) self.in_prob = params.get('InPainting_probability'.lower(), 0.5) - params['InPainting_probability'] = 1.0 - params['outPainting_probability'] = 1.0 + params['InPainting_probability'.lower()] = 1.0 + params['OutPainting_probability'.lower()] = 1.0 self.inpaint = InPainting(params) self.outpaint = OutPainting(params) @@ -423,35 +401,23 @@ class PatchSwaping(AbstractTransform): """ def __init__(self, params): super(PatchSwaping, self).__init__(params) + self.block_range = params.get('PatchSwaping_block_range'.lower(), (10, 20)) + self.block_size = params.get('PatchSwaping_block_size'.lower(), [8, 16, 16]) self.inverse = params.get('PatchSwaping_inverse'.lower(), False) - self.swap_t = params.get('PatchSwaping_swap_time'.lower(), (1, 6)) - self.patch_size_min = params.get('PatchSwaping_patch_size_min'.lower(), None) - self.patch_size_max = params.get('PatchSwaping_patch_size_max'.lower(), None) - - def __call__(self, sample): + def __call__(self, sample): image= sample['image'] img_shape = image.shape img_dim = len(img_shape) - 1 assert(img_dim == 2 or img_dim == 3) - img_out = image - - C, D, H, W = image.shape - patch_size = [random.randint(self.patch_size_min[i], self.patch_size_max[i]) for \ - i in range(img_dim)] - - coordinate_list = [] - for d in range(0, D-patch_size[0], patch_size[0]): - for h in range(0, H-patch_size[1], patch_size[1]): - for w in range(0, W-patch_size[2], patch_size[2]): - coordinate_list.append((d, h, w)) - random.shuffle(coordinate_list) + img_out = copy.deepcopy(image) - for t in range(self.swap_t): - pos_a0 = coordinate_list[2*t] - pos_b0 = coordinate_list[2*t + 1] - pos_a1 = [pos_a0[i] + patch_size[i] for i in range(img_dim)] - pos_b1 = [pos_b0[i] + patch_size[i] for i in range(img_dim)] + block_num = random.randint(self.block_range[0], self.block_range[1]) + for t in range(block_num): + pos_a0 = [random.randint(0, img_shape[-3+i] - self.block_size[i]) for i in range(img_dim)] + pos_b0 = [random.randint(0, img_shape[-3+i] - self.block_size[i]) for i in range(img_dim)] + pos_a1 = [pos_a0[i] + self.block_size[i] for i in range(img_dim)] + pos_b1 = [pos_b0[i] + self.block_size[i] for i in range(img_dim)] img_out[:, pos_a0[0]:pos_a1[0], pos_a0[1]:pos_a1[1], pos_a0[2]:pos_a1[2]] = \ image[:, pos_b0[0]:pos_b1[0], pos_b0[1]:pos_b1[1], pos_b0[2]:pos_b1[2]] img_out[:, pos_b0[0]:pos_b1[0], pos_b0[1]:pos_b1[1], pos_b0[2]:pos_b1[2]] = \ diff --git a/pymic/transform/normalize.py b/pymic/transform/normalize.py index 3d28c0d..5f0e4ec 100644 --- a/pymic/transform/normalize.py +++ b/pymic/transform/normalize.py @@ -74,7 +74,7 @@ def __call__(self, sample): class NormalizeWithMinMax(AbstractTransform): - """Nomralize the image to [0, 1]. The shape should be [C, D, H, W] or [C, H, W]. + """Nomralize the image to [-1, 1]. The shape should be [C, D, H, W] or [C, H, W]. The arguments should be written in the `params` dictionary, and it has the following fields: @@ -112,13 +112,13 @@ def __call__(self, sample): img_chn[img_chn < v0] = v0 img_chn[img_chn > v1] = v1 - img_chn = (img_chn - v0) / (v1 - v0) + img_chn = 2.0* (img_chn - v0) / (v1 - v0) -1.0 image[chn] = img_chn sample['image'] = image return sample class NormalizeWithPercentiles(AbstractTransform): - """Nomralize the image to [0, 1] with percentiles for given channels. + """Nomralize the image to [-1, 1] with percentiles for given channels. The shape should be [C, D, H, W] or [C, H, W]. The arguments should be written in the `params` dictionary, and it has the @@ -152,7 +152,7 @@ def __call__(self, sample): img_chn[img_chn < v0] = v0 img_chn[img_chn > v1] = v1 - img_chn = (img_chn - v0) / (v1 - v0) + img_chn = 2.0* (img_chn - v0) / (v1 - v0) -1.0 image[chn] = img_chn sample['image'] = image return sample \ No newline at end of file diff --git a/pymic/transform/rescale.py b/pymic/transform/rescale.py index fa4f052..2896a4e 100644 --- a/pymic/transform/rescale.py +++ b/pymic/transform/rescale.py @@ -217,6 +217,7 @@ def inverse_transform_for_prediction(self, sample): origin_shape = json.loads(sample['Resample_origin_shape'][0]) else: origin_shape = json.loads(sample['Resample_origin_shape']) + origin_dim = len(origin_shape) - 1 predict = sample['predict'] input_shape = predict.shape diff --git a/pymic/transform/trans_dict.py b/pymic/transform/trans_dict.py index a28e848..ed5ad0c 100644 --- a/pymic/transform/trans_dict.py +++ b/pymic/transform/trans_dict.py @@ -50,6 +50,7 @@ 'ChannelWiseThresholdWithNormalize': ChannelWiseThresholdWithNormalize, 'CropWithBoundingBox': CropWithBoundingBox, 'CropWithForeground': CropWithForeground, + 'CropHumanRegionFromCT': CropHumanRegionFromCT, 'CenterCrop': CenterCrop, 'GrayscaleToRGB': GrayscaleToRGB, 'GammaCorrection': GammaCorrection, @@ -67,6 +68,7 @@ 'NormalizeWithPercentiles': NormalizeWithPercentiles, 'PartialLabelToProbability':PartialLabelToProbability, 'RandomCrop': RandomCrop, + 'RandomSlice': RandomSlice, 'RandomResizedCrop': RandomResizedCrop, 'RandomRescale': RandomRescale, 'RandomTranspose': RandomTranspose, diff --git a/pymic/util/image_process.py b/pymic/util/image_process.py index d6a7220..c813e5d 100644 --- a/pymic/util/image_process.py +++ b/pymic/util/image_process.py @@ -327,7 +327,7 @@ def convert_label(label, source_list, target_list): label_converted[label_s > 0] = label_t[label_s > 0] return label_converted -def resample_sitk_image_to_given_spacing(image, spacing, order): +def resample_sitk_image_to_given_spacing(image, spacing, order = 3): """ Resample an sitk image objct to a given spacing. From 3a0e1501e75b58bdca5e089309fcec221a92813e Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 10 Jan 2024 14:32:22 +0800 Subject: [PATCH 36/86] update readme --- README.md | 4 +- pymic/net/net3d/unet3d.py | 151 +++++++++++++++-------------------- pymic/net_run/preprocess.py | 82 ------------------- pymic/transform/normalize.py | 2 +- setup.py | 2 +- 5 files changed, 68 insertions(+), 173 deletions(-) delete mode 100644 pymic/net_run/preprocess.py diff --git a/README.md b/README.md index ac7f9b6..4af5eff 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # PyMIC: A Pytorch-Based Toolkit for Medical Image Computing -PyMIC is a pytorch-based toolkit for medical image computing with annotation-efficient deep learning. Despite that pytorch is a fantastic platform for deep learning, using it for medical image computing is not straightforward as medical images are often with high dimension and large volume, multiple modalities and difficulies in annotating. This toolkit is developed to facilitate medical image computing researchers so that training and testing deep learning models become easier. It is very friendly to researchers who are new to this area. Even without writing any code, you can use PyMIC commands to train and test a model by simply editing configuration files. PyMIC is developed to support learning with imperfect labels, including semi-supervised and weakly supervised learning, and learning with noisy annotations. +PyMIC is a pytorch-based toolkit for medical image computing with annotation-efficient deep learning. Despite that pytorch is a fantastic platform for deep learning, using it for medical image computing is not straightforward as medical images are often with high dimension and large volume, multiple modalities and difficulies in annotating. This toolkit is developed to facilitate medical image computing researchers so that training and testing deep learning models become easier. It is very friendly to researchers who are new to this area. Even without writing any code, you can use PyMIC commands to train and test a model by simply editing configuration files. PyMIC is developed to support learning with imperfect labels, including semi-supervised, self-supervised, and weakly supervised learning, and learning with noisy annotations. Currently PyMIC supports 2D/3D medical image classification and segmentation, and it is still under development. If you use this toolkit, please cite the following paper: @@ -23,7 +23,7 @@ BibTeX entry: # Features PyMIC provides flixible modules for medical image computing tasks including classification and segmentation. It currently provides the following functions: -* Support for annotation-efficient image segmentation, especially for semi-supervised, self-supervised, weakly-supervised and noisy-label learning. +* Support for annotation-efficient image segmentation, especially for semi-supervised, self-supervised, self-supervised, weakly-supervised and noisy-label learning. * User friendly: For beginners, you only need to edit the configuration files for model training and inference, without writing code. For advanced users, you can customize different modules (networks, loss functions, training pipeline, etc) and easily integrate them into PyMIC. * Easy-to-use I/O interface to read and write different 2D and 3D images. * Various data pre-processing/transformation methods before sending a tensor into a network. diff --git a/pymic/net/net3d/unet3d.py b/pymic/net/net3d/unet3d.py index a17bcb8..e383e77 100644 --- a/pymic/net/net3d/unet3d.py +++ b/pymic/net/net3d/unet3d.py @@ -1,10 +1,11 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division +import logging import torch import torch.nn as nn import numpy as np -from torch.nn.functional import interpolate + class ConvBlock(nn.Module): """ @@ -56,22 +57,32 @@ class UpBlock(nn.Module): :param in_channels2: (int) Channel number of low-level features. :param out_channels: (int) Output channel number. :param dropout_p: (int) Dropout probability. - :param trilinear: (bool) Use trilinear for up-sampling (by default). - If False, deconvolution is used for up-sampling. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear`). The default value + is 2 (`Trilinear`). """ - def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, - trilinear=True): + def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, up_mode=2): super(UpBlock, self).__init__() - self.trilinear = trilinear - if trilinear: - self.conv1x1 = nn.Conv3d(in_channels1, in_channels2, kernel_size = 1) - self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True) + if(isinstance(up_mode, int)): + up_mode_values = ["transconv", "nearest", "trilinear"] + if(up_mode > 2): + raise ValueError("The upsample mode should be 0-2, but {0:} is given.".format(up_mode)) + self.up_mode = up_mode_values[up_mode] else: + self.up_mode = up_mode.lower() + + if (self.up_mode == "transconv"): self.up = nn.ConvTranspose3d(in_channels1, in_channels2, kernel_size=2, stride=2) + else: + self.conv1x1 = nn.Conv3d(in_channels1, in_channels2, kernel_size = 1) + if(self.up_mode == "nearest"): + self.up = nn.Upsample(scale_factor=2, mode=self.up_mode) + else: + self.up = nn.Upsample(scale_factor=2, mode=self.up_mode, align_corners=True) self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p) def forward(self, x1, x2): - if self.trilinear: + if self.up_mode != "transconv": x1 = self.conv1x1(x1) x1 = self.up(x1) x = torch.cat([x2, x1], dim=1) @@ -129,9 +140,10 @@ class Decoder(nn.Module): :param dropout: (list) The dropout ratio for each resolution level. The length should be the same as that of `feature_chns`. :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - :param multiscale_pred: (bool) Get multi-scale prediction. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear`). The default value + is 2 (`Trilinear`). + :param multiscale_pred: (bool) Get multi-scale prediction. """ def __init__(self, params): super(Decoder, self).__init__() @@ -140,21 +152,25 @@ def __init__(self, params): self.ft_chns = self.params['feature_chns'] self.dropout = self.params['dropout'] self.n_class = self.params['class_num'] - self.trilinear = self.params.get('trilinear', True) + self.up_mode = self.params.get('up_mode', 2) self.mul_pred = self.params.get('multiscale_pred', False) - assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) if(len(self.ft_chns) == 5): - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.trilinear) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.trilinear) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.trilinear) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.trilinear) + self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.up_mode) + self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.up_mode) + self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.up_mode) + self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.up_mode) self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 1) + if(self.mul_pred): self.out_conv1 = nn.Conv3d(self.ft_chns[1], self.n_class, kernel_size = 1) self.out_conv2 = nn.Conv3d(self.ft_chns[2], self.n_class, kernel_size = 1) self.out_conv3 = nn.Conv3d(self.ft_chns[3], self.n_class, kernel_size = 1) + self.stage = 'train' + + def set_stage(self, stage): + self.stage = stage def forward(self, x): if(len(self.ft_chns) == 5): @@ -169,7 +185,7 @@ def forward(self, x): x_d1 = self.up3(x_d2, x1) x_d0 = self.up4(x_d1, x0) output = self.out_conv(x_d0) - if(self.mul_pred): + if(self.mul_pred and self.stage == 'train'): output1 = self.out_conv1(x_d1) output2 = self.out_conv2(x_d2) output3 = self.out_conv3(x_d3) @@ -196,77 +212,38 @@ class UNet3D(nn.Module): :param dropout: (list) The dropout ratio for each resolution level. The length should be the same as that of `feature_chns`. :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using trilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear`). The default value + is 2 (`Trilinear`). :param multiscale_pred: (bool) Get multi-scale prediction. """ def __init__(self, params): super(UNet3D, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.n_class = self.params['class_num'] - self.trilinear = self.params['trilinear'] - self.mul_pred = self.params['multiscale_pred'] - assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) - - self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) - self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) - self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) - self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) - if(len(self.ft_chns) == 5): - self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], - dropout_p = self.dropout[3], trilinear=self.trilinear) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], - dropout_p = self.dropout[2], trilinear=self.trilinear) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], - dropout_p = self.dropout[1], trilinear=self.trilinear) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], - dropout_p = self.dropout[0], trilinear=self.trilinear) - - self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 1) - if(self.mul_pred): - self.out_conv1 = nn.Conv3d(self.ft_chns[1], self.n_class, kernel_size = 1) - self.out_conv2 = nn.Conv3d(self.ft_chns[2], self.n_class, kernel_size = 1) - self.out_conv3 = nn.Conv3d(self.ft_chns[3], self.n_class, kernel_size = 1) + params = self.get_default_parameters(params) + for p in params: + print(p, params[p]) + self.stage = 'train' + self.encoder = Encoder(params) + self.decoder = Decoder(params) + + def get_default_parameters(self, params): + default_param = { + 'feature_chns': [32, 64, 128, 256, 512], + 'dropout': [0.0, 0.0, 0.2, 0.3, 0.4], + 'up_mode': 2, + 'multiscale_pred': False + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params + + def set_stage(self, stage): + self.stage = stage + self.decoder.set_stage(stage) def forward(self, x): - x0 = self.in_conv(x) - x1 = self.down1(x0) - x2 = self.down2(x1) - x3 = self.down3(x2) - if(len(self.ft_chns) == 5): - x4 = self.down4(x3) - x_d3 = self.up1(x4, x3) - else: - x_d3 = x3 - x_d2 = self.up2(x_d3, x2) - x_d1 = self.up3(x_d2, x1) - x_d0 = self.up4(x_d1, x0) - output = self.out_conv(x_d0) - if(self.mul_pred): - output1 = self.out_conv1(x_d1) - output2 = self.out_conv2(x_d2) - output3 = self.out_conv3(x_d3) - output = [output, output1, output2, output3] + f = self.encoder(x) + output = self.decoder(f) return output - -if __name__ == "__main__": - params = {'in_chns':4, - 'class_num': 2, - 'feature_chns':[2, 8, 32, 64], - 'dropout' : [0, 0, 0, 0.5], - 'trilinear': True, - 'multiscale_pred': False} - Net = UNet3D(params) - Net = Net.double() - - x = np.random.rand(4, 4, 96, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - y = y.detach().numpy() - print(y.shape) diff --git a/pymic/net_run/preprocess.py b/pymic/net_run/preprocess.py deleted file mode 100644 index f2bbe0f..0000000 --- a/pymic/net_run/preprocess.py +++ /dev/null @@ -1,82 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division -import logging -import os -import sys -import shutil -from datetime import datetime -from pymic import TaskType -from pymic.util.parse_config import * -from pymic.net_run.agent_cls import ClassificationAgent -from pymic.net_run.agent_seg import SegmentationAgent -from pymic.net_run.semi_sup import SSLMethodDict -from pymic.net_run.weak_sup import WSLMethodDict -from pymic.net_run.self_sup import SelfSupMethodDict -from pymic.net_run.noisy_label import NLLMethodDict -# from pymic.net_run.self_sup import SelfSLSegAgent - -def get_seg_rec_agent(config, sup_type): - assert(sup_type in ['fully_sup', 'semi_sup', 'self_sup', 'weak_sup', 'noisy_label']) - if(sup_type == 'fully_sup'): - logging.info("\n********** Fully Supervised Learning **********\n") - agent = SegmentationAgent(config, 'train') - elif(sup_type == 'semi_sup'): - logging.info("\n********** Semi Supervised Learning **********\n") - method = config['semi_supervised_learning']['method_name'] - agent = SSLMethodDict[method](config, 'train') - elif(sup_type == 'weak_sup'): - logging.info("\n********** Weakly Supervised Learning **********\n") - method = config['weakly_supervised_learning']['method_name'] - agent = WSLMethodDict[method](config, 'train') - elif(sup_type == 'noisy_label'): - logging.info("\n********** Noisy Label Learning **********\n") - method = config['noisy_label_learning']['method_name'] - agent = NLLMethodDict[method](config, 'train') - elif(sup_type == 'self_sup'): - logging.info("\n********** Self Supervised Learning **********\n") - method = config['self_supervised_learning']['method_name'] - agent = SelfSupMethodDict[method](config, 'train') - else: - raise ValueError("undefined supervision type: {0:}".format(sup_type)) - return agent - -def main(): - """ - The main function for running a network for training. - """ - if(len(sys.argv) < 2): - print('Number of arguments should be 2. e.g.') - print(' pymic_train config.cfg') - exit() - cfg_file = str(sys.argv[1]) - if(not os.path.isfile(cfg_file)): - raise ValueError("The config file does not exist: " + cfg_file) - config = parse_config(cfg_file) - config = synchronize_config(config) - log_dir = config['training']['ckpt_save_dir'] - if(not os.path.exists(log_dir)): - os.makedirs(log_dir, exist_ok=True) - dst_cfg = cfg_file if "/" not in cfg_file else cfg_file.split("/")[-1] - shutil.copy(cfg_file, log_dir + "/" + dst_cfg) - datetime_str = str(datetime.now())[:-7].replace(":", "_") - if sys.version.startswith("3.9"): - logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(datetime_str), - level=logging.INFO, format='%(message)s', force=True) # for python 3.9 - else: - logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(datetime_str), - level=logging.INFO, format='%(message)s') # for python 3.6 - logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) - logging_config(config) - task = config['dataset']['task_type'] - if(task == TaskType.CLASSIFICATION_ONE_HOT or task == TaskType.CLASSIFICATION_COEXIST): - agent = ClassificationAgent(config, 'train') - else: - sup_type = config['dataset'].get('supervise_type', 'fully_sup') - agent = get_seg_rec_agent(config, sup_type) - - agent.run() - -if __name__ == "__main__": - main() - - diff --git a/pymic/transform/normalize.py b/pymic/transform/normalize.py index 5f0e4ec..643c12e 100644 --- a/pymic/transform/normalize.py +++ b/pymic/transform/normalize.py @@ -38,7 +38,7 @@ def __init__(self, params): self.chns = params.get('NormalizeWithMeanStd_channels'.lower(), None) self.mean = params.get('NormalizeWithMeanStd_mean'.lower(), None) self.std = params.get('NormalizeWithMeanStd_std'.lower(), None) - self.mask_thrd = params.get('NormalizeWithMeanStd_mask_threshold'.lower(), 1.0) + self.mask_thrd = params.get('NormalizeWithMeanStd_mask_threshold'.lower(), None) self.bg_random = params.get('NormalizeWithMeanStd_set_background_to_random'.lower(), True) self.inverse = params.get('NormalizeWithMeanStd_inverse'.lower(), False) diff --git a/setup.py b/setup.py index 36daf9a..22406a0 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setuptools.setup( name = 'PYMIC', - version = "0.4.0", + version = "0.4.0.1", author ='PyMIC Consortium', author_email = 'wguotai@gmail.com', description = description, From cafaac0ca5699750e34b1eeee092985085775949 Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 10 Jan 2024 14:39:40 +0800 Subject: [PATCH 37/86] fix minor issues for config --- pymic/loss/seg/mumford_shah.py | 7 +- pymic/net/net3d/trans3d/HiFormer_v1.py | 1010 --------------------- pymic/net/net3d/trans3d/HiFormer_v2.py | 381 -------- pymic/net/net3d/trans3d/HiFormer_v3.py | 455 ---------- pymic/net/net3d/trans3d/HiFormer_v4.py | 455 ---------- pymic/net/net3d/trans3d/HiFormer_v5.py | 308 ------- pymic/net/net3d/trans3d/MedFormer_v1.py | 173 ---- pymic/net/net3d/trans3d/MedFormer_v2.py | 464 ---------- pymic/net/net3d/trans3d/MedFormer_v3.py | 255 ------ pymic/net/net3d/trans3d/MedFormer_va1.py | 105 --- pymic/net/net3d/trans3d/__init__.py | 0 pymic/net/net3d/trans3d/nnFormer_wrap.py | 43 - pymic/net/net3d/trans3d/unetr.py | 227 ----- pymic/net/net3d/trans3d/unetr_pp.py | 469 ---------- pymic/net/net3d/trans3d/unetr_pp_block.py | 278 ------ pymic/net_run/agent_preprocess.py | 80 +- pymic/net_run/train.py | 1 - pymic/util/evaluation_seg.py | 96 +- pymic/util/image_process.py | 12 + pymic/util/parse_config.py | 29 +- pymic/util/preprocess.py | 63 -- setup.py | 2 +- 22 files changed, 128 insertions(+), 4785 deletions(-) delete mode 100644 pymic/net/net3d/trans3d/HiFormer_v1.py delete mode 100644 pymic/net/net3d/trans3d/HiFormer_v2.py delete mode 100644 pymic/net/net3d/trans3d/HiFormer_v3.py delete mode 100644 pymic/net/net3d/trans3d/HiFormer_v4.py delete mode 100644 pymic/net/net3d/trans3d/HiFormer_v5.py delete mode 100644 pymic/net/net3d/trans3d/MedFormer_v1.py delete mode 100644 pymic/net/net3d/trans3d/MedFormer_v2.py delete mode 100644 pymic/net/net3d/trans3d/MedFormer_v3.py delete mode 100644 pymic/net/net3d/trans3d/MedFormer_va1.py delete mode 100644 pymic/net/net3d/trans3d/__init__.py delete mode 100644 pymic/net/net3d/trans3d/nnFormer_wrap.py delete mode 100644 pymic/net/net3d/trans3d/unetr.py delete mode 100644 pymic/net/net3d/trans3d/unetr_pp.py delete mode 100644 pymic/net/net3d/trans3d/unetr_pp_block.py delete mode 100644 pymic/util/preprocess.py diff --git a/pymic/loss/seg/mumford_shah.py b/pymic/loss/seg/mumford_shah.py index 6da51b5..db9368a 100644 --- a/pymic/loss/seg/mumford_shah.py +++ b/pymic/loss/seg/mumford_shah.py @@ -3,8 +3,9 @@ import torch import torch.nn as nn +from pymic.loss.seg.abstract import AbstractSegLoss -class MumfordShahLoss(nn.Module): +class MumfordShahLoss(AbstractSegLoss): """ Implementation of Mumford Shah Loss for weakly supervised learning. @@ -76,8 +77,8 @@ def forward(self, loss_input_dict): image = loss_input_dict['image'] if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) pred_shape = list(predict.shape) if(len(pred_shape) == 5): diff --git a/pymic/net/net3d/trans3d/HiFormer_v1.py b/pymic/net/net3d/trans3d/HiFormer_v1.py deleted file mode 100644 index af73683..0000000 --- a/pymic/net/net3d/trans3d/HiFormer_v1.py +++ /dev/null @@ -1,1010 +0,0 @@ -from einops import rearrange -from copy import deepcopy -from nnformer.utilities.nd_softmax import softmax_helper -from torch import nn -import torch -import numpy as np -import torch.nn.functional -import torch.nn.functional as F -import torch.utils.checkpoint as checkpoint -from timm.models.layers import DropPath, to_3tuple, trunc_normal_ -from pymic.net.net3d.unet3d import ConvBlock, DownBlock -# from nnFormer -class ContiguousGrad(torch.autograd.Function): - @staticmethod - def forward(ctx, x): - return x - @staticmethod - def backward(ctx, grad_out): - return grad_out.contiguous() - -# from nnFormer -class Mlp(nn.Module): - """ Multilayer perceptron.""" - - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - -# from nnFormer -def window_partition(x, window_size): - - B, S, H, W, C = x.shape - x = x.view(B, S // window_size, window_size, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size, window_size, window_size, C) - return windows - -# from nnFormer -def window_reverse(windows, window_size, S, H, W): - - B = int(windows.shape[0] / (S * H * W / window_size / window_size / window_size)) - x = windows.view(B, S // window_size, H // window_size, W // window_size, window_size, window_size, window_size, -1) - x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, S, H, W, -1) - return x - - -# from nnFormer -class SwinTransformerBlock_kv(nn.Module): - - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - self.attn = WindowAttention_kv( - dim, window_size=to_3tuple(self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) - - #self.window_size=to_3tuple(self.window_size) - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - def forward(self, x, mask_matrix,skip=None,x_up=None): - - B, L, C = x.shape - S, H, W = self.input_resolution - - assert L == S * H * W, "input feature has wrong size" - - shortcut = x - skip = self.norm1(skip) - x_up = self.norm1(x_up) - - skip = skip.view(B, S, H, W, C) - x_up = x_up.view(B, S, H, W, C) - x = x.view(B, S, H, W, C) - # pad feature maps to multiples of window size - pad_r = (self.window_size - W % self.window_size) % self.window_size - pad_b = (self.window_size - H % self.window_size) % self.window_size - pad_g = (self.window_size - S % self.window_size) % self.window_size - - skip = F.pad(skip, (0, 0, 0, pad_r, 0, pad_b, 0, pad_g)) - x_up = F.pad(x_up, (0, 0, 0, pad_r, 0, pad_b, 0, pad_g)) - _, Sp, Hp, Wp, _ = skip.shape - - - - # cyclic shift - if self.shift_size > 0: - skip = torch.roll(skip, shifts=(-self.shift_size, -self.shift_size,-self.shift_size), dims=(1, 2,3)) - x_up = torch.roll(x_up, shifts=(-self.shift_size, -self.shift_size,-self.shift_size), dims=(1, 2,3)) - attn_mask = mask_matrix - else: - skip = skip - x_up=x_up - attn_mask = None - # partition windows - skip = window_partition(skip, self.window_size) - skip = skip.view(-1, self.window_size * self.window_size * self.window_size, - C) - x_up = window_partition(x_up, self.window_size) - x_up = x_up.view(-1, self.window_size * self.window_size * self.window_size, - C) - attn_windows=self.attn(skip,x_up,mask=attn_mask,pos_embed=None) - - # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, self.window_size, Sp, Hp, Wp) # B H' W' C - - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size, self.shift_size), dims=(1, 2, 3)) - else: - x = shifted_x - - if pad_r > 0 or pad_b > 0 or pad_g > 0: - x = x[:, :S, :H, :W, :].contiguous() - - x = x.view(B, S * H * W, C) - - # FFN - x = shortcut + self.drop_path(x) - x = x + self.drop_path(self.mlp(self.norm2(x))) - - return x - -# from nnFormer -class WindowAttention_kv(nn.Module): - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): - - super().__init__() - self.dim = dim - self.window_size = window_size - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 - - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), - num_heads)) - - # get pair-wise relative position index for each token inside the window - coords_s = torch.arange(self.window_size[0]) - coords_h = torch.arange(self.window_size[1]) - coords_w = torch.arange(self.window_size[2]) - coords = torch.stack(torch.meshgrid([coords_s, coords_h, coords_w])) - coords_flatten = torch.flatten(coords, 1) - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] - relative_coords = relative_coords.permute(1, 2, 0).contiguous() - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 2] += self.window_size[2] - 1 - - relative_coords[:, :, 0] *= 3 * self.window_size[1] - 1 - relative_coords[:, :, 1] *= 2 * self.window_size[1] - 1 - - relative_position_index = relative_coords.sum(-1) - self.register_buffer("relative_position_index", relative_position_index) - - self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - self.softmax = nn.Softmax(dim=-1) - trunc_normal_(self.relative_position_bias_table, std=.02) - - - def forward(self, skip,x_up,pos_embed=None, mask=None): - - B_, N, C = skip.shape - - kv = self.kv(skip) - q = x_up - - kv=kv.reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() - q = q.reshape(B_,N,self.num_heads,C//self.num_heads).permute(0,2,1,3).contiguous() - k,v = kv[0], kv[1] - q = q * self.scale - attn = (q @ k.transpose(-2, -1).contiguous()) - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1] * self.window_size[2], - self.window_size[0] * self.window_size[1] * self.window_size[2], -1) - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C).contiguous() - if pos_embed is not None: - x = x + pos_embed - x = self.proj(x) - x = self.proj_drop(x) - return x - -# from nnFormer -class WindowAttention(nn.Module): - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): - - super().__init__() - self.dim = dim - self.window_size = window_size - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 - - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), - num_heads)) - - # get pair-wise relative position index for each token inside the window - coords_s = torch.arange(self.window_size[0]) - coords_h = torch.arange(self.window_size[1]) - coords_w = torch.arange(self.window_size[2]) - coords = torch.stack(torch.meshgrid([coords_s, coords_h, coords_w])) - coords_flatten = torch.flatten(coords, 1) - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] - relative_coords = relative_coords.permute(1, 2, 0).contiguous() - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 2] += self.window_size[2] - 1 - - relative_coords[:, :, 0] *= 3 * self.window_size[1] - 1 - relative_coords[:, :, 1] *= 2 * self.window_size[1] - 1 - - relative_position_index = relative_coords.sum(-1) - self.register_buffer("relative_position_index", relative_position_index) - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - trunc_normal_(self.relative_position_bias_table, std=.02) - - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask=None,pos_embed=None): - - B_, N, C = x.shape - - qkv = self.qkv(x) - - qkv=qkv.reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - - q = q * self.scale - attn = (q @ k.transpose(-2, -1).contiguous()) - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1] * self.window_size[2], - self.window_size[0] * self.window_size[1] * self.window_size[2], -1) - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C).contiguous() - if pos_embed is not None: - x = x+pos_embed - x = self.proj(x) - x = self.proj_drop(x) - return x - -# from nnFormer -class SwinTransformerBlock(nn.Module): - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - - self.attn = WindowAttention( - dim, window_size=to_3tuple(self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) - - - - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - - def forward(self, x, mask_matrix): - - B, L, C = x.shape - S, H, W = self.input_resolution - - assert L == S * H * W, "input feature has wrong size" - - shortcut = x - x = self.norm1(x) - x = x.view(B, S, H, W, C) - - # pad feature maps to multiples of window size - pad_r = (self.window_size - W % self.window_size) % self.window_size - pad_b = (self.window_size - H % self.window_size) % self.window_size - pad_g = (self.window_size - S % self.window_size) % self.window_size - - x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b, 0, pad_g)) - _, Sp, Hp, Wp, _ = x.shape - - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size,-self.shift_size), dims=(1, 2,3)) - attn_mask = mask_matrix - else: - shifted_x = x - attn_mask = None - - # partition windows - x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size * self.window_size * self.window_size, - C) - - # W-MSA/SW-MSA - attn_windows = self.attn(x_windows, mask=attn_mask,pos_embed=None) - - # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, self.window_size, Sp, Hp, Wp) - - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size, self.shift_size), dims=(1, 2, 3)) - else: - x = shifted_x - - if pad_r > 0 or pad_b > 0 or pad_g > 0: - x = x[:, :S, :H, :W, :].contiguous() - - x = x.view(B, S * H * W, C) - - # FFN - x = shortcut + self.drop_path(x) - x = x + self.drop_path(self.mlp(self.norm2(x))) - - return x - -# from nnFormer -class PatchMerging(nn.Module): - - - def __init__(self, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - self.reduction = nn.Conv3d(dim,dim*2,kernel_size=3,stride=2,padding=1) - - self.norm = norm_layer(dim) - - def forward(self, x, S, H, W): - - B, L, C = x.shape - assert L == H * W * S, "input feature has wrong size" - x = x.view(B, S, H, W, C) - - x = F.gelu(x) - x = self.norm(x) - x=x.permute(0,4,1,2,3).contiguous() - x=self.reduction(x) - x=x.permute(0,2,3,4,1).contiguous().view(B,-1,2*C) - - return x - -# from nnFormer -class Patch_Expanding(nn.Module): - def __init__(self, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - - self.norm = norm_layer(dim) - self.up=nn.ConvTranspose3d(dim,dim//2,2,2) - def forward(self, x, S, H, W): - - - B, L, C = x.shape - assert L == H * W * S, "input feature has wrong size" - - x = x.view(B, S, H, W, C) - - - - x = self.norm(x) - x=x.permute(0,4,1,2,3).contiguous() - x = self.up(x) - x = ContiguousGrad.apply(x) - x=x.permute(0,2,3,4,1).contiguous().view(B,-1,C//2) - - return x - -# from nnFormer -class BasicLayer(nn.Module): - - def __init__(self, - dim, - input_resolution, - depth, - num_heads, - window_size=7, - mlp_ratio=4., - qkv_bias=True, - qk_scale=None, - drop=0., - attn_drop=0., - drop_path=0., - norm_layer=nn.LayerNorm, - downsample=True - ): - super().__init__() - self.window_size = window_size - self.shift_size = window_size // 2 - self.depth = depth - # build blocks - - self.blocks = nn.ModuleList([ - SwinTransformerBlock( - dim=dim, - input_resolution=input_resolution, - num_heads=num_heads, - window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop, - attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) - for i in range(depth)]) - - # patch merging layer - if downsample is not None: - self.downsample = downsample(dim=dim, norm_layer=norm_layer) - else: - self.downsample = None - - def forward(self, x, S, H, W): - - - # calculate attention mask for SW-MSA - Sp = int(np.ceil(S / self.window_size)) * self.window_size - Hp = int(np.ceil(H / self.window_size)) * self.window_size - Wp = int(np.ceil(W / self.window_size)) * self.window_size - img_mask = torch.zeros((1, Sp, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 - s_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for s in s_slices: - for h in h_slices: - for w in w_slices: - img_mask[:, s, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) - mask_windows = mask_windows.view(-1, - self.window_size * self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - for blk in self.blocks: - - x = blk(x, attn_mask) - if self.downsample is not None: - x_down = self.downsample(x, S, H, W) - Ws, Wh, Ww = (S + 1) // 2, (H + 1) // 2, (W + 1) // 2 - return x, S, H, W, x_down, Ws, Wh, Ww - else: - return x, S, H, W, x, S, H, W - -# from nnFormer -class BasicLayer_up(nn.Module): - - def __init__(self, - dim, - input_resolution, - depth, - num_heads, - window_size=7, - mlp_ratio=4., - qkv_bias=True, - qk_scale=None, - drop=0., - attn_drop=0., - drop_path=0., - norm_layer=nn.LayerNorm, - upsample=True - ): - super().__init__() - self.window_size = window_size - self.shift_size = window_size // 2 - self.depth = depth - - - # build blocks - self.blocks = nn.ModuleList() - self.blocks.append( - SwinTransformerBlock_kv( - dim=dim, - input_resolution=input_resolution, - num_heads=num_heads, - window_size=window_size, - shift_size=0 , - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop, - attn_drop=attn_drop, - drop_path=drop_path[0] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) - ) - for i in range(depth-1): - self.blocks.append( - SwinTransformerBlock( - dim=dim, - input_resolution=input_resolution, - num_heads=num_heads, - window_size=window_size, - shift_size=window_size // 2 , - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop, - attn_drop=attn_drop, - drop_path=drop_path[i+1] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) - ) - - - - self.Upsample = upsample(dim=2*dim, norm_layer=norm_layer) - def forward(self, x,skip, S, H, W): - - - x_up = self.Upsample(x, S, H, W) - - x = x_up + skip - S, H, W = S * 2, H * 2, W * 2 - # calculate attention mask for SW-MSA - Sp = int(np.ceil(S / self.window_size)) * self.window_size - Hp = int(np.ceil(H / self.window_size)) * self.window_size - Wp = int(np.ceil(W / self.window_size)) * self.window_size - img_mask = torch.zeros((1, Sp, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 - s_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for s in s_slices: - for h in h_slices: - for w in w_slices: - img_mask[:, s, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, - self.window_size * self.window_size * self.window_size) # 3d��3��winds�˻�����Ŀ�Ǻܴ�ģ�����winds����̫�� - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - - x = self.blocks[0](x, attn_mask,skip=skip,x_up=x_up) - for i in range(self.depth-1): - x = self.blocks[i+1](x,attn_mask) - - return x, S, H, W - - -# from nnFormer -class project(nn.Module): - def __init__(self,in_dim,out_dim,stride,padding,activate,norm,last=False): - super().__init__() - self.out_dim=out_dim - self.conv1=nn.Conv3d(in_dim,out_dim,kernel_size=3,stride=stride,padding=padding) - self.conv2=nn.Conv3d(out_dim,out_dim,kernel_size=3,stride=1,padding=1) - self.activate=activate() - self.norm1=norm(out_dim) - self.last=last - if not last: - self.norm2=norm(out_dim) - - def forward(self,x): - x=self.conv1(x) - x=self.activate(x) - #norm1 - Ws, Wh, Ww = x.size(2), x.size(3), x.size(4) - x = x.flatten(2).transpose(1, 2).contiguous() - x = self.norm1(x) - x = x.transpose(1, 2).contiguous().view(-1, self.out_dim, Ws, Wh, Ww) - - - x=self.conv2(x) - if not self.last: - x=self.activate(x) - #norm2 - Ws, Wh, Ww = x.size(2), x.size(3), x.size(4) - x = x.flatten(2).transpose(1, 2).contiguous() - x = self.norm2(x) - x = x.transpose(1, 2).contiguous().view(-1, self.out_dim, Ws, Wh, Ww) - return x - - -# from nnFormer -class PatchEmbed_backup(nn.Module): - def __init__(self, patch_size=4, in_chans=4, embed_dim=96, norm_layer=None): - super().__init__() - patch_size = to_3tuple(patch_size) - self.patch_size = patch_size - - self.in_chans = in_chans - self.embed_dim = embed_dim - stride1=[patch_size[0]//2,patch_size[1]//2,patch_size[2]//2] - stride2=[patch_size[0]//2,patch_size[1]//2,patch_size[2]//2] - self.proj1 = project(in_chans,embed_dim//2,stride1,1,nn.GELU,nn.LayerNorm,False) - self.proj2 = project(embed_dim//2,embed_dim,stride2,1,nn.GELU,nn.LayerNorm,True) - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x): - """Forward function.""" - # padding - _, _, S, H, W = x.size() - if W % self.patch_size[2] != 0: - x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2])) - if H % self.patch_size[1] != 0: - x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1])) - if S % self.patch_size[0] != 0: - x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - S % self.patch_size[0])) - x = self.proj1(x) # B C Ws Wh Ww - x = self.proj2(x) # B C Ws Wh Ww - if self.norm is not None: - Ws, Wh, Ww = x.size(2), x.size(3), x.size(4) - x = x.flatten(2).transpose(1, 2).contiguous() - x = self.norm(x) - x = x.transpose(1, 2).contiguous().view(-1, self.embed_dim, Ws, Wh, Ww) - - return x - - -class PatchEmbed(nn.Module): - """ - replace patch embed with conv layers""" - def __init__(self, in_chns=1, ft_chns = [32, 64, 128], dropout = [0, 0, 0.2]): - super().__init__() - self.in_conv= ConvBlock(in_chns, ft_chns[0], dropout[0]) - self.down1 = DownBlock(ft_chns[0], ft_chns[1], dropout[1]) - self.down2 = DownBlock(ft_chns[1], ft_chns[2], dropout[2]) - - - def forward(self, x): - """Forward function.""" - x0 = self.in_conv(x) - x1 = self.down1(x0) - x2 = self.down2(x1) - return x2 - -# from nnFormer -class Encoder(nn.Module): - - def __init__(self, - pretrain_img_size=224, - patch_size=4, - in_chans=1 , - embed_dim=96, - depths=[2, 2, 2, 2], - num_heads=[4, 8, 16, 32], - window_size=7, - mlp_ratio=4., - qkv_bias=True, - qk_scale=None, - drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0.2, - norm_layer=nn.LayerNorm, - patch_norm=True, - out_indices=(0, 1, 2, 3) - ): - super().__init__() - - self.pretrain_img_size = pretrain_img_size - - self.num_layers = len(depths) - print("number of layers in encoder", self.num_layers, depths) - self.embed_dim = embed_dim - self.patch_norm = patch_norm - self.out_indices = out_indices - - # split image into non-overlapping patches - # self.patch_embed = PatchEmbed( - # patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, - # norm_layer=norm_layer if self.patch_norm else None) - self.patch_embed = PatchEmbed(in_chans, ft_chns=[embed_dim // 4, embed_dim //2, embed_dim]) - - - - self.pos_drop = nn.Dropout(p=drop_rate) - - # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - # build layers - self.layers = nn.ModuleList() - for i_layer in range(self.num_layers): - layer = BasicLayer( - dim=int(embed_dim * 2 ** i_layer), - input_resolution=( - pretrain_img_size[0] // patch_size[0] // 2 ** i_layer, pretrain_img_size[1] // patch_size[1] // 2 ** i_layer, - pretrain_img_size[2] // patch_size[2] // 2 ** i_layer), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size[i_layer], - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop_rate, - attn_drop=attn_drop_rate, - drop_path=dpr[sum( - depths[:i_layer]):sum(depths[:i_layer + 1])], - norm_layer=norm_layer, - downsample=PatchMerging - if (i_layer < self.num_layers - 1) else None - ) - self.layers.append(layer) - - num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] - self.num_features = num_features - - # add a norm layer for each output - for i_layer in out_indices: - layer = norm_layer(num_features[i_layer]) - layer_name = f'norm{i_layer}' - self.add_module(layer_name, layer) - - - def forward(self, x): - """Forward function.""" - - x = self.patch_embed(x) - down=[] - Ws, Wh, Ww = x.size(2), x.size(3), x.size(4) - - x = x.flatten(2).transpose(1, 2).contiguous() - x = self.pos_drop(x) - - - for i in range(self.num_layers): - layer = self.layers[i] - x_out, S, H, W, x, Ws, Wh, Ww = layer(x, Ws, Wh, Ww) - if i in self.out_indices: - norm_layer = getattr(self, f'norm{i}') - x_out = norm_layer(x_out) - - out = x_out.view(-1, S, H, W, self.num_features[i]).permute(0, 4, 1, 2, 3).contiguous() - - down.append(out) - return down - - -# from nnFormer -class Decoder(nn.Module): - def __init__(self, - pretrain_img_size, - embed_dim, - patch_size=4, - depths=[2,2,2], - num_heads=[24,12,6], - window_size=4, - mlp_ratio=4., - qkv_bias=True, - qk_scale=None, - drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0.2, - norm_layer=nn.LayerNorm - ): - super().__init__() - - - self.num_layers = len(depths) - self.pos_drop = nn.Dropout(p=drop_rate) - - # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - # build layers - self.layers = nn.ModuleList() - for i_layer in range(self.num_layers)[::-1]: - - layer = BasicLayer_up( - dim=int(embed_dim * 2 ** (len(depths)-i_layer-1)), - input_resolution=( - pretrain_img_size[0] // patch_size[0] // 2 ** (len(depths)-i_layer-1), pretrain_img_size[1] // patch_size[1] // 2 ** (len(depths)-i_layer-1), - pretrain_img_size[2] // patch_size[2] // 2 ** (len(depths)-i_layer-1)), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size[i_layer], - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop_rate, - attn_drop=attn_drop_rate, - drop_path=dpr[sum( - depths[:i_layer]):sum(depths[:i_layer + 1])], - norm_layer=norm_layer, - upsample=Patch_Expanding - ) - self.layers.append(layer) - self.num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] - def forward(self,x,skips): - - outs=[] - S, H, W = x.size(2), x.size(3), x.size(4) - x = x.flatten(2).transpose(1, 2).contiguous() - for index,i in enumerate(skips): - i = i.flatten(2).transpose(1, 2).contiguous() - skips[index]=i - x = self.pos_drop(x) - - for i in range(self.num_layers)[::-1]: - - layer = self.layers[i] - - x, S, H, W, = layer(x,skips[i], S, H, W) - out = x.view(-1, S, H, W, self.num_features[i]) - outs.append(out) - return outs - - -class final_patch_expanding(nn.Module): - def __init__(self,dim,num_class,patch_size): - super().__init__() - self.up=nn.ConvTranspose3d(dim,num_class,patch_size,patch_size) - - def forward(self,x): - x=x.permute(0,4,1,2,3).contiguous() - x=self.up(x) - - - return x - - - - - -class HiFormer_v1(nn.Module): - def __init__(self, params): - """ - replace the embedding layer with convolutional blocks - """ - super(HiFormer_v1, self).__init__() - # crop_size=[96,96,96], - # embedding_dim=192, - # input_channels=1, - # num_classes=9, - # conv_op=nn.Conv3d, - # depths=[2,2,2,2], - # num_heads=[6, 12, 24, 48], - # patch_size=[4,4,4], - # window_size=[4,4,8,4], - # deep_supervision=False): - - crop_size = params["input_size"] - embed_dim = params.get("embedding_dim", 192) - input_channels = params["in_chns"] - num_classes = params["class_num"] - self.conv_op = nn.Conv3d - depths = params.get("depths", [2, 2, 2, 2]) - num_heads = params.get("num_heads", [6, 12, 24, 48]) - patch_size = params.get("patch_size", [4, 4, 4]) # for patch embedding - window_size = params.get("window_size", [4, 4, 8, 4]) # for swin transformer window - self._deep_supervision = params.get("deep_supervision", False) - self.do_ds = params.get("deep_supervision", False) - - - self.num_classes = num_classes - self.upscale_logits_ops = [] - self.upscale_logits_ops.append(lambda x: x) - - self.model_down=Encoder(pretrain_img_size=crop_size,window_size=window_size,embed_dim=embed_dim, - patch_size=patch_size,depths=depths,num_heads=num_heads,in_chans=input_channels, out_indices=range(len(depths))) - self.decoder=Decoder(pretrain_img_size=crop_size,embed_dim=embed_dim,window_size=window_size[::-1][1:],patch_size=patch_size,num_heads=num_heads[::-1][:-1],depths=depths[::-1][1:]) - - self.final=[] - if self.do_ds: - - for i in range(len(depths)-1): - self.final.append(final_patch_expanding(embed_dim*2**i,num_classes,patch_size=patch_size)) - - else: - self.final.append(final_patch_expanding(embed_dim,num_classes,patch_size=patch_size)) - - self.final=nn.ModuleList(self.final) - - - def forward(self, x): - - - seg_outputs=[] - skips = self.model_down(x) - neck=skips[-1] - - out=self.decoder(neck,skips) - - - - if self.do_ds: - for i in range(len(out)): - seg_outputs.append(self.final[-(i+1)](out[i])) - - - return seg_outputs[::-1] - else: - seg_outputs.append(self.final[0](out[-1])) - return seg_outputs[-1] - - -if __name__ == "__main__": - # params = {"input_size": [96, 96, 96], - # "in_chns": 1, - # "depth": [2, 2, 2, 2], - # "num_heads": [6, 12, 24, 48], - # "window_size": [6, 6, 6, 3], - # "class_num": 5} - params = {"input_size": [96, 96, 96], - "in_chns": 1, - "depths": [2, 2, 2], - "num_heads": [6, 12, 24], - "window_size": [6, 6, 6], - "class_num": 5} - Net = HiFormer_v1(params) - Net = Net.double() - - x = np.random.rand(1, 1, 96, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - print(y.shape) - - - diff --git a/pymic/net/net3d/trans3d/HiFormer_v2.py b/pymic/net/net3d/trans3d/HiFormer_v2.py deleted file mode 100644 index 7d4c440..0000000 --- a/pymic/net/net3d/trans3d/HiFormer_v2.py +++ /dev/null @@ -1,381 +0,0 @@ - -import torch -import numpy as np -import torch.utils.checkpoint as checkpoint -from einops import rearrange -from copy import deepcopy -from torch import nn -from pymic.net.net3d.trans3d.HiFormer_v1 import BasicLayer - -class ConvBlock(nn.Module): - """ - 2D or 3D convolutional block - - :param in_channels: (int) Input channel number. - :param out_channels: (int) Output channel number. - :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. - :param dropout_p: (int) Dropout probability. - """ - def __init__(self, in_channels, out_channels, dim = 2, dropout_p = 0.0): - super(ConvBlock, self).__init__() - assert(dim == 2 or dim == 3) - if(dim == 2): - kernel_size = [1, 3, 3] - padding = [0, 1, 1] - else: - kernel_size = 3 - padding = 1 - - self.conv_conv = nn.Sequential( - nn.BatchNorm3d(in_channels), - nn.PReLU(), - nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=padding), - nn.BatchNorm3d(out_channels), - nn.PReLU(), - nn.Dropout(dropout_p), - nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, padding=padding), - ) - - def forward(self, x): - return self.conv_conv(x) - - -class DownSample(nn.Module): - def __init__(self, in_channels, out_channels, dim = 2, first_layer = False): - super(DownSample, self).__init__() - assert(dim == 2 or dim == 3) - if(dim == 2): - kernel_size = [1, 3, 3] - stride = [1, 2, 2] - padding = [0, 1, 1] - else: - kernel_size = 3 - stride = 2 - padding = 1 - - if(first_layer): - self.down = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, - padding=padding, stride = stride) - else: - self.down = nn.Sequential( - nn.BatchNorm3d(in_channels), - nn.PReLU(), - nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, - padding=padding, stride = stride), - ) - - def forward(self, x): - return self.down(x) - - - -class ConvTransBlock(nn.Module): - def __init__(self, - input_resolution= [32, 32, 32], - chns=96, - depth=2, - num_head=4, - window_size=7, - mlp_ratio=4., - qkv_bias=True, - qk_scale=None, - drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0.2, - norm_layer=nn.LayerNorm, - patch_norm=True, - ): - super().__init__() - self.conv = ConvBlock(chns, chns, dim = 3, dropout_p = drop_rate) - self.trans = BasicLayer( - dim= chns, - input_resolution= input_resolution, - depth=depth, - num_heads=num_head, - window_size=window_size, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop_rate, - attn_drop=attn_drop_rate, - drop_path=drop_path_rate, - norm_layer=norm_layer, - downsample= None - ) - self.norm_layer = nn.LayerNorm(chns) - self.pos_drop = nn.Dropout(p=drop_rate) - - def forward(self, x): - """Forward function.""" - x1 = self.conv(x) - C, Ws, Wh, Ww = x.size(1), x.size(2), x.size(3), x.size(4) - x = x.flatten(2).transpose(1, 2).contiguous() - x = self.pos_drop(x) - x2, S, H, W, x, Ws, Wh, Ww = self.trans(x, Ws, Wh, Ww) - # x2 = self.norm_layer(x2) - x2 = x2.view(-1, S, H, W, C).permute(0, 4, 1, 2, 3).contiguous() - return x1 + x2 - - -class UpCatBlock(nn.Module): - """ - 3D upsampling followed by ConvBlock - - :param in_channels1: (int) Channel number of high-level features. - :param in_channels2: (int) Channel number of low-level features. - :param out_channels: (int) Output channel number. - :param dropout_p: (int) Dropout probability. - :param trilinear: (bool) Use trilinear for up-sampling (by default). - If False, deconvolution is used for up-sampling. - """ - def __init__(self, chns_l, chns_h, up_dim = 3, conv_dim = 3): - super(UpCatBlock, self).__init__() - assert(up_dim == 2 or up_dim == 3) - if(up_dim == 2): - kernel_size, stride = [1, 2, 2], [1, 2, 2] - else: - kernel_size, stride = 2, 2 - self.up = nn.ConvTranspose3d(chns_h, chns_l, - kernel_size = kernel_size, stride=stride) - - if(conv_dim == 2): - kernel_size, padding = [1, 3, 3], [0, 1, 1] - else: - kernel_size, padding = 3, 1 - self.conv = nn.Sequential( - nn.BatchNorm3d(chns_l*2), - nn.PReLU(), - nn.Conv3d(chns_l*2, chns_l, kernel_size=kernel_size, padding=padding), - ) - - def forward(self, x_l, x_h): - # print("input shapes", x1.shape, x2.shape) - # print("after upsample", x1.shape) - y = torch.cat([x_l, self.up(x_h)], dim=1) - return self.conv(y) - -class Encoder(nn.Module): - def __init__(self, - in_chns = 1 , - ft_chns = [48, 192, 384, 768], - input_size= [32, 128, 128], - down_dims = [2, 2, 3, 3], - conv_dims = [2, 3, 3, 3], - dropout = [0, 0.2, 0.2, 0.2], - depths = [2, 2, 2], - num_heads = [4, 8, 16], - window_sizes = [6, 6, 6], - ): - super().__init__() - - self.down1 = DownSample(in_chns, ft_chns[0], down_dims[0], first_layer=True) - self.down2 = DownSample(ft_chns[0], ft_chns[1], down_dims[1]) - self.down3 = DownSample(ft_chns[1], ft_chns[2], down_dims[2]) - self.down4 = DownSample(ft_chns[2], ft_chns[3], down_dims[3]) - - self.conv1 = ConvBlock(ft_chns[0], ft_chns[0], conv_dims[0], dropout[0]) - self.conv2 = ConvBlock(ft_chns[1], ft_chns[1], conv_dims[1], dropout[1]) - - down_scales = [] - for i in range(4): - down_scale = [1, 2, 2] if down_dims[i] == 2 else [2, 2, 2] - down_scales.append(down_scale) - - r_t2 = [input_size[i] // down_scales[0][i] // down_scales[1][i] for i in range(3)] - r_t3 = [r_t2[i] // down_scales[2][i] for i in range(3)] - r_t4 = [r_t3[i] // down_scales[3][i] for i in range(3)] - - self.conv_t2 = ConvTransBlock(chns = ft_chns[1], - input_resolution = r_t2, - window_size = window_sizes[0], - depth = depths[0], - num_head = num_heads[0], - drop_rate = dropout[1], - attn_drop_rate=dropout[1] - ) - self.conv_t3 = ConvTransBlock(chns = ft_chns[2], - input_resolution = r_t3, - window_size = window_sizes[1], - depth = depths[1], - num_head = num_heads[1], - drop_rate = dropout[2], - attn_drop_rate=dropout[2] - ) - self.conv_t4 = ConvTransBlock(chns = ft_chns[3], - input_resolution = r_t4, - window_size = window_sizes[2], - depth = depths[2], - num_head = num_heads[2], - drop_rate = dropout[3], - attn_drop_rate=dropout[3] - ) - - - - def forward(self, x): - """Forward function.""" - x1 = self.conv1(self.down1(x)) - x2 = self.conv2(self.down2(x1)) - x2 = self.conv_t2(x2) - x3 = self.conv_t3(self.down3(x2)) - x4 = self.conv_t4(self.down4(x3)) - - return x1, x2, x3, x4 - -class Decoder(nn.Module): - """ - Decoder of 3D UNet. - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 4 or 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - :param multiscale_pred: (bool) Get multi-scale prediction. - """ - def __init__(self, - ft_chns = [48, 192, 384, 768], - input_size = [32, 128, 128], - down_dims = [2, 2, 3, 3], - conv_dims = [2, 3, 3, 3], - dropout = [0, 0, 0.2, 0.2], - depths = [2, 2, 2], - num_heads = [4, 8, 16], - window_sizes = [6, 6, 6], - class_num = 2, - multiscale_pred = False - ): - super(Decoder, self).__init__() - - self.up1 = UpCatBlock(ft_chns[0], ft_chns[1], down_dims[1], conv_dims[0]) - self.up2 = UpCatBlock(ft_chns[1], ft_chns[2], down_dims[2], conv_dims[1]) - self.up3 = UpCatBlock(ft_chns[2], ft_chns[3], down_dims[3], conv_dims[2]) - - down_scales = [] - for i in range(4): - down_scale = [1, 2, 2] if down_dims[i] == 2 else [2, 2, 2] - down_scales.append(down_scale) - - r_t2 = [input_size[i] // down_scales[0][i] // down_scales[1][i] for i in range(3)] - r_t3 = [r_t2[i] // down_scales[2][i] for i in range(3)] - - self.conv1 = ConvBlock(ft_chns[0], ft_chns[0], conv_dims[0], dropout[0]) - self.conv2 = ConvTransBlock(chns = ft_chns[1], - input_resolution = r_t2, - window_size = window_sizes[0], - depth = depths[0], - num_head = num_heads[0], - drop_rate = dropout[1], - attn_drop_rate=dropout[1] - ) - self.conv3 = ConvTransBlock(chns = ft_chns[2], - input_resolution = r_t3, - window_size = window_sizes[1], - depth = depths[1], - num_head = num_heads[1], - drop_rate = dropout[2], - attn_drop_rate=dropout[2] - ) - - kernel_size, stride = 2, 2 - if down_dims[0] == 2: - kernel_size, stride = [1, 2, 2], [1, 2, 2] - self.out_conv0 = nn.ConvTranspose3d(ft_chns[0], class_num, - kernel_size = kernel_size, stride= stride) - - self.mul_pred = multiscale_pred - if(self.mul_pred): - self.out_conv1 = nn.Conv3d(ft_chns[0], class_num, kernel_size = 1) - self.out_conv2 = nn.Conv3d(ft_chns[1], class_num, kernel_size = 1) - self.out_conv3 = nn.Conv3d(ft_chns[2], class_num, kernel_size = 1) - - def forward(self, x): - x1, x2, x3, x4 = x - x_d3 = self.conv3(self.up3(x3, x4)) - x_d2 = self.conv2(self.up2(x2, x_d3)) - x_d1 = self.conv1(self.up1(x1, x_d2)) - - output = self.out_conv0(x_d1) - if(self.mul_pred): - output1 = self.out_conv1(x_d1) - output2 = self.out_conv2(x_d2) - output3 = self.out_conv3(x_d3) - output = [output, output1, output2, output3] - return output - -class HiFormer_v2(nn.Module): - def __init__(self, params): - """ - replace the embedding layer with convolutional blocks - """ - super(HiFormer_v2, self).__init__() - in_chns = params["in_chns"] - class_num = params["class_num"] - input_size = params["input_size"] - ft_chns = params.get("feature_chns", [48, 192, 384, 764]) - down_dims = params.get("down_dims", [2, 2, 3, 3]) - conv_dims = params.get("conv_dims", [2, 3, 3, 3]) - dropout = params.get('dropout', [0, 0.2, 0.2, 0.2]) - depths = params.get("depths", [2, 2, 2]) - num_heads = params.get("num_heads", [4, 8, 16]) - window_sizes= params.get("window_sizes", [6, 6, 6]) - multiscale_pred = params.get("multiscale_pred", False) - - self.encoder = Encoder(in_chns, - ft_chns = ft_chns, - input_size = input_size, - down_dims = down_dims, - conv_dims = conv_dims, - dropout = dropout, - depths = depths, - num_heads = num_heads, - window_sizes= window_sizes) - - self.decoder = Decoder(ft_chns = ft_chns, - input_size = input_size, - down_dims = down_dims, - conv_dims = conv_dims, - dropout = dropout, - depths = depths, - num_heads = num_heads, - window_sizes= window_sizes, - class_num = class_num, - multiscale_pred = multiscale_pred - ) - - def forward(self, x): - x = self.encoder(x) - x = self.decoder(x) - return x - - -if __name__ == "__main__": - params = {"input_size": [32, 128, 128], - "in_chns": 1, - "down_dims": [2, 2, 3, 3], - "conv_dims": [2, 3, 3, 3], - "feature_chns": [96, 192, 384, 768], - "class_num": 5, - "multiscale_pred": True} - Net = HiFormer_v2(params) - Net = Net.double() - - x = np.random.rand(1, 1, 32, 128, 128) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - if(params['multiscale_pred']): - for yi in y: - print(yi.shape) - else: - print(y.shape) - - - diff --git a/pymic/net/net3d/trans3d/HiFormer_v3.py b/pymic/net/net3d/trans3d/HiFormer_v3.py deleted file mode 100644 index 2f8c831..0000000 --- a/pymic/net/net3d/trans3d/HiFormer_v3.py +++ /dev/null @@ -1,455 +0,0 @@ - -import torch -import numpy as np -import torch.utils.checkpoint as checkpoint -from einops import rearrange -from copy import deepcopy -from torch import nn -from pymic.net.net3d.trans3d.HiFormer_v1 import BasicLayer - -class ConvBlock(nn.Module): - """ - 2D or 3D convolutional block - - :param in_channels: (int) Input channel number. - :param out_channels: (int) Output channel number. - :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. - :param dropout_p: (int) Dropout probability. - """ - def __init__(self, in_channels, out_channels, dim = 2, dropout_p = 0.0): - super(ConvBlock, self).__init__() - assert(dim == 2 or dim == 3) - if(dim == 2): - kernel_size = [1, 3, 3] - padding = [0, 1, 1] - else: - kernel_size = 3 - padding = 1 - - self.conv_conv = nn.Sequential( - nn.BatchNorm3d(in_channels), - nn.PReLU(), - nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=padding), - nn.BatchNorm3d(out_channels), - nn.PReLU(), - nn.Dropout(dropout_p), - nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, padding=padding), - ) - - def forward(self, x): - return self.conv_conv(x) - - -class DownSample(nn.Module): - def __init__(self, in_channels, out_channels, dim = 2, first_layer = False): - super(DownSample, self).__init__() - assert(dim == 2 or dim == 3) - if(dim == 2): - kernel_size = [1, 3, 3] - stride = [1, 2, 2] - padding = [0, 1, 1] - else: - kernel_size = 3 - stride = 2 - padding = 1 - - if(first_layer): - self.down = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, - padding=padding, stride = stride) - else: - self.down = nn.Sequential( - nn.BatchNorm3d(in_channels), - nn.PReLU(), - nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, - padding=padding, stride = stride), - ) - - def forward(self, x): - return self.down(x) - - - -class ConvTransBlock_backup(nn.Module): - def __init__(self, - input_resolution= [32, 32, 32], - chns=96, - depth=2, - num_head=4, - window_size=7, - mlp_ratio=4., - qkv_bias=True, - qk_scale=None, - drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0.2, - norm_layer=nn.LayerNorm, - patch_norm=True, - ): - super().__init__() - self.conv = ConvBlock(chns, chns, dim = 3, dropout_p = drop_rate) - self.trans = BasicLayer( - dim= chns, - input_resolution= input_resolution, - depth=depth, - num_heads=num_head, - window_size=window_size, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop_rate, - attn_drop=attn_drop_rate, - drop_path=drop_path_rate, - norm_layer=norm_layer, - downsample= None - ) - self.norm_layer = nn.LayerNorm(chns) - self.pos_drop = nn.Dropout(p=drop_rate) - - def forward(self, x): - """Forward function.""" - x1 = self.conv(x) - C, Ws, Wh, Ww = x.size(1), x.size(2), x.size(3), x.size(4) - x = x.flatten(2).transpose(1, 2).contiguous() - x = self.pos_drop(x) - x2, S, H, W, x, Ws, Wh, Ww = self.trans(x, Ws, Wh, Ww) - # x2 = self.norm_layer(x2) - x2 = x2.view(-1, S, H, W, C).permute(0, 4, 1, 2, 3).contiguous() - return x1 + x2 - -# only using the conv block -class ConvTransBlock(nn.Module): - def __init__(self, - input_resolution= [32, 32, 32], - chns=96, - depth=2, - num_head=4, - window_size=7, - mlp_ratio=4., - qkv_bias=True, - qk_scale=None, - drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0.2, - norm_layer=nn.LayerNorm, - patch_norm=True, - ): - super().__init__() - self.conv = ConvBlock(chns, chns, dim = 3, dropout_p = drop_rate) - # self.trans = BasicLayer( - # dim= chns, - # input_resolution= input_resolution, - # depth=depth, - # num_heads=num_head, - # window_size=window_size, - # mlp_ratio=mlp_ratio, - # qkv_bias=qkv_bias, - # qk_scale=qk_scale, - # drop=drop_rate, - # attn_drop=attn_drop_rate, - # drop_path=drop_path_rate, - # norm_layer=norm_layer, - # downsample= None - # ) - # self.norm_layer = nn.LayerNorm(chns) - # self.pos_drop = nn.Dropout(p=drop_rate) - - def forward(self, x): - """Forward function.""" - x1 = self.conv(x) - return x1 - # C, Ws, Wh, Ww = x.size(1), x.size(2), x.size(3), x.size(4) - # x = x.flatten(2).transpose(1, 2).contiguous() - # x = self.pos_drop(x) - # x2, S, H, W, x, Ws, Wh, Ww = self.trans(x, Ws, Wh, Ww) - # # x2 = self.norm_layer(x2) - # x2 = x2.view(-1, S, H, W, C).permute(0, 4, 1, 2, 3).contiguous() - # return x1 + x2 - -class UpCatBlock(nn.Module): - """ - 3D upsampling followed by ConvBlock - - :param in_channels1: (int) Channel number of high-level features. - :param in_channels2: (int) Channel number of low-level features. - :param out_channels: (int) Output channel number. - :param dropout_p: (int) Dropout probability. - :param trilinear: (bool) Use trilinear for up-sampling (by default). - If False, deconvolution is used for up-sampling. - """ - def __init__(self, chns_l, chns_h, up_dim = 3, conv_dim = 3): - super(UpCatBlock, self).__init__() - assert(up_dim == 2 or up_dim == 3) - if(up_dim == 2): - kernel_size, stride = [1, 2, 2], [1, 2, 2] - else: - kernel_size, stride = 2, 2 - self.up = nn.ConvTranspose3d(chns_h, chns_l, - kernel_size = kernel_size, stride=stride) - - if(conv_dim == 2): - kernel_size, padding = [1, 3, 3], [0, 1, 1] - else: - kernel_size, padding = 3, 1 - self.conv = nn.Sequential( - nn.BatchNorm3d(chns_l*2), - nn.PReLU(), - nn.Conv3d(chns_l*2, chns_l, kernel_size=kernel_size, padding=padding), - ) - - def forward(self, x_l, x_h): - # print("input shapes", x1.shape, x2.shape) - # print("after upsample", x1.shape) - y = torch.cat([x_l, self.up(x_h)], dim=1) - return self.conv(y) - -class Encoder(nn.Module): - def __init__(self, - in_chns = 1 , - ft_chns = [48, 192, 384, 768], - input_size= [32, 128, 128], - down_dims = [2, 2, 3, 3], - conv_dims = [2, 3, 3, 3], - dropout = [0, 0.2, 0.2, 0.2], - depths = [2, 2, 2], - num_heads = [4, 8, 16], - window_sizes = [6, 6, 6], - high_res = False, - ): - super().__init__() - self.high_res = high_res - - self.down1 = DownSample(in_chns, ft_chns[0], down_dims[0], first_layer=True) - self.down2 = DownSample(ft_chns[0], ft_chns[1], down_dims[1]) - self.down3 = DownSample(ft_chns[1], ft_chns[2], down_dims[2]) - self.down4 = DownSample(ft_chns[2], ft_chns[3], down_dims[3]) - - if(high_res): - self.conv0 = ConvBlock(in_chns, ft_chns[0] // 2, 3, 0) - self.conv1 = ConvBlock(ft_chns[0], ft_chns[0], conv_dims[0], dropout[0]) - self.conv2 = ConvBlock(ft_chns[1], ft_chns[1], conv_dims[1], dropout[1]) - - down_scales = [] - for i in range(4): - down_scale = [1, 2, 2] if down_dims[i] == 2 else [2, 2, 2] - down_scales.append(down_scale) - - r_t2 = [input_size[i] // down_scales[0][i] // down_scales[1][i] for i in range(3)] - r_t3 = [r_t2[i] // down_scales[2][i] for i in range(3)] - r_t4 = [r_t3[i] // down_scales[3][i] for i in range(3)] - - self.conv_t2 = ConvTransBlock(chns = ft_chns[1], - input_resolution = r_t2, - window_size = window_sizes[0], - depth = depths[0], - num_head = num_heads[0], - drop_rate = dropout[1], - attn_drop_rate=dropout[1] - ) - self.conv_t3 = ConvTransBlock(chns = ft_chns[2], - input_resolution = r_t3, - window_size = window_sizes[1], - depth = depths[1], - num_head = num_heads[1], - drop_rate = dropout[2], - attn_drop_rate=dropout[2] - ) - self.conv_t4 = ConvTransBlock(chns = ft_chns[3], - input_resolution = r_t4, - window_size = window_sizes[2], - depth = depths[2], - num_head = num_heads[2], - drop_rate = dropout[3], - attn_drop_rate=dropout[3] - ) - - - - def forward(self, x): - """Forward function.""" - if(self.high_res): - x0 = self.conv0(x) - x1 = self.conv1(self.down1(x)) - x2 = self.conv2(self.down2(x1)) - x2 = self.conv_t2(x2) - x3 = self.conv_t3(self.down3(x2)) - x4 = self.conv_t4(self.down4(x3)) - if(self.high_res): - return x0, x1, x2, x3, x4 - else: - return x1, x2, x3, x4 - -class Decoder(nn.Module): - """ - Decoder of 3D UNet. - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 4 or 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - :param multiscale_pred: (bool) Get multi-scale prediction. - """ - def __init__(self, - ft_chns = [48, 192, 384, 768], - input_size = [32, 128, 128], - down_dims = [2, 2, 3, 3], - conv_dims = [2, 3, 3, 3], - dropout = [0, 0, 0.2, 0.2], - depths = [2, 2, 2], - num_heads = [4, 8, 16], - window_sizes = [6, 6, 6], - high_res = False, - class_num = 2, - multiscale_pred = False - ): - super(Decoder, self).__init__() - self.high_res = high_res - if(self.high_res): - self.up0 = UpCatBlock(ft_chns[0] // 2, ft_chns[0], down_dims[0], 3) - self.conv0 = ConvBlock(ft_chns[0] // 2, ft_chns[0] // 2, 3, 0) - self.up1 = UpCatBlock(ft_chns[0], ft_chns[1], down_dims[1], conv_dims[0]) - self.up2 = UpCatBlock(ft_chns[1], ft_chns[2], down_dims[2], conv_dims[1]) - self.up3 = UpCatBlock(ft_chns[2], ft_chns[3], down_dims[3], conv_dims[2]) - - down_scales = [] - for i in range(4): - down_scale = [1, 2, 2] if down_dims[i] == 2 else [2, 2, 2] - down_scales.append(down_scale) - - r_t2 = [input_size[i] // down_scales[0][i] // down_scales[1][i] for i in range(3)] - r_t3 = [r_t2[i] // down_scales[2][i] for i in range(3)] - - self.conv1 = ConvBlock(ft_chns[0], ft_chns[0], conv_dims[0], dropout[0]) - self.conv2 = ConvTransBlock(chns = ft_chns[1], - input_resolution = r_t2, - window_size = window_sizes[0], - depth = depths[0], - num_head = num_heads[0], - drop_rate = dropout[1], - attn_drop_rate=dropout[1] - ) - self.conv3 = ConvTransBlock(chns = ft_chns[2], - input_resolution = r_t3, - window_size = window_sizes[1], - depth = depths[1], - num_head = num_heads[1], - drop_rate = dropout[2], - attn_drop_rate=dropout[2] - ) - - kernel_size, stride = 2, 2 - if down_dims[0] == 2: - kernel_size, stride = [1, 2, 2], [1, 2, 2] - if(self.high_res): - self.out_conv0 = nn.Conv3d(ft_chns[0] // 2, class_num, - kernel_size = [1, 3, 3], padding = [0, 1, 1]) - else: - self.out_conv0 = nn.ConvTranspose3d(ft_chns[0], class_num, - kernel_size = kernel_size, stride= stride) - - self.mul_pred = multiscale_pred - if(self.mul_pred): - self.out_conv1 = nn.Conv3d(ft_chns[0], class_num, kernel_size = 1) - self.out_conv2 = nn.Conv3d(ft_chns[1], class_num, kernel_size = 1) - self.out_conv3 = nn.Conv3d(ft_chns[2], class_num, kernel_size = 1) - - def forward(self, x): - if(self.high_res): - x0, x1, x2, x3, x4 = x - else: - x1, x2, x3, x4 = x - x_d3 = self.conv3(self.up3(x3, x4)) - x_d2 = self.conv2(self.up2(x2, x_d3)) - x_d1 = self.conv1(self.up1(x1, x_d2)) - if(self.high_res): - x_d0 = self.conv0(self.up0(x0, x_d1)) - output = self.out_conv0(x_d0) - else: - output = self.out_conv0(x_d1) - if(self.mul_pred): - output1 = self.out_conv1(x_d1) - output2 = self.out_conv2(x_d2) - output3 = self.out_conv3(x_d3) - output = [output, output1, output2, output3] - return output - -class HiFormer_v3(nn.Module): - def __init__(self, params): - """ - replace the embedding layer with convolutional blocks - """ - super(HiFormer_v3, self).__init__() - in_chns = params["in_chns"] - class_num = params["class_num"] - input_size = params["input_size"] - ft_chns = params.get("feature_chns", [48, 192, 384, 764]) - down_dims = params.get("down_dims", [2, 2, 3, 3]) - conv_dims = params.get("conv_dims", [2, 3, 3, 3]) - dropout = params.get('dropout', [0, 0.2, 0.2, 0.2]) - high_res = params.get("high_res", False) - depths = params.get("depths", [2, 2, 2]) - num_heads = params.get("num_heads", [4, 8, 16]) - window_sizes= params.get("window_sizes", [6, 6, 6]) - multiscale_pred = params.get("multiscale_pred", False) - - self.encoder = Encoder(in_chns, - ft_chns = ft_chns, - input_size = input_size, - down_dims = down_dims, - conv_dims = conv_dims, - dropout = dropout, - depths = depths, - num_heads = num_heads, - window_sizes= window_sizes, - high_res = high_res) - - self.decoder = Decoder(ft_chns = ft_chns, - input_size = input_size, - down_dims = down_dims, - conv_dims = conv_dims, - dropout = dropout, - depths = depths, - num_heads = num_heads, - window_sizes= window_sizes, - high_res = high_res, - class_num = class_num, - multiscale_pred = multiscale_pred - ) - - def forward(self, x): - x = self.encoder(x) - x = self.decoder(x) - return x - - -if __name__ == "__main__": - params = {"input_size": [64, 96, 96], - "in_chns": 1, - "down_dims": [3, 3, 3, 3], - "conv_dims": [3, 3, 3, 3], - "feature_chns": [96, 192, 384, 768], - "high_res": True, - "class_num": 5, - "multiscale_pred": True} - Net = HiFormer_v3(params) - Net = Net.double() - - x = np.random.rand(1, 1, 64, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - if(params['multiscale_pred']): - for yi in y: - print(yi.shape) - else: - print(y.shape) - - - diff --git a/pymic/net/net3d/trans3d/HiFormer_v4.py b/pymic/net/net3d/trans3d/HiFormer_v4.py deleted file mode 100644 index f0c6087..0000000 --- a/pymic/net/net3d/trans3d/HiFormer_v4.py +++ /dev/null @@ -1,455 +0,0 @@ - -import torch -import numpy as np -import torch.utils.checkpoint as checkpoint -from einops import rearrange -from copy import deepcopy -from torch import nn -from pymic.net.net3d.trans3d.HiFormer_v1 import BasicLayer - -class ConvBlock(nn.Module): - """ - 2D or 3D convolutional block - - :param in_channels: (int) Input channel number. - :param out_channels: (int) Output channel number. - :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. - :param dropout_p: (int) Dropout probability. - """ - def __init__(self, in_channels, out_channels, dim = 2, dropout_p = 0.0): - super(ConvBlock, self).__init__() - assert(dim == 2 or dim == 3) - if(dim == 2): - kernel_size = [1, 3, 3] - padding = [0, 1, 1] - else: - kernel_size = 3 - padding = 1 - - self.conv_conv = nn.Sequential( - nn.BatchNorm3d(in_channels), - nn.PReLU(), - nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=padding), - nn.BatchNorm3d(out_channels), - nn.PReLU(), - nn.Dropout(dropout_p), - nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, padding=padding), - ) - - def forward(self, x): - return self.conv_conv(x) - - -class DownSample(nn.Module): - def __init__(self, in_channels, out_channels, down_dim = 3, conv_dim = 3): - super(DownSample, self).__init__() - assert(down_dim == 2 or down_dim == 3) - assert(conv_dim == 2 or conv_dim == 3) - - kernel_size = [1, 2, 2] if(down_dim == 2) else 2 - self.pool = nn.MaxPool3d(kernel_size) - - if(conv_dim == 2): - kernel_size = [1, 3, 3] - padding = [0, 1, 1] - else: - kernel_size = 3 - padding = 1 - - self.conv = nn.Sequential( - nn.BatchNorm3d(in_channels), - nn.PReLU(), - nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=padding), - ) - - def forward(self, x): - return self.conv(self.pool(x)) - - - -# class ConvTransBlock(nn.Module): -# def __init__(self, -# input_resolution= [32, 32, 32], -# chns=96, -# depth=2, -# num_head=4, -# window_size=7, -# mlp_ratio=4., -# qkv_bias=True, -# qk_scale=None, -# drop_rate=0., -# attn_drop_rate=0., -# drop_path_rate=0.2, -# norm_layer=nn.LayerNorm, -# patch_norm=True, -# ): -# super().__init__() -# self.conv = ConvBlock(chns, chns, dim = 3, dropout_p = drop_rate) -# self.trans = BasicLayer( -# dim= chns, -# input_resolution= input_resolution, -# depth=depth, -# num_heads=num_head, -# window_size=window_size, -# mlp_ratio=mlp_ratio, -# qkv_bias=qkv_bias, -# qk_scale=qk_scale, -# drop=drop_rate, -# attn_drop=attn_drop_rate, -# drop_path=drop_path_rate, -# norm_layer=norm_layer, -# downsample= None -# ) -# self.norm_layer = nn.LayerNorm(chns) -# self.pos_drop = nn.Dropout(p=drop_rate) - -# def forward(self, x): -# """Forward function.""" -# x1 = self.conv(x) -# C, Ws, Wh, Ww = x.size(1), x.size(2), x.size(3), x.size(4) -# x = x.flatten(2).transpose(1, 2).contiguous() -# x = self.pos_drop(x) -# x2, S, H, W, x, Ws, Wh, Ww = self.trans(x, Ws, Wh, Ww) -# # x2 = self.norm_layer(x2) -# x2 = x2.view(-1, S, H, W, C).permute(0, 4, 1, 2, 3).contiguous() -# return x1 + x2 - -# only using the conv block -class ConvTransBlock(nn.Module): - def __init__(self, - input_resolution= [32, 32, 32], - chns=96, - depth=2, - num_head=4, - window_size=7, - mlp_ratio=4., - qkv_bias=True, - qk_scale=None, - drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0.2, - norm_layer=nn.LayerNorm, - patch_norm=True, - ): - super().__init__() - self.conv = ConvBlock(chns, chns, dim = 3, dropout_p = drop_rate) - # self.trans = BasicLayer( - # dim= chns, - # input_resolution= input_resolution, - # depth=depth, - # num_heads=num_head, - # window_size=window_size, - # mlp_ratio=mlp_ratio, - # qkv_bias=qkv_bias, - # qk_scale=qk_scale, - # drop=drop_rate, - # attn_drop=attn_drop_rate, - # drop_path=drop_path_rate, - # norm_layer=norm_layer, - # downsample= None - # ) - # self.norm_layer = nn.LayerNorm(chns) - # self.pos_drop = nn.Dropout(p=drop_rate) - - def forward(self, x): - """Forward function.""" - x1 = self.conv(x) - return x1 - # C, Ws, Wh, Ww = x.size(1), x.size(2), x.size(3), x.size(4) - # x = x.flatten(2).transpose(1, 2).contiguous() - # x = self.pos_drop(x) - # x2, S, H, W, x, Ws, Wh, Ww = self.trans(x, Ws, Wh, Ww) - # # x2 = self.norm_layer(x2) - # x2 = x2.view(-1, S, H, W, C).permute(0, 4, 1, 2, 3).contiguous() - # return x1 + x2 - -class ConvLayer(nn.Module): - """ - 2D or 3D convolutional block - - :param in_channels: (int) Input channel number. - :param out_channels: (int) Output channel number. - :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. - :param dropout_p: (int) Dropout probability. - """ - def __init__(self, in_channels, out_channels, kernel = 1, padding = 0): - super(ConvLayer, self).__init__() - - self.conv = nn.Sequential( - nn.BatchNorm3d(in_channels), - nn.PReLU(), - nn.Conv3d(in_channels, out_channels, kernel_size=kernel, padding=padding), - ) - - def forward(self, x): - return self.conv(x) - -class UpCatBlock(nn.Module): - """ - 3D upsampling followed by ConvBlock - - :param in_channels1: (int) Channel number of high-level features. - :param in_channels2: (int) Channel number of low-level features. - :param out_channels: (int) Output channel number. - :param dropout_p: (int) Dropout probability. - :param trilinear: (bool) Use trilinear for up-sampling (by default). - If False, deconvolution is used for up-sampling. - """ - def __init__(self, chns_l, chns_h, up_dim = 3, conv_dim = 3): - super(UpCatBlock, self).__init__() - assert(up_dim == 2 or up_dim == 3) - if(up_dim == 2): - kernel_size, stride = [1, 2, 2], [1, 2, 2] - else: - kernel_size, stride = 2, 2 - - self.up = nn.Sequential( - nn.BatchNorm3d(chns_h), - nn.PReLU(), - nn.ConvTranspose3d(chns_h, chns_l, kernel_size = kernel_size, stride=stride) - ) - - if(conv_dim == 2): - kernel_size, padding = [1, 3, 3], [0, 1, 1] - else: - kernel_size, padding = 3, 1 - self.conv = nn.Sequential( - nn.BatchNorm3d(chns_l*2), - nn.PReLU(), - nn.Conv3d(chns_l*2, chns_l, kernel_size=kernel_size, padding=padding), - ) - - def forward(self, x_l, x_h): - # print("input shapes", x1.shape, x2.shape) - # print("after upsample", x1.shape) - y = torch.cat([x_l, self.up(x_h)], dim=1) - return self.conv(y) - -class Encoder(nn.Module): - def __init__(self, - in_chns = 1 , - ft_chns = [24, 48, 192, 384, 768], - input_size= [32, 128, 128], - down_dims = [3, 3, 3, 3, 3], - conv_dims = [3, 3, 3, 3, 3], - dropout = [0, 0, 0.2, 0.2, 0.2], - depths = [2, 2, 2], - num_heads = [4, 8, 16], - window_sizes = [6, 6, 6], - ): - super().__init__() - self.proj = nn.Conv3d(in_chns, ft_chns[0], kernel_size=3, padding=1) - self.conv0 = ConvBlock(ft_chns[0], ft_chns[0], conv_dims[0], dropout[0]) - self.conv1 = ConvBlock(ft_chns[1], ft_chns[1], conv_dims[1], dropout[1]) - self.conv2 = ConvBlock(ft_chns[2], ft_chns[2], conv_dims[2], dropout[2]) - - self.down1 = DownSample(ft_chns[0], ft_chns[1], down_dims[0], conv_dims[1]) - self.down2 = DownSample(ft_chns[1], ft_chns[2], down_dims[1], conv_dims[2]) - self.down3 = DownSample(ft_chns[2], ft_chns[3], down_dims[2], conv_dims[3]) - self.down4 = DownSample(ft_chns[3], ft_chns[4], down_dims[3], conv_dims[4]) - - down_scales = [] - for i in range(4): - down_scale = [1, 2, 2] if down_dims[i] == 2 else [2, 2, 2] - down_scales.append(down_scale) - - r_t2 = [input_size[i] // down_scales[0][i] // down_scales[1][i] for i in range(3)] - r_t3 = [r_t2[i] // down_scales[2][i] for i in range(3)] - r_t4 = [r_t3[i] // down_scales[3][i] for i in range(3)] - - self.conv_t2 = ConvTransBlock(chns = ft_chns[2], - input_resolution = r_t2, - window_size = window_sizes[0], - depth = depths[0], - num_head = num_heads[0], - drop_rate = dropout[2], - attn_drop_rate=dropout[2] - ) - self.conv_t3 = ConvTransBlock(chns = ft_chns[3], - input_resolution = r_t3, - window_size = window_sizes[1], - depth = depths[1], - num_head = num_heads[1], - drop_rate = dropout[3], - attn_drop_rate=dropout[3] - ) - self.conv_t4 = ConvTransBlock(chns = ft_chns[4], - input_resolution = r_t4, - window_size = window_sizes[2], - depth = depths[2], - num_head = num_heads[2], - drop_rate = dropout[4], - attn_drop_rate=dropout[4] - ) - - - - def forward(self, x): - """Forward function.""" - x0 = self.conv0(self.proj(x)) - x1 = self.conv1(self.down1(x0)) - x2 = self.conv2(self.down2(x1)) - x2 = self.conv_t2(x2) - x3 = self.conv_t3(self.down3(x2)) - x4 = self.conv_t4(self.down4(x3)) - return x0, x1, x2, x3, x4 - - -class Decoder(nn.Module): - """ - Decoder of 3D UNet. - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 4 or 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - :param multiscale_pred: (bool) Get multi-scale prediction. - """ - def __init__(self, - ft_chns = [24, 48, 192, 384, 768], - input_size= [32, 128, 128], - down_dims = [3, 3, 3, 3, 3], - conv_dims = [3, 3, 3, 3, 3], - dropout = [0, 0, 0.2, 0.2, 0.2], - depths = [2, 2, 2], - num_heads = [4, 8, 16], - window_sizes = [6, 6, 6], - class_num = 2, - multiscale_pred = False - ): - super(Decoder, self).__init__() - # self.up0 = UpCatBlock(ft_chns[0] // 2, ft_chns[0], down_dims[0], 3) - # self.conv0 = ConvBlock(ft_chns[0] // 2, ft_chns[0] // 2, 3, 0) - self.up1 = UpCatBlock(ft_chns[0], ft_chns[1], down_dims[0], conv_dims[0]) - self.up2 = UpCatBlock(ft_chns[1], ft_chns[2], down_dims[1], conv_dims[1]) - self.up3 = UpCatBlock(ft_chns[2], ft_chns[3], down_dims[2], conv_dims[2]) - self.up4 = UpCatBlock(ft_chns[3], ft_chns[4], down_dims[3], conv_dims[3]) - - down_scales = [] - for i in range(4): - down_scale = [1, 2, 2] if down_dims[i] == 2 else [2, 2, 2] - down_scales.append(down_scale) - - r_t2 = [input_size[i] // down_scales[0][i] // down_scales[1][i] for i in range(3)] - r_t3 = [r_t2[i] // down_scales[2][i] for i in range(3)] - - self.conv0 = ConvBlock(ft_chns[0], ft_chns[0], conv_dims[0], dropout[0]) - self.conv1 = ConvBlock(ft_chns[1], ft_chns[1], conv_dims[1], dropout[1]) - self.conv2 = ConvTransBlock(chns = ft_chns[2], - input_resolution = r_t2, - window_size = window_sizes[0], - depth = depths[0], - num_head = num_heads[0], - drop_rate = dropout[2], - attn_drop_rate=dropout[2] - ) - self.conv3 = ConvTransBlock(chns = ft_chns[3], - input_resolution = r_t3, - window_size = window_sizes[1], - depth = depths[1], - num_head = num_heads[1], - drop_rate = dropout[3], - attn_drop_rate=dropout[3] - ) - - self.out_conv0 = ConvLayer(ft_chns[0], class_num) - - self.mul_pred = multiscale_pred - if(self.mul_pred): - self.out_conv1 = ConvLayer(ft_chns[1], class_num) - self.out_conv2 = ConvLayer(ft_chns[2], class_num) - self.out_conv3 = ConvLayer(ft_chns[3], class_num) - - def forward(self, x): - x0, x1, x2, x3, x4 = x - - x_d3 = self.conv3(self.up4(x3, x4)) - x_d2 = self.conv2(self.up3(x2, x_d3)) - x_d1 = self.conv1(self.up2(x1, x_d2)) - x_d0 = self.conv0(self.up1(x0, x_d1)) - output = self.out_conv0(x_d0) - - if(self.mul_pred): - output1 = self.out_conv1(x_d1) - output2 = self.out_conv2(x_d2) - output3 = self.out_conv3(x_d3) - output = [output, output1, output2, output3] - return output - -class HiFormer_v4(nn.Module): - def __init__(self, params): - """ - replace the embedding layer with convolutional blocks - """ - super(HiFormer_v4, self).__init__() - in_chns = params["in_chns"] - class_num = params["class_num"] - input_size = params["input_size"] - ft_chns = params.get("feature_chns", [32, 64, 128, 256, 512]) - down_dims = params.get("down_dims", [3, 3, 3, 3, 3]) - conv_dims = params.get("conv_dims", [3, 3, 3, 3, 3]) - dropout = params.get('dropout', [0, 0, 0.2, 0.2, 0.2]) - depths = params.get("depths", [2, 2, 2]) - num_heads = params.get("num_heads", [4, 8, 16]) - window_sizes= params.get("window_sizes", [6, 6, 6]) - multiscale_pred = params.get("multiscale_pred", False) - - self.encoder = Encoder(in_chns, - ft_chns = ft_chns, - input_size = input_size, - down_dims = down_dims, - conv_dims = conv_dims, - dropout = dropout, - depths = depths, - num_heads = num_heads, - window_sizes= window_sizes) - - self.decoder = Decoder(ft_chns = ft_chns, - input_size = input_size, - down_dims = down_dims, - conv_dims = conv_dims, - dropout = dropout, - depths = depths, - num_heads = num_heads, - window_sizes= window_sizes, - class_num = class_num, - multiscale_pred = multiscale_pred - ) - - def forward(self, x): - x = self.encoder(x) - x = self.decoder(x) - return x - - -if __name__ == "__main__": - params = {"input_size": [64, 96, 96], - "in_chns": 1, - "down_dims": [3, 3, 3, 3, 3], - "conv_dims": [3, 3, 3, 3, 3], - "feature_chns": [32, 64, 128, 256, 512], - "class_num": 5, - "multiscale_pred": True} - Net = HiFormer_v4(params) - Net = Net.double() - - x = np.random.rand(1, 1, 64, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - if(params['multiscale_pred']): - for yi in y: - print(yi.shape) - else: - print(y.shape) - - - diff --git a/pymic/net/net3d/trans3d/HiFormer_v5.py b/pymic/net/net3d/trans3d/HiFormer_v5.py deleted file mode 100644 index 5fcef5a..0000000 --- a/pymic/net/net3d/trans3d/HiFormer_v5.py +++ /dev/null @@ -1,308 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division - -import torch -import torch.nn as nn -import numpy as np -from torch.nn.functional import interpolate - - -class ConvBlock(nn.Module): - """ - 2D or 3D convolutional block - - :param in_channels: (int) Input channel number. - :param out_channels: (int) Output channel number. - :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. - :param dropout_p: (int) Dropout probability. - """ - def __init__(self, in_channels, out_channels, dropout_p = 0.0, dim = 3): - super(ConvBlock, self).__init__() - assert(dim == 2 or dim == 3) - if(dim == 2): - kernel_size = [1, 3, 3] - padding = [0, 1, 1] - else: - kernel_size = 3 - padding = 1 - - self.conv_conv = nn.Sequential( - nn.BatchNorm3d(in_channels), - nn.LeakyReLU(), - nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=padding), - nn.BatchNorm3d(out_channels), - nn.LeakyReLU(), - nn.Dropout(dropout_p), - nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, padding=padding), - ) - - def forward(self, x): - return self.conv_conv(x) - -class ConvLayer(nn.Module): - """ - 2D or 3D convolutional block - - :param in_channels: (int) Input channel number. - :param out_channels: (int) Output channel number. - :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. - :param dropout_p: (int) Dropout probability. - """ - def __init__(self, in_channels, out_channels, kernel = 1, padding = 0): - super(ConvLayer, self).__init__() - - self.conv = nn.Sequential( - nn.BatchNorm3d(in_channels), - nn.LeakyReLU(), - nn.Conv3d(in_channels, out_channels, kernel_size=kernel, padding=padding), - ) - - def forward(self, x): - return self.conv(x) - -class DownBlock(nn.Module): - """ - 3D downsampling followed by ConvBlock - - :param in_channels: (int) Input channel number. - :param out_channels: (int) Output channel number. - :param dropout_p: (int) Dropout probability. - """ - def __init__(self, in_channels, out_channels, dropout_p): - super(DownBlock, self).__init__() - self.maxpool_conv = nn.Sequential( - nn.MaxPool3d(2), - ConvBlock(in_channels, out_channels, dropout_p) - ) - - def forward(self, x): - return self.maxpool_conv(x) - -class UpBlock(nn.Module): - """ - 3D upsampling followed by ConvBlock - - :param in_channels1: (int) Channel number of high-level features. - :param in_channels2: (int) Channel number of low-level features. - :param out_channels: (int) Output channel number. - :param dropout_p: (int) Dropout probability. - :param trilinear: (bool) Use trilinear for up-sampling (by default). - If False, deconvolution is used for up-sampling. - """ - def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, - trilinear=True): - super(UpBlock, self).__init__() - self.trilinear = trilinear - if trilinear: - self.conv1x1 = nn.Conv3d(in_channels1, in_channels2, kernel_size = 1) - self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True) - else: - self.up = nn.Sequential( - nn.BatchNorm3d(in_channels1), - nn.LeakyReLU(), - nn.ConvTranspose3d(in_channels1, in_channels2, kernel_size=2, stride=2) - ) - self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p) - - def forward(self, x1, x2): - if self.trilinear: - x1 = self.conv1x1(x1) - x1 = self.up(x1) - x = torch.cat([x2, x1], dim=1) - return self.conv(x) - -class Encoder(nn.Module): - """ - Encoder of 3D UNet. - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 4 or 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - """ - def __init__(self, params): - super(Encoder, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) - - self.proj = nn.Conv3d(self.in_chns, self.ft_chns[0], kernel_size=3, padding=1) - self.in_conv= ConvBlock(self.ft_chns[0], self.ft_chns[0], self.dropout[0]) - self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) - self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) - self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) - if(len(self.ft_chns) == 5): - self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - - def forward(self, x): - x0 = self.in_conv(self.proj(x)) - x1 = self.down1(x0) - x2 = self.down2(x1) - x3 = self.down3(x2) - output = [x0, x1, x2, x3] - if(len(self.ft_chns) == 5): - x4 = self.down4(x3) - output.append(x4) - return output - -class Decoder(nn.Module): - """ - Decoder of 3D UNet. - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 4 or 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - :param multiscale_pred: (bool) Get multi-scale prediction. - """ - def __init__(self, params): - super(Decoder, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.n_class = self.params['class_num'] - self.trilinear = self.params.get('trilinear', True) - self.mul_pred = self.params.get('multiscale_pred', False) - - assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) - - if(len(self.ft_chns) == 5): - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.trilinear) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.trilinear) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.trilinear) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.trilinear) - self.out_conv = ConvLayer(self.ft_chns[0], self.n_class) - if(self.mul_pred): - self.out_conv1 = ConvLayer(self.ft_chns[1], self.n_class) - self.out_conv2 = ConvLayer(self.ft_chns[2], self.n_class) - self.out_conv3 = ConvLayer(self.ft_chns[3], self.n_class) - - def forward(self, x): - if(len(self.ft_chns) == 5): - assert(len(x) == 5) - x0, x1, x2, x3, x4 = x - x_d3 = self.up1(x4, x3) - else: - assert(len(x) == 4) - x0, x1, x2, x3 = x - x_d3 = x3 - x_d2 = self.up2(x_d3, x2) - x_d1 = self.up3(x_d2, x1) - x_d0 = self.up4(x_d1, x0) - output = self.out_conv(x_d0) - if(self.mul_pred): - output1 = self.out_conv1(x_d1) - output2 = self.out_conv2(x_d2) - output3 = self.out_conv3(x_d3) - output = [output, output1, output2, output3] - return output - -class HiFormer_v5(nn.Module): - """ - An implementation of the U-Net. - - * Reference: Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, Olaf Ronneberger: - 3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation. - `MICCAI (2) 2016: 424-432. `_ - - Note that there are some modifications from the original paper, such as - the use of batch normalization, dropout, leaky relu and deep supervision. - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 4 or 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using trilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - :param multiscale_pred: (bool) Get multi-scale prediction. - """ - def __init__(self, params): - super(HiFormer_v5, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.n_class = self.params['class_num'] - self.trilinear = self.params['trilinear'] - self.mul_pred = self.params['multiscale_pred'] - assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) - - self.proj = nn.Conv3d(self.in_chns, self.ft_chns[0], kernel_size=3, padding=1) - self.in_conv= ConvBlock(self.ft_chns[0], self.ft_chns[0], self.dropout[0]) - self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) - self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) - self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) - if(len(self.ft_chns) == 5): - self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], - dropout_p = self.dropout[3], trilinear=self.trilinear) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], - dropout_p = self.dropout[2], trilinear=self.trilinear) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], - dropout_p = self.dropout[1], trilinear=self.trilinear) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], - dropout_p = self.dropout[0], trilinear=self.trilinear) - - self.out_conv = ConvLayer(self.ft_chns[0], self.n_class) - if(self.mul_pred): - self.out_conv1 = ConvLayer(self.ft_chns[1], self.n_class) - self.out_conv2 = ConvLayer(self.ft_chns[2], self.n_class) - self.out_conv3 = ConvLayer(self.ft_chns[3], self.n_class) - - def forward(self, x): - x0 = self.in_conv(self.proj(x)) - x1 = self.down1(x0) - x2 = self.down2(x1) - x3 = self.down3(x2) - if(len(self.ft_chns) == 5): - x4 = self.down4(x3) - x_d3 = self.up1(x4, x3) - else: - x_d3 = x3 - x_d2 = self.up2(x_d3, x2) - x_d1 = self.up3(x_d2, x1) - x_d0 = self.up4(x_d1, x0) - output = self.out_conv(x_d0) - if(self.mul_pred): - output1 = self.out_conv1(x_d1) - output2 = self.out_conv2(x_d2) - output3 = self.out_conv3(x_d3) - output = [output, output1, output2, output3] - return output - -if __name__ == "__main__": - params = {'in_chns':4, - 'class_num': 2, - 'feature_chns':[32, 64, 128, 256, 512], - 'dropout' : [0, 0, 0, 0, 0.5], - 'trilinear': False, - 'multiscale_pred': False} - Net = HiFormer_v5(params) - Net = Net.double() - - x = np.random.rand(4, 4, 96, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - y = y.detach().numpy() - print(y.shape) diff --git a/pymic/net/net3d/trans3d/MedFormer_v1.py b/pymic/net/net3d/trans3d/MedFormer_v1.py deleted file mode 100644 index 1f2ed54..0000000 --- a/pymic/net/net3d/trans3d/MedFormer_v1.py +++ /dev/null @@ -1,173 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division - -import math -import torch -import torch.nn as nn -import numpy as np -from torch.nn import Dropout, Softmax, Linear, Conv2d, LayerNorm -from pymic.net.net3d.unet3d import Encoder, Decoder - -class Attention(nn.Module): - def __init__(self, params): - super(Attention, self).__init__() - hidden_size = params["attention_hidden_size"] - self.num_attention_heads = params["attention_num_heads"] - self.attention_head_size = int(hidden_size / self.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - - self.query = Linear(hidden_size, self.all_head_size) - self.key = Linear(hidden_size, self.all_head_size) - self.value = Linear(hidden_size, self.all_head_size) - - self.out = Linear(hidden_size, hidden_size) - self.attn_dropout = Dropout(params["attention_dropout_rate"]) - self.proj_dropout = Dropout(params["attention_dropout_rate"]) - - self.softmax = Softmax(dim=-1) - - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - - def forward(self, hidden_states): - mixed_query_layer = self.query(hidden_states) - mixed_key_layer = self.key(hidden_states) - mixed_value_layer = self.value(hidden_states) - - query_layer = self.transpose_for_scores(mixed_query_layer) - key_layer = self.transpose_for_scores(mixed_key_layer) - value_layer = self.transpose_for_scores(mixed_value_layer) - - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - attention_probs = self.softmax(attention_scores) - # weights = attention_probs if self.vis else None - attention_probs = self.attn_dropout(attention_probs) - - context_layer = torch.matmul(attention_probs, value_layer) - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) - attention_output = self.out(context_layer) - attention_output = self.proj_dropout(attention_output) - return attention_output - -class MLP(nn.Module): - def __init__(self, params): - super(MLP, self).__init__() - hidden_size = params["attention_hidden_size"] - mlp_dim = params["attention_mlp_dim"] - self.fc1 = Linear(hidden_size, mlp_dim) - self.fc2 = Linear(mlp_dim, hidden_size) - self.act_fn = torch.nn.functional.gelu - self.dropout = Dropout(params["attention_dropout_rate"]) - - self._init_weights() - - def _init_weights(self): - nn.init.xavier_uniform_(self.fc1.weight) - nn.init.xavier_uniform_(self.fc2.weight) - nn.init.normal_(self.fc1.bias, std=1e-6) - nn.init.normal_(self.fc2.bias, std=1e-6) - - def forward(self, x): - x = self.fc1(x) - x = self.act_fn(x) - x = self.dropout(x) - x = self.fc2(x) - x = self.dropout(x) - return x - -class Block(nn.Module): - def __init__(self, params): - super(Block, self).__init__() - hidden_size = params["attention_hidden_size"] - self.attention_norm = LayerNorm(hidden_size, eps=1e-6) - self.ffn_norm = LayerNorm(hidden_size, eps=1e-6) - self.ffn = MLP(params) - self.attn = Attention(params) - - def forward(self, x): - # convert the tensor shape from [B, C, D, H, W] to [B, DHW, C] - [B, C, D, H, W] = list(x.shape) - new_shape = [B, C, D*H*W] - x = torch.reshape(x, new_shape) - x = torch.transpose(x, 1, 2) - - h = x - x = self.attention_norm(x) - x = self.attn(x) - x = x + h - - h = x - x = self.ffn_norm(x) - x = self.ffn(x) - x = x + h - - # convert the result back to [B, C, D, H, W] - x = torch.transpose(x, 1, 2) - x = torch.reshape(x, [B, C, D, H, W]) - return x - -class MedFormerV1(nn.Module): - """ - An implementation of the U-Net. - - * Reference: Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, Olaf Ronneberger: - 3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation. - `MICCAI (2) 2016: 424-432. `_ - - Note that there are some modifications from the original paper, such as - the use of batch normalization, dropout, leaky relu and deep supervision. - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 4 or 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using trilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - :param deep_supervise: (bool) Using deep supervision for training or not. - """ - def __init__(self, params): - super(MedFormerV1, self).__init__() - self.params = params - self.encoder = Encoder(params) - self.decoder = Decoder(params) - self.attn = Block(params) - - def forward(self, x): - f = self.encoder(x) - f[-1] = self.attn(f[-1]) - output = self.decoder(f) - return output - -if __name__ == "__main__": - params = {'in_chns':4, - 'class_num': 2, - 'feature_chns':[16, 32, 64, 128], - 'dropout' : [0, 0, 0, 0.5], - 'trilinear': True, - 'deep_supervise': True, - 'attention_hidden_size': 128, - 'attention_num_heads': 4, - 'attention_mlp_dim': 256, - 'attention_dropout_rate': 0.2} - Net = MedFormerV1(params) - Net = Net.double() - - x = np.random.rand(1, 4, 96, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - print("output length", len(y)) - for yi in y: - yi = yi.detach().numpy() - print(yi.shape) diff --git a/pymic/net/net3d/trans3d/MedFormer_v2.py b/pymic/net/net3d/trans3d/MedFormer_v2.py deleted file mode 100644 index 00cb295..0000000 --- a/pymic/net/net3d/trans3d/MedFormer_v2.py +++ /dev/null @@ -1,464 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division - -import math -import copy -import torch -import torch.nn as nn -import numpy as np -import torch.nn.functional as F -from pymic.net.net3d.unet3d import ConvBlock, Encoder, Decoder -from pymic.net.net3d.trans3d.MedFormer_v1 import Block -from timm.models.layers import DropPath, to_3tuple, trunc_normal_ - - -# code from nnFormer -class Mlp(nn.Module): - """ Multilayer perceptron.""" - - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -def window_partition(x, window_size): - - B, S, H, W, C = x.shape - x = x.view(B, S // window_size, window_size, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size, window_size, window_size, C) - return windows - - -def window_reverse(windows, window_size, S, H, W): - - B = int(windows.shape[0] / (S * H * W / window_size / window_size / window_size)) - x = windows.view(B, S // window_size, H // window_size, W // window_size, window_size, window_size, window_size, -1) - x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, S, H, W, -1) - return x - - -class WindowAttention(nn.Module): - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): - - super().__init__() - self.dim = dim - self.window_size = window_size - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 - - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), - num_heads)) - - # get pair-wise relative position index for each token inside the window - coords_s = torch.arange(self.window_size[0]) - coords_h = torch.arange(self.window_size[1]) - coords_w = torch.arange(self.window_size[2]) - coords = torch.stack(torch.meshgrid([coords_s, coords_h, coords_w])) - coords_flatten = torch.flatten(coords, 1) - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] - relative_coords = relative_coords.permute(1, 2, 0).contiguous() - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 2] += self.window_size[2] - 1 - - relative_coords[:, :, 0] *= 3 * self.window_size[1] - 1 - relative_coords[:, :, 1] *= 2 * self.window_size[1] - 1 - - relative_position_index = relative_coords.sum(-1) - self.register_buffer("relative_position_index", relative_position_index) - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - trunc_normal_(self.relative_position_bias_table, std=.02) - - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask=None,pos_embed=None): - - B_, N, C = x.shape - - qkv = self.qkv(x) - - qkv=qkv.reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - - q = q * self.scale - attn = (q @ k.transpose(-2, -1).contiguous()) - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1] * self.window_size[2], - self.window_size[0] * self.window_size[1] * self.window_size[2], -1) - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C).contiguous() - if pos_embed is not None: - x = x+pos_embed - x = self.proj(x) - x = self.proj_drop(x) - return x - -class SwinTransformerBlock(nn.Module): - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - - self.attn = WindowAttention( - dim, window_size=to_3tuple(self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) - - - - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - - def forward(self, x, mask_matrix): - - B, L, C = x.shape - S, H, W = self.input_resolution - - assert L == S * H * W, "input feature has wrong size" - - shortcut = x - x = self.norm1(x) - x = x.view(B, S, H, W, C) - - # pad feature maps to multiples of window size - pad_r = (self.window_size - W % self.window_size) % self.window_size - pad_b = (self.window_size - H % self.window_size) % self.window_size - pad_g = (self.window_size - S % self.window_size) % self.window_size - - x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b, 0, pad_g)) - _, Sp, Hp, Wp, _ = x.shape - - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size,-self.shift_size), dims=(1, 2,3)) - attn_mask = mask_matrix - else: - shifted_x = x - attn_mask = None - - # partition windows - x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size * self.window_size * self.window_size, - C) - - # W-MSA/SW-MSA - attn_windows = self.attn(x_windows, mask=attn_mask,pos_embed=None) - - # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, self.window_size, Sp, Hp, Wp) - - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size, self.shift_size), dims=(1, 2, 3)) - else: - x = shifted_x - - if pad_r > 0 or pad_b > 0 or pad_g > 0: - x = x[:, :S, :H, :W, :].contiguous() - - x = x.view(B, S * H * W, C) - - # FFN - x = shortcut + self.drop_path(x) - x = x + self.drop_path(self.mlp(self.norm2(x))) - - return x - -class BasicLayer(nn.Module): - - def __init__(self, - dim, - input_resolution, - depth, - num_heads, - window_size=7, - mlp_ratio=4., - qkv_bias=True, - qk_scale=None, - drop=0., - attn_drop=0., - drop_path=0., - norm_layer=nn.LayerNorm, - downsample=True - ): - super().__init__() - self.window_size = window_size - self.shift_size = window_size // 2 - self.depth = depth - # build blocks - - self.blocks = nn.ModuleList([ - SwinTransformerBlock( - dim=dim, - input_resolution=input_resolution, - num_heads=num_heads, - window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop, - attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) - for i in range(depth)]) - - # patch merging layer - if downsample is not None: - self.downsample = downsample(dim=dim, norm_layer=norm_layer) - else: - self.downsample = None - - def forward(self, x, S, H, W): - - - # calculate attention mask for SW-MSA - Sp = int(np.ceil(S / self.window_size)) * self.window_size - Hp = int(np.ceil(H / self.window_size)) * self.window_size - Wp = int(np.ceil(W / self.window_size)) * self.window_size - img_mask = torch.zeros((1, Sp, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 - s_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for s in s_slices: - for h in h_slices: - for w in w_slices: - img_mask[:, s, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) - mask_windows = mask_windows.view(-1, - self.window_size * self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - for blk in self.blocks: - - x = blk(x, attn_mask) - if self.downsample is not None: - x_down = self.downsample(x, S, H, W) - Ws, Wh, Ww = (S + 1) // 2, (H + 1) // 2, (W + 1) // 2 - return x, S, H, W, x_down, Ws, Wh, Ww - else: - return x, S, H, W, x, S, H, W - - -class AttUpBlock(nn.Module): - """ - 3D upsampling followed by ConvBlock - - :param in_channels1: (int) Channel number of high-level features. - :param in_channels2: (int) Channel number of low-level features. - :param out_channels: (int) Output channel number. - :param dropout_p: (int) Dropout probability. - :param trilinear: (bool) Use trilinear for up-sampling (by default). - If False, deconvolution is used for up-sampling. - """ - def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, - trilinear=True, with_att = False, att_params = None): - super(AttUpBlock, self).__init__() - self.trilinear = trilinear - self.with_att = with_att - if trilinear: - self.conv1x1 = nn.Conv3d(in_channels1, in_channels2, kernel_size = 1) - self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True) - else: - self.up = nn.ConvTranspose3d(in_channels1, in_channels2, kernel_size=2, stride=2) - self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p) - if(self.with_att): - input_resolution = att_params['input_resolution'] - depth = att_params['depth'] - num_heads = att_params['num_heads'] - self.attn = BasicLayer(out_channels, input_resolution, depth, num_heads, downsample=None) - - def forward(self, x1, x2): - if self.trilinear: - x1 = self.conv1x1(x1) - x1 = self.up(x1) - x = torch.cat([x2, x1], dim=1) - x = self.conv(x) - if(self.with_att): - [B, C, D, H, W] = list(x.shape) - x = x.flatten(2).transpose(1, 2).contiguous() - x = self.attn(x, D, H, W)[0] - x = x.view(-1, D, H, W, C).permute(0, 4, 1, 2, 3).contiguous() - return x - -class AttDecoder(nn.Module): - """ - Decoder of 3D UNet. - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 4 or 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - """ - def __init__(self, params): - super(AttDecoder, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.n_class = self.params['class_num'] - self.trilinear = self.params.get('trilinear', True) - self.mul_pred = self.params['multiscale_pred'] - - assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) - - if(len(self.ft_chns) == 5): - self.up1 = AttUpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.trilinear) - att_params = {"input_resolution": [24, 24, 24], "depth": 2, "num_heads": 4} - self.up2 = AttUpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.trilinear, True, att_params) - att_params = {"input_resolution": [48, 48, 48], "depth": 2, "num_heads": 4} - self.up3 = AttUpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.trilinear, True, att_params) - self.up4 = AttUpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.trilinear) - self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 1) - if(self.mul_pred): - self.out_conv1 = nn.Conv3d(self.ft_chns[1], self.n_class, kernel_size = 1) - self.out_conv2 = nn.Conv3d(self.ft_chns[2], self.n_class, kernel_size = 1) - self.out_conv3 = nn.Conv3d(self.ft_chns[3], self.n_class, kernel_size = 1) - - def forward(self, x): - if(len(self.ft_chns) == 5): - assert(len(x) == 5) - x0, x1, x2, x3, x4 = x - x_d3 = self.up1(x4, x3) - else: - assert(len(x) == 4) - x0, x1, x2, x3 = x - x_d3 = x3 - x_d2 = self.up2(x_d3, x2) - x_d1 = self.up3(x_d2, x1) - x_d0 = self.up4(x_d1, x0) - output = self.out_conv(x_d0) - if(self.mul_pred): - output1 = self.out_conv1(x_d1) - output2 = self.out_conv2(x_d2) - output3 = self.out_conv3(x_d3) - output = [output, output1, output2, output3] - return output - -class MedFormerV2(nn.Module): - """ - An implementation of the U-Net. - - * Reference: Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, Olaf Ronneberger: - 3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation. - `MICCAI (2) 2016: 424-432. `_ - - Note that there are some modifications from the original paper, such as - the use of batch normalization, dropout, leaky relu and deep supervision. - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 4 or 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using trilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - """ - def __init__(self, params): - super(MedFormerV2, self).__init__() - self.params = params - self.encoder = Encoder(params) - self.decoder = AttDecoder(params) - self.attn = Block(params) - - def forward(self, x): - f = self.encoder(x) - f[-1] = self.attn(f[-1]) - output = self.decoder(f) - return output - -if __name__ == "__main__": - params = {'in_chns':4, - 'class_num': 2, - 'feature_chns':[16, 32, 64, 128], - 'dropout' : [0, 0, 0, 0.5], - 'trilinear': True, - 'multiscale_pred': True, - 'attention_hidden_size': 128, - 'attention_num_heads': 4, - 'attention_mlp_dim': 256, - 'attention_dropout_rate': 0.2} - - Net = MedFormerV2(params) - Net = Net.double() - - x = np.random.rand(1, 4, 96, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - print("output length", len(y)) - for yi in y: - yi = yi.detach().numpy() - print(yi.shape) diff --git a/pymic/net/net3d/trans3d/MedFormer_v3.py b/pymic/net/net3d/trans3d/MedFormer_v3.py deleted file mode 100644 index f119a9c..0000000 --- a/pymic/net/net3d/trans3d/MedFormer_v3.py +++ /dev/null @@ -1,255 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division - -import numpy as np -import torch -import torch.nn as nn -from torch.nn.functional import interpolate -from pymic.net.net3d.unet3d import ConvBlock, Encoder -from pymic.net.net3d.trans3d.MedFormer_v1 import Block -from pymic.net.net3d.trans3d.MedFormer_v2 import SwinTransformerBlock, window_partition - -class GLAttLayer(nn.Module): - def __init__(self, - dim, - input_resolution, - num_heads, - window_size=7, - mlp_ratio=4., - qkv_bias=True, - qk_scale=None, - drop=0., - attn_drop=0., - drop_path=0., - norm_layer=nn.LayerNorm): - super().__init__() - self.window_size = window_size - self.shift_size = window_size // 2 - # build blocks - - self.lcl_att = SwinTransformerBlock( - dim=dim, - input_resolution=input_resolution, - num_heads=num_heads, - window_size=window_size, - shift_size=0, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop, - attn_drop=attn_drop, - drop_path=drop_path) - self.adpool = nn.AdaptiveAvgPool3d([12, 12, 12]) - - params = {'attention_hidden_size': dim, - 'attention_num_heads': 4, - 'attention_mlp_dim': dim, - 'attention_dropout_rate': 0.2} - self.glb_att = Block(params) - self.conv1x1 = nn.Sequential( - nn.Conv3d(2*dim, dim, kernel_size=1), - nn.BatchNorm3d(dim), - nn.LeakyReLU()) - - def forward(self, x): - [B, C, S, H, W] = list(x.shape) - # calculate attention mask for SW-MSA - Sp = int(np.ceil(S / self.window_size)) * self.window_size - Hp = int(np.ceil(H / self.window_size)) * self.window_size - Wp = int(np.ceil(W / self.window_size)) * self.window_size - img_mask = torch.zeros((1, Sp, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 - s_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for s in s_slices: - for h in h_slices: - for w in w_slices: - img_mask[:, s, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) - mask_windows = mask_windows.view(-1, - self.window_size * self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - - # for local attention - xl = x.flatten(2).transpose(1, 2).contiguous() - xl = self.lcl_att(xl, attn_mask) - xl = xl.view(-1, S, H, W, C).permute(0, 4, 1, 2, 3).contiguous() - - # for global attention - xg = self.adpool(x) - xg = self.glb_att(xg) - xg = interpolate(xg, [S, H, W], mode = 'trilinear') - out = torch.cat([xl, xg], dim=1) - out = self.conv1x1(out) - return out - -class AttUpBlock(nn.Module): - """ - 3D upsampling followed by ConvBlock - - :param in_channels1: (int) Channel number of high-level features. - :param in_channels2: (int) Channel number of low-level features. - :param out_channels: (int) Output channel number. - :param dropout_p: (int) Dropout probability. - :param trilinear: (bool) Use trilinear for up-sampling (by default). - If False, deconvolution is used for up-sampling. - """ - def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, - trilinear=True, with_att = False, att_params = None): - super(AttUpBlock, self).__init__() - self.trilinear = trilinear - self.with_att = with_att - if trilinear: - self.conv1x1 = nn.Conv3d(in_channels1, in_channels2, kernel_size = 1) - self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True) - else: - self.up = nn.ConvTranspose3d(in_channels1, in_channels2, kernel_size=2, stride=2) - self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p) - if(self.with_att): - input_resolution = att_params['input_resolution'] - num_heads = att_params['num_heads'] - window_size = att_params['window_size'] - self.attn = GLAttLayer(out_channels, input_resolution, num_heads, window_size, 2.0) - - def forward(self, x1, x2): - if self.trilinear: - x1 = self.conv1x1(x1) - x1 = self.up(x1) - x = torch.cat([x2, x1], dim=1) - x = self.conv(x) - if(self.with_att): - x = self.attn(x) - return x - - -class AttDecoder(nn.Module): - """ - Decoder of 3D UNet. - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 4 or 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - """ - def __init__(self, params): - super(AttDecoder, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.n_class = self.params['class_num'] - self.trilinear = self.params.get('trilinear', True) - self.mul_pred = self.params['multiscale_pred'] - - assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) - - if(len(self.ft_chns) == 5): - self.up1 = AttUpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.trilinear) - att_params = {"input_resolution": [24, 24, 24], "num_heads": 4, "window_size": 7} - self.up2 = AttUpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.trilinear, True, att_params) - att_params = {"input_resolution": [48, 48, 48], "num_heads": 4, "window_size": 7} - self.up3 = AttUpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.trilinear, True, att_params) - self.up4 = AttUpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.trilinear) - self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 1) - if(self.mul_pred): - self.out_conv1 = nn.Conv3d(self.ft_chns[1], self.n_class, kernel_size = 1) - self.out_conv2 = nn.Conv3d(self.ft_chns[2], self.n_class, kernel_size = 1) - self.out_conv3 = nn.Conv3d(self.ft_chns[3], self.n_class, kernel_size = 1) - - def forward(self, x): - if(len(self.ft_chns) == 5): - assert(len(x) == 5) - x0, x1, x2, x3, x4 = x - x_d3 = self.up1(x4, x3) - else: - assert(len(x) == 4) - x0, x1, x2, x3 = x - x_d3 = x3 - x_d2 = self.up2(x_d3, x2) - x_d1 = self.up3(x_d2, x1) - x_d0 = self.up4(x_d1, x0) - output = self.out_conv(x_d0) - if(self.mul_pred): - output1 = self.out_conv1(x_d1) - output2 = self.out_conv2(x_d2) - output3 = self.out_conv3(x_d3) - output = [output, output1, output2, output3] - return output - -class MedFormerV3(nn.Module): - """ - An implementation of the U-Net. - - * Reference: Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, Olaf Ronneberger: - 3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation. - `MICCAI (2) 2016: 424-432. `_ - - Note that there are some modifications from the original paper, such as - the use of batch normalization, dropout, leaky relu and deep supervision. - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 4 or 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using trilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - """ - def __init__(self, params): - super(MedFormerV3, self).__init__() - self.params = params - self.encoder = Encoder(params) - self.decoder = AttDecoder(params) - params["attention_hidden_size"] = params['feature_chns'][-1] - params["attention_mlp_dim"] = params['feature_chns'][-1] - self.attn = Block(params) - - def forward(self, x): - f = self.encoder(x) - f[-1] = self.attn(f[-1]) - output = self.decoder(f) - return output - -if __name__ == "__main__": - params = {'in_chns':4, - 'class_num': 2, - 'feature_chns':[16, 32, 64, 128], - 'dropout' : [0, 0, 0, 0.5], - 'trilinear': True, - 'multiscale_pred': True, - 'attention_num_heads': 4, - 'attention_dropout_rate': 0.2} - - Net = MedFormerV3(params) - Net = Net.double() - - x = np.random.rand(2, 4, 96, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - print("output length", len(y)) - for yi in y: - yi = yi.detach().numpy() - print(yi.shape) diff --git a/pymic/net/net3d/trans3d/MedFormer_va1.py b/pymic/net/net3d/trans3d/MedFormer_va1.py deleted file mode 100644 index 27dfa3e..0000000 --- a/pymic/net/net3d/trans3d/MedFormer_va1.py +++ /dev/null @@ -1,105 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division - -import math -import torch -import torch.nn as nn -import numpy as np -from torch.nn import Dropout, Softmax, Linear, Conv2d, LayerNorm -from pymic.net.net3d.unet3d import Decoder - -class EmbeddingBlock(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, padding, stride): - super(EmbeddingBlock, self).__init__() - self.out_channels = out_channels - self.conv1 = nn.Conv3d(in_channels, out_channels//2, kernel_size=kernel_size, padding=padding, stride = stride) - self.conv2 = nn.Conv3d(out_channels//2, out_channels, kernel_size=1) - self.act = nn.GELU() - self.norm1 = nn.LayerNorm(out_channels//2) - self.norm2 = nn.LayerNorm(out_channels) - - - def forward(self, x): - x = self.act(self.conv1(x)) - # norm 1 - Ws, Wh, Ww = x.size(2), x.size(3), x.size(4) - x = x.flatten(2).transpose(1, 2).contiguous() - x = self.norm1(x) - x = x.transpose(1, 2).contiguous().view(-1, self.out_channels // 2, Ws, Wh, Ww) - - x = self.act(self.conv2(x)) - x = x.flatten(2).transpose(1, 2).contiguous() - x = self.norm2(x) - x = x.transpose(1, 2).contiguous().view(-1, self.out_channels, Ws, Wh, Ww) - - return x - -class Encoder(nn.Module): - """ - Encoder of 3D UNet. - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 4 or 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - """ - def __init__(self, params): - super(Encoder, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - assert(len(self.ft_chns) == 4) - - self.down0 = EmbeddingBlock(self.in_chns, self.ft_chns[0], 3, 1, 1) - self.down1 = EmbeddingBlock(self.in_chns, self.ft_chns[1], 2, 0, 2) - self.down2 = EmbeddingBlock(self.in_chns, self.ft_chns[2], 4, 0, 4) - self.down3 = EmbeddingBlock(self.in_chns, self.ft_chns[3], 8, 0, 8) - - def forward(self, x): - x0 = self.down0(x) - x1 = self.down1(x) - x2 = self.down2(x) - x3 = self.down3(x) - output = [x0, x1, x2, x3] - return output - -class MedFormerVA1(nn.Module): - def __init__(self, params): - super(MedFormerVA1, self).__init__() - self.params = params - self.encoder = Encoder(params) - self.decoder = Decoder(params) - - def forward(self, x): - f = self.encoder(x) - output = self.decoder(f) - return output - - -if __name__ == "__main__": - params = {'in_chns':1, - 'class_num': 8, - 'feature_chns':[16, 32, 64, 128], - 'dropout' : [0, 0, 0, 0.5], - 'trilinear': True, - 'deep_supervise': True, - 'attention_hidden_size': 128, - 'attention_num_heads': 4, - 'attention_mlp_dim': 256, - 'attention_dropout_rate': 0.2} - Net = MedFormerVA1(params) - Net = Net.double() - - x = np.random.rand(1, 1, 96, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - print("output length", len(y)) - for yi in y: - yi = yi.detach().numpy() - print(yi.shape) \ No newline at end of file diff --git a/pymic/net/net3d/trans3d/__init__.py b/pymic/net/net3d/trans3d/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/pymic/net/net3d/trans3d/nnFormer_wrap.py b/pymic/net/net3d/trans3d/nnFormer_wrap.py deleted file mode 100644 index 35593a4..0000000 --- a/pymic/net/net3d/trans3d/nnFormer_wrap.py +++ /dev/null @@ -1,43 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division - -import math -import torch -import torch.nn as nn -import numpy as np -from nnformer.network_architecture.nnFormer_tumor import nnFormer - -class nnFormer_wrap(nn.Module): - def __init__(self, params): - super(nnFormer_wrap, self).__init__() - patch_size = params["patch_size"] # 96x96x96 - n_class = params['class_num'] - in_chns = params['in_chns'] - # https://github.com/282857341/nnFormer/blob/main/nnformer/network_architecture/nnFormer_tumor.py - self.nnformer = nnFormer(crop_size = patch_size, - embedding_dim=192, - input_channels = in_chns, - num_classes = n_class, - conv_op=nn.Conv3d, - depths =[2,2,2,2], - num_heads = [6, 12, 24, 48], - patch_size = [4,4,4], - window_size= [4,4,8,4], - deep_supervision=False) - - def forward(self, x): - return self.nnformer(x) - -if __name__ == "__main__": - params = {"patch_size": [96, 96, 96], - "in_chns": 1, - "class_num": 5} - Net = nnFormer_wrap(params) - Net = Net.double() - - x = np.random.rand(1, 1, 96, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - print(y.shape) diff --git a/pymic/net/net3d/trans3d/unetr.py b/pymic/net/net3d/trans3d/unetr.py deleted file mode 100644 index ea90b2f..0000000 --- a/pymic/net/net3d/trans3d/unetr.py +++ /dev/null @@ -1,227 +0,0 @@ -from __future__ import print_function, division - -import torch -import torch.nn as nn - -from monai.networks.blocks import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock -from monai.networks.blocks.dynunet_block import UnetOutBlock -from monai.networks.nets import ViT - - -class UNETR(nn.Module): - """ - UNETR based on: "Hatamizadeh et al., - UNETR: Transformers for 3D Medical Image Segmentation " - """ - - def __init__(self, params): - # in_channels: int, - # out_channels: int, - # img_size: Tuple[int, int, int], - # feature_size: int = 16, - # hidden_size: int = 768, - # mlp_dim: int = 3072, - # num_heads: int = 12, - # pos_embed: str = "perceptron", - # norm_name: Union[Tuple, str] = "instance", - # conv_block: bool = False, - # res_block: bool = True, - # dropout_rate: float = 0.0, - # ) -> None: - """ - Args: - in_channels: dimension of input channels. - out_channels: dimension of output channels. - img_size: dimension of input image. - feature_size: dimension of network feature size. - hidden_size: dimension of hidden layer. - mlp_dim: dimension of feedforward layer. - num_heads: number of attention heads. - pos_embed: position embedding layer type. - norm_name: feature normalization type and arguments. - conv_block: bool argument to determine if convolutional block is used. - res_block: bool argument to determine if residual block is used. - dropout_rate: faction of the input units to drop. - Examples:: - # for single channel input 4-channel output with patch size of (96,96,96), feature size of 32 and batch norm - >>> net = UNETR(in_channels=1, out_channels=4, img_size=(96,96,96), feature_size=32, norm_name='batch') - # for 4-channel input 3-channel output with patch size of (128,128,128), conv position embedding and instance norm - >>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), pos_embed='conv', norm_name='instance') - """ - - super().__init__() - in_channels = params['in_chns'] - out_channels = params['class_num'] - img_size = params['img_size'] - feature_size = 16 - hidden_size = 768 - mlp_dim = 3072 - num_heads = 12 - pos_embed = "perceptron" - norm_name = "instance" - conv_block = False - res_block = True - dropout_rate = 0.0 - - if not (0 <= dropout_rate <= 1): - raise AssertionError("dropout_rate should be between 0 and 1.") - - if hidden_size % num_heads != 0: - raise AssertionError("hidden size should be divisible by num_heads.") - - if pos_embed not in ["conv", "perceptron"]: - raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.") - - self.num_layers = 12 - self.patch_size = (16, 16, 16) - self.feat_size = ( - img_size[0] // self.patch_size[0], - img_size[1] // self.patch_size[1], - img_size[2] // self.patch_size[2], - ) - self.hidden_size = hidden_size - self.classification = False - self.vit = ViT( - in_channels=in_channels, - img_size=img_size, - patch_size=self.patch_size, - hidden_size=hidden_size, - mlp_dim=mlp_dim, - num_layers=self.num_layers, - num_heads=num_heads, - pos_embed=pos_embed, - classification=self.classification, - dropout_rate=dropout_rate, - ) - self.encoder1 = UnetrBasicBlock( - spatial_dims=3, - in_channels=in_channels, - out_channels=feature_size, - kernel_size=3, - stride=1, - norm_name=norm_name, - res_block=res_block, - ) - self.encoder2 = UnetrPrUpBlock( - spatial_dims=3, - in_channels=hidden_size, - out_channels=feature_size * 2, - num_layer=2, - kernel_size=3, - stride=1, - upsample_kernel_size=2, - norm_name=norm_name, - conv_block=conv_block, - res_block=res_block, - ) - self.encoder3 = UnetrPrUpBlock( - spatial_dims=3, - in_channels=hidden_size, - out_channels=feature_size * 4, - num_layer=1, - kernel_size=3, - stride=1, - upsample_kernel_size=2, - norm_name=norm_name, - conv_block=conv_block, - res_block=res_block, - ) - self.encoder4 = UnetrPrUpBlock( - spatial_dims=3, - in_channels=hidden_size, - out_channels=feature_size * 8, - num_layer=0, - kernel_size=3, - stride=1, - upsample_kernel_size=2, - norm_name=norm_name, - conv_block=conv_block, - res_block=res_block, - ) - self.decoder5 = UnetrUpBlock( - spatial_dims=3, - in_channels=hidden_size, - out_channels=feature_size * 8, - kernel_size=3, - upsample_kernel_size=2, - norm_name=norm_name, - res_block=res_block, - ) - self.decoder4 = UnetrUpBlock( - spatial_dims=3, - in_channels=feature_size * 8, - out_channels=feature_size * 4, - kernel_size=3, - upsample_kernel_size=2, - norm_name=norm_name, - res_block=res_block, - ) - self.decoder3 = UnetrUpBlock( - spatial_dims=3, - in_channels=feature_size * 4, - out_channels=feature_size * 2, - kernel_size=3, - upsample_kernel_size=2, - norm_name=norm_name, - res_block=res_block, - ) - self.decoder2 = UnetrUpBlock( - spatial_dims=3, - in_channels=feature_size * 2, - out_channels=feature_size, - kernel_size=3, - upsample_kernel_size=2, - norm_name=norm_name, - res_block=res_block, - ) - self.out = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels) # type: ignore - - def proj_feat(self, x, hidden_size, feat_size): - x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) - x = x.permute(0, 4, 1, 2, 3).contiguous() - return x - - def load_from(self, weights): - with torch.no_grad(): - res_weight = weights - # copy weights from patch embedding - for i in weights["state_dict"]: - print(i) - self.vit.patch_embedding.position_embeddings.copy_( - weights["state_dict"]["module.transformer.patch_embedding.position_embeddings_3d"] - ) - self.vit.patch_embedding.cls_token.copy_( - weights["state_dict"]["module.transformer.patch_embedding.cls_token"] - ) - self.vit.patch_embedding.patch_embeddings[1].weight.copy_( - weights["state_dict"]["module.transformer.patch_embedding.patch_embeddings.1.weight"] - ) - self.vit.patch_embedding.patch_embeddings[1].bias.copy_( - weights["state_dict"]["module.transformer.patch_embedding.patch_embeddings.1.bias"] - ) - - # copy weights from encoding blocks (default: num of blocks: 12) - for bname, block in self.vit.blocks.named_children(): - print(block) - block.loadFrom(weights, n_block=bname) - # last norm layer of transformer - self.vit.norm.weight.copy_(weights["state_dict"]["module.transformer.norm.weight"]) - self.vit.norm.bias.copy_(weights["state_dict"]["module.transformer.norm.bias"]) - - def forward(self, x_in): - x, hidden_states_out = self.vit(x_in) - enc1 = self.encoder1(x_in) - x2 = hidden_states_out[3] - enc2 = self.encoder2(self.proj_feat(x2, self.hidden_size, self.feat_size)) - x3 = hidden_states_out[6] - enc3 = self.encoder3(self.proj_feat(x3, self.hidden_size, self.feat_size)) - x4 = hidden_states_out[9] - enc4 = self.encoder4(self.proj_feat(x4, self.hidden_size, self.feat_size)) - dec4 = self.proj_feat(x, self.hidden_size, self.feat_size) - dec3 = self.decoder5(dec4, enc4) - dec2 = self.decoder4(dec3, enc3) - dec1 = self.decoder3(dec2, enc2) - out = self.decoder2(dec1, enc1) - logits = self.out(out) - return logits - diff --git a/pymic/net/net3d/trans3d/unetr_pp.py b/pymic/net/net3d/trans3d/unetr_pp.py deleted file mode 100644 index a4ab7e6..0000000 --- a/pymic/net/net3d/trans3d/unetr_pp.py +++ /dev/null @@ -1,469 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from typing import Optional, Sequence, Tuple, Union -from pymic.net.net3d.trans3d.unetr_pp_block import UnetOutBlock, UnetResBlock, get_conv_layer -from timm.models.layers import trunc_normal_ -from monai.utils import optional_import -from monai.networks.blocks.convolutions import Convolution -from monai.networks.layers.factories import Act, Norm -from monai.networks.layers.utils import get_act_layer, get_norm_layer - -einops, _ = optional_import("einops") - -class LayerNorm(nn.Module): - def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): - super().__init__() - self.weight = nn.Parameter(torch.ones(normalized_shape)) - self.bias = nn.Parameter(torch.zeros(normalized_shape)) - self.eps = eps - self.data_format = data_format - if self.data_format not in ["channels_last", "channels_first"]: - raise NotImplementedError - self.normalized_shape = (normalized_shape,) - - def forward(self, x): - if self.data_format == "channels_last": - return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) - elif self.data_format == "channels_first": - u = x.mean(1, keepdim=True) - s = (x - u).pow(2).mean(1, keepdim=True) - x = (x - u) / torch.sqrt(s + self.eps) - x = self.weight[:, None, None] * x + self.bias[:, None, None] - return x - -class EPA(nn.Module): - """ - Efficient Paired Attention Block, based on: "Shaker et al., - UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" - """ - def __init__(self, input_size, hidden_size, proj_size, num_heads=4, qkv_bias=False, - channel_attn_drop=0.1, spatial_attn_drop=0.1): - super().__init__() - self.num_heads = num_heads - self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) - self.temperature2 = nn.Parameter(torch.ones(num_heads, 1, 1)) - - # qkvv are 4 linear layers (query_shared, key_shared, value_spatial, value_channel) - self.qkvv = nn.Linear(hidden_size, hidden_size * 4, bias=qkv_bias) - - # E and F are projection matrices with shared weights used in spatial attention module to project - # keys and values from HWD-dimension to P-dimension - self.E = self.F = nn.Linear(input_size, proj_size) - - self.attn_drop = nn.Dropout(channel_attn_drop) - self.attn_drop_2 = nn.Dropout(spatial_attn_drop) - - self.out_proj = nn.Linear(hidden_size, int(hidden_size // 2)) - self.out_proj2 = nn.Linear(hidden_size, int(hidden_size // 2)) - - def forward(self, x): - B, N, C = x.shape - - qkvv = self.qkvv(x).reshape(B, N, 4, self.num_heads, C // self.num_heads) - - qkvv = qkvv.permute(2, 0, 3, 1, 4) - - q_shared, k_shared, v_CA, v_SA = qkvv[0], qkvv[1], qkvv[2], qkvv[3] - - q_shared = q_shared.transpose(-2, -1) - k_shared = k_shared.transpose(-2, -1) - v_CA = v_CA.transpose(-2, -1) - v_SA = v_SA.transpose(-2, -1) - - k_shared_projected = self.E(k_shared) - - v_SA_projected = self.F(v_SA) - - q_shared = torch.nn.functional.normalize(q_shared, dim=-1) - k_shared = torch.nn.functional.normalize(k_shared, dim=-1) - - attn_CA = (q_shared @ k_shared.transpose(-2, -1)) * self.temperature - - attn_CA = attn_CA.softmax(dim=-1) - attn_CA = self.attn_drop(attn_CA) - - x_CA = (attn_CA @ v_CA).permute(0, 3, 1, 2).reshape(B, N, C) - - attn_SA = (q_shared.permute(0, 1, 3, 2) @ k_shared_projected) * self.temperature2 - - attn_SA = attn_SA.softmax(dim=-1) - attn_SA = self.attn_drop_2(attn_SA) - - x_SA = (attn_SA @ v_SA_projected.transpose(-2, -1)).permute(0, 3, 1, 2).reshape(B, N, C) - - # Concat fusion - x_SA = self.out_proj(x_SA) - x_CA = self.out_proj2(x_CA) - x = torch.cat((x_SA, x_CA), dim=-1) - return x - - @torch.jit.ignore - def no_weight_decay(self): - return {'temperature', 'temperature2'} - - -class TransformerBlock(nn.Module): - """ - A transformer block, based on: "Shaker et al., - UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" - """ - - def __init__( - self, - input_size: int, - hidden_size: int, - proj_size: int, - num_heads: int, - dropout_rate: float = 0.0, - pos_embed=False, - ) -> None: - """ - Args: - input_size: the size of the input for each stage. - hidden_size: dimension of hidden layer. - proj_size: projection size for keys and values in the spatial attention module. - num_heads: number of attention heads. - dropout_rate: faction of the input units to drop. - pos_embed: bool argument to determine if positional embedding is used. - - """ - - super().__init__() - - if not (0 <= dropout_rate <= 1): - raise ValueError("dropout_rate should be between 0 and 1.") - - if hidden_size % num_heads != 0: - print("Hidden size is ", hidden_size) - print("Num heads is ", num_heads) - raise ValueError("hidden_size should be divisible by num_heads.") - - self.norm = nn.LayerNorm(hidden_size) - self.gamma = nn.Parameter(1e-6 * torch.ones(hidden_size), requires_grad=True) - self.epa_block = EPA(input_size=input_size, hidden_size=hidden_size, proj_size=proj_size, num_heads=num_heads, channel_attn_drop=dropout_rate,spatial_attn_drop=dropout_rate) - self.conv51 = UnetResBlock(3, hidden_size, hidden_size, kernel_size=3, stride=1, norm_name="batch") - self.conv8 = nn.Sequential(nn.Dropout3d(0.1, False), nn.Conv3d(hidden_size, hidden_size, 1)) - - self.pos_embed = None - if pos_embed: - self.pos_embed = nn.Parameter(torch.zeros(1, input_size, hidden_size)) - - def forward(self, x): - B, C, H, W, D = x.shape - x = x.reshape(B, C, H * W * D).permute(0, 2, 1) - - if self.pos_embed is not None: - x = x + self.pos_embed - attn = x + self.gamma * self.epa_block(self.norm(x)) - - attn_skip = attn.reshape(B, H, W, D, C).permute(0, 4, 1, 2, 3) # (B, C, H, W, D) - attn = self.conv51(attn_skip) - x = attn_skip + self.conv8(attn) - - return x - -class UnetrPPEncoder(nn.Module): - def __init__(self, input_size=[32 * 32 * 32, 16 * 16 * 16, 8 * 8 * 8, 4 * 4 * 4],dims=[32, 64, 128, 256], - proj_size =[64,64,64,32], depths=[3, 3, 3, 3], num_heads=4, spatial_dims=3, - in_channels=1, dropout=0.0, transformer_dropout_rate=0.15, kernel_size=(2,4,4), **kwargs): - super().__init__() - - self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers - stem_layer = nn.Sequential( - get_conv_layer(spatial_dims, in_channels, dims[0], kernel_size=kernel_size, stride=kernel_size, - dropout=dropout, conv_only=True, ), - get_norm_layer(name=("group", {"num_groups": in_channels}), channels=dims[0]), - ) - self.downsample_layers.append(stem_layer) - for i in range(3): - downsample_layer = nn.Sequential( - get_conv_layer(spatial_dims, dims[i], dims[i + 1], kernel_size=(2, 2, 2), stride=(2, 2, 2), - dropout=dropout, conv_only=True, ), - get_norm_layer(name=("group", {"num_groups": dims[i]}), channels=dims[i + 1]), - ) - self.downsample_layers.append(downsample_layer) - - self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple Transformer blocks - for i in range(4): - stage_blocks = [] - for j in range(depths[i]): - stage_blocks.append(TransformerBlock(input_size=input_size[i], hidden_size=dims[i], proj_size=proj_size[i], num_heads=num_heads, - dropout_rate=transformer_dropout_rate, pos_embed=True)) - self.stages.append(nn.Sequential(*stage_blocks)) - self.hidden_states = [] - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, (nn.Conv2d, nn.Linear)): - trunc_normal_(m.weight, std=.02) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, (LayerNorm, nn.LayerNorm)): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - def forward_features(self, x): - hidden_states = [] - x = self.downsample_layers[0](x) - x = self.stages[0](x) - - hidden_states.append(x) - - for i in range(1, 4): - x = self.downsample_layers[i](x) - x = self.stages[i](x) - if i == 3: # Reshape the output of the last stage - x = einops.rearrange(x, "b c h w d -> b (h w d) c") - hidden_states.append(x) - return x, hidden_states - - def forward(self, x): - x, hidden_states = self.forward_features(x) - return x, hidden_states - - -class UnetrUpBlock(nn.Module): - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Union[Sequence[int], int], - upsample_kernel_size: Union[Sequence[int], int], - norm_name: Union[Tuple, str], - proj_size: int = 64, - num_heads: int = 4, - out_size: int = 0, - depth: int = 3, - conv_decoder: bool = False, - ) -> None: - """ - Args: - spatial_dims: number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - kernel_size: convolution kernel size. - upsample_kernel_size: convolution kernel size for transposed convolution layers. - norm_name: feature normalization type and arguments. - proj_size: projection size for keys and values in the spatial attention module. - num_heads: number of heads inside each EPA module. - out_size: spatial size for each decoder. - depth: number of blocks for the current decoder stage. - """ - - super().__init__() - upsample_stride = upsample_kernel_size - self.transp_conv = get_conv_layer( - spatial_dims, - in_channels, - out_channels, - kernel_size=upsample_kernel_size, - stride=upsample_stride, - conv_only=True, - is_transposed=True, - ) - - # 4 feature resolution stages, each consisting of multiple residual blocks - self.decoder_block = nn.ModuleList() - - # If this is the last decoder, use ConvBlock(UnetResBlock) instead of EPA_Block (see suppl. material in the paper) - if conv_decoder == True: - self.decoder_block.append( - UnetResBlock(spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1, - norm_name=norm_name, )) - else: - stage_blocks = [] - for j in range(depth): - stage_blocks.append(TransformerBlock(input_size=out_size, hidden_size= out_channels, proj_size=proj_size, num_heads=num_heads, - dropout_rate=0.15, pos_embed=True)) - self.decoder_block.append(nn.Sequential(*stage_blocks)) - - def _init_weights(self, m): - if isinstance(m, (nn.Conv2d, nn.Linear)): - trunc_normal_(m.weight, std=.02) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, (nn.LayerNorm)): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - def forward(self, inp, skip): - - out = self.transp_conv(inp) - out = out + skip - out = self.decoder_block[0](out) - - return out - - -class UNETR_PP(nn.Module): - """ - UNETR++ based on: "Shaker et al., - UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" - """ - - def __init__(self, params): - """ - Args: - in_channels: dimension of input channels. - out_channels: dimension of output channels. - img_size: dimension of input image. - feature_size: dimension of network feature size. - hidden_size: dimension of the last encoder. - num_heads: number of attention heads. - pos_embed: position embedding layer type. - norm_name: feature normalization type and arguments. - dropout_rate: faction of the input units to drop. - depths: number of blocks for each stage. - dims: number of channel maps for the stages. - conv_op: type of convolution operation. - do_ds: use deep supervision to compute the loss. - - """ - super().__init__() - in_channels = params['in_chns'] - out_channels = params['class_num'] - img_size = params['img_size'] - self.res_mode= params.get("resolution_mode", 1) - feature_size = params.get('feature_size', 16) - hidden_size = params.get('hidden_size', 256) - num_heads = params.get('num_heads', 4) - pos_embed = params.get('pos_embed', "perceptron") - norm_name = params.get('norm_name', "instance") - dropout_rate = params.get('dropout_rate', 0.0) - depths = params.get('depths', [3, 3, 3, 3]) - dims = params.get('dims', [32, 64, 128, 256]) - conv_op = nn.Conv3d - do_ds = params.get('deep_supervise', True) - - self.do_ds = do_ds - self.conv_op = conv_op - self.num_classes = out_channels - if not (0 <= dropout_rate <= 1): - raise AssertionError("dropout_rate should be between 0 and 1.") - - if pos_embed not in ["conv", "perceptron"]: - raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.") - - kernel_ds = [4, 2, 1] - kernel_d = kernel_ds[self.res_mode] - self.patch_size = (kernel_d, 4, 4) - - self.feat_size = ( - img_size[0] // self.patch_size[0] // 8, # 8 is the downsampling happened through the four encoders stages - img_size[1] // self.patch_size[1] // 8, # 8 is the downsampling happened through the four encoders stages - img_size[2] // self.patch_size[2] // 8, # 8 is the downsampling happened through the four encoders stages - ) - - self.hidden_size = hidden_size - - self.unetr_pp_encoder = UnetrPPEncoder(dims=dims, depths=depths, num_heads=num_heads, - in_channels=in_channels, kernel_size=self.patch_size) - - self.encoder1 = UnetResBlock( - spatial_dims=3, - in_channels=in_channels, - out_channels=feature_size, - kernel_size=3, - stride=1, - norm_name=norm_name, - ) - self.decoder5 = UnetrUpBlock( - spatial_dims=3, - in_channels=feature_size * 16, - out_channels=feature_size * 8, - kernel_size=3, - upsample_kernel_size=2, - norm_name=norm_name, - out_size=8 * 8 * 8, - ) - self.decoder4 = UnetrUpBlock( - spatial_dims=3, - in_channels=feature_size * 8, - out_channels=feature_size * 4, - kernel_size=3, - upsample_kernel_size=2, - norm_name=norm_name, - out_size=16 * 16 * 16, - ) - self.decoder3 = UnetrUpBlock( - spatial_dims=3, - in_channels=feature_size * 4, - out_channels=feature_size * 2, - kernel_size=3, - upsample_kernel_size=2, - norm_name=norm_name, - out_size=32 * 32 * 32, - ) - - self.decoder2 = UnetrUpBlock( - spatial_dims=3, - in_channels=feature_size * 2, - out_channels=feature_size, - kernel_size=3, - upsample_kernel_size= self.patch_size, - norm_name=norm_name, - out_size= kernel_d*32 * 128 * 128, - conv_decoder=True, - ) - self.out1 = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels) - # if self.do_ds: - self.out2 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 2, out_channels=out_channels) - self.out3 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 4, out_channels=out_channels) - - def proj_feat(self, x, hidden_size, feat_size): - x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) - x = x.permute(0, 4, 1, 2, 3).contiguous() - return x - - def forward(self, x_in): - x_output, hidden_states = self.unetr_pp_encoder(x_in) - - convBlock = self.encoder1(x_in) - - # Four encoders - enc1 = hidden_states[0] - enc2 = hidden_states[1] - enc3 = hidden_states[2] - enc4 = hidden_states[3] - - # Four decoders - dec4 = self.proj_feat(enc4, self.hidden_size, self.feat_size) - dec3 = self.decoder5(dec4, enc3) - dec2 = self.decoder4(dec3, enc2) - dec1 = self.decoder3(dec2, enc1) - - out = self.decoder2(dec1, convBlock) - if self.do_ds: - logits = [self.out1(out), self.out2(dec1), self.out3(dec2)] - else: - logits = self.out1(out) - - return logits - - -if __name__ == "__main__": - depths = [128, 64, 32] - for i in range(3): - params = {'in_chns': 4, - 'class_num': 2, - 'img_size': [depths[i], 128, 128], - 'resolution_mode': i - } - net = UNETR_PP(params) - net.double() - - x = np.random.rand(2, 4, depths[i], 128, 128) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = net(xt) - print(len(y)) - for yi in y: - yi = yi.detach().numpy() - print(yi.shape) \ No newline at end of file diff --git a/pymic/net/net3d/trans3d/unetr_pp_block.py b/pymic/net/net3d/trans3d/unetr_pp_block.py deleted file mode 100644 index 89a8769..0000000 --- a/pymic/net/net3d/trans3d/unetr_pp_block.py +++ /dev/null @@ -1,278 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division - -import numpy as np -import torch -import torch.nn as nn -from typing import Optional, Sequence, Tuple, Union -from monai.networks.blocks.convolutions import Convolution -from monai.networks.layers.factories import Act, Norm -from monai.networks.layers.utils import get_act_layer, get_norm_layer - - -class UnetResBlock(nn.Module): - """ - A skip-connection based module that can be used for DynUNet, based on: - `Automated Design of Deep Learning Methods for Biomedical Image Segmentation `_. - `nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation `_. - - Args: - spatial_dims: number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - kernel_size: convolution kernel size. - stride: convolution stride. - norm_name: feature normalization type and arguments. - act_name: activation layer type and arguments. - dropout: dropout probability. - - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Union[Sequence[int], int], - stride: Union[Sequence[int], int], - norm_name: Union[Tuple, str], - act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), - dropout: Optional[Union[Tuple, str, float]] = None, - ): - super().__init__() - self.conv1 = get_conv_layer( - spatial_dims, - in_channels, - out_channels, - kernel_size=kernel_size, - stride=stride, - dropout=dropout, - conv_only=True, - ) - self.conv2 = get_conv_layer( - spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1, dropout=dropout, conv_only=True - ) - self.lrelu = get_act_layer(name=act_name) - self.norm1 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) - self.norm2 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) - self.downsample = in_channels != out_channels - stride_np = np.atleast_1d(stride) - if not np.all(stride_np == 1): - self.downsample = True - if self.downsample: - self.conv3 = get_conv_layer( - spatial_dims, in_channels, out_channels, kernel_size=1, stride=stride, dropout=dropout, conv_only=True - ) - self.norm3 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) - - def forward(self, inp): - residual = inp - out = self.conv1(inp) - out = self.norm1(out) - out = self.lrelu(out) - out = self.conv2(out) - out = self.norm2(out) - if hasattr(self, "conv3"): - residual = self.conv3(residual) - if hasattr(self, "norm3"): - residual = self.norm3(residual) - out += residual - out = self.lrelu(out) - return out - - -class UnetBasicBlock(nn.Module): - """ - A CNN module module that can be used for DynUNet, based on: - `Automated Design of Deep Learning Methods for Biomedical Image Segmentation `_. - `nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation `_. - - Args: - spatial_dims: number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - kernel_size: convolution kernel size. - stride: convolution stride. - norm_name: feature normalization type and arguments. - act_name: activation layer type and arguments. - dropout: dropout probability. - - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Union[Sequence[int], int], - stride: Union[Sequence[int], int], - norm_name: Union[Tuple, str], - act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), - dropout: Optional[Union[Tuple, str, float]] = None, - ): - super().__init__() - self.conv1 = get_conv_layer( - spatial_dims, - in_channels, - out_channels, - kernel_size=kernel_size, - stride=stride, - dropout=dropout, - conv_only=True, - ) - self.conv2 = get_conv_layer( - spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1, dropout=dropout, conv_only=True - ) - self.lrelu = get_act_layer(name=act_name) - self.norm1 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) - self.norm2 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) - - def forward(self, inp): - out = self.conv1(inp) - out = self.norm1(out) - out = self.lrelu(out) - out = self.conv2(out) - out = self.norm2(out) - out = self.lrelu(out) - return out - - -class UnetUpBlock(nn.Module): - """ - An upsampling module that can be used for DynUNet, based on: - `Automated Design of Deep Learning Methods for Biomedical Image Segmentation `_. - `nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation `_. - - Args: - spatial_dims: number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - kernel_size: convolution kernel size. - stride: convolution stride. - upsample_kernel_size: convolution kernel size for transposed convolution layers. - norm_name: feature normalization type and arguments. - act_name: activation layer type and arguments. - dropout: dropout probability. - trans_bias: transposed convolution bias. - - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Union[Sequence[int], int], - stride: Union[Sequence[int], int], - upsample_kernel_size: Union[Sequence[int], int], - norm_name: Union[Tuple, str], - act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), - dropout: Optional[Union[Tuple, str, float]] = None, - trans_bias: bool = False, - ): - super().__init__() - upsample_stride = upsample_kernel_size - self.transp_conv = get_conv_layer( - spatial_dims, - in_channels, - out_channels, - kernel_size=upsample_kernel_size, - stride=upsample_stride, - dropout=dropout, - bias=trans_bias, - conv_only=True, - is_transposed=True, - ) - self.conv_block = UnetBasicBlock( - spatial_dims, - out_channels + out_channels, - out_channels, - kernel_size=kernel_size, - stride=1, - dropout=dropout, - norm_name=norm_name, - act_name=act_name, - ) - - def forward(self, inp, skip): - # number of channels for skip should equals to out_channels - out = self.transp_conv(inp) - out = torch.cat((out, skip), dim=1) - out = self.conv_block(out) - return out - - -class UnetOutBlock(nn.Module): - def __init__( - self, spatial_dims: int, in_channels: int, out_channels: int, dropout: Optional[Union[Tuple, str, float]] = None - ): - super().__init__() - self.conv = get_conv_layer( - spatial_dims, in_channels, out_channels, kernel_size=1, stride=1, dropout=dropout, bias=True, conv_only=True - ) - - def forward(self, inp): - return self.conv(inp) - - -def get_conv_layer( - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Union[Sequence[int], int] = 3, - stride: Union[Sequence[int], int] = 1, - act: Optional[Union[Tuple, str]] = Act.PRELU, - norm: Union[Tuple, str] = Norm.INSTANCE, - dropout: Optional[Union[Tuple, str, float]] = None, - bias: bool = False, - conv_only: bool = True, - is_transposed: bool = False, -): - padding = get_padding(kernel_size, stride) - output_padding = None - if is_transposed: - output_padding = get_output_padding(kernel_size, stride, padding) - return Convolution( - spatial_dims, - in_channels, - out_channels, - strides=stride, - kernel_size=kernel_size, - act=act, - norm=norm, - dropout=dropout, - bias=bias, - conv_only=conv_only, - is_transposed=is_transposed, - padding=padding, - output_padding=output_padding, - ) - - -def get_padding( - kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int] -) -> Union[Tuple[int, ...], int]: - - kernel_size_np = np.atleast_1d(kernel_size) - stride_np = np.atleast_1d(stride) - padding_np = (kernel_size_np - stride_np + 1) / 2 - if np.min(padding_np) < 0: - raise AssertionError("padding value should not be negative, please change the kernel size and/or stride.") - padding = tuple(int(p) for p in padding_np) - - return padding if len(padding) > 1 else padding[0] - - -def get_output_padding( - kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], padding: Union[Sequence[int], int] -) -> Union[Tuple[int, ...], int]: - kernel_size_np = np.atleast_1d(kernel_size) - stride_np = np.atleast_1d(stride) - padding_np = np.atleast_1d(padding) - - out_padding_np = 2 * padding_np + stride_np - kernel_size_np - if np.min(out_padding_np) < 0: - raise AssertionError("out_padding value should not be negative, please change the kernel size and/or stride.") - out_padding = tuple(int(p) for p in out_padding_np) - - return out_padding if len(out_padding) > 1 else out_padding[0] diff --git a/pymic/net_run/agent_preprocess.py b/pymic/net_run/agent_preprocess.py index c681de9..67b1262 100644 --- a/pymic/net_run/agent_preprocess.py +++ b/pymic/net_run/agent_preprocess.py @@ -8,8 +8,8 @@ from pymic.io.image_read_write import save_nd_array_as_image from pymic.io.nifty_dataset import NiftyDataset from pymic.transform.trans_dict import TransformDict - - +from pymic.net_run.agent_abstract import seed_torch +from pymic.net_run.self_sup.util import volume_fusion class PreprocessAgent(object): def __init__(self, config): @@ -19,9 +19,14 @@ def __init__(self, config): self.task_type = config['dataset']['task_type'] self.dataloader = None self.dataloader_unlab= None + + deterministic = config['dataset'].get('deterministic', True) + if(deterministic): + random_seed = config['dataset'].get('random_seed', 1) + seed_torch(random_seed) def get_dataset_from_config(self): - root_dir = self.config['dataset']['root_dir'] + root_dir = self.config['dataset']['data_dir'] modal_num = self.config['dataset'].get('modal_num', 1) transform_names = self.config['dataset']["transform"] @@ -40,6 +45,8 @@ def get_dataset_from_config(self): data_csv = self.config['dataset'].get('data_csv', None) data_csv_unlab = self.config['dataset'].get('data_csv_unlab', None) + batch_size = self.config['dataset'].get('batch_size', 1) + data_shuffle = self.config['dataset'].get('data_shuffle', False) if(data_csv is not None): dataset = NiftyDataset(root_dir = root_dir, csv_file = data_csv, @@ -48,7 +55,7 @@ def get_dataset_from_config(self): transform = data_transform, task = self.task_type) self.dataloader = torch.utils.data.DataLoader(dataset, - batch_size = 1, shuffle=False, num_workers= 8, + batch_size = batch_size, shuffle=data_shuffle, num_workers= 8, worker_init_fn=None, generator = torch.Generator()) if(data_csv_unlab is not None): dataset_unlab = NiftyDataset(root_dir = root_dir, @@ -58,7 +65,7 @@ def get_dataset_from_config(self): transform = data_transform, task = self.task_type) self.dataloader_unlab = torch.utils.data.DataLoader(dataset_unlab, - batch_size = 1, shuffle=False, num_workers= 8, + batch_size = batch_size, shuffle=data_shuffle, num_workers= 8, worker_init_fn=None, generator = torch.Generator()) def run(self): @@ -67,38 +74,35 @@ def run(self): """ self.get_dataset_from_config() out_dir = self.config['dataset']['output_dir'] + if(not os.path.isdir(out_dir)): + os.mkdir(out_dir) + batch_operation = self.config['dataset'].get('batch_operation', None) for dataloader in [self.dataloader, self.dataloader_unlab]: - for item in dataloader: - img = item['image'][0] # the batch size is 1 - # save differnt modaliteis - img_names = item['names'] - spacing = [x.numpy()[0] for x in item['spacing']] - for i in range(img.shape[0]): - image_name = out_dir + "/" + img_names[i][0] - print(image_name) - save_nd_array_as_image(img[i], image_name, reference_name = None, spacing=spacing) - if('label' in item): - lab = item['label'][0] - label_name = out_dir + "/" + img_names[-1][0] - print(label_name) - save_nd_array_as_image(lab[0], label_name, reference_name = None, spacing=spacing) - -def main(): - """ - The main function for data preprocessing. - """ - if(len(sys.argv) < 2): - print('Number of arguments should be 2. e.g.') - print(' pymic_preprocess config.cfg') - exit() - cfg_file = str(sys.argv[1]) - if(not os.path.isfile(cfg_file)): - raise ValueError("The config file does not exist: " + cfg_file) - config = parse_config(cfg_file) - config = synchronize_config(config) - agent = PreprocessAgent(config) - agent.run() + if(dataloader is None): + continue + for data in dataloader: + inputs = data['image'] + labels = data.get('label', None) + img_names = data['names'] + lab_names = img_names[-1] + B, C = inputs.shape[0], inputs.shape[1] + spacing = [x.numpy()[0] for x in data['spacing']] + + if(batch_operation is not None and 'VolumeFusion' in batch_operation): + class_num = self.config['dataset']['VolumeFusion_cls_num'.lower()] + block_range = self.config['dataset']['VolumeFusion_block_range'.lower()] + size_min = self.config['dataset']['VolumeFusion_size_min'.lower()] + size_max = self.config['dataset']['VolumeFusion_size_max'.lower()] + if(labels is None): + lab_names = [item.replace(".nii.gz", "_lab.nii.gz") for item in img_names[0]] + inputs, labels = volume_fusion(inputs, class_num - 1, block_range, size_min, size_max) -if __name__ == "__main__": - main() - + for b in range(B): + for c in range(C): + image_name = out_dir + "/" + img_names[c][b] + print(image_name) + save_nd_array_as_image(inputs[b][c], image_name, reference_name = None, spacing=spacing) + if(labels is not None): + label_name = out_dir + "/" + lab_names[b] + print(label_name) + save_nd_array_as_image(labels[b][0], label_name, reference_name = None, spacing=spacing) diff --git a/pymic/net_run/train.py b/pymic/net_run/train.py index 50a5fb7..3a4571f 100644 --- a/pymic/net_run/train.py +++ b/pymic/net_run/train.py @@ -14,7 +14,6 @@ from pymic.net_run.weak_sup import WSLMethodDict from pymic.net_run.self_sup import SelfSupMethodDict from pymic.net_run.noisy_label import NLLMethodDict -# from pymic.net_run.self_sup import SelfSLSegAgent def get_seg_rec_agent(config, sup_type): assert(sup_type in ['fully_sup', 'semi_sup', 'self_sup', 'weak_sup', 'noisy_label']) diff --git a/pymic/util/evaluation_seg.py b/pymic/util/evaluation_seg.py index a9b114b..82939ce 100644 --- a/pymic/util/evaluation_seg.py +++ b/pymic/util/evaluation_seg.py @@ -286,63 +286,65 @@ def evaluation(config): label_list = [label_list] label_fuse = config.get('label_fuse', False) output_name = config.get('output_name', None) - gt_root = config['ground_truth_folder_root'] - seg_root = config['segmentation_folder_root'] + gt_dir = config['ground_truth_folder'] + seg_dirs = config['segmentation_folder'] image_pair_csv = config.get('evaluation_image_pair', None) + if(not isinstance(seg_dirs, (tuple, list))): + seg_dirs = [seg_dirs] if(image_pair_csv is not None): image_pair = pd.read_csv(image_pair_csv) gt_names, seg_names = image_pair.iloc[:, 0], image_pair.iloc[:, 1] else: - seg_names = sorted(os.listdir(seg_root)) + seg_names = sorted(os.listdir(seg_dirs[0])) seg_names = [item for item in seg_names if is_image_name(item)] gt_names = seg_names + for seg_dir in seg_dirs: + for metric in metric_list: + print(metric) + score_all_data = [] + name_score_list= [] + for i in range(len(gt_names)): + gt_full_name = join(gt_dir, gt_names[i]) + seg_full_name = join(seg_dir, seg_names[i]) + s_dict = load_image_as_nd_array(seg_full_name) + g_dict = load_image_as_nd_array(gt_full_name) + s_volume = s_dict["data_array"]; s_spacing = s_dict["spacing"] + g_volume = g_dict["data_array"]; g_spacing = g_dict["spacing"] + # for dim in range(len(s_spacing)): + # assert(s_spacing[dim] == g_spacing[dim]) + + score_vector = get_multi_class_evaluation_score(s_volume, g_volume, label_list, + label_fuse, s_spacing, metric ) + if(len(label_list) > 1): + score_vector.append(np.asarray(score_vector).mean()) + score_all_data.append(score_vector) + name_score_list.append([seg_names[i]] + score_vector) + print(seg_names[i], score_vector) + score_all_data = np.asarray(score_all_data) + score_mean = score_all_data.mean(axis = 0) + score_std = score_all_data.std(axis = 0) + name_score_list.append(['mean'] + list(score_mean)) + name_score_list.append(['std'] + list(score_std)) - for metric in metric_list: - print(metric) - score_all_data = [] - name_score_list= [] - for i in range(len(gt_names)): - gt_full_name = join(gt_root, gt_names[i]) - seg_full_name = join(seg_root, seg_names[i]) - s_dict = load_image_as_nd_array(seg_full_name) - g_dict = load_image_as_nd_array(gt_full_name) - s_volume = s_dict["data_array"]; s_spacing = s_dict["spacing"] - g_volume = g_dict["data_array"]; g_spacing = g_dict["spacing"] - # for dim in range(len(s_spacing)): - # assert(s_spacing[dim] == g_spacing[dim]) - - score_vector = get_multi_class_evaluation_score(s_volume, g_volume, label_list, - label_fuse, s_spacing, metric ) - if(len(label_list) > 1): - score_vector.append(np.asarray(score_vector).mean()) - score_all_data.append(score_vector) - name_score_list.append([seg_names[i]] + score_vector) - print(seg_names[i], score_vector) - score_all_data = np.asarray(score_all_data) - score_mean = score_all_data.mean(axis = 0) - score_std = score_all_data.std(axis = 0) - name_score_list.append(['mean'] + list(score_mean)) - name_score_list.append(['std'] + list(score_std)) - - # save the result as csv - if(output_name is None): - metric_output_name = "{0:}/eval_{1:}.csv".format(seg_root, metric) - else: - metric_output_name = output_name - with open(metric_output_name, mode='w') as csv_file: - csv_writer = csv.writer(csv_file, delimiter=',', - quotechar='"',quoting=csv.QUOTE_MINIMAL) - head = ['image'] + ["class_{0:}".format(i) for i in label_list] - if(len(label_list) > 1): - head = head + ["average"] - csv_writer.writerow(head) - for item in name_score_list: - csv_writer.writerow(item) - - print("{0:} mean ".format(metric), score_mean) - print("{0:} std ".format(metric), score_std) + # save the result as csv + if(output_name is None): + metric_output_name = "{0:}/eval_{1:}.csv".format(seg_dir, metric) + else: + metric_output_name = output_name + with open(metric_output_name, mode='w') as csv_file: + csv_writer = csv.writer(csv_file, delimiter=',', + quotechar='"',quoting=csv.QUOTE_MINIMAL) + head = ['image'] + ["class_{0:}".format(i) for i in label_list] + if(len(label_list) > 1): + head = head + ["average"] + csv_writer.writerow(head) + for item in name_score_list: + csv_writer.writerow(item) + + print("{0:} mean ".format(metric), score_mean) + print("{0:} std ".format(metric), score_std) def main(): """ diff --git a/pymic/util/image_process.py b/pymic/util/image_process.py index c813e5d..158569c 100644 --- a/pymic/util/image_process.py +++ b/pymic/util/image_process.py @@ -38,6 +38,18 @@ def get_ND_bounding_box(volume, margin = None): bb_max[i] = min(bb_max[i] + margin[i], input_shape[i]) return bb_min, bb_max +def get_human_region_from_ct(image, threshold_i = -600, threshold_z = 0.6): + input_shape = image.shape + mask = np.asarray(image > threshold_i) + mask2d = np.mean(mask, axis = 0) > threshold_z + se = np.ones([3,3]) + mask2d = ndimage.binary_opening(mask2d, se, iterations = 2) + mask2d = get_largest_k_components(mask2d, 1) + bbmin, bbmax = get_ND_bounding_box(mask2d, margin = [0, 0]) + bb_min = [0] + bbmin + bb_max = list(input_shape[:1]) + bbmax + return bb_min, bb_max + def crop_ND_volume_with_bounding_box(volume, bb_min, bb_max): """ Extract a subregion form an ND image. diff --git a/pymic/util/parse_config.py b/pymic/util/parse_config.py index a12cc76..0e38b91 100644 --- a/pymic/util/parse_config.py +++ b/pymic/util/parse_config.py @@ -102,24 +102,35 @@ def parse_config(filename): def synchronize_config(config): data_cfg = config['dataset'] - net_cfg = config['network'] - # data_cfg["modal_num"] = net_cfg["in_chns"] data_cfg["task_type"] = TaskDict[data_cfg["task_type"]] - data_cfg["LabelToProbability_class_num".lower()] = net_cfg["class_num"] - if "PartialLabelToProbability" in data_cfg['train_transform']: + if('network' in config): + net_cfg = config['network'] + # data_cfg["modal_num"] = net_cfg["in_chns"] + data_cfg["LabelToProbability_class_num".lower()] = net_cfg["class_num"] + transform = [] + if('transform' in data_cfg and data_cfg['transform'] is not None): + transform.extend(data_cfg['transform']) + if('train_transform' in data_cfg and data_cfg['train_transform'] is not None): + transform.extend(data_cfg['train_transform']) + if('valid_transform' in data_cfg and data_cfg['valid_transform'] is not None): + transform.extend(data_cfg['valid_transform']) + if('test_transform' in data_cfg and data_cfg['test_transform'] is not None): + transform.extend(data_cfg['test_transform']) + if ( "PartialLabelToProbability" in transform and 'network' in config): data_cfg["PartialLabelToProbability_class_num".lower()] = net_cfg["class_num"] patch_size = data_cfg.get('patch_size', None) if(patch_size is not None): - if('Pad' in data_cfg['train_transform']): + if('Pad' in transform and 'Pad_output_size'.lower() not in data_cfg): data_cfg['Pad_output_size'.lower()] = patch_size - if('CenterCrop' in data_cfg['train_transform']): + if('CenterCrop' in transform and 'CenterCrop_output_size'.lower() not in data_cfg): data_cfg['CenterCrop_output_size'.lower()] = patch_size - if('RandomCrop' in data_cfg['train_transform']): + if('RandomCrop' in transform and 'RandomCrop_output_size'.lower() not in data_cfg): data_cfg['RandomCrop_output_size'.lower()] = patch_size - if('RandomResizedCrop' in data_cfg['train_transform']): + if('RandomResizedCrop' in transform and \ + 'RandomResizedCrop_output_size'.lower() not in data_cfg): data_cfg['RandomResizedCrop_output_size'.lower()] = patch_size config['dataset'] = data_cfg - config['network'] = net_cfg + # config['network'] = net_cfg return config def logging_config(config): diff --git a/pymic/util/preprocess.py b/pymic/util/preprocess.py deleted file mode 100644 index c0dc9a1..0000000 --- a/pymic/util/preprocess.py +++ /dev/null @@ -1,63 +0,0 @@ -# -*- coding: utf-8 -*- -import os -import numpy as np -import SimpleITK as sitk -from pymic.io.image_read_write import load_image_as_nd_array -from pymic.transform.trans_dict import TransformDict -from pymic.util.parse_config import parse_config - -def get_transform_list(trans_config_file): - """ - Create a list of transforms given a configuration file. - """ - config = parse_config(trans_config_file) - transform_list = [] - - transform_param = config['dataset'] - transform_param['task'] = 'segmentation' - transform_names = config['dataset']['transform'] - for name in transform_names: - print(name) - if(name not in TransformDict): - raise(ValueError("Undefined transform {0:}".format(name))) - one_transform = TransformDict[name](transform_param) - transform_list.append(one_transform) - return transform_list - -def preprocess_with_transform(transforms, img_in_name, img_out_name, - lab_in_name = None, lab_out_name = None): - """ - Using a list of data transforms for preprocessing, - such as image normalization, cropping, etc. - TODO: support multip-modality preprocessing. - - :param transforms: (list) A list of transform objects. - :param img_in_name: (str) Input file name. - :param img_out_name: (str) Output file name. - :param lab_in_name: (optional, str) If None, load the image's - corresponding label for preprocessing as well. - :param lab_out_name: (optional, str) The output label name. - """ - image_dict = load_image_as_nd_array(img_in_name) - sample = {'image': np.asarray(image_dict['data_array'], np.float32), - 'origin':image_dict['origin'], - 'spacing': image_dict['spacing'], - 'direction':image_dict['direction']} - if(lab_in_name is not None): - label_dict = load_image_as_nd_array(lab_in_name) - sample['label'] = label_dict['data_array'] - for transform in transforms: - sample = transform(sample) - - out_img = sitk.GetImageFromArray(sample['image'][0]) - out_img.SetSpacing(sample['spacing']) - out_img.SetOrigin(sample['origin']) - out_img.SetDirection(sample['direction']) - sitk.WriteImage(out_img, img_out_name) - if(lab_in_name is not None and lab_out_name is not None): - out_lab = sitk.GetImageFromArray(sample['label'][0]) - out_lab.CopyInformation(out_img) - sitk.WriteImage(out_lab, lab_out_name) - - - diff --git a/setup.py b/setup.py index 22406a0..cfb5634 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setuptools.setup( name = 'PYMIC', - version = "0.4.0.1", + version = "0.4.0.2", author ='PyMIC Consortium', author_email = 'wguotai@gmail.com', description = description, From 613923b8f7e0fa943f456943c855a6d667ebd5e0 Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 10 Jan 2024 14:47:47 +0800 Subject: [PATCH 38/86] addd pymic_preprocess to setup --- pymic/net_run/agent_seg.py | 1 + setup.py | 1 + 2 files changed, 2 insertions(+) diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 2d6d489..2d61d0d 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -177,6 +177,7 @@ def training(self): inputs, labels_prob = mixup(inputs, labels_prob) # for debug + # print("current iteration", it) # if(it > 10): # break # for i in range(inputs.shape[0]): diff --git a/setup.py b/setup.py index cfb5634..1cad2ff 100644 --- a/setup.py +++ b/setup.py @@ -41,6 +41,7 @@ python_requires = '>=3.6', entry_points = { 'console_scripts': [ + 'pymic_preprocess = pymic.net_run.agent_preprocess:main' 'pymic_train = pymic.net_run.train:main', 'pymic_test = pymic.net_run.predict:main', 'pymic_eval_cls = pymic.util.evaluation_cls:main', From 6b08588ee46e5a3c0c45a6d832c015443fa52108 Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 10 Jan 2024 16:24:28 +0800 Subject: [PATCH 39/86] update config for multi-network models --- pymic/net_run/noisy_label/nll_co_teaching.py | 24 ---------------- pymic/net_run/noisy_label/nll_dast.py | 28 ++++++++---------- pymic/net_run/noisy_label/nll_trinet.py | 26 ----------------- pymic/net_run/preprocess.py | 30 ++++++++++++++++++++ setup.py | 2 +- 5 files changed, 43 insertions(+), 67 deletions(-) create mode 100644 pymic/net_run/preprocess.py diff --git a/pymic/net_run/noisy_label/nll_co_teaching.py b/pymic/net_run/noisy_label/nll_co_teaching.py index ec8e230..c60616e 100644 --- a/pymic/net_run/noisy_label/nll_co_teaching.py +++ b/pymic/net_run/noisy_label/nll_co_teaching.py @@ -18,22 +18,6 @@ from pymic.util.parse_config import * from pymic.util.ramps import get_rampup_ratio -class BiNet(nn.Module): - def __init__(self, params): - super(BiNet, self).__init__() - net_name = params['net_type'] - self.net1 = SegNetDict[net_name](params) - self.net2 = SegNetDict[net_name](params) - - def forward(self, x): - out1 = self.net1(x) - out2 = self.net2(x) - - if(self.training): - return out1, out2 - else: - return (out1 + out2) / 2 - class NLLCoTeaching(SegmentationAgent): """ Co-teaching for noisy-label learning. @@ -58,14 +42,6 @@ def __init__(self, config, stage = 'train'): logging.warn("only CrossEntropyLoss supported for" + " coteaching, the specified loss {0:} is ingored".format(loss_type)) - def create_network(self): - if(self.net is None): - self.net = BiNet(self.config['network']) - if(self.tensor_type == 'float'): - self.net.float() - else: - self.net.double() - def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] diff --git a/pymic/net_run/noisy_label/nll_dast.py b/pymic/net_run/noisy_label/nll_dast.py index 1921e9c..a90747c 100644 --- a/pymic/net_run/noisy_label/nll_dast.py +++ b/pymic/net_run/noisy_label/nll_dast.py @@ -117,31 +117,27 @@ def get_noisy_dataset_from_config(self): """ Create a dataset for images with noisy labels based on configuraiton. """ - root_dir = self.config['dataset']['root_dir'] - modal_num = self.config['dataset'].get('modal_num', 1) - transform_names = self.config['dataset']['train_transform'] - - self.transform_list = [] - if(transform_names is None or len(transform_names) == 0): - data_transform = None - else: - transform_param = self.config['dataset'] - transform_param['task'] = 'segmentation' - for name in transform_names: + trans_names, trans_params = self.get_transform_names_and_parameters('train') + transform_list = [] + if(trans_names is not None and len(trans_names) > 0): + for name in trans_names: if(name not in self.transform_dict): raise(ValueError("Undefined transform {0:}".format(name))) - one_transform = self.transform_dict[name](transform_param) - self.transform_list.append(one_transform) - data_transform = transforms.Compose(self.transform_list) + one_transform = self.transform_dict[name](trans_params) + transform_list.append(one_transform) + data_transform = transforms.Compose(transform_list) + modal_num = self.config['dataset'].get('modal_num', 1) csv_file = self.config['dataset'].get('train_csv_noise', None) - dataset = NiftyDataset(root_dir=root_dir, + dataset = NiftyDataset(root_dir = self.config['dataset']['train_dir'], csv_file = csv_file, modal_num = modal_num, with_label= True, - transform = data_transform ) + transform = data_transform , + task = self.task_type) return dataset + def create_dataset(self): super(NLLDAST, self).create_dataset() if(self.stage == 'train'): diff --git a/pymic/net_run/noisy_label/nll_trinet.py b/pymic/net_run/noisy_label/nll_trinet.py index 25c90cf..64d87b6 100644 --- a/pymic/net_run/noisy_label/nll_trinet.py +++ b/pymic/net_run/noisy_label/nll_trinet.py @@ -17,24 +17,6 @@ from pymic.util.parse_config import * from pymic.util.ramps import get_rampup_ratio -class TriNet(nn.Module): - def __init__(self, params): - super(TriNet, self).__init__() - net_name = params['net_type'] - self.net1 = SegNetDict[net_name](params) - self.net2 = SegNetDict[net_name](params) - self.net3 = SegNetDict[net_name](params) - - def forward(self, x): - out1 = self.net1(x) - out2 = self.net2(x) - out3 = self.net3(x) - - if(self.training): - return out1, out2, out3 - else: - return (out1 + out2 + out3) / 3 - class NLLTriNet(SegmentationAgent): """ Implementation of trinet for learning from noisy samples for @@ -56,14 +38,6 @@ class NLLTriNet(SegmentationAgent): def __init__(self, config, stage = 'train'): super(NLLTriNet, self).__init__(config, stage) - def create_network(self): - if(self.net is None): - self.net = TriNet(self.config['network']) - if(self.tensor_type == 'float'): - self.net.float() - else: - self.net.double() - def get_loss_and_confident_mask(self, pred, labels_prob, conf_ratio): prob = nn.Softmax(dim = 1)(pred) prob_2d = reshape_tensor_to_2D(prob) * 0.999 + 5e-4 diff --git a/pymic/net_run/preprocess.py b/pymic/net_run/preprocess.py new file mode 100644 index 0000000..3b34887 --- /dev/null +++ b/pymic/net_run/preprocess.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import os +import sys +from datetime import datetime +from pymic.util.parse_config import * +from pymic.net_run.agent_preprocess import PreprocessAgent + + +def main(): + """ + The main function for data preprocessing. + """ + if(len(sys.argv) < 2): + print('Number of arguments should be 2. e.g.') + print(' pymic_preprocess config.cfg') + exit() + cfg_file = str(sys.argv[1]) + if(not os.path.isfile(cfg_file)): + raise ValueError("The config file does not exist: " + cfg_file) + config = parse_config(cfg_file) + config = synchronize_config(config) + agent = PreprocessAgent(config) + agent.run() + +if __name__ == "__main__": + main() + + + diff --git a/setup.py b/setup.py index 1cad2ff..8030dc2 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,7 @@ python_requires = '>=3.6', entry_points = { 'console_scripts': [ - 'pymic_preprocess = pymic.net_run.agent_preprocess:main' + 'pymic_preprocess = pymic.net_run.preprocess:main', 'pymic_train = pymic.net_run.train:main', 'pymic_test = pymic.net_run.predict:main', 'pymic_eval_cls = pymic.util.evaluation_cls:main', From 57e5257b431968fb3e858838870b7dd6e81fdc5f Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 11 Jan 2024 22:26:52 +0800 Subject: [PATCH 40/86] update code for nll_clslsr --- pymic/loss/seg/slsr.py | 4 +-- pymic/net_run/noisy_label/nll_clslsr.py | 33 +++++++++++-------------- 2 files changed, 17 insertions(+), 20 deletions(-) diff --git a/pymic/loss/seg/slsr.py b/pymic/loss/seg/slsr.py index d5c4151..92adea7 100644 --- a/pymic/loss/seg/slsr.py +++ b/pymic/loss/seg/slsr.py @@ -38,8 +38,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) predict = reshape_tensor_to_2D(predict) soft_y = reshape_tensor_to_2D(soft_y) if(pix_w is not None): diff --git a/pymic/net_run/noisy_label/nll_clslsr.py b/pymic/net_run/noisy_label/nll_clslsr.py index 0148621..836272a 100644 --- a/pymic/net_run/noisy_label/nll_clslsr.py +++ b/pymic/net_run/noisy_label/nll_clslsr.py @@ -142,9 +142,9 @@ def test_time_dropout(m): print(gt.shape, pred_cat.shape) conf = get_confident_map(gt, pred_cat) conf = conf.reshape(-1, 256, 256).astype(np.uint8) * 255 - save_dir = self.config['dataset']['root_dir'] + "/slsr_conf" + save_dir = self.config['dataset']['train_dir'] + "/slsr_conf" for idx in range(len(filename_list)): - filename = filename_list[idx][0].split('/')[-1] + filename = filename_list[idx][0][0].split('/')[-1] conf_map = Image.fromarray(conf[idx]) dst_path = os.path.join(save_dir, filename) conf_map.save(dst_path) @@ -152,32 +152,29 @@ def test_time_dropout(m): def get_confidence_map(cfg_file): config = parse_config(cfg_file) config = synchronize_config(config) + agent = NLLCLSLSR(config, 'test') - # set dataset - transform_names = config['dataset']['valid_transform'] + # set customized dataset for testing, i.e,. inference with training images + trans_names, trans_params = agent.get_transform_names_and_parameters('valid') transform_list = [] - transform_dict = TransformDict - if(transform_names is None or len(transform_names) == 0): - data_transform = None - else: - transform_param = config['dataset'] - transform_param['task'] = 'segmentation' - for name in transform_names: - if(name not in transform_dict): + if(trans_names is not None and len(trans_names) > 0): + for name in trans_names: + if(name not in agent.transform_dict): raise(ValueError("Undefined transform {0:}".format(name))) - one_transform = transform_dict[name](transform_param) + one_transform = agent.transform_dict[name](trans_params) transform_list.append(one_transform) - data_transform = transforms.Compose(transform_list) + data_transform = transforms.Compose(transform_list) csv_file = config['dataset']['train_csv'] modal_num = config['dataset'].get('modal_num', 1) - dataset = NiftyDataset(root_dir = config['dataset']['root_dir'], + stage_dir = config['dataset']['train_dir'] + dataset = NiftyDataset(root_dir = stage_dir, csv_file = csv_file, modal_num = modal_num, with_label= True, - transform = data_transform ) + transform = data_transform, + task = agent.task_type) - agent = NLLCLSLSR(config, 'test') agent.set_datasets(None, None, dataset) agent.transform_list = transform_list agent.create_dataset() @@ -196,4 +193,4 @@ def get_confidence_map(cfg_file): "label": df_train["label"]} train_cl_csv = csv_file.replace(".csv", "_clslsr.csv") df_cl = pd.DataFrame.from_dict(train_cl_dict) - df_cl.to_csv(train_cl_csv, index = False) \ No newline at end of file + df_cl.to_csv(train_cl_csv, index = False) False) \ No newline at end of file From 1f44153ea5ae9d4ea500d24c5d0847d818279956 Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 11 Jan 2024 23:04:07 +0800 Subject: [PATCH 41/86] Update nll_clslsr.py --- pymic/net_run/noisy_label/nll_clslsr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymic/net_run/noisy_label/nll_clslsr.py b/pymic/net_run/noisy_label/nll_clslsr.py index 836272a..c977eba 100644 --- a/pymic/net_run/noisy_label/nll_clslsr.py +++ b/pymic/net_run/noisy_label/nll_clslsr.py @@ -193,4 +193,4 @@ def get_confidence_map(cfg_file): "label": df_train["label"]} train_cl_csv = csv_file.replace(".csv", "_clslsr.csv") df_cl = pd.DataFrame.from_dict(train_cl_dict) - df_cl.to_csv(train_cl_csv, index = False) False) \ No newline at end of file + df_cl.to_csv(train_cl_csv, index = False) \ No newline at end of file From 59a6f70612f474e39120306244501b9f839d22d6 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 12 Jan 2024 10:32:55 +0800 Subject: [PATCH 42/86] update preprocess file fix issues for label name in output --- pymic/net_run/agent_preprocess.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/pymic/net_run/agent_preprocess.py b/pymic/net_run/agent_preprocess.py index 67b1262..db8b10b 100644 --- a/pymic/net_run/agent_preprocess.py +++ b/pymic/net_run/agent_preprocess.py @@ -73,7 +73,8 @@ def run(self): Do preprocessing for labeled and unlabeled data. """ self.get_dataset_from_config() - out_dir = self.config['dataset']['output_dir'] + out_dir = self.config['dataset']['output_dir'] + modal_num = self.config['dataset']['modal_num'] if(not os.path.isdir(out_dir)): os.mkdir(out_dir) batch_operation = self.config['dataset'].get('batch_operation', None) @@ -82,9 +83,12 @@ def run(self): continue for data in dataloader: inputs = data['image'] - labels = data.get('label', None) + labels = data.get('label', None) img_names = data['names'] - lab_names = img_names[-1] + if(len(img_names) == modal_num): # for unlabeled dataset + lab_names = [item.replace(".nii.gz", "_lab.nii.gz") for item in img_names[0]] + else: + lab_names = img_names[-1] B, C = inputs.shape[0], inputs.shape[1] spacing = [x.numpy()[0] for x in data['spacing']] @@ -93,8 +97,6 @@ def run(self): block_range = self.config['dataset']['VolumeFusion_block_range'.lower()] size_min = self.config['dataset']['VolumeFusion_size_min'.lower()] size_max = self.config['dataset']['VolumeFusion_size_max'.lower()] - if(labels is None): - lab_names = [item.replace(".nii.gz", "_lab.nii.gz") for item in img_names[0]] inputs, labels = volume_fusion(inputs, class_num - 1, block_range, size_min, size_max) for b in range(B): From 4a3eeb407d5cdfb0896777b49a1c3d48f8b196c0 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 13 Jan 2024 20:22:02 +0800 Subject: [PATCH 43/86] update activation function for loss, and ema --- pymic/loss/seg/ssl.py | 8 ++++---- pymic/net_run/infer_func.py | 2 +- pymic/net_run/semi_sup/ssl_mt.py | 2 +- pymic/net_run/semi_sup/ssl_uamt.py | 2 +- pymic/net_run/weak_sup/wsl_ustm.py | 2 +- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pymic/loss/seg/ssl.py b/pymic/loss/seg/ssl.py index 0bf276f..3a7430a 100644 --- a/pymic/loss/seg/ssl.py +++ b/pymic/loss/seg/ssl.py @@ -34,8 +34,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) # for numeric stability predict = predict * 0.999 + 5e-4 @@ -70,8 +70,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) # for numeric stability predict = predict * 0.999 + 5e-4 diff --git a/pymic/net_run/infer_func.py b/pymic/net_run/infer_func.py index b0190ad..e0e466e 100644 --- a/pymic/net_run/infer_func.py +++ b/pymic/net_run/infer_func.py @@ -104,7 +104,7 @@ def __infer_with_sliding_window(self, image): weight = torch.zeros(output_shape).to(image.device) temp_w = self.__get_gaussian_weight_map(window_size) temp_w = np.broadcast_to(temp_w, [batch_size, class_num] + window_size) - temp_w = torch.from_numpy(temp_w).to(image.device) + temp_w = torch.from_numpy(np.array(temp_w)).to(image.device) temp_in_shape = img_full_shape[:2] + window_size tempx = torch.ones(temp_in_shape).to(image.device) out_num, scale_list = self.__get_prediction_number_and_scales(tempx) diff --git a/pymic/net_run/semi_sup/ssl_mt.py b/pymic/net_run/semi_sup/ssl_mt.py index 2a2abb8..409af19 100644 --- a/pymic/net_run/semi_sup/ssl_mt.py +++ b/pymic/net_run/semi_sup/ssl_mt.py @@ -106,7 +106,7 @@ def training(self): alpha = ssl_cfg.get('ema_decay', 0.99) alpha = min(1 - 1 / (self.glob_it / iter_valid + 1), alpha) for ema_param, param in zip(self.net_ema.parameters(), self.net.parameters()): - ema_param.data.mul_(alpha).add_(1 - alpha, param.data) + ema_param.data.mul_(alpha).add(param.data, alpha = 1.0 - alpha) train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() diff --git a/pymic/net_run/semi_sup/ssl_uamt.py b/pymic/net_run/semi_sup/ssl_uamt.py index 6222fe3..053a012 100644 --- a/pymic/net_run/semi_sup/ssl_uamt.py +++ b/pymic/net_run/semi_sup/ssl_uamt.py @@ -108,7 +108,7 @@ def training(self): alpha = ssl_cfg.get('ema_decay', 0.99) alpha = min(1 - 1 / (self.glob_it / iter_valid + 1), alpha) for ema_param, param in zip(self.net_ema.parameters(), self.net.parameters()): - ema_param.data.mul_(alpha).add_(1 - alpha, param.data) + ema_param.data.mul_(alpha).add(param.data, alpha = 1.0 - alpha) train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() diff --git a/pymic/net_run/weak_sup/wsl_ustm.py b/pymic/net_run/weak_sup/wsl_ustm.py index 0ea3fbc..31a6644 100644 --- a/pymic/net_run/weak_sup/wsl_ustm.py +++ b/pymic/net_run/weak_sup/wsl_ustm.py @@ -125,7 +125,7 @@ def training(self): alpha = wsl_cfg.get('ema_decay', 0.99) alpha = min(1 - 1 / (self.glob_it / iter_valid + 1), alpha) for ema_param, param in zip(self.net_ema.parameters(), self.net.parameters()): - ema_param.data.mul_(alpha).add_(1 - alpha, param.data) + ema_param.data.mul_(alpha).add(param.data, alpha = 1.0 - alpha) train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() From be140dec60b1ed37529ac57d8edb270b2cdf0add Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 16 Jan 2024 21:06:24 +0800 Subject: [PATCH 44/86] Update setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 8030dc2..ebb738f 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setuptools.setup( name = 'PYMIC', - version = "0.4.0.2", + version = "0.4.1", author ='PyMIC Consortium', author_email = 'wguotai@gmail.com', description = description, From 0e6b60cc5d355651b714a3af2db6a5361a0917ec Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 16 Jan 2024 21:22:18 +0800 Subject: [PATCH 45/86] Update README.md update version to 0.4.1 --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 4af5eff..90a6de7 100644 --- a/README.md +++ b/README.md @@ -47,10 +47,10 @@ Run the following command to install the latest released version of PyMIC: ```bash pip install PYMIC ``` -To install a specific version of PYMIC such as 0.4.0, run: +To install a specific version of PYMIC such as 0.4.1, run: ```bash -pip install PYMIC==0.4.0 +pip install PYMIC==0.4.1 ``` Alternatively, you can download the source code for the latest version. Run the following command to compile and install: From 84d8ed084819cb2c76971dc03de44d75d71344ce Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 16 Jan 2024 21:25:14 +0800 Subject: [PATCH 46/86] Update __init__.py update version --- pymic/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymic/__init__.py b/pymic/__init__.py index 1520d82..33943e4 100644 --- a/pymic/__init__.py +++ b/pymic/__init__.py @@ -1,7 +1,7 @@ from __future__ import absolute_import from enum import Enum -__version__ = "0.4.0" +__version__ = "0.4.1" class TaskType(Enum): CLASSIFICATION_ONE_HOT = 1 From ca4ebb0a66a69fbe3bec19fdc042671fc4116465 Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 1 May 2024 16:18:41 +0800 Subject: [PATCH 47/86] update config files use argparse for configuration --- pymic/io/image_read_write.py | 10 +++- pymic/io/nifty_dataset.py | 32 ++++++++++--- pymic/net_run/agent_abstract.py | 5 +- pymic/net_run/agent_cls.py | 10 ++-- pymic/net_run/agent_preprocess.py | 34 +++++++++++--- pymic/net_run/agent_seg.py | 7 +-- pymic/net_run/predict.py | 32 +++++++++---- pymic/net_run/preprocess.py | 13 ++++-- pymic/net_run/train.py | 32 +++++++++---- pymic/transform/affine.py | 6 ++- pymic/transform/intensity.py | 77 +++++++++++++++++++++++++++++-- pymic/transform/label_convert.py | 2 + pymic/util/evaluation_seg.py | 8 ++-- pymic/util/parse_config.py | 26 ++++++++++- 14 files changed, 239 insertions(+), 55 deletions(-) diff --git a/pymic/io/image_read_write.py b/pymic/io/image_read_write.py index 3aa87bd..efbe656 100644 --- a/pymic/io/image_read_write.py +++ b/pymic/io/image_read_write.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division - +import logging import os import numpy as np import SimpleITK as sitk @@ -53,10 +53,16 @@ def load_rgb_image_as_3d_array(filename): image = np.expand_dims(image, axis = 0) else: # transpose rgb image from [H, W, C] to [C, H, W] - assert(image_shape[2] == 3 or image_shape[2] == 4) + # logging.warning("The image is expected to have 1 or three channels, but it has a different channel number") + # logging.warning("({0:} {1:}".format(filename, image_shape)) if(image_shape[2] == 4): image = image[:, :, range(3)] + elif(image_shape[2] == 2): + image = image[:, :, 0:1] + elif(image_shape[2] != 3): + raise ValueError("invalid channel number {0:}", image_shape[2]) image = np.transpose(image, axes = [2, 0, 1]) + output = {} output['data_array'] = image output['origin'] = (0, 0) diff --git a/pymic/io/nifty_dataset.py b/pymic/io/nifty_dataset.py index aefe4da..c3253c9 100644 --- a/pymic/io/nifty_dataset.py +++ b/pymic/io/nifty_dataset.py @@ -118,7 +118,7 @@ def __getitem__(self, idx): return sample -class ClassificationDataset(NiftyDataset): +class ClassificationDataset(Dataset): """ Dataset for loading images for classification. It generates 4D tensors with dimention order [C, D, H, W] for 3D images, and 3D tensors @@ -134,16 +134,32 @@ class ClassificationDataset(NiftyDataset): """ def __init__(self, root_dir, csv_file, modal_num = 1, class_num = 2, with_label = False, transform=None, task = TaskType.CLASSIFICATION_ONE_HOT): - super(ClassificationDataset, self).__init__(root_dir, - csv_file, modal_num, with_label, transform) + # super(ClassificationDataset, self).__init__(root_dir, + # csv_file, modal_num, with_label, transform, task) + self.root_dir = root_dir + self.csv_items = pd.read_csv(csv_file) + self.modal_num = modal_num + self.with_label = with_label + self.transform = transform self.class_num = class_num self.task = task assert self.task in [TaskType.CLASSIFICATION_ONE_HOT, TaskType.CLASSIFICATION_COEXIST] + + csv_keys = list(self.csv_items.keys()) + self.image_weight_idx = None + if('image_weight' in csv_keys): + self.image_weight_idx = csv_keys.index('image_weight') + + def __len__(self): + return len(self.csv_items) def __getlabel__(self, idx): csv_keys = list(self.csv_items.keys()) - label_idx = csv_keys.index('label') - label = self.csv_items.iloc[idx, label_idx] + if self.task == TaskType.CLASSIFICATION_ONE_HOT: + label_idx = csv_keys.index('label') + label = self.csv_items.iloc[idx, label_idx] + else: + label = np.asarray(self.csv_items.iloc[idx, 1:self.class_num + 1], np.float32) return label def __getweight__(self, idx): @@ -161,13 +177,15 @@ def __getitem__(self, idx): names_list.append(image_name) image_list.append(image_data) image = np.concatenate(image_list, axis = 0) - image = np.asarray(image, np.float32) + image = np.asarray(image, np.float32) sample = {'image': image, 'names' : names_list[0], 'origin':image_dict['origin'], 'spacing': image_dict['spacing'], 'direction':image_dict['direction']} + if (self.with_label): - sample['label'] = self.__getlabel__(idx) + label = self.__getlabel__(idx) + sample['label'] = label #np.asarray(label, np.float32) if (self.image_weight_idx is not None): sample['image_weight'] = self.__getweight__(idx) if self.transform: diff --git a/pymic/net_run/agent_abstract.py b/pymic/net_run/agent_abstract.py index f9575ab..d8abb19 100644 --- a/pymic/net_run/agent_abstract.py +++ b/pymic/net_run/agent_abstract.py @@ -63,6 +63,7 @@ def __init__(self, config, stage = 'train'): if(self.deterministic): seed_torch(self.random_seed) logging.info("deterministric is true") + def set_datasets(self, train_set, valid_set, test_set): """ @@ -139,7 +140,7 @@ def get_checkpoint_name(self): """ ckpt_mode = self.config['testing']['ckpt_mode'] if(ckpt_mode == 0 or ckpt_mode == 1): - ckpt_dir = self.config['training']['ckpt_save_dir'] + ckpt_dir = self.config['training']['ckpt_dir'] ckpt_prefix = self.config['training'].get('ckpt_prefix', None) if(ckpt_prefix is None): ckpt_prefix = ckpt_dir.split('/')[-1] @@ -326,3 +327,5 @@ def run(self): else: self.infer() + + diff --git a/pymic/net_run/agent_cls.py b/pymic/net_run/agent_cls.py index ee4e25b..a31df84 100644 --- a/pymic/net_run/agent_cls.py +++ b/pymic/net_run/agent_cls.py @@ -73,7 +73,8 @@ def get_stage_dataset_from_config(self, stage): modal_num = modal_num, class_num = class_num, with_label= not (stage == 'test'), - transform = data_transform ) + transform = data_transform, + task = self.task_type) return dataset def create_network(self): @@ -97,6 +98,8 @@ def create_loss_calculator(self): if(self.loss_dict is None): self.loss_dict = PyMICClsLossDict loss_name = self.config['training']['loss_type'] + if(loss_name != "SigmoidCELoss" and self.task_type == TaskType.CLASSIFICATION_COEXIST): + raise ValueError("SigmoidCELoss should be used when task_type is cls_coexist") if(loss_name in self.loss_dict): self.loss_calculater = self.loss_dict[loss_name](self.config['training']) else: @@ -218,7 +221,7 @@ def train_valid(self): self.device = torch.device("cuda:{0:}".format(device_ids[0])) self.net.to(self.device) - ckpt_dir = self.config['training']['ckpt_save_dir'] + ckpt_dir = self.config['training']['ckpt_dir'] if(ckpt_dir[-1] == "/"): ckpt_dir = ckpt_dir[:-1] ckpt_prefix = self.config['training'].get('ckpt_prefix', None) @@ -259,7 +262,7 @@ def train_valid(self): self.trainIter = iter(self.train_loader) logging.info("{0:} training start".format(str(datetime.now())[:-7])) - self.summ_writer = SummaryWriter(self.config['training']['ckpt_save_dir']) + self.summ_writer = SummaryWriter(self.config['training']['ckpt_dir']) self.glob_it = iter_start for it in range(iter_start, iter_max, iter_valid): lr_value = self.optimizer.param_groups[0]['lr'] @@ -267,6 +270,7 @@ def train_valid(self): train_scalars = self.training() t1 = time.time() valid_scalars = self.validation() + t2 = time.time() if(isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): self.scheduler.step(valid_scalars[metrics]) diff --git a/pymic/net_run/agent_preprocess.py b/pymic/net_run/agent_preprocess.py index db8b10b..c53421f 100644 --- a/pymic/net_run/agent_preprocess.py +++ b/pymic/net_run/agent_preprocess.py @@ -9,7 +9,7 @@ from pymic.io.nifty_dataset import NiftyDataset from pymic.transform.trans_dict import TransformDict from pymic.net_run.agent_abstract import seed_torch -from pymic.net_run.self_sup.util import volume_fusion +from pymic.net_run.self_sup.util import volume_fusion, nonlienar_volume_fusion,augmented_volume_fusion,self_volume_fusion class PreprocessAgent(object): def __init__(self, config): @@ -92,17 +92,37 @@ def run(self): B, C = inputs.shape[0], inputs.shape[1] spacing = [x.numpy()[0] for x in data['spacing']] - if(batch_operation is not None and 'VolumeFusion' in batch_operation): - class_num = self.config['dataset']['VolumeFusion_cls_num'.lower()] - block_range = self.config['dataset']['VolumeFusion_block_range'.lower()] - size_min = self.config['dataset']['VolumeFusion_size_min'.lower()] - size_max = self.config['dataset']['VolumeFusion_size_max'.lower()] - inputs, labels = volume_fusion(inputs, class_num - 1, block_range, size_min, size_max) + if(batch_operation is not None): + if('VolumeFusion' in batch_operation): + class_num = self.config['dataset']['VolumeFusion_cls_num'.lower()] + block_range = self.config['dataset']['VolumeFusion_block_range'.lower()] + size_min = self.config['dataset']['VolumeFusion_size_min'.lower()] + size_max = self.config['dataset']['VolumeFusion_size_max'.lower()] + inputs, labels = volume_fusion(inputs, class_num - 1, block_range, size_min, size_max) + elif('SelfVolumeFusion' in batch_operation): + class_num = self.config['dataset']['SelfVolumeFusion_cls_num'.lower()] + fuse_ratio = self.config['dataset']['SelfVolumeFusion_fuse_ratio'.lower()] + size_min = self.config['dataset']['SelfVolumeFusion_size_min'.lower()] + size_max = self.config['dataset']['SelfVolumeFusion_size_max'.lower()] + inputs, labels = self_volume_fusion(inputs, class_num - 1, fuse_ratio, size_min, size_max) + elif('NonLinearVolumeFusion' in batch_operation): + block_range = self.config['dataset']['NonLinearVolumeFusion_block_range'.lower()] + size_min = self.config['dataset']['NonLinearVolumeFusion_size_min'.lower()] + size_max = self.config['dataset']['NonLinearVolumeFusion_size_max'.lower()] + inputs, labels = nonlienar_volume_fusion(inputs, block_range, size_min, size_max) + elif('AugmentedVolumeFusion' in batch_operation): + size_min = self.config['dataset']['AugmentedVolumeFusion_size_min'.lower()] + size_max = self.config['dataset']['AugmentedVolumeFusion_size_max'.lower()] + inputs, labels = augmented_volume_fusion(inputs, size_min, size_max) for b in range(B): for c in range(C): image_name = out_dir + "/" + img_names[c][b] print(image_name) + out_dir_full = "/".join(image_name.split("/")[:-1]) + print(out_dir_full) + if(not os.path.exists(out_dir_full)): + os.mkdir(out_dir_full) save_nd_array_as_image(inputs[b][c], image_name, reference_name = None, spacing=spacing) if(labels is not None): label_name = out_dir + "/" + lab_names[b] diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 2d61d0d..4a80298 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -321,7 +321,7 @@ def train_valid(self): self.device = torch.device("cuda:{0:}".format(device_ids[0])) self.net.to(self.device) - ckpt_dir = self.config['training']['ckpt_save_dir'] + ckpt_dir = self.config['training']['ckpt_dir'] if(ckpt_dir[-1] == "/"): ckpt_dir = ckpt_dir[:-1] ckpt_prefix = self.config['training'].get('ckpt_prefix', None) @@ -365,7 +365,7 @@ def train_valid(self): self.trainIter = iter(self.train_loader) logging.info("{0:} training start".format(str(datetime.now())[:-7])) - self.summ_writer = SummaryWriter(self.config['training']['ckpt_save_dir']) + self.summ_writer = SummaryWriter(self.config['training']['ckpt_dir']) self.glob_it = iter_start for it in range(iter_start, iter_max, iter_valid): lr_value = self.optimizer.param_groups[0]['lr'] @@ -581,7 +581,7 @@ def save_outputs(self, data): if(test_dir is None): test_dir = self.config['dataset']['train_dir'] - for i in range(len(names)): + for i in range(output.shape[0]): save_name = names[i][0].split('/')[-1] if ignore_dir else \ names[i][0].replace('/', '_') if((filename_replace_source is not None) and (filename_replace_target is not None)): @@ -607,3 +607,4 @@ def save_outputs(self, data): if(len(temp_prob.shape) == 2): temp_prob = np.asarray(temp_prob * 255, np.uint8) save_nd_array_as_image(temp_prob, prob_save_name, test_dir + '/' + names[i][0]) +0]) diff --git a/pymic/net_run/predict.py b/pymic/net_run/predict.py index 80134d8..e618be6 100644 --- a/pymic/net_run/predict.py +++ b/pymic/net_run/predict.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division +import argparse import logging import os import sys @@ -12,16 +13,30 @@ def main(): """ - The main function for running a network for training or inference. + The main function for running a network for inference. """ if(len(sys.argv) < 2): - print('Number of arguments should be 2. e.g.') - print(' pymic_test config.cfg') + print('Number of arguments should be at least 2. e.g.') + print(' pymic_test config.cfg -test_csv train.csv -output_dir result_dir -ckpt_mode 1') exit() - cfg_file = str(sys.argv[1]) - if(not os.path.isfile(cfg_file)): - raise ValueError("The config file does not exist: " + cfg_file) - config = parse_config(cfg_file) + parser = argparse.ArgumentParser() + parser.add_argument("cfg", help="configuration file for testing") + parser.add_argument("-test_csv", help="the csv file for testing images", + required=False, default=None) + parser.add_argument("-output_dir", help="the output dir for inference results", + required=False, default=None) + parser.add_argument("-ckpt_dir", help="the dir for trained model", + required=False, default=None) + parser.add_argument("-ckpt_mode", help="the mode for chekpoint: 0-latest, 1-best, 2-customized", + required=False, default=None) + parser.add_argument("-ckpt_name", help="the name chekpoint if ckpt_mode = 2", + required=False, default=None) + parser.add_argument("-gpus", help="the gpus for runing, e.g., [0]", + required=False, default=None) + args = parser.parse_args() + if(not os.path.isfile(args.cfg)): + raise ValueError("The config file does not exist: " + args.cfg) + config = parse_config(args) config = synchronize_config(config) log_dir = config['testing']['output_dir'] if(not os.path.exists(log_dir)): @@ -34,7 +49,8 @@ def main(): logging.basicConfig(filename=log_dir+"/log_test.txt", level=logging.INFO, format='%(message)s') # for python 3.6 logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) - logging_config(config) + dst_cfg = args.cfg if "/" not in args.cfg else args.cfg.split("/")[-1] + wrtie_config(config, log_dir + "/" + dst_cfg) task = config['dataset']['task_type'] if(task == TaskType.CLASSIFICATION_ONE_HOT or task == TaskType.CLASSIFICATION_COEXIST): agent = ClassificationAgent(config, 'test') diff --git a/pymic/net_run/preprocess.py b/pymic/net_run/preprocess.py index 3b34887..63410b5 100644 --- a/pymic/net_run/preprocess.py +++ b/pymic/net_run/preprocess.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division +import argparse import os import sys from datetime import datetime @@ -15,11 +16,13 @@ def main(): print('Number of arguments should be 2. e.g.') print(' pymic_preprocess config.cfg') exit() - cfg_file = str(sys.argv[1]) - if(not os.path.isfile(cfg_file)): - raise ValueError("The config file does not exist: " + cfg_file) - config = parse_config(cfg_file) - config = synchronize_config(config) + parser = argparse.ArgumentParser() + parser.add_argument("cfg", help="configuration file for preprocessing") + args = parser.parse_args() + if(not os.path.isfile(args.cfg)): + raise ValueError("The config file does not exist: " + args.cfg) + config = parse_config(args) + config = synchronize_config(config) agent = PreprocessAgent(config) agent.run() diff --git a/pymic/net_run/train.py b/pymic/net_run/train.py index 3a4571f..ed60fa1 100644 --- a/pymic/net_run/train.py +++ b/pymic/net_run/train.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division +import argparse import logging import os import sys @@ -48,19 +49,28 @@ def main(): The main function for running a network for training. """ if(len(sys.argv) < 2): - print('Number of arguments should be 2. e.g.') - print(' pymic_train config.cfg') + print('Number of arguments should be at least 2. e.g.') + print(' pymic_train config.cfg -train_csv train.csv') exit() - cfg_file = str(sys.argv[1]) - if(not os.path.isfile(cfg_file)): - raise ValueError("The config file does not exist: " + cfg_file) - config = parse_config(cfg_file) + parser = argparse.ArgumentParser() + parser.add_argument("cfg", help="configuration file for training") + parser.add_argument("-train_csv", help="the csv file for training images", + required=False, default=None) + parser.add_argument("-valid_csv", help="the csv file for validation images", + required=False, default=None) + parser.add_argument("-ckpt_dir", help="the output dir for trained model", + required=False, default=None) + parser.add_argument("-gpus", help="the gpus for runing, e.g., [0]", + required=False, default=None) + args = parser.parse_args() + if(not os.path.isfile(args.cfg)): + raise ValueError("The config file does not exist: " + args.cfg) + config = parse_config(args) config = synchronize_config(config) - log_dir = config['training']['ckpt_save_dir'] + + log_dir = config['training']['ckpt_dir'] if(not os.path.exists(log_dir)): os.makedirs(log_dir, exist_ok=True) - dst_cfg = cfg_file if "/" not in cfg_file else cfg_file.split("/")[-1] - shutil.copy(cfg_file, log_dir + "/" + dst_cfg) datetime_str = str(datetime.now())[:-7].replace(":", "_") if sys.version.startswith("3.9"): logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(datetime_str), @@ -69,7 +79,9 @@ def main(): logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(datetime_str), level=logging.INFO, format='%(message)s') # for python 3.6 logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) - logging_config(config) + dst_cfg = args.cfg if "/" not in args.cfg else args.cfg.split("/")[-1] + wrtie_config(config, log_dir + "/" + dst_cfg) + task = config['dataset']['task_type'] if(task == TaskType.CLASSIFICATION_ONE_HOT or task == TaskType.CLASSIFICATION_COEXIST): agent = ClassificationAgent(config, 'train') diff --git a/pymic/transform/affine.py b/pymic/transform/affine.py index 552516f..2efd586 100644 --- a/pymic/transform/affine.py +++ b/pymic/transform/affine.py @@ -86,7 +86,7 @@ def _get_affine_param(self, sample, output_shape): # sample['Affine_Param'] = json.dumps((input_shape, tform["matrix"])) return sample, tform - def _apply_affine_to_ND_volume(self, image, output_shape, tform, order = 3): + def _apply_affine_to_ND_volume(self, image, output_shape, tform, order = 2): """ output_shape should only has two dimensions, e.g., (H, W) """ @@ -152,5 +152,9 @@ def _get_param_for_inverse_transform(self, sample): # aff_out_shape = origin_shape[-2:] # output_predict = self._apply_affine_to_ND_volume(predict, aff_out_shape, tform.inverse) + # sample['predict'] = output_predict + # return sample + = self._apply_affine_to_ND_volume(predict, aff_out_shape, tform.inverse) + # sample['predict'] = output_predict # return sample diff --git a/pymic/transform/intensity.py b/pymic/transform/intensity.py index 2b19ebc..2829e76 100644 --- a/pymic/transform/intensity.py +++ b/pymic/transform/intensity.py @@ -1,10 +1,13 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division import copy +import itertools import json import math import random import numpy as np +from scipy import ndimage +from skimage import exposure from pymic.transform.abstract_transform import AbstractTransform from pymic.util.image_process import * try: # SciPy >= 0.19 @@ -82,7 +85,37 @@ def __call__(self, sample): image[chn] = np.clip(image[chn], lower_c, upper_c) sample['image'] = image return sample + +class HistEqual(AbstractTransform): + """ + Histogram equalization. Note that the output will be in the range of [0, 1]. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `HistEqual_channels`: (list) A list of int for specifying the channels. + :param `HistEqual_bin`: (int) The number of bins. + :param `HistEqual_inverse`: (optional, bool) + Is inverse transform needed for inference. Default is `False`. + """ + def __init__(self, params): + super(HistEqual, self).__init__(params) + self.channels = params.get('HistEqual_channels'.lower(), None) + # self.min = params.get('HistEqual_min'.lower(), None) + # self.max = params.get('HistEqual_max'.lower(), None) + self.bin = params.get('HistEqual_bin'.lower(), 2000) + self.inverse = params.get('HistEqual_inverse'.lower(), False) + def __call__(self, sample): + image = sample['image'] + C = image.shape[0] + chns = range(C) if self.channels is None else self.channels + for i in range(len(chns)): + c = chns[i] + image[c] = exposure.equalize_hist(image[c],nbins= self.bin) + sample['image'] = image + return sample + class GammaCorrection(AbstractTransform): """ Apply random gamma correction to given channels. @@ -189,7 +222,7 @@ def __init__(self, params): self.block_size = params.get('NonLinearTransform_block_size'.lower(), [8, 16, 16]) - def __apply_nonlinear_transform(self, img): + def apply_nonlinear_transform(self, img): """ the input img should be normlized to [0, 1]""" points = [[0, 0], [random.random(), random.random()], [random.random(), random.random()], [1, 1]] @@ -217,7 +250,7 @@ def __call__(self, sample): if(v_min < v_max): img_c = (img_c - v_min)/(v_max - v_min) if(self.block_range is None): # apply non-linear transform to the entire image - img_c = self.__apply_nonlinear_transform(img_c) + img_c = self.apply_nonlinear_transform(img_c) else: # non-linear transform to random blocks img_c_sr = copy.deepcopy(img_c) for n in range(self.block_range[0], self.block_range[1]): @@ -229,7 +262,7 @@ def __call__(self, sample): img_c[coord_min[0]:coord_min[0] + self.block_size[0], coord_min[1]:coord_min[1] + self.block_size[1], coord_min[2]:coord_min[2] + self.block_size[2]] = \ - self.__apply_nonlinear_transform(window) + self.apply_nonlinear_transform(window) image[chn] = img_c * (v_max - v_min) + v_min sample['image'] = image return sample @@ -423,6 +456,44 @@ def __call__(self, sample): img_out[:, pos_b0[0]:pos_b1[0], pos_b0[1]:pos_b1[1], pos_b0[2]:pos_b1[2]] = \ image[:, pos_a0[0]:pos_a1[0], pos_a0[1]:pos_a1[1], pos_a0[2]:pos_a1[2]] + sample['image'] = img_out + sample['label'] = image + return sample + +class MaskedImageModeling(AbstractTransform): + """ + Apply masking for context restoration in self-supervised learning. + Reference: Zekai Chen et al., Masked Image Modeling Advances 3D Medical Image Analysis, + WACV, 2023 . + """ + def __init__(self, params): + super(MaskedImageModeling, self).__init__(params) + self.ratio = params.get('MaskedImageModeling_ratio'.lower(), 0.45) + self.block_size = params.get('MaskedImageModeling_block_size'.lower(), [8, 16, 16]) + self.inverse = params.get('MaskedImageModeling_inverse'.lower(), False) + + def __call__(self, sample): + image= sample['image'] + C, D, H, W = image.shape + img_out = copy.deepcopy(image) + + block = np.zeros([C] + list(self.block_size)) + for d in range(0, D, self.block_size[0]): + d1 = d + self.block_size[0] + if d1 > D: + continue + for h in range(0, H, self.block_size[1]): + h1 = h + self.block_size[1] + if h1 > H: + continue + for w in range(0, W, self.block_size[2]): + w1 = w + self.block_size[2] + if w1 > W: + continue + r = random.random() + if ( r < self.ratio): + img_out[:, d:d1, h:h1, w:w1] = block + sample['image'] = img_out sample['label'] = image return sample \ No newline at end of file diff --git a/pymic/transform/label_convert.py b/pymic/transform/label_convert.py index 00c505e..afbbaaf 100644 --- a/pymic/transform/label_convert.py +++ b/pymic/transform/label_convert.py @@ -92,6 +92,8 @@ def __call__(self, sample): label_prob = np.zeros((self.class_num,), np.float32) label_prob[label_idx] = 1.0 sample['label_prob'] = label_prob + elif(self.task == TaskType.CLASSIFICATION_COEXIST): + sample['label_prob'] = sample['label'] return sample class LabelSmooth(AbstractTransform): diff --git a/pymic/util/evaluation_seg.py b/pymic/util/evaluation_seg.py index 82939ce..926ff0e 100644 --- a/pymic/util/evaluation_seg.py +++ b/pymic/util/evaluation_seg.py @@ -387,14 +387,14 @@ def main(): args = parser.parse_args() print(args) if(args.cfg is not None): - config = parse_config(args.cfg)['evaluation'] + config = parse_config(args)['evaluation'] else: config = {} config['metric_list'] = parse_value_from_string(args.metric) config['label_list'] = None if args.cls_index is None else parse_value_from_string(args.cls_index) config['class_number']= None if args.cls_num is None else parse_value_from_string(args.cls_num) - config['ground_truth_folder_root'] = args.gt_dir - config['segmentation_folder_root'] = args.seg_dir + config['ground_truth_folder'] = args.gt_dir + config['segmentation_folder'] = args.seg_dir config['evaluation_image_pair'] = args.name_pair config['output_name'] = args.out print(config) @@ -402,3 +402,5 @@ def main(): if __name__ == '__main__': main() + + main() diff --git a/pymic/util/parse_config.py b/pymic/util/parse_config.py index 0e38b91..09f6db0 100644 --- a/pymic/util/parse_config.py +++ b/pymic/util/parse_config.py @@ -84,14 +84,18 @@ def parse_value_from_string(val_str): val = val_str return val -def parse_config(filename): +def parse_config(args): config = configparser.ConfigParser() - config.read(filename) + config.read(args.cfg) output = {} for section in config.sections(): output[section] = {} for key in config[section]: val_str = str(config[section][key]) + if hasattr(args, key): + args_key = getattr(args, key) + if(args_key is not None): + val_str = args_key if(len(val_str)>0): val = parse_value_from_string(val_str) output[section][key] = val @@ -133,6 +137,24 @@ def synchronize_config(config): # config['network'] = net_cfg return config +def wrtie_config(config, output_name): + logging.info("The running configuations are: ") + with open(output_name, 'w') as f: + for section in config: + if(isinstance(config[section], dict)): + line = '[' + section + ']' + f.write('\n' + line + '\n') + logging.info(line) + for key in config[section]: + value = config[section][key] + line = "{0:} = {1:}".format(key, value) + f.write(line + '\n') + logging.info(line) + else: + line = "{0:} = {1:}".format(section, config[section]) + f.write(line + "\n") + logging.info(line) + def logging_config(config): for section in config: if(isinstance(config[section], dict)): From f823eb8ba5e23327f8556f432a8f114a792731ca Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 6 Aug 2024 17:02:39 +0800 Subject: [PATCH 48/86] update transform --- README.md | 2 +- pymic/transform/crop.py | 10 ++-- pymic/transform/intensity.py | 100 ++++++++++++++++++++++++++++++++--- pymic/transform/normalize.py | 17 ++++-- 4 files changed, 112 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 90a6de7..6f34f92 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # PyMIC: A Pytorch-Based Toolkit for Medical Image Computing -PyMIC is a pytorch-based toolkit for medical image computing with annotation-efficient deep learning. Despite that pytorch is a fantastic platform for deep learning, using it for medical image computing is not straightforward as medical images are often with high dimension and large volume, multiple modalities and difficulies in annotating. This toolkit is developed to facilitate medical image computing researchers so that training and testing deep learning models become easier. It is very friendly to researchers who are new to this area. Even without writing any code, you can use PyMIC commands to train and test a model by simply editing configuration files. PyMIC is developed to support learning with imperfect labels, including semi-supervised, self-supervised, and weakly supervised learning, and learning with noisy annotations. +PyMIC is a pytorch-based toolkit for medical image computing with annotation-efficient deep learning. Despite that pytorch is a fantastic platform for deep learning, using it for medical image computing is not straightforward as medical images are often with high dimension and large volume, multiple modalities and difficulies in annotating. This toolkit is developed to facilitate medical image computing researchers so that training and testing deep learning models become easier. It is very friendly to researchers who are new to this area. Even without writing any code, you can use PyMIC commands to train and test a model by simply editing configuration files. PyMIC is developed to support learning with imperfect labels, including semi-supervised, self-supervised, and weakly supervised learning, and learning with noisy annotations. Currently PyMIC supports 2D/3D medical image classification and segmentation, and it is still under development. If you use this toolkit, please cite the following paper: diff --git a/pymic/transform/crop.py b/pymic/transform/crop.py index b821bb2..9e1c077 100644 --- a/pymic/transform/crop.py +++ b/pymic/transform/crop.py @@ -428,9 +428,9 @@ def __call__(self, sample): return sample -class CropHumanRegionFromCT(CenterCrop): +class CropHumanRegion(CenterCrop): """ - Crop the human region from a CT volume. + Crop the human region from a CT for MRI volume. The arguments should be written in the `params` dictionary, and it has the following fields: @@ -447,9 +447,9 @@ class CropHumanRegionFromCT(CenterCrop): Default is `True`. """ def __init__(self, params): - self.threshold_i = params.get('CropHumanRegionFromCT_intensity_threshold'.lower(), -600) - self.threshold_z = params.get('CropHumanRegionFromCT_zaxis_threshold'.lower(), 0.5) - self.inverse = params.get('CropHumanRegionFromCT_inverse'.lower(), True) + self.threshold_i = params.get('CropHumanRegion_intensity_threshold'.lower(), -600) + self.threshold_z = params.get('CropHumanRegion_zaxis_threshold'.lower(), 0.5) + self.inverse = params.get('CropHumanRegion_inverse'.lower(), True) self.task = params['task'] def _get_crop_param(self, sample): diff --git a/pymic/transform/intensity.py b/pymic/transform/intensity.py index 2829e76..f05b95b 100644 --- a/pymic/transform/intensity.py +++ b/pymic/transform/intensity.py @@ -47,6 +47,7 @@ def bezier_curve(points, nTimes=1000): return xvals, yvals + class IntensityClip(AbstractTransform): """ Clip the intensity for input image @@ -161,6 +162,48 @@ def __call__(self, sample): sample['image'] = image return sample +def gaussian_noise(image, std_min, std_max,): + """ + The input has a shape of [C, D, H, W] or [D, H, W]. + In the former case, volume-level noise will be added. + In the latter case, slice-level noise will ba added. + """ + v_min = image.min() + v_max = image.max() + std = random.random() * (std_max - std_min) + std_min + noise = np.random.normal(0, std, image.shape) + out = image + noise + out = np.clip(out, v_min, v_max) + return out + +def gaussian_blur(image, sigma_min, sigma_max): + sigma = random.random() * (sigma_max - sigma_min) + sigma_min + out = ndimage.gaussian_filter(image, sigma, order = 0) + return out + +def gaussian_sharpen(image, sigma_min, sigma_max, alpha = 10.0): + blurred = gaussian_blur(image, sigma_min, sigma_max) + out = image + (image - blurred) * alpha + return out + +def window_level_augment(image, offset = 0.1): + v_min = image.min() + v_max = image.max() + margin = (v_max - v_min) * offset + v0 = random.uniform(v_min - margin, v_min + margin) + v1 = random.uniform(v_max - margin, v_max + margin) + out = np.clip((image - v0) / (v1 - v0), 0, 1) + return out + +def gamma_correction(image, gamma_min, gamma_max): + v_min = image.min() + v_max = image.max() + if(v_min < v_max): + image = (image - v_min)/(v_max - v_min) + gamma = random.random() * (gamma_max - gamma_min) + gamma_min + image = np.power(image, gamma)*(v_max - v_min) + v_min + return image + class GaussianNoise(AbstractTransform): """ Add Gaussian Noise to given channels. @@ -179,8 +222,8 @@ class GaussianNoise(AbstractTransform): def __init__(self, params): super(GaussianNoise, self).__init__(params) self.channels = params.get('GaussianNoise_channels'.lower(), None) - self.mean = params['GaussianNoise_mean'.lower()] - self.std = params['GaussianNoise_std'.lower()] + self.std_min = params.get('GaussianNoise_std_min'.lower(), 0.02) + self.std_max = params.get('GaussianNoise_std_max'.lower(), 0.1) self.prob = params.get('GaussianNoise_probability'.lower(), 0.5) self.inverse = params.get('GaussianNoise_inverse'.lower(), False) @@ -190,10 +233,53 @@ def __call__(self, sample): self.channels = range(image.shape[0]) for chn in self.channels: if(np.random.uniform() < self.prob): - img_c = image[chn] - noise = np.random.normal(self.mean, self.std, img_c.shape) - image[chn] = img_c + noise + image[chn] = gaussian_noise(image[chn], self.std_min, self.std_max) + sample['image'] = image + return sample + +def adaptive_contrast_adjust(image, p0=0.1, p1=99.9): + v_min = image.min() + v_max = image.max() + v0 = np.percentile(image, p0) + v1 = np.percentile(image, p1) + mask_l = image < v0 + mask_m = (image >= v0) * (image <= v1) + mask_u = image > v1 + image[mask_l] = (image[mask_l] - v_min) * 0.1 / (v0 - v_min) + image[mask_m] = (image[mask_m] - v0) / (v1 - v0)*0.8 + 0.1 + image[mask_u] = 0.9 + 0.1 * (image[mask_u] - v1) / (v_max - v1) + return image + +class AdaptiveContrastAdjust(AbstractTransform): + """ + Add Gaussian Noise to given channels. + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `GaussianNoise_channels`: (list) A list of int for specifying the channels. + :param `GaussianNoise_mean`: (float) The mean value of noise. + :param `GaussianNoise_std`: (float) The std of noise. + :param `GaussianNoise_probability`: (optional, float) + The probability of applying GaussianNoise. Default is 0.5. + :param `GaussianNoise_inverse`: (optional, bool) + Is inverse transform needed for inference. Default is `False`. + """ + def __init__(self, params): + super(AdaptiveContrastAdjust, self).__init__(params) + self.channels = params.get('AdaptiveContrastAdjust_channels'.lower(), None) + self.p0 = params.get('AdaptiveContrastAdjust_percent_lower'.lower(), 2) + self.p1 = params.get('AdaptiveContrastAdjust_percent_upper'.lower(), 98) + self.prob = params.get('AdaptiveContrastAdjust_probability'.lower(), 0.5) + self.inverse = params.get('AdaptiveContrastAdjust_inverse'.lower(), False) + + def __call__(self, sample): + image = sample['image'] * 1.0 + if(self.channels is None): + self.channels = range(image.shape[0]) + for chn in self.channels: + if(np.random.uniform() < self.prob): + image[chn] = adaptive_contrast_adjust(image[chn], self.p0, self.p1) sample['image'] = image return sample @@ -219,7 +305,7 @@ def __init__(self, params): self.prob = params.get('NonLinearTransform_probability'.lower(), 0.5) self.inverse = params.get('NonLinearTransform_inverse'.lower(), False) self.block_range = params.get('NonLinearTransform_block_range'.lower(), None) - self.block_size = params.get('NonLinearTransform_block_size'.lower(), [8, 16, 16]) + self.block_size = params.get('NonLinearTransform_block_size'.lower(), [4, 8, 8]) def apply_nonlinear_transform(self, img): @@ -326,7 +412,7 @@ def __init__(self, params): self.inverse = params.get('InPainting_inverse'.lower(), False) self.prob = params.get('InPainting_probability'.lower(), 0.5) self.block_range = params.get('InPainting_block_range'.lower(), (20, 40)) - self.block_size = params.get('InPainting_block_size'.lower(), [8, 16, 16]) + self.block_size = params.get('InPainting_block_size'.lower(), [4, 8, 8]) def __call__(self, sample): if(random.random() > self.prob): diff --git a/pymic/transform/normalize.py b/pymic/transform/normalize.py index 643c12e..35c5dc4 100644 --- a/pymic/transform/normalize.py +++ b/pymic/transform/normalize.py @@ -131,14 +131,17 @@ class NormalizeWithPercentiles(AbstractTransform): The min percentile, which must be between 0 and 100 inclusive. :param `NormalizeWithPercentiles_percentile_upper`: (float) The max percentile, which must be between 0 and 100 inclusive. + :param `NormalizeWithPercentiles_output_mode`: (int) 0: the output is in the range [0,1] + Otherwise the output is in the range of [-1, 1] :param `NormalizeWithMinMax_inverse`: (optional, bool) Is inverse transform needed for inference. Default is `False`. """ def __init__(self, params): super(NormalizeWithPercentiles, self).__init__(params) - self.chns = params['NormalizeWithPercentiles_channels'.lower()] - self.percent_lower = params['NormalizeWithPercentiles_percentile_lower'.lower()] - self.percent_upper = params['NormalizeWithPercentiles_percentile_upper'.lower()] + self.chns = params.get('NormalizeWithPercentiles_channels'.lower(), None) + self.percent_lower = params.get('NormalizeWithPercentiles_percentile_lower'.lower(), 0.1) + self.percent_upper = params.get('NormalizeWithPercentiles_percentile_upper'.lower(), 99.9) + self.out_mode = params.get('NormalizeWithPercentiles_output_mode'.lower(), 0) self.inverse = params.get('NormalizeWithPercentiles_inverse'.lower(), False) def __call__(self, sample): @@ -152,7 +155,13 @@ def __call__(self, sample): img_chn[img_chn < v0] = v0 img_chn[img_chn > v1] = v1 - img_chn = 2.0* (img_chn - v0) / (v1 - v0) -1.0 + if(self.out_mode == 0): + img_chn = (img_chn - v0) / (v1 - v0) + img_chn = np.clip(img_chn, 0, 1) + else: + img_chn = 2.0* (img_chn - v0) / (v1 - v0) -1.0 + img_chn = np.clip(img_chn, -1, 1) + image[chn] = img_chn sample['image'] = image return sample \ No newline at end of file From f4a2dcea7800f503772380835fefda782b578766 Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 4 Sep 2024 16:24:59 +0800 Subject: [PATCH 49/86] allow timing --- pymic/net/net3d/unet2d5.py | 258 ++++++++++++------- pymic/net/net3d/unet3d.py | 102 +++++--- pymic/net/net3d/unet3d_dual_branch.py | 3 +- pymic/net/net3d/unet3d_scse.py | 133 ++++------ pymic/net_run/noisy_label/nll_co_teaching.py | 31 ++- pymic/net_run/noisy_label/nll_dast.py | 24 +- pymic/net_run/noisy_label/nll_trinet.py | 26 +- pymic/net_run/self_sup/self_volume_fusion.py | 245 +++++++++++++++++- pymic/net_run/self_sup/util.py | 159 +++++++++++- pymic/net_run/semi_sup/__init__.py | 2 + pymic/net_run/semi_sup/ssl_abstract.py | 3 + pymic/net_run/semi_sup/ssl_cct.py | 22 +- pymic/net_run/semi_sup/ssl_cps.py | 23 +- pymic/net_run/semi_sup/ssl_em.py | 24 +- pymic/net_run/semi_sup/ssl_mcnet.py | 24 +- pymic/net_run/semi_sup/ssl_mt.py | 23 +- pymic/net_run/semi_sup/ssl_uamt.py | 22 +- pymic/net_run/semi_sup/ssl_urpc.py | 22 +- pymic/net_run/weak_sup/wsl_abstract.py | 9 +- pymic/net_run/weak_sup/wsl_dmpls.py | 24 +- pymic/net_run/weak_sup/wsl_em.py | 24 +- pymic/net_run/weak_sup/wsl_gatedcrf.py | 22 +- pymic/net_run/weak_sup/wsl_mumford_shah.py | 22 +- pymic/net_run/weak_sup/wsl_tv.py | 22 +- pymic/net_run/weak_sup/wsl_ustm.py | 22 +- 25 files changed, 951 insertions(+), 340 deletions(-) diff --git a/pymic/net/net3d/unet2d5.py b/pymic/net/net3d/unet2d5.py index 308fdde..4e70393 100644 --- a/pymic/net/net3d/unet2d5.py +++ b/pymic/net/net3d/unet2d5.py @@ -1,9 +1,16 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division + +import logging import torch import torch.nn as nn import numpy as np +ConvND = {2: nn.Conv2d, 3: nn.Conv3d} +BatchNormND = {2: nn.BatchNorm2d, 3: nn.BatchNorm3d} +MaxPoolND = {2: nn.MaxPool2d, 3: nn.MaxPool3d} +ConvTransND = {2: nn.ConvTranspose2d, 3: nn.ConvTranspose3d} + class ConvBlockND(nn.Module): """ 2D or 3D convolutional block @@ -13,29 +20,17 @@ class ConvBlockND(nn.Module): :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. :param dropout_p: (int) Dropout probability. """ - def __init__(self, in_channels, out_channels, - dim = 2, dropout_p = 0.0): + def __init__(self, in_channels, out_channels, dim = 2, dropout_p = 0.0): super(ConvBlockND, self).__init__() assert(dim == 2 or dim == 3) self.dim = dim - if(self.dim == 2): - self.conv_conv = nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), - nn.BatchNorm2d(out_channels), - nn.PReLU(), - nn.Dropout(dropout_p), - nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), - nn.BatchNorm2d(out_channels), - nn.PReLU() - ) - else: - self.conv_conv = nn.Sequential( - nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1), - nn.BatchNorm3d(out_channels), + self.conv_conv = nn.Sequential( + ConvND[dim](in_channels, out_channels, kernel_size=3, padding=1), + BatchNormND[dim](out_channels), nn.PReLU(), nn.Dropout(dropout_p), - nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1), - nn.BatchNorm3d(out_channels), + ConvND[dim](out_channels, out_channels, kernel_size=3, padding=1), + BatchNormND[dim](out_channels), nn.PReLU() ) @@ -52,17 +47,12 @@ class DownBlock(nn.Module): :param dropout_p: (int) Dropout probability. :param downsample: (bool) Use downsample or not after convolution. """ - def __init__(self, in_channels, out_channels, - dim = 2, dropout_p = 0.0, downsample = True): + def __init__(self, in_channels, out_channels, dim = 2, dropout_p = 0.0, downsample = True): super(DownBlock, self).__init__() self.downsample = downsample self.dim = dim self.conv = ConvBlockND(in_channels, out_channels, dim, dropout_p) - if(downsample): - if(self.dim == 2): - self.down_layer = nn.MaxPool2d(kernel_size = 2, stride = 2) - else: - self.down_layer = nn.MaxPool3d(kernel_size = 2, stride = 2) + self.down_layer = MaxPoolND[dim](kernel_size = 2, stride = 2) def forward(self, x): x_shape = list(x.shape) @@ -95,28 +85,31 @@ class UpBlock(nn.Module): :param out_channels: (int) Output channel number. :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. :param dropout_p: (int) Dropout probability. - :param bilinear: (bool) Use bilinear for up-sampling or not. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear` for 3D and `Bilinear` for 2D). + The default value is 2. """ def __init__(self, in_channels1, in_channels2, out_channels, - dim = 2, dropout_p = 0.0, bilinear=True): + dim = 2, dropout_p = 0.0, up_mode= 2): super(UpBlock, self).__init__() - self.bilinear = bilinear - self.dim = dim - if bilinear: - if(dim == 2): - self.up = nn.Sequential( - nn.Conv2d(in_channels1, in_channels2, kernel_size = 1), - nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)) - else: - self.up = nn.Sequential( - nn.Conv3d(in_channels1, in_channels2, kernel_size = 1), - nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)) + if(isinstance(up_mode, int)): + up_mode_values = ["transconv", "nearest", "trilinear"] + if(up_mode > 2): + raise ValueError("The upsample mode should be 0-2, but {0:} is given.".format(up_mode)) + self.up_mode = up_mode_values[up_mode] else: - if(dim == 2): - self.up = nn.ConvTranspose2d(in_channels1, in_channels2, kernel_size=2, stride=2) + self.up_mode = up_mode.lower() + + self.dim = dim + if (self.up_mode == "transconv"): + self.up = ConvTransND[dim](in_channels1, in_channels2, kernel_size=2, stride=2) + else: + self.conv1x1 = ConvND[dim](in_channels1, in_channels2, kernel_size = 1) + if(self.up_mode == "nearest"): + self.up = nn.Upsample(scale_factor=2, mode=self.up_mode) else: - self.up = nn.ConvTranspose3d(in_channels1, in_channels2, kernel_size=2, stride=2) - + mode = "trilinear" if dim == 3 else "bilinear" + self.up = nn.Upsample(scale_factor=2, mode=mode, align_corners=True) self.conv = ConvBlockND(in_channels2 * 2, out_channels, dim, dropout_p) def forward(self, x1, x2): @@ -132,6 +125,8 @@ def forward(self, x1, x2): x2 = torch.transpose(x2, 1, 2) x2 = torch.reshape(x2, new_shape) + if self.up_mode != "transconv": + x1 = self.conv1x1(x1) x1 = self.up(x1) output = torch.cat([x2, x1], dim=1) output = self.conv(output) @@ -141,6 +136,98 @@ def forward(self, x1, x2): output = torch.transpose(output, 1, 2) return output +class Encoder(nn.Module): + """ + A modification of the encoder of 3D UNet by using ConvScSEBlock3D + + Parameters are given in the `params` dictionary. + See :mod:`pymic.net.net3d.unet3d.Encoder` for details. + """ + def __init__(self, params): + super(Encoder, self).__init__() + self.params = params + self.in_chns = self.params['in_chns'] + self.n_class = self.params['class_num'] + self.ft_chns = self.params['feature_chns'] + self.dropout = self.params['dropout'] + self.dims = self.params['conv_dims'] + + self.block0 = DownBlock(self.in_chns, self.ft_chns[0], self.dims[0], self.dropout[0], True) + self.block1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dims[1], self.dropout[1], True) + self.block2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dims[2], self.dropout[2], True) + self.block3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dims[3], self.dropout[3], True) + self.block4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dims[4], self.dropout[4], False) + + def forward(self, x): + x0, x0_d = self.block0(x) + x1, x1_d = self.block1(x0_d) + x2, x2_d = self.block2(x1_d) + x3, x3_d = self.block3(x2_d) + x4, x4_d = self.block4(x3_d) + return [x0, x1, x2, x3, x4] + +class Decoder(nn.Module): + """ + Decoder of 3D UNet. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear` for 3D and `Bilinear` for 2D). + The default value is 2. + :param multiscale_pred: (bool) Get multi-scale prediction. + """ + def __init__(self, params): + super(Decoder, self).__init__() + self.params = params + self.in_chns = self.params['in_chns'] + self.n_class = self.params['class_num'] + self.ft_chns = self.params['feature_chns'] + self.dropout = self.params['dropout'] + self.dims = self.params['conv_dims'] + self.up_mode = self.params.get('up_mode', 2) + self.mul_pred = self.params.get('multiscale_pred', False) + + self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], + self.dims[3], dropout_p = self.dropout[3], up_mode=self.up_mode) + self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], + self.dims[2], dropout_p = self.dropout[2], up_mode=self.up_mode) + self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], + self.dims[1], dropout_p = self.dropout[1], up_mode=self.up_mode) + self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], + self.dims[0], dropout_p = self.dropout[0], up_mode=self.up_mode) + + self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 1) + if(self.mul_pred): + self.out_conv1 = nn.Conv3d(self.ft_chns[1], self.n_class, kernel_size = 1) + self.out_conv2 = nn.Conv3d(self.ft_chns[2], self.n_class, kernel_size = 1) + self.out_conv3 = nn.Conv3d(self.ft_chns[3], self.n_class, kernel_size = 1) + self.stage = 'train' + + def set_stage(self, stage): + self.stage = stage + + def forward(self, x): + x0, x1, x2, x3, x4 = x + x_d3 = self.up1(x4, x3) + x_d2 = self.up2(x_d3, x2) + x_d1 = self.up3(x_d2, x1) + x_d0 = self.up4(x_d1, x0) + output = self.out_conv(x_d0) + if(self.mul_pred and self.stage == 'train'): + output1 = self.out_conv1(x_d1) + output2 = self.out_conv2(x_d2) + output3 = self.out_conv3(x_d3) + output = [output, output1, output2, output3] + return output + class UNet2D5(nn.Module): """ A 2.5D network combining 3D convolutions with 2D convolutions. @@ -164,68 +251,39 @@ class UNet2D5(nn.Module): :param conv_dims: (list) The convolution dimension (2 or 3) for each resolution level. The length should be the same as that of `feature_chns`. :param class_num: (int) The class number for segmentation task. - :param bilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear`). The default value + is 2 (`Trilinear`). + :param multiscale_pred: (bool) Get multi-scale prediction. """ def __init__(self, params): super(UNet2D5, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.dims = self.params['conv_dims'] - self.n_class = self.params['class_num'] - self.bilinear = self.params['bilinear'] - - assert(len(self.ft_chns) == 5) + params = self.get_default_parameters(params) + for p in params: + print(p, params[p]) + self.stage = 'train' + self.encoder = Encoder(params) + self.decoder = Decoder(params) + + def get_default_parameters(self, params): + default_param = { + 'feature_chns': [32, 64, 128, 256, 512], + 'dropout': [0.0, 0.0, 0.2, 0.3, 0.4], + 'conv_dims':[2, 2, 3, 3, 3], + 'up_mode': 2, + 'multiscale_pred': False + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params - self.block0 = DownBlock(self.in_chns, self.ft_chns[0], self.dims[0], self.dropout[0], True) - self.block1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dims[1], self.dropout[1], True) - self.block2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dims[2], self.dropout[2], True) - self.block3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dims[3], self.dropout[3], True) - self.block4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dims[4], self.dropout[4], False) - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], - self.dims[3], dropout_p = self.dropout[3], bilinear = self.bilinear) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], - self.dims[2], dropout_p = self.dropout[2], bilinear = self.bilinear) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], - self.dims[1], dropout_p = self.dropout[1], bilinear = self.bilinear) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], - self.dims[0], dropout_p = self.dropout[0], bilinear = self.bilinear) - - self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, - kernel_size = (1, 3, 3), padding = (0, 1, 1)) + def set_stage(self, stage): + self.stage = stage + self.decoder.set_stage(stage) def forward(self, x): - x0, x0_d = self.block0(x) - x1, x1_d = self.block1(x0_d) - x2, x2_d = self.block2(x1_d) - x3, x3_d = self.block3(x2_d) - x4, x4_d = self.block4(x3_d) - - x = self.up1(x4, x3) - x = self.up2(x, x2) - x = self.up3(x, x1) - x = self.up4(x, x0) - output = self.out_conv(x) + f = self.encoder(x) + output = self.decoder(f) return output - - -if __name__ == "__main__": - params = {'in_chns':4, - 'feature_chns':[2, 8, 32, 48, 64], - 'conv_dims': [2, 2, 3, 3, 3], - 'dropout': [0, 0, 0.3, 0.4, 0.5], - 'class_num': 2, - 'bilinear': False} - Net = UNet2D5(params) - Net = Net.double() - - x = np.random.rand(4, 4, 32, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - print(len(y.size())) - y = y.detach().numpy() - print(y.shape) diff --git a/pymic/net/net3d/unet3d.py b/pymic/net/net3d/unet3d.py index e383e77..b66ea38 100644 --- a/pymic/net/net3d/unet3d.py +++ b/pymic/net/net3d/unet3d.py @@ -5,6 +5,7 @@ import torch import torch.nn as nn import numpy as np +from pymic.net.net_init import Initialization_He, Initialization_XavierUniform class ConvBlock(nn.Module): @@ -16,15 +17,23 @@ class ConvBlock(nn.Module): :param out_channels: (int) Output channel number. :param dropout_p: (int) Dropout probability. """ - def __init__(self, in_channels, out_channels, dropout_p): + def __init__(self, in_channels, out_channels, dropout_p, norm_type = 'batch_norm'): super(ConvBlock, self).__init__() + if(norm_type == 'batch_norm'): + norm1 = nn.BatchNorm3d(out_channels, affine = True) + norm2 = nn.BatchNorm3d(out_channels, affine = True) + elif(norm_type == 'instance_norm'): + norm1 = nn.InstanceNorm3d(out_channels, affine = True) + norm2 = nn.InstanceNorm3d(out_channels, affine = True) + else: + raise ValueError("norm_type {0:} not supported, it should be batch_norm or instance_norm".format(norm_type)) self.conv_conv = nn.Sequential( nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1), - nn.BatchNorm3d(out_channels), + norm1, nn.LeakyReLU(), nn.Dropout(dropout_p), nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1), - nn.BatchNorm3d(out_channels), + norm2, nn.LeakyReLU() ) @@ -39,11 +48,11 @@ class DownBlock(nn.Module): :param out_channels: (int) Output channel number. :param dropout_p: (int) Dropout probability. """ - def __init__(self, in_channels, out_channels, dropout_p): + def __init__(self, in_channels, out_channels, dropout_p, norm_type = 'batch_norm'): super(DownBlock, self).__init__() self.maxpool_conv = nn.Sequential( nn.MaxPool3d(2), - ConvBlock(in_channels, out_channels, dropout_p) + ConvBlock(in_channels, out_channels, dropout_p, norm_type) ) def forward(self, x): @@ -61,7 +70,8 @@ class UpBlock(nn.Module): 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear`). The default value is 2 (`Trilinear`). """ - def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, up_mode=2): + def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, + up_mode=2, norm_type = 'batch_norm'): super(UpBlock, self).__init__() if(isinstance(up_mode, int)): up_mode_values = ["transconv", "nearest", "trilinear"] @@ -79,7 +89,7 @@ def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, up_mode= self.up = nn.Upsample(scale_factor=2, mode=self.up_mode) else: self.up = nn.Upsample(scale_factor=2, mode=self.up_mode, align_corners=True) - self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p) + self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p, norm_type) def forward(self, x1, x2): if self.up_mode != "transconv": @@ -104,17 +114,19 @@ class Encoder(nn.Module): def __init__(self, params): super(Encoder, self).__init__() self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) - - self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) - self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) - self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) - self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) - if(len(self.ft_chns) == 5): - self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) + in_chns = self.params['in_chns'] + ft_chns = self.params['feature_chns'] + dropout = self.params['dropout'] + norm_type = self.params['norm_type'] + assert(len(ft_chns) == 5 or len(ft_chns) == 4) + + self.ft_chns= ft_chns + self.in_conv= ConvBlock(in_chns, ft_chns[0], dropout[0], norm_type) + self.down1 = DownBlock(ft_chns[0], ft_chns[1], dropout[1], norm_type) + self.down2 = DownBlock(ft_chns[1], ft_chns[2], dropout[2], norm_type) + self.down3 = DownBlock(ft_chns[2], ft_chns[3], dropout[3], norm_type) + if(len(ft_chns) == 5): + self.down4 = DownBlock(ft_chns[3], ft_chns[4], dropout[4]) def forward(self, x): x0 = self.in_conv(x) @@ -148,25 +160,26 @@ class Decoder(nn.Module): def __init__(self, params): super(Decoder, self).__init__() self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.n_class = self.params['class_num'] - self.up_mode = self.params.get('up_mode', 2) - self.mul_pred = self.params.get('multiscale_pred', False) - assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) + ft_chns = self.params['feature_chns'] + dropout = self.params['dropout'] + n_class = self.params['class_num'] + norm_type = self.params['norm_type'] + up_mode = self.params.get('up_mode', 2) + self.ft_chns = ft_chns + self.mul_pred = self.params.get('multiscale_pred', False) + assert(len(ft_chns) == 5 or len(ft_chns) == 4) - if(len(self.ft_chns) == 5): - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.up_mode) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.up_mode) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.up_mode) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.up_mode) - self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 1) + if(len(ft_chns) == 5): + self.up1 = UpBlock(ft_chns[4], ft_chns[3], ft_chns[3], dropout[3], up_mode, norm_type) + self.up2 = UpBlock(ft_chns[3], ft_chns[2], ft_chns[2], dropout[2], up_mode, norm_type) + self.up3 = UpBlock(ft_chns[2], ft_chns[1], ft_chns[1], dropout[1], up_mode, norm_type) + self.up4 = UpBlock(ft_chns[1], ft_chns[0], ft_chns[0], dropout[0], up_mode, norm_type) + self.out_conv = nn.Conv3d(ft_chns[0], n_class, kernel_size = 1) if(self.mul_pred): - self.out_conv1 = nn.Conv3d(self.ft_chns[1], self.n_class, kernel_size = 1) - self.out_conv2 = nn.Conv3d(self.ft_chns[2], self.n_class, kernel_size = 1) - self.out_conv3 = nn.Conv3d(self.ft_chns[3], self.n_class, kernel_size = 1) + self.out_conv1 = nn.Conv3d(ft_chns[1], n_class, kernel_size = 1) + self.out_conv2 = nn.Conv3d(ft_chns[2], n_class, kernel_size = 1) + self.out_conv3 = nn.Conv3d(ft_chns[3], n_class, kernel_size = 1) self.stage = 'train' def set_stage(self, stage): @@ -223,14 +236,21 @@ def __init__(self, params): for p in params: print(p, params[p]) self.stage = 'train' + self.update_mode = params.get("update_mode", "all") self.encoder = Encoder(params) - self.decoder = Decoder(params) - + self.decoder = Decoder(params) + + init = params['initialization'].lower() + weightInitializer = Initialization_He(1e-2) if init == 'he' else Initialization_XavierUniform() + self.apply(weightInitializer) + def get_default_parameters(self, params): default_param = { 'feature_chns': [32, 64, 128, 256, 512], 'dropout': [0.0, 0.0, 0.2, 0.3, 0.4], 'up_mode': 2, + 'initialization': 'he', + 'norm_type': 'batch_norm', 'multiscale_pred': False } for key in default_param: @@ -243,6 +263,16 @@ def set_stage(self, stage): self.stage = stage self.decoder.set_stage(stage) + def get_parameters_to_update(self): + if(self.update_mode == "all"): + return self.parameters() + elif(self.update_mode == "decoder"): + print("only update parameters in decoder") + params = self.decoder.parameters() + return params + else: + raise(ValueError("update_mode can only be 'all' or 'decoder'.")) + def forward(self, x): f = self.encoder(x) output = self.decoder(f) diff --git a/pymic/net/net3d/unet3d_dual_branch.py b/pymic/net/net3d/unet3d_dual_branch.py index 3bede4e..54b01a0 100644 --- a/pymic/net/net3d/unet3d_dual_branch.py +++ b/pymic/net/net3d/unet3d_dual_branch.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division -import torch import torch.nn as nn from pymic.net.net3d.unet3d import * @@ -20,7 +19,7 @@ class UNet3D_DualBranch(nn.Module): :param output_mode: (str) How to obtain the result during the inference. `average`: taking average of the two branches. - `first`: takeing the result in the first branch. + `first`: taking the result in the first branch. `second`: taking the result in the second branch. """ def __init__(self, params): diff --git a/pymic/net/net3d/unet3d_scse.py b/pymic/net/net3d/unet3d_scse.py index b2da0dc..79abf9c 100644 --- a/pymic/net/net3d/unet3d_scse.py +++ b/pymic/net/net3d/unet3d_scse.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn import numpy as np +from pymic.net.net3d.unet3d import UpBlock, Encoder, Decoder, UNet3D from pymic.net.net3d.scse3d import * class ConvScSEBlock3D(nn.Module): @@ -48,108 +49,70 @@ def __init__(self, in_channels, out_channels, dropout_p): def forward(self, x): return self.maxpool_conv(x) -class UpBlock(nn.Module): +class UpBlockScSE(UpBlock): """3D Up-sampling followed by `ConvScSEBlock3D` in UNet3D_ScSE. :param in_channels1: (int) Input channel number for low-resolution feature map. :param in_channels2: (int) Input channel number for high-resolution feature map. :param out_channels: (int) Output channel number. :param dropout_p: (int) Dropout probability. - :param trilinear: (bool) Use trilinear for up-sampling or not. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear`). The default value + is 2 (`Trilinear`). """ - def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, - trilinear=True): - super(UpBlock, self).__init__() - self.trilinear = trilinear - if trilinear: - self.conv1x1 = nn.Conv3d(in_channels1, in_channels2, kernel_size = 1) - self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True) - else: - self.up = nn.ConvTranspose3d(in_channels1, in_channels2, kernel_size=2, stride=2) + def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, up_mode=2): + super(UpBlockScSE, self).__init__(in_channels1, in_channels2, + out_channels, dropout_p, up_mode) self.conv = ConvScSEBlock3D(in_channels2 * 2, out_channels, dropout_p) - def forward(self, x1, x2): - if self.trilinear: - x1 = self.conv1x1(x1) - x1 = self.up(x1) - x = torch.cat([x2, x1], dim=1) - return self.conv(x) -class UNet3D_ScSE(nn.Module): +class EncoderScSE(Encoder): """ - Combining 3D U-Net with SCSE module. - - * Reference: Abhijit Guha Roy, Nassir Navab, Christian Wachinger: - Recalibrating Fully Convolutional Networks With Spatial and Channel - "Squeeze and Excitation" Blocks. - `IEEE Trans. Med. Imaging 38(2): 540-549 (2019). `_ + A modification of the encoder of 3D UNet by using ConvScSEBlock3D - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using trilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. + Parameters are given in the `params` dictionary. + See :mod:`pymic.net.net3d.unet3d.Encoder` for details. """ def __init__(self, params): - super(UNet3D_ScSE, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.n_class = self.params['class_num'] - self.bilinear = self.params['trilinear'] + super(EncoderScSE, self).__init__(params) - assert(len(self.ft_chns) == 5) - self.in_conv= ConvScSEBlock3D(self.in_chns, self.ft_chns[0], self.dropout[0]) self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) - self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p = self.dropout[3]) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p = self.dropout[2]) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p = self.dropout[1]) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p = self.dropout[0]) - - self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, - kernel_size = 3, padding = 1) + if(len(self.ft_chns) == 5): + self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - def forward(self, x): - - x0 = self.in_conv(x) - x1 = self.down1(x0) - x2 = self.down2(x1) - x3 = self.down3(x2) - x4 = self.down4(x3) - - x = self.up1(x4, x3) - x = self.up2(x, x2) - x = self.up3(x, x1) - x = self.up4(x, x0) - output = self.out_conv(x) - - return output - -if __name__ == "__main__": - params = {'in_chns':4, - 'feature_chns':[2, 8, 32, 48, 64], - 'dropout': [0, 0, 0.3, 0.4, 0.5], - 'class_num': 2, - 'trilinear': True} - Net = UNet3D_ScSE(params) - Net = Net.double() - - x = np.random.rand(4, 4, 96, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - print(len(y.size())) - y = y.detach().numpy() - print(y.shape) \ No newline at end of file +class DecoderScSE(Decoder): + """ + A modification of the decoder of 3D UNet by using ConvScSEBlock3D + + Parameters are given in the `params` dictionary. + See :mod:`pymic.net.net3d.unet3d.Decoder` for details. + """ + def __init__(self, params): + super(DecoderScSE, self).__init__(params) + + if(len(self.ft_chns) == 5): + self.up1 = UpBlockScSE(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.up_mode) + self.up2 = UpBlockScSE(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.up_mode) + self.up3 = UpBlockScSE(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.up_mode) + self.up4 = UpBlockScSE(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.up_mode) + + +class UNet3D_ScSE(UNet3D): + """ + Combining 3D U-Net with SCSE module. + + * Reference: Abhijit Guha Roy, Nassir Navab, Christian Wachinger: + Recalibrating Fully Convolutional Networks With Spatial and Channel + "Squeeze and Excitation" Blocks. + `IEEE Trans. Med. Imaging 38(2): 540-549 (2019). `_ + + Parameters are given in the `params` dictionary. + See :mod:`pymic.net.net3d.unet3d.UNet3D` for details. + """ + def __init__(self, params): + super(UNet3D_ScSE, self).__init__(params) + self.encoder = EncoderScSE(params) + self.decoder = DecoderScSE(params) diff --git a/pymic/net_run/noisy_label/nll_co_teaching.py b/pymic/net_run/noisy_label/nll_co_teaching.py index c60616e..d46b05b 100644 --- a/pymic/net_run/noisy_label/nll_co_teaching.py +++ b/pymic/net_run/noisy_label/nll_co_teaching.py @@ -5,16 +5,14 @@ import os import sys import numpy as np +import time import torch import torch.nn as nn -import torch.optim as optim -from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.util import reshape_tensor_to_2D from pymic.net_run.agent_seg import SegmentationAgent -from pymic.net.net_dict_seg import SegNetDict from pymic.util.parse_config import * from pymic.util.ramps import get_rampup_ratio @@ -51,19 +49,19 @@ def training(self): rampup_start = nll_cfg.get('rampup_start', 0) rampup_end = nll_cfg.get('rampup_end', iter_max) - train_loss_no_select1 = 0 - train_loss_no_select2 = 0 - train_loss1 = 0 - train_loss2 = 0 + train_loss_no_select1, train_loss_no_select2 = 0, 0 + train_loss1, train_avg_loss2 = 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() for it in range(iter_valid): + t0 = time.time() try: data = next(self.trainIter) except StopIteration: self.trainIter = iter(self.train_loader) data = next(self.trainIter) - + t1 = time.time() # get the inputs inputs = self.convert_tensor_type(data['image']) labels_prob = self.convert_tensor_type(data['label_prob']) @@ -74,7 +72,7 @@ def training(self): # forward + backward + optimize outputs1, outputs2 = self.net(inputs) - + t2 = time.time() prob1 = nn.Softmax(dim = 1)(outputs1) prob2 = nn.Softmax(dim = 1)(outputs2) prob1_2d = reshape_tensor_to_2D(prob1) * 0.999 + 5e-4 @@ -101,8 +99,9 @@ def training(self): loss2_select = loss2[ind_1_update] loss = loss1_select.mean() + loss2_select.mean() - + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss_no_select1 = train_loss_no_select1 + loss1.mean().item() @@ -115,6 +114,11 @@ def training(self): soft_out1, labels_prob = reshape_prediction_and_ground_truth(soft_out1, labels_prob) dice_list = get_classwise_dice(soft_out1, labels_prob).detach().cpu().numpy() train_dice_list.append(dice_list) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss_no_select1 = train_loss_no_select1 / iter_valid train_avg_loss_no_select2 = train_loss_no_select2 / iter_valid train_avg_loss1 = train_loss1 / iter_valid @@ -126,7 +130,9 @@ def training(self): 'loss1':train_avg_loss1, 'loss2': train_avg_loss2, 'loss_no_select1':train_avg_loss_no_select1, 'loss_no_select2':train_avg_loss_no_select2, - 'select_ratio':remb_ratio, 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'select_ratio':remb_ratio, 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } return train_scalers def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): @@ -153,3 +159,6 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): logging.info('valid loss {0:.4f}, avg foreground dice {1:.4f} '.format( valid_scalars['loss'], valid_scalars['avg_fg_dice']) + "[" + \ ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") + logging.info("data: {0:.2f}s, forward: {1:.2f}s, loss: {2:.2f}s, backward: {3:.2f}s".format( + train_scalars['data_time'], train_scalars['forward_time'], + train_scalars['loss_time'], train_scalars['backward_time'])) diff --git a/pymic/net_run/noisy_label/nll_dast.py b/pymic/net_run/noisy_label/nll_dast.py index a90747c..938e10a 100644 --- a/pymic/net_run/noisy_label/nll_dast.py +++ b/pymic/net_run/noisy_label/nll_dast.py @@ -1,8 +1,9 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division +import numpy as np import random import torch -import numpy as np +import time import torch.nn as nn import torchvision.transforms as transforms from pymic.io.nifty_dataset import NiftyDataset @@ -163,15 +164,15 @@ def training(self): iter_max = self.config['training']['iter_max'] rampup_start = nll_cfg.get('rampup_start', 0) rampup_end = nll_cfg.get('rampup_end', iter_max) - train_loss = 0 - train_loss_sup = 0 - train_loss_reg = 0 + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() rank_length = nll_cfg.get("dast_rank_length", 20) consist_loss = ConsistLoss() for it in range(iter_valid): + t0 = time.time() try: data_cl = next(self.trainIter) except StopIteration: @@ -182,7 +183,7 @@ def training(self): except StopIteration: self.trainIter_noise = iter(self.train_loader_noise) data_no = next(self.trainIter_noise) - + t1 = time.time() # get the inputs x0 = self.convert_tensor_type(data_cl['image']) # clean sample y0 = self.convert_tensor_type(data_cl['label_prob']) @@ -196,6 +197,7 @@ def training(self): # forward + backward + optimize b0_pred, b1_pred = self.net(inputs) + t2 = time.time() n0 = list(x0.shape)[0] # number of clean samples b0_x0_pred = b0_pred[:n0] # predication of clean samples from clean branch b0_x1_pred = b0_pred[n0:] # predication of noisy samples from clean branch @@ -231,8 +233,9 @@ def training(self): b0_x1_prob = nn.Softmax(dim = 1)(b0_x1_pred) loss_st = torch.mean(torch.abs(b0_x1_prob - sharpen(pseudo_label, 0.5))) loss = loss + loss_st * w_st - + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss = train_loss + loss.item() @@ -248,6 +251,11 @@ def training(self): p0_soft, y0 = reshape_prediction_and_ground_truth(p0_soft, y0) dice_list = get_classwise_dice(p0_soft, y0) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid @@ -256,7 +264,9 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':w_dbc, - 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } return train_scalers def train_valid(self): diff --git a/pymic/net_run/noisy_label/nll_trinet.py b/pymic/net_run/noisy_label/nll_trinet.py index 64d87b6..4c4198f 100644 --- a/pymic/net_run/noisy_label/nll_trinet.py +++ b/pymic/net_run/noisy_label/nll_trinet.py @@ -2,9 +2,8 @@ from __future__ import print_function, division import logging -import os -import sys import numpy as np +import time import torch import torch.nn as nn import torch.optim as optim @@ -62,14 +61,16 @@ def training(self): train_loss_no_select2 = 0 train_loss1, train_loss2, train_loss3 = 0, 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() for it in range(iter_valid): + t0 = time.time() try: data = next(self.trainIter) except StopIteration: self.trainIter = iter(self.train_loader) data = next(self.trainIter) - + t1 = time.time() # get the inputs inputs = self.convert_tensor_type(data['image']) labels_prob = self.convert_tensor_type(data['label_prob']) @@ -80,7 +81,7 @@ def training(self): # forward + backward + optimize outputs1, outputs2, outputs3 = self.net(inputs) - + t2 = time.time() rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end) forget_ratio = (1 - select_ratio) * rampup_ratio remb_ratio = 1 - forget_ratio @@ -95,8 +96,9 @@ def training(self): loss2_avg = torch.sum(loss2 * mask13) / mask13.sum() loss3_avg = torch.sum(loss3 * mask12) / mask12.sum() loss = (loss1_avg + loss2_avg + loss3_avg) / 3 - + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss_no_select1 = train_loss_no_select1 + loss1.mean().item() @@ -109,6 +111,11 @@ def training(self): soft_out1, labels_prob = reshape_prediction_and_ground_truth(soft_out1, labels_prob) dice_list = get_classwise_dice(soft_out1, labels_prob).detach().cpu().numpy() train_dice_list.append(dice_list) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss_no_select1 = train_loss_no_select1 / iter_valid train_avg_loss_no_select2 = train_loss_no_select2 / iter_valid train_avg_loss1 = train_loss1 / iter_valid @@ -120,7 +127,9 @@ def training(self): 'loss1':train_avg_loss1, 'loss2': train_avg_loss2, 'loss_no_select1':train_avg_loss_no_select1, 'loss_no_select2':train_avg_loss_no_select2, - 'select_ratio':remb_ratio, 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'select_ratio':remb_ratio, 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } return train_scalers def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): @@ -146,4 +155,7 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): ' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]") logging.info('valid loss {0:.4f}, avg foreground dice {1:.4f} '.format( valid_scalars['loss'], valid_scalars['avg_fg_dice']) + "[" + \ - ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") + ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") + logging.info("data: {0:.2f}s, forward: {1:.2f}s, loss: {2:.2f}s, backward: {3:.2f}s".format( + train_scalars['data_time'], train_scalars['forward_time'], + train_scalars['loss_time'], train_scalars['backward_time'])) diff --git a/pymic/net_run/self_sup/self_volume_fusion.py b/pymic/net_run/self_sup/self_volume_fusion.py index 91fe088..ad2d640 100644 --- a/pymic/net_run/self_sup/self_volume_fusion.py +++ b/pymic/net_run/self_sup/self_volume_fusion.py @@ -30,11 +30,11 @@ from pymic.loss.seg.util import get_classwise_dice from pymic.transform.trans_dict import TransformDict from pymic.util.post_process import PostProcessDict -from pymic.util.image_process import convert_label from pymic.util.parse_config import * from pymic.util.general import get_one_hot_seg from pymic.io.image_read_write import save_nd_array_as_image -from pymic.net_run.self_sup.util import volume_fusion +from pymic.net_run.self_sup.util import volume_fusion, nonlienar_volume_fusion, augmented_volume_fusion +from pymic.net_run.self_sup.util import self_volume_fusion from pymic.net_run.agent_seg import SegmentationAgent @@ -57,7 +57,6 @@ def __init__(self, config, stage = 'train'): def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] - cls_num = self.config['network']['class_num'] block_range = self.config['self_supervised_learning']['VolumeFusion_block_range'.lower()] size_min = self.config['self_supervised_learning']['VolumeFusion_size_min'.lower()] size_max = self.config['self_supervised_learning']['VolumeFusion_size_max'.lower()] @@ -73,8 +72,8 @@ def training(self): data = next(self.trainIter) # get the inputs inputs = self.convert_tensor_type(data['image']) - inputs, labels = volume_fusion(inputs, cls_num - 1, block_range, size_min, size_max) - labels_prob = get_one_hot_seg(labels, cls_num) + inputs, labels = volume_fusion(inputs, class_num - 1, block_range, size_min, size_max) + labels_prob = get_one_hot_seg(labels, class_num) # for debug # if(it==10): @@ -117,3 +116,239 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'avg_fg_dice':train_avg_dice,\ 'class_dice': train_cls_dice} return train_scalers + +class SelfSupSelfVolumeFusion(SegmentationAgent): + """ + Abstract class for self-supervised segmentation. + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. + """ + def __init__(self, config, stage = 'train'): + super(SelfSupSelfVolumeFusion, self).__init__(config, stage) + + def training(self): + class_num = self.config['network']['class_num'] + iter_valid = self.config['training']['iter_valid'] + fuse_ratio = self.config['self_supervised_learning']['SelfVolumeFusion_fuse_ratio'.lower()] + size_min = self.config['self_supervised_learning']['SelfVolumeFusion_size_min'.lower()] + size_max = self.config['self_supervised_learning']['SelfVolumeFusion_size_max'.lower()] + + train_loss = 0 + train_dice_list = [] + self.net.train() + for it in range(iter_valid): + try: + data = next(self.trainIter) + except StopIteration: + self.trainIter = iter(self.train_loader) + data = next(self.trainIter) + # get the inputs + inputs = self.convert_tensor_type(data['image']) + inputs, labels = self_volume_fusion(inputs, class_num - 1, fuse_ratio, size_min, size_max) + labels_prob = get_one_hot_seg(labels, class_num) + + # for debug + # if(it==10): + # break + # for i in range(inputs.shape[0]): + # image_i = inputs[i][0] + # label_i = np.argmax(labels_prob[i], axis = 0) + # # pixw_i = pix_w[i][0] + # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) + # label_name = "temp/label_{0:}_{1:}.nii.gz".format(it, i) + # # weight_name= "temp/weight_{0:}_{1:}.nii.gz".format(it, i) + # save_nd_array_as_image(image_i, image_name, reference_name = None) + # save_nd_array_as_image(label_i, label_name, reference_name = None) + # # save_nd_array_as_image(pixw_i, weight_name, reference_name = None) + # continue + + inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device) + + # zero the parameter gradients + self.optimizer.zero_grad() + + # forward + backward + optimize + outputs = self.net(inputs) + loss = self.get_loss_value(data, outputs, labels_prob) + loss.backward() + self.optimizer.step() + train_loss = train_loss + loss.item() + # get dice evaluation for each class + if(isinstance(outputs, tuple) or isinstance(outputs, list)): + outputs = outputs[0] + outputs_argmax = torch.argmax(outputs, dim = 1, keepdim = True) + soft_out = get_soft_label(outputs_argmax, class_num, self.tensor_type) + soft_out, labels_prob = reshape_prediction_and_ground_truth(soft_out, labels_prob) + dice_list = get_classwise_dice(soft_out, labels_prob) + train_dice_list.append(dice_list.cpu().numpy()) + train_avg_loss = train_loss / iter_valid + train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) + train_avg_dice = train_cls_dice[1:].mean() + + train_scalers = {'loss': train_avg_loss, 'avg_fg_dice':train_avg_dice,\ + 'class_dice': train_cls_dice} + return train_scalers + +class SelfSupNonLinearVolumeFusion(SegmentationAgent): + """ + Abstract class for self-supervised segmentation. + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. + """ + def __init__(self, config, stage = 'train'): + super(SelfSupNonLinearVolumeFusion, self).__init__(config, stage) + + def training(self): + class_num = 3 + iter_valid = self.config['training']['iter_valid'] + block_range = self.config['self_supervised_learning']['NonLinearVolumeFusion_block_range'.lower()] + size_min = self.config['self_supervised_learning']['NonLinearVolumeFusion_size_min'.lower()] + size_max = self.config['self_supervised_learning']['NonLinearVolumeFusion_size_max'.lower()] + + train_loss = 0 + train_dice_list = [] + self.net.train() + for it in range(iter_valid): + try: + data = next(self.trainIter) + except StopIteration: + self.trainIter = iter(self.train_loader) + data = next(self.trainIter) + # get the inputs + inputs = self.convert_tensor_type(data['image']) + inputs, labels = nonlienar_volume_fusion(inputs, block_range, size_min, size_max) + labels_prob = get_one_hot_seg(labels, class_num) + + # for debug + # if(it==10): + # break + # for i in range(inputs.shape[0]): + # image_i = inputs[i][0] + # label_i = np.argmax(labels_prob[i], axis = 0) + # # pixw_i = pix_w[i][0] + # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) + # label_name = "temp/label_{0:}_{1:}.nii.gz".format(it, i) + # # weight_name= "temp/weight_{0:}_{1:}.nii.gz".format(it, i) + # save_nd_array_as_image(image_i, image_name, reference_name = None) + # save_nd_array_as_image(label_i, label_name, reference_name = None) + # # save_nd_array_as_image(pixw_i, weight_name, reference_name = None) + # continue + + inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device) + + # zero the parameter gradients + self.optimizer.zero_grad() + + # forward + backward + optimize + outputs = self.net(inputs) + loss = self.get_loss_value(data, outputs, labels_prob) + loss.backward() + self.optimizer.step() + train_loss = train_loss + loss.item() + # get dice evaluation for each class + if(isinstance(outputs, tuple) or isinstance(outputs, list)): + outputs = outputs[0] + outputs_argmax = torch.argmax(outputs, dim = 1, keepdim = True) + soft_out = get_soft_label(outputs_argmax, class_num, self.tensor_type) + soft_out, labels_prob = reshape_prediction_and_ground_truth(soft_out, labels_prob) + dice_list = get_classwise_dice(soft_out, labels_prob) + train_dice_list.append(dice_list.cpu().numpy()) + train_avg_loss = train_loss / iter_valid + train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) + train_avg_dice = train_cls_dice[1:].mean() + + train_scalers = {'loss': train_avg_loss, 'avg_fg_dice':train_avg_dice,\ + 'class_dice': train_cls_dice} + return train_scalers + +class SelfSupAugmentedVolumeFusion(SegmentationAgent): + """ + Abstract class for self-supervised segmentation. + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. + """ + def __init__(self, config, stage = 'train'): + super(SelfSupAugmentedVolumeFusion, self).__init__(config, stage) + + def training(self): + class_num = 5 + iter_valid = self.config['training']['iter_valid'] + size_min = self.config['self_supervised_learning']['AugmentedVolumeFusion_size_min'.lower()] + size_max = self.config['self_supervised_learning']['AugmentedVolumeFusion_size_max'.lower()] + + train_loss = 0 + train_dice_list = [] + self.net.train() + for it in range(iter_valid): + try: + data = next(self.trainIter) + except StopIteration: + self.trainIter = iter(self.train_loader) + data = next(self.trainIter) + # get the inputs + inputs = self.convert_tensor_type(data['image']) + inputs, labels = augmented_volume_fusion(inputs, size_min, size_max) + labels_prob = get_one_hot_seg(labels, class_num) + + # for debug + # if(it==10): + # break + # for i in range(inputs.shape[0]): + # image_i = inputs[i][0] + # label_i = np.argmax(labels_prob[i], axis = 0) + # # pixw_i = pix_w[i][0] + # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) + # label_name = "temp/label_{0:}_{1:}.nii.gz".format(it, i) + # # weight_name= "temp/weight_{0:}_{1:}.nii.gz".format(it, i) + # save_nd_array_as_image(image_i, image_name, reference_name = None) + # save_nd_array_as_image(label_i, label_name, reference_name = None) + # # save_nd_array_as_image(pixw_i, weight_name, reference_name = None) + # continue + + inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device) + + # zero the parameter gradients + self.optimizer.zero_grad() + + # forward + backward + optimize + outputs = self.net(inputs) + loss = self.get_loss_value(data, outputs, labels_prob) + loss.backward() + self.optimizer.step() + train_loss = train_loss + loss.item() + # get dice evaluation for each class + if(isinstance(outputs, tuple) or isinstance(outputs, list)): + outputs = outputs[0] + outputs_argmax = torch.argmax(outputs, dim = 1, keepdim = True) + soft_out = get_soft_label(outputs_argmax, class_num, self.tensor_type) + soft_out, labels_prob = reshape_prediction_and_ground_truth(soft_out, labels_prob) + dice_list = get_classwise_dice(soft_out, labels_prob) + train_dice_list.append(dice_list.cpu().numpy()) + train_avg_loss = train_loss / iter_valid + train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) + train_avg_dice = train_cls_dice[1:].mean() + + train_scalers = {'loss': train_avg_loss, 'avg_fg_dice':train_avg_dice,\ + 'class_dice': train_cls_dice} + return train_scalers \ No newline at end of file diff --git a/pymic/net_run/self_sup/util.py b/pymic/net_run/self_sup/util.py index db27702..d6adcc1 100644 --- a/pymic/net_run/self_sup/util.py +++ b/pymic/net_run/self_sup/util.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division import os +import copy import torch import random import numpy as np @@ -136,7 +137,6 @@ def get_human_body_mask_and_crop(input_dir, out_img_dir, out_mask_dir): mask_obj.CopyInformation(out_img_obj) sitk.WriteImage(mask_obj, mask_name) - def volume_fusion(x, fg_num, block_range, size_min, size_max): """ Fuse a subregion of an impage with another one to generate @@ -145,7 +145,7 @@ def volume_fusion(x, fg_num, block_range, size_min, size_max): """ #n_min, n_max, N, C, D, H, W = list(x.shape) - fg_mask = torch.zeros_like(x).to(torch.int32) + fg_mask = torch.zeros_like(x[:, :1, :, :, :]).to(torch.int32) # generate mask for n in range(N): p_num = random.randint(block_range[0], block_range[1]) @@ -163,10 +163,163 @@ def volume_fusion(x, fg_num, block_range, size_min, size_max): h1 = min(H, h0 + h) w1 = min(W, w0 + w) d0, h0, w0 = max(0, d0), max(0, h0), max(0, w0) - temp_m = torch.ones([C, d1 - d0, h1 - h0, w1 - w0]) * random.randint(1, fg_num) + temp_m = torch.ones([1, d1 - d0, h1 - h0, w1 - w0]) * random.randint(1, fg_num) fg_mask[n, :, d0:d1, h0:h1, w0:w1] = temp_m fg_w = fg_mask * 1.0 / fg_num x_roll = torch.roll(x, 1, 0) x_fuse = fg_w*x_roll + (1.0 - fg_w)*x # y_prob = get_one_hot_seg(fg_mask.to(torch.int32), fg_num + 1) return x_fuse, fg_mask + +def nonlinear_transform(x): + v_min = torch.min(x) + v_max = torch.max(x) + x = (x - v_min)/(v_max - v_min) + a = random.random() * 0.7 + 0.15 + b = random.random() * 0.7 + 0.15 + alpha = b / a + beta = (1 - b) / (1 - a) + if(alpha < 1.0 ): + y = torch.maximum(alpha*x, beta*x + 1 - beta) + else: + y = torch.minimum(alpha*x, beta*x + 1 - beta) + if(random.random() < 0.5): + y = 1.0 - y + y = y * (v_max - v_min) + v_min + return y + +def nonlienar_volume_fusion(x, block_range, size_min, size_max): + """ + Fuse a subregion of an impage with another one to generate + images and labels for self-supervised segmentation. + input x should be a batch of tensors + """ + #n_min, n_max, + N, C, D, H, W = list(x.shape) + # apply nonlinear transform to x: + x_nl1 = torch.zeros_like(x).to(torch.float32) + x_nl2 = torch.zeros_like(x).to(torch.float32) + for n in range(N): + x_nl1[n] = nonlinear_transform(x[n]) + x_nl2[n] = nonlinear_transform(x[n]) + x_roll = torch.roll(x_nl2, 1, 0) + mask = torch.zeros_like(x).to(torch.int32) + p_num = random.randint(block_range[0], block_range[1]) + for n in range(N): + for i in range(p_num): + d = random.randint(size_min[0], size_max[0]) + h = random.randint(size_min[1], size_max[1]) + w = random.randint(size_min[2], size_max[2]) + dc = random.randint(0, D - 1) + hc = random.randint(0, H - 1) + wc = random.randint(0, W - 1) + d0 = dc - d // 2 + h0 = hc - h // 2 + w0 = wc - w // 2 + d1 = min(D, d0 + d) + h1 = min(H, h0 + h) + w1 = min(W, w0 + w) + d0, h0, w0 = max(0, d0), max(0, h0), max(0, w0) + temp_m = torch.ones([C, d1 - d0, h1 - h0, w1 - w0]) + if(random.random() < 0.5): + temp_m = temp_m * 2 + mask[n, :, d0:d1, h0:h1, w0:w1] = temp_m + + mask1 = (mask == 1).to(torch.int32) + mask2 = (mask == 2).to(torch.int32) + y = x_nl1 * (1.0 - mask1) + x_nl2 * mask1 + y = y * (1.0 - mask2) + x_roll * mask2 + return y, mask + +def augmented_volume_fusion(x, size_min, size_max): + """ + Fuse a subregion of an impage with another one to generate + images and labels for self-supervised segmentation. + input x should be a batch of tensors + """ + #n_min, n_max, + N, C, D, H, W = list(x.shape) + # apply nonlinear transform to x: + x1 = torch.zeros_like(x).to(torch.float32) + y = torch.zeros_like(x).to(torch.float32) + mask = torch.zeros_like(x).to(torch.int32) + for n in range(N): + x1[n] = nonlinear_transform(x[n]) + y[n] = nonlinear_transform(x[n]) + x2 = torch.roll(x1, 1, 0) + + for n in range(N): + block_size = [random.randint(size_min[i], size_max[i]) for i in range(3)] + d_start = random.randint(0, block_size[0] // 2) + h_start = random.randint(0, block_size[1] // 2) + w_stat = random.randint(0, block_size[2] // 2) + for d in range(d_start, D, block_size[0]): + if(D - d < block_size[0] // 2): + continue + d1 = min(d + block_size[0], D) + for h in range(h_start, H, block_size[1]): + if(H - h < block_size[1] // 2): + continue + h1 = min(h + block_size[1], H) + for w in range(w_stat, W, block_size[2]): + if(W - w < block_size[2] // 2): + continue + w1 = min(w + block_size[2], W) + p = random.random() + if(p < 0.15): # nonlinear intensity augmentation + mask[n, :, d:d1, h:h1, w:w1] = 1 + y[n, :, d:d1, h:h1, w:w1] = x1[n, :, d:d1, h:h1, w:w1] + elif(p < 0.3): # random flip across a certain axis + mask[n, :, d:d1, h:h1, w:w1] = 2 + flip_axis = random.randint(-3, -1) + y[n, :, d:d1, h:h1, w:w1] = torch.flip(y[n, :, d:d1, h:h1, w:w1], (flip_axis,)) + elif(p < 0.45): # nonlinear intensity augmentation and random flip across a certain axis + mask[n, :, d:d1, h:h1, w:w1] = 3 + flip_axis = random.randint(-3, -1) + y[n, :, d:d1, h:h1, w:w1] = torch.flip(x1[n, :, d:d1, h:h1, w:w1], (flip_axis,)) + elif(p < 0.6): # paste from another volume + mask[n, :, d:d1, h:h1, w:w1] = 4 + y[n, :, d:d1, h:h1, w:w1] = x2[n, :, d:d1, h:h1, w:w1] + return y, mask + +def self_volume_fusion(x, fg_num, fuse_ratio, size_min, size_max): + """ + Fuse a subregion of an impage with another one to generate + images and labels for self-supervised segmentation. + input x should be a batch of tensors + """ + #n_min, n_max, + N, C, D, H, W = list(x.shape) + y = 1.0 * x + fg_mask = torch.zeros_like(x[:, :1, :, :, :]).to(torch.int32) + + for n in range(N): + db = random.randint(size_min[0], size_max[0]) + hb = random.randint(size_min[1], size_max[1]) + wb = random.randint(size_min[2], size_max[2]) + d0 = random.randint(0, D % db) + h0 = random.randint(0, H % hb) + w0 = random.randint(0, W % wb) + coord_list_source = [] + for di in range(D // db): + for hi in range(H // hb): + for wi in range(W // wb): + coord_list_source.append([di, hi, wi]) + coord_list_target = copy.deepcopy(coord_list_source) + random.shuffle(coord_list_source) + random.shuffle(coord_list_target) + for i in range(int(len(coord_list_source)*fuse_ratio)): + ds_l = d0 + db * coord_list_source[i][0] + hs_l = h0 + hb * coord_list_source[i][1] + ws_l = w0 + wb * coord_list_source[i][2] + dt_l = d0 + db * coord_list_target[i][0] + ht_l = h0 + hb * coord_list_target[i][1] + wt_l = w0 + wb * coord_list_target[i][2] + s_crop = x[n, :, ds_l:ds_l+db, hs_l:hs_l+hb, ws_l:ws_l+wb] + t_crop = x[n, :, dt_l:dt_l+db, ht_l:ht_l+hb, wt_l:wt_l+wb] + fg_m = random.randint(1, fg_num) + fg_w = fg_m / (fg_num + 0.0) + y[n, :, dt_l:dt_l+db, ht_l:ht_l+hb, wt_l:wt_l+wb] = t_crop * (1.0 - fg_w) + s_crop * fg_w + fg_mask[n, 0, dt_l:dt_l+db, ht_l:ht_l+hb, wt_l:wt_l+wb] = \ + torch.ones([1, db, hb, wb]) * fg_m + return y, fg_mask \ No newline at end of file diff --git a/pymic/net_run/semi_sup/__init__.py b/pymic/net_run/semi_sup/__init__.py index d3095f6..cb5d1a3 100644 --- a/pymic/net_run/semi_sup/__init__.py +++ b/pymic/net_run/semi_sup/__init__.py @@ -3,6 +3,7 @@ from pymic.net_run.semi_sup.ssl_em import SSLEntropyMinimization from pymic.net_run.semi_sup.ssl_mt import SSLMeanTeacher from pymic.net_run.semi_sup.ssl_mcnet import SSLMCNet +from pymic.net_run.semi_sup.ssl_cdma import SSLCDMA from pymic.net_run.semi_sup.ssl_uamt import SSLUncertaintyAwareMeanTeacher from pymic.net_run.semi_sup.ssl_cct import SSLCCT from pymic.net_run.semi_sup.ssl_cps import SSLCPS @@ -12,6 +13,7 @@ SSLMethodDict = {'EntropyMinimization': SSLEntropyMinimization, 'MeanTeacher': SSLMeanTeacher, 'MCNet': SSLMCNet, + 'CDMA': SSLCDMA, 'UAMT': SSLUncertaintyAwareMeanTeacher, 'CCT': SSLCCT, 'CPS': SSLCPS, diff --git a/pymic/net_run/semi_sup/ssl_abstract.py b/pymic/net_run/semi_sup/ssl_abstract.py index 0e05281..69a09fd 100644 --- a/pymic/net_run/semi_sup/ssl_abstract.py +++ b/pymic/net_run/semi_sup/ssl_abstract.py @@ -101,6 +101,9 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): logging.info('valid loss {0:.4f}, avg foreground dice {1:.4f} '.format( valid_scalars['loss'], valid_scalars['avg_fg_dice']) + "[" + \ ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") + logging.info("data: {0:.2f}s, forward: {1:.2f}s, loss: {2:.2f}s, backward: {3:.2f}s".format( + train_scalars['data_time'], train_scalars['forward_time'], + train_scalars['loss_time'], train_scalars['backward_time'])) def train_valid(self): self.trainIter_unlab = iter(self.train_loader_unlab) diff --git a/pymic/net_run/semi_sup/ssl_cct.py b/pymic/net_run/semi_sup/ssl_cct.py index 1943608..81723bd 100644 --- a/pymic/net_run/semi_sup/ssl_cct.py +++ b/pymic/net_run/semi_sup/ssl_cct.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division import logging +import time import torch import torch.nn as nn import torch.nn.functional as F @@ -88,13 +89,13 @@ def training(self): rampup_end = ssl_cfg.get('rampup_end', iter_max) unsup_loss_name = ssl_cfg.get('unsupervised_loss', "MSE") self.unsup_loss_f = unsup_loss_dict[unsup_loss_name] - train_loss = 0 - train_loss_sup = 0 - train_loss_reg = 0 + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() for it in range(iter_valid): + t0 = time.time() try: data_lab = next(self.trainIter) except StopIteration: @@ -105,7 +106,7 @@ def training(self): except StopIteration: self.trainIter_unlab = iter(self.train_loader_unlab) data_unlab = next(self.trainIter_unlab) - + t1 = time.time() # get the inputs x0 = self.convert_tensor_type(data_lab['image']) y0 = self.convert_tensor_type(data_lab['label_prob']) @@ -118,6 +119,7 @@ def training(self): # forward pass output, aux_outputs = self.net(inputs) + t2 = time.time() n0 = list(x0.shape)[0] # get supervised loss @@ -135,8 +137,9 @@ def training(self): rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") regular_w = ssl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg - + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss = train_loss + loss.item() @@ -150,6 +153,11 @@ def training(self): p0_soft, y0 = reshape_prediction_and_ground_truth(p0_soft, y0) dice_list = get_classwise_dice(p0_soft, y0) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid @@ -158,5 +166,7 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } return train_scalers diff --git a/pymic/net_run/semi_sup/ssl_cps.py b/pymic/net_run/semi_sup/ssl_cps.py index 7acfe17..dc0d325 100644 --- a/pymic/net_run/semi_sup/ssl_cps.py +++ b/pymic/net_run/semi_sup/ssl_cps.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division import logging +import time import numpy as np import torch from random import random @@ -44,8 +45,10 @@ def training(self): train_loss_sup1, train_loss_pseudo_sup1 = 0, 0 train_loss_sup2, train_loss_pseudo_sup2 = 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() for it in range(iter_valid): + t0 = time.time() try: data_lab = next(self.trainIter) except StopIteration: @@ -56,7 +59,7 @@ def training(self): except StopIteration: self.trainIter_unlab = iter(self.train_loader_unlab) data_unlab = next(self.trainIter_unlab) - + t1 = time.time() # get the inputs x0 = self.convert_tensor_type(data_lab['image']) y0 = self.convert_tensor_type(data_lab['label_prob']) @@ -84,6 +87,7 @@ def training(self): outputs1, outputs2 = self.net(inputs) outputs_soft1 = torch.softmax(outputs1, dim=1) outputs_soft2 = torch.softmax(outputs2, dim=1) + t2 = time.time() n0 = list(x0.shape)[0] p0 = outputs_soft1[:n0] @@ -105,8 +109,9 @@ def training(self): model1_loss = loss_sup1 + regular_w * pse_sup1 model2_loss = loss_sup2 + regular_w * pse_sup2 loss = model1_loss + model2_loss - + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss = train_loss + loss.item() @@ -123,6 +128,11 @@ def training(self): p0_soft, y0 = reshape_prediction_and_ground_truth(p0_soft, y0) dice_list = get_classwise_dice(p0_soft, y0) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup1 = train_loss_sup1 / iter_valid train_avg_loss_sup2 = train_loss_sup2 / iter_valid @@ -134,7 +144,9 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup1':train_avg_loss_sup1, 'loss_sup2': train_avg_loss_sup2, 'loss_pse_sup1':train_avg_loss_pse_sup1, 'loss_pse_sup2': train_avg_loss_pse_sup2, - 'regular_w':regular_w, 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'regular_w':regular_w, 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } return train_scalers def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): @@ -162,4 +174,7 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): ' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]") logging.info('valid loss {0:.4f}, avg dice {1:.4f} '.format( valid_scalars['loss'], valid_scalars['avg_fg_dice']) + "[" + \ - ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") \ No newline at end of file + ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") + logging.info("data: {0:.2f}s, forward: {1:.2f}s, loss: {2:.2f}s, backward: {3:.2f}s".format( + train_scalars['data_time'], train_scalars['forward_time'], + train_scalars['loss_time'], train_scalars['backward_time'])) \ No newline at end of file diff --git a/pymic/net_run/semi_sup/ssl_em.py b/pymic/net_run/semi_sup/ssl_em.py index fde941b..750e3d3 100644 --- a/pymic/net_run/semi_sup/ssl_em.py +++ b/pymic/net_run/semi_sup/ssl_em.py @@ -2,6 +2,7 @@ from __future__ import print_function, division import logging import numpy as np +import time import torch from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth @@ -40,12 +41,12 @@ def training(self): iter_max = self.config['training']['iter_max'] rampup_start = ssl_cfg.get('rampup_start', 0) rampup_end = ssl_cfg.get('rampup_end', iter_max) - train_loss = 0 - train_loss_sup = 0 - train_loss_reg = 0 + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() for it in range(iter_valid): + t0 = time.time() try: data_lab = next(self.trainIter) except StopIteration: @@ -56,7 +57,7 @@ def training(self): except StopIteration: self.trainIter_unlab = iter(self.train_loader_unlab) data_unlab = next(self.trainIter_unlab) - + t1 = time.time() # get the inputs x0 = self.convert_tensor_type(data_lab['image']) y0 = self.convert_tensor_type(data_lab['label_prob']) @@ -69,6 +70,8 @@ def training(self): # forward + backward + optimize outputs = self.net(inputs) + t2 = time.time() + n0 = list(x0.shape)[0] p0 = outputs[:n0] loss_sup = self.get_loss_value(data_lab, p0, y0) @@ -79,8 +82,10 @@ def training(self): regular_w = ssl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg - # if (self.config['training']['use']) + t3 = time.time() + loss.backward() + t4 = time.time() self.optimizer.step() train_loss = train_loss + loss.item() @@ -94,6 +99,11 @@ def training(self): p0_soft, y0 = reshape_prediction_and_ground_truth(p0_soft, y0) dice_list = get_classwise_dice(p0_soft, y0) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid @@ -102,5 +112,7 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } return train_scalers \ No newline at end of file diff --git a/pymic/net_run/semi_sup/ssl_mcnet.py b/pymic/net_run/semi_sup/ssl_mcnet.py index 66e1034..d374773 100644 --- a/pymic/net_run/semi_sup/ssl_mcnet.py +++ b/pymic/net_run/semi_sup/ssl_mcnet.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division import logging +import time import torch import torch.nn as nn import torch.nn.functional as F @@ -44,13 +45,13 @@ def training(self): rampup_end = ssl_cfg.get('rampup_end', iter_max) temperature = ssl_cfg.get('temperature', 0.1) unsup_loss_name = ssl_cfg.get('unsupervised_loss', "MSE") - train_loss = 0 - train_loss_sup = 0 - train_loss_reg = 0 + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() for it in range(iter_valid): + t0 = time.time() try: data_lab = next(self.trainIter) except StopIteration: @@ -61,7 +62,7 @@ def training(self): except StopIteration: self.trainIter_unlab = iter(self.train_loader_unlab) data_unlab = next(self.trainIter_unlab) - + t1 = time.time() # get the inputs x0 = self.convert_tensor_type(data_lab['image']) y0 = self.convert_tensor_type(data_lab['label_prob']) @@ -74,6 +75,7 @@ def training(self): # forward pass to obtain multiple predictions outputs = self.net(inputs) + t2 = time.time() num_outputs = len(outputs) n0 = list(x0.shape)[0] p0 = F.softmax(outputs[0], dim=1)[:n0] @@ -81,7 +83,7 @@ def training(self): p_ori = torch.zeros((num_outputs,) + outputs[0].shape) y_psu = torch.zeros((num_outputs,) + outputs[0].shape) - # get supervised loss + # get supervised loss loss_sup = 0 for idx in range(num_outputs): p0i = outputs[idx][:n0] @@ -102,8 +104,9 @@ def training(self): rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") regular_w = ssl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg - + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss = train_loss + loss.item() @@ -117,6 +120,11 @@ def training(self): p0_soft, y0 = reshape_prediction_and_ground_truth(p0_soft, y0) dice_list = get_classwise_dice(p0_soft, y0) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid @@ -125,5 +133,7 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } return train_scalers diff --git a/pymic/net_run/semi_sup/ssl_mt.py b/pymic/net_run/semi_sup/ssl_mt.py index 409af19..303eeff 100644 --- a/pymic/net_run/semi_sup/ssl_mt.py +++ b/pymic/net_run/semi_sup/ssl_mt.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division import logging +import time import torch import numpy as np from pymic.loss.seg.util import get_soft_label @@ -50,13 +51,13 @@ def training(self): iter_max = self.config['training']['iter_max'] rampup_start = ssl_cfg.get('rampup_start', 0) rampup_end = ssl_cfg.get('rampup_end', iter_max) - train_loss = 0 - train_loss_sup = 0 - train_loss_reg = 0 + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() self.net_ema.to(self.device) for it in range(iter_valid): + t0 = time.time() try: data_lab = next(self.trainIter) except StopIteration: @@ -67,7 +68,7 @@ def training(self): except StopIteration: self.trainIter_unlab = iter(self.train_loader_unlab) data_unlab = next(self.trainIter_unlab) - + t1 = time.time() # get the inputs x0 = self.convert_tensor_type(data_lab['image']) y0 = self.convert_tensor_type(data_lab['label_prob']) @@ -82,6 +83,8 @@ def training(self): self.optimizer.zero_grad() outputs = self.net(inputs) + t2 = time.time() + n0 = list(x0.shape)[0] p0 = outputs[:n0] loss_sup = self.get_loss_value(data_lab, p0, y0) @@ -98,8 +101,9 @@ def training(self): loss_reg = torch.nn.MSELoss()(p1_soft, p1_ema_soft) loss = loss_sup + regular_w*loss_reg - + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() # update EMA @@ -119,6 +123,11 @@ def training(self): p0_soft, y0 = reshape_prediction_and_ground_truth(p0_soft, y0) dice_list = get_classwise_dice(p0_soft, y0) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid @@ -127,5 +136,7 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time} return train_scalers \ No newline at end of file diff --git a/pymic/net_run/semi_sup/ssl_uamt.py b/pymic/net_run/semi_sup/ssl_uamt.py index 053a012..5888a8b 100644 --- a/pymic/net_run/semi_sup/ssl_uamt.py +++ b/pymic/net_run/semi_sup/ssl_uamt.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division import logging +import time import torch import numpy as np from pymic.loss.seg.util import get_soft_label @@ -33,13 +34,13 @@ def training(self): iter_max = self.config['training']['iter_max'] rampup_start = ssl_cfg.get('rampup_start', 0) rampup_end = ssl_cfg.get('rampup_end', iter_max) - train_loss = 0 - train_loss_sup = 0 - train_loss_reg = 0 + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() self.net_ema.to(self.device) for it in range(iter_valid): + t0 = time.time() try: data_lab = next(self.trainIter) except StopIteration: @@ -50,7 +51,7 @@ def training(self): except StopIteration: self.trainIter_unlab = iter(self.train_loader_unlab) data_unlab = next(self.trainIter_unlab) - + t1 = time.time() # get the inputs x0 = self.convert_tensor_type(data_lab['image']) y0 = self.convert_tensor_type(data_lab['label_prob']) @@ -64,6 +65,7 @@ def training(self): self.optimizer.zero_grad() outputs = self.net(inputs) + t2 = time.time() n0 = list(x0.shape)[0] p0, p1 = torch.tensor_split(outputs, [n0,], dim = 0) outputs_soft = torch.softmax(outputs, dim=1) @@ -100,8 +102,9 @@ def training(self): regular_w = ssl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg - + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() # update EMA @@ -121,6 +124,11 @@ def training(self): p0_soft, y0 = reshape_prediction_and_ground_truth(p0_soft, y0) dice_list = get_classwise_dice(p0_soft, y0) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid @@ -129,5 +137,7 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } return train_scalers \ No newline at end of file diff --git a/pymic/net_run/semi_sup/ssl_urpc.py b/pymic/net_run/semi_sup/ssl_urpc.py index 56bb77e..8447709 100644 --- a/pymic/net_run/semi_sup/ssl_urpc.py +++ b/pymic/net_run/semi_sup/ssl_urpc.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division import logging +import time import torch import torch.nn as nn import numpy as np @@ -35,13 +36,13 @@ def training(self): iter_max = self.config['training']['iter_max'] rampup_start = ssl_cfg.get('rampup_start', 0) rampup_end = ssl_cfg.get('rampup_end', iter_max) - train_loss = 0 - train_loss_sup = 0 - train_loss_reg = 0 + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() kl_distance = nn.KLDivLoss(reduction='none') for it in range(iter_valid): + t0 = time.time() try: data_lab = next(self.trainIter) except StopIteration: @@ -52,7 +53,7 @@ def training(self): except StopIteration: self.trainIter_unlab = iter(self.train_loader_unlab) data_unlab = next(self.trainIter_unlab) - + t1 = time.time() # get the inputs x0 = self.convert_tensor_type(data_lab['image']) y0 = self.convert_tensor_type(data_lab['label_prob']) @@ -65,6 +66,7 @@ def training(self): # forward pass outputs_list = self.net(inputs) + t2 = time.time() n0 = list(x0.shape)[0] # get supervised loss @@ -95,8 +97,9 @@ def training(self): rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") regular_w = ssl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg - + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss = train_loss + loss.item() @@ -110,6 +113,11 @@ def training(self): p0_soft, y0 = reshape_prediction_and_ground_truth(p0_soft, y0) dice_list = get_classwise_dice(p0_soft, y0) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid @@ -118,5 +126,7 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time} return train_scalers diff --git a/pymic/net_run/weak_sup/wsl_abstract.py b/pymic/net_run/weak_sup/wsl_abstract.py index f290465..37b9445 100644 --- a/pymic/net_run/weak_sup/wsl_abstract.py +++ b/pymic/net_run/weak_sup/wsl_abstract.py @@ -24,7 +24,7 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): 'valid':valid_scalars['loss']} loss_sup_scalar = {'train':train_scalars['loss_sup']} loss_upsup_scalar = {'train':train_scalars['loss_reg']} - dice_scalar ={'train':train_scalars['avg_fg_dice'], 'valid':valid_scalars['avg_fg_dice']} + dice_scalar ={'valid':valid_scalars['avg_fg_dice']} self.summ_writer.add_scalars('loss', loss_scalar, glob_it) self.summ_writer.add_scalars('loss_sup', loss_sup_scalar, glob_it) self.summ_writer.add_scalars('loss_reg', loss_upsup_scalar, glob_it) @@ -36,9 +36,10 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): cls_dice_scalar = {'train':train_scalars['class_dice'][c], \ 'valid':valid_scalars['class_dice'][c]} self.summ_writer.add_scalars('class_{0:}_dice'.format(c), cls_dice_scalar, glob_it) - logging.info('train loss {0:.4f}, avg foreground dice {1:.4f} '.format( - train_scalars['loss'], train_scalars['avg_fg_dice']) + "[" + \ - ' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]") + logging.info('train loss {0:.4f}'.format(train_scalars['loss'])) logging.info('valid loss {0:.4f}, avg foreground dice {1:.4f} '.format( valid_scalars['loss'], valid_scalars['avg_fg_dice']) + "[" + \ ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") + logging.info("data: {0:.2f}s, forward: {1:.2f}s, loss: {2:.2f}s, backward: {3:.2f}s".format( + train_scalars['data_time'], train_scalars['forward_time'], + train_scalars['loss_time'], train_scalars['backward_time'])) diff --git a/pymic/net_run/weak_sup/wsl_dmpls.py b/pymic/net_run/weak_sup/wsl_dmpls.py index 4212409..01902c7 100644 --- a/pymic/net_run/weak_sup/wsl_dmpls.py +++ b/pymic/net_run/weak_sup/wsl_dmpls.py @@ -3,6 +3,7 @@ import logging import numpy as np import random +import time import torch from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label @@ -44,18 +45,18 @@ def training(self): iter_max = self.config['training']['iter_max'] rampup_start = wsl_cfg.get('rampup_start', 0) rampup_end = wsl_cfg.get('rampup_end', iter_max) - train_loss = 0 - train_loss_sup = 0 - train_loss_reg = 0 + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() for it in range(iter_valid): + t0 = time.time() try: data = next(self.trainIter) except StopIteration: self.trainIter = iter(self.train_loader) data = next(self.trainIter) - + t1 = time.time() # get the inputs inputs = self.convert_tensor_type(data['image']) y = self.convert_tensor_type(data['label_prob']) @@ -67,6 +68,8 @@ def training(self): # forward + backward + optimize outputs1, outputs2 = self.net(inputs) + t2 = time.time() + loss_sup1 = self.get_loss_value(data, outputs1, y) loss_sup2 = self.get_loss_value(data, outputs2, y) loss_sup = 0.5 * (loss_sup1 + loss_sup2) @@ -88,8 +91,9 @@ def training(self): rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg - + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss = train_loss + loss.item() @@ -103,6 +107,11 @@ def training(self): p_soft, y = reshape_prediction_and_ground_truth(p_soft, y) dice_list = get_classwise_dice(p_soft, y) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid @@ -111,7 +120,8 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} - + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time} return train_scalers \ No newline at end of file diff --git a/pymic/net_run/weak_sup/wsl_em.py b/pymic/net_run/weak_sup/wsl_em.py index adcd70c..987aa89 100644 --- a/pymic/net_run/weak_sup/wsl_em.py +++ b/pymic/net_run/weak_sup/wsl_em.py @@ -2,12 +2,12 @@ from __future__ import print_function, division import logging import numpy as np +import time import torch from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.ssl import EntropyLoss -from pymic.net_run.agent_seg import SegmentationAgent from pymic.net_run.weak_sup import WSLSegAgent from pymic.util.ramps import get_rampup_ratio @@ -38,18 +38,18 @@ def training(self): iter_max = self.config['training']['iter_max'] rampup_start = wsl_cfg.get('rampup_start', 0) rampup_end = wsl_cfg.get('rampup_end', iter_max) - train_loss = 0 - train_loss_sup = 0 - train_loss_reg = 0 + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() for it in range(iter_valid): + t0 = time.time() try: data = next(self.trainIter) except StopIteration: self.trainIter = iter(self.train_loader) data = next(self.trainIter) - + t1 = time.time() # get the inputs inputs = self.convert_tensor_type(data['image']) y = self.convert_tensor_type(data['label_prob']) @@ -61,6 +61,8 @@ def training(self): # forward + backward + optimize outputs = self.net(inputs) + t2 = time.time() + loss_sup = self.get_loss_value(data, outputs, y) loss_dict= {"prediction":outputs, 'softmax':True} loss_reg = EntropyLoss()(loss_dict) @@ -68,8 +70,9 @@ def training(self): rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg - + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss = train_loss + loss.item() @@ -83,6 +86,11 @@ def training(self): p_soft, y = reshape_prediction_and_ground_truth(p_soft, y) dice_list = get_classwise_dice(p_soft, y) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid @@ -91,5 +99,7 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } return train_scalers \ No newline at end of file diff --git a/pymic/net_run/weak_sup/wsl_gatedcrf.py b/pymic/net_run/weak_sup/wsl_gatedcrf.py index 2ce1f95..1ecae4a 100644 --- a/pymic/net_run/weak_sup/wsl_gatedcrf.py +++ b/pymic/net_run/weak_sup/wsl_gatedcrf.py @@ -2,6 +2,7 @@ from __future__ import print_function, division import logging import numpy as np +import time import torch from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth @@ -48,20 +49,19 @@ def training(self): iter_max = self.config['training']['iter_max'] rampup_start = wsl_cfg.get('rampup_start', 0) rampup_end = wsl_cfg.get('rampup_end', iter_max) - train_loss = 0 - train_loss_sup = 0 - train_loss_reg = 0 + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] - + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 gatecrf_loss = GatedCRFLoss() self.net.train() for it in range(iter_valid): + t0 = time.time() try: data = next(self.trainIter) except StopIteration: self.trainIter = iter(self.train_loader) data = next(self.trainIter) - + t1 = time.time() # get the inputs inputs = self.convert_tensor_type(data['image']) y = self.convert_tensor_type(data['label_prob']) @@ -94,8 +94,9 @@ def training(self): rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg - + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss = train_loss + loss.item() @@ -109,6 +110,11 @@ def training(self): p_soft, y = reshape_prediction_and_ground_truth(p_soft, y) dice_list = get_classwise_dice(p_soft, y) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid @@ -117,6 +123,8 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } return train_scalers \ No newline at end of file diff --git a/pymic/net_run/weak_sup/wsl_mumford_shah.py b/pymic/net_run/weak_sup/wsl_mumford_shah.py index 2480fee..e917fb3 100644 --- a/pymic/net_run/weak_sup/wsl_mumford_shah.py +++ b/pymic/net_run/weak_sup/wsl_mumford_shah.py @@ -2,6 +2,7 @@ from __future__ import print_function, division import logging import numpy as np +import time import torch from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth @@ -37,20 +38,20 @@ def training(self): iter_max = self.config['training']['iter_max'] rampup_start = wsl_cfg.get('rampup_start', 0) rampup_end = wsl_cfg.get('rampup_end', iter_max) - train_loss = 0 - train_loss_sup = 0 - train_loss_reg = 0 + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 reg_loss_calculator = MumfordShahLoss(wsl_cfg) self.net.train() for it in range(iter_valid): + t0 = time.time() try: data = next(self.trainIter) except StopIteration: self.trainIter = iter(self.train_loader) data = next(self.trainIter) - + t1 = time.time() # get the inputs inputs = self.convert_tensor_type(data['image']) y = self.convert_tensor_type(data['label_prob']) @@ -62,6 +63,7 @@ def training(self): # forward + backward + optimize outputs = self.net(inputs) + t2 = time.time() loss_sup = self.get_loss_value(data, outputs, y) loss_dict = {"prediction":outputs, 'image':inputs} loss_reg = reg_loss_calculator(loss_dict) @@ -69,8 +71,9 @@ def training(self): rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg - # if (self.config['training']['use']) + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss = train_loss + loss.item() @@ -84,6 +87,11 @@ def training(self): p_soft, y = reshape_prediction_and_ground_truth(p_soft, y) dice_list = get_classwise_dice(p_soft, y) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid @@ -92,6 +100,8 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } return train_scalers \ No newline at end of file diff --git a/pymic/net_run/weak_sup/wsl_tv.py b/pymic/net_run/weak_sup/wsl_tv.py index 9d13c5d..5a43ce0 100644 --- a/pymic/net_run/weak_sup/wsl_tv.py +++ b/pymic/net_run/weak_sup/wsl_tv.py @@ -2,6 +2,7 @@ from __future__ import print_function, division import logging import numpy as np +import time import torch from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth @@ -34,18 +35,18 @@ def training(self): iter_max = self.config['training']['iter_max'] rampup_start = wsl_cfg.get('rampup_start', 0) rampup_end = wsl_cfg.get('rampup_end', iter_max) - train_loss = 0 - train_loss_sup = 0 - train_loss_reg = 0 + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() for it in range(iter_valid): + t0 = time.time() try: data = next(self.trainIter) except StopIteration: self.trainIter = iter(self.train_loader) data = next(self.trainIter) - + t1 = time.time() # get the inputs inputs = self.convert_tensor_type(data['image']) y = self.convert_tensor_type(data['label_prob']) @@ -57,6 +58,7 @@ def training(self): # forward + backward + optimize outputs = self.net(inputs) + t2 = time.time() loss_sup = self.get_loss_value(data, outputs, y) loss_dict = {"prediction":outputs, 'softmax':True} loss_reg = TotalVariationLoss()(loss_dict) @@ -64,8 +66,9 @@ def training(self): rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg - # if (self.config['training']['use']) + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss = train_loss + loss.item() @@ -79,6 +82,11 @@ def training(self): p_soft, y = reshape_prediction_and_ground_truth(p_soft, y) dice_list = get_classwise_dice(p_soft, y) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid @@ -87,6 +95,8 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } return train_scalers \ No newline at end of file diff --git a/pymic/net_run/weak_sup/wsl_ustm.py b/pymic/net_run/weak_sup/wsl_ustm.py index 31a6644..ac1c6a5 100644 --- a/pymic/net_run/weak_sup/wsl_ustm.py +++ b/pymic/net_run/weak_sup/wsl_ustm.py @@ -3,6 +3,7 @@ import logging import numpy as np import random +import time import torch import torch.nn.functional as F from pymic.loss.seg.util import get_soft_label @@ -54,19 +55,19 @@ def training(self): iter_max = self.config['training']['iter_max'] rampup_start = wsl_cfg.get('rampup_start', 0) rampup_end = wsl_cfg.get('rampup_end', iter_max) - train_loss = 0 - train_loss_sup = 0 - train_loss_reg = 0 + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() self.net_ema.to(self.device) for it in range(iter_valid): + t0 = time.time() try: data = next(self.trainIter) except StopIteration: self.trainIter = iter(self.train_loader) data = next(self.trainIter) - + t1 = time.time() # get the inputs inputs = self.convert_tensor_type(data['image']) y = self.convert_tensor_type(data['label_prob']) @@ -79,6 +80,7 @@ def training(self): # forward + backward + optimize noise = torch.clamp(torch.randn_like(inputs) * 0.1, -0.2, 0.2) outputs = self.net(inputs + noise) + t2 = time.time() out_prob= F.softmax(outputs, dim=1) loss_sup = self.get_loss_value(data, outputs, y) @@ -117,7 +119,7 @@ def training(self): regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg - + t3 = time.time() loss.backward() self.optimizer.step() @@ -126,6 +128,7 @@ def training(self): alpha = min(1 - 1 / (self.glob_it / iter_valid + 1), alpha) for ema_param, param in zip(self.net_ema.parameters(), self.net.parameters()): ema_param.data.mul_(alpha).add(param.data, alpha = 1.0 - alpha) + t4 = time.time() train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() @@ -138,6 +141,11 @@ def training(self): p_soft, y = reshape_prediction_and_ground_truth(p_soft, y) dice_list = get_classwise_dice(p_soft, y) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid @@ -146,5 +154,7 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } return train_scalers \ No newline at end of file From b11dd8c22239e533a756fd709c907cb2e07b8e7f Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 30 Sep 2024 16:17:35 +0800 Subject: [PATCH 50/86] update self-supervised learning methods --- pymic/net_run/agent_rec.py | 30 +- pymic/net_run/agent_seg.py | 25 +- pymic/net_run/self_sup/self_volf.py | 37 ++ pymic/net_run/self_sup/self_volume_fusion.py | 354 ------------------- pymic/transform/trans_dict.py | 17 +- pymic/transform/volume_fusion.py | 117 ++++++ 6 files changed, 213 insertions(+), 367 deletions(-) create mode 100644 pymic/net_run/self_sup/self_volf.py delete mode 100644 pymic/net_run/self_sup/self_volume_fusion.py create mode 100644 pymic/transform/volume_fusion.py diff --git a/pymic/net_run/agent_rec.py b/pymic/net_run/agent_rec.py index cd311ad..cc83526 100644 --- a/pymic/net_run/agent_rec.py +++ b/pymic/net_run/agent_rec.py @@ -29,6 +29,9 @@ class ReconstructionAgent(SegmentationAgent): """ def __init__(self, config, stage = 'train'): super(ReconstructionAgent, self).__init__(config, stage) + if (self.config['network']['class_num'] != 1): + raise ValueError("For reconstruction tasks, the output channel number should be 1, " + + "but {} was given.".format(self.config['network']['class_num'])) def create_loss_calculator(self): if(self.loss_dict is None): @@ -55,14 +58,17 @@ def create_loss_calculator(self): def training(self): iter_valid = self.config['training']['iter_valid'] train_loss = 0 + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() for it in range(iter_valid): + t0 = time.time() try: data = next(self.trainIter) except StopIteration: self.trainIter = iter(self.train_loader) data = next(self.trainIter) # get the inputs + t1 = time.time() inputs = self.convert_tensor_type(data['image']) label = self.convert_tensor_type(data['label']) @@ -87,7 +93,7 @@ def training(self): # forward + backward + optimize outputs = self.net(inputs) - + t2 = time.time() # for debug # if it < 5: # outputs = nn.Tanh()(outputs) @@ -100,15 +106,23 @@ def training(self): # break loss = self.get_loss_value(data, outputs, label) + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss = train_loss + loss.item() - # get dice evaluation for each class + if(isinstance(outputs, tuple) or isinstance(outputs, list)): outputs = outputs[0] + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid - train_scalers = {'loss': train_avg_loss} + train_scalers = {'loss': train_avg_loss, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time} return train_scalers def validation(self): @@ -163,6 +177,9 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it) logging.info('train loss {0:.4f}'.format(train_scalars['loss'])) logging.info('valid loss {0:.4f}'.format(valid_scalars['loss'])) + logging.info("data: {0:.2f}s, forward: {1:.2f}s, loss: {2:.2f}s, backward: {3:.2f}s".format( + train_scalars['data_time'], train_scalars['forward_time'], + train_scalars['loss_time'], train_scalars['backward_time'])) def train_valid(self): device_ids = self.config['training']['gpus'] @@ -173,7 +190,7 @@ def train_valid(self): self.device = torch.device("cuda:{0:}".format(device_ids[0])) self.net.to(self.device) - ckpt_dir = self.config['training']['ckpt_save_dir'] + ckpt_dir = self.config['training']['ckpt_dir'] ckpt_prefix = self.config['training'].get('ckpt_prefix', None) if(ckpt_prefix is None): ckpt_prefix = ckpt_dir.split('/')[-1] @@ -224,7 +241,7 @@ def train_valid(self): self.trainIter = iter(self.train_loader) logging.info("{0:} training start".format(str(datetime.now())[:-7])) - self.summ_writer = SummaryWriter(self.config['training']['ckpt_save_dir']) + self.summ_writer = SummaryWriter(self.config['training']['ckpt_dir']) self.glob_it = iter_start for it in range(iter_start, iter_max, iter_valid): lr_value = self.optimizer.param_groups[0]['lr'] @@ -242,6 +259,9 @@ def train_valid(self): logging.info("\n{0:} it {1:}".format(str(datetime.now())[:-7], self.glob_it)) logging.info('learning rate {0:}'.format(lr_value)) logging.info("training/validation time: {0:.2f}s/{1:.2f}s".format(t1-t0, t2-t1)) + logging.info("data: {0:.2f}s, forward: {1:.2f}s, loss: {2:.2f}s, backward: {3:.2f}s".format( + train_scalars['data_time'], train_scalars['forward_time'], + train_scalars['loss_time'], train_scalars['backward_time'])) self.write_scalars(train_scalars, valid_scalars, lr_value, self.glob_it) if(valid_scalars['loss'] < self.min_val_loss): self.min_val_loss = valid_scalars['loss'] diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 4a80298..59ffe05 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -74,6 +74,7 @@ def get_stage_dataset_from_config(self, stage): stage_dir = self.config['dataset']['valid_dir'] if(stage == 'test' and "test_dir" in self.config['dataset']): stage_dir = self.config['dataset']['test_dir'] + logging.info("Creating dataset for {0:}".format(stage)) dataset = NiftyDataset(root_dir = stage_dir, csv_file = csv_file, modal_num = modal_num, @@ -163,14 +164,16 @@ def training(self): mixup_prob = self.config['training'].get('mixup_probability', 0.0) train_loss = 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() for it in range(iter_valid): + t0 = time.time() try: data = next(self.trainIter) except StopIteration: self.trainIter = iter(self.train_loader) data = next(self.trainIter) - # get the inputs + t1 = time.time() inputs = self.convert_tensor_type(data['image']) labels_prob = self.convert_tensor_type(data['label_prob']) if(mixup_prob > 0 and random() < mixup_prob): @@ -196,14 +199,17 @@ def training(self): inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device) - + # zero the parameter gradients self.optimizer.zero_grad() # forward + backward + optimize outputs = self.net(inputs) + t2 = time.time() loss = self.get_loss_value(data, outputs, labels_prob) + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss = train_loss + loss.item() # get dice evaluation for each class @@ -214,12 +220,19 @@ def training(self): soft_out, labels_prob = reshape_prediction_and_ground_truth(soft_out, labels_prob) dice_list = get_classwise_dice(soft_out, labels_prob) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) train_avg_dice = train_cls_dice[1:].mean() train_scalers = {'loss': train_avg_loss, 'avg_fg_dice':train_avg_dice,\ - 'class_dice': train_cls_dice} + 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time} return train_scalers def validation(self): @@ -289,7 +302,10 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): ' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]") logging.info('valid loss {0:.4f}, avg foreground dice {1:.4f} '.format( valid_scalars['loss'], valid_scalars['avg_fg_dice']) + "[" + \ - ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") + ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") + logging.info("data: {0:.2f}s, forward: {1:.2f}s, loss: {2:.2f}s, backward: {3:.2f}s".format( + train_scalars['data_time'], train_scalars['forward_time'], + train_scalars['loss_time'], train_scalars['backward_time'])) def load_pretrained_weights(self, network, pretrained_dict, device_ids): if(len(device_ids) > 1): @@ -607,4 +623,3 @@ def save_outputs(self, data): if(len(temp_prob.shape) == 2): temp_prob = np.asarray(temp_prob * 255, np.uint8) save_nd_array_as_image(temp_prob, prob_save_name, test_dir + '/' + names[i][0]) -0]) diff --git a/pymic/net_run/self_sup/self_volf.py b/pymic/net_run/self_sup/self_volf.py new file mode 100644 index 0000000..4615979 --- /dev/null +++ b/pymic/net_run/self_sup/self_volf.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +from pymic.net_run.agent_seg import SegmentationAgent + + +class SelfSupVolumeFusion(SegmentationAgent): + """ + Abstract class for self-supervised segmentation. + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. + """ + def __init__(self, config, stage = 'train'): + super(SelfSupVolumeFusion, self).__init__(config, stage) + + def get_transform_names_and_parameters(self, stage): + trans_names, trans_params = super(SelfSupVolumeFusion, self).get_transform_names_and_parameters(stage) + if(stage == 'train'): + print('training transforms:', trans_names) + if("Crop4VolumeFusion" not in trans_names): + raise ValueError("Crop4VolumeFusion is required for VolF, \ + but it is not given in training transform") + if("VolumeFusion" not in trans_names): + raise ValueError("VolumeFusion is required for VolF, \ + but it is not given in training transform") + if("LabelToProbability" not in trans_names): + raise ValueError("LabelToProbability is required for VolF, \ + but it is not given in training transform") + return trans_names, trans_params + + \ No newline at end of file diff --git a/pymic/net_run/self_sup/self_volume_fusion.py b/pymic/net_run/self_sup/self_volume_fusion.py deleted file mode 100644 index ad2d640..0000000 --- a/pymic/net_run/self_sup/self_volume_fusion.py +++ /dev/null @@ -1,354 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division -import copy -import os -import sys -import shutil -import time -import logging -import scipy -import torch -import torchvision.transforms as transforms -import numpy as np -import torch.nn as nn -import torch.optim as optim -import torch.nn.functional as F -from datetime import datetime -from random import random -from torch.optim import lr_scheduler -from tensorboardX import SummaryWriter -from pymic.io.image_read_write import save_nd_array_as_image -from pymic.io.nifty_dataset import NiftyDataset -from pymic.net.net_dict_seg import SegNetDict -from pymic.net_run.agent_abstract import NetRunAgent -from pymic.net_run.infer_func import Inferer -from pymic.loss.loss_dict_seg import SegLossDict -from pymic.loss.seg.combined import CombinedLoss -from pymic.loss.seg.deep_sup import DeepSuperviseLoss -from pymic.loss.seg.util import get_soft_label -from pymic.loss.seg.util import reshape_prediction_and_ground_truth -from pymic.loss.seg.util import get_classwise_dice -from pymic.transform.trans_dict import TransformDict -from pymic.util.post_process import PostProcessDict -from pymic.util.parse_config import * -from pymic.util.general import get_one_hot_seg -from pymic.io.image_read_write import save_nd_array_as_image -from pymic.net_run.self_sup.util import volume_fusion, nonlienar_volume_fusion, augmented_volume_fusion -from pymic.net_run.self_sup.util import self_volume_fusion -from pymic.net_run.agent_seg import SegmentationAgent - - -class SelfSupVolumeFusion(SegmentationAgent): - """ - Abstract class for self-supervised segmentation. - - :param config: (dict) A dictionary containing the configuration. - :param stage: (str) One of the stage in `train` (default), `inference` or `test`. - - .. note:: - - In the configuration dictionary, in addition to the four sections (`dataset`, - `network`, `training` and `inference`) used in fully supervised learning, an - extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. - """ - def __init__(self, config, stage = 'train'): - super(SelfSupVolumeFusion, self).__init__(config, stage) - - def training(self): - class_num = self.config['network']['class_num'] - iter_valid = self.config['training']['iter_valid'] - block_range = self.config['self_supervised_learning']['VolumeFusion_block_range'.lower()] - size_min = self.config['self_supervised_learning']['VolumeFusion_size_min'.lower()] - size_max = self.config['self_supervised_learning']['VolumeFusion_size_max'.lower()] - - train_loss = 0 - train_dice_list = [] - self.net.train() - for it in range(iter_valid): - try: - data = next(self.trainIter) - except StopIteration: - self.trainIter = iter(self.train_loader) - data = next(self.trainIter) - # get the inputs - inputs = self.convert_tensor_type(data['image']) - inputs, labels = volume_fusion(inputs, class_num - 1, block_range, size_min, size_max) - labels_prob = get_one_hot_seg(labels, class_num) - - # for debug - # if(it==10): - # break - # for i in range(inputs.shape[0]): - # image_i = inputs[i][0] - # label_i = np.argmax(labels_prob[i], axis = 0) - # # pixw_i = pix_w[i][0] - # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) - # label_name = "temp/label_{0:}_{1:}.nii.gz".format(it, i) - # # weight_name= "temp/weight_{0:}_{1:}.nii.gz".format(it, i) - # save_nd_array_as_image(image_i, image_name, reference_name = None) - # save_nd_array_as_image(label_i, label_name, reference_name = None) - # # save_nd_array_as_image(pixw_i, weight_name, reference_name = None) - # continue - - inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device) - - # zero the parameter gradients - self.optimizer.zero_grad() - - # forward + backward + optimize - outputs = self.net(inputs) - loss = self.get_loss_value(data, outputs, labels_prob) - loss.backward() - self.optimizer.step() - train_loss = train_loss + loss.item() - # get dice evaluation for each class - if(isinstance(outputs, tuple) or isinstance(outputs, list)): - outputs = outputs[0] - outputs_argmax = torch.argmax(outputs, dim = 1, keepdim = True) - soft_out = get_soft_label(outputs_argmax, class_num, self.tensor_type) - soft_out, labels_prob = reshape_prediction_and_ground_truth(soft_out, labels_prob) - dice_list = get_classwise_dice(soft_out, labels_prob) - train_dice_list.append(dice_list.cpu().numpy()) - train_avg_loss = train_loss / iter_valid - train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) - train_avg_dice = train_cls_dice[1:].mean() - - train_scalers = {'loss': train_avg_loss, 'avg_fg_dice':train_avg_dice,\ - 'class_dice': train_cls_dice} - return train_scalers - -class SelfSupSelfVolumeFusion(SegmentationAgent): - """ - Abstract class for self-supervised segmentation. - - :param config: (dict) A dictionary containing the configuration. - :param stage: (str) One of the stage in `train` (default), `inference` or `test`. - - .. note:: - - In the configuration dictionary, in addition to the four sections (`dataset`, - `network`, `training` and `inference`) used in fully supervised learning, an - extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. - """ - def __init__(self, config, stage = 'train'): - super(SelfSupSelfVolumeFusion, self).__init__(config, stage) - - def training(self): - class_num = self.config['network']['class_num'] - iter_valid = self.config['training']['iter_valid'] - fuse_ratio = self.config['self_supervised_learning']['SelfVolumeFusion_fuse_ratio'.lower()] - size_min = self.config['self_supervised_learning']['SelfVolumeFusion_size_min'.lower()] - size_max = self.config['self_supervised_learning']['SelfVolumeFusion_size_max'.lower()] - - train_loss = 0 - train_dice_list = [] - self.net.train() - for it in range(iter_valid): - try: - data = next(self.trainIter) - except StopIteration: - self.trainIter = iter(self.train_loader) - data = next(self.trainIter) - # get the inputs - inputs = self.convert_tensor_type(data['image']) - inputs, labels = self_volume_fusion(inputs, class_num - 1, fuse_ratio, size_min, size_max) - labels_prob = get_one_hot_seg(labels, class_num) - - # for debug - # if(it==10): - # break - # for i in range(inputs.shape[0]): - # image_i = inputs[i][0] - # label_i = np.argmax(labels_prob[i], axis = 0) - # # pixw_i = pix_w[i][0] - # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) - # label_name = "temp/label_{0:}_{1:}.nii.gz".format(it, i) - # # weight_name= "temp/weight_{0:}_{1:}.nii.gz".format(it, i) - # save_nd_array_as_image(image_i, image_name, reference_name = None) - # save_nd_array_as_image(label_i, label_name, reference_name = None) - # # save_nd_array_as_image(pixw_i, weight_name, reference_name = None) - # continue - - inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device) - - # zero the parameter gradients - self.optimizer.zero_grad() - - # forward + backward + optimize - outputs = self.net(inputs) - loss = self.get_loss_value(data, outputs, labels_prob) - loss.backward() - self.optimizer.step() - train_loss = train_loss + loss.item() - # get dice evaluation for each class - if(isinstance(outputs, tuple) or isinstance(outputs, list)): - outputs = outputs[0] - outputs_argmax = torch.argmax(outputs, dim = 1, keepdim = True) - soft_out = get_soft_label(outputs_argmax, class_num, self.tensor_type) - soft_out, labels_prob = reshape_prediction_and_ground_truth(soft_out, labels_prob) - dice_list = get_classwise_dice(soft_out, labels_prob) - train_dice_list.append(dice_list.cpu().numpy()) - train_avg_loss = train_loss / iter_valid - train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) - train_avg_dice = train_cls_dice[1:].mean() - - train_scalers = {'loss': train_avg_loss, 'avg_fg_dice':train_avg_dice,\ - 'class_dice': train_cls_dice} - return train_scalers - -class SelfSupNonLinearVolumeFusion(SegmentationAgent): - """ - Abstract class for self-supervised segmentation. - - :param config: (dict) A dictionary containing the configuration. - :param stage: (str) One of the stage in `train` (default), `inference` or `test`. - - .. note:: - - In the configuration dictionary, in addition to the four sections (`dataset`, - `network`, `training` and `inference`) used in fully supervised learning, an - extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. - """ - def __init__(self, config, stage = 'train'): - super(SelfSupNonLinearVolumeFusion, self).__init__(config, stage) - - def training(self): - class_num = 3 - iter_valid = self.config['training']['iter_valid'] - block_range = self.config['self_supervised_learning']['NonLinearVolumeFusion_block_range'.lower()] - size_min = self.config['self_supervised_learning']['NonLinearVolumeFusion_size_min'.lower()] - size_max = self.config['self_supervised_learning']['NonLinearVolumeFusion_size_max'.lower()] - - train_loss = 0 - train_dice_list = [] - self.net.train() - for it in range(iter_valid): - try: - data = next(self.trainIter) - except StopIteration: - self.trainIter = iter(self.train_loader) - data = next(self.trainIter) - # get the inputs - inputs = self.convert_tensor_type(data['image']) - inputs, labels = nonlienar_volume_fusion(inputs, block_range, size_min, size_max) - labels_prob = get_one_hot_seg(labels, class_num) - - # for debug - # if(it==10): - # break - # for i in range(inputs.shape[0]): - # image_i = inputs[i][0] - # label_i = np.argmax(labels_prob[i], axis = 0) - # # pixw_i = pix_w[i][0] - # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) - # label_name = "temp/label_{0:}_{1:}.nii.gz".format(it, i) - # # weight_name= "temp/weight_{0:}_{1:}.nii.gz".format(it, i) - # save_nd_array_as_image(image_i, image_name, reference_name = None) - # save_nd_array_as_image(label_i, label_name, reference_name = None) - # # save_nd_array_as_image(pixw_i, weight_name, reference_name = None) - # continue - - inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device) - - # zero the parameter gradients - self.optimizer.zero_grad() - - # forward + backward + optimize - outputs = self.net(inputs) - loss = self.get_loss_value(data, outputs, labels_prob) - loss.backward() - self.optimizer.step() - train_loss = train_loss + loss.item() - # get dice evaluation for each class - if(isinstance(outputs, tuple) or isinstance(outputs, list)): - outputs = outputs[0] - outputs_argmax = torch.argmax(outputs, dim = 1, keepdim = True) - soft_out = get_soft_label(outputs_argmax, class_num, self.tensor_type) - soft_out, labels_prob = reshape_prediction_and_ground_truth(soft_out, labels_prob) - dice_list = get_classwise_dice(soft_out, labels_prob) - train_dice_list.append(dice_list.cpu().numpy()) - train_avg_loss = train_loss / iter_valid - train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) - train_avg_dice = train_cls_dice[1:].mean() - - train_scalers = {'loss': train_avg_loss, 'avg_fg_dice':train_avg_dice,\ - 'class_dice': train_cls_dice} - return train_scalers - -class SelfSupAugmentedVolumeFusion(SegmentationAgent): - """ - Abstract class for self-supervised segmentation. - - :param config: (dict) A dictionary containing the configuration. - :param stage: (str) One of the stage in `train` (default), `inference` or `test`. - - .. note:: - - In the configuration dictionary, in addition to the four sections (`dataset`, - `network`, `training` and `inference`) used in fully supervised learning, an - extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. - """ - def __init__(self, config, stage = 'train'): - super(SelfSupAugmentedVolumeFusion, self).__init__(config, stage) - - def training(self): - class_num = 5 - iter_valid = self.config['training']['iter_valid'] - size_min = self.config['self_supervised_learning']['AugmentedVolumeFusion_size_min'.lower()] - size_max = self.config['self_supervised_learning']['AugmentedVolumeFusion_size_max'.lower()] - - train_loss = 0 - train_dice_list = [] - self.net.train() - for it in range(iter_valid): - try: - data = next(self.trainIter) - except StopIteration: - self.trainIter = iter(self.train_loader) - data = next(self.trainIter) - # get the inputs - inputs = self.convert_tensor_type(data['image']) - inputs, labels = augmented_volume_fusion(inputs, size_min, size_max) - labels_prob = get_one_hot_seg(labels, class_num) - - # for debug - # if(it==10): - # break - # for i in range(inputs.shape[0]): - # image_i = inputs[i][0] - # label_i = np.argmax(labels_prob[i], axis = 0) - # # pixw_i = pix_w[i][0] - # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) - # label_name = "temp/label_{0:}_{1:}.nii.gz".format(it, i) - # # weight_name= "temp/weight_{0:}_{1:}.nii.gz".format(it, i) - # save_nd_array_as_image(image_i, image_name, reference_name = None) - # save_nd_array_as_image(label_i, label_name, reference_name = None) - # # save_nd_array_as_image(pixw_i, weight_name, reference_name = None) - # continue - - inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device) - - # zero the parameter gradients - self.optimizer.zero_grad() - - # forward + backward + optimize - outputs = self.net(inputs) - loss = self.get_loss_value(data, outputs, labels_prob) - loss.backward() - self.optimizer.step() - train_loss = train_loss + loss.item() - # get dice evaluation for each class - if(isinstance(outputs, tuple) or isinstance(outputs, list)): - outputs = outputs[0] - outputs_argmax = torch.argmax(outputs, dim = 1, keepdim = True) - soft_out = get_soft_label(outputs_argmax, class_num, self.tensor_type) - soft_out, labels_prob = reshape_prediction_and_ground_truth(soft_out, labels_prob) - dice_list = get_classwise_dice(soft_out, labels_prob) - train_dice_list.append(dice_list.cpu().numpy()) - train_avg_loss = train_loss / iter_valid - train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) - train_avg_dice = train_cls_dice[1:].mean() - - train_scalers = {'loss': train_avg_loss, 'avg_fg_dice':train_avg_dice,\ - 'class_dice': train_cls_dice} - return train_scalers \ No newline at end of file diff --git a/pymic/transform/trans_dict.py b/pymic/transform/trans_dict.py index ed5ad0c..2a15857 100644 --- a/pymic/transform/trans_dict.py +++ b/pymic/transform/trans_dict.py @@ -15,6 +15,7 @@ 'LabelConvertNonzero': LabelConvertNonzero, 'LabelToProbability': LabelToProbability, 'IntensityClip': IntensityClip, + 'NonLinearTransform': NonLinearTransform, 'NormalizeWithMeanStd': NormalizeWithMeanStd, 'NormalizeWithMinMax': NormalizeWithMinMax, 'NormalizeWithPercentiles': NormalizeWithPercentiles, @@ -41,20 +42,28 @@ from pymic.transform.threshold import * from pymic.transform.normalize import * from pymic.transform.crop import * -from pymic.transform.mix import * +from pymic.transform.crop4dino import Crop4Dino +from pymic.transform.crop4vox2vec import Crop4Vox2Vec +from pymic.transform.crop4vf import Crop4VolumeFusion, VolumeFusion, VolumeFusionShuffle +from pymic.transform.volume_fusion import * from pymic.transform.label_convert import * TransformDict = { 'Affine': Affine, + 'AdaptiveContrastAdjust': AdaptiveContrastAdjust, 'ChannelWiseThreshold': ChannelWiseThreshold, 'ChannelWiseThresholdWithNormalize': ChannelWiseThresholdWithNormalize, 'CropWithBoundingBox': CropWithBoundingBox, 'CropWithForeground': CropWithForeground, - 'CropHumanRegionFromCT': CropHumanRegionFromCT, + 'CropHumanRegion': CropHumanRegion, 'CenterCrop': CenterCrop, + 'Crop4Dino': Crop4Dino, + 'Crop4Vox2Vec': Crop4Vox2Vec, + 'Crop4VolumeFusion': Crop4VolumeFusion, 'GrayscaleToRGB': GrayscaleToRGB, 'GammaCorrection': GammaCorrection, 'GaussianNoise': GaussianNoise, + 'HistEqual': HistEqual, 'InPainting': InPainting, 'InOutPainting': InOutPainting, 'LabelConvert': LabelConvert, @@ -62,6 +71,7 @@ 'LabelToProbability': LabelToProbability, 'LocalShuffling': LocalShuffling, 'IntensityClip': IntensityClip, + 'MaskedImageModeling': MaskedImageModeling, 'NonLinearTransform': NonLinearTransform, 'NormalizeWithMeanStd': NormalizeWithMeanStd, 'NormalizeWithMinMax': NormalizeWithMinMax, @@ -82,5 +92,6 @@ 'OutPainting': OutPainting, 'Pad': Pad, 'PatchSwaping':PatchSwaping, - 'PatchMix': PatchMix + 'VolumeFusion': VolumeFusion, + 'VolumeFusionShuffle': VolumeFusionShuffle } diff --git a/pymic/transform/volume_fusion.py b/pymic/transform/volume_fusion.py new file mode 100644 index 0000000..38c74c0 --- /dev/null +++ b/pymic/transform/volume_fusion.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import random +import numpy as np +from pymic.transform.abstract_transform import AbstractTransform +from pymic.util.image_process import * +try: # SciPy >= 0.19 + from scipy.special import comb +except ImportError: + from scipy.misc import comb + +def random_resized_crop(x, output_shape): + img_shape = x.shape[1:] + ratio = [img_shape[i] / output_shape[i] for i in range(3)] + r_max = [min(ratio[i], 1.25) for i in range(3)] + r_min = (0.8, 0.8, 0.8) + scale = [r_min[i] + random.random() * (r_max[i] - r_min[i]) for i in range(3)] + crop_size = [int(output_shape[i] * scale[i]) for i in range(3)] + + bb_min = [random.randint(0, img_shape[i] - crop_size[i]) for i in range(3)] + bb_max = [bb_min[i] + crop_size[i] for i in range(3)] + bb_min = [0] + bb_min + bb_max = [x.shape[0]] + bb_max + crop_volume = crop_ND_volume_with_bounding_box(x, bb_min, bb_max) + + scale = [(output_shape[i] + 0.0)/crop_size[i] for i in range(3)] + scale = [1.0] + scale + y = ndimage.interpolation.zoom(crop_volume, scale, order = 1) + return y + +def nonlinear_transform(x): + v_min = np.min(x) + v_max = np.max(x) + x = (x - v_min)/(v_max - v_min) + a = random.random() * 0.7 + 0.15 + b = random.random() * 0.7 + 0.15 + alpha = b / a + beta = (1 - b) / (1 - a) + if(alpha < 1.0 ): + y = np.maximum(alpha*x, beta*x + 1 - beta) + else: + y = np.minimum(alpha*x, beta*x + 1 - beta) + if(random.random() < 0.5): + y = 1.0 - y + y = y * (v_max - v_min) + v_min + return y + +def random_flip(x): + flip_axis = [] + if(random.random() > 0.5): + flip_axis.append(-1) + if(random.random() > 0.5): + flip_axis.append(-2) + if(random.random() > 0.5): + flip_axis.append(-3) + + if(len(flip_axis) > 0): + # use .copy() to avoid negative strides of numpy array + # current pytorch does not support negative strides + y = np.flip(x, flip_axis).copy() + else: + y = x + return y + + +class VolumeFusion(AbstractTransform): + """ + fusing two subvolumes of an image, used for self-supervised learning + """ + def __init__(self, params): + super(VolumeFusion, self).__init__(params) + self.inverse = params.get('VolumeFusion_inverse'.lower(), False) + self.crop_size = params.get('VolumeFusion_crop_size'.lower(), [64, 128, 128]) + self.block_range = params.get('VolumeFusion_block_range'.lower(), [20, 40]) + self.size_min = params.get('VolumeFusion_size_min'.lower(), [8, 16, 16]) + self.size_max = params.get('VolumeFusion_size_max'.lower(), [16, 32, 32]) + + def __call__(self, sample): + x = sample['image'] + x0 = random_resized_crop(x, self.crop_size) + x1 = random_resized_crop(x, self.crop_size) + x0 = random_flip(x0) + x1 = random_flip(x1) + # nonlinear transform + x0a = nonlinear_transform(x0) + x0b = nonlinear_transform(x0) + x1 = nonlinear_transform(x1) + + D, H, W = x0.shape[1:] + mask = np.zeros_like(x0, np.uint8) + p_num = random.randint(self.block_range[0], self.block_range[1]) + for i in range(p_num): + d = random.randint(self.size_min[0], self.size_max[0]) + h = random.randint(self.size_min[1], self.size_max[1]) + w = random.randint(self.size_min[2], self.size_max[2]) + dc = random.randint(0, D - 1) + hc = random.randint(0, H - 1) + wc = random.randint(0, W - 1) + d0 = dc - d // 2 + h0 = hc - h // 2 + w0 = wc - w // 2 + d1 = min(D, d0 + d) + h1 = min(H, h0 + h) + w1 = min(W, w0 + w) + d0, h0, w0 = max(0, d0), max(0, h0), max(0, w0) + temp_m = np.ones([d1 - d0, h1 - h0, w1 - w0]) + if(random.random() < 0.5): + temp_m = temp_m * 2 + mask[:, d0:d1, h0:h1, w0:w1] = temp_m + + mask1 = np.asarray(mask == 1, np.uint8) + mask2 = np.asarray(mask == 2, np.uint8) + y = x0a * (1.0 - mask1) + x0b * mask1 + y = y * (1.0 - mask2) + x1 * mask2 + sample['image'] = y + sample['label'] = mask + return sample From 53db91d5066f2ff20a9487fd4118f9e2750f6908 Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 30 Sep 2024 16:20:24 +0800 Subject: [PATCH 51/86] update transform --- pymic/transform/affine.py | 4 - pymic/transform/crop.py | 50 +++++-- pymic/transform/crop4dino.py | 175 +++++++++++++++++++++++ pymic/transform/crop4vf.py | 232 +++++++++++++++++++++++++++++++ pymic/transform/crop4vox2vec.py | 160 +++++++++++++++++++++ pymic/transform/flip.py | 3 +- pymic/transform/rescale.py | 11 +- pymic/transform/volume_fusion.py | 117 ---------------- 8 files changed, 612 insertions(+), 140 deletions(-) create mode 100644 pymic/transform/crop4dino.py create mode 100644 pymic/transform/crop4vf.py create mode 100644 pymic/transform/crop4vox2vec.py delete mode 100644 pymic/transform/volume_fusion.py diff --git a/pymic/transform/affine.py b/pymic/transform/affine.py index 2efd586..1717a97 100644 --- a/pymic/transform/affine.py +++ b/pymic/transform/affine.py @@ -152,9 +152,5 @@ def _get_param_for_inverse_transform(self, sample): # aff_out_shape = origin_shape[-2:] # output_predict = self._apply_affine_to_ND_volume(predict, aff_out_shape, tform.inverse) - # sample['predict'] = output_predict - # return sample - = self._apply_affine_to_ND_volume(predict, aff_out_shape, tform.inverse) - # sample['predict'] = output_predict # return sample diff --git a/pymic/transform/crop.py b/pymic/transform/crop.py index 9e1c077..c444acd 100644 --- a/pymic/transform/crop.py +++ b/pymic/transform/crop.py @@ -260,7 +260,7 @@ def _get_crop_param(self, sample): crop_min = [0 if item == 0 else random.randint(0, item) for item in crop_margin] crop_max = [crop_min[i] + output_size[i] for i in range(input_dim)] - label_exist = False if ('label' not in sample or sample['label']) is None else True + label_exist = True if ('label' in sample and sample['label'].sum() > 0) else False if(label_exist and self.fg_focus and random.random() < self.fg_ratio): label = sample['label'][0] if(self.mask_label is None): @@ -398,6 +398,9 @@ class RandomSlice(AbstractTransform): """ def __init__(self, params): self.output_size = params['RandomSlice_output_size'.lower()] + self.fg_focus = params.get('RandomSlice_foreground_focus'.lower(), False) + self.fg_ratio = params.get('RandomSlice_foreground_ratio'.lower(), 0.5) + self.mask_label = params.get('RandomSlice_mask_label'.lower(), None) self.shuffle = params.get('RandomSlice_shuffle'.lower(), False) self.inverse = params.get('RandomSlice_inverse'.lower(), False) self.task = params['Task'.lower()] @@ -406,13 +409,30 @@ def __call__(self, sample): image = sample['image'] D = image.shape[1] assert( D >= self.output_size) + out_half = self.output_size // 2 + + label_exist = True if ('label' in sample and sample['label'].sum() > 0) else False + if(label_exist and self.fg_focus and random.random() < self.fg_ratio): + label = sample['label'][0] + if(self.mask_label is None): + mask_label = np.unique(label)[1:] + else: + mask_label = self.mask_label + random_label = random.choice(mask_label) + mask = label == random_label + dc = random.choice(np.nonzero(mask)[0]) + else: + dc = random.choice(range(out_half, D - out_half)) + slice_idx = list(range(D)) if(self.shuffle): random.shuffle(slice_idx) - slice_idx = slice_idx[:self.output_size] - else: d0 = random.randint(0, D - self.output_size) d1 = d0 + self.output_size + slice_idx = slice_idx[d0:d1-1] + [dc] + else: + d0 = max(0, dc - out_half) + d1 = d0 + self.output_size slice_idx = slice_idx[d0:d1] sample['image'] = image[:, slice_idx, :, :] @@ -448,6 +468,7 @@ class CropHumanRegion(CenterCrop): """ def __init__(self, params): self.threshold_i = params.get('CropHumanRegion_intensity_threshold'.lower(), -600) + self.threshold_mode = params.get('CropHumanRegion_threshold_mode'.lower(), 'mean') self.threshold_z = params.get('CropHumanRegion_zaxis_threshold'.lower(), 0.5) self.inverse = params.get('CropHumanRegion_inverse'.lower(), True) self.task = params['task'] @@ -456,20 +477,27 @@ def _get_crop_param(self, sample): image = sample['image'] input_shape = image.shape mask = np.asarray(image[0] > self.threshold_i) - mask2d = np.mean(mask, axis = 0) > self.threshold_z + if(self.threshold_mode == "mean"): + mask2d = np.mean(mask, axis = 0) > self.threshold_z + else: + mask2d = np.max(mask, axis = 0) se = np.ones([3,3]) mask2d = ndimage.binary_opening(mask2d, se, iterations = 2) - mask2d = get_largest_k_components(mask2d, 1) - bbmin, bbmax = get_ND_bounding_box(mask2d, margin = [0, 0]) + if(mask2d.sum() > 0): + mask2d = get_largest_k_components(mask2d, 1) + bbmin, bbmax = get_ND_bounding_box(mask2d, margin = [0, 0]) + else: + bbmin = [0] * (image.ndim - 2) + bbmax = list(input_shape[2:]) crop_min = [0, 0] + bbmin crop_max = list(input_shape[:2]) + bbmax - sample['CropHumanRegionFromCT_Param'] = json.dumps((input_shape, crop_min, crop_max)) + sample['CropHumanRegion_Param'] = json.dumps((input_shape, crop_min, crop_max)) return sample, crop_min, crop_max def _get_param_for_inverse_transform(self, sample): - if(isinstance(sample['CropHumanRegionFromCT_Param'], list) or \ - isinstance(sample['CropHumanRegionFromCT_Param'], tuple)): - params = json.loads(sample['CropHumanRegionFromCT_Param'][0]) + if(isinstance(sample['CropHumanRegion_Param'], list) or \ + isinstance(sample['CropHumanRegion_Param'], tuple)): + params = json.loads(sample['CropHumanRegion_Param'][0]) else: - params = json.loads(sample['CropHumanRegionFromCT_Param']) + params = json.loads(sample['CropHumanRegion_Param']) return params \ No newline at end of file diff --git a/pymic/transform/crop4dino.py b/pymic/transform/crop4dino.py new file mode 100644 index 0000000..df86787 --- /dev/null +++ b/pymic/transform/crop4dino.py @@ -0,0 +1,175 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import json +import math +import random +import numpy as np +from scipy import ndimage +from pymic import TaskType +from pymic.transform.abstract_transform import AbstractTransform +from pymic.transform.crop import CenterCrop +from pymic.transform.intensity import * +from pymic.util.image_process import * + +class Crop4Dino(CenterCrop): + """ + Randomly crop an volume into two views with augmentation. This is used for + self-supervised pretraining such as DeSD. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `DualViewCrop_output_size`: (list/tuple) Desired output size [D, H, W]. + The output channel is the same as the input channel. + :param `DualViewCrop_scale_lower_bound`: (list/tuple) Lower bound of the range of scale + for each dimension. e.g. (1.0, 0.5, 0.5). + param `DualViewCrop_scale_upper_bound`: (list/tuple) Upper bound of the range of scale + for each dimension. e.g. (1.0, 2.0, 2.0). + :param `DualViewCrop_inverse`: (optional, bool) Is inverse transform needed for inference. + Default is `False`. Currently, the inverse transform is not supported, and + this transform is assumed to be used only during training stage. + """ + def __init__(self, params): + self.output_size = params['Crop4Dino_output_size'.lower()] + self.scale_lower = params['Crop4Dino_resize_lower_bound'.lower()] + self.scale_upper = params['Crop4Dino_resize_upper_bound'.lower()] + self.prob = params.get('Crop4Dino_resize_prob'.lower(), 0.5) + self.noise_std_range = params.get('Crop4Dino_noise_std_range'.lower(), (0.05, 0.1)) + self.blur_sigma_range = params.get('Crop4Dino_blur_sigma_range'.lower(), (1.0, 3.0)) + self.gamma_range = params.get('Crop4Dino_gamma_range'.lower(), (0.75, 1.25)) + self.inverse = params.get('Crop4Dino_inverse'.lower(), False) + self.task = params['Task'.lower()] + assert isinstance(self.output_size, (list, tuple)) + assert isinstance(self.scale_lower, (list, tuple)) + assert isinstance(self.scale_upper, (list, tuple)) + + def __call__(self, sample): + image = sample['image'] + channel, input_size = image.shape[0], image.shape[1:] + input_dim = len(input_size) + assert(input_dim == len(self.output_size)) + + # # center crop first + # crop_size = self.output_size + # crop_margin = [input_size[i] - crop_size[i] for i in range(input_dim)] + # crop_min = [int(item/2) for item in crop_margin] + # crop_max = [crop_min[i] + crop_size[i] for i in range(input_dim)] + # crop_min = [0] + crop_min + # crop_max = [channel] + crop_max + # crop0 = crop_ND_volume_with_bounding_box(image, crop_min, crop_max) + + crop_num = 2 + crop_img = [] + for crop_i in range(crop_num): + resize = random.random() < self.prob + if(resize): + scale = [self.scale_lower[i] + (self.scale_upper[i] - self.scale_lower[i]) * random.random() \ + for i in range(input_dim)] + crop_size = [int(self.output_size[i] * scale[i]) for i in range(input_dim)] + else: + crop_size = self.output_size + + crop_margin = [input_size[i] - crop_size[i] for i in range(input_dim)] + pad_image = min(crop_margin) < 0 + if(pad_image): # pad the image if necessary + pad_size = [max(0, -crop_margin[i]) for i in range(input_dim)] + pad_lower = [int(pad_size[i] / 2) for i in range(input_dim)] + pad_upper = [pad_size[i] - pad_lower[i] for i in range(input_dim)] + pad = [(pad_lower[i], pad_upper[i]) for i in range(input_dim)] + pad = tuple([(0, 0)] + pad) + image = np.pad(image, pad, 'reflect') + crop_margin = [max(0, crop_margin[i]) for i in range(input_dim)] + + + crop_min = [random.randint(0, item) for item in crop_margin] + crop_max = [crop_min[i] + crop_size[i] for i in range(input_dim)] + crop_min = [0] + crop_min + crop_max = [channel] + crop_max + + crop_out = crop_ND_volume_with_bounding_box(image, crop_min, crop_max) + if(resize): + scale = [(self.output_size[i] + 0.0)/crop_size[i] for i in range(input_dim)] + scale = [1.0] + scale + crop_out = ndimage.interpolation.zoom(crop_out, scale, order = 1) + + # add intensity augmentation + C = crop_out.shape[0] + for c in range(C): + if(random.random() < 0.8): + crop_out[c] = gaussian_noise(crop_out[c], self.noise_std_range[0], self.noise_std_range[1]) + + if(random.uniform(0, 1) < 0.5): + crop_out[c] = gaussian_blur(crop_out[c], self.blur_sigma_range[0], self.blur_sigma_range[1]) + else: + alpha = random.uniform(0.0, 2.0) + crop_out[c] = gaussian_sharpen(crop_out[c], self.blur_sigma_range[0], self.blur_sigma_range[1], alpha) + if(random.random() < 0.8): + crop_out[c] = gamma_correction(crop_out[c], self.gamma_range[0], self.gamma_range[1]) + if(random.random() < 0.8): + crop_out[c] = window_level_augment(crop_out[c]) + crop_img.append(crop_out) + sample['image'] = crop_img + return sample + + def __call__backup(self, sample): + image = sample['image'] + channel, input_size = image.shape[0], image.shape[1:] + input_dim = len(input_size) + assert(input_dim == len(self.output_size)) + + # center crop first + crop_size = self.output_size + crop_margin = [input_size[i] - crop_size[i] for i in range(input_dim)] + crop_min = [int(item/2) for item in crop_margin] + crop_max = [crop_min[i] + crop_size[i] for i in range(input_dim)] + crop_min = [0] + crop_min + crop_max = [channel] + crop_max + crop0 = crop_ND_volume_with_bounding_box(image, crop_min, crop_max) + + # crop_num = 2 + # crop_img = [] + # for crop_i in range(crop_num): + # get another resized crop size + resize = random.random() < self.prob + if(resize): + scale = [self.scale_lower[i] + (self.scale_upper[i] - self.scale_lower[i]) * random.random() \ + for i in range(input_dim)] + crop_size = [int(self.output_size[i] * scale[i]) for i in range(input_dim)] + else: + crop_size = self.output_size + + crop_margin = [input_size[i] - crop_size[i] for i in range(input_dim)] + pad_image = min(crop_margin) < 0 + if(pad_image): # pad the image if necessary + pad_size = [max(0, -crop_margin[i]) for i in range(input_dim)] + pad_lower = [int(pad_size[i] / 2) for i in range(input_dim)] + pad_upper = [pad_size[i] - pad_lower[i] for i in range(input_dim)] + pad = [(pad_lower[i], pad_upper[i]) for i in range(input_dim)] + pad = tuple([(0, 0)] + pad) + image = np.pad(image, pad, 'reflect') + crop_margin = [max(0, crop_margin[i]) for i in range(input_dim)] + + + crop_min = [random.randint(0, item) for item in crop_margin] + crop_max = [crop_min[i] + crop_size[i] for i in range(input_dim)] + crop_min = [0] + crop_min + crop_max = [channel] + crop_max + + crop_out = crop_ND_volume_with_bounding_box(image, crop_min, crop_max) + if(resize): + scale = [(self.output_size[i] + 0.0)/crop_size[i] for i in range(input_dim)] + scale = [1.0] + scale + crop_out = ndimage.interpolation.zoom(crop_out, scale, order = 1) + # crop_img.append(crop_out) + crop_img = [crop0, crop_out] + # add intensity augmentation + # image_t = gaussian_noise(image_t, self.noise_std_range[0], self.noise_std_range[1], 0.8) + # image_t = gaussian_blur(image_t, self.blur_sigma_range[0], self.blur_sigma_range[1], 0.8) + # image_t = brightness_multiplicative(image_t, self.inten_multi_range[0], self.inten_multi_range[1], 0.8) + # image_t = brightness_additive(image_t, self.inten_add_range[0], self.inten_add_range[1], 0.8) + # image_t = contrast_augment(image_t, self.contrast_f_range[0], self.contrast_f_range[1], 0.8) + # image_t = gamma_correction(image_t, self.gamma_range[0], self.gamma_range[1], 0.8) + sample['image'] = crop_img + return sample diff --git a/pymic/transform/crop4vf.py b/pymic/transform/crop4vf.py new file mode 100644 index 0000000..4e07357 --- /dev/null +++ b/pymic/transform/crop4vf.py @@ -0,0 +1,232 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch + +import json +import math +import random +import numpy as np +from imops import crop_to_box +from typing import * +from scipy import ndimage +from pymic import TaskType +from pymic.transform.abstract_transform import AbstractTransform +from pymic.transform.crop import CenterCrop +from pymic.util.image_process import * +from pymic.transform.intensity import * + + +def random_resized_crop(image, output_size, scale_lower, scale_upper): + input_size = image.shape + scale = [scale_lower[i] + (scale_upper[i] - scale_lower[i]) * random.random() \ + for i in range(3)] + crop_size = [min(int(output_size[i] * scale[i]), input_size[1+i]) for i in range(3)] + crop_margin = [input_size[1+i] - crop_size[i] for i in range(3)] + crop_min = [random.randint(0, item) for item in crop_margin] + crop_max = [crop_min[i] + crop_size[i] for i in range(3)] + crop_min = [0] + crop_min + crop_max = [input_size[0]] + crop_max + + image_t = crop_ND_volume_with_bounding_box(image, crop_min, crop_max) + scale = [(output_size[i] + 0.0)/crop_size[i] for i in range(3)] + scale = [1.0] + scale + image_t = ndimage.interpolation.zoom(image_t, scale, order = 1) + return image_t + +def random_flip(image): + flip_axis = [] + if(random.random() > 0.5): + flip_axis.append(-1) + if(random.random() > 0.5): + flip_axis.append(-2) + if(random.random() > 0.5): + flip_axis.append(-3) + if(len(flip_axis) > 0): + image = np.flip(image , flip_axis) + return image + + +class Crop4VolumeFusion(AbstractTransform): + """ + Randomly crop an volume into two views with augmentation. This is used for + self-supervised pretraining in Vox2vec. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `Crop4VolumeFusion_output_size`: (list/tuple) Desired output size [D, H, W]. + The output channel is the same as the input channel. + :param `Crop4VolumeFusion_rescale_lower_bound`: (list/tuple) Lower bound of the range of scale + for each dimension. e.g. (1.0, 0.5, 0.5). + param `Crop4VolumeFusion_rescale_upper_bound`: (list/tuple) Upper bound of the range of scale + for each dimension. e.g. (1.0, 2.0, 2.0). + :param `Crop4VolumeFusion_augentation_mode`: (optional, int) The mode for augmentation of cropped volume. + 0: no spatial or intensity augmentatin. + 1: intensity augmentation only +` 2: spatial augmentation only + 3: Both intensity and spatial augmentation (default). + """ + def __init__(self, params): + self.output_size = params['Crop4VolumeFusion_output_size'.lower()] + self.scale_lower = params.get('Crop4VolumeFusion_rescale_lower_bound'.lower(), [0.7, 0.7, 0.7]) + self.scale_upper = params.get('Crop4VolumeFusion_rescale_upper_bound'.lower(), [1.5, 1.5, 1.5]) + self.aug_mode = params.get('Crop4VolumeFusion_augentation_mode'.lower(), 3) + self.task = params['Task'.lower()] + assert isinstance(self.output_size, (list, tuple)) + + def __call__(self, sample): + image = sample['image'] + channel, input_size = image.shape[0], image.shape[1:] + input_dim = len(input_size) + assert channel == 1 + assert(input_dim == len(self.output_size)) + + if(self.aug_mode == 0 or self.aug_mode == 1): + self.scale_lower = [1.0, 1.0, 1.0] + self.scale_upper = [1.0, 1.0, 1.0] + patch_1 = random_resized_crop(image, self.output_size, self.scale_lower, self.scale_upper) + patch_2 = random_resized_crop(image, self.output_size, self.scale_lower, self.scale_upper) + if(self.aug_mode > 1): + patch_1 = random_flip(patch_1) + patch_2 = random_flip(patch_2) + if(self.aug_mode == 1 or self.aug_mode == 3): + p0, p1 = random.uniform(0.1, 2.0), random.uniform(98, 99.9) + patch_1 = adaptive_contrast_adjust(patch_1, p0, p1) + patch_1 = gamma_correction(patch_1, 0.7, 1.5) + + p0, p1 = random.uniform(0.1, 2.0), random.uniform(98, 99.9) + patch_2 = adaptive_contrast_adjust(patch_2, p0, p1) + patch_2 = gamma_correction(patch_2, 0.7, 1.5) + + if(random.random() < 0.25): + patch_1 = 1.0 - patch_1 + patch_2 = 1.0 - patch_2 + + sample['image'] = patch_1, patch_2 + return sample + +class VolumeFusion(AbstractTransform): + """ + Randomly crop an volume into two views with augmentation. This is used for + self-supervised pretraining in Vox2vec. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `DualViewCrop_output_size`: (list/tuple) Desired output size [D, H, W]. + The output channel is the same as the input channel. + :param `DualViewCrop_scale_lower_bound`: (list/tuple) Lower bound of the range of scale + for each dimension. e.g. (1.0, 0.5, 0.5). + param `DualViewCrop_scale_upper_bound`: (list/tuple) Upper bound of the range of scale + for each dimension. e.g. (1.0, 2.0, 2.0). + :param `DualViewCrop_inverse`: (optional, bool) Is inverse transform needed for inference. + Default is `False`. Currently, the inverse transform is not supported, and + this transform is assumed to be used only during training stage. + """ + def __init__(self, params): + self.cls_num = params.get('VolumeFusion_cls_num'.lower(), 5) + self.ratio = params.get('VolumeFusion_foreground_ratio'.lower(), 0.7) + self.size_min = params.get('VolumeFusion_patchsize_min'.lower(), [8, 8, 8]) + self.size_max = params.get('VolumeFusion_patchsize_max'.lower(), [32, 32, 32]) + self.task = params['Task'.lower()] + + def __call__(self, sample): + K = self.cls_num - 1 + image1, image2 = sample['image'] + C, D, H, W = image1.shape + db = random.randint(self.size_min[0], self.size_max[0]) + hb = random.randint(self.size_min[1], self.size_max[1]) + wb = random.randint(self.size_min[2], self.size_max[2]) + d_offset = random.randint(0, D % db) + h_offset = random.randint(0, H % hb) + w_offset = random.randint(0, W % wb) + d_n = D // db + h_n = H // hb + w_n = W // wb + Nblock = d_n * h_n * w_n + Nfg = int(d_n * h_n * w_n * self.ratio) + list_fg = [1] * Nfg + [0] * (Nblock - Nfg) + random.shuffle(list_fg) + mask = np.zeros([1, D, H, W], np.uint8) + for d in range(d_n): + for h in range(h_n): + for w in range(w_n): + d0, h0, w0 = d*db + d_offset, h*hb + h_offset, w*wb + w_offset + d1, h1, w1 = d0 + db, h0 + hb, w0 + wb + idx = d*h_n*w_n + h*w_n + w + if(list_fg[idx]> 0): + cls_k = random.randint(1, K) + mask[:, d0:d1, h0:h1, w0:w1] = cls_k + alpha = mask * 1.0 / K + x_fuse = alpha*image1 + (1.0 - alpha)*image2 + sample['image'] = x_fuse + sample['label'] = mask + return sample + +class VolumeFusionShuffle(AbstractTransform): + """ + Randomly crop an volume into two views with augmentation. This is used for + self-supervised pretraining in Vox2vec. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `DualViewCrop_output_size`: (list/tuple) Desired output size [D, H, W]. + The output channel is the same as the input channel. + :param `DualViewCrop_scale_lower_bound`: (list/tuple) Lower bound of the range of scale + for each dimension. e.g. (1.0, 0.5, 0.5). + param `DualViewCrop_scale_upper_bound`: (list/tuple) Upper bound of the range of scale + for each dimension. e.g. (1.0, 2.0, 2.0). + :param `DualViewCrop_inverse`: (optional, bool) Is inverse transform needed for inference. + Default is `False`. Currently, the inverse transform is not supported, and + this transform is assumed to be used only during training stage. + """ + def __init__(self, params): + self.cls_num = params.get('VolumeFusionShuffle_cls_num'.lower(), 5) + self.ratio = params.get('VolumeFusionShuffle_foreground_ratio'.lower(), 0.7) + self.size_min = params.get('VolumeFusionShuffle_patchsize_min'.lower(), [8, 8, 8]) + self.size_max = params.get('VolumeFusionShuffle_patchsize_max'.lower(), [32, 32, 32]) + self.task = params['Task'.lower()] + + def __call__(self, sample): + K = self.cls_num - 1 + image1, image2 = sample['image'] + C, D, H, W = image1.shape + x_fuse = image2 * 1.0 + mask = np.zeros([1, D, H, W], np.uint8) + db = random.randint(self.size_min[0], self.size_max[0]) + hb = random.randint(self.size_min[1], self.size_max[1]) + wb = random.randint(self.size_min[2], self.size_max[2]) + d_offset = random.randint(0, D % db) + h_offset = random.randint(0, H % hb) + w_offset = random.randint(0, W % wb) + d_n = D // db + h_n = H // hb + w_n = W // wb + coord_list_source = [] + for di in range(d_n): + for hi in range(h_n): + for wi in range(w_n): + coord_list_source.append([di, hi, wi]) + coord_list_target = copy.deepcopy(coord_list_source) + random.shuffle(coord_list_source) + random.shuffle(coord_list_target) + for i in range(int(len(coord_list_source)*self.ratio)): + ds_l = d_offset + db * coord_list_source[i][0] + hs_l = h_offset + hb * coord_list_source[i][1] + ws_l = w_offset + wb * coord_list_source[i][2] + dt_l = d_offset + db * coord_list_target[i][0] + ht_l = h_offset + hb * coord_list_target[i][1] + wt_l = w_offset + wb * coord_list_target[i][2] + s_crop = image1[:, ds_l:ds_l+db, hs_l:hs_l+hb, ws_l:ws_l+wb] + t_crop = image2[:, dt_l:dt_l+db, ht_l:ht_l+hb, wt_l:wt_l+wb] + fg_m = random.randint(1, K) + fg_w = fg_m / (K + 0.0) + x_fuse[:, dt_l:dt_l+db, ht_l:ht_l+hb, wt_l:wt_l+wb] = t_crop * (1.0 - fg_w) + s_crop * fg_w + mask[0, dt_l:dt_l+db, ht_l:ht_l+hb, wt_l:wt_l+wb] = \ + np.ones([1, db, hb, wb]) * fg_m + sample['image'] = x_fuse + sample['label'] = mask + return sample + diff --git a/pymic/transform/crop4vox2vec.py b/pymic/transform/crop4vox2vec.py new file mode 100644 index 0000000..6fdcf83 --- /dev/null +++ b/pymic/transform/crop4vox2vec.py @@ -0,0 +1,160 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch + +import json +import math +import random +import numpy as np +from imops import crop_to_box +from typing import * +from scipy import ndimage +from pymic import TaskType +from pymic.transform.abstract_transform import AbstractTransform +from pymic.transform.crop import CenterCrop +from pymic.util.image_process import * +from pymic.transform.intensity import * + +def normalize_axis_list(axis, ndim): + return list(np.core.numeric.normalize_axis_tuple(axis, ndim)) + +def scale_hu(image_hu: np.ndarray, window_hu: Tuple[float, float]) -> np.ndarray: + min_hu, max_hu = window_hu + assert min_hu < max_hu + return np.clip((image_hu - min_hu) / (max_hu - min_hu), 0, 1) + +# def gaussian_filter( +# x: np.ndarray, +# sigma: Union[float, Sequence[float]], +# axis: Union[int, Sequence[int]] +# ) -> np.ndarray: +# axis = normalize_axis_list(axis, x.ndim) +# sigma = np.broadcast_to(sigma, len(axis)) +# for sgm, ax in zip(sigma, axis): +# x = ndimage.gaussian_filter1d(x, sgm, ax) +# return x + +# def gaussian_sharpen( +# x: np.ndarray, +# sigma_1: Union[float, Sequence[float]], +# sigma_2: Union[float, Sequence[float]], +# alpha: float, +# axis: Union[int, Sequence[int]] +# ) -> np.ndarray: +# """ See https://docs.monai.io/en/stable/transforms.html#gaussiansharpen """ +# blurred = gaussian_filter(x, sigma_1, axis) +# return blurred + alpha * (blurred - gaussian_filter(blurred, sigma_2, axis)) + +def sample_box(image_size, patch_size, anchor_voxel=None): + image_size = np.array(image_size, ndmin=1) + patch_size = np.array(patch_size, ndmin=1) + + if not np.all(image_size >= patch_size): + raise ValueError(f'Can\'t sample patch of size {patch_size} from image of size {image_size}') + + min_start = 0 + max_start = image_size - patch_size + if anchor_voxel is not None: + anchor_voxel = np.array(anchor_voxel, ndmin=1) + min_start = np.maximum(min_start, anchor_voxel - patch_size + 1) + max_start = np.minimum(max_start, anchor_voxel) + start = np.random.randint(min_start, max_start + 1) + return np.array([start, start + patch_size]) + +def sample_views( + image: np.ndarray, + min_overlap: Tuple[int, int, int], + patch_size: Tuple[int, int, int], + max_num_voxels: int, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """ For 3D volumes, the image shape should be [C, D, H, W]. + """ + img_size = image.shape[1:] + overlap = [random.randint(min_overlap[i], patch_size[i]) for i in range(3)] + union_size = [2*patch_size[i] - overlap[i] for i in range(3)] + anchor_max = [img_size[i] - union_size[i] for i in range(3)] + crop_min_1 = [random.randint(0, anchor_max[i]) for i in range(3)] + crop_min_2 = [crop_min_1[i] + patch_size[i] - overlap[i] for i in range(3)] + patch_1 = sample_view(image, crop_min_1, patch_size) + patch_2 = sample_view(image, crop_min_2, patch_size) + + coords = [range(crop_min_2[i], crop_min_2[i] + overlap[i]) for i in range(3)] + coords = np.asarray(np.meshgrid(coords[0], coords[1], coords[2])) + coords = coords.reshape(3, -1).transpose() + roi_voxels_1 = coords - crop_min_1 + roi_voxels_2 = coords - crop_min_2 + + indices = range(coords.shape[0]) + if len(indices) > max_num_voxels: + indices = np.random.choice(indices, max_num_voxels, replace=False) + + return patch_1, patch_2, roi_voxels_1[indices], roi_voxels_2[indices] + + +def sample_view(image, crop_min, patch_size): + """ For 3D volumes, the image shape should be [C, D, H, W]. + """ + assert image.ndim == 4 + C = image.shape[0] + crop_max = [crop_min[i] + patch_size[i] for i in range(3)] + out = crop_ND_volume_with_bounding_box(image, [0] + crop_min, [C] + crop_max) + + # intensity augmentations + for c in range(C): + if(random.random() < 0.8): + out[c] = gaussian_noise(out[c], 0.05, 0.1) + if(random.random() < 0.5): + out[c] = gaussian_blur(out[c], 0.5, 1.5) + else: + alpha = random.uniform(0.0, 2.0) + out[c] = gaussian_sharpen(out[c], 0.5, 2.0, alpha) + if(random.random() < 0.8): + out[c] = gamma_correction(out[c], 0.5, 2.0) + if(random.random() < 0.8): + out[c] = window_level_augment(out[c]) + return out + +class Crop4Vox2Vec(CenterCrop): + """ + Randomly crop an volume into two views with augmentation. This is used for + self-supervised pretraining in Vox2vec. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `DualViewCrop_output_size`: (list/tuple) Desired output size [D, H, W]. + The output channel is the same as the input channel. + :param `DualViewCrop_scale_lower_bound`: (list/tuple) Lower bound of the range of scale + for each dimension. e.g. (1.0, 0.5, 0.5). + param `DualViewCrop_scale_upper_bound`: (list/tuple) Upper bound of the range of scale + for each dimension. e.g. (1.0, 2.0, 2.0). + :param `DualViewCrop_inverse`: (optional, bool) Is inverse transform needed for inference. + Default is `False`. Currently, the inverse transform is not supported, and + this transform is assumed to be used only during training stage. + """ + def __init__(self, params): + self.output_size = params['Crop4Vox2Vec_output_size'.lower()] + self.min_overlap = params.get('Crop4Vox2Vec_min_overlap'.lower(), [8, 12, 12]) + self.max_voxel = params.get('Crop4Vox2Vec_max_voxel'.lower(), 1024) + self.inverse = params.get('Crop4Vox2Vec_inverse'.lower(), False) + self.task = params['Task'.lower()] + assert isinstance(self.output_size, (list, tuple)) + + def __call__(self, sample): + image = sample['image'] + channel, input_size = image.shape[0], image.shape[1:] + input_dim = len(input_size) + assert channel == 1 + assert(input_dim == len(self.output_size)) + invalid_size = [input_size[i] < self.output_size[i]*2 - self.min_overlap[i] for i in range(3)] + if True in invalid_size: + raise ValueError("The overlap requirement {0:} is too weak for the given patch size \ + {1:} and input size {2:}".format( self.min_overlap, self.output_size,input_size)) + + patches_1, patches_2, voxels_1, voxels_2 = sample_views(image, + self.min_overlap, self.output_size, self.max_voxel) + sample['image'] = patches_1, patches_2, voxels_1, voxels_2 + return sample + + diff --git a/pymic/transform/flip.py b/pymic/transform/flip.py index 6ea017c..6ffd535 100644 --- a/pymic/transform/flip.py +++ b/pymic/transform/flip.py @@ -33,8 +33,7 @@ def __init__(self, params): def __call__(self, sample): image = sample['image'] - input_shape = image.shape - input_dim = len(input_shape) - 1 + input_dim = image.ndim flip_axis = [] if(self.flip_width): if(random.random() > 0.5): diff --git a/pymic/transform/rescale.py b/pymic/transform/rescale.py index 2896a4e..47271ec 100644 --- a/pymic/transform/rescale.py +++ b/pymic/transform/rescale.py @@ -165,11 +165,10 @@ class Resample(Rescale): The arguments should be written in the `params` dictionary, and it has the following fields: - :param `Rescale_output_size`: (list/tuple or int) The output size along each spatial axis, - such as [D, H, W] or [H, W]. If D is None, the input image is only reslcaled in 2D. - If int, the smallest axis is matched to output_size keeping aspect ratio the same - as the input. - :param `Rescale_inverse`: (optional, bool) + :param `Resample_output_spacing`: (list/tuple or int) The output spacing along each spatial axis, + such as [Ds, Hs, Ws] or [Hs, Ws]. If Ds is None, the input image is only reslcaled in 2D. + :param `Resample_ignore_zspacing_range`: (list/tuple) The range of zspacing that would be ingored. + :param `Resample_inverse`: (optional, bool) Is inverse transform needed for inference. Default is `True`. """ def __init__(self, params): @@ -177,11 +176,11 @@ def __init__(self, params): self.output_spacing = params["Resample_output_spacing".lower()] self.ignore_zspacing= params.get("Resample_ignore_zspacing_range".lower(), None) self.inverse = params.get("Resample_inverse".lower(), True) - # assert isinstance(self.output_size, (int, list, tuple)) def __call__(self, sample): image = sample['image'] input_shape = image.shape + input_dim = len(input_shape) - 1 spacing = sample['spacing'] out_spacing = [item for item in self.output_spacing] diff --git a/pymic/transform/volume_fusion.py b/pymic/transform/volume_fusion.py deleted file mode 100644 index 38c74c0..0000000 --- a/pymic/transform/volume_fusion.py +++ /dev/null @@ -1,117 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division -import random -import numpy as np -from pymic.transform.abstract_transform import AbstractTransform -from pymic.util.image_process import * -try: # SciPy >= 0.19 - from scipy.special import comb -except ImportError: - from scipy.misc import comb - -def random_resized_crop(x, output_shape): - img_shape = x.shape[1:] - ratio = [img_shape[i] / output_shape[i] for i in range(3)] - r_max = [min(ratio[i], 1.25) for i in range(3)] - r_min = (0.8, 0.8, 0.8) - scale = [r_min[i] + random.random() * (r_max[i] - r_min[i]) for i in range(3)] - crop_size = [int(output_shape[i] * scale[i]) for i in range(3)] - - bb_min = [random.randint(0, img_shape[i] - crop_size[i]) for i in range(3)] - bb_max = [bb_min[i] + crop_size[i] for i in range(3)] - bb_min = [0] + bb_min - bb_max = [x.shape[0]] + bb_max - crop_volume = crop_ND_volume_with_bounding_box(x, bb_min, bb_max) - - scale = [(output_shape[i] + 0.0)/crop_size[i] for i in range(3)] - scale = [1.0] + scale - y = ndimage.interpolation.zoom(crop_volume, scale, order = 1) - return y - -def nonlinear_transform(x): - v_min = np.min(x) - v_max = np.max(x) - x = (x - v_min)/(v_max - v_min) - a = random.random() * 0.7 + 0.15 - b = random.random() * 0.7 + 0.15 - alpha = b / a - beta = (1 - b) / (1 - a) - if(alpha < 1.0 ): - y = np.maximum(alpha*x, beta*x + 1 - beta) - else: - y = np.minimum(alpha*x, beta*x + 1 - beta) - if(random.random() < 0.5): - y = 1.0 - y - y = y * (v_max - v_min) + v_min - return y - -def random_flip(x): - flip_axis = [] - if(random.random() > 0.5): - flip_axis.append(-1) - if(random.random() > 0.5): - flip_axis.append(-2) - if(random.random() > 0.5): - flip_axis.append(-3) - - if(len(flip_axis) > 0): - # use .copy() to avoid negative strides of numpy array - # current pytorch does not support negative strides - y = np.flip(x, flip_axis).copy() - else: - y = x - return y - - -class VolumeFusion(AbstractTransform): - """ - fusing two subvolumes of an image, used for self-supervised learning - """ - def __init__(self, params): - super(VolumeFusion, self).__init__(params) - self.inverse = params.get('VolumeFusion_inverse'.lower(), False) - self.crop_size = params.get('VolumeFusion_crop_size'.lower(), [64, 128, 128]) - self.block_range = params.get('VolumeFusion_block_range'.lower(), [20, 40]) - self.size_min = params.get('VolumeFusion_size_min'.lower(), [8, 16, 16]) - self.size_max = params.get('VolumeFusion_size_max'.lower(), [16, 32, 32]) - - def __call__(self, sample): - x = sample['image'] - x0 = random_resized_crop(x, self.crop_size) - x1 = random_resized_crop(x, self.crop_size) - x0 = random_flip(x0) - x1 = random_flip(x1) - # nonlinear transform - x0a = nonlinear_transform(x0) - x0b = nonlinear_transform(x0) - x1 = nonlinear_transform(x1) - - D, H, W = x0.shape[1:] - mask = np.zeros_like(x0, np.uint8) - p_num = random.randint(self.block_range[0], self.block_range[1]) - for i in range(p_num): - d = random.randint(self.size_min[0], self.size_max[0]) - h = random.randint(self.size_min[1], self.size_max[1]) - w = random.randint(self.size_min[2], self.size_max[2]) - dc = random.randint(0, D - 1) - hc = random.randint(0, H - 1) - wc = random.randint(0, W - 1) - d0 = dc - d // 2 - h0 = hc - h // 2 - w0 = wc - w // 2 - d1 = min(D, d0 + d) - h1 = min(H, h0 + h) - w1 = min(W, w0 + w) - d0, h0, w0 = max(0, d0), max(0, h0), max(0, w0) - temp_m = np.ones([d1 - d0, h1 - h0, w1 - w0]) - if(random.random() < 0.5): - temp_m = temp_m * 2 - mask[:, d0:d1, h0:h1, w0:w1] = temp_m - - mask1 = np.asarray(mask == 1, np.uint8) - mask2 = np.asarray(mask == 2, np.uint8) - y = x0a * (1.0 - mask1) + x0b * mask1 - y = y * (1.0 - mask2) + x1 * mask2 - sample['image'] = y - sample['label'] = mask - return sample From 2bb8b0813e96713af067f9a5eb325bde061d1fe3 Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 30 Sep 2024 16:22:40 +0800 Subject: [PATCH 52/86] update util files --- pymic/transform/trans_dict.py | 1 - pymic/util/image_process.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pymic/transform/trans_dict.py b/pymic/transform/trans_dict.py index 2a15857..332e594 100644 --- a/pymic/transform/trans_dict.py +++ b/pymic/transform/trans_dict.py @@ -45,7 +45,6 @@ from pymic.transform.crop4dino import Crop4Dino from pymic.transform.crop4vox2vec import Crop4Vox2Vec from pymic.transform.crop4vf import Crop4VolumeFusion, VolumeFusion, VolumeFusionShuffle -from pymic.transform.volume_fusion import * from pymic.transform.label_convert import * TransformDict = { diff --git a/pymic/util/image_process.py b/pymic/util/image_process.py index 158569c..c31f28f 100644 --- a/pymic/util/image_process.py +++ b/pymic/util/image_process.py @@ -73,7 +73,7 @@ def crop_ND_volume_with_bounding_box(volume, bb_min, bb_max): output = volume[bb_min[0]:bb_max[0], bb_min[1]:bb_max[1], bb_min[2]:bb_max[2], bb_min[3]:bb_max[3], bb_min[4]:bb_max[4]] else: raise ValueError("the dimension number shoud be 2 to 5") - return output + return output * 1 def set_ND_volume_roi_with_bounding_box_range(volume, bb_min, bb_max, sub_volume, addition = True): """ From 9cc16a6e3e29c34259083c8d5f409affc7869316 Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 30 Sep 2024 16:24:05 +0800 Subject: [PATCH 53/86] update loss function --- pymic/loss/cls/basic.py | 2 +- pymic/loss/cls/infoNCE.py | 39 +++++++++++++++++++++++++++++++++++++ pymic/loss/loss_dict_cls.py | 3 ++- pymic/loss/seg/deep_sup.py | 3 ++- 4 files changed, 44 insertions(+), 3 deletions(-) create mode 100644 pymic/loss/cls/infoNCE.py diff --git a/pymic/loss/cls/basic.py b/pymic/loss/cls/basic.py index 56925fc..6b71e23 100644 --- a/pymic/loss/cls/basic.py +++ b/pymic/loss/cls/basic.py @@ -39,7 +39,7 @@ def forward(self, loss_input_dict): class SigmoidCELoss(AbstractClassificationLoss): """ - Sigmoid-based CE loss. + Sigmoid-based CE loss, should be used when task_type = cls_coexist """ def __init__(self, params = None): super(SigmoidCELoss, self).__init__(params) diff --git a/pymic/loss/cls/infoNCE.py b/pymic/loss/cls/infoNCE.py new file mode 100644 index 0000000..fb6f1c1 --- /dev/null +++ b/pymic/loss/cls/infoNCE.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import torch.nn as nn + +class InfoNCELoss(nn.Module): + """ + Abstract Classification Loss. + """ + def __init__(self, params = None): + super(InfoNCELoss, self).__init__() + self.temp = params.get("temperature", 0.1) + + def forward(self, input_1, input_2): + """ + The arguments should be written in the `loss_input_dict` dictionary, and it has the + following fields. + + :param prediction: A prediction with shape of [N, C] where C is the class number. + :param ground_truth: The corresponding ground truth, with shape of [N, 1]. + + Note that `prediction` is the digit output of a network, before using softmax. + """ + B = list(input_1.shape)[0] + loss = 0.0 + for b in range(B): + embeds_1 = input_1[b] + embeds_2 = input_2[b] + logits_11 = torch.matmul(embeds_1, embeds_1.T) / self.temp + logits_11.fill_diagonal_(float('-inf')) + logits_12 = torch.matmul(embeds_1, embeds_2.T) / self.temp + logits_22 = torch.matmul(embeds_2, embeds_2.T) / self.temp + logits_22.fill_diagonal_(float('-inf')) + loss_1 = torch.mean(-logits_12.diag() + torch.logsumexp(torch.cat([logits_11, logits_12], dim=1), dim=1)) + loss_2 = torch.mean(-logits_12.diag() + torch.logsumexp(torch.cat([logits_12.T, logits_22], dim=1), dim=1)) + loss = loss + (loss_1 + loss_2) / 2 + loss = loss / B + return loss \ No newline at end of file diff --git a/pymic/loss/loss_dict_cls.py b/pymic/loss/loss_dict_cls.py index e07f46b..44744cb 100644 --- a/pymic/loss/loss_dict_cls.py +++ b/pymic/loss/loss_dict_cls.py @@ -11,9 +11,10 @@ """ from __future__ import print_function, division from pymic.loss.cls.basic import * - +from pymic.loss.cls.infoNCE import InfoNCELoss PyMICClsLossDict = {"CrossEntropyLoss": CrossEntropyLoss, "SigmoidCELoss": SigmoidCELoss, + 'InfoNCELoss': InfoNCELoss, "L1Loss": L1Loss, "MSELoss": MSELoss, "NLLLoss": NLLLoss} diff --git a/pymic/loss/seg/deep_sup.py b/pymic/loss/seg/deep_sup.py index da6d9ef..6669486 100644 --- a/pymic/loss/seg/deep_sup.py +++ b/pymic/loss/seg/deep_sup.py @@ -2,6 +2,7 @@ from __future__ import print_function, division import torch.nn as nn +import numpy as np from torch.nn.functional import interpolate from pymic.loss.seg.abstract import AbstractSegLoss @@ -69,7 +70,7 @@ def forward(self, loss_input_dict): be a list or a tuple""") pred_num = len(pred) if(self.deep_sup_weight is None): - self.deep_sup_weight = [1.0] * pred_num + self.deep_sup_weight = [1.0 / pow(2, i) for i in range(pred_num)] else: assert(pred_num == len(self.deep_sup_weight)) loss_sum, weight_sum = 0.0, 0.0 From 474188eaed3bcb4dd373054f10e8546e2450de73 Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 30 Sep 2024 16:26:26 +0800 Subject: [PATCH 54/86] Update torch_pretrained_net.py --- pymic/net/cls/torch_pretrained_net.py | 48 +++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 3 deletions(-) diff --git a/pymic/net/cls/torch_pretrained_net.py b/pymic/net/cls/torch_pretrained_net.py index 5017f72..c1959a4 100644 --- a/pymic/net/cls/torch_pretrained_net.py +++ b/pymic/net/cls/torch_pretrained_net.py @@ -2,7 +2,6 @@ from __future__ import print_function, division import itertools -import torch import torch.nn as nn import torchvision.models as models @@ -61,7 +60,49 @@ class ResNet18(BuiltInNet): """ def __init__(self, params): super(ResNet18, self).__init__(params) - self.net = models.resnet18(pretrained = self.pretrain) + weights = 'IMAGENET1K_V1' if self.pretrain else None + self.net = models.resnet18(weights = weights) + + # replace the last layer + num_ftrs = self.net.fc.in_features + self.net.fc = nn.Linear(num_ftrs, params['class_num']) + + # replace the first layer when in_chns is not 3 + if(self.in_chns != 3): + self.net.conv1 = nn.Conv2d(self.in_chns, 64, kernel_size=(7, 7), + stride=(2, 2), padding=(3, 3), bias=False) + + def get_parameters_to_update(self): + if(self.update_mode == "all"): + return self.net.parameters() + elif(self.update_mode == "last"): + params = self.net.fc.parameters() + if(self.in_chns !=3): + # combining the two iterables into a single one + # see: https://dzone.com/articles/python-joining-multiple + params = itertools.chain() + for pram in [self.net.fc.parameters(), self.net.conv1.parameters()]: + params = itertools.chain(params, pram) + return params + else: + raise(ValueError("update_mode can only be 'all' or 'last'.")) + +class ResNet50(BuiltInNet): + """ + ResNet18 for classification. + Parameters should be set in the `params` dictionary that contains the + following fields: + + :param input_chns: (int) Input channel number, default is 3. + :param pretrain: (bool) Using pretrained model or not, default is True. + :param update_mode: (str) The strategy for updating layers: "`all`" means updating + all the layers, and "`last`" (by default) means updating the last layer, + as well as the first layer when `input_chns` is not 3. + """ + def __init__(self, params): + super(ResNet50, self).__init__(params) + weights = 'IMAGENET1K_V1' if self.pretrain else None + self.net = models.resnet50(weights = weights) # replace the last layer num_ftrs = self.net.fc.in_features @@ -101,7 +142,8 @@ class VGG16(BuiltInNet): """ def __init__(self, params): super(VGG16, self).__init__(params) - self.net = models.vgg16(pretrained = self.pretrain) + weights = 'IMAGENET1K_V1' if self.pretrain else None + self.net = models.vgg16(weights = weights) # replace the last layer num_ftrs = self.net.classifier[-1].in_features From 1615412608e86471f94f53c08895bb153f2de836 Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 30 Sep 2024 16:28:32 +0800 Subject: [PATCH 55/86] Create canet.py --- pymic/net/net2d/canet.py | 229 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 229 insertions(+) create mode 100644 pymic/net/net2d/canet.py diff --git a/pymic/net/net2d/canet.py b/pymic/net/net2d/canet.py new file mode 100644 index 0000000..ab64ab5 --- /dev/null +++ b/pymic/net/net2d/canet.py @@ -0,0 +1,229 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import torch +import torch.nn as nn + +class ConvLayer(nn.Module): + """ + A combination of Conv2d, BatchNorm2d and LeakyReLU. + """ + def __init__(self, in_channels, out_channels, kernel_size = 1): + super(ConvLayer, self).__init__() + padding = int((kernel_size - 1) / 2) + self.conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding), + nn.BatchNorm2d(out_channels), + nn.LeakyReLU() + ) + + def forward(self, x): + return self.conv(x) + +class SEBlock(nn.Module): + """ + A Modified Squeeze-and-Excitation block for spatial attention. + """ + def __init__(self, in_channels, r): + super(SEBlock, self).__init__() + + redu_chns = int(in_channels / r) + self.se_layers = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels, redu_chns, kernel_size=1, padding=0), + nn.LeakyReLU(), + nn.Conv2d(redu_chns, in_channels, kernel_size=1, padding=0), + nn.ReLU()) + + def forward(self, x): + f = self.se_layers(x) + return f*x + x + +class ASPPBlock(nn.Module): + """ + ASPP block. + """ + def __init__(self,in_channels, out_channels_list, kernel_size_list, dilation_list): + super(ASPPBlock, self).__init__() + self.conv_num = len(out_channels_list) + assert(self.conv_num == 4) + assert(self.conv_num == len(kernel_size_list) and self.conv_num == len(dilation_list)) + pad0 = int((kernel_size_list[0] - 1) / 2 * dilation_list[0]) + pad1 = int((kernel_size_list[1] - 1) / 2 * dilation_list[1]) + pad2 = int((kernel_size_list[2] - 1) / 2 * dilation_list[2]) + pad3 = int((kernel_size_list[3] - 1) / 2 * dilation_list[3]) + self.conv_1 = nn.Conv2d(in_channels, out_channels_list[0], kernel_size = kernel_size_list[0], + dilation = dilation_list[0], padding = pad0 ) + self.conv_2 = nn.Conv2d(in_channels, out_channels_list[1], kernel_size = kernel_size_list[1], + dilation = dilation_list[1], padding = pad1 ) + self.conv_3 = nn.Conv2d(in_channels, out_channels_list[2], kernel_size = kernel_size_list[2], + dilation = dilation_list[2], padding = pad2 ) + self.conv_4 = nn.Conv2d(in_channels, out_channels_list[3], kernel_size = kernel_size_list[3], + dilation = dilation_list[3], padding = pad3 ) + + out_channels = out_channels_list[0] + out_channels_list[1] + out_channels_list[2] + out_channels_list[3] + self.conv_1x1 = nn.Sequential( + nn.Conv2d(out_channels, out_channels, kernel_size=1, padding=0), + nn.BatchNorm2d(out_channels), + nn.LeakyReLU()) + + def forward(self, x): + x1 = self.conv_1(x) + x2 = self.conv_2(x) + x3 = self.conv_3(x) + x4 = self.conv_4(x) + + y = torch.cat([x1, x2, x3, x4], dim=1) + y = self.conv_1x1(y) + return y + +class ConvBNActBlock(nn.Module): + """ + Two convolution layers with batch norm, leaky relu, + dropout and SE block. + """ + def __init__(self,in_channels, out_channels, dropout_p): + super(ConvBNActBlock, self).__init__() + self.conv_conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(out_channels), + nn.LeakyReLU(), + nn.Dropout(dropout_p), + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(out_channels), + nn.LeakyReLU(), + SEBlock(out_channels, 2) + ) + + def forward(self, x): + return self.conv_conv(x) + +class DownBlock(nn.Module): + """ + Downsampling by a concantenation of max-pool and avg-pool, + followed by ConvBNActBlock. + """ + def __init__(self, in_channels, out_channels, dropout_p): + super(DownBlock, self).__init__() + self.maxpool = nn.MaxPool2d(2) + self.avgpool = nn.AvgPool2d(2) + self.conv = ConvBNActBlock(2 * in_channels, out_channels, dropout_p) + + def forward(self, x): + x_max = self.maxpool(x) + x_avg = self.avgpool(x) + x_cat = torch.cat([x_max, x_avg], dim=1) + y = self.conv(x_cat) + return y + x_cat + +class UpBlock(nn.Module): + """ + Upssampling followed by ConvBNActBlock. + """ + def __init__(self, in_channels1, in_channels2, out_channels, + bilinear=True, dropout_p = 0.5): + super(UpBlock, self).__init__() + self.bilinear = bilinear + if bilinear: + self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size = 1) + self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + else: + self.up = nn.ConvTranspose2d(in_channels1, in_channels2, kernel_size=2, stride=2) + self.conv = ConvBNActBlock(in_channels2 * 2, out_channels, dropout_p) + + def forward(self, x1, x2): + if self.bilinear: + x1 = self.conv1x1(x1) + x1 = self.up(x1) + x_cat = torch.cat([x2, x1], dim=1) + y = self.conv(x_cat) + return y + x_cat + +class CANet(nn.Module): + """ + Implementation of of CA-Net for biomedical image segmentation. + + * Reference: R. Gu et al. `CA-Net: Comprehensive Attention Convolutional Neural Networks + for Explainable Medical Image Segmentation `_. + IEEE Transactions on Medical Imaging, 40(2),2021:699-711. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param bilinear: (bool) Using bilinear for up-sampling or not. + If False, deconvolution will be used for up-sampling. + """ + def __init__(self, params): + super(COPLENet, self).__init__() + self.params = params + self.in_chns = self.params['in_chns'] + self.ft_chns = self.params['feature_chns'] + self.dropout = self.params['dropout'] + self.n_class = self.params['class_num'] + self.bilinear = self.params['bilinear'] + assert(len(self.ft_chns) == 5) + + f0_half = int(self.ft_chns[0] / 2) + f1_half = int(self.ft_chns[1] / 2) + f2_half = int(self.ft_chns[2] / 2) + f3_half = int(self.ft_chns[3] / 2) + self.in_conv= ConvBNActBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) + self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) + self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) + self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) + self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) + + self.bridge0= ConvLayer(self.ft_chns[0], f0_half) + self.bridge1= ConvLayer(self.ft_chns[1], f1_half) + self.bridge2= ConvLayer(self.ft_chns[2], f2_half) + self.bridge3= ConvLayer(self.ft_chns[3], f3_half) + + self.up1 = UpBlock(self.ft_chns[4], f3_half, self.ft_chns[3], dropout_p = self.dropout[3]) + self.up2 = UpBlock(self.ft_chns[3], f2_half, self.ft_chns[2], dropout_p = self.dropout[2]) + self.up3 = UpBlock(self.ft_chns[2], f1_half, self.ft_chns[1], dropout_p = self.dropout[1]) + self.up4 = UpBlock(self.ft_chns[1], f0_half, self.ft_chns[0], dropout_p = self.dropout[0]) + + f4 = self.ft_chns[4] + aspp_chns = [int(f4 / 4), int(f4 / 4), int(f4 / 4), int(f4 / 4)] + aspp_knls = [1, 3, 3, 3] + aspp_dila = [1, 2, 4, 6] + self.aspp = ASPPBlock(f4, aspp_chns, aspp_knls, aspp_dila) + + + self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, + kernel_size = 3, padding = 1) + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + x0 = self.in_conv(x) + x0b = self.bridge0(x0) + x1 = self.down1(x0) + x1b = self.bridge1(x1) + x2 = self.down2(x1) + x2b = self.bridge2(x2) + x3 = self.down3(x2) + x3b = self.bridge3(x3) + x4 = self.down4(x3) + x4 = self.aspp(x4) + + x = self.up1(x4, x3b) + x = self.up2(x, x2b) + x = self.up3(x, x1b) + x = self.up4(x, x0b) + output = self.out_conv(x) + + if(len(x_shape) == 5): + new_shape = [N, D] + list(output.shape)[1:] + output = torch.reshape(output, new_shape) + output = torch.transpose(output, 1, 2) + return output \ No newline at end of file From 11e1c48549387a0e343e5715630ba6e4c47a1cb5 Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 30 Sep 2024 21:42:34 +0800 Subject: [PATCH 56/86] update 3d segmentation networks --- pymic/net/net3d/grunet.py | 264 ++++++ pymic/net/net3d/trans3d/__init__.py | 0 pymic/net/net3d/trans3d/transunet3d.py | 1053 ++++++++++++++++++++++++ pymic/net/net3d/unet3d.py | 22 +- pymic/net/net_dict_seg.py | 24 +- pymic/net_run/self_sup/__init__.py | 8 +- pymic/net_run/self_sup/self_vox2vec.py | 304 +++++++ 7 files changed, 1655 insertions(+), 20 deletions(-) create mode 100644 pymic/net/net3d/grunet.py create mode 100644 pymic/net/net3d/trans3d/__init__.py create mode 100644 pymic/net/net3d/trans3d/transunet3d.py create mode 100644 pymic/net_run/self_sup/self_vox2vec.py diff --git a/pymic/net/net3d/grunet.py b/pymic/net/net3d/grunet.py new file mode 100644 index 0000000..bae6447 --- /dev/null +++ b/pymic/net/net3d/grunet.py @@ -0,0 +1,264 @@ +# -*- coding: utf-8 -*- +# Note: this is a renamed version of fmunetv3. fmunetv3 will be removed in a later version. +# This network is originally used in the VolF paper. +from __future__ import print_function, division + +import itertools +import logging +import torch +import torch.nn as nn +from pymic.net.net_init import Initialization_He, Initialization_XavierUniform + +dim0 = {0:3, 1:2, 2:2} +dim1 = {0:3, 1:3, 2:2} +conv_knl = {2: (1, 3, 3), 3: 3} +conv_pad = {2: (0, 1, 1), 3: 1} +pool_knl = {2: (1, 2, 2), 3: 2} +down_stride = {2: (1, 2, 2), 3: 2} + +class ResConv(nn.Module): + def __init__(self, out_channels, dim = 3, dropout_p = 0.0, depth = 2): + super(ResConv, self).__init__() + assert(dim == 2 or dim == 3) + self.out_channels = out_channels + self.conv_list = nn.ModuleList([nn.Sequential( + nn.InstanceNorm3d(out_channels, affine = True), + nn.LeakyReLU(), + nn.Dropout(dropout_p), + nn.Conv3d(out_channels, out_channels, kernel_size=conv_knl[dim], padding=conv_pad[dim])) + for i in range(depth)]) + + def forward(self, x): + for conv in self.conv_list: + x = conv(x) + x + return x + +class DownSample(nn.Module): + """downsampling based on convolution + + :param in_channels: (int) Input channel number. + :param out_channels: (int) Output channel number. + :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. + :param dropout_p: (int) Dropout probability. + :param downsample: (bool) Use downsample or not after convolution. + """ + def __init__(self, in_channels, out_channels, dim = 3): + super(DownSample, self).__init__() + self.down = nn.Sequential( + nn.InstanceNorm3d(in_channels, affine = True), + nn.LeakyReLU(), + nn.Conv3d(in_channels, out_channels, kernel_size=conv_knl[dim], + padding=conv_pad[dim], stride = down_stride[dim]) + ) + + def forward(self, x): + return self.down(x) + +class UpCatConv(nn.Module): + """Upsampling followed by `ResConv` block + + :param in_channels1: (int) Input channel number for low-resolution feature map. + :param in_channels2: (int) Input channel number for high-resolution feature map. + :param out_channels: (int) Output channel number. + :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. + :param dropout_p: (int) Dropout probability. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear` for 3D and `Bilinear` for 2D). + The default value is 2. + """ + def __init__(self, in_channels1, in_channels2, out_channels, dim = 3): + super(UpCatConv, self).__init__() + + self.up = nn.Sequential( + nn.InstanceNorm3d(in_channels1, affine = True), + nn.LeakyReLU(), + nn.Conv3d(in_channels1, in_channels2, kernel_size=1, padding=0), + nn.Upsample(scale_factor=pool_knl[dim], mode='trilinear', align_corners=True) + ) + + self.conv = nn.Sequential( + nn.InstanceNorm3d(in_channels2*2, affine = True), + nn.LeakyReLU(), + nn.Conv3d(in_channels2 * 2, out_channels, kernel_size=conv_knl[dim], padding=conv_pad[dim]) + ) + + def forward(self, x_l, x_h): + """ + x_l: low-resolution feature map. + x_h: high-resolution feature map. + """ + y = torch.cat([x_h, self.up(x_l)], dim=1) + return self.conv(y) + +class Encoder(nn.Module): + """ + A modification of the encoder of 3D UNet by using ConvScSEBlock3D + + Parameters are given in the `params` dictionary. + See :mod:`pymic.net.net3d.unet3d.Encoder` for details. + + res_mode: resolution mode: 0-- isotrpic, 1-- near isotrpic, 2-- isotropic + """ + def __init__(self, ft_chns, res_mode = 0, dropout_p = 0, depth = 2): + super(Encoder, self).__init__() + d0, d1 = dim0[res_mode], dim1[res_mode] + + self.en_conv0 = ResConv(ft_chns[0], d0, 0, depth) + self.en_conv1 = ResConv(ft_chns[1], d1, 0, depth) + self.en_conv2 = ResConv(ft_chns[2], 3, dropout_p, depth) + self.en_conv3 = ResConv(ft_chns[3], 3, dropout_p, depth) + self.en_conv4 = ResConv(ft_chns[4], 3, dropout_p, depth) + + self.down0 = DownSample(ft_chns[0], ft_chns[1], d0) + self.down1 = DownSample(ft_chns[1], ft_chns[2], d1) + self.down2 = DownSample(ft_chns[2], ft_chns[3], 3) + self.down3 = DownSample(ft_chns[3], ft_chns[4], 3) + + def forward(self, x): + x0 = self.en_conv0(x) + x1 = self.en_conv1(self.down0(x0)) + x2 = self.en_conv2(self.down1(x1)) + x3 = self.en_conv3(self.down2(x2)) + x4 = self.en_conv4(self.down3(x3)) + return [x0, x1, x2, x3, x4] + +class Decoder(nn.Module): + """ + A modification of the encoder of 3D UNet by using ConvScSEBlock3D + + Parameters are given in the `params` dictionary. + See :mod:`pymic.net.net3d.unet3d.Encoder` for details. + """ + def __init__(self, ft_chns, res_mode = 0, dropout_p = 0, depth = 2): + super(Decoder, self).__init__() + d0, d1 = dim0[res_mode], dim1[res_mode] + + self.upcat0 = UpCatConv(ft_chns[1], ft_chns[0], ft_chns[0], d0) + self.upcat1 = UpCatConv(ft_chns[2], ft_chns[1], ft_chns[1], d1) + self.upcat2 = UpCatConv(ft_chns[3], ft_chns[2], ft_chns[2], 3) + self.upcat3 = UpCatConv(ft_chns[4], ft_chns[3], ft_chns[3], 3) + + self.de_conv0 = ResConv(ft_chns[0], d0, 0, depth) + self.de_conv1 = ResConv(ft_chns[1], d1, 0, depth) + self.de_conv2 = ResConv(ft_chns[2], 3, dropout_p, depth) + self.de_conv3 = ResConv(ft_chns[3], 3, dropout_p, depth) + self.de_conv4 = ResConv(ft_chns[4], 3, dropout_p, depth) + + def forward(self, x): + x0, x1, x2, x3, x4 = x + x4_de = self.de_conv4(x4) + x3_de = self.de_conv3(self.upcat3(x4_de, x3)) + x2_de = self.de_conv2(self.upcat2(x3_de, x2)) + x1_de = self.de_conv1(self.upcat1(x2_de, x1)) + x0_de = self.de_conv0(self.upcat0(x1_de, x0)) + return [x0_de, x1_de, x2_de, x3_de] + +class GRUNet(nn.Module): + """ + A General Residual UNet. + + * Reference: Guotai Wang, Jonathan Shapey, Wenqi Li, Reuben Dorent, Alex Demitriadis, + Sotirios Bisdas, Ian Paddick, Robert Bradford, Shaoting Zhang, Sébastien Ourselin, + Tom Vercauteren: Automatic Segmentation of Vestibular Schwannoma from T2-Weighted + MRI by Deep Spatial Attention with Hardness-Weighted Loss. + `MICCAI (2) 2019: 264-272. `_ + + Note that the attention module in the orininal paper is not used here. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param conv_dims: (list) The convolution dimension (2 or 3) for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear`). The default value + is 2 (`Trilinear`). + :param multiscale_pred: (bool) Get multi-scale prediction. + """ + def __init__(self, params): + super(GRUNet, self).__init__() + params = self.get_default_parameters(params) + + self.stage = 'train' + in_chns = params['in_chns'] + ft_chns = params['feature_chns'] + res_mode = params['res_mode'] + dropout = params['dropout'] + depth = params['depth'] + cls_num = params['class_num'] + self.mul_pred = params.get('multiscale_pred', True) + self.tune_mode= params.get('finetune_mode', 'all') + self.load_mode= params.get('weights_load_mode', 'all') + + d0 = dim0[res_mode] + self.project = nn.Conv3d(in_chns, ft_chns[0], kernel_size=conv_knl[d0], padding=conv_pad[d0]) + self.encoder = Encoder(ft_chns, res_mode, dropout, depth) + # self.decoder = Decoder(ft_chns, res_mode, dropout, depth = 2) + self.decoder = Decoder(ft_chns, res_mode, dropout, depth) + + self.out_layers = nn.ModuleList() + dims = [dim0[res_mode], dim1[res_mode], 3, 3] + for i in range(4): + out_layer = nn.Sequential( + nn.InstanceNorm3d(ft_chns[i], affine = True), + nn.LeakyReLU(), + nn.Conv3d(ft_chns[i], cls_num, kernel_size=conv_knl[dims[i]], padding=conv_pad[dims[i]])) + self.out_layers.append(out_layer) + + init = params['initialization'].lower() + weightInitializer = Initialization_He(1e-2) if init == 'he' else Initialization_XavierUniform() + self.apply(weightInitializer) + + def get_default_parameters(self, params): + default_param = { + 'finetune_mode': 'all', + 'initialization': 'he', + 'feature_chns': [32, 64, 128, 256, 512], + 'dropout': 0.2, + 'res_mode': 0, + 'depth': 2, + 'multiscale_pred': True + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params + + def set_stage(self, stage): + self.stage = stage + + def forward(self, x): + x_en = self.encoder(self.project(x)) + x_de = self.decoder(x_en) + output = self.out_layers[0](x_de[0]) + if(self.mul_pred and self.stage == 'train'): + output = [output] + for i in range(1, len(x_de)): + output.append(self.out_layers[i](x_de[i])) + return output + + def get_parameters_to_update(self): + if(self.tune_mode == 'all'): + return self.parameters() + + up_params = itertools.chain() + if(self.tune_mode == 'decoder'): + up_blocks = [self.decoder, self.out_layers] + else: + raise ValueError("undefined fine-tune mode for GRUNet: {0:}".format(self.tune_mode)) + for block in up_blocks: + up_params = itertools.chain(up_params, block.parameters()) + return up_params + + def get_parameters_to_load(self): + state_dict = self.state_dict() + if(self.load_mode == 'encoder'): + state_dict = {k:v for k, v in state_dict.items() if "project" in k or "encoder" in k } + return state_dict \ No newline at end of file diff --git a/pymic/net/net3d/trans3d/__init__.py b/pymic/net/net3d/trans3d/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pymic/net/net3d/trans3d/transunet3d.py b/pymic/net/net3d/trans3d/transunet3d.py new file mode 100644 index 0000000..183834e --- /dev/null +++ b/pymic/net/net3d/trans3d/transunet3d.py @@ -0,0 +1,1053 @@ +# 3D version of TransUNet; Copyright Johns Hopkins University +# Modified from nnUNet + + + +import torch +import numpy as np +import torch.nn.functional +import torch.nn.functional as F + +from copy import deepcopy +from torch import nn +from torch.cuda.amp import autocast +from scipy.optimize import linear_sum_assignment + +from ..networks.neural_network import SegmentationNetwork +from .vit_modeling import Transformer +from .vit_modeling import CONFIGS as CONFIGS_ViT + +softmax_helper = lambda x: F.softmax(x, 1) + +class InitWeights_He(object): + def __init__(self, neg_slope=1e-2): + self.neg_slope = neg_slope + + def __call__(self, module): + if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d): + module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope) + if module.bias is not None: + module.bias = nn.init.constant_(module.bias, 0) + +class ConvDropoutNormNonlin(nn.Module): + """ + fixes a bug in ConvDropoutNormNonlin where lrelu was used regardless of nonlin. Bad. + """ + + def __init__(self, input_channels, output_channels, + conv_op=nn.Conv2d, conv_kwargs=None, + norm_op=nn.BatchNorm2d, norm_op_kwargs=None, + dropout_op=nn.Dropout2d, dropout_op_kwargs=None, + nonlin=nn.LeakyReLU, nonlin_kwargs=None): + super(ConvDropoutNormNonlin, self).__init__() + if nonlin_kwargs is None: + nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True} + if dropout_op_kwargs is None: + dropout_op_kwargs = {'p': 0.5, 'inplace': True} + if norm_op_kwargs is None: + norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'momentum': 0.1} + if conv_kwargs is None: + conv_kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1, 'dilation': 1, 'bias': True} + + self.nonlin_kwargs = nonlin_kwargs + self.nonlin = nonlin + self.dropout_op = dropout_op + self.dropout_op_kwargs = dropout_op_kwargs + self.norm_op_kwargs = norm_op_kwargs + self.conv_kwargs = conv_kwargs + self.conv_op = conv_op + self.norm_op = norm_op + + self.conv = self.conv_op(input_channels, output_channels, **self.conv_kwargs) + if self.dropout_op is not None and self.dropout_op_kwargs['p'] is not None and self.dropout_op_kwargs[ + 'p'] > 0: + self.dropout = self.dropout_op(**self.dropout_op_kwargs) + else: + self.dropout = None + self.instnorm = self.norm_op(output_channels, **self.norm_op_kwargs) + self.lrelu = self.nonlin(**self.nonlin_kwargs) + + def forward(self, x): + x = self.conv(x) + if self.dropout is not None: + x = self.dropout(x) + return self.lrelu(self.instnorm(x)) + + +class ConvDropoutNonlinNorm(ConvDropoutNormNonlin): + def forward(self, x): + x = self.conv(x) + if self.dropout is not None: + x = self.dropout(x) + return self.instnorm(self.lrelu(x)) + + +class StackedConvLayers(nn.Module): + def __init__(self, input_feature_channels, output_feature_channels, num_convs, + conv_op=nn.Conv2d, conv_kwargs=None, + norm_op=nn.BatchNorm2d, norm_op_kwargs=None, + dropout_op=nn.Dropout2d, dropout_op_kwargs=None, + nonlin=nn.LeakyReLU, nonlin_kwargs=None, first_stride=None, basic_block=ConvDropoutNormNonlin): + ''' + stacks ConvDropoutNormLReLU layers. initial_stride will only be applied to first layer in the stack. The other parameters affect all layers + :param input_feature_channels: + :param output_feature_channels: + :param num_convs: + :param dilation: + :param kernel_size: + :param padding: + :param dropout: + :param initial_stride: + :param conv_op: + :param norm_op: + :param dropout_op: + :param inplace: + :param neg_slope: + :param norm_affine: + :param conv_bias: + ''' + self.input_channels = input_feature_channels + self.output_channels = output_feature_channels + + if nonlin_kwargs is None: + nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True} + if dropout_op_kwargs is None: + dropout_op_kwargs = {'p': 0.5, 'inplace': True} + if norm_op_kwargs is None: + norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'momentum': 0.1} + if conv_kwargs is None: + conv_kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1, 'dilation': 1, 'bias': True} + + self.nonlin_kwargs = nonlin_kwargs + self.nonlin = nonlin + self.dropout_op = dropout_op + self.dropout_op_kwargs = dropout_op_kwargs + self.norm_op_kwargs = norm_op_kwargs + self.conv_kwargs = conv_kwargs + self.conv_op = conv_op + self.norm_op = norm_op + + if first_stride is not None: + self.conv_kwargs_first_conv = deepcopy(conv_kwargs) + self.conv_kwargs_first_conv['stride'] = first_stride + else: + self.conv_kwargs_first_conv = conv_kwargs + + super(StackedConvLayers, self).__init__() + self.blocks = nn.Sequential( + *([basic_block(input_feature_channels, output_feature_channels, self.conv_op, + self.conv_kwargs_first_conv, + self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, + self.nonlin, self.nonlin_kwargs)] + + [basic_block(output_feature_channels, output_feature_channels, self.conv_op, + self.conv_kwargs, + self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, + self.nonlin, self.nonlin_kwargs) for _ in range(num_convs - 1)])) + + def forward(self, x): + return self.blocks(x) + + +def print_module_training_status(module): + if isinstance(module, nn.Conv2d) or isinstance(module, nn.Conv3d) or isinstance(module, nn.Dropout3d) or \ + isinstance(module, nn.Dropout2d) or isinstance(module, nn.Dropout) or isinstance(module, nn.InstanceNorm3d) \ + or isinstance(module, nn.InstanceNorm2d) or isinstance(module, nn.InstanceNorm1d) \ + or isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d) or isinstance(module, + nn.BatchNorm1d): + print(str(module), module.training) + + +class Upsample(nn.Module): + def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=False): + super(Upsample, self).__init__() + self.align_corners = align_corners + self.mode = mode + self.scale_factor = scale_factor + self.size = size + + def forward(self, x): + return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, + align_corners=self.align_corners) + +def c2_xavier_fill(module: nn.Module) -> None: + """ + Initialize `module.weight` using the "XavierFill" implemented in Caffe2. + Also initializes `module.bias` to 0. + Args: + module (torch.nn.Module): module to initialize. + """ + # Caffe2 implementation of XavierFill in fact + # corresponds to kaiming_uniform_ in PyTorch + nn.init.kaiming_uniform_(module.weight, a=1) + if module.bias is not None: + # pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[nn.Module, + # torch.Tensor]`. + nn.init.constant_(module.bias, 0) + +class Generic_TransUNet_max_ppbp(SegmentationNetwork): + DEFAULT_BATCH_SIZE_3D = 2 + DEFAULT_PATCH_SIZE_3D = (64, 192, 160) + SPACING_FACTOR_BETWEEN_STAGES = 2 + BASE_NUM_FEATURES_3D = 30 + MAX_NUMPOOL_3D = 999 + MAX_NUM_FILTERS_3D = 320 + + DEFAULT_PATCH_SIZE_2D = (256, 256) + BASE_NUM_FEATURES_2D = 30 + DEFAULT_BATCH_SIZE_2D = 50 + MAX_NUMPOOL_2D = 999 + MAX_FILTERS_2D = 480 + + use_this_for_batch_size_computation_2D = 19739648 + use_this_for_batch_size_computation_3D = 520000000 # 505789440 + + def __init__(self, input_channels, base_num_features, num_classes, num_pool, num_conv_per_stage=2, + feat_map_mul_on_downscale=2, conv_op=nn.Conv2d, + norm_op=nn.BatchNorm2d, norm_op_kwargs=None, + dropout_op=nn.Dropout2d, dropout_op_kwargs=None, + nonlin=nn.LeakyReLU, nonlin_kwargs=None, deep_supervision=True, dropout_in_localization=False, + final_nonlin=softmax_helper, weightInitializer=InitWeights_He(1e-2), pool_op_kernel_sizes=None, + conv_kernel_sizes=None, + upscale_logits=False, convolutional_pooling=False, convolutional_upsampling=False, # TODO default False + max_num_features=None, basic_block=ConvDropoutNormNonlin, + seg_output_use_bias=False, + patch_size=None, is_vit_pretrain=False, + vit_depth=12, vit_hidden_size=768, vit_mlp_dim=3072, vit_num_heads=12, + max_msda='', is_max_ms=True, is_max_ms_fpn=False, max_n_fpn=4, max_ms_idxs=[-4,-3,-2], max_ss_idx=0, + is_max_bottleneck_transformer=False, max_seg_weight=1.0, max_hidden_dim=256, max_dec_layers=10, + mw = 0.5, + is_max=True, is_masked_attn=False, is_max_ds=False, is_masking=False, is_masking_argmax=False, + is_fam=False, fam_k=5, fam_reduct_ratio=8, + is_max_hungarian=False, num_queries=None, is_max_cls=False, + point_rend=False, num_point_rend=None, no_object_weight=None, is_mhsa_float32=False, no_max_hw_pe=False, + max_infer=None, cost_weight=[2.0, 5.0, 5.0], vit_layer_scale=False, decoder_layer_scale=False): + + super(Generic_TransUNet_max_ppbp, self).__init__() + + # newly added + self.is_fam = is_fam + self.is_max, self.max_msda, self.is_max_ms, self.is_max_ms_fpn, self.max_n_fpn, self.max_ss_idx, self.mw = is_max, max_msda, is_max_ms, is_max_ms_fpn, max_n_fpn, max_ss_idx, mw + self.max_ms_idxs = max_ms_idxs + + self.is_max_cls = is_max_cls + self.is_masked_attn, self.is_max_ds = is_masked_attn, is_max_ds + self.is_max_bottleneck_transformer = is_max_bottleneck_transformer + + self.convolutional_upsampling = convolutional_upsampling + self.convolutional_pooling = convolutional_pooling + self.upscale_logits = upscale_logits + if nonlin_kwargs is None: + nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True} + if dropout_op_kwargs is None: + dropout_op_kwargs = {'p': 0.5, 'inplace': True} + if norm_op_kwargs is None: + norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'momentum': 0.1} + + self.conv_kwargs = {'stride': 1, 'dilation': 1, 'bias': True} + + self.nonlin = nonlin + self.nonlin_kwargs = nonlin_kwargs + self.dropout_op_kwargs = dropout_op_kwargs + self.norm_op_kwargs = norm_op_kwargs + self.weightInitializer = weightInitializer + self.conv_op = conv_op + self.norm_op = norm_op + self.dropout_op = dropout_op + self.num_classes = num_classes + self.final_nonlin = final_nonlin + self._deep_supervision = deep_supervision + self.do_ds = deep_supervision + + if conv_op == nn.Conv2d: + upsample_mode = 'bilinear' + pool_op = nn.MaxPool2d + transpconv = nn.ConvTranspose2d + if pool_op_kernel_sizes is None: + pool_op_kernel_sizes = [(2, 2)] * num_pool + if conv_kernel_sizes is None: + conv_kernel_sizes = [(3, 3)] * (num_pool + 1) + elif conv_op == nn.Conv3d: + upsample_mode = 'trilinear' + pool_op = nn.MaxPool3d + transpconv = nn.ConvTranspose3d + if pool_op_kernel_sizes is None: + pool_op_kernel_sizes = [(2, 2, 2)] * num_pool + if conv_kernel_sizes is None: + conv_kernel_sizes = [(3, 3, 3)] * (num_pool + 1) + else: + raise ValueError("unknown convolution dimensionality, conv op: %s" % str(conv_op)) + + self.input_shape_must_be_divisible_by = np.prod(pool_op_kernel_sizes, 0, dtype=np.int64) + self.pool_op_kernel_sizes = pool_op_kernel_sizes + self.conv_kernel_sizes = conv_kernel_sizes + + self.conv_pad_sizes = [] + for krnl in self.conv_kernel_sizes: + self.conv_pad_sizes.append([1 if i == 3 else 0 for i in krnl]) + + if max_num_features is None: + if self.conv_op == nn.Conv3d: + self.max_num_features = self.MAX_NUM_FILTERS_3D + else: + self.max_num_features = self.MAX_FILTERS_2D + else: + self.max_num_features = max_num_features + + self.conv_blocks_context = [] + self.conv_blocks_localization = [] + self.td = [] + self.tu = [] + + + self.fams = [] + + output_features = base_num_features + input_features = input_channels + + for d in range(num_pool): + # determine the first stride + if d != 0 and self.convolutional_pooling: + first_stride = pool_op_kernel_sizes[d - 1] + else: + first_stride = None + + self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[d] + self.conv_kwargs['padding'] = self.conv_pad_sizes[d] + # add convolutions + self.conv_blocks_context.append(StackedConvLayers(input_features, output_features, num_conv_per_stage, + self.conv_op, self.conv_kwargs, self.norm_op, + self.norm_op_kwargs, self.dropout_op, + self.dropout_op_kwargs, self.nonlin, self.nonlin_kwargs, + first_stride, basic_block=basic_block)) + if not self.convolutional_pooling: + self.td.append(pool_op(pool_op_kernel_sizes[d])) + input_features = output_features + output_features = int(np.round(output_features * feat_map_mul_on_downscale)) + + output_features = min(output_features, self.max_num_features) + + # now the bottleneck. + # determine the first stride + if self.convolutional_pooling: + first_stride = pool_op_kernel_sizes[-1] + else: + first_stride = None + + # the output of the last conv must match the number of features from the skip connection if we are not using + # convolutional upsampling. If we use convolutional upsampling then the reduction in feature maps will be + # done by the transposed conv + if self.convolutional_upsampling: + final_num_features = output_features + else: + final_num_features = self.conv_blocks_context[-1].output_channels + + self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[num_pool] + self.conv_kwargs['padding'] = self.conv_pad_sizes[num_pool] + self.conv_blocks_context.append(nn.Sequential( + StackedConvLayers(input_features, output_features, num_conv_per_stage - 1, self.conv_op, self.conv_kwargs, + self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, self.nonlin, + self.nonlin_kwargs, first_stride, basic_block=basic_block), + StackedConvLayers(output_features, final_num_features, 1, self.conv_op, self.conv_kwargs, + self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, self.nonlin, + self.nonlin_kwargs, basic_block=basic_block))) + + # if we don't want to do dropout in the localization pathway then we set the dropout prob to zero here + if not dropout_in_localization: + old_dropout_p = self.dropout_op_kwargs['p'] + self.dropout_op_kwargs['p'] = 0.0 + + # now lets build the localization pathway + for u in range(num_pool): + nfeatures_from_down = final_num_features + nfeatures_from_skip = self.conv_blocks_context[ + -(2 + u)].output_channels # self.conv_blocks_context[-1] is bottleneck, so start with -2 + n_features_after_tu_and_concat = nfeatures_from_skip * 2 + + # the first conv reduces the number of features to match those of skip + # the following convs work on that number of features + # if not convolutional upsampling then the final conv reduces the num of features again + if u != num_pool - 1 and not self.convolutional_upsampling: + final_num_features = self.conv_blocks_context[-(3 + u)].output_channels + else: + final_num_features = nfeatures_from_skip + + if not self.convolutional_upsampling: + self.tu.append(Upsample(scale_factor=pool_op_kernel_sizes[-(u + 1)], mode=upsample_mode)) + else: + self.tu.append(transpconv(nfeatures_from_down, nfeatures_from_skip, pool_op_kernel_sizes[-(u + 1)], + pool_op_kernel_sizes[-(u + 1)], bias=False)) + + self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[- (u + 1)] + self.conv_kwargs['padding'] = self.conv_pad_sizes[- (u + 1)] + self.conv_blocks_localization.append(nn.Sequential( + StackedConvLayers(n_features_after_tu_and_concat, nfeatures_from_skip, num_conv_per_stage - 1, + self.conv_op, self.conv_kwargs, self.norm_op, self.norm_op_kwargs, self.dropout_op, + self.dropout_op_kwargs, self.nonlin, self.nonlin_kwargs, basic_block=basic_block), + StackedConvLayers(nfeatures_from_skip, final_num_features, 1, self.conv_op, self.conv_kwargs, + self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, + self.nonlin, self.nonlin_kwargs, basic_block=basic_block) + )) + + + + if self.is_fam: + self.fams = nn.ModuleList(self.fams) + + if self.do_ds: + self.seg_outputs = [] + for ds in range(len(self.conv_blocks_localization)): + self.seg_outputs.append(conv_op(self.conv_blocks_localization[ds][-1].output_channels, num_classes, + 1, 1, 0, 1, 1, seg_output_use_bias)) + self.seg_outputs = nn.ModuleList(self.seg_outputs) + + self.upscale_logits_ops = [] + cum_upsample = np.cumprod(np.vstack(pool_op_kernel_sizes), axis=0)[::-1] + for usl in range(num_pool - 1): + if self.upscale_logits: + self.upscale_logits_ops.append(Upsample(scale_factor=tuple([int(i) for i in cum_upsample[usl + 1]]), + mode=upsample_mode)) + else: + self.upscale_logits_ops.append(lambda x: x) + + if not dropout_in_localization: + self.dropout_op_kwargs['p'] = old_dropout_p + + # register all modules properly + self.conv_blocks_localization = nn.ModuleList(self.conv_blocks_localization) + self.conv_blocks_context = nn.ModuleList(self.conv_blocks_context) + self.td = nn.ModuleList(self.td) + self.tu = nn.ModuleList(self.tu) + + if self.upscale_logits: + self.upscale_logits_ops = nn.ModuleList( + self.upscale_logits_ops) # lambda x:x is not a Module so we need to distinguish here + + if self.weightInitializer is not None: + self.apply(self.weightInitializer) + # self.apply(print_module_training_status) + + # Transformer configuration + if self.is_max_bottleneck_transformer: + self.patch_size = patch_size # e.g. [48, 192, 192] + config_vit = CONFIGS_ViT['R50-ViT-B_16'] + config_vit.transformer.num_layers = vit_depth + config_vit.hidden_size = vit_hidden_size # 768 + config_vit.transformer.mlp_dim = vit_mlp_dim # 3072 + config_vit.transformer.num_heads = vit_num_heads # 12 + self.conv_more = nn.Conv3d(config_vit.hidden_size, output_features, 1) + num_pool_per_axis = np.prod(np.array(pool_op_kernel_sizes), axis=0) + num_pool_per_axis = np.log2(num_pool_per_axis).astype(np.uint8) + feat_size = [int(self.patch_size[0]/2**num_pool_per_axis[0]), int(self.patch_size[1]/2**num_pool_per_axis[1]), int(self.patch_size[2]/2**num_pool_per_axis[2])] + self.transformer = Transformer(config_vit, feat_size=feat_size, vis=False, feat_channels=output_features, use_layer_scale=vit_layer_scale) + if is_vit_pretrain: + self.transformer.load_from(weights=np.load(config_vit.pretrained_path)) + + + if self.is_max: + # Max PPB+ configuration (i.e. MultiScaleStandardTransformerDecoder) + cfg = { + "num_classes": num_classes, + "hidden_dim": max_hidden_dim, + "num_queries": num_classes if num_queries is None else num_queries, # N=K if 'fixed matching', else default=100, + "nheads": 8, + "dim_feedforward": max_hidden_dim * 8, # 2048, + "dec_layers": max_dec_layers, # 9 decoder layers, add one for the loss on learnable query? + "pre_norm": False, + "enforce_input_project": False, + "mask_dim": max_hidden_dim, # input feat of segm head? + "non_object": False, + "use_layer_scale": decoder_layer_scale, + } + cfg['non_object'] = is_max_cls + input_proj_list = [] # from low resolution to high resolution (res4 -> res1), [1, 1024, 14, 14], [1, 512, 28, 28], 1, 256, 56, 56], [1, 64, 112, 112] + decoder_channels = [320, 320, 256, 128, 64, 32] + if self.is_max_ms: # use multi-scale feature as Transformer decoder input + if self.is_max_ms_fpn: + for idx, in_channels in enumerate(decoder_channels[:max_n_fpn]): # max_n_fpn=4: 1/32, 1/16, 1/8, 1/4 + input_proj_list.append(nn.Sequential( + nn.Conv3d(in_channels, max_hidden_dim, kernel_size=1), + nn.GroupNorm(32, max_hidden_dim), + nn.Upsample(size=(int(patch_size[0]/2), int(patch_size[1]/4), int(patch_size[2]/4)), mode='trilinear') + )) # proj to scale (1, 1/2, 1/2), TODO: init + self.input_proj = nn.ModuleList(input_proj_list) + self.linear_encoder_feature = nn.Conv3d(max_hidden_dim * max_n_fpn, max_hidden_dim, 1, 1) # concat four-level feature + else: + for idx, in_channels in enumerate([decoder_channels[i] for i in self.max_ms_idxs]): + input_proj_list.append(nn.Sequential( + nn.Conv3d(in_channels, max_hidden_dim, kernel_size=1), + nn.GroupNorm(32, max_hidden_dim), + )) + self.input_proj = nn.ModuleList(input_proj_list) + + # self.linear_mask_features =nn.Conv3d(decoder_channels[max_n_fpn-1], cfg["mask_dim"], kernel_size=1, stride=1, padding=0,) # low-level feat, dot product Trans-feat + self.linear_mask_features =nn.Conv3d(decoder_channels[-1], cfg["mask_dim"], kernel_size=1, stride=1, padding=0,) # following SingleScale, high-level feat, obtain seg_map + else: + self.linear_encoder_feature = nn.Conv3d(decoder_channels[max_ss_idx], cfg["mask_dim"], kernel_size=1) + self.linear_mask_features = nn.Conv3d(decoder_channels[-1], cfg["mask_dim"], kernel_size=1, stride=1, padding=0,) # low-level feat, dot product Trans-feat + + if self.is_masked_attn: + from .mask2former_modeling.transformer_decoder.mask2former_transformer_decoder3d import MultiScaleMaskedTransformerDecoder3d + cfg['num_feature_levels'] = 1 if not self.is_max_ms or self.is_max_ms_fpn else 3 + cfg["is_masking"] = True if is_masking else False + cfg["is_masking_argmax"] = True if is_masking_argmax else False + cfg["is_mhsa_float32"] = True if is_mhsa_float32 else False + cfg["no_max_hw_pe"] = True if no_max_hw_pe else False + self.predictor = MultiScaleMaskedTransformerDecoder3d(in_channels=max_hidden_dim, mask_classification=is_max_cls, **cfg) + else: + from .mask2former_modeling.transformer_decoder.maskformer_transformer_decoder3d import StandardTransformerDecoder + cfg["dropout"], cfg["enc_layers"], cfg["deep_supervision"] = 0.1, 0, False + self.predictor = StandardTransformerDecoder(in_channels=max_hidden_dim, mask_classification=is_max_cls, **cfg) + + def forward(self, x): + skips = [] + seg_outputs = [] + for d in range(len(self.conv_blocks_context) - 1): + x = self.conv_blocks_context[d](x) + skips.append(x) + if not self.convolutional_pooling: + x = self.td[d](x) + + x = self.conv_blocks_context[-1](x) + ######### TransUNet ######### + if self.is_max_bottleneck_transformer: + x, attn = self.transformer(x) # [b, hidden, d/8, h/16, w/16] + x = self.conv_more(x) + ############################# + + ds_feats = [] # obtain multi-scale feature + ds_feats.append(x) + for u in range(len(self.tu)): + if unm", inputs, targets) + denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :] + loss = 1 - (numerator + 1) / (denominator + 1) + return loss + + +batch_dice_loss_jit = torch.jit.script( + batch_dice_loss +) # type: torch.jit.ScriptModule + + + +def batch_sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor): + """ + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + Returns: + Loss tensor + """ + hw = inputs.shape[1] + + pos = F.binary_cross_entropy_with_logits( + inputs, torch.ones_like(inputs), reduction="none" + ) + neg = F.binary_cross_entropy_with_logits( + inputs, torch.zeros_like(inputs), reduction="none" + ) + + loss = torch.einsum("nc,mc->nm", pos, targets) + torch.einsum( + "nc,mc->nm", neg, (1 - targets) + ) + + return loss / hw + + +batch_sigmoid_ce_loss_jit = torch.jit.script( + batch_sigmoid_ce_loss +) # type: torch.jit.ScriptModule + + +class HungarianMatcher3D(nn.Module): + """This class computes an assignment between the targets and the predictions of the network + + For efficiency reasons, the targets don't include the no_object. Because of this, in general, + there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, + while the others are un-matched (and thus treated as non-objects). + """ + + def __init__(self, cost_class: float = 1, cost_mask: float = 1, cost_dice: float = 1, ): + """Creates the matcher + + Params: + cost_class: This is the relative weight of the classification error in the matching cost + cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost + cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost + """ + super().__init__() + self.cost_class = cost_class + self.cost_mask = cost_mask + self.cost_dice = cost_dice + + def compute_cls_loss(self, inputs, targets): + """ Classification loss (NLL) + implemented in compute_loss() + """ + raise NotImplementedError + + + def compute_dice_loss(self, inputs, targets): + """ mask dice loss + inputs (B*K, C, H, W) + target (B*K, D, H, W) + """ + inputs = inputs.sigmoid() + inputs = inputs.flatten(1) + targets = targets.flatten(1) + num_masks = len(inputs) + + numerator = 2 * (inputs * targets).sum(-1) + denominator = inputs.sum(-1) + targets.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + return loss.sum() / num_masks + + + def compute_ce_loss(self, inputs, targets): + """mask ce loss""" + num_masks = len(inputs) + loss = F.binary_cross_entropy_with_logits(inputs.flatten(1), targets.flatten(1), reduction="none") + loss = loss.mean(1).sum() / num_masks + return loss + + def compute_dice(self, inputs, targets): + """ output (N_q, C, H, W) + target (K, D, H, W) + """ + inputs = inputs.sigmoid() + inputs = inputs.flatten(1) + targets = targets.flatten(1) + numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets) + denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :] + loss = 1 - (numerator + 1) / (denominator + 1) + return loss # [N_q, K] + + + def compute_ce(self, inputs, targets): + """ output (N_q, C, H, W) + target (K, D, H, W) + return (N_q, K) + """ + inputs = inputs.flatten(1) + targets = targets.flatten(1) + hw = inputs.shape[1] + + pos = F.binary_cross_entropy_with_logits( + inputs, torch.ones_like(inputs), reduction="none" + ) + + neg = F.binary_cross_entropy_with_logits( + inputs, torch.zeros_like(inputs), reduction="none" + ) + + loss = torch.einsum("nc,mc->nm", pos, targets) + torch.einsum( + "nc,mc->nm", neg, (1 - targets) + ) + + return loss / hw + + # target_onehot = torch.zeros_like(output, device=output.device) + # target_onehot.scatter_(1, target.long(), 1) + # assert (torch.argmax(target_onehot, dim=1) == target[:, 0].long()).all() + # ce_loss = F.binary_cross_entropy_with_logits(output, target_onehot) + # return ce_loss + + + @torch.no_grad() + def memory_efficient_forward(self, outputs, targets): + """More memory-friendly matching for single aux, outputs: (b, q, d, h, w)""" + """suppose each crop must contain foreground class""" + bs, num_queries = outputs["pred_logits"].shape[:2] + indices = [] + + # Iterate through batch size + for b in range(bs): + out_prob = outputs["pred_logits"][b].softmax(-1) # [num_queries, num_classes+1] + out_mask = outputs["pred_masks"][b] # [num_queries, H_pred, W_pred] + + tgt_ids = targets[b]["labels"] + tgt_mask = targets[b]["masks"].to(out_mask) # [K, D, H, W], K is number of classes shown in this image, and K < n_class + + # target_onehot = torch.zeros_like(tgt_mask, device=out_mask.device) + # target_onehot.scatter_(1, targets.long(), 1) + + cost_class = -out_prob[:, tgt_ids] # [num_queries, K] + + with autocast(enabled=False): + out_mask = out_mask.float() + tgt_mask = tgt_mask.float() + cost_dice = self.compute_dice(out_mask, tgt_mask) + cost_mask = self.compute_ce(out_mask, tgt_mask) + + # Final cost matrix + C = ( + self.cost_class * cost_class + + self.cost_mask * cost_mask + + self.cost_dice * cost_dice + ) + + C = C.reshape(num_queries, -1).cpu() # (num_queries, K) + + # linear_sum_assignment return a tuple of two arrays: row_ind, col_ind, the length of array is min(N_q, K) + # The cost of the assignment can be computed as cost_matrix[row_ind, col_ind].sum() + + indices.append(linear_sum_assignment(C)) + + final_indices = [ + (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) + for i, j in indices + ] + + return final_indices + + @torch.no_grad() + def forward(self, outputs, targets): + """Performs the matching + + Params: + outputs: This is a dict that contains at least these entries: + "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits + "pred_masks": Tensor of dim [batch_size, num_queries, H_pred, W_pred] with the predicted masks + + targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: + "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth + objects in the target) containing the class labels + "masks": Tensor of dim [num_target_boxes, H_gt, W_gt] containing the target masks + + Returns: + A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected targets (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_boxes) + """ + + return self.memory_efficient_forward(outputs, targets) + + def __repr__(self, _repr_indent=4): + head = "Matcher " + self.__class__.__name__ + body = [ + "cost_class: {}".format(self.cost_class), + "cost_mask: {}".format(self.cost_mask), + "cost_dice: {}".format(self.cost_dice), + ] + lines = [head] + [" " * _repr_indent + line for line in body] + return "\n".join(lines) + + + def _get_src_permutation_idx(self, indices): + # permute predictions following indices + batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + src_idx = torch.cat([src for (src, _) in indices]) + return batch_idx, src_idx + + def _get_tgt_permutation_idx(self, indices): + # permute targets following indices + batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + tgt_idx = torch.cat([tgt for (_, tgt) in indices]) + return batch_idx, tgt_idx + + +def compute_loss_hungarian(outputs, targets, idx, matcher, num_classes, point_rend=False, num_points=12544, oversample_ratio=3.0, importance_sample_ratio=0.75, no_object_weight=None, cost_weight=[2,5,5]): + """output is a dict only contain keys ['pred_masks', 'pred_logits'] """ + # outputs_without_aux = {k: v for k, v in output.items() if k != "aux_outputs"} + + indices = matcher(outputs, targets) + src_idx = matcher._get_src_permutation_idx(indices) # return a tuple of (batch_idx, src_idx) + tgt_idx = matcher._get_tgt_permutation_idx(indices) # return a tuple of (batch_idx, tgt_idx) + assert len(tgt_idx[0]) == sum([len(t["masks"]) for t in targets]) # verify that all masks of (K1, K2, ..) are used + + # step2 : compute mask loss + src_masks = outputs["pred_masks"] + src_masks = src_masks[src_idx] # [len(src_idx[0]), D, H, W] -> (K1+K2+..., D, H, W) + target_masks = torch.cat([t["masks"] for t in targets], dim=0) # (K1+K2+..., D, H, W) actually + src_masks = src_masks[:, None] # [K..., 1, D, H, W] + target_masks = target_masks[:, None] + + if point_rend: # only calculate hard example + with torch.no_grad(): + # num_points=12544 config in cityscapes + + # sample point_coords + point_coords = get_uncertain_point_coords_with_randomness( + src_masks.float(), + lambda logits: calculate_uncertainty(logits), + num_points, + oversample_ratio, + importance_sample_ratio, + ) # [K, num_points=12544, 3] + + point_labels = point_sample_3d( + target_masks.float(), + point_coords.float(), + align_corners=False, + ).squeeze(1) # [K, 12544] + + point_logits = point_sample_3d( + src_masks.float(), + point_coords.float(), + align_corners=False, + ).squeeze(1) # [K, 12544] + + src_masks, target_masks = point_logits, point_labels + + loss_mask_ce = matcher.compute_ce_loss(src_masks, target_masks) + loss_mask_dice = matcher.compute_dice_loss(src_masks, target_masks) + + # step3: compute class loss + src_logits = outputs["pred_logits"].float() # (B, num_query, num_class+1) + target_classes_o = torch.cat([t["labels"] for t in targets], dim=0) # (K1+K2+, ) + target_classes = torch.full( + src_logits.shape[:2], num_classes, dtype=torch.int64, device=src_logits.device + ) # (B, num_query, num_class+1) + target_classes[src_idx] = target_classes_o + + + if no_object_weight is not None: + empty_weight = torch.ones(num_classes + 1).to(src_logits.device) + empty_weight[-1] = no_object_weight + loss_cls = F.cross_entropy(src_logits.transpose(1, 2), target_classes, empty_weight) + else: + loss_cls = F.cross_entropy(src_logits.transpose(1, 2), target_classes) + + loss = (cost_weight[0]/10)*loss_cls + (cost_weight[1]/10)*loss_mask_ce + (cost_weight[2]/10)*loss_mask_dice # 2:5:5, like hungarian matching + # print("idx {}, loss {}, loss_cls {}, loss_mask_ce {}, loss_mask_dice {}".format(idx, loss, loss_cls, loss_mask_ce, loss_mask_dice)) + return loss + + +def point_sample_3d(input, point_coords, **kwargs): + """ + from detectron2.projects.point_rend.point_features + A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors. + Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside + [0, 1] x [0, 1] square. + Args: + input (Tensor): A tensor of shape (N, C, D, H, W) that contains features map on a D x H x W grid. + point_coords (Tensor): A tensor of shape (N, P, 3) or (N, Dgrid, Hgrid, Wgrid, 3) that contains + [0, 1] x [0, 1] x [0, 1] normalized point coordinates. + Returns: + output (Tensor): A tensor of shape (N, C, P) or (N, C, Dgrid, Hgrid, Wgrid) that contains + features for points in `point_coords`. The features are obtained via bilinear + interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`. + """ + add_dim = False + if point_coords.dim() == 3: + add_dim = True + point_coords = point_coords.unsqueeze(2).unsqueeze(2) # why + + # point_coords should be (N, D, H, W, 3) + output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs) + + if add_dim: + output = output.squeeze(3).squeeze(3) + + return output + + +def calculate_uncertainty(logits): + """ + We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the + foreground class in `classes`. + Args: + logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or + class-agnostic, where R is the total number of predicted masks in all images and C is + the number of foreground classes. The values are logits. + Returns: + scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with + the most uncertain locations having the highest uncertainty score. + """ + assert logits.shape[1] == 1 + gt_class_logits = logits.clone() + return -(torch.abs(gt_class_logits)) + + +# implemented! +def get_uncertain_point_coords_with_randomness( + coarse_logits, uncertainty_func, num_points, oversample_ratio, importance_sample_ratio): + """ + Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties + are calculated for each point using 'uncertainty_func' function that takes point's logit + prediction as input. + See PointRend paper for details. + Args: + coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for + class-specific or class-agnostic prediction. + uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that + contains logit predictions for P points and returns their uncertainties as a Tensor of + shape (N, 1, P). + num_points (int): The number of points P to sample. + oversample_ratio (int): Oversampling parameter. + importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling. + Returns: + point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P + sampled points. + """ + assert oversample_ratio >= 1 + assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0 + n_dim = 3 + num_boxes = coarse_logits.shape[0] + num_sampled = int(num_points * oversample_ratio) # 12544 * 3, oversampled + point_coords = torch.rand(num_boxes, num_sampled, n_dim, device=coarse_logits.device) # (K, 37632, 3); uniform dist [0, 1) + point_logits = point_sample_3d(coarse_logits, point_coords, align_corners=False) # (K, 1, 37632) + + # It is crucial to calculate uncertainty based on the sampled prediction value for the points. + # Calculating uncertainties of the coarse predictions first and sampling them for points leads + # to incorrect results. + # To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between + # two coarse predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value. + # However, if we calculate uncertainties for the coarse predictions first, + # both will have -1 uncertainty, and the sampled point will get -1 uncertainty. + point_uncertainties = uncertainty_func(point_logits) + num_uncertain_points = int(importance_sample_ratio * num_points) # 9408 + + num_random_points = num_points - num_uncertain_points # 3136 + idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] + + shift = num_sampled * torch.arange(num_boxes, dtype=torch.long, device=coarse_logits.device) + idx += shift[:, None] # [K, 9408] + + point_coords = point_coords.view(-1, n_dim)[idx.view(-1), :].view( + num_boxes, num_uncertain_points, n_dim + ) # [K, 9408, 3] + + if num_random_points > 0: + # from detectron2.layers import cat + point_coords = torch.cat( + [ + point_coords, + torch.rand(num_boxes, num_random_points, n_dim, device=coarse_logits.device), + ], + dim=1, + ) # [K, 12544, 3] + + return point_coords \ No newline at end of file diff --git a/pymic/net/net3d/unet3d.py b/pymic/net/net3d/unet3d.py index b66ea38..5954869 100644 --- a/pymic/net/net3d/unet3d.py +++ b/pymic/net/net3d/unet3d.py @@ -236,7 +236,8 @@ def __init__(self, params): for p in params: print(p, params[p]) self.stage = 'train' - self.update_mode = params.get("update_mode", "all") + self.tune_mode= params.get('finetune_mode', 'all') + self.load_mode= params.get('weights_load_mode', 'all') self.encoder = Encoder(params) self.decoder = Decoder(params) @@ -263,17 +264,24 @@ def set_stage(self, stage): self.stage = stage self.decoder.set_stage(stage) + + def forward(self, x): + f = self.encoder(x) + output = self.decoder(f) + return output + def get_parameters_to_update(self): - if(self.update_mode == "all"): + if(self.tune_mode == "all"): return self.parameters() - elif(self.update_mode == "decoder"): + elif(self.tune_mode == "decoder"): print("only update parameters in decoder") params = self.decoder.parameters() return params else: raise(ValueError("update_mode can only be 'all' or 'decoder'.")) - def forward(self, x): - f = self.encoder(x) - output = self.decoder(f) - return output + def get_parameters_to_load(self): + state_dict = self.state_dict() + if(self.load_mode == 'encoder'): + state_dict = {k:v for k, v in state_dict.items() if "encoder" in k } + return state_dict diff --git a/pymic/net/net_dict_seg.py b/pymic/net/net_dict_seg.py index e381421..741eed7 100644 --- a/pymic/net/net_dict_seg.py +++ b/pymic/net/net_dict_seg.py @@ -16,10 +16,10 @@ """ from __future__ import print_function, division from pymic.net.net2d.unet2d import UNet2D -from pymic.net.net2d.unet2d_dual_branch import UNet2D_DualBranch +from pymic.net.net2d.unet2d_multi_decoder import UNet2D_DualBranch, MCNet2D from pymic.net.net2d.unet2d_canet import CANet from pymic.net.net2d.unet2d_cct import UNet2D_CCT -from pymic.net.net2d.unet2d_mcnet import MCNet2D +from pymic.net.net2d.unet2d_mtnet import MTNet2D from pymic.net.net2d.cople_net import COPLENet from pymic.net.net2d.unet2d_attention import AttentionUNet2D from pymic.net.net2d.unet2d_nest import NestedUNet2D @@ -28,8 +28,13 @@ from pymic.net.net2d.trans2d.swinunet import SwinUNet from pymic.net.net3d.unet2d5 import UNet2D5 from pymic.net.net3d.unet3d import UNet3D +from pymic.net.net3d.grunet import GRUNet +from pymic.net.net3d.fmunetv3 import FMUNetV3 from pymic.net.net3d.unet3d_scse import UNet3D_ScSE from pymic.net.net3d.unet3d_dual_branch import UNet3D_DualBranch +# from pymic.net.net3d.stunet_wrap import STUNet_wrap +# from pymic.net.net3d.mystunet import MySTUNet + # from pymic.net.net3d.trans3d.nnFormer_wrap import nnFormer_wrap # from pymic.net.net3d.trans3d.unetr import UNETR # from pymic.net.net3d.trans3d.unetr_pp import UNETR_PP @@ -49,6 +54,7 @@ 'UNet2D_DualBranch': UNet2D_DualBranch, 'UNet2D_CCT': UNet2D_CCT, 'MCNet2D': MCNet2D, + 'MTNet2D': MTNet2D, 'CANet': CANet, 'COPLENet': COPLENet, 'AttentionUNet2D': AttentionUNet2D, @@ -57,20 +63,14 @@ 'TransUNet': TransUNet, 'SwinUNet': SwinUNet, 'UNet2D5': UNet2D5, + 'GRUNet': GRUNet, + 'FMUNetV3': FMUNetV3, 'UNet3D': UNet3D, 'UNet3D_ScSE': UNet3D_ScSE, 'UNet3D_DualBranch': UNet3D_DualBranch, + # 'STUNet': STUNet_wrap, + # 'MySTUNet': MySTUNet, # 'nnFormer': nnFormer_wrap, # 'UNETR': UNETR, # 'UNETR_PP': UNETR_PP, - # 'MedFormerV1': MedFormerV1, - # 'MedFormerV2': MedFormerV2, - # 'MedFormerV3': MedFormerV3, - # 'MedFormerVA1':MedFormerVA1, - # 'HiFormer_v1': HiFormer_v1, - # 'HiFormer_v2': HiFormer_v2, - # 'HiFormer_v3': HiFormer_v3, - # 'HiFormer_v4': HiFormer_v4, - # 'HiFormer_v5': HiFormer_v5 - # 'SwitchNet': SwitchNet } diff --git a/pymic/net_run/self_sup/__init__.py b/pymic/net_run/self_sup/__init__.py index d73e42a..86482f9 100644 --- a/pymic/net_run/self_sup/__init__.py +++ b/pymic/net_run/self_sup/__init__.py @@ -1,10 +1,16 @@ from __future__ import absolute_import from pymic.net_run.self_sup.self_genesis import SelfSupModelGenesis from pymic.net_run.self_sup.self_patch_swapping import SelfSupPatchSwapping -from pymic.net_run.self_sup.self_volume_fusion import SelfSupVolumeFusion +# from pymic.net_run.self_sup.self_mim import SelfSupMIM +# from pymic.net_run.self_sup.self_dino import SelfSupDINO +from pymic.net_run.self_sup.self_vox2vec import SelfSupVox2Vec +from pymic.net_run.self_sup.self_volf import SelfSupVolumeFusion SelfSupMethodDict = { + # 'DINO': SelfSupDINO, + 'Vox2Vec': SelfSupVox2Vec, 'ModelGenesis': SelfSupModelGenesis, 'PatchSwapping': SelfSupPatchSwapping, 'VolumeFusion': SelfSupVolumeFusion + # 'MaskedImageModeling': SelfSupMIM } \ No newline at end of file diff --git a/pymic/net_run/self_sup/self_vox2vec.py b/pymic/net_run/self_sup/self_vox2vec.py new file mode 100644 index 0000000..94e2b16 --- /dev/null +++ b/pymic/net_run/self_sup/self_vox2vec.py @@ -0,0 +1,304 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import copy +import logging +import time +import logging +import torch +import torch.nn as nn +from datetime import datetime +from torch.optim import lr_scheduler +from tensorboardX import SummaryWriter +from pymic.io.image_read_write import save_nd_array_as_image +from pymic.net.net3d.fmunetv3 import FMUNetV3 +from pymic.net_run.agent_seg import SegmentationAgent +from pymic.loss.cls.infoNCE import InfoNCELoss + +def select_from_pyramid(feature_pyramid, indices): + """Select features from feature pyramid by their indices w.r.t. base feature map. + + Args: + feature_pyramid (Sequence[torch.Tensor]): Sequence of tensors of shapes ``(B, C_i, D_i, H_i, W_i)``. + indices (torch.Tensor): tensor of shape ``(B, N, 3)`` + + Returns: + torch.Tensor: tensor of shape ``(B, N, \sum_i c_i)`` + """ + out = [] + for i, x in enumerate(feature_pyramid): + batch_size = list(x.shape)[0] + x_move = x.moveaxis(1, -1) + index_i = indices // 2 ** i + x_i = [x_move[b][index_i[b][:, 0], index_i[b][:, 1], index_i[b][:, 2], :] for \ + b in range(batch_size)] + x_i = torch.stack(x_i) + out.append(x_i) + out = torch.cat(out, dim = -1) + return out + +class Vox2VecHead(nn.Module): + def __init__(self, params): + super(Vox2VecHead, self).__init__() + ft_chns = params['feature_chns'] + hidden_dim = params['hidden_dim'] + proj_dim = params['project_dim'] + embed_dim = sum(ft_chns) + self.proj_head = nn.Sequential( + nn.Linear(embed_dim, hidden_dim), + nn.BatchNorm1d(hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.BatchNorm1d(hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, proj_dim) + ) + + def forward(self, x): + output = self.proj_head(x) + output = nn.functional.normalize(output) + return output + +class Vox2VecWrapper(nn.Module): + """ + Perform forward pass separately on each resolution input. + The inputs corresponding to a single resolution are clubbed and single + forward is run on the same resolution inputs. Hence we do several + forward passes = number of different resolutions used. We then + concatenate all the output features and run the head forward on these + concatenated features. + """ + def __init__(self, backbone, head): + super(Vox2VecWrapper, self).__init__() + self.backbone = backbone + self.head = head + + def forward(self, x, vex_idx): + if(isinstance(self.backbone, FMUNetV3)): + x = self.backbone.project(x) + f = self.backbone.encoder(x) + B = list(f[0].shape)[0] + f_fpn = select_from_pyramid(f, vex_idx) + feature_dim = list(f_fpn.shape)[-1] + f_fpn = f_fpn.view(-1, feature_dim) + output = self.head(f_fpn) + proj_dim = list(output.shape)[-1] + output = output.view(B, -1, proj_dim) + return output + +class SelfSupVox2Vec(SegmentationAgent): + """ + An agent for image self-supervised learning with DeSD. + """ + def __init__(self, config, stage = 'train'): + super(SelfSupVox2Vec, self).__init__(config, stage) + + def create_network(self): + super(SelfSupVox2Vec, self).create_network() + proj_dim = self.config['self_supervised_learning'].get('project_dim', 1024) + hidden_dim = self.config['self_supervised_learning'].get('hidden_dim', 1024) + head_params= {'feature_chns': self.config['network']['feature_chns'], + 'hidden_dim':hidden_dim, + 'project_dim':proj_dim} + self.head = Vox2VecHead(head_params) + self.net_wrapper = Vox2VecWrapper(self.net, self.head) + + def create_loss_calculator(self): + # constrastive loss + self_sup_params = self.config['self_supervised_learning'] + self.loss_calculator = InfoNCELoss(self_sup_params) + + def get_parameters_to_update(self): + params = self.net_wrapper.parameters() + return params + + def training(self): + iter_valid = self.config['training']['iter_valid'] + train_loss = 0 + err_info = None + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 + self.net_wrapper.train() + for it in range(iter_valid): + t0 = time.time() + try: + data = next(self.trainIter) + except StopIteration: + self.trainIter = iter(self.train_loader) + data = next(self.trainIter) + t1 = time.time() + patch1, patch2, vox_ids1, vox_ids2 = data['image'] + inputs = torch.cat([patch1, patch2], dim = 0) + vox_ids = torch.cat([vox_ids1, vox_ids2], dim = 0) + inputs = self.convert_tensor_type(inputs) + inputs = inputs.to(self.device) + vox_ids = vox_ids.to(self.device) + + # for debug + # for i in range(patch1.shape[0]): + # v1_i = patch1[i][0] + # v2_i = patch2[i][0] + # print("patch shape", v1_i.shape, v2_i.shape) + # image_name0 = "temp/image_{0:}_{1:}_v0.nii.gz".format(it, i) + # image_name1 = "temp/image_{0:}_{1:}_v1.nii.gz".format(it, i) + # save_nd_array_as_image(v1_i, image_name0, reference_name = None) + # save_nd_array_as_image(v2_i, image_name1, reference_name = None) + # if(it > 10): + # return + + # zero the parameter gradients + self.optimizer.zero_grad() + + # forward + backward + optimize + out = self.net_wrapper(inputs, vox_ids) + out1, out2 = out.chunk(2) + + t2 = time.time() + loss = self.loss_calculator(out1, out2) + t3 = time.time() + + loss.backward() + self.optimizer.step() + train_loss = train_loss + loss.item() + t4 = time.time() + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 + + train_avg_loss = train_loss / iter_valid + train_scalers = {'loss': train_avg_loss, 'data_time': data_time, + 'gpu_time':gpu_time, 'loss_time':loss_time, 'back_time':back_time, + 'err_info': err_info} + return train_scalers + + + def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): + loss_scalar ={'train':train_scalars['loss']} + self.summ_writer.add_scalars('loss', loss_scalar, glob_it) + self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it) + logging.info('train loss {0:.4f}'.format(train_scalars['loss'])) + + def train_valid(self): + device_ids = self.config['training']['gpus'] + if(len(device_ids) > 1): + self.device = torch.device("cuda:0") + self.net_wrapper = nn.DataParallel(self.net_wrapper, device_ids = device_ids) + else: + self.device = torch.device("cuda:{0:}".format(device_ids[0])) + self.net_wrapper.to(self.device) + + ckpt_dir = self.config['training']['ckpt_dir'] + ckpt_prefix = self.config['training'].get('ckpt_prefix', None) + if(ckpt_prefix is None): + ckpt_prefix = ckpt_dir.split('/')[-1] + # iter_start = self.config['training']['iter_start'] + iter_start = 0 + iter_max = self.config['training']['iter_max'] + iter_valid = self.config['training']['iter_valid'] + iter_save = self.config['training'].get('iter_save', None) + early_stop_it = self.config['training'].get('early_stop_patience', None) + if(iter_save is None): + iter_save_list = [iter_max] + elif(isinstance(iter_save, (tuple, list))): + iter_save_list = iter_save + else: + iter_save_list = range(0, iter_max + 1, iter_save) + + self.min_loss = 10000.0 + self.min_loss_it = 0 + self.best_model_wts = None + self.bett_head_wts = None + checkpoint = None + # initialize the network with pre-trained weights + ckpt_init_name = self.config['training'].get('ckpt_init_name', None) + ckpt_init_mode = self.config['training'].get('ckpt_init_mode', 0) + ckpt_for_optm = None + if(ckpt_init_name is not None): + checkpoint = torch.load(ckpt_dir + "/" + ckpt_init_name, map_location = self.device) + pretrained_dict = checkpoint['model_state_dict'] + pretrain_head_dict = checkpoint['head_state_dict'] + self.load_pretrained_weights(self.net, pretrained_dict, device_ids) + self.load_pretrained_weights(self.head, pretrain_head_dict, device_ids) + + if(ckpt_init_mode > 0): # Load other information + self.min_loss = checkpoint.get('train_loss', 10000) + iter_start = checkpoint['iteration'] + self.min_loss_it = iter_start + self.best_model_wts = checkpoint['model_state_dict'] + self.best_head_wts = checkpoint['head_state_dict'] + ckpt_for_optm = checkpoint + + self.create_optimizer(self.get_parameters_to_update(), ckpt_for_optm) + self.create_loss_calculator() + + self.trainIter = iter(self.train_loader) + + logging.info("{0:} training start".format(str(datetime.now())[:-7])) + self.summ_writer = SummaryWriter(self.config['training']['ckpt_dir']) + self.glob_it = iter_start + for it in range(iter_start, iter_max, iter_valid): + lr_value = self.optimizer.param_groups[0]['lr'] + + t0 = time.time() + train_scalars = self.training() + t1 = time.time() + if(isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): + self.scheduler.step(-train_scalars['loss']) + else: + self.scheduler.step() + + self.glob_it = it + iter_valid + logging.info("\n{0:} it {1:}".format(str(datetime.now())[:-7], self.glob_it)) + logging.info('learning rate {0:}'.format(lr_value)) + logging.info("training time: {0:.2f}s".format(t1-t0)) + logging.info("data: {0:.2f}s, gpu: {1:.2f}s, loss: {2:.2f}s, back: {3:.2f}s".format( + train_scalars['data_time'], train_scalars['gpu_time'], + train_scalars['loss_time'], train_scalars['back_time'])) + + self.write_scalars(train_scalars, None, lr_value, self.glob_it) + if(train_scalars['loss'] < self.min_loss): + self.min_loss = train_scalars['loss'] + self.min_loss_it = self.glob_it + if(len(device_ids) > 1): + self.best_model_wts = copy.deepcopy(self.net.module.state_dict()) + self.best_head_wts = copy.deepcopy(self.head.module.state_dict()) + else: + self.best_model_wts = copy.deepcopy(self.net.state_dict()) + self.best_head_wts = copy.deepcopy(self.head.state_dict()) + + save_dict = {'iteration': self.min_loss_it, + 'train_loss': self.min_loss, + 'model_state_dict': self.best_model_wts, + 'head_state_dict': self.best_head_wts, + 'optimizer_state_dict': self.optimizer.state_dict()} + save_name = "{0:}/{1:}_best.pt".format(ckpt_dir, ckpt_prefix) + torch.save(save_dict, save_name) + txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefix), 'wt') + txt_file.write(str(self.min_loss_it)) + txt_file.close() + + stop_now = True if(early_stop_it is not None and \ + self.glob_it - self.min_loss_it > early_stop_it) else False + if(train_scalars['err_info'] is not None): + logging.info("Early stopped due to error: {0:}".format(train_scalars['err_info'])) + stop_now = True + if ((self.glob_it in iter_save_list) or stop_now): + save_dict = {'iteration': self.glob_it, + 'train_loss': train_scalars['loss'], + 'model_state_dict': self.net.module.state_dict() \ + if len(device_ids) > 1 else self.net.state_dict(), + 'head_state_dict': self.head.module.state_dict() \ + if len(device_ids) > 1 else self.head.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict()} + save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.glob_it) + torch.save(save_dict, save_name) + txt_file = open("{0:}/{1:}_latest.txt".format(ckpt_dir, ckpt_prefix), 'wt') + txt_file.write(str(self.glob_it)) + txt_file.close() + if(stop_now): + logging.info("The training is early stopped") + break + # save the best performing checkpoint + logging.info('The best performing iter is {0:}, train loss {1:}'.format(\ + self.min_loss_it, self.min_loss)) + self.summ_writer.close() \ No newline at end of file From e2f68d06fecc9d6f36d896f6ab6cd92ce8236c3c Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 30 Sep 2024 22:34:19 +0800 Subject: [PATCH 57/86] add vitb16 --- pymic/net/cls/torch_pretrained_net.py | 35 ++++++++++++++++++++++++++- pymic/net/net_dict_cls.py | 3 ++- pymic/net/net_init.py | 26 ++++++++++++++++++++ 3 files changed, 62 insertions(+), 2 deletions(-) create mode 100644 pymic/net/net_init.py diff --git a/pymic/net/cls/torch_pretrained_net.py b/pymic/net/cls/torch_pretrained_net.py index c1959a4..9d05c28 100644 --- a/pymic/net/cls/torch_pretrained_net.py +++ b/pymic/net/cls/torch_pretrained_net.py @@ -181,7 +181,8 @@ class MobileNetV2(BuiltInNet): """ def __init__(self, params): super(MobileNetV2, self).__init__(params) - self.net = models.mobilenet_v2(pretrained = self.pretrain) + weights = 'IMAGENET1K_V1' if self.pretrain else None + self.net = models.mobilenet_v2(weights = weights) # replace the last layer num_ftrs = self.net.last_channel @@ -204,6 +205,38 @@ def get_parameters_to_update(self): return params else: raise(ValueError("update_mode can only be 'all' or 'last'.")) + +class ViTB16(BuiltInNet): + """ + ViTB16 for classification. + Parameters should be set in the `params` dictionary that contains the + following fields: + + :param input_chns: (int) Input channel number, default is 3. + :param pretrain: (bool) Using pretrained model or not, default is True. + :param update_mode: (str) The strategy for updating layers: "`all`" means updating + all the layers, and "`last`" (by default) means updating the last layer, + as well as the first layer when `input_chns` is not 3. + """ + def __init__(self, params): + super(ViTB16, self).__init__(params) + weights = 'IMAGENET1K_V1' if self.pretrain else None + self.net = models.vit_b_16(weights = weights) + + # replace the last layer + num_ftrs = self.net.representation_size + if(num_ftrs is None): + num_ftrs = self.net.hidden_dim + self.net.heads[-1] = nn.Linear(num_ftrs, params['class_num']) + + def get_parameters_to_update(self): + if(self.update_mode == "all"): + return self.net.parameters() + elif(self.update_mode == "last"): + params = self.net.heads[-1].parameters() + return params + else: + raise(ValueError("update_mode can only be 'all' or 'last'.")) if __name__ == "__main__": params = {"class_num": 2, "pretrain": False, "input_chns": 3} diff --git a/pymic/net/net_dict_cls.py b/pymic/net/net_dict_cls.py index 3a7808b..a83334a 100644 --- a/pymic/net/net_dict_cls.py +++ b/pymic/net/net_dict_cls.py @@ -13,5 +13,6 @@ TorchClsNetDict = { 'resnet18': ResNet18, 'vgg16': VGG16, - 'mobilenetv2':MobileNetV2 + 'mobilenetv2':MobileNetV2, + 'vitb16': ViTB16 } diff --git a/pymic/net/net_init.py b/pymic/net/net_init.py new file mode 100644 index 0000000..1f9b48e --- /dev/null +++ b/pymic/net/net_init.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +from torch import nn + + +class Initialization_He(object): + def __init__(self, neg_slope=1e-2): + self.neg_slope = neg_slope + + def __call__(self, module): + if isinstance(module, (nn.Conv3d, nn.Conv2d, nn.ConvTranspose2d, nn.ConvTranspose3d)): + module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope) + if module.bias is not None: + module.bias = nn.init.constant_(module.bias, 0) + + +class Initialization_XavierUniform(object): + def __init__(self, gain=1): + self.gain = gain + + def __call__(self, module): + if isinstance(module, (nn.Conv3d ,nn.Conv2d, nn.ConvTranspose2d, nn.ConvTranspose3d)): + module.weight = nn.init.xavier_uniform_(module.weight, self.gain) + if module.bias is not None: + module.bias = nn.init.constant_(module.bias, 0) From 2db9a3a0fc2429a38e44e1bad7376c42d4c5e77f Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 2 Oct 2024 15:48:05 +0800 Subject: [PATCH 58/86] add CANet for segmentation --- pymic/net/net2d/canet.py | 229 --------- pymic/net/net2d/canet_module.py | 578 ---------------------- pymic/net/net2d/unet2d_canet.py | 840 ++++++++++++++++++++++++++------ 3 files changed, 687 insertions(+), 960 deletions(-) delete mode 100644 pymic/net/net2d/canet.py delete mode 100644 pymic/net/net2d/canet_module.py diff --git a/pymic/net/net2d/canet.py b/pymic/net/net2d/canet.py deleted file mode 100644 index ab64ab5..0000000 --- a/pymic/net/net2d/canet.py +++ /dev/null @@ -1,229 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division -import torch -import torch.nn as nn - -class ConvLayer(nn.Module): - """ - A combination of Conv2d, BatchNorm2d and LeakyReLU. - """ - def __init__(self, in_channels, out_channels, kernel_size = 1): - super(ConvLayer, self).__init__() - padding = int((kernel_size - 1) / 2) - self.conv = nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding), - nn.BatchNorm2d(out_channels), - nn.LeakyReLU() - ) - - def forward(self, x): - return self.conv(x) - -class SEBlock(nn.Module): - """ - A Modified Squeeze-and-Excitation block for spatial attention. - """ - def __init__(self, in_channels, r): - super(SEBlock, self).__init__() - - redu_chns = int(in_channels / r) - self.se_layers = nn.Sequential( - nn.AdaptiveAvgPool2d(1), - nn.Conv2d(in_channels, redu_chns, kernel_size=1, padding=0), - nn.LeakyReLU(), - nn.Conv2d(redu_chns, in_channels, kernel_size=1, padding=0), - nn.ReLU()) - - def forward(self, x): - f = self.se_layers(x) - return f*x + x - -class ASPPBlock(nn.Module): - """ - ASPP block. - """ - def __init__(self,in_channels, out_channels_list, kernel_size_list, dilation_list): - super(ASPPBlock, self).__init__() - self.conv_num = len(out_channels_list) - assert(self.conv_num == 4) - assert(self.conv_num == len(kernel_size_list) and self.conv_num == len(dilation_list)) - pad0 = int((kernel_size_list[0] - 1) / 2 * dilation_list[0]) - pad1 = int((kernel_size_list[1] - 1) / 2 * dilation_list[1]) - pad2 = int((kernel_size_list[2] - 1) / 2 * dilation_list[2]) - pad3 = int((kernel_size_list[3] - 1) / 2 * dilation_list[3]) - self.conv_1 = nn.Conv2d(in_channels, out_channels_list[0], kernel_size = kernel_size_list[0], - dilation = dilation_list[0], padding = pad0 ) - self.conv_2 = nn.Conv2d(in_channels, out_channels_list[1], kernel_size = kernel_size_list[1], - dilation = dilation_list[1], padding = pad1 ) - self.conv_3 = nn.Conv2d(in_channels, out_channels_list[2], kernel_size = kernel_size_list[2], - dilation = dilation_list[2], padding = pad2 ) - self.conv_4 = nn.Conv2d(in_channels, out_channels_list[3], kernel_size = kernel_size_list[3], - dilation = dilation_list[3], padding = pad3 ) - - out_channels = out_channels_list[0] + out_channels_list[1] + out_channels_list[2] + out_channels_list[3] - self.conv_1x1 = nn.Sequential( - nn.Conv2d(out_channels, out_channels, kernel_size=1, padding=0), - nn.BatchNorm2d(out_channels), - nn.LeakyReLU()) - - def forward(self, x): - x1 = self.conv_1(x) - x2 = self.conv_2(x) - x3 = self.conv_3(x) - x4 = self.conv_4(x) - - y = torch.cat([x1, x2, x3, x4], dim=1) - y = self.conv_1x1(y) - return y - -class ConvBNActBlock(nn.Module): - """ - Two convolution layers with batch norm, leaky relu, - dropout and SE block. - """ - def __init__(self,in_channels, out_channels, dropout_p): - super(ConvBNActBlock, self).__init__() - self.conv_conv = nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), - nn.BatchNorm2d(out_channels), - nn.LeakyReLU(), - nn.Dropout(dropout_p), - nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), - nn.BatchNorm2d(out_channels), - nn.LeakyReLU(), - SEBlock(out_channels, 2) - ) - - def forward(self, x): - return self.conv_conv(x) - -class DownBlock(nn.Module): - """ - Downsampling by a concantenation of max-pool and avg-pool, - followed by ConvBNActBlock. - """ - def __init__(self, in_channels, out_channels, dropout_p): - super(DownBlock, self).__init__() - self.maxpool = nn.MaxPool2d(2) - self.avgpool = nn.AvgPool2d(2) - self.conv = ConvBNActBlock(2 * in_channels, out_channels, dropout_p) - - def forward(self, x): - x_max = self.maxpool(x) - x_avg = self.avgpool(x) - x_cat = torch.cat([x_max, x_avg], dim=1) - y = self.conv(x_cat) - return y + x_cat - -class UpBlock(nn.Module): - """ - Upssampling followed by ConvBNActBlock. - """ - def __init__(self, in_channels1, in_channels2, out_channels, - bilinear=True, dropout_p = 0.5): - super(UpBlock, self).__init__() - self.bilinear = bilinear - if bilinear: - self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size = 1) - self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) - else: - self.up = nn.ConvTranspose2d(in_channels1, in_channels2, kernel_size=2, stride=2) - self.conv = ConvBNActBlock(in_channels2 * 2, out_channels, dropout_p) - - def forward(self, x1, x2): - if self.bilinear: - x1 = self.conv1x1(x1) - x1 = self.up(x1) - x_cat = torch.cat([x2, x1], dim=1) - y = self.conv(x_cat) - return y + x_cat - -class CANet(nn.Module): - """ - Implementation of of CA-Net for biomedical image segmentation. - - * Reference: R. Gu et al. `CA-Net: Comprehensive Attention Convolutional Neural Networks - for Explainable Medical Image Segmentation `_. - IEEE Transactions on Medical Imaging, 40(2),2021:699-711. - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param bilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - """ - def __init__(self, params): - super(COPLENet, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.n_class = self.params['class_num'] - self.bilinear = self.params['bilinear'] - assert(len(self.ft_chns) == 5) - - f0_half = int(self.ft_chns[0] / 2) - f1_half = int(self.ft_chns[1] / 2) - f2_half = int(self.ft_chns[2] / 2) - f3_half = int(self.ft_chns[3] / 2) - self.in_conv= ConvBNActBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) - self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) - self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) - self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) - self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - - self.bridge0= ConvLayer(self.ft_chns[0], f0_half) - self.bridge1= ConvLayer(self.ft_chns[1], f1_half) - self.bridge2= ConvLayer(self.ft_chns[2], f2_half) - self.bridge3= ConvLayer(self.ft_chns[3], f3_half) - - self.up1 = UpBlock(self.ft_chns[4], f3_half, self.ft_chns[3], dropout_p = self.dropout[3]) - self.up2 = UpBlock(self.ft_chns[3], f2_half, self.ft_chns[2], dropout_p = self.dropout[2]) - self.up3 = UpBlock(self.ft_chns[2], f1_half, self.ft_chns[1], dropout_p = self.dropout[1]) - self.up4 = UpBlock(self.ft_chns[1], f0_half, self.ft_chns[0], dropout_p = self.dropout[0]) - - f4 = self.ft_chns[4] - aspp_chns = [int(f4 / 4), int(f4 / 4), int(f4 / 4), int(f4 / 4)] - aspp_knls = [1, 3, 3, 3] - aspp_dila = [1, 2, 4, 6] - self.aspp = ASPPBlock(f4, aspp_chns, aspp_knls, aspp_dila) - - - self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, - kernel_size = 3, padding = 1) - - def forward(self, x): - x_shape = list(x.shape) - if(len(x_shape) == 5): - [N, C, D, H, W] = x_shape - new_shape = [N*D, C, H, W] - x = torch.transpose(x, 1, 2) - x = torch.reshape(x, new_shape) - x0 = self.in_conv(x) - x0b = self.bridge0(x0) - x1 = self.down1(x0) - x1b = self.bridge1(x1) - x2 = self.down2(x1) - x2b = self.bridge2(x2) - x3 = self.down3(x2) - x3b = self.bridge3(x3) - x4 = self.down4(x3) - x4 = self.aspp(x4) - - x = self.up1(x4, x3b) - x = self.up2(x, x2b) - x = self.up3(x, x1b) - x = self.up4(x, x0b) - output = self.out_conv(x) - - if(len(x_shape) == 5): - new_shape = [N, D] + list(output.shape)[1:] - output = torch.reshape(output, new_shape) - output = torch.transpose(output, 1, 2) - return output \ No newline at end of file diff --git a/pymic/net/net2d/canet_module.py b/pymic/net/net2d/canet_module.py deleted file mode 100644 index 097a4f1..0000000 --- a/pymic/net/net2d/canet_module.py +++ /dev/null @@ -1,578 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Building blcoks for CA-Net. - -Oringinal file is on `Github. -`_ -""" - -from __future__ import print_function, division -import torch -import torch.nn as nn -import functools -from torch.nn import functional as F - - -class conv_block(nn.Module): - def __init__(self, ch_in, ch_out, drop_out=False): - super(conv_block, self).__init__() - self.conv = nn.Sequential( - nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), - nn.BatchNorm2d(ch_out), - nn.ReLU(inplace=True), - nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True), - nn.BatchNorm2d(ch_out), - nn.ReLU(inplace=True), - ) - self.dropout = drop_out - - def forward(self, x): - x = self.conv(x) - if self.dropout: - x = nn.Dropout2d(0.5)(x) - return x - - -# # UpCat(nn.Module) for U-net UP convolution -class UpCat(nn.Module): - def __init__(self, in_feat, out_feat, is_deconv=True): - super(UpCat, self).__init__() - if is_deconv: - self.up = nn.ConvTranspose2d(in_feat, out_feat, kernel_size=2, stride=2) - else: - self.up = nn.Upsample(scale_factor=2, mode='bilinear') - - def forward(self, inputs, down_outputs): - # TODO: Upsampling required after deconv? - outputs = self.up(down_outputs) - offset = inputs.size()[3] - outputs.size()[3] - if offset == 1: - addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2]), out=None).unsqueeze( - 3).cuda() - outputs = torch.cat([outputs, addition], dim=3) - elif offset > 1: - addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2], offset), out=None).cuda() - outputs = torch.cat([outputs, addition], dim=3) - out = torch.cat([inputs, outputs], dim=1) - - return out - - -# # UpCatconv(nn.Module) for up convolution -class UpCatconv(nn.Module): - def __init__(self, in_feat, out_feat, is_deconv=True, drop_out=False): - super(UpCatconv, self).__init__() - - if is_deconv: - self.conv = conv_block(in_feat, out_feat, drop_out=drop_out) - self.up = nn.ConvTranspose2d(in_feat, out_feat, kernel_size=2, stride=2) - else: - self.conv = conv_block(in_feat + out_feat, out_feat, drop_out=drop_out) - self.up = nn.Upsample(scale_factor=2, mode='bilinear') - - def forward(self, inputs, down_outputs): - # TODO: Upsampling required after deconv - outputs = self.up(down_outputs) - offset = inputs.size()[3] - outputs.size()[3] - if offset == 1: - addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2]), out=None).unsqueeze( - 3).cuda() - outputs = torch.cat([outputs, addition], dim=3) - elif offset > 1: - addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2], offset), out=None).cuda() - outputs = torch.cat([outputs, addition], dim=3) - out = self.conv(torch.cat([inputs, outputs], dim=1)) - - return out - - -class UnetDsv3(nn.Module): - def __init__(self, in_size, out_size, scale_factor): - super(UnetDsv3, self).__init__() - self.dsv = nn.Sequential(nn.Conv2d(in_size, out_size, kernel_size=1, stride=1, padding=0), - nn.Upsample(size=scale_factor, mode='bilinear'), ) - - def forward(self, input): - return self.dsv(input) - - -###### Intial weights ##### -def weights_init_normal(m): - classname = m.__class__.__name__ - #print(classname) - if classname.find('Conv') != -1: - nn.init.normal(m.weight.data, 0.0, 0.02) - elif classname.find('Linear') != -1: - nn.init.normal(m.weight.data, 0.0, 0.02) - elif classname.find('BatchNorm') != -1: - nn.init.normal(m.weight.data, 1.0, 0.02) - nn.init.constant(m.bias.data, 0.0) - - -def weights_init_xavier(m): - classname = m.__class__.__name__ - #print(classname) - if classname.find('Conv') != -1: - nn.init.xavier_normal(m.weight.data, gain=1) - elif classname.find('Linear') != -1: - nn.init.xavier_normal(m.weight.data, gain=1) - elif classname.find('BatchNorm') != -1: - nn.init.normal(m.weight.data, 1.0, 0.02) - nn.init.constant(m.bias.data, 0.0) - - -def weights_init_kaiming(m): - classname = m.__class__.__name__ - #print(classname) - if classname.find('Conv') != -1: - nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in') - elif classname.find('Linear') != -1: - nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in') - elif classname.find('BatchNorm') != -1: - nn.init.normal(m.weight.data, 1.0, 0.02) - nn.init.constant(m.bias.data, 0.0) - - -def weights_init_orthogonal(m): - classname = m.__class__.__name__ - #print(classname) - if classname.find('Conv') != -1: - nn.init.orthogonal(m.weight.data, gain=1) - elif classname.find('Linear') != -1: - nn.init.orthogonal(m.weight.data, gain=1) - elif classname.find('BatchNorm') != -1: - nn.init.normal(m.weight.data, 1.0, 0.02) - nn.init.constant(m.bias.data, 0.0) - - -def init_weights(net, init_type='normal'): - #print('initialization method [%s]' % init_type) - if init_type == 'normal': - net.apply(weights_init_normal) - elif init_type == 'xavier': - net.apply(weights_init_xavier) - elif init_type == 'kaiming': - net.apply(weights_init_kaiming) - elif init_type == 'orthogonal': - net.apply(weights_init_orthogonal) - else: - raise NotImplementedError('initialization method [%s] is not implemented' % init_type) - - -def get_norm_layer(norm_type='instance'): - if norm_type == 'batch': - norm_layer = functools.partial(nn.BatchNorm2d, affine=True) - elif norm_type == 'instance': - norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) - elif norm_type == 'none': - norm_layer = None - else: - raise NotImplementedError('normalization layer [%s] is not found' % norm_type) - return norm_layer - - -###### For attention ###### -class _GridAttentionBlockND(nn.Module): - def __init__(self, in_channels, gating_channels, inter_channels=None, dimension=3, mode='concatenation', - sub_sample_factor=(2,2,2)): - super(_GridAttentionBlockND, self).__init__() - - assert dimension in [2, 3] - assert mode in ['concatenation', 'concatenation_debug', 'concatenation_residual'] - - # Downsampling rate for the input featuremap - if isinstance(sub_sample_factor, tuple): self.sub_sample_factor = sub_sample_factor - elif isinstance(sub_sample_factor, list): self.sub_sample_factor = tuple(sub_sample_factor) - else: self.sub_sample_factor = tuple([sub_sample_factor]) * dimension - - # Default parameter set - self.mode = mode - self.dimension = dimension - self.sub_sample_kernel_size = self.sub_sample_factor - - # Number of channels (pixel dimensions) - self.in_channels = in_channels - self.gating_channels = gating_channels - self.inter_channels = inter_channels - - if self.inter_channels is None: - self.inter_channels = in_channels // 2 - if self.inter_channels == 0: - self.inter_channels = 1 - - if dimension == 3: - conv_nd = nn.Conv3d - bn = nn.BatchNorm3d - self.upsample_mode = 'trilinear' - elif dimension == 2: - conv_nd = nn.Conv2d - bn = nn.BatchNorm2d - self.upsample_mode = 'bilinear' - else: - raise NotImplemented - - # Output transform - self.W = nn.Sequential( - conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), - bn(self.in_channels), - ) - - # Theta^T * x_ij + Phi^T * gating_signal + bias - self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, - kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=True) - self.phi = conv_nd(in_channels=self.gating_channels, out_channels=self.inter_channels, - kernel_size=(1, 1), stride=1, padding=0, bias=True) - self.psi = conv_nd(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True) - - # Initialise weights - for m in self.children(): - init_weights(m, init_type='kaiming') - - # Define the operation - if mode == 'concatenation': - self.operation_function = self._concatenation - elif mode == 'concatenation_debug': - self.operation_function = self._concatenation_debug - elif mode == 'concatenation_residual': - self.operation_function = self._concatenation_residual - else: - raise NotImplementedError('Unknown operation function.') - - - def forward(self, x, g): - ''' - :param x: (b, c, t, h, w) - :param g: (b, g_d) - :return: - ''' - - output = self.operation_function(x, g) - return output - - def _concatenation(self, x, g): - input_size = x.size() - batch_size = input_size[0] - assert batch_size == g.size(0) - - # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) - # phi => (b, g_d) -> (b, i_c) - theta_x = self.theta(x) - theta_x_size = theta_x.size() - - # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') - # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) - phi_g = F.upsample(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) - f = F.relu(theta_x + phi_g, inplace=True) - - # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) - sigm_psi_f = F.sigmoid(self.psi(f)) - - # upsample the attentions and multiply - sigm_psi_f = F.upsample(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) - y = sigm_psi_f.expand_as(x) * x - W_y = self.W(y) - - return W_y, sigm_psi_f - - def _concatenation_debug(self, x, g): - input_size = x.size() - batch_size = input_size[0] - assert batch_size == g.size(0) - - # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) - # phi => (b, g_d) -> (b, i_c) - theta_x = self.theta(x) - theta_x_size = theta_x.size() - - # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') - # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) - phi_g = F.upsample(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) - f = F.softplus(theta_x + phi_g) - - # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) - sigm_psi_f = F.sigmoid(self.psi(f)) - - # upsample the attentions and multiply - sigm_psi_f = F.upsample(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) - y = sigm_psi_f.expand_as(x) * x - W_y = self.W(y) - - return W_y, sigm_psi_f - - - def _concatenation_residual(self, x, g): - input_size = x.size() - batch_size = input_size[0] - assert batch_size == g.size(0) - - # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) - # phi => (b, g_d) -> (b, i_c) - theta_x = self.theta(x) - theta_x_size = theta_x.size() - - # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') - # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) - phi_g = F.upsample(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) - f = F.relu(theta_x + phi_g, inplace=True) - - # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) - f = self.psi(f).view(batch_size, 1, -1) - sigm_psi_f = F.softmax(f, dim=2).view(batch_size, 1, *theta_x.size()[2:]) - - # upsample the attentions and multiply - sigm_psi_f = F.upsample(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) - y = sigm_psi_f.expand_as(x) * x - W_y = self.W(y) - - return W_y, sigm_psi_f - - -class GridAttentionBlock2D(_GridAttentionBlockND): - def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation', - sub_sample_factor=(2, 2)): - super(GridAttentionBlock2D, self).__init__(in_channels, - inter_channels=inter_channels, - gating_channels=gating_channels, - dimension=2, mode=mode, - sub_sample_factor=sub_sample_factor, - ) - - -class GridAttentionBlock3D(_GridAttentionBlockND): - def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation', - sub_sample_factor=(2,2,2)): - super(GridAttentionBlock3D, self).__init__(in_channels, - inter_channels=inter_channels, - gating_channels=gating_channels, - dimension=3, mode=mode, - sub_sample_factor=sub_sample_factor, - ) - -class _GridAttentionBlockND_TORR(nn.Module): - def __init__(self, in_channels, gating_channels, inter_channels=None, dimension=3, mode='concatenation', - sub_sample_factor=(1,1,1), bn_layer=True, use_W=True, use_phi=True, use_theta=True, use_psi=True, nonlinearity1='relu'): - super(_GridAttentionBlockND_TORR, self).__init__() - - assert dimension in [2, 3] - assert mode in ['concatenation', 'concatenation_softmax', - 'concatenation_sigmoid', 'concatenation_mean', - 'concatenation_range_normalise', 'concatenation_mean_flow'] - - # Default parameter set - self.mode = mode - self.dimension = dimension - self.sub_sample_factor = sub_sample_factor if isinstance(sub_sample_factor, tuple) else tuple([sub_sample_factor])*dimension - self.sub_sample_kernel_size = self.sub_sample_factor - - # Number of channels (pixel dimensions) - self.in_channels = in_channels - self.gating_channels = gating_channels - self.inter_channels = inter_channels - - if self.inter_channels is None: - self.inter_channels = in_channels // 2 - if self.inter_channels == 0: - self.inter_channels = 1 - - if dimension == 3: - conv_nd = nn.Conv3d - bn = nn.BatchNorm3d - self.upsample_mode = 'trilinear' - elif dimension == 2: - conv_nd = nn.Conv2d - bn = nn.BatchNorm2d - self.upsample_mode = 'bilinear' - else: - raise NotImplemented - - # initialise id functions - # Theta^T * x_ij + Phi^T * gating_signal + bias - self.W = lambda x: x - self.theta = lambda x: x - self.psi = lambda x: x - self.phi = lambda x: x - self.nl1 = lambda x: x - - if use_W: - if bn_layer: - self.W = nn.Sequential( - conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), - bn(self.in_channels), - ) - else: - self.W = conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0) - - if use_theta: - self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, - kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=False) - - - if use_phi: - self.phi = conv_nd(in_channels=self.gating_channels, out_channels=self.inter_channels, - kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=False) - - - if use_psi: - self.psi = conv_nd(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True) - - - if nonlinearity1: - if nonlinearity1 == 'relu': - self.nl1 = lambda x: F.relu(x, inplace=True) - - if 'concatenation' in mode: - self.operation_function = self._concatenation - else: - raise NotImplementedError('Unknown operation function.') - - # Initialise weights - for m in self.children(): - init_weights(m, init_type='kaiming') - - - if use_psi and self.mode == 'concatenation_sigmoid': - nn.init.constant(self.psi.bias.data, 3.0) - - if use_psi and self.mode == 'concatenation_softmax': - nn.init.constant(self.psi.bias.data, 10.0) - - # if use_psi and self.mode == 'concatenation_mean': - # nn.init.constant(self.psi.bias.data, 3.0) - - # if use_psi and self.mode == 'concatenation_range_normalise': - # nn.init.constant(self.psi.bias.data, 3.0) - - parallel = False - if parallel: - if use_W: self.W = nn.DataParallel(self.W) - if use_phi: self.phi = nn.DataParallel(self.phi) - if use_psi: self.psi = nn.DataParallel(self.psi) - if use_theta: self.theta = nn.DataParallel(self.theta) - - def forward(self, x, g): - ''' - :param x: (b, c, t, h, w) - :param g: (b, g_d) - :return: - ''' - - output = self.operation_function(x, g) - return output - - def _concatenation(self, x, g): - input_size = x.size() - batch_size = input_size[0] - assert batch_size == g.size(0) - - ############################# - # compute compatibility score - - # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) - # phi => (b, c, t, h, w) -> (b, i_c, t, h, w) - theta_x = self.theta(x) - theta_x_size = theta_x.size() - - # nl(theta.x + phi.g + bias) -> f = (b, i_c, t/s1, h/s2, w/s3) - phi_g = F.upsample(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) - - f = theta_x + phi_g - f = self.nl1(f) - - psi_f = self.psi(f) - - ############################################ - # normalisation -- scale compatibility score - # psi^T . f -> (b, 1, t/s1, h/s2, w/s3) - if self.mode == 'concatenation_softmax': - sigm_psi_f = F.softmax(psi_f.view(batch_size, 1, -1), dim=2) - sigm_psi_f = sigm_psi_f.view(batch_size, 1, *theta_x_size[2:]) - elif self.mode == 'concatenation_mean': - psi_f_flat = psi_f.view(batch_size, 1, -1) - psi_f_sum = torch.sum(psi_f_flat, dim=2)#clamp(1e-6) - psi_f_sum = psi_f_sum[:,:,None].expand_as(psi_f_flat) - - sigm_psi_f = psi_f_flat / psi_f_sum - sigm_psi_f = sigm_psi_f.view(batch_size, 1, *theta_x_size[2:]) - elif self.mode == 'concatenation_mean_flow': - psi_f_flat = psi_f.view(batch_size, 1, -1) - ss = psi_f_flat.shape - psi_f_min = psi_f_flat.min(dim=2)[0].view(ss[0],ss[1],1) - psi_f_flat = psi_f_flat - psi_f_min - psi_f_sum = torch.sum(psi_f_flat, dim=2).view(ss[0],ss[1],1).expand_as(psi_f_flat) - - sigm_psi_f = psi_f_flat / psi_f_sum - sigm_psi_f = sigm_psi_f.view(batch_size, 1, *theta_x_size[2:]) - elif self.mode == 'concatenation_range_normalise': - psi_f_flat = psi_f.view(batch_size, 1, -1) - ss = psi_f_flat.shape - psi_f_max = torch.max(psi_f_flat, dim=2)[0].view(ss[0], ss[1], 1) - psi_f_min = torch.min(psi_f_flat, dim=2)[0].view(ss[0], ss[1], 1) - - sigm_psi_f = (psi_f_flat - psi_f_min) / (psi_f_max - psi_f_min).expand_as(psi_f_flat) - sigm_psi_f = sigm_psi_f.view(batch_size, 1, *theta_x_size[2:]) - - elif self.mode == 'concatenation_sigmoid': - sigm_psi_f = F.sigmoid(psi_f) - else: - raise NotImplementedError - - # sigm_psi_f is attention map! upsample the attentions and multiply - sigm_psi_f = F.upsample(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) - y = sigm_psi_f.expand_as(x) * x - W_y = self.W(y) - - return W_y, sigm_psi_f - - -class GridAttentionBlock2D_TORR(_GridAttentionBlockND_TORR): - def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation', - sub_sample_factor=(1,1), bn_layer=True, - use_W=True, use_phi=True, use_theta=True, use_psi=True, - nonlinearity1='relu'): - super(GridAttentionBlock2D_TORR, self).__init__(in_channels, - inter_channels=inter_channels, - gating_channels=gating_channels, - dimension=2, mode=mode, - sub_sample_factor=sub_sample_factor, - bn_layer=bn_layer, - use_W=use_W, - use_phi=use_phi, - use_theta=use_theta, - use_psi=use_psi, - nonlinearity1=nonlinearity1) - - -class GridAttentionBlock3D_TORR(_GridAttentionBlockND_TORR): - def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation', - sub_sample_factor=(1,1,1), bn_layer=True): - super(GridAttentionBlock3D_TORR, self).__init__(in_channels, - inter_channels=inter_channels, - gating_channels=gating_channels, - dimension=3, mode=mode, - sub_sample_factor=sub_sample_factor, - bn_layer=bn_layer) - - -class MultiAttentionBlock(nn.Module): - def __init__(self, in_size, gate_size, inter_size, nonlocal_mode, sub_sample_factor): - super(MultiAttentionBlock, self).__init__() - self.gate_block_1 = GridAttentionBlock2D(in_channels=in_size, gating_channels=gate_size, - inter_channels=inter_size, mode=nonlocal_mode, - sub_sample_factor=sub_sample_factor) - self.gate_block_2 = GridAttentionBlock2D(in_channels=in_size, gating_channels=gate_size, - inter_channels=inter_size, mode=nonlocal_mode, - sub_sample_factor=sub_sample_factor) - self.combine_gates = nn.Sequential(nn.Conv2d(in_size*2, in_size, kernel_size=1, stride=1, padding=0), - nn.BatchNorm2d(in_size), - nn.ReLU(inplace=True)) - - # initialise the blocks - for m in self.children(): - if m.__class__.__name__.find('GridAttentionBlock2D') != -1: continue - init_weights(m, init_type='kaiming') - - def forward(self, input, gating_signal): - gate_1, attention_1 = self.gate_block_1(input, gating_signal) - gate_2, attention_2 = self.gate_block_2(input, gating_signal) - - return self.combine_gates(torch.cat([gate_1, gate_2], 1)), torch.cat([attention_1, attention_2], 1) \ No newline at end of file diff --git a/pymic/net/net2d/unet2d_canet.py b/pymic/net/net2d/unet2d_canet.py index defcb60..0af3159 100644 --- a/pymic/net/net2d/unet2d_canet.py +++ b/pymic/net/net2d/unet2d_canet.py @@ -1,18 +1,327 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division - -import numpy as np +import numpy as np import torch import torch.nn as nn +from torch.nn import init from torch.nn import functional as F -from pymic.net.net2d.canet_module import * + +## init +def weights_init_normal(m): + classname = m.__class__.__name__ + #print(classname) + if classname.find('Conv') != -1: + init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find('Linear') != -1: + init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm') != -1: + init.normal_(m.weight.data, 1.0, 0.02) + init.constant_(m.bias.data, 0.0) + + +def weights_init_xavier(m): + classname = m.__class__.__name__ + #print(classname) + if classname.find('Conv') != -1: + init.xavier_normal_(m.weight.data, gain=1) + elif classname.find('Linear') != -1: + init.xavier_normal_(m.weight.data, gain=1) + elif classname.find('BatchNorm') != -1: + init.normal_(m.weight.data, 1.0, 0.02) + init.constant_(m.bias.data, 0.0) + + +def weights_init_kaiming(m): + classname = m.__class__.__name__ + #print(classname) + if classname.find('Conv') != -1: + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif classname.find('Linear') != -1: + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif classname.find('BatchNorm') != -1: + init.normal_(m.weight.data, 1.0, 0.02) + init.constant_(m.bias.data, 0.0) + + +def weights_init_orthogonal(m): + classname = m.__class__.__name__ + #print(classname) + if classname.find('Conv') != -1: + init.orthogonal_(m.weight.data, gain=1) + elif classname.find('Linear') != -1: + init.orthogonal_(m.weight.data, gain=1) + elif classname.find('BatchNorm') != -1: + init.normal_(m.weight.data, 1.0, 0.02) + init.constant_(m.bias.data, 0.0) + + +def init_weights(net, init_type='normal'): + #print('initialization method [%s]' % init_type) + if init_type == 'normal': + net.apply(weights_init_normal) + elif init_type == 'xavier': + net.apply(weights_init_xavier) + elif init_type == 'kaiming': + net.apply(weights_init_kaiming) + elif init_type == 'orthogonal': + net.apply(weights_init_orthogonal) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + +## 1, modules +def conv1x1(in_planes, out_planes, stride=1, bias=False): + "1x1 convolution" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, + padding=0, bias=bias) def conv3x3(in_planes, out_planes, stride=1, bias=False, group=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,padding=1, groups=group, bias=bias) +# conv_block(nn.Module) for U-net convolution block +class conv_block(nn.Module): + def __init__(self, ch_in, ch_out, drop_out=False): + super(conv_block, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), + nn.BatchNorm2d(ch_out), + nn.ReLU(inplace=True), + nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True), + nn.BatchNorm2d(ch_out), + nn.ReLU(inplace=True), + ) + self.dropout = drop_out + + def forward(self, x): + x = self.conv(x) + if self.dropout: + x = nn.Dropout2d(0.5)(x) + return x + + +# # UpCat(nn.Module) for U-net UP convolution +class UpCat(nn.Module): + def __init__(self, in_feat, out_feat, is_deconv=True): + super(UpCat, self).__init__() + + if is_deconv: + self.up = nn.ConvTranspose2d(in_feat, out_feat, kernel_size=2, stride=2) + else: + self.up = nn.Upsample(scale_factor=2, mode='bilinear') + + def forward(self, inputs, down_outputs): + # TODO: Upsampling required after deconv? + outputs = self.up(down_outputs) + offset = inputs.size()[3] - outputs.size()[3] + if offset == 1: + addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2]), out=None).unsqueeze( + 3).cuda() + outputs = torch.cat([outputs, addition], dim=3) + elif offset > 1: + addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2], offset), out=None).cuda() + outputs = torch.cat([outputs, addition], dim=3) + out = torch.cat([inputs, outputs], dim=1) + + return out + + +# # UpCatconv(nn.Module) for up convolution +class UpCatconv(nn.Module): + def __init__(self, in_feat, out_feat, is_deconv=True, drop_out=False): + super(UpCatconv, self).__init__() + + if is_deconv: + self.conv = conv_block(in_feat, out_feat, drop_out=drop_out) + self.up = nn.ConvTranspose2d(in_feat, out_feat, kernel_size=2, stride=2) + else: + self.conv = conv_block(in_feat + out_feat, out_feat, drop_out=drop_out) + self.up = nn.Upsample(scale_factor=2, mode='bilinear') + + def forward(self, inputs, down_outputs): + # TODO: Upsampling required after deconv + outputs = self.up(down_outputs) + offset = inputs.size()[3] - outputs.size()[3] + if offset == 1: + addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2]), out=None).unsqueeze( + 3).cuda() + outputs = torch.cat([outputs, addition], dim=3) + elif offset > 1: + addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2], offset), out=None).cuda() + outputs = torch.cat([outputs, addition], dim=3) + out = self.conv(torch.cat([inputs, outputs], dim=1)) + + return out + + + +class _GridAttentionBlockND(nn.Module): + def __init__(self, in_channels, gating_channels, inter_channels=None, dimension=3, mode='concatenation', + sub_sample_factor=(2,2,2)): + super(_GridAttentionBlockND, self).__init__() + + assert dimension in [2, 3] + assert mode in ['concatenation', 'concatenation_debug', 'concatenation_residual'] + + # Downsampling rate for the input featuremap + if isinstance(sub_sample_factor, tuple): self.sub_sample_factor = sub_sample_factor + elif isinstance(sub_sample_factor, list): self.sub_sample_factor = tuple(sub_sample_factor) + else: self.sub_sample_factor = tuple([sub_sample_factor]) * dimension + + # Default parameter set + self.mode = mode + self.dimension = dimension + self.sub_sample_kernel_size = self.sub_sample_factor + + # Number of channels (pixel dimensions) + self.in_channels = in_channels + self.gating_channels = gating_channels + self.inter_channels = inter_channels + + if self.inter_channels is None: + self.inter_channels = in_channels // 2 + if self.inter_channels == 0: + self.inter_channels = 1 + + if dimension == 3: + conv_nd = nn.Conv3d + bn = nn.BatchNorm3d + self.upsample_mode = 'trilinear' + elif dimension == 2: + conv_nd = nn.Conv2d + bn = nn.BatchNorm2d + self.upsample_mode = 'bilinear' + else: + raise NotImplemented + + # Output transform + self.W = nn.Sequential( + conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), + bn(self.in_channels), + ) + + # Theta^T * x_ij + Phi^T * gating_signal + bias + self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, + kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=True) + self.phi = conv_nd(in_channels=self.gating_channels, out_channels=self.inter_channels, + kernel_size=(1, 1), stride=1, padding=0, bias=True) + self.psi = conv_nd(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True) + + # Initialise weights + for m in self.children(): + init_weights(m, init_type='kaiming') + + # Define the operation + if mode == 'concatenation': + self.operation_function = self._concatenation + elif mode == 'concatenation_debug': + self.operation_function = self._concatenation_debug + elif mode == 'concatenation_residual': + self.operation_function = self._concatenation_residual + else: + raise NotImplementedError('Unknown operation function.') + + + def forward(self, x, g): + ''' + :param x: (b, c, t, h, w) + :param g: (b, g_d) + :return: + ''' + + output = self.operation_function(x, g) + return output + + def _concatenation(self, x, g): + input_size = x.size() + batch_size = input_size[0] + assert batch_size == g.size(0) + + # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) + # phi => (b, g_d) -> (b, i_c) + theta_x = self.theta(x) + theta_x_size = theta_x.size() + + # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') + # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) + phi_g = F.interpolate(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) + f = F.relu(theta_x + phi_g, inplace=True) + + # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) + sigm_psi_f = torch.sigmoid(self.psi(f)) + + # upsample the attentions and multiply + sigm_psi_f = F.interpolate(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) + y = sigm_psi_f.expand_as(x) * x + W_y = self.W(y) + + return W_y, sigm_psi_f + + def _concatenation_debug(self, x, g): + input_size = x.size() + batch_size = input_size[0] + assert batch_size == g.size(0) + + # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) + # phi => (b, g_d) -> (b, i_c) + theta_x = self.theta(x) + theta_x_size = theta_x.size() + + # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') + # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) + phi_g = F.interpolate(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) + f = F.softplus(theta_x + phi_g) + + # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) + sigm_psi_f = torch.sigmoid(self.psi(f)) + + # upsample the attentions and multiply + sigm_psi_f = F.interpolate(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) + y = sigm_psi_f.expand_as(x) * x + W_y = self.W(y) + + return W_y, sigm_psi_f + + def _concatenation_residual(self, x, g): + input_size = x.size() + batch_size = input_size[0] + assert batch_size == g.size(0) + + # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) + # phi => (b, g_d) -> (b, i_c) + theta_x = self.theta(x) + theta_x_size = theta_x.size() + + # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') + # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) + phi_g = F.interpolate(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) + f = F.relu(theta_x + phi_g, inplace=True) + + # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) + f = self.psi(f).view(batch_size, 1, -1) + sigm_psi_f = F.softmax(f, dim=2).view(batch_size, 1, *theta_x.size()[2:]) + + # upsample the attentions and multiply + sigm_psi_f = F.interpolate(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) + y = sigm_psi_f.expand_as(x) * x + W_y = self.W(y) + + return W_y, sigm_psi_f + + +class GridAttentionBlock2D(_GridAttentionBlockND): + def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation', + sub_sample_factor=(2, 2)): + super(GridAttentionBlock2D, self).__init__(in_channels, + inter_channels=inter_channels, + gating_channels=gating_channels, + dimension=2, mode=mode, + sub_sample_factor=sub_sample_factor, + ) + + +## 2, channel attention class SE_Conv_Block(nn.Module): expansion = 4 @@ -29,6 +338,9 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, drop_out=False): self.stride = stride self.dropout = drop_out + self.globalAvgPool = nn.AdaptiveAvgPool2d(1) + self.globalMaxPool = nn.AdaptiveMaxPool2d(1) + self.fc1 = nn.Linear(in_features=planes * 2, out_features=round(planes / 2)) self.fc2 = nn.Linear(in_features=round(planes / 2), out_features=planes * 2) self.sigmoid = nn.Sigmoid() @@ -54,7 +366,7 @@ def forward(self, x): original_out = out out1 = out # For global average pool - out = F.adaptive_avg_pool2d(out, (1,1)) + out = self.globalAvgPool(out) out = out.view(out.size(0), -1) out = self.fc1(out) out = self.relu(out) @@ -64,7 +376,7 @@ def forward(self, x): avg_att = out out = out * original_out # For global maximum pool - out1 = F.adaptive_max_pool2d(out1, (1,1)) + out1 = self.globalMaxPool(out1) out1 = out1.view(out1.size(0), -1) out1 = self.fc1(out1) out1 = self.relu(out1) @@ -87,175 +399,197 @@ def forward(self, x): return out, att_weight -# # CBAM Convolutional block attention module -class BasicConv(nn.Module): - def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, - relu=True, bn=True, bias=False): - super(BasicConv, self).__init__() - self.out_channels = out_planes - self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, - dilation=dilation, groups=groups, bias=bias) - self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None - self.relu = nn.ReLU() if relu else None +## 3, grid attention +class _GridAttentionBlockND(nn.Module): + def __init__(self, in_channels, gating_channels, inter_channels=None, dimension=3, mode='concatenation', + sub_sample_factor=(2,2,2)): + super(_GridAttentionBlockND, self).__init__() - def forward(self, x): - x = self.conv(x) - if self.bn is not None: - x = self.bn(x) - if self.relu is not None: - x = self.relu(x) - return x + assert dimension in [2, 3] + assert mode in ['concatenation', 'concatenation_debug', 'concatenation_residual'] + # Downsampling rate for the input featuremap + if isinstance(sub_sample_factor, tuple): self.sub_sample_factor = sub_sample_factor + elif isinstance(sub_sample_factor, list): self.sub_sample_factor = tuple(sub_sample_factor) + else: self.sub_sample_factor = tuple([sub_sample_factor]) * dimension -class Flatten(nn.Module): - def forward(self, x): - return x.view(x.size(0), -1) + # Default parameter set + self.mode = mode + self.dimension = dimension + self.sub_sample_kernel_size = self.sub_sample_factor + # Number of channels (pixel dimensions) + self.in_channels = in_channels + self.gating_channels = gating_channels + self.inter_channels = inter_channels -class ChannelGate(nn.Module): - def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): - super(ChannelGate, self).__init__() - self.gate_channels = gate_channels - self.mlp = nn.Sequential( - Flatten(), - nn.Linear(gate_channels, gate_channels // reduction_ratio), - nn.ReLU(), - nn.Linear(gate_channels // reduction_ratio, gate_channels) - ) - self.pool_types = pool_types + if self.inter_channels is None: + self.inter_channels = in_channels // 2 + if self.inter_channels == 0: + self.inter_channels = 1 - def forward(self, x): - channel_att_sum = None - for pool_type in self.pool_types: - if pool_type == 'avg': - avg_pool = F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) - channel_att_raw = self.mlp(avg_pool) - elif pool_type == 'max': - max_pool = F.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) - channel_att_raw = self.mlp(max_pool) - elif pool_type == 'lp': - lp_pool = F.lp_pool2d(x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) - channel_att_raw = self.mlp(lp_pool) - elif pool_type == 'lse': - # LSE pool only - lse_pool = logsumexp_2d(x) - channel_att_raw = self.mlp(lse_pool) + if dimension == 3: + conv_nd = nn.Conv3d + bn = nn.BatchNorm3d + self.upsample_mode = 'trilinear' + elif dimension == 2: + conv_nd = nn.Conv2d + bn = nn.BatchNorm2d + self.upsample_mode = 'bilinear' + else: + raise NotImplemented - if channel_att_sum is None: - channel_att_sum = channel_att_raw - else: - channel_att_sum = channel_att_sum + channel_att_raw + # Output transform + self.W = nn.Sequential( + conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), + bn(self.in_channels), + ) - # scalecoe = F.sigmoid(channel_att_sum) - # print("channel att_sum", channel_att_sum.shape) - # channel_att_sum = channel_att_sum.reshape(channel_att_sum.shape[0], 4, 4) - # avg_weight = torch.mean(channel_att_sum, dim=2).unsqueeze(2) - # avg_weight = avg_weight.expand(channel_att_sum.shape[0], 4, 4).reshape(channel_att_sum.shape[0], 16) - # scale = F.sigmoid(avg_weight).unsqueeze(2).unsqueeze(3).expand_as(x) - scale = F.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x) - return x * scale, scale + # Theta^T * x_ij + Phi^T * gating_signal + bias + self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, + kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=True) + self.phi = conv_nd(in_channels=self.gating_channels, out_channels=self.inter_channels, + kernel_size=(1, 1), stride=1, padding=0, bias=True) + self.psi = conv_nd(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True) + # Initialise weights + for m in self.children(): + init_weights(m, init_type='kaiming') -def logsumexp_2d(tensor): - tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) - s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) - outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() - return outputs + # Define the operation + if mode == 'concatenation': + self.operation_function = self._concatenation + elif mode == 'concatenation_debug': + self.operation_function = self._concatenation_debug + elif mode == 'concatenation_residual': + self.operation_function = self._concatenation_residual + else: + raise NotImplementedError('Unknown operation function.') -class ChannelPool(nn.Module): - def forward(self, x): - return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1) + def forward(self, x, g): + ''' + :param x: (b, c, t, h, w) + :param g: (b, g_d) + :return: + ''' + output = self.operation_function(x, g) + return output -class SpatialGate(nn.Module): - def __init__(self): - super(SpatialGate, self).__init__() - kernel_size = 7 - self.compress = ChannelPool() - self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False) + def _concatenation(self, x, g): + input_size = x.size() + batch_size = input_size[0] + assert batch_size == g.size(0) - def forward(self, x): - x_compress = self.compress(x) - x_out = self.spatial(x_compress) - scale = F.sigmoid(x_out) # broadcasting - return x * scale, scale + # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) + # phi => (b, g_d) -> (b, i_c) + theta_x = self.theta(x) + theta_x_size = theta_x.size() -class SpatialAtten(nn.Module): - def __init__(self, in_size, out_size, kernel_size=3, stride=1): - super(SpatialAtten, self).__init__() - self.conv1 = BasicConv(in_size, out_size, kernel_size, stride=stride, - padding=(kernel_size-1) // 2, relu=True) - self.conv2 = BasicConv(out_size, in_size, kernel_size=1, stride=stride, - padding=0, relu=True, bn=False) + # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') + # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) + phi_g = F.interpolate(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) + f = F.relu(theta_x + phi_g, inplace=True) - def forward(self, x): - residual = x - x_out = self.conv1(x) - x_out = self.conv2(x_out) - spatial_att = F.sigmoid(x_out) - # .unsqueeze(4).permute(0, 1, 4, 2, 3) - # spatial_att = spatial_att.expand(spatial_att.shape[0], 4, 4, spatial_att.shape[3], spatial_att.shape[4]).reshape( - # spatial_att.shape[0], 16, spatial_att.shape[3], spatial_att.shape[4]) - x_out = residual * spatial_att + # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) + sigm_psi_f = torch.sigmoid(self.psi(f)) - x_out += residual + # upsample the attentions and multiply + sigm_psi_f = F.interpolate(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) + y = sigm_psi_f.expand_as(x) * x + W_y = self.W(y) - return x_out, spatial_att + return W_y, sigm_psi_f -class Scale_atten_block(nn.Module): - def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False): - super(Scale_atten_block, self).__init__() - self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) - self.no_spatial = no_spatial - if not no_spatial: - self.SpatialGate = SpatialAtten(gate_channels, gate_channels //reduction_ratio) + def _concatenation_debug(self, x, g): + input_size = x.size() + batch_size = input_size[0] + assert batch_size == g.size(0) - def forward(self, x): - x_out, ca_atten = self.ChannelGate(x) - if not self.no_spatial: - x_out, sa_atten = self.SpatialGate(x_out) + # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) + # phi => (b, g_d) -> (b, i_c) + theta_x = self.theta(x) + theta_x_size = theta_x.size() - return x_out, ca_atten, sa_atten + # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') + # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) + phi_g = F.interpolate(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) + f = F.softplus(theta_x + phi_g) + # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) + sigm_psi_f = torch.sigmoid(self.psi(f)) -class scale_atten_convblock(nn.Module): - def __init__(self, in_size, out_size, stride=1, downsample=None, use_cbam=True, no_spatial=False, drop_out=False): - super(scale_atten_convblock, self).__init__() - self.downsample = downsample - self.stride = stride - self.no_spatial = no_spatial - self.dropout = drop_out + # upsample the attentions and multiply + sigm_psi_f = F.interpolate(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) + y = sigm_psi_f.expand_as(x) * x + W_y = self.W(y) - self.relu = nn.ReLU(inplace=True) - self.conv3 = conv3x3(in_size, out_size) - self.bn3 = nn.BatchNorm2d(out_size) + return W_y, sigm_psi_f - if use_cbam: - self.cbam = Scale_atten_block(in_size, reduction_ratio=4, no_spatial=self.no_spatial) # out_size - else: - self.cbam = None - def forward(self, x): - residual = x + def _concatenation_residual(self, x, g): + input_size = x.size() + batch_size = input_size[0] + assert batch_size == g.size(0) - if self.downsample is not None: - residual = self.downsample(x) + # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) + # phi => (b, g_d) -> (b, i_c) + theta_x = self.theta(x) + theta_x_size = theta_x.size() - if not self.cbam is None: - out, scale_c_atten, scale_s_atten = self.cbam(x) + # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') + # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) + phi_g = F.interpolate(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) + f = F.relu(theta_x + phi_g, inplace=True) - out += residual - out = self.relu(out) - out = self.conv3(out) - out = self.bn3(out) - out = self.relu(out) + # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) + f = self.psi(f).view(batch_size, 1, -1) + sigm_psi_f = F.softmax(f, dim=2).view(batch_size, 1, *theta_x.size()[2:]) - if self.dropout: - out = nn.Dropout2d(0.5)(out) + # upsample the attentions and multiply + sigm_psi_f = F.interpolate(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) + y = sigm_psi_f.expand_as(x) * x + W_y = self.W(y) - return out - + return W_y, sigm_psi_f + + +class GridAttentionBlock2D(_GridAttentionBlockND): + def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation', + sub_sample_factor=(2, 2)): + super(GridAttentionBlock2D, self).__init__(in_channels, + inter_channels=inter_channels, + gating_channels=gating_channels, + dimension=2, mode=mode, + sub_sample_factor=sub_sample_factor, + ) + +class MultiAttentionBlock(nn.Module): + def __init__(self, in_size, gate_size, inter_size, nonlocal_mode, sub_sample_factor): + super(MultiAttentionBlock, self).__init__() + self.gate_block_1 = GridAttentionBlock2D(in_channels=in_size, gating_channels=gate_size, + inter_channels=inter_size, mode=nonlocal_mode, + sub_sample_factor=sub_sample_factor) + self.gate_block_2 = GridAttentionBlock2D(in_channels=in_size, gating_channels=gate_size, + inter_channels=inter_size, mode=nonlocal_mode, + sub_sample_factor=sub_sample_factor) + self.combine_gates = nn.Sequential(nn.Conv2d(in_size*2, in_size, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(in_size), + nn.ReLU(inplace=True)) + + # initialise the blocks + for m in self.children(): + if m.__class__.__name__.find('GridAttentionBlock2D') != -1: continue + init_weights(m, init_type='kaiming') + + def forward(self, input, gating_signal): + gate_1, attention_1 = self.gate_block_1(input, gating_signal) + gate_2, attention_2 = self.gate_block_2(input, gating_signal) + + return self.combine_gates(torch.cat([gate_1, gate_2], 1)), torch.cat([attention_1, attention_2], 1) + +## 4, Non-local layers class _NonLocalBlockND(nn.Module): def __init__(self, in_channels, inter_channels=None, dimension=3, mode='embedded_gaussian', sub_sample_factor=4, bn_layer=True): @@ -300,13 +634,13 @@ def __init__(self, in_channels, inter_channels=None, dimension=3, mode='embedded kernel_size=1, stride=1, padding=0), bn(self.in_channels) ) - nn.init.constant(self.W[1].weight, 0) - nn.init.constant(self.W[1].bias, 0) + nn.init.constant_(self.W[1].weight, 0) + nn.init.constant_(self.W[1].bias, 0) else: self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0) - nn.init.constant(self.W.weight, 0) - nn.init.constant(self.W.bias, 0) + nn.init.constant_(self.W.weight, 0) + nn.init.constant_(self.W.bias, 0) self.theta = None self.phi = None @@ -527,7 +861,7 @@ def _concatenation_proper_down(self, x): y = y.contiguous().view(batch_size, self.inter_channels, *downsampled_size[2:]) # upsample the final featuremaps # (b,0.5c,t/s1,h/s2,w/s3) - y = F.upsample(y, size=x.size()[2:], mode='trilinear') + y = F.interpolate(y, size=x.size()[2:], mode='trilinear') # attention block output W_y = self.W(y) @@ -544,7 +878,191 @@ def __init__(self, in_channels, inter_channels=None, mode='embedded_gaussian', s sub_sample_factor=sub_sample_factor, bn_layer=bn_layer) - +## 5, scale attention +class BasicConv(nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, + relu=True, bn=True, bias=False): + super(BasicConv, self).__init__() + self.out_channels = out_planes + self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, + dilation=dilation, groups=groups, bias=bias) + self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None + self.relu = nn.ReLU() if relu else None + + def forward(self, x): + x = self.conv(x) + if self.bn is not None: + x = self.bn(x) + if self.relu is not None: + x = self.relu(x) + return x + + +class Flatten(nn.Module): + def forward(self, x): + return x.view(x.size(0), -1) + + +class ChannelGate(nn.Module): + def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): + super(ChannelGate, self).__init__() + self.gate_channels = gate_channels + self.mlp = nn.Sequential( + Flatten(), + nn.Linear(gate_channels, gate_channels // reduction_ratio), + nn.ReLU(), + nn.Linear(gate_channels // reduction_ratio, gate_channels) + ) + self.pool_types = pool_types + + def forward(self, x): + channel_att_sum = None + for pool_type in self.pool_types: + if pool_type == 'avg': + avg_pool = F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) + channel_att_raw = self.mlp(avg_pool) + elif pool_type == 'max': + max_pool = F.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) + channel_att_raw = self.mlp(max_pool) + elif pool_type == 'lp': + lp_pool = F.lp_pool2d(x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) + channel_att_raw = self.mlp(lp_pool) + elif pool_type == 'lse': + # LSE pool only + lse_pool = logsumexp_2d(x) + channel_att_raw = self.mlp(lse_pool) + + if channel_att_sum is None: + channel_att_sum = channel_att_raw + else: + channel_att_sum = channel_att_sum + channel_att_raw + + # scalecoe = F.sigmoid(channel_att_sum) + channel_att_sum = channel_att_sum.reshape(channel_att_sum.shape[0], 4, 4) + avg_weight = torch.mean(channel_att_sum, dim=2).unsqueeze(2) + avg_weight = avg_weight.expand(channel_att_sum.shape[0], 4, 4).reshape(channel_att_sum.shape[0], 16) + scale = torch.sigmoid(avg_weight).unsqueeze(2).unsqueeze(3).expand_as(x) + + return x * scale, scale + + +def logsumexp_2d(tensor): + tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) + s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) + outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() + return outputs + + +class ChannelPool(nn.Module): + def forward(self, x): + return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1) + + +class SpatialGate(nn.Module): + def __init__(self): + super(SpatialGate, self).__init__() + kernel_size = 7 + self.compress = ChannelPool() + self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False) + + def forward(self, x): + x_compress = self.compress(x) + x_out = self.spatial(x_compress) + scale = torch.sigmoid(x_out) # broadcasting + # spa_scale = scale.expand_as(x) + # print(spa_scale.shape) + return x * scale, scale + +class SpatialAtten(nn.Module): + def __init__(self, in_size, out_size, kernel_size=3, stride=1): + super(SpatialAtten, self).__init__() + self.conv1 = BasicConv(in_size, out_size, kernel_size, stride=stride, + padding=(kernel_size-1) // 2, relu=True) + self.conv2 = BasicConv(out_size, out_size, kernel_size=1, stride=stride, + padding=0, relu=True, bn=False) + + def forward(self, x): + residual = x + x_out = self.conv1(x) + x_out = self.conv2(x_out) + spatial_att = torch.sigmoid(x_out).unsqueeze(4).permute(0, 1, 4, 2, 3) + spatial_att = spatial_att.expand(spatial_att.shape[0], 4, 4, spatial_att.shape[3], spatial_att.shape[4]).reshape( + spatial_att.shape[0], 16, spatial_att.shape[3], spatial_att.shape[4]) + x_out = residual * spatial_att + + x_out += residual + + return x_out, spatial_att + +class Scale_atten_block(nn.Module): + def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False): + super(Scale_atten_block, self).__init__() + self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) + self.no_spatial = no_spatial + if not no_spatial: + self.SpatialGate = SpatialAtten(gate_channels, gate_channels //reduction_ratio) + + def forward(self, x): + x_out, ca_atten = self.ChannelGate(x) + if not self.no_spatial: + x_out, sa_atten = self.SpatialGate(x_out) + + return x_out, ca_atten, sa_atten + + +class scale_atten_convblock(nn.Module): + def __init__(self, in_size, out_size, stride=1, downsample=None, use_cbam=True, no_spatial=False, drop_out=False): + super(scale_atten_convblock, self).__init__() + # if stride != 1 or in_size != out_size: + # downsample = nn.Sequential( + # nn.Conv2d(in_size, out_size, + # kernel_size=1, stride=stride, bias=False), + # nn.BatchNorm2d(out_size), + # ) + self.downsample = downsample + self.stride = stride + self.no_spatial = no_spatial + self.dropout = drop_out + + self.relu = nn.ReLU(inplace=True) + self.conv3 = conv3x3(in_size, out_size) + self.bn3 = nn.BatchNorm2d(out_size) + + if use_cbam: + self.cbam = Scale_atten_block(in_size, reduction_ratio=4, no_spatial=self.no_spatial) # out_size + else: + self.cbam = None + + def forward(self, x): + residual = x + + if self.downsample is not None: + residual = self.downsample(x) + + if not self.cbam is None: + out, scale_c_atten, scale_s_atten = self.cbam(x) + + # scale_c_atten = nn.Sigmoid()(scale_c_atten) + # scale_s_atten = nn.Sigmoid()(scale_s_atten) + # scale_atten = channel_atten_c * spatial_atten_s + + # scale_max = torch.argmax(scale_atten, dim=1, keepdim=True) + # scale_max_soft = get_soft_label(input_tensor=scale_max, num_class=8) + # scale_max_soft = scale_max_soft.permute(0, 3, 1, 2) + # scale_atten_soft = scale_atten * scale_max_soft + + out += residual + out = self.relu(out) + out = self.conv3(out) + out = self.bn3(out) + out = self.relu(out) + + if self.dropout: + out = nn.Dropout2d(0.5)(out) + + return out + +## 6, CANet class CANet(nn.Module): """ Implementation of CANet (Comprehensive Attention Network) for image segmentation. @@ -559,24 +1077,25 @@ class CANet(nn.Module): :param in_chns: (int) Input channel number. :param feature_chns: (list) Feature channel for each resolution level. The length should be 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. :param class_num: (int) The class number for segmentation task. - :param bilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. + :param is_deconv: (bool) Using deconvolution for up-sampling or not. + If False, bilinear interpolation will be used for up-sampling. Default is True. + :param is_batchnorm: (bool) If batch normalization is or not. Default is True. + :param feature_scale: (int) The scale of resolution levels. Default is 4. """ def __init__(self, params): #args, in_ch=3, n_classes=2, feature_scale=4, is_deconv=True, is_batchnorm=True, # nonlocal_mode='concatenation', attention_dsample=(1, 1)): super(CANet, self).__init__() self.in_channels = params['in_chns'] self.num_classes = params['class_num'] + self.feature_chns= params.get('feature_chns', [32, 64, 128, 256, 512]) self.is_deconv = params.get('is_deconv', True) self.is_batchnorm = params.get('is_batchnorm', True) self.feature_scale = params.get('feature_scale', 4) nonlocal_mode = 'concatenation' attention_dsample = (1, 1) - filters = [64, 128, 256, 512, 1024] + filters = self.feature_chns filters = [int(x / self.feature_scale) for x in filters] # downsampling @@ -623,6 +1142,13 @@ def __init__(self, params): #args, in_ch=3, n_classes=2, feature_scale=4, is_de self.final = nn.Conv2d(filters[0], self.num_classes, kernel_size=1) def forward(self, inputs): + x_shape = list(inputs.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + inputs = torch.transpose(inputs, 1, 2) + inputs = torch.reshape(inputs, new_shape) + # Feature Extraction conv1 = self.conv1(inputs) maxpool1 = self.maxpool1(conv1) @@ -671,6 +1197,14 @@ def forward(self, inputs): out = self.scale_att(dsv_cat) out = self.final(out) + if(len(x_shape) == 5): + if(isinstance(out, (list,tuple))): + for i in range(len(out)): + new_shape = [N, D] + list(out[i].shape)[1:] + out[i] = torch.transpose(torch.reshape(out[i], new_shape), 1, 2) + else: + new_shape = [N, D] + list(out.shape)[1:] + out = torch.transpose(torch.reshape(out, new_shape), 1, 2) return out From 1aebb5e053309639bb7720f77281fd13e735d5f9 Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 2 Oct 2024 15:51:18 +0800 Subject: [PATCH 59/86] Update unet2d_canet.py --- pymic/net/net2d/unet2d_canet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymic/net/net2d/unet2d_canet.py b/pymic/net/net2d/unet2d_canet.py index 0af3159..f578025 100644 --- a/pymic/net/net2d/unet2d_canet.py +++ b/pymic/net/net2d/unet2d_canet.py @@ -1216,7 +1216,7 @@ def forward(self, inputs): x = np.random.rand(4, 3, 224, 224) xt = torch.from_numpy(x) - xt = torch.tensor(xt) + xt = xt.clone().detach() y = Net(xt) print(len(y.size())) From d0ebdc2f8c193d431224689406e70018ab4e93a6 Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 2 Oct 2024 21:05:32 +0800 Subject: [PATCH 60/86] update unetpp --- pymic/net/net2d/{unet2d_nest.py => unet2d_pp.py} | 8 ++++---- pymic/net/net_dict_seg.py | 14 +++++++------- 2 files changed, 11 insertions(+), 11 deletions(-) rename pymic/net/net2d/{unet2d_nest.py => unet2d_pp.py} (96%) diff --git a/pymic/net/net2d/unet2d_nest.py b/pymic/net/net2d/unet2d_pp.py similarity index 96% rename from pymic/net/net2d/unet2d_nest.py rename to pymic/net/net2d/unet2d_pp.py index efa048f..f2a003b 100644 --- a/pymic/net/net2d/unet2d_nest.py +++ b/pymic/net/net2d/unet2d_pp.py @@ -3,9 +3,9 @@ import torch.nn as nn from pymic.net.net2d.unet2d import * -class NestedUNet2D(nn.Module): +class UNet2Dpp(nn.Module): """ - An implementation of the Nested U-Net. + An implementation of the U-Net++. * Reference: Zongwei Zhou, et al.: `UNet++: A Nested U-Net Architecture for Medical Image Segmentation. `_ @@ -25,7 +25,7 @@ class NestedUNet2D(nn.Module): :param class_num: (int) The class number for segmentation task. """ def __init__(self, params): - super(NestedUNet2D, self).__init__() + super(UNet2Dpp, self).__init__() self.params = params self.in_chns = self.params['in_chns'] self.filters = self.params['feature_chns'] @@ -96,7 +96,7 @@ def forward(self, x): 'feature_chns':[2, 8, 32, 48, 64], 'dropout': [0, 0, 0.3, 0.4, 0.5], 'class_num': 2} - Net = NestedUNet2D(params) + Net = UNet2Dpp(params) Net = Net.double() x = np.random.rand(4, 4, 10, 96, 96) diff --git a/pymic/net/net_dict_seg.py b/pymic/net/net_dict_seg.py index 741eed7..99d1693 100644 --- a/pymic/net/net_dict_seg.py +++ b/pymic/net/net_dict_seg.py @@ -22,7 +22,7 @@ from pymic.net.net2d.unet2d_mtnet import MTNet2D from pymic.net.net2d.cople_net import COPLENet from pymic.net.net2d.unet2d_attention import AttentionUNet2D -from pymic.net.net2d.unet2d_nest import NestedUNet2D +from pymic.net.net2d.unet2d_pp import UNet2Dpp from pymic.net.net2d.unet2d_scse import UNet2D_ScSE from pymic.net.net2d.trans2d.transunet import TransUNet from pymic.net.net2d.trans2d.swinunet import SwinUNet @@ -50,15 +50,15 @@ # from pymic.net.net3d.trans3d.SwitchNet import SwitchNet SegNetDict = { + 'AttentionUNet2D': AttentionUNet2D, + 'CANet': CANet, + 'COPLENet': COPLENet, + 'MCNet2D': MCNet2D, + 'MTNet2D': MTNet2D, 'UNet2D': UNet2D, 'UNet2D_DualBranch': UNet2D_DualBranch, 'UNet2D_CCT': UNet2D_CCT, - 'MCNet2D': MCNet2D, - 'MTNet2D': MTNet2D, - 'CANet': CANet, - 'COPLENet': COPLENet, - 'AttentionUNet2D': AttentionUNet2D, - 'NestedUNet2D': NestedUNet2D, + 'UNet2Dpp': UNet2Dpp, 'UNet2D_ScSE': UNet2D_ScSE, 'TransUNet': TransUNet, 'SwinUNet': SwinUNet, From 9670b63225fbd0cc8fe1989bb08a5a670913bd96 Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 3 Oct 2024 21:41:13 +0800 Subject: [PATCH 61/86] add lcovnet --- pymic/net/net3d/lcovnet.py | 246 +++++++++++++++++++++++++++++++++++++ pymic/net/net_dict_seg.py | 2 + pymic/test/test_net3d.py | 97 ++++++++++++++- 3 files changed, 344 insertions(+), 1 deletion(-) create mode 100644 pymic/net/net3d/lcovnet.py diff --git a/pymic/net/net3d/lcovnet.py b/pymic/net/net3d/lcovnet.py new file mode 100644 index 0000000..bd91878 --- /dev/null +++ b/pymic/net/net3d/lcovnet.py @@ -0,0 +1,246 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import logging +import torch +import torch.nn as nn +import numpy as np +from pymic.net.net_init import Initialization_He, Initialization_XavierUniform + +class UnetBlock_Encode(nn.Module): + def __init__(self, in_channels, out_channel): + super(UnetBlock_Encode, self).__init__() + + self.in_chns = in_channels + self.out_chns = out_channel + + self.conv1 = nn.Sequential( + nn.Conv3d(self.in_chns, self.out_chns, kernel_size=(1, 1, 3), + padding=(0, 0, 1)), + nn.BatchNorm3d(self.out_chns), + nn.ReLU6(inplace=True) + ) + + self.conv2_1 = nn.Sequential( + nn.Conv3d(self.out_chns, self.out_chns, kernel_size=(3, 3, 1), + padding=(1, 1, 0), groups=1), + nn.BatchNorm3d(self.out_chns), + nn.ReLU6(inplace=True), + nn.Dropout(p=0.2) + ) + + self.conv2_2 = nn.Sequential( + nn.AvgPool3d(kernel_size=4, stride=2, padding=1), + nn.Conv3d(self.out_chns, self.out_chns, kernel_size=1, + padding=0), + nn.BatchNorm3d(self.out_chns), + nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False) + ) + + def forward(self, x): + # print(x.shape) + x = self.conv1(x) + + x1 = self.conv2_1(x) + x2 = self.conv2_2(x) + x2 = torch.sigmoid(x2) + x = x1 + x2 * x + return x + + +class UnetBlock_Encode_BottleNeck(nn.Module): + def __init__(self, in_channels, out_channel): + super(UnetBlock_Encode_BottleNeck, self).__init__() + + self.in_chns = in_channels + self.out_chns = out_channel + + self.conv1 = nn.Sequential( + nn.Conv3d(self.in_chns, self.out_chns, kernel_size=(1, 1, 3), + padding=(0, 0, 1)), + nn.BatchNorm3d(self.out_chns), + nn.ReLU6(inplace=True) + ) + + self.conv2_1 = nn.Sequential( + nn.Conv3d(self.out_chns, self.out_chns, kernel_size=(3, 3, 1), + padding=(1, 1, 0), groups=self.out_chns), + nn.BatchNorm3d(self.out_chns), + nn.ReLU6(inplace=True), + nn.Dropout(p=0.2) + ) + + self.conv2_2 = nn.Sequential( + # nn.AvgPool3d(kernel_size=4, stride=2), + nn.Conv3d(self.out_chns, self.out_chns, kernel_size=1, + padding=0), + nn.BatchNorm3d(self.out_chns), + nn.ReLU6(inplace=True), + nn.Dropout(p=0.2) + ) + + def forward(self, x): + x = self.conv1(x) + + x1 = self.conv2_1(x) + x2 = self.conv2_2(x) + x2 = torch.sigmoid(x2) + x = x1 + x2 * x + return x + + +class UnetBlock_Down(nn.Module): + def __init__(self): + super(UnetBlock_Down, self).__init__() + self.avg_pool = nn.MaxPool3d(kernel_size=2, stride=2) + + def forward(self, x): + x = self.avg_pool(x) + return x + + +class UnetBlock_Up(nn.Module): + def __init__(self, in_channels, out_channel): + super(UnetBlock_Up, self).__init__() + self.conv = self.conv1 = nn.Sequential( + nn.Conv3d(in_channels, out_channel, kernel_size=1, + padding=0, groups=1), + nn.BatchNorm3d(out_channel), + nn.ReLU6(inplace=True), + nn.Dropout(p=0.2) + ) + + self.up = nn.Upsample( + scale_factor=2, mode='trilinear', align_corners=False) + + def forward(self, x): + x = self.conv(x) + x = self.up(x) + return x + + +class LCOVNet(nn.Module): + """ + An implementation of the LCOVNet. + + * Reference: Q. Zhao, L. Zhong, J. Xiao, J. Zhang, Y. Chen , W. Liao, S. Zhang, and G. Wang: + Efficient Multi-Organ Segmentation From 3D Abdominal CT Images With Lightweight Network and Knowledge Distillation. + `IEEE TMI 42(9) 2023: 2513 - 2523. `_ + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear`). The default value + is 2 (`Trilinear`). + :param multiscale_pred: (bool) Get multi-scale prediction. + """ + def __init__(self, params): + super(LCOVNet, self).__init__() + params = self.get_default_parameters(params) + for p in params: + print(p, params[p]) + self.stage = 'train' + # C_in=32, n_classes=17, m=1, is_ds=True): + + in_chns = params['in_chns'] + n_class = params['class_num'] + self.ft_chns = params['feature_chns'] + self.mul_pred = params.get('multiscale_pred', False) + + self.Encode_block1 = UnetBlock_Encode(in_chns, self.ft_chns[0]) + self.down1 = UnetBlock_Down() + + self.Encode_block2 = UnetBlock_Encode(self.ft_chns[0], self.ft_chns[1]) + self.down2 = UnetBlock_Down() + + self.Encode_block3 = UnetBlock_Encode(self.ft_chns[1], self.ft_chns[2]) + self.down3 = UnetBlock_Down() + + self.Encode_block4 = UnetBlock_Encode(self.ft_chns[2], self.ft_chns[3]) + self.down4 = UnetBlock_Down() + + self.Encode_BottleNeck_block5 = UnetBlock_Encode_BottleNeck( + self.ft_chns[3], self.ft_chns[4]) + + self.up1 = UnetBlock_Up(self.ft_chns[4], self.ft_chns[3]) + self.Decode_block1 = UnetBlock_Encode( + self.ft_chns[3]*2, self.ft_chns[3]) + self.segout1 = nn.Conv3d( + self.ft_chns[3], n_class, kernel_size=1, padding=0) + + self.up2 = UnetBlock_Up(self.ft_chns[3], self.ft_chns[2]) + self.Decode_block2 = UnetBlock_Encode( + self.ft_chns[2]*2, self.ft_chns[2]) + self.segout2 = nn.Conv3d( + self.ft_chns[2], n_class, kernel_size=1, padding=0) + + self.up3 = UnetBlock_Up(self.ft_chns[2], self.ft_chns[1]) + self.Decode_block3 = UnetBlock_Encode( + self.ft_chns[1]*2, self.ft_chns[1]) + self.segout3 = nn.Conv3d( + self.ft_chns[1], n_class, kernel_size=1, padding=0) + + self.up4 = UnetBlock_Up(self.ft_chns[1], self.ft_chns[0]) + self.Decode_block4 = UnetBlock_Encode( + self.ft_chns[0]*2, self.ft_chns[0]) + self.segout4 = nn.Conv3d( + self.ft_chns[0], n_class, kernel_size=1, padding=0) + + def get_default_parameters(self, params): + default_param = { + 'feature_chns': [32, 64, 128, 256, 512], + 'initialization': 'he', + 'multiscale_pred': False + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params + + def forward(self, x): + _x1 = self.Encode_block1(x) + x1 = self.down1(_x1) + + _x2 = self.Encode_block2(x1) + x2 = self.down2(_x2) + + _x3 = self.Encode_block3(x2) + x3 = self.down2(_x3) + + _x4 = self.Encode_block4(x3) + x4 = self.down2(_x4) + + x5 = self.Encode_BottleNeck_block5(x4) + + x6 = self.up1(x5) + x6 = torch.cat((x6, _x4), dim=1) + x6 = self.Decode_block1(x6) + segout1 = self.segout1(x6) + + x7 = self.up2(x6) + x7 = torch.cat((x7, _x3), dim=1) + x7 = self.Decode_block2(x7) + segout2 = self.segout2(x7) + + x8 = self.up3(x7) + x8 = torch.cat((x8, _x2), dim=1) + x8 = self.Decode_block3(x8) + segout3 = self.segout3(x8) + + x9 = self.up4(x8) + x9 = torch.cat((x9, _x1), dim=1) + x9 = self.Decode_block4(x9) + segout4 = self.segout4(x9) + + if (self.mul_pred == True and self.stage == 'train'): + return [segout4, segout3, segout2, segout1] + else: + return segout4 \ No newline at end of file diff --git a/pymic/net/net_dict_seg.py b/pymic/net/net_dict_seg.py index 99d1693..687710b 100644 --- a/pymic/net/net_dict_seg.py +++ b/pymic/net/net_dict_seg.py @@ -30,6 +30,7 @@ from pymic.net.net3d.unet3d import UNet3D from pymic.net.net3d.grunet import GRUNet from pymic.net.net3d.fmunetv3 import FMUNetV3 +from pymic.net.net3d.lcovnet import LCOVNet from pymic.net.net3d.unet3d_scse import UNet3D_ScSE from pymic.net.net3d.unet3d_dual_branch import UNet3D_DualBranch # from pymic.net.net3d.stunet_wrap import STUNet_wrap @@ -64,6 +65,7 @@ 'SwinUNet': SwinUNet, 'UNet2D5': UNet2D5, 'GRUNet': GRUNet, + 'LCOVNet': LCOVNet, 'FMUNetV3': FMUNetV3, 'UNet3D': UNet3D, 'UNet3D_ScSE': UNet3D_ScSE, diff --git a/pymic/test/test_net3d.py b/pymic/test/test_net3d.py index 180dcff..058a6fd 100644 --- a/pymic/test/test_net3d.py +++ b/pymic/test/test_net3d.py @@ -6,6 +6,9 @@ from pymic.net.net3d.unet3d import UNet3D from pymic.net.net3d.unet3d_scse import UNet3D_ScSE from pymic.net.net3d.unet2d5 import UNet2D5 +from pymic.net.net3d.grunet import GRUNet +from pymic.net.net3d.lcovnet import LCOVNet +from pymic.net.net3d.trans3d.unetr_pp import UNETR_PP def test_unet3d(): params = {'in_chns':4, @@ -61,6 +64,22 @@ def test_unet3d_scse(): y = y.detach().numpy() print(y.shape) +def test_lcovnet(): + params = {'in_chns':4, + 'feature_chns':[16, 32, 64, 128, 256], + 'class_num': 2} + Net = LCOVNet(params) + Net = Net.double() + + x = np.random.rand(4, 4, 96, 96, 96) + xt = torch.from_numpy(x) + xt = xt.clone().detach() + + y = Net(xt) + print(len(y.size())) + y = y.detach().numpy() + print(y.shape) + def test_unet2d5(): params = {'in_chns':4, 'feature_chns':[8, 16, 32, 64, 128], @@ -100,9 +119,85 @@ def test_unet2d5(): y = y.detach().numpy() print(y.shape) +def test_mystunet(): + in_chns = 4 + num_class = 4 + # input_channels, num_classes, depth=[1,1,1,1,1,1], dims=[32, 64, 128, 256, 512, 512], + # pool_op_kernel_sizes=None, conv_kernel_sizes=None) + dims=[16, 32, 64, 128, 256, 512] + Net = MySTUNet(in_chns, num_class, dims = dims, pool_op_kernel_sizes = [[2, 2, 2], [2,2,2], [2,2,2], [2,2,2], [1, 1, 1]], + conv_kernel_sizes = [[3, 3, 3], [3,3,3], [3,3,3], [3,3,3], [3,3,3], [3, 3, 3]]) + Net = Net.double() + + x = np.random.rand(4, 4, 96, 96, 96) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + out = Net(xt) + for y in out: + y = y.detach().numpy() + print(y.shape) + +def test_grunet(): + params = {'in_chns':4, + 'feature_chns':[8, 16, 32, 64, 128], + 'dims': [2, 3, 3, 3, 3], + 'dropout': [0, 0, 0.3, 0.4, 0.5], + 'class_num': 2, + 'depth': 2, + 'multiscale_pred': True} + x = np.random.rand(4, 4, 64, 128, 128) + + # params = {'in_chns':4, + # 'feature_chns':[8, 16, 32, 64, 128], + # 'dims': [3, 3, 3, 3, 3], + # 'dropout': [0, 0, 0.3, 0.4, 0.5], + # 'class_num': 2, + # 'depth': 4, + # 'multiscale_pred': True} + # x = np.random.rand(4, 4, 96, 96, 96) + + Net = GRUNet(params) + Net = Net.double() + + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + out = Net(xt) + for y in out: + y = y.detach().numpy() + print(y.shape) + +def test_unetr_pp(): + depths = [128, 64, 32] + for i in range(3): + params = {'in_chns': 4, + 'class_num': 2, + 'img_size': [depths[i], 128, 128], + 'resolution_mode': i + } + net = UNETR_PP(params) + net.double() + + x = np.random.rand(2, 4, depths[i], 128, 128) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = net(xt) + print(len(y)) + for yi in y: + yi = yi.detach().numpy() + print(yi.shape) + + + if __name__ == "__main__": # test_unet3d() # test_unet3d_scse() - test_unet2d5() + test_lcovnet() + # test_unetr_pp() + # test_unet2d5() + # test_mystunet() + # test_fmunetv2() \ No newline at end of file From a74efa89ff91b4da59bfcd78152c257bd6b01515 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 4 Oct 2024 15:23:41 +0800 Subject: [PATCH 62/86] Create unet2d_multi_decoder.py --- pymic/net/net2d/unet2d_multi_decoder.py | 162 ++++++++++++++++++++++++ 1 file changed, 162 insertions(+) create mode 100644 pymic/net/net2d/unet2d_multi_decoder.py diff --git a/pymic/net/net2d/unet2d_multi_decoder.py b/pymic/net/net2d/unet2d_multi_decoder.py new file mode 100644 index 0000000..2c7b3c4 --- /dev/null +++ b/pymic/net/net2d/unet2d_multi_decoder.py @@ -0,0 +1,162 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import torch.nn as nn +from pymic.net.net2d.unet2d import * + +class UNet2D_DualBranch(nn.Module): + """ + A dual branch network using UNet2D as backbone. + + * Reference: Xiangde Luo, Minhao Hu, Wenjun Liao, Shuwei Zhai, Tao Song, Guotai Wang, + Shaoting Zhang. ScribblScribble-Supervised Medical Image Segmentation via + Dual-Branch Network and Dynamically Mixed Pseudo Labels Supervision. + `MICCAI 2022. `_ + + The parameters for the backbone should be given in the `params` dictionary. + See :mod:`pymic.net.net2d.unet2d.UNet2D` for details. + In addition, the following field should be included: + + :param output_mode: (str) How to obtain the result during the inference. + `average`: taking average of the two branches. + `first`: takeing the result in the first branch. + `second`: taking the result in the second branch. + """ + def __init__(self, params): + super(UNet2D_DualBranch, self).__init__() + params = self.get_default_parameters(params) + self.output_mode = params["output_mode"] + self.encoder = Encoder(params) + self.decoder1 = Decoder(params) + self.decoder2 = Decoder(params) + + def get_default_parameters(self, params): + default_param = { + 'feature_chns': [32, 64, 128, 256, 512], + 'dropout': [0.0, 0.0, 0.2, 0.3, 0.4], + 'up_mode': 2, + 'multiscale_pred': False, + 'output_mode': "average" + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + + f = self.encoder(x) + output1 = self.decoder1(f) + output2 = self.decoder2(f) + if(len(x_shape) == 5): + new_shape = [N, D] + list(output1.shape)[1:] + output1 = torch.reshape(output1, new_shape) + output1 = torch.transpose(output1, 1, 2) + output2 = torch.reshape(output2, new_shape) + output2 = torch.transpose(output2, 1, 2) + + if(self.training): + return output1, output2 + else: + if(self.output_mode == "average"): + return (output1 + output2)/2 + elif(self.output_mode == "first"): + return output1 + else: + return output2 + +class UNet2D_TriBranch(nn.Module): + """ + A tri-branch network using UNet2D as backbone. The super class for MCNet2D and MTNet2D. + + The parameters for the backbone should be given in the `params` dictionary. + See :mod:`pymic.net.net2d.unet2d.UNet2D` for details. + """ + def __init__(self, params): + super(UNet2D_TriBranch, self).__init__() + params = self.get_default_parameters(params) + self.encoder = Encoder(params) + self.decoder1 = Decoder(params) + self.decoder2 = Decoder(params) + self.decoder3 = Decoder(params) + + def get_default_parameters(self, params): + default_param = { + 'feature_chns': [32, 64, 128, 256, 512], + 'dropout': [0.0, 0.0, 0.2, 0.3, 0.4], + 'up_mode': 2, + 'multiscale_pred': False, + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + + feature = self.encoder(x) + output1 = self.decoder1(feature) + new_shape = [N, D] + list(output1.shape)[1:] + output1 = torch.transpose(torch.reshape(output1, new_shape), 1, 2) + if(not self.training): + return output1 + output2 = self.decoder2(feature) + output3 = self.decoder3(feature) + if(len(x_shape) == 5): + output2 = torch.transpose(torch.reshape(output2, new_shape), 1, 2) + output3 = torch.transpose(torch.reshape(output3, new_shape), 1, 2) + return output1, output2, output3 + +class MCNet2D(UNet2D_TriBranch): + """ + A tri-branch network using UNet2D as backbone. + + * Reference: Yicheng Wu, Zongyuan Ge et al. Mutual consistency learning for + semi-supervised medical image segmentation. + `Medical Image Analysis 2022. `_ + + The original code is at: https://github.com/ycwu1997/MC-Net + + The parameters for the backbone should be given in the `params` dictionary. + See :mod:`pymic.net.net2d.unet2d.UNet2D` for details. + """ + def __init__(self, params): + super(MCNet2D, self).__init__(params) + in_chns = params['in_chns'] + class_num = params['class_num'] + ft_chns = params['feature_chns'] + dropout = params['dropout'] + params1 = {'in_chns': in_chns, + 'feature_chns': ft_chns, + 'dropout': dropout, + 'class_num': class_num, + 'up_mode': 0 } + params2 = {'in_chns': in_chns, + 'feature_chns': ft_chns, + 'dropout': dropout, + 'class_num': class_num, + 'up_mode': 1 } + params3 = {'in_chns': in_chns, + 'feature_chns': ft_chns, + 'dropout': dropout, + 'class_num': class_num, + 'up_mode': 2 } + self.encoder = Encoder(params1) + self.decoder1 = Decoder(params1) + self.decoder2 = Decoder(params2) + self.decoder3 = Decoder(params3) \ No newline at end of file From 59c82f3ff8230be5586bae6f40000dfdf42b18b4 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 4 Oct 2024 15:27:21 +0800 Subject: [PATCH 63/86] update 3D networks --- pymic/net/net3d/fmunetv3.py | 262 ++++++++++++++++++++++++++++++++++++ pymic/net/net_dict_seg.py | 4 +- 2 files changed, 264 insertions(+), 2 deletions(-) create mode 100644 pymic/net/net3d/fmunetv3.py diff --git a/pymic/net/net3d/fmunetv3.py b/pymic/net/net3d/fmunetv3.py new file mode 100644 index 0000000..1367209 --- /dev/null +++ b/pymic/net/net3d/fmunetv3.py @@ -0,0 +1,262 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import itertools +import logging +import torch +import torch.nn as nn +from pymic.net.net_init import Initialization_He, Initialization_XavierUniform + +dim0 = {0:3, 1:2, 2:2} +dim1 = {0:3, 1:3, 2:2} +conv_knl = {2: (1, 3, 3), 3: 3} +conv_pad = {2: (0, 1, 1), 3: 1} +pool_knl = {2: (1, 2, 2), 3: 2} +down_stride = {2: (1, 2, 2), 3: 2} + +class ResConv(nn.Module): + def __init__(self, out_channels, dim = 3, dropout_p = 0.0, depth = 2): + super(ResConv, self).__init__() + assert(dim == 2 or dim == 3) + self.out_channels = out_channels + self.conv_list = nn.ModuleList([nn.Sequential( + nn.InstanceNorm3d(out_channels, affine = True), + nn.LeakyReLU(), + nn.Dropout(dropout_p), + nn.Conv3d(out_channels, out_channels, kernel_size=conv_knl[dim], padding=conv_pad[dim])) + for i in range(depth)]) + + def forward(self, x): + for conv in self.conv_list: + x = conv(x) + x + return x + +class DownSample(nn.Module): + """downsampling based on convolution + + :param in_channels: (int) Input channel number. + :param out_channels: (int) Output channel number. + :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. + :param dropout_p: (int) Dropout probability. + :param downsample: (bool) Use downsample or not after convolution. + """ + def __init__(self, in_channels, out_channels, dim = 3): + super(DownSample, self).__init__() + self.down = nn.Sequential( + nn.InstanceNorm3d(in_channels, affine = True), + nn.LeakyReLU(), + nn.Conv3d(in_channels, out_channels, kernel_size=conv_knl[dim], + padding=conv_pad[dim], stride = down_stride[dim]) + ) + + def forward(self, x): + return self.down(x) + +class UpCatConv(nn.Module): + """Upsampling followed by `ResConv` block + + :param in_channels1: (int) Input channel number for low-resolution feature map. + :param in_channels2: (int) Input channel number for high-resolution feature map. + :param out_channels: (int) Output channel number. + :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. + :param dropout_p: (int) Dropout probability. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear` for 3D and `Bilinear` for 2D). + The default value is 2. + """ + def __init__(self, in_channels1, in_channels2, out_channels, dim = 3): + super(UpCatConv, self).__init__() + + self.up = nn.Sequential( + nn.InstanceNorm3d(in_channels1, affine = True), + nn.LeakyReLU(), + nn.Conv3d(in_channels1, in_channels2, kernel_size=1, padding=0), + nn.Upsample(scale_factor=pool_knl[dim], mode='trilinear', align_corners=True) + ) + + self.conv = nn.Sequential( + nn.InstanceNorm3d(in_channels2*2, affine = True), + nn.LeakyReLU(), + nn.Conv3d(in_channels2 * 2, out_channels, kernel_size=conv_knl[dim], padding=conv_pad[dim]) + ) + + def forward(self, x_l, x_h): + """ + x_l: low-resolution feature map. + x_h: high-resolution feature map. + """ + y = torch.cat([x_h, self.up(x_l)], dim=1) + return self.conv(y) + +class Encoder(nn.Module): + """ + A modification of the encoder of 3D UNet by using ConvScSEBlock3D + + Parameters are given in the `params` dictionary. + See :mod:`pymic.net.net3d.unet3d.Encoder` for details. + + res_mode: resolution mode: 0-- isotrpic, 1-- near isotrpic, 2-- isotropic + """ + def __init__(self, ft_chns, res_mode = 0, dropout_p = 0, depth = 2): + super(Encoder, self).__init__() + d0, d1 = dim0[res_mode], dim1[res_mode] + + self.en_conv0 = ResConv(ft_chns[0], d0, 0, depth) + self.en_conv1 = ResConv(ft_chns[1], d1, 0, depth) + self.en_conv2 = ResConv(ft_chns[2], 3, dropout_p, depth) + self.en_conv3 = ResConv(ft_chns[3], 3, dropout_p, depth) + self.en_conv4 = ResConv(ft_chns[4], 3, dropout_p, depth) + + self.down0 = DownSample(ft_chns[0], ft_chns[1], d0) + self.down1 = DownSample(ft_chns[1], ft_chns[2], d1) + self.down2 = DownSample(ft_chns[2], ft_chns[3], 3) + self.down3 = DownSample(ft_chns[3], ft_chns[4], 3) + + def forward(self, x): + x0 = self.en_conv0(x) + x1 = self.en_conv1(self.down0(x0)) + x2 = self.en_conv2(self.down1(x1)) + x3 = self.en_conv3(self.down2(x2)) + x4 = self.en_conv4(self.down3(x3)) + return [x0, x1, x2, x3, x4] + +class Decoder(nn.Module): + """ + A modification of the encoder of 3D UNet by using ConvScSEBlock3D + + Parameters are given in the `params` dictionary. + See :mod:`pymic.net.net3d.unet3d.Encoder` for details. + """ + def __init__(self, ft_chns, res_mode = 0, dropout_p = 0, depth = 2): + super(Decoder, self).__init__() + d0, d1 = dim0[res_mode], dim1[res_mode] + + self.upcat0 = UpCatConv(ft_chns[1], ft_chns[0], ft_chns[0], d0) + self.upcat1 = UpCatConv(ft_chns[2], ft_chns[1], ft_chns[1], d1) + self.upcat2 = UpCatConv(ft_chns[3], ft_chns[2], ft_chns[2], 3) + self.upcat3 = UpCatConv(ft_chns[4], ft_chns[3], ft_chns[3], 3) + + self.de_conv0 = ResConv(ft_chns[0], d0, 0, depth) + self.de_conv1 = ResConv(ft_chns[1], d1, 0, depth) + self.de_conv2 = ResConv(ft_chns[2], 3, dropout_p, depth) + self.de_conv3 = ResConv(ft_chns[3], 3, dropout_p, depth) + self.de_conv4 = ResConv(ft_chns[4], 3, dropout_p, depth) + + def forward(self, x): + x0, x1, x2, x3, x4 = x + x4_de = self.de_conv4(x4) + x3_de = self.de_conv3(self.upcat3(x4_de, x3)) + x2_de = self.de_conv2(self.upcat2(x3_de, x2)) + x1_de = self.de_conv1(self.upcat1(x2_de, x1)) + x0_de = self.de_conv0(self.upcat0(x1_de, x0)) + return [x0_de, x1_de, x2_de, x3_de] + +class FMUNetV3(nn.Module): + """ + A 2.5D network combining 3D convolutions with 2D convolutions. + + * Reference: Guotai Wang, Jonathan Shapey, Wenqi Li, Reuben Dorent, Alex Demitriadis, + Sotirios Bisdas, Ian Paddick, Robert Bradford, Shaoting Zhang, Sébastien Ourselin, + Tom Vercauteren: Automatic Segmentation of Vestibular Schwannoma from T2-Weighted + MRI by Deep Spatial Attention with Hardness-Weighted Loss. + `MICCAI (2) 2019: 264-272. `_ + + Note that the attention module in the orininal paper is not used here. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param conv_dims: (list) The convolution dimension (2 or 3) for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear`). The default value + is 2 (`Trilinear`). + :param multiscale_pred: (bool) Get multi-scale prediction. + """ + def __init__(self, params): + super(FMUNetV3, self).__init__() + params = self.get_default_parameters(params) + + self.stage = 'train' + in_chns = params['in_chns'] + ft_chns = params['feature_chns'] + res_mode = params['res_mode'] + dropout = params['dropout'] + depth = params['depth'] + cls_num = params['class_num'] + self.mul_pred = params.get('multiscale_pred', True) + self.tune_mode= params.get('finetune_mode', 'all') + self.load_mode= params.get('weights_load_mode', 'all') + + d0 = dim0[res_mode] + self.project = nn.Conv3d(in_chns, ft_chns[0], kernel_size=conv_knl[d0], padding=conv_pad[d0]) + self.encoder = Encoder(ft_chns, res_mode, dropout, depth) + # self.decoder = Decoder(ft_chns, res_mode, dropout, depth = 2) + self.decoder = Decoder(ft_chns, res_mode, dropout, depth) + + self.out_layers = nn.ModuleList() + dims = [dim0[res_mode], dim1[res_mode], 3, 3] + for i in range(4): + out_layer = nn.Sequential( + nn.InstanceNorm3d(ft_chns[i], affine = True), + nn.LeakyReLU(), + nn.Conv3d(ft_chns[i], cls_num, kernel_size=conv_knl[dims[i]], padding=conv_pad[dims[i]])) + self.out_layers.append(out_layer) + + init = params['initialization'].lower() + weightInitializer = Initialization_He(1e-2) if init == 'he' else Initialization_XavierUniform() + self.apply(weightInitializer) + + def get_default_parameters(self, params): + default_param = { + 'finetune_mode': 'all', + 'initialization': 'he', + 'feature_chns': [32, 64, 128, 256, 512], + 'dropout': 0.2, + 'res_mode': 0, + 'depth': 2, + 'multiscale_pred': True + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params + + def set_stage(self, stage): + self.stage = stage + + def forward(self, x): + x_en = self.encoder(self.project(x)) + x_de = self.decoder(x_en) + output = self.out_layers[0](x_de[0]) + if(self.mul_pred and self.stage == 'train'): + output = [output] + for i in range(1, len(x_de)): + output.append(self.out_layers[i](x_de[i])) + return output + + def get_parameters_to_update(self): + if(self.tune_mode == 'all'): + return self.parameters() + + up_params = itertools.chain() + if(self.tune_mode == 'decoder'): + up_blocks = [self.decoder, self.out_layers] + else: + raise ValueError("undefined fine-tune mode for FMUNet: {0:}".format(self.tune_mode)) + for block in up_blocks: + up_params = itertools.chain(up_params, block.parameters()) + return up_params + + def get_parameters_to_load(self): + state_dict = self.state_dict() + if(self.load_mode == 'encoder'): + state_dict = {k:v for k, v in state_dict.items() if "project" in k or "encoder" in k } + return state_dict \ No newline at end of file diff --git a/pymic/net/net_dict_seg.py b/pymic/net/net_dict_seg.py index 687710b..d7f759b 100644 --- a/pymic/net/net_dict_seg.py +++ b/pymic/net/net_dict_seg.py @@ -19,7 +19,7 @@ from pymic.net.net2d.unet2d_multi_decoder import UNet2D_DualBranch, MCNet2D from pymic.net.net2d.unet2d_canet import CANet from pymic.net.net2d.unet2d_cct import UNet2D_CCT -from pymic.net.net2d.unet2d_mtnet import MTNet2D +# from pymic.net.net2d.unet2d_mtnet import MTNet2D from pymic.net.net2d.cople_net import COPLENet from pymic.net.net2d.unet2d_attention import AttentionUNet2D from pymic.net.net2d.unet2d_pp import UNet2Dpp @@ -55,7 +55,7 @@ 'CANet': CANet, 'COPLENet': COPLENet, 'MCNet2D': MCNet2D, - 'MTNet2D': MTNet2D, + # 'MTNet2D': MTNet2D, 'UNet2D': UNet2D, 'UNet2D_DualBranch': UNet2D_DualBranch, 'UNet2D_CCT': UNet2D_CCT, From b99b5de9a330f7334c0d45527b08b2b4205fefd9 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 4 Oct 2024 15:28:36 +0800 Subject: [PATCH 64/86] Update __init__.py --- pymic/net_run/semi_sup/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymic/net_run/semi_sup/__init__.py b/pymic/net_run/semi_sup/__init__.py index cb5d1a3..769a66b 100644 --- a/pymic/net_run/semi_sup/__init__.py +++ b/pymic/net_run/semi_sup/__init__.py @@ -3,7 +3,7 @@ from pymic.net_run.semi_sup.ssl_em import SSLEntropyMinimization from pymic.net_run.semi_sup.ssl_mt import SSLMeanTeacher from pymic.net_run.semi_sup.ssl_mcnet import SSLMCNet -from pymic.net_run.semi_sup.ssl_cdma import SSLCDMA +# from pymic.net_run.semi_sup.ssl_cdma import SSLCDMA from pymic.net_run.semi_sup.ssl_uamt import SSLUncertaintyAwareMeanTeacher from pymic.net_run.semi_sup.ssl_cct import SSLCCT from pymic.net_run.semi_sup.ssl_cps import SSLCPS @@ -13,7 +13,7 @@ SSLMethodDict = {'EntropyMinimization': SSLEntropyMinimization, 'MeanTeacher': SSLMeanTeacher, 'MCNet': SSLMCNet, - 'CDMA': SSLCDMA, + # 'CDMA': SSLCDMA, 'UAMT': SSLUncertaintyAwareMeanTeacher, 'CCT': SSLCCT, 'CPS': SSLCPS, From 981b772f8e75002c1f13d789fb2c5454c5683a80 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 4 Oct 2024 15:47:37 +0800 Subject: [PATCH 65/86] load pretrained weights for classification models --- pymic/net_run/agent_cls.py | 49 +++++++++++++++++++++++++++----------- pymic/net_run/train.py | 2 ++ 2 files changed, 37 insertions(+), 14 deletions(-) diff --git a/pymic/net_run/agent_cls.py b/pymic/net_run/agent_cls.py index a31df84..728f805 100644 --- a/pymic/net_run/agent_cls.py +++ b/pymic/net_run/agent_cls.py @@ -19,7 +19,7 @@ from pymic.net.net_dict_cls import TorchClsNetDict from pymic.transform.trans_dict import TransformDict from pymic.net_run.agent_abstract import NetRunAgent -from pymic.util.general import mixup +from pymic.util.general import mixup, tensor_shape_match import warnings warnings.filterwarnings('ignore', '.*output shape of zoom.*') @@ -212,6 +212,27 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): logging.info('valid loss {0:.4f}, avg {1:} {2:.4f}'.format( valid_scalars['loss'], metrics, valid_scalars[metrics])) + def load_pretrained_weights(self, network, pretrained_dict, device_ids): + if(len(device_ids) > 1): + if(hasattr(network.module, "get_parameters_to_load")): + model_dict = network.module.get_parameters_to_load() + else: + model_dict = network.module.state_dict() + else: + if(hasattr(network, "get_parameters_to_load")): + model_dict = network.get_parameters_to_load() + else: + model_dict = network.state_dict() + pretrained_dict = {k: v for k, v in pretrained_dict.items() if \ + k in model_dict and tensor_shape_match(pretrained_dict[k], model_dict[k])} + logging.info("Initializing the following parameters with pre-trained model") + for k in pretrained_dict: + logging.info(k) + if (len(device_ids) > 1): + network.module.load_state_dict(pretrained_dict, strict = False) + else: + network.load_state_dict(pretrained_dict, strict = False) + def train_valid(self): device_ids = self.config['training']['gpus'] if(len(device_ids) > 1): @@ -227,7 +248,7 @@ def train_valid(self): ckpt_prefix = self.config['training'].get('ckpt_prefix', None) if(ckpt_prefix is None): ckpt_prefix = ckpt_dir.split('/')[-1] - iter_start = self.config['training']['iter_start'] + iter_start = 0 iter_max = self.config['training']['iter_max'] iter_valid = self.config['training']['iter_valid'] iter_save = self.config['training']['iter_save'] @@ -243,18 +264,18 @@ def train_valid(self): self.max_val_score = 0.0 self.max_val_it = 0 self.best_model_wts = None - self.checkpoint = None - if(iter_start > 0): - checkpoint_file = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, iter_start) - self.checkpoint = torch.load(checkpoint_file, map_location = self.device) - assert(self.checkpoint['iteration'] == iter_start) - if(len(device_ids) > 1): - self.net.module.load_state_dict(self.checkpoint['model_state_dict']) - else: - self.net.load_state_dict(self.checkpoint['model_state_dict']) - self.max_val_score = self.checkpoint.get('valid_pred', 0) - self.max_val_it = self.checkpoint['iteration'] - self.best_model_wts = self.checkpoint['model_state_dict'] + ckpt_init_name = self.config['training'].get('ckpt_init_name', None) + ckpt_init_mode = self.config['training'].get('ckpt_init_mode', 0) + + if(ckpt_init_name is not None): + checkpoint = torch.load(ckpt_dir + "/" + ckpt_init_name, map_location = self.device) + pretrained_dict = checkpoint['model_state_dict'] + self.load_pretrained_weights(self.net, pretrained_dict, device_ids) + if(ckpt_init_mode > 0): # Load other information + iter_start = checkpoint['iteration'] + self.max_val_score = checkpoint.get('valid_pred', 0) + self.max_val_it = checkpoint['iteration'] + self.best_model_wts = checkpoint['model_state_dict'] self.create_optimizer(self.get_parameters_to_update()) self.create_loss_calculator() diff --git a/pymic/net_run/train.py b/pymic/net_run/train.py index ed60fa1..ec0002f 100644 --- a/pymic/net_run/train.py +++ b/pymic/net_run/train.py @@ -60,6 +60,8 @@ def main(): required=False, default=None) parser.add_argument("-ckpt_dir", help="the output dir for trained model", required=False, default=None) + parser.add_argument("-iter_max", help="the maximal iteration number for training", + required=False, default=None) parser.add_argument("-gpus", help="the gpus for runing, e.g., [0]", required=False, default=None) args = parser.parse_args() From d833edc2528afd297798a30d0ce918559f64ed13 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 4 Oct 2024 16:31:54 +0800 Subject: [PATCH 66/86] update 3d networks --- pymic/net/net2d/cople_net.py | 34 ++++++++++++++++++++++------------ pymic/net/net3d/unet3d_scse.py | 23 +++++++++++++---------- 2 files changed, 35 insertions(+), 22 deletions(-) diff --git a/pymic/net/net2d/cople_net.py b/pymic/net/net2d/cople_net.py index bd54fd6..046dee8 100644 --- a/pymic/net/net2d/cople_net.py +++ b/pymic/net/net2d/cople_net.py @@ -120,20 +120,30 @@ class UpBlock(nn.Module): Upssampling followed by ConvBNActBlock. """ def __init__(self, in_channels1, in_channels2, out_channels, - bilinear=True, dropout_p = 0.5): + up_mode = 2, dropout_p = 0.5): super(UpBlock, self).__init__() - self.bilinear = bilinear - if bilinear: - self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size = 1) - self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + if(isinstance(up_mode, int)): + up_mode_values = ["transconv", "nearest", "bilinear", "bicubic"] + if(up_mode > 3): + raise ValueError("The upsample mode should be 0-3, but {0:} is given.".format(up_mode)) + self.up_mode = up_mode_values[up_mode] else: + self.up_mode = up_mode.lower() + + if (self.up_mode == "transconv"): self.up = nn.ConvTranspose2d(in_channels1, in_channels2, kernel_size=2, stride=2) + else: + self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size = 1) + if(self.up_mode == "nearest"): + self.up = nn.Upsample(scale_factor=2, mode=self.up_mode) + else: + self.up = nn.Upsample(scale_factor=2, mode=self.up_mode, align_corners=True) self.conv = ConvBNActBlock(in_channels2 * 2, out_channels, dropout_p) def forward(self, x1, x2): - if self.bilinear: + if self.up_mode != "transconv": x1 = self.conv1x1(x1) - x1 = self.up(x1) + x1 = self.up(x1) x_cat = torch.cat([x2, x1], dim=1) y = self.conv(x_cat) return y + x_cat @@ -165,7 +175,7 @@ def __init__(self, params): self.ft_chns = self.params['feature_chns'] self.dropout = self.params['dropout'] self.n_class = self.params['class_num'] - self.bilinear = self.params['bilinear'] + self.up_mode = self.params.get('up_mode', 2) assert(len(self.ft_chns) == 5) f0_half = int(self.ft_chns[0] / 2) @@ -183,10 +193,10 @@ def __init__(self, params): self.bridge2= ConvLayer(self.ft_chns[2], f2_half) self.bridge3= ConvLayer(self.ft_chns[3], f3_half) - self.up1 = UpBlock(self.ft_chns[4], f3_half, self.ft_chns[3], dropout_p = self.dropout[3]) - self.up2 = UpBlock(self.ft_chns[3], f2_half, self.ft_chns[2], dropout_p = self.dropout[2]) - self.up3 = UpBlock(self.ft_chns[2], f1_half, self.ft_chns[1], dropout_p = self.dropout[1]) - self.up4 = UpBlock(self.ft_chns[1], f0_half, self.ft_chns[0], dropout_p = self.dropout[0]) + self.up1 = UpBlock(self.ft_chns[4], f3_half, self.ft_chns[3], self.up_mode, dropout_p = self.dropout[3]) + self.up2 = UpBlock(self.ft_chns[3], f2_half, self.ft_chns[2], self.up_mode, dropout_p = self.dropout[2]) + self.up3 = UpBlock(self.ft_chns[2], f1_half, self.ft_chns[1], self.up_mode, dropout_p = self.dropout[1]) + self.up4 = UpBlock(self.ft_chns[1], f0_half, self.ft_chns[0], self.up_mode, dropout_p = self.dropout[0]) f4 = self.ft_chns[4] aspp_chns = [int(f4 / 4), int(f4 / 4), int(f4 / 4), int(f4 / 4)] diff --git a/pymic/net/net3d/unet3d_scse.py b/pymic/net/net3d/unet3d_scse.py index 79abf9c..49cecc4 100644 --- a/pymic/net/net3d/unet3d_scse.py +++ b/pymic/net/net3d/unet3d_scse.py @@ -76,12 +76,14 @@ class EncoderScSE(Encoder): def __init__(self, params): super(EncoderScSE, self).__init__(params) - self.in_conv= ConvScSEBlock3D(self.in_chns, self.ft_chns[0], self.dropout[0]) - self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) - self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) - self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) + in_chns = self.params['in_chns'] + dropout = self.params['dropout'] + self.in_conv= ConvScSEBlock3D(in_chns, self.ft_chns[0], dropout[0]) + self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], dropout[1]) + self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], dropout[2]) + self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], dropout[3]) if(len(self.ft_chns) == 5): - self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) + self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], dropout[4]) class DecoderScSE(Decoder): """ @@ -92,12 +94,13 @@ class DecoderScSE(Decoder): """ def __init__(self, params): super(DecoderScSE, self).__init__(params) - + dropout = self.params['dropout'] + up_mode = self.params.get('up_mode', 2) if(len(self.ft_chns) == 5): - self.up1 = UpBlockScSE(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.up_mode) - self.up2 = UpBlockScSE(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.up_mode) - self.up3 = UpBlockScSE(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.up_mode) - self.up4 = UpBlockScSE(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.up_mode) + self.up1 = UpBlockScSE(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout[3], up_mode) + self.up2 = UpBlockScSE(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout[2], up_mode) + self.up3 = UpBlockScSE(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout[1], up_mode) + self.up4 = UpBlockScSE(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout[0], up_mode) class UNet3D_ScSE(UNet3D): From bbc75bfcc616698f8ffc50d3fcb2619095510a66 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 4 Oct 2024 17:16:01 +0800 Subject: [PATCH 67/86] Update wsl_gatedcrf.py --- pymic/net_run/weak_sup/wsl_gatedcrf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pymic/net_run/weak_sup/wsl_gatedcrf.py b/pymic/net_run/weak_sup/wsl_gatedcrf.py index 1ecae4a..7eaa67d 100644 --- a/pymic/net_run/weak_sup/wsl_gatedcrf.py +++ b/pymic/net_run/weak_sup/wsl_gatedcrf.py @@ -73,6 +73,7 @@ def training(self): # forward + backward + optimize outputs = self.net(inputs) + t2 = time.time() loss_sup = self.get_loss_value(data, outputs, y) # for gated CRF loss, the input should be like NCHW From 908eb6f67ee54064c7147dec85674d41db6a600c Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 4 Oct 2024 17:24:39 +0800 Subject: [PATCH 68/86] Update nll_co_teaching.py --- pymic/net_run/noisy_label/nll_co_teaching.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pymic/net_run/noisy_label/nll_co_teaching.py b/pymic/net_run/noisy_label/nll_co_teaching.py index d46b05b..33e375c 100644 --- a/pymic/net_run/noisy_label/nll_co_teaching.py +++ b/pymic/net_run/noisy_label/nll_co_teaching.py @@ -50,7 +50,8 @@ def training(self): rampup_end = nll_cfg.get('rampup_end', iter_max) train_loss_no_select1, train_loss_no_select2 = 0, 0 - train_loss1, train_avg_loss2 = 0, 0 + train_loss1, train_avg_loss1 = 0, 0 + train_loss2, train_avg_loss2 = 0, 0 train_dice_list = [] data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() From 6f04440b9417142bd9f8c03efb12547e311e5294 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 15 Nov 2024 13:51:19 +0800 Subject: [PATCH 69/86] Update __init__.py --- pymic/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymic/__init__.py b/pymic/__init__.py index 33943e4..b7531d9 100644 --- a/pymic/__init__.py +++ b/pymic/__init__.py @@ -1,7 +1,7 @@ from __future__ import absolute_import from enum import Enum -__version__ = "0.4.1" +__version__ = "0.4.2" # 2024.11.15 class TaskType(Enum): CLASSIFICATION_ONE_HOT = 1 From 63e9a531ab6b0c99321a5a7b8e2d7c1d005c3312 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 15 Nov 2024 14:34:15 +0800 Subject: [PATCH 70/86] update config for evaluation --- pymic/util/evaluation_cls.py | 58 +++++++++++++++++++++--------------- setup.py | 2 +- 2 files changed, 35 insertions(+), 25 deletions(-) diff --git a/pymic/util/evaluation_cls.py b/pymic/util/evaluation_cls.py index af11a17..a65953a 100644 --- a/pymic/util/evaluation_cls.py +++ b/pymic/util/evaluation_cls.py @@ -3,7 +3,7 @@ Evaluation module for classification tasks. """ from __future__ import absolute_import, print_function - +import argparse import os import csv import sys @@ -75,15 +75,15 @@ def binary_evaluation(config): The arguments are given in the `config` dictionary. It should have the following fields: - :param metric_list: (list) A list of evaluation metrics. + :param metric: (list) A list of evaluation metrics. The supported metrics are {`accuracy`, `recall`, `sensitivity`, `specificity`, `precision`, `auc`}. - :param ground_truth_csv: (str) The csv file for ground truth. - :param predict_prob_csv: (str) The csv file for prediction probability. + :param gt_csv: (str) The csv file for ground truth. + :param pred_prob_csv: (str) The csv file for prediction probability. """ - metric_list = config['metric_list'] - gt_csv = config['ground_truth_csv'] - prob_csv= config['predict_prob_csv'] + metric_list = config['metric'] + gt_csv = config['gt_csv'] + prob_csv= config['pred_prob_csv'] gt_items = pd.read_csv(gt_csv) prob_items = pd.read_csv(prob_csv) assert(len(gt_items) == len(prob_items)) @@ -111,15 +111,15 @@ def nexcl_evaluation(config): The arguments are given in the `config` dictionary. It should have the following fields: - :param metric_list: (list) A list of evaluation metrics. + :param metric: (list) A list of evaluation metrics. The supported metrics are {`accuracy`, `recall`, `sensitivity`, `specificity`, `precision`, `auc`}. - :param ground_truth_csv: (str) The csv file for ground truth. - :param predict_prob_csv: (str) The csv file for prediction probability. + :param gt_csv: (str) The csv file for ground truth. + :param pred_prob_csv: (str) The csv file for prediction probability. """ - metric_list = config['metric_list'] - gt_csv = config['ground_truth_csv'] - prob_csv = config['predict_prob_csv'] + metric_list = config['metric'] + gt_csv = config['gt_csv'] + prob_csv = config['pred_prob_csv'] gt_items = pd.read_csv(gt_csv) prob_items= pd.read_csv(prob_csv) assert(len(gt_items) == len(prob_items)) @@ -163,25 +163,35 @@ def main(): .. code-block:: none - pymic_evaluate_cls config.cfg + pymic_evaluate_cls -cfg config.cfg The configuration file should have an `evaluation` section with the following fields: :param task_type: (str) `cls` or `cls_nexcl`. - :param metric_list: (list) A list of evaluation metrics. + :param metric: (list) A list of evaluation metrics. The supported metrics are {`accuracy`, `recall`, `sensitivity`, `specificity`, `precision`, `auc`}. - :param ground_truth_csv: (str) The csv file for ground truth. - :param predict_prob_csv: (str) The csv file for prediction probability. + :param gt_csv: (str) The csv file for ground truth. + :param pred_prob_csv: (str) The csv file for prediction probability. """ - if(len(sys.argv) < 2): - print('Number of arguments should be 2. e.g.') - print(' pymic_evaluate_cls config.cfg') - exit() - config_file = str(sys.argv[1]) - assert(os.path.isfile(config_file)) - config = parse_config(config_file)['evaluation'] + parser = argparse.ArgumentParser() + parser.add_argument("-cfg", help="configuration file for evaluation", + required=False, default=None) + parser.add_argument("-metric", help="evaluation metrics, e.g., accuracy, or [accuracy, auc]", + required=False, default=None) + parser.add_argument("-gt_csv", help="csv file for ground truth", + required=False, default=None) + parser.add_argument("-pred_prob_csv", help="csv file for probability prediction", + required=False, default=None) + args = parser.parse_args() + print(args) + if(args.cfg is not None): + config = parse_config(args)['evaluation'] + + # config_file = str(sys.argv[1]) + # assert(os.path.isfile(config_file)) + # config = parse_config(config_file)['evaluation'] task_type = config.get('task_type', "cls") if(task_type == "cls"): # default exclusive classification binary_evaluation(config) diff --git a/setup.py b/setup.py index ebb738f..312b1b0 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setuptools.setup( name = 'PYMIC', - version = "0.4.1", + version = "0.4.2", author ='PyMIC Consortium', author_email = 'wguotai@gmail.com', description = description, From 6dff55e782e8d7c055c73bc3e7dd6818d952e11a Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 15 Nov 2024 16:20:48 +0800 Subject: [PATCH 71/86] upgrade to v0.5.0 --- README.md | 4 ++-- pymic/__init__.py | 2 +- setup.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 6f34f92..bedeaae 100644 --- a/README.md +++ b/README.md @@ -47,10 +47,10 @@ Run the following command to install the latest released version of PyMIC: ```bash pip install PYMIC ``` -To install a specific version of PYMIC such as 0.4.1, run: +To install a specific version of PYMIC such as 0.5.0, run: ```bash -pip install PYMIC==0.4.1 +pip install PYMIC==0.5.0 ``` Alternatively, you can download the source code for the latest version. Run the following command to compile and install: diff --git a/pymic/__init__.py b/pymic/__init__.py index b7531d9..ae1775d 100644 --- a/pymic/__init__.py +++ b/pymic/__init__.py @@ -1,7 +1,7 @@ from __future__ import absolute_import from enum import Enum -__version__ = "0.4.2" # 2024.11.15 +__version__ = "0.5.0" # 2024.11.15 class TaskType(Enum): CLASSIFICATION_ONE_HOT = 1 diff --git a/setup.py b/setup.py index 312b1b0..cbf7355 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setuptools.setup( name = 'PYMIC', - version = "0.4.2", + version = "0.5.0", author ='PyMIC Consortium', author_email = 'wguotai@gmail.com', description = description, From 808dfbef597d455b640a9e4df25b6fd3c30b2eb1 Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 3 Dec 2024 15:21:58 +0800 Subject: [PATCH 72/86] update dataloader and nework add FMUNet, and allow missing modalities for multi-modal inputs --- pymic/io/nifty_dataset.py | 16 ++- pymic/net/net3d/fmunet.py | 265 +++++++++++++++++++++++++++++++++++++ pymic/net/net_dict_seg.py | 2 + pymic/net_run/agent_seg.py | 6 +- 4 files changed, 284 insertions(+), 5 deletions(-) create mode 100644 pymic/net/net3d/fmunet.py diff --git a/pymic/io/nifty_dataset.py b/pymic/io/nifty_dataset.py index c3253c9..048ba64 100644 --- a/pymic/io/nifty_dataset.py +++ b/pymic/io/nifty_dataset.py @@ -22,11 +22,13 @@ class NiftyDataset(Dataset): :param transform: (list) List of transforms to be applied on a sample. The built-in transforms can listed in :mod:`pymic.transform.trans_dict`. """ - def __init__(self, root_dir, csv_file, modal_num = 1, + # def __init__(self, root_dir, csv_file, modal_num = 1, + def __init__(self, root_dir, csv_file, modal_num = 1, allow_missing_modal = False, with_label = False, transform=None, task = TaskType.SEGMENTATION): self.root_dir = root_dir self.csv_items = pd.read_csv(csv_file) self.modal_num = modal_num + self.allow_emtpy= allow_missing_modal self.with_label = with_label self.transform = transform self.task = task @@ -89,11 +91,19 @@ def __get_pixel_weight__(self, idx): def __getitem__(self, idx): names_list, image_list = [], [] + image_shape = None for i in range (self.modal_num): image_name = self.csv_items.iloc[idx, i] image_full_name = "{0:}/{1:}".format(self.root_dir, image_name) - image_dict = load_image_as_nd_array(image_full_name) - image_data = image_dict['data_array'] + if(os.path.exists(image_full_name)): + image_dict = load_image_as_nd_array(image_full_name) + image_data = image_dict['data_array'] + elif(self.allow_emtpy and image_shape is not None): + image_data = np.zeros(image_shape) + else: + raise KeyError("File not found: {0:}".format(image_full_name)) + if(i == 0): + image_shape = image_data.shape names_list.append(image_name) image_list.append(image_data) image = np.concatenate(image_list, axis = 0) diff --git a/pymic/net/net3d/fmunet.py b/pymic/net/net3d/fmunet.py new file mode 100644 index 0000000..84ee385 --- /dev/null +++ b/pymic/net/net3d/fmunet.py @@ -0,0 +1,265 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import itertools +import logging +import torch +import torch.nn as nn +from pymic.net.net_init import Initialization_He, Initialization_XavierUniform + +''' +A copy of fmunetv3, and rename the class as FMUNet. +''' +dim0 = {0:3, 1:2, 2:2} +dim1 = {0:3, 1:3, 2:2} +conv_knl = {2: (1, 3, 3), 3: 3} +conv_pad = {2: (0, 1, 1), 3: 1} +pool_knl = {2: (1, 2, 2), 3: 2} +down_stride = {2: (1, 2, 2), 3: 2} + +class ResConv(nn.Module): + def __init__(self, out_channels, dim = 3, dropout_p = 0.0, depth = 2): + super(ResConv, self).__init__() + assert(dim == 2 or dim == 3) + self.out_channels = out_channels + self.conv_list = nn.ModuleList([nn.Sequential( + nn.InstanceNorm3d(out_channels, affine = True), + nn.LeakyReLU(), + nn.Dropout(dropout_p), + nn.Conv3d(out_channels, out_channels, kernel_size=conv_knl[dim], padding=conv_pad[dim])) + for i in range(depth)]) + + def forward(self, x): + for conv in self.conv_list: + x = conv(x) + x + return x + +class DownSample(nn.Module): + """downsampling based on convolution + + :param in_channels: (int) Input channel number. + :param out_channels: (int) Output channel number. + :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. + :param dropout_p: (int) Dropout probability. + :param downsample: (bool) Use downsample or not after convolution. + """ + def __init__(self, in_channels, out_channels, dim = 3): + super(DownSample, self).__init__() + self.down = nn.Sequential( + nn.InstanceNorm3d(in_channels, affine = True), + nn.LeakyReLU(), + nn.Conv3d(in_channels, out_channels, kernel_size=conv_knl[dim], + padding=conv_pad[dim], stride = down_stride[dim]) + ) + + def forward(self, x): + return self.down(x) + +class UpCatConv(nn.Module): + """Upsampling followed by `ResConv` block + + :param in_channels1: (int) Input channel number for low-resolution feature map. + :param in_channels2: (int) Input channel number for high-resolution feature map. + :param out_channels: (int) Output channel number. + :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. + :param dropout_p: (int) Dropout probability. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear` for 3D and `Bilinear` for 2D). + The default value is 2. + """ + def __init__(self, in_channels1, in_channels2, out_channels, dim = 3): + super(UpCatConv, self).__init__() + + self.up = nn.Sequential( + nn.InstanceNorm3d(in_channels1, affine = True), + nn.LeakyReLU(), + nn.Conv3d(in_channels1, in_channels2, kernel_size=1, padding=0), + nn.Upsample(scale_factor=pool_knl[dim], mode='trilinear', align_corners=True) + ) + + self.conv = nn.Sequential( + nn.InstanceNorm3d(in_channels2*2, affine = True), + nn.LeakyReLU(), + nn.Conv3d(in_channels2 * 2, out_channels, kernel_size=conv_knl[dim], padding=conv_pad[dim]) + ) + + def forward(self, x_l, x_h): + """ + x_l: low-resolution feature map. + x_h: high-resolution feature map. + """ + y = torch.cat([x_h, self.up(x_l)], dim=1) + return self.conv(y) + +class Encoder(nn.Module): + """ + A modification of the encoder of 3D UNet by using ConvScSEBlock3D + + Parameters are given in the `params` dictionary. + See :mod:`pymic.net.net3d.unet3d.Encoder` for details. + + res_mode: resolution mode: 0-- isotrpic, 1-- near isotrpic, 2-- isotropic + """ + def __init__(self, ft_chns, res_mode = 0, dropout_p = 0, depth = 2): + super(Encoder, self).__init__() + d0, d1 = dim0[res_mode], dim1[res_mode] + + self.en_conv0 = ResConv(ft_chns[0], d0, 0, depth) + self.en_conv1 = ResConv(ft_chns[1], d1, 0, depth) + self.en_conv2 = ResConv(ft_chns[2], 3, dropout_p, depth) + self.en_conv3 = ResConv(ft_chns[3], 3, dropout_p, depth) + self.en_conv4 = ResConv(ft_chns[4], 3, dropout_p, depth) + + self.down0 = DownSample(ft_chns[0], ft_chns[1], d0) + self.down1 = DownSample(ft_chns[1], ft_chns[2], d1) + self.down2 = DownSample(ft_chns[2], ft_chns[3], 3) + self.down3 = DownSample(ft_chns[3], ft_chns[4], 3) + + def forward(self, x): + x0 = self.en_conv0(x) + x1 = self.en_conv1(self.down0(x0)) + x2 = self.en_conv2(self.down1(x1)) + x3 = self.en_conv3(self.down2(x2)) + x4 = self.en_conv4(self.down3(x3)) + return [x0, x1, x2, x3, x4] + +class Decoder(nn.Module): + """ + A modification of the encoder of 3D UNet by using ConvScSEBlock3D + + Parameters are given in the `params` dictionary. + See :mod:`pymic.net.net3d.unet3d.Encoder` for details. + """ + def __init__(self, ft_chns, res_mode = 0, dropout_p = 0, depth = 2): + super(Decoder, self).__init__() + d0, d1 = dim0[res_mode], dim1[res_mode] + + self.upcat0 = UpCatConv(ft_chns[1], ft_chns[0], ft_chns[0], d0) + self.upcat1 = UpCatConv(ft_chns[2], ft_chns[1], ft_chns[1], d1) + self.upcat2 = UpCatConv(ft_chns[3], ft_chns[2], ft_chns[2], 3) + self.upcat3 = UpCatConv(ft_chns[4], ft_chns[3], ft_chns[3], 3) + + self.de_conv0 = ResConv(ft_chns[0], d0, 0, depth) + self.de_conv1 = ResConv(ft_chns[1], d1, 0, depth) + self.de_conv2 = ResConv(ft_chns[2], 3, dropout_p, depth) + self.de_conv3 = ResConv(ft_chns[3], 3, dropout_p, depth) + self.de_conv4 = ResConv(ft_chns[4], 3, dropout_p, depth) + + def forward(self, x): + x0, x1, x2, x3, x4 = x + x4_de = self.de_conv4(x4) + x3_de = self.de_conv3(self.upcat3(x4_de, x3)) + x2_de = self.de_conv2(self.upcat2(x3_de, x2)) + x1_de = self.de_conv1(self.upcat1(x2_de, x1)) + x0_de = self.de_conv0(self.upcat0(x1_de, x0)) + return [x0_de, x1_de, x2_de, x3_de] + +class FMUNet(nn.Module): + """ + A 2.5D network combining 3D convolutions with 2D convolutions. + + * Reference: Guotai Wang, Jonathan Shapey, Wenqi Li, Reuben Dorent, Alex Demitriadis, + Sotirios Bisdas, Ian Paddick, Robert Bradford, Shaoting Zhang, Sébastien Ourselin, + Tom Vercauteren: Automatic Segmentation of Vestibular Schwannoma from T2-Weighted + MRI by Deep Spatial Attention with Hardness-Weighted Loss. + `MICCAI (2) 2019: 264-272. `_ + + Note that the attention module in the orininal paper is not used here. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param conv_dims: (list) The convolution dimension (2 or 3) for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear`). The default value + is 2 (`Trilinear`). + :param multiscale_pred: (bool) Get multi-scale prediction. + """ + def __init__(self, params): + super(FMUNet, self).__init__() + params = self.get_default_parameters(params) + + self.stage = 'train' + in_chns = params['in_chns'] + ft_chns = params['feature_chns'] + res_mode = params['res_mode'] + dropout = params['dropout'] + depth = params['depth'] + cls_num = params['class_num'] + self.mul_pred = params.get('multiscale_pred', True) + self.tune_mode= params.get('finetune_mode', 'all') + self.load_mode= params.get('weights_load_mode', 'all') + + d0 = dim0[res_mode] + self.project = nn.Conv3d(in_chns, ft_chns[0], kernel_size=conv_knl[d0], padding=conv_pad[d0]) + self.encoder = Encoder(ft_chns, res_mode, dropout, depth) + # self.decoder = Decoder(ft_chns, res_mode, dropout, depth = 2) + self.decoder = Decoder(ft_chns, res_mode, dropout, depth) + + self.out_layers = nn.ModuleList() + dims = [dim0[res_mode], dim1[res_mode], 3, 3] + for i in range(4): + out_layer = nn.Sequential( + nn.InstanceNorm3d(ft_chns[i], affine = True), + nn.LeakyReLU(), + nn.Conv3d(ft_chns[i], cls_num, kernel_size=conv_knl[dims[i]], padding=conv_pad[dims[i]])) + self.out_layers.append(out_layer) + + init = params['initialization'].lower() + weightInitializer = Initialization_He(1e-2) if init == 'he' else Initialization_XavierUniform() + self.apply(weightInitializer) + + def get_default_parameters(self, params): + default_param = { + 'finetune_mode': 'all', + 'initialization': 'he', + 'feature_chns': [32, 64, 128, 256, 512], + 'dropout': 0.2, + 'res_mode': 0, + 'depth': 2, + 'multiscale_pred': True + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params + + def set_stage(self, stage): + self.stage = stage + + def forward(self, x): + x_en = self.encoder(self.project(x)) + x_de = self.decoder(x_en) + output = self.out_layers[0](x_de[0]) + if(self.mul_pred and self.stage == 'train'): + output = [output] + for i in range(1, len(x_de)): + output.append(self.out_layers[i](x_de[i])) + return output + + def get_parameters_to_update(self): + if(self.tune_mode == 'all'): + return self.parameters() + + up_params = itertools.chain() + if(self.tune_mode == 'decoder'): + up_blocks = [self.decoder, self.out_layers] + else: + raise ValueError("undefined fine-tune mode for FMUNet: {0:}".format(self.tune_mode)) + for block in up_blocks: + up_params = itertools.chain(up_params, block.parameters()) + return up_params + + def get_parameters_to_load(self): + state_dict = self.state_dict() + if(self.load_mode == 'encoder'): + state_dict = {k:v for k, v in state_dict.items() if "project" in k or "encoder" in k } + return state_dict \ No newline at end of file diff --git a/pymic/net/net_dict_seg.py b/pymic/net/net_dict_seg.py index d7f759b..8877ade 100644 --- a/pymic/net/net_dict_seg.py +++ b/pymic/net/net_dict_seg.py @@ -30,6 +30,7 @@ from pymic.net.net3d.unet3d import UNet3D from pymic.net.net3d.grunet import GRUNet from pymic.net.net3d.fmunetv3 import FMUNetV3 +from pymic.net.net3d.fmunet import FMUNet from pymic.net.net3d.lcovnet import LCOVNet from pymic.net.net3d.unet3d_scse import UNet3D_ScSE from pymic.net.net3d.unet3d_dual_branch import UNet3D_DualBranch @@ -66,6 +67,7 @@ 'UNet2D5': UNet2D5, 'GRUNet': GRUNet, 'LCOVNet': LCOVNet, + 'FMUNet': FMUNet, 'FMUNetV3': FMUNetV3, 'UNet3D': UNet3D, 'UNet3D_ScSE': UNet3D_ScSE, diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 59ffe05..cb30b45 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -68,8 +68,9 @@ def get_stage_dataset_from_config(self, stage): self.test_transforms = transform_list else: with_label = self.config['dataset'].get(stage + '_label', True) - modal_num = self.config['dataset'].get('modal_num', 1) - stage_dir = self.config['dataset'].get('train_dir', None) + modal_num = self.config['dataset'].get('modal_num', 1) + allow_miss = self.config['dataset'].get('allow_missing_modal', False) + stage_dir = self.config['dataset'].get('train_dir', None) if(stage == 'valid' and "valid_dir" in self.config['dataset']): stage_dir = self.config['dataset']['valid_dir'] if(stage == 'test' and "test_dir" in self.config['dataset']): @@ -78,6 +79,7 @@ def get_stage_dataset_from_config(self, stage): dataset = NiftyDataset(root_dir = stage_dir, csv_file = csv_file, modal_num = modal_num, + allow_missing_modal = allow_miss, with_label= with_label, transform = data_transform, task = self.task_type) From 2a240b1751c6139af1abc7882dd3498687e4d04a Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 3 Dec 2024 16:58:07 +0800 Subject: [PATCH 73/86] add postprocess dictionary --- pymic/net_run/agent_abstract.py | 34 +++++++++++++++++++++++---------- pymic/net_run/agent_rec.py | 14 ++++++-------- pymic/net_run/agent_seg.py | 11 +---------- 3 files changed, 31 insertions(+), 28 deletions(-) diff --git a/pymic/net_run/agent_abstract.py b/pymic/net_run/agent_abstract.py index d8abb19..01ad808 100644 --- a/pymic/net_run/agent_abstract.py +++ b/pymic/net_run/agent_abstract.py @@ -56,6 +56,8 @@ def __init__(self, config, stage = 'train'): self.loss_dict = None self.transform_dict = None self.inferer = None + self.postprocess_dict = None + self.postprocessor = None self.tensor_type = config['dataset']['tensor_type'] self.task_type = config['dataset']['task_type'] self.deterministic = config['training'].get('deterministic', True) @@ -102,6 +104,14 @@ def set_net_dict(self, net_dict): """ self.net_dict = net_dict + def set_postprocess_dict(self, postprocess_dict): + """ + Set the available methods for postprocess, including customized postprocess methods. + + :param postprocess_dict: (dictionary) A dictionary of available postprocess methods. + """ + self.postprocess_dict = postprocess_dict + def set_loss_dict(self, loss_dict): """ Set the available loss functions, including customized loss functions. @@ -258,7 +268,11 @@ def create_dataset(self): if(self.train_set is None): self.train_set = self.get_stage_dataset_from_config('train') if(self.valid_set is None): - self.valid_set = self.get_stage_dataset_from_config('valid') + valid_csv = self.config['dataset'].get('valid_csv', None) + if valid_csv is not None: + self.valid_set = self.get_stage_dataset_from_config('valid') + else: + logging.warning("Dataset for validation is not created, as valid_dir is not provided.") if(self.deterministic): def worker_init_fn(worker_id): # workder_seed = self.random_seed+worker_id @@ -269,18 +283,20 @@ def worker_init_fn(worker_id): else: worker_init = None - bn_train = self.config['dataset']['train_batch_size'] - bn_valid = self.config['dataset'].get('valid_batch_size', 1) num_worker = self.config['dataset'].get('num_worker', 8) - g_train, g_valid = torch.Generator(), torch.Generator() + bn_train = self.config['dataset']['train_batch_size'] + g_train = torch.Generator() g_train.manual_seed(self.random_seed) - g_valid.manual_seed(self.random_seed) self.train_loader = torch.utils.data.DataLoader(self.train_set, batch_size = bn_train, shuffle=True, num_workers= num_worker, worker_init_fn=worker_init, generator = g_train, drop_last = True) - self.valid_loader = torch.utils.data.DataLoader(self.valid_set, - batch_size = bn_valid, shuffle=False, num_workers= num_worker, - worker_init_fn=worker_init, generator = g_valid) + if(self.valid_set is not None): + bn_valid = self.config['dataset'].get('valid_batch_size', 1) + g_valid = torch.Generator() + g_valid.manual_seed(self.random_seed) + self.valid_loader = torch.utils.data.DataLoader(self.valid_set, + batch_size = bn_valid, shuffle=False, num_workers= num_worker, + worker_init_fn=worker_init, generator = g_valid) else: bn_test = self.config['dataset'].get('test_batch_size', 1) if(self.test_set is None): @@ -327,5 +343,3 @@ def run(self): else: self.infer() - - diff --git a/pymic/net_run/agent_rec.py b/pymic/net_run/agent_rec.py index cc83526..0d78b43 100644 --- a/pymic/net_run/agent_rec.py +++ b/pymic/net_run/agent_rec.py @@ -108,8 +108,8 @@ def training(self): loss = self.get_loss_value(data, outputs, label) t3 = time.time() loss.backward() - t4 = time.time() self.optimizer.step() + t4 = time.time() train_loss = train_loss + loss.item() if(isinstance(outputs, tuple) or isinstance(outputs, list)): @@ -175,8 +175,7 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): 'valid':valid_scalars['loss']} self.summ_writer.add_scalars('loss', loss_scalar, glob_it) self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it) - logging.info('train loss {0:.4f}'.format(train_scalars['loss'])) - logging.info('valid loss {0:.4f}'.format(valid_scalars['loss'])) + logging.info('train/valid loss {0:.4f}/{1:.4f}'.format(train_scalars['loss'],valid_scalars['loss'])) logging.info("data: {0:.2f}s, forward: {1:.2f}s, loss: {2:.2f}s, backward: {3:.2f}s".format( train_scalars['data_time'], train_scalars['forward_time'], train_scalars['loss_time'], train_scalars['backward_time'])) @@ -259,9 +258,6 @@ def train_valid(self): logging.info("\n{0:} it {1:}".format(str(datetime.now())[:-7], self.glob_it)) logging.info('learning rate {0:}'.format(lr_value)) logging.info("training/validation time: {0:.2f}s/{1:.2f}s".format(t1-t0, t2-t1)) - logging.info("data: {0:.2f}s, forward: {1:.2f}s, loss: {2:.2f}s, backward: {3:.2f}s".format( - train_scalars['data_time'], train_scalars['forward_time'], - train_scalars['loss_time'], train_scalars['backward_time'])) self.write_scalars(train_scalars, valid_scalars, lr_value, self.glob_it) if(valid_scalars['loss'] < self.min_val_loss): self.min_val_loss = valid_scalars['loss'] @@ -320,19 +316,21 @@ def save_outputs(self, data): if(isinstance(pred, (list, tuple))): pred = pred[0] pred = np.tanh(pred) + if(self.postprocessor is not None): + pred = self.postprocessor(pred) # pred = scipy.special.expit(pred) # save the output predictions test_dir = self.config['dataset'].get('test_dir', None) if(test_dir is None): test_dir = self.config['dataset']['train_dir'] - for i in range(len(names)): + for i in range(pred.shape[1]): save_name = names[i][0].split('/')[-1] if ignore_dir else \ names[i][0].replace('/', '_') if((filename_replace_source is not None) and (filename_replace_target is not None)): save_name = save_name.replace(filename_replace_source, filename_replace_target) print(save_name) save_name = "{0:}/{1:}".format(output_dir, save_name) - save_nd_array_as_image(pred[i][i], save_name, test_dir + '/' + names[i][0]) + save_nd_array_as_image(pred[i][0], save_name, test_dir + '/' + names[i][0]) \ No newline at end of file diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index cb30b45..45e815b 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -150,15 +150,6 @@ def get_loss_value(self, data, pred, gt, param = None): loss_input_dict['class_weight'] = class_weight.to(device) loss_value = self.loss_calculator(loss_input_dict) return loss_value - - def set_postprocessor(self, postprocessor): - """ - Set post processor after prediction. - - :param postprocessor: post processor, such as an instance of - `pymic.util.post_process.PostProcess`. - """ - self.postprocessor = postprocessor def training(self): class_num = self.config['network']['class_num'] @@ -476,7 +467,7 @@ def test_time_dropout(m): self.inferer = Inferer(infer_cfg) postpro_name = self.config['testing'].get('post_process', None) if(self.postprocessor is None and postpro_name is not None): - self.postprocessor = PostProcessDict[postpro_name](self.config['testing']) + self.postprocessor = self.postprocess_dict[postpro_name](self.config['testing']) infer_time_list = [] with torch.no_grad(): for data in self.test_loader: From 4625b0e048d0c9d053301f1fe406f1502ba48349 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 31 Jan 2025 17:40:48 +0800 Subject: [PATCH 74/86] add umamba to pymic --- pymic/net/net2d/umamba.py | 1234 +++++++++++++++++++++++++++++++++++++ pymic/net/net_dict_seg.py | 3 + pymic/test/test_net2d.py | 15 +- 3 files changed, 1250 insertions(+), 2 deletions(-) create mode 100644 pymic/net/net2d/umamba.py diff --git a/pymic/net/net2d/umamba.py b/pymic/net/net2d/umamba.py new file mode 100644 index 0000000..63ccacf --- /dev/null +++ b/pymic/net/net2d/umamba.py @@ -0,0 +1,1234 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import logging +import numpy as np +import math +import torch +from torch import nn +from torch.nn import functional as F +from typing import Union, Type, List, Tuple + +from torch.nn.modules.conv import _ConvNd +from torch.nn.modules.dropout import _DropoutNd +from mamba_ssm import Mamba +from torch.cuda.amp import autocast + +# from dynamic_network_architectures.building_blocks.helper import convert_conv_op_to_dim +# from nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager, PlansManager +# from dynamic_network_architectures.building_blocks.helper import get_matching_instancenorm, convert_dim_to_conv_op +# from nnunetv2.utilities.network_initialization import InitWeights_He +# from dynamic_network_architectures.building_blocks.helper import maybe_convert_scalar_to_list, get_matching_pool_op + +def dim_of_conv_op(conv_op: Type[_ConvNd]) -> int: + """ + :param conv_op: conv class + :return: dimension: 1, 2 or 3 + """ + if conv_op == nn.Conv1d: + return 1 + elif conv_op == nn.Conv2d: + return 2 + elif conv_op == nn.Conv3d: + return 3 + else: + raise ValueError("Unknown dimension. Only 1d 2d and 3d conv are supported. got %s" % str(conv_op)) + +def get_matching_pool_op(conv_op: Type[_ConvNd] = None, + dimension: int = None, + adaptive=False, + pool_type: str = 'avg') -> Type[torch.nn.Module]: + """ + You MUST set EITHER conv_op OR dimension. Do not set both! + :param conv_op: + :param dimension: + :param adaptive: + :param pool_type: either 'avg' or 'max' + :return: + """ + assert not ((conv_op is not None) and (dimension is not None)), \ + "You MUST set EITHER conv_op OR dimension. Do not set both!" + assert pool_type in ['avg', 'max'], 'pool_type must be either avg or max' + if conv_op is not None: + dimension = dim_of_conv_op(conv_op) + assert dimension in [1, 2, 3], 'Dimension must be 1, 2 or 3' + + if conv_op is not None: + dimension = dim_of_conv_op(conv_op) + + if dimension == 1: + if pool_type == 'avg': + if adaptive: + return nn.AdaptiveAvgPool1d + else: + return nn.AvgPool1d + elif pool_type == 'max': + if adaptive: + return nn.AdaptiveMaxPool1d + else: + return nn.MaxPool1d + elif dimension == 2: + if pool_type == 'avg': + if adaptive: + return nn.AdaptiveAvgPool2d + else: + return nn.AvgPool2d + elif pool_type == 'max': + if adaptive: + return nn.AdaptiveMaxPool2d + else: + return nn.MaxPool2d + elif dimension == 3: + if pool_type == 'avg': + if adaptive: + return nn.AdaptiveAvgPool3d + else: + return nn.AvgPool3d + elif pool_type == 'max': + if adaptive: + return nn.AdaptiveMaxPool3d + else: + return nn.MaxPool3d + +def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): + """ + This function is taken from the timm package (https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py). + + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + +class DropPath(nn.Module): + """ + This class is taken from the timm package (https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py). + + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + +def make_divisible(v, divisor=8, min_value=None, round_limit=.9): + """ + This function is taken from the timm package (https://github.com/rwightman/pytorch-image-models/blob/b7cb8d0337b3e7b50516849805ddb9be5fc11644/timm/models/layers/helpers.py#L25) + """ + min_value = min_value or divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < round_limit * v: + new_v += divisor + return new_v + +class SqueezeExcite(nn.Module): + """ + This class is taken from the timm package (https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/squeeze_excite.py) + and slightly modified so that the convolution type can be adapted. + + SE Module as defined in original SE-Nets with a few additions + Additions include: + * divisor can be specified to keep channels % div == 0 (default: 8) + * reduction channels can be specified directly by arg (if rd_channels is set) + * reduction channels can be specified by float rd_ratio (default: 1/16) + * global max pooling can be added to the squeeze aggregation + * customizable activation, normalization, and gate layer + """ + def __init__( + self, channels, conv_op, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False, + act_layer=nn.ReLU, norm_layer=None, gate_layer=nn.Sigmoid): + super(SqueezeExcite, self).__init__() + self.add_maxpool = add_maxpool + if not rd_channels: + rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) + self.fc1 = conv_op(channels, rd_channels, kernel_size=1, bias=True) + self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity() + self.act = act_layer(inplace=True) + self.fc2 = conv_op(rd_channels, channels, kernel_size=1, bias=True) + self.gate = gate_layer() + + def forward(self, x): + x_se = x.mean((2, 3), keepdim=True) + if self.add_maxpool: + # experimental codepath, may remove or change + x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True) + x_se = self.fc1(x_se) + x_se = self.act(self.bn(x_se)) + x_se = self.fc2(x_se) + return x * self.gate(x_se) + + +class ConvDropoutNormReLU(nn.Module): + def __init__(self, + conv_op: Type[_ConvNd], + input_channels: int, + output_channels: int, + kernel_size: Union[int, List[int], Tuple[int, ...]], + stride: Union[int, List[int], Tuple[int, ...]], + conv_bias: bool = False, + norm_op: Union[None, Type[nn.Module]] = None, + norm_op_kwargs: dict = None, + dropout_op: Union[None, Type[_DropoutNd]] = None, + dropout_op_kwargs: dict = None, + nonlin: Union[None, Type[torch.nn.Module]] = None, + nonlin_kwargs: dict = None, + nonlin_first: bool = False + ): + super(ConvDropoutNormReLU, self).__init__() + self.input_channels = input_channels + self.output_channels = output_channels + if not isinstance(stride, (tuple, list, np.ndarray)): + stride = [stride] * dim_of_conv_op(conv_op) + self.stride = stride + + if not isinstance(kernel_size, (tuple, list, np.ndarray)): + kernel_size = [kernel_size] * dim_of_conv_op(conv_op) + if norm_op_kwargs is None: + norm_op_kwargs = {} + if nonlin_kwargs is None: + nonlin_kwargs = {} + + ops = [] + + self.conv = conv_op( + input_channels, + output_channels, + kernel_size, + stride, + padding=[(i - 1) // 2 for i in kernel_size], + dilation=1, + bias=conv_bias, + ) + ops.append(self.conv) + + if dropout_op is not None: + self.dropout = dropout_op(**dropout_op_kwargs) + ops.append(self.dropout) + + if norm_op is not None: + self.norm = norm_op(output_channels, **norm_op_kwargs) + ops.append(self.norm) + + if nonlin is not None: + self.nonlin = nonlin(**nonlin_kwargs) + ops.append(self.nonlin) + + if nonlin_first and (norm_op is not None and nonlin is not None): + ops[-1], ops[-2] = ops[-2], ops[-1] + + self.all_modules = nn.Sequential(*ops) + + def forward(self, x): + return self.all_modules(x) + + def compute_conv_feature_map_size(self, input_size): + assert len(input_size) == len(self.stride), "just give the image size without color/feature channels or " \ + "batch channel. Do not give input_size=(b, c, x, y(, z)). " \ + "Give input_size=(x, y(, z))!" + output_size = [i // j for i, j in zip(input_size, self.stride)] # we always do same padding + return np.prod([self.output_channels, *output_size], dtype=np.int64) + +# # from dynamic_network_architectures.building_blocks.residual import BasicBlockD +class BasicBlockD(nn.Module): + def __init__(self, + conv_op: Type[_ConvNd], + input_channels: int, + output_channels: int, + kernel_size: Union[int, List[int], Tuple[int, ...]], + stride: Union[int, List[int], Tuple[int, ...]], + conv_bias: bool = False, + norm_op: Union[None, Type[nn.Module]] = None, + norm_op_kwargs: dict = None, + dropout_op: Union[None, Type[_DropoutNd]] = None, + dropout_op_kwargs: dict = None, + nonlin: Union[None, Type[torch.nn.Module]] = None, + nonlin_kwargs: dict = None, + stochastic_depth_p: float = 0.0, + squeeze_excitation: bool = False, + squeeze_excitation_reduction_ratio: float = 1. / 16, + # todo wideresnet? + ): + """ + This implementation follows ResNet-D: + + He, Tong, et al. "Bag of tricks for image classification with convolutional neural networks." + Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2019. + + The skip has an avgpool (if needed) followed by 1x1 conv instead of just a strided 1x1 conv + + :param conv_op: + :param input_channels: + :param output_channels: + :param kernel_size: refers only to convs in feature extraction path, not to 1x1x1 conv in skip + :param stride: only applies to first conv (and skip). Second conv always has stride 1 + :param conv_bias: + :param norm_op: + :param norm_op_kwargs: + :param dropout_op: only the first conv can have dropout. The second never has + :param dropout_op_kwargs: + :param nonlin: + :param nonlin_kwargs: + :param stochastic_depth_p: + :param squeeze_excitation: + :param squeeze_excitation_reduction_ratio: + """ + super().__init__() + self.input_channels = input_channels + self.output_channels = output_channels + if not isinstance(stride, (tuple, list, np.ndarray)): + stride = [stride] * dim_of_conv_op(conv_op) + self.stride = stride + + if not isinstance(kernel_size, (tuple, list, np.ndarray)): + kernel_size = [kernel_size] * dim_of_conv_op(conv_op) + + if norm_op_kwargs is None: + norm_op_kwargs = {} + if nonlin_kwargs is None: + nonlin_kwargs = {} + + self.conv1 = ConvDropoutNormReLU(conv_op, input_channels, output_channels, kernel_size, stride, conv_bias, + norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs) + self.conv2 = ConvDropoutNormReLU(conv_op, output_channels, output_channels, kernel_size, 1, conv_bias, norm_op, + norm_op_kwargs, None, None, None, None) + + self.nonlin2 = nonlin(**nonlin_kwargs) if nonlin is not None else lambda x: x + + # Stochastic Depth + self.apply_stochastic_depth = False if stochastic_depth_p == 0.0 else True + if self.apply_stochastic_depth: + self.drop_path = DropPath(drop_prob=stochastic_depth_p) + + # Squeeze Excitation + self.apply_se = squeeze_excitation + if self.apply_se: + self.squeeze_excitation = SqueezeExcite(self.output_channels, conv_op, + rd_ratio=squeeze_excitation_reduction_ratio, rd_divisor=8) + + has_stride = (isinstance(stride, int) and stride != 1) or any([i != 1 for i in stride]) + requires_projection = (input_channels != output_channels) + + if has_stride or requires_projection: + ops = [] + if has_stride: + ops.append(get_matching_pool_op(conv_op=conv_op, adaptive=False, pool_type='avg')(stride, stride)) + if requires_projection: + ops.append( + ConvDropoutNormReLU(conv_op, input_channels, output_channels, 1, 1, False, norm_op, + norm_op_kwargs, None, None, None, None + ) + ) + self.skip = nn.Sequential(*ops) + else: + self.skip = lambda x: x + + def forward(self, x): + residual = self.skip(x) + out = self.conv2(self.conv1(x)) + if self.apply_stochastic_depth: + out = self.drop_path(out) + if self.apply_se: + out = self.squeeze_excitation(out) + out += residual + return self.nonlin2(out) + + def compute_conv_feature_map_size(self, input_size): + assert len(input_size) == len(self.stride), "just give the image size without color/feature channels or " \ + "batch channel. Do not give input_size=(b, c, x, y(, z)). " \ + "Give input_size=(x, y(, z))!" + size_after_stride = [i // j for i, j in zip(input_size, self.stride)] + # conv1 + output_size_conv1 = np.prod([self.output_channels, *size_after_stride], dtype=np.int64) + # conv2 + output_size_conv2 = np.prod([self.output_channels, *size_after_stride], dtype=np.int64) + # skip conv (if applicable) + if (self.input_channels != self.output_channels) or any([i != j for i, j in zip(input_size, size_after_stride)]): + assert isinstance(self.skip, nn.Sequential) + output_size_skip = np.prod([self.output_channels, *size_after_stride], dtype=np.int64) + else: + assert not isinstance(self.skip, nn.Sequential) + output_size_skip = 0 + return output_size_conv1 + output_size_conv2 + output_size_skip + +class UpsampleLayer(nn.Module): + def __init__( + self, + conv_op, + input_channels, + output_channels, + pool_op_kernel_size, + mode='nearest' + ): + super().__init__() + self.conv = conv_op(input_channels, output_channels, kernel_size=1) + self.pool_op_kernel_size = pool_op_kernel_size + self.mode = mode + + def forward(self, x): + x = F.interpolate(x, scale_factor=self.pool_op_kernel_size, mode=self.mode) + x = self.conv(x) + return x + +# class MambaLayer(nn.Module): +# def __init__(self, dim, d_state = 16, d_conv = 4, expand = 2): +# super().__init__() +# self.dim = dim +# self.norm = nn.LayerNorm(dim) +# self.mamba = Mamba( +# d_model=dim, # Model dimension d_model +# d_state=d_state, # SSM state expansion factor +# d_conv=d_conv, # Local convolution width +# expand=expand, # Block expansion factor +# ) + +# @autocast(enabled=False) +# def forward(self, x): +# if x.dtype == torch.float16: +# x = x.type(torch.float32) +# B, C = x.shape[:2] +# assert C == self.dim +# n_tokens = x.shape[2:].numel() +# img_dims = x.shape[2:] +# x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2) +# x_norm = self.norm(x_flat) +# x_mamba = self.mamba(x_norm) +# out = x_mamba.transpose(-1, -2).reshape(B, C, *img_dims) + +# return out + +class MambaLayer(nn.Module): + def __init__(self, dim, d_state = 16, d_conv = 4, expand = 2, channel_token = False): + super().__init__() + self.dim = dim + self.norm = nn.LayerNorm(dim) + self.mamba = Mamba( + d_model=dim, # Model dimension d_model + d_state=d_state, # SSM state expansion factor + d_conv=d_conv, # Local convolution width + expand=expand, # Block expansion factor + ) + self.channel_token = channel_token ## whether to use channel as tokens + + def forward_patch_token(self, x): + B, d_model = x.shape[:2] + assert d_model == self.dim + n_tokens = x.shape[2:].numel() + img_dims = x.shape[2:] + x_flat = x.reshape(B, d_model, n_tokens).transpose(-1, -2) + x_norm = self.norm(x_flat) + x_mamba = self.mamba(x_norm) + out = x_mamba.transpose(-1, -2).reshape(B, d_model, *img_dims) + + return out + + def forward_channel_token(self, x): + B, n_tokens = x.shape[:2] + d_model = x.shape[2:].numel() + assert d_model == self.dim, f"d_model: {d_model}, self.dim: {self.dim}" + img_dims = x.shape[2:] + x_flat = x.flatten(2) + assert x_flat.shape[2] == d_model, f"x_flat.shape[2]: {x_flat.shape[2]}, d_model: {d_model}" + x_norm = self.norm(x_flat) + x_mamba = self.mamba(x_norm) + out = x_mamba.reshape(B, n_tokens, *img_dims) + + return out + + @autocast(enabled=False) + def forward(self, x): + if x.dtype == torch.float16: + x = x.type(torch.float32) + + if self.channel_token: + out = self.forward_channel_token(x) + else: + out = self.forward_patch_token(x) + + return out + + +class BasicResBlock(nn.Module): + def __init__( + self, + conv_op, + input_channels, + output_channels, + norm_op, + norm_op_kwargs, + kernel_size=3, + padding=1, + stride=1, + use_1x1conv=False, + nonlin=nn.LeakyReLU, + nonlin_kwargs={'inplace': True} + ): + super().__init__() + + self.conv1 = conv_op(input_channels, output_channels, kernel_size, stride=stride, padding=padding) + self.norm1 = norm_op(output_channels, **norm_op_kwargs) + self.act1 = nonlin(**nonlin_kwargs) + + self.conv2 = conv_op(output_channels, output_channels, kernel_size, padding=padding) + self.norm2 = norm_op(output_channels, **norm_op_kwargs) + self.act2 = nonlin(**nonlin_kwargs) + + if use_1x1conv: + self.conv3 = conv_op(input_channels, output_channels, kernel_size=1, stride=stride) + else: + self.conv3 = None + + def forward(self, x): + y = self.conv1(x) + y = self.act1(self.norm1(y)) + y = self.norm2(self.conv2(y)) + if self.conv3: + x = self.conv3(x) + y += x + return self.act2(y) + +class UNetResEncoder(nn.Module): + def __init__(self, + input_channels: int, + n_stages: int, + features_per_stage: Union[int, List[int], Tuple[int, ...]], + conv_op: Type[_ConvNd], + kernel_sizes: Union[int, List[int], Tuple[int, ...]], + strides: Union[int, List[int], Tuple[int, ...], Tuple[Tuple[int, ...], ...]], + n_blocks_per_stage: Union[int, List[int], Tuple[int, ...]], + conv_bias: bool = False, + norm_op: Union[None, Type[nn.Module]] = None, + norm_op_kwargs: dict = None, + nonlin: Union[None, Type[torch.nn.Module]] = None, + nonlin_kwargs: dict = None, + return_skips: bool = False, + stem_channels: int = None, + pool_type: str = 'conv', + ): + super().__init__() + if isinstance(kernel_sizes, int): + kernel_sizes = [kernel_sizes] * n_stages + if isinstance(features_per_stage, int): + features_per_stage = [features_per_stage] * n_stages + if isinstance(n_blocks_per_stage, int): + n_blocks_per_stage = [n_blocks_per_stage] * n_stages + if isinstance(strides, int): + strides = [strides] * n_stages + + assert len( + kernel_sizes) == n_stages, "kernel_sizes must have as many entries as we have resolution stages (n_stages)" + assert len( + n_blocks_per_stage) == n_stages, "n_conv_per_stage must have as many entries as we have resolution stages (n_stages)" + assert len( + features_per_stage) == n_stages, "features_per_stage must have as many entries as we have resolution stages (n_stages)" + assert len(strides) == n_stages, "strides must have as many entries as we have resolution stages (n_stages). " \ + "Important: first entry is recommended to be 1, else we run strided conv drectly on the input" + + pool_op = get_matching_pool_op(conv_op, pool_type=pool_type) if pool_type != 'conv' else None + + self.conv_pad_sizes = [] + for krnl in kernel_sizes: + self.conv_pad_sizes.append([i // 2 for i in krnl]) + + stem_channels = features_per_stage[0] + + self.stem = nn.Sequential( + BasicResBlock( + conv_op = conv_op, + input_channels = input_channels, + output_channels = stem_channels, + norm_op=norm_op, + norm_op_kwargs=norm_op_kwargs, + kernel_size=kernel_sizes[0], + padding=self.conv_pad_sizes[0], + stride=1, + nonlin=nonlin, + nonlin_kwargs=nonlin_kwargs, + use_1x1conv=True + ), + *[ + BasicBlockD( + conv_op = conv_op, + input_channels = stem_channels, + output_channels = stem_channels, + kernel_size = kernel_sizes[0], + stride = 1, + conv_bias = conv_bias, + norm_op = norm_op, + norm_op_kwargs = norm_op_kwargs, + nonlin = nonlin, + nonlin_kwargs = nonlin_kwargs, + ) for _ in range(n_blocks_per_stage[0] - 1) + ] + ) + + + input_channels = stem_channels + + # now build the network + stages = [] + for s in range(n_stages): + stage = nn.Sequential( + BasicResBlock( + conv_op = conv_op, + norm_op = norm_op, + norm_op_kwargs = norm_op_kwargs, + input_channels = input_channels, + output_channels = features_per_stage[s], + kernel_size = kernel_sizes[s], + padding=self.conv_pad_sizes[s], + stride=strides[s], + use_1x1conv=True, + nonlin = nonlin, + nonlin_kwargs = nonlin_kwargs + ), + *[ + BasicBlockD( + conv_op = conv_op, + input_channels = features_per_stage[s], + output_channels = features_per_stage[s], + kernel_size = kernel_sizes[s], + stride = 1, + conv_bias = conv_bias, + norm_op = norm_op, + norm_op_kwargs = norm_op_kwargs, + nonlin = nonlin, + nonlin_kwargs = nonlin_kwargs, + ) for _ in range(n_blocks_per_stage[s] - 1) + ] + ) + + stages.append(stage) + input_channels = features_per_stage[s] + + self.stages = nn.Sequential(*stages) + self.output_channels = features_per_stage + self.strides = [[item] * dim_of_conv_op(conv_op) if not isinstance(item, (tuple, list, np.ndarray)) \ + else item for item in strides] + # self.strides = [maybe_convert_scalar_to_list(conv_op, i) for i in strides] + + self.return_skips = return_skips + + self.conv_op = conv_op + self.norm_op = norm_op + self.norm_op_kwargs = norm_op_kwargs + self.nonlin = nonlin + self.nonlin_kwargs = nonlin_kwargs + #self.dropout_op = dropout_op + #self.dropout_op_kwargs = dropout_op_kwargs + self.conv_bias = conv_bias + self.kernel_sizes = kernel_sizes + + def forward(self, x): + if self.stem is not None: + x = self.stem(x) + ret = [] + for s in self.stages: + x = s(x) + ret.append(x) + if self.return_skips: + return ret + else: + return ret[-1] + + def compute_conv_feature_map_size(self, input_size): + if self.stem is not None: + output = self.stem.compute_conv_feature_map_size(input_size) + else: + output = np.int64(0) + + for s in range(len(self.stages)): + output += self.stages[s].compute_conv_feature_map_size(input_size) + input_size = [i // j for i, j in zip(input_size, self.strides[s])] + + return output + + +class UNetResDecoder(nn.Module): + def __init__(self, + encoder, + num_classes, + n_conv_per_stage: Union[int, Tuple[int, ...], List[int]], + deep_supervision, nonlin_first: bool = False): + + super().__init__() + self.deep_supervision = deep_supervision + self.encoder = encoder + self.num_classes = num_classes + n_stages_encoder = len(encoder.output_channels) + if isinstance(n_conv_per_stage, int): + n_conv_per_stage = [n_conv_per_stage] * (n_stages_encoder - 1) + assert len(n_conv_per_stage) == n_stages_encoder - 1, "n_conv_per_stage must have as many entries as we have " \ + "resolution stages - 1 (n_stages in encoder - 1), " \ + "here: %d" % n_stages_encoder + + stages = [] + upsample_layers = [] + + seg_layers = [] + for s in range(1, n_stages_encoder): + input_features_below = encoder.output_channels[-s] + input_features_skip = encoder.output_channels[-(s + 1)] + stride_for_upsampling = encoder.strides[-s] + upsample_layers.append(UpsampleLayer( + conv_op = encoder.conv_op, + input_channels = input_features_below, + output_channels = input_features_skip, + pool_op_kernel_size = stride_for_upsampling, + mode='nearest' + )) + + stages.append(nn.Sequential( + BasicResBlock( + conv_op = encoder.conv_op, + norm_op = encoder.norm_op, + norm_op_kwargs = encoder.norm_op_kwargs, + nonlin = encoder.nonlin, + nonlin_kwargs = encoder.nonlin_kwargs, + input_channels = 2 * input_features_skip if s < n_stages_encoder - 1 else input_features_skip, + output_channels = input_features_skip, + kernel_size = encoder.kernel_sizes[-(s + 1)], + padding=encoder.conv_pad_sizes[-(s + 1)], + stride=1, + use_1x1conv=True + ), + *[ + BasicBlockD( + conv_op = encoder.conv_op, + input_channels = input_features_skip, + output_channels = input_features_skip, + kernel_size = encoder.kernel_sizes[-(s + 1)], + stride = 1, + conv_bias = encoder.conv_bias, + norm_op = encoder.norm_op, + norm_op_kwargs = encoder.norm_op_kwargs, + nonlin = encoder.nonlin, + nonlin_kwargs = encoder.nonlin_kwargs, + ) for _ in range(n_conv_per_stage[s-1] - 1) + ] + )) + seg_layers.append(encoder.conv_op(input_features_skip, num_classes, 1, 1, 0, bias=True)) + + self.stages = nn.ModuleList(stages) + self.upsample_layers = nn.ModuleList(upsample_layers) + self.seg_layers = nn.ModuleList(seg_layers) + + def forward(self, skips): + lres_input = skips[-1] + seg_outputs = [] + for s in range(len(self.stages)): + x = self.upsample_layers[s](lres_input) + if s < (len(self.stages) - 1): + x = torch.cat((x, skips[-(s+2)]), 1) + x = self.stages[s](x) + if self.deep_supervision: + seg_outputs.append(self.seg_layers[s](x)) + elif s == (len(self.stages) - 1): + seg_outputs.append(self.seg_layers[-1](x)) + lres_input = x + seg_outputs = seg_outputs[::-1] + + if not self.deep_supervision: + r = seg_outputs[0] + else: + r = seg_outputs + return r + + def compute_conv_feature_map_size(self, input_size): + skip_sizes = [] + for s in range(len(self.encoder.strides) - 1): + skip_sizes.append([i // j for i, j in zip(input_size, self.encoder.strides[s])]) + input_size = skip_sizes[-1] + + assert len(skip_sizes) == len(self.stages) + + output = np.int64(0) + for s in range(len(self.stages)): + output += self.stages[s].compute_conv_feature_map_size(skip_sizes[-(s+1)]) + output += np.prod([self.encoder.output_channels[-(s+2)], *skip_sizes[-(s+1)]], dtype=np.int64) + if self.deep_supervision or (s == (len(self.stages) - 1)): + output += np.prod([self.num_classes, *skip_sizes[-(s+1)]], dtype=np.int64) + return output + + +class UMambaBot(nn.Module): + """ + UMambaBot that uses Mamba block at the bottleneck of UNet. + + * Reference: Jun Ma, Feifei Li, Bo Wang. + U-Mamba: Enhancing long-range dependency for biomedical image segmentation. + arxiv 2403.20035, 2024. + + The implementation is based on the code at: + https://github.com/bowang-lab/U-Mamba. + + The parameters for the backbone should be given in the `params` dictionary. + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param class_num: (int) The class number for segmentation task. + :param n_blocks_per_stage: (int) the number of con blocks at each stage. + """ + def __init__(self, params): + super(UMambaBot, self).__init__() + params = self.get_default_parameters(params) + for p in params: + print(p, params[p]) + + # def __init__(self, + # input_channels: int, + # n_stages: int, + # features_per_stage: Union[int, List[int], Tuple[int, ...]], + # conv_op: Type[_ConvNd], + # kernel_sizes: Union[int, List[int], Tuple[int, ...]], + # strides: Union[int, List[int], Tuple[int, ...]], + # n_conv_per_stage: Union[int, List[int], Tuple[int, ...]], + # num_classes: int, + # n_conv_per_stage_decoder: Union[int, Tuple[int, ...], List[int]], + # conv_bias: bool = False, + # norm_op: Union[None, Type[nn.Module]] = None, + # norm_op_kwargs: dict = None, + # dropout_op: Union[None, Type[_DropoutNd]] = None, + # dropout_op_kwargs: dict = None, + # nonlin: Union[None, Type[torch.nn.Module]] = None, + # nonlin_kwargs: dict = None, + # deep_supervision: bool = False, + # stem_channels: int = None + # ): + # super().__init__() + + input_channels = params['in_chns'] + features_per_stage = params['feature_chns'] + num_classes = params['class_num'] + n_blocks_per_stage = params['n_blocks_per_stage'] + n_conv_per_stage_decoder = n_blocks_per_stage + n_stages = len(features_per_stage) + conv_op = nn.Conv2d + kernel_sizes = [(3,3)] * len(features_per_stage) + strides = [(1, 1)] + [(2,2)] * (len(features_per_stage) - 1) + # strides = [(1,1)] * len(features_per_stage) + + conv_bias = True + norm_op = nn.InstanceNorm2d + norm_op_kwargs = {"affine":True} + nonlin=nn.LeakyReLU + nonlin_kwargs={'inplace': True} + deep_supervision = False + stem_channels = None + + if isinstance(n_blocks_per_stage, int): + n_blocks_per_stage = [n_blocks_per_stage] * n_stages + if isinstance(n_conv_per_stage_decoder, int): + n_conv_per_stage_decoder = [n_conv_per_stage_decoder] * (n_stages - 1) + + for s in range(math.ceil(n_stages / 2), n_stages): + n_blocks_per_stage[s] = 1 + + for s in range(math.ceil((n_stages - 1) / 2 + 0.5), n_stages - 1): + n_conv_per_stage_decoder[s] = 1 + + + assert len(n_blocks_per_stage) == n_stages, "n_blocks_per_stage must have as many entries as we have " \ + f"resolution stages. here: {n_stages}. " \ + f"n_blocks_per_stage: {n_blocks_per_stage}" + assert len(n_conv_per_stage_decoder) == (n_stages - 1), "n_conv_per_stage_decoder must have one less entries " \ + f"as we have resolution stages. here: {n_stages} " \ + f"stages, so it should have {n_stages - 1} entries. " \ + f"n_conv_per_stage_decoder: {n_conv_per_stage_decoder}" + + + self.encoder = UNetResEncoder( + input_channels, + n_stages, + features_per_stage, + conv_op, + kernel_sizes, + strides, + n_blocks_per_stage, + conv_bias, + norm_op, + norm_op_kwargs, + nonlin, + nonlin_kwargs, + return_skips=True, + stem_channels=stem_channels + ) + + self.mamba_layer = MambaLayer(dim = features_per_stage[-1]) + + self.decoder = UNetResDecoder(self.encoder, num_classes, n_conv_per_stage_decoder, deep_supervision) + + def get_default_parameters(self, params): + default_param = { + 'feature_chns': [32, 64, 128, 256, 512], + 'n_blocks_per_stage': 2 + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + skips = self.encoder(x) + # for skip in skips: + # print(skip.shape) + skips[-1] = self.mamba_layer(skips[-1]) + output = self.decoder(skips) + if(len(x_shape) == 5): + if(isinstance(output, (list,tuple))): + for i in range(len(output)): + new_shape = [N, D] + list(output[i].shape)[1:] + output[i] = torch.transpose(torch.reshape(output[i], new_shape), 1, 2) + else: + new_shape = [N, D] + list(output.shape)[1:] + output = torch.transpose(torch.reshape(output, new_shape), 1, 2) + return output + + def compute_conv_feature_map_size(self, input_size): + assert len(input_size) == dim_of_conv_op(self.encoder.conv_op), "just give the image size without color/feature channels or " \ + "batch channel. Do not give input_size=(b, c, x, y(, z)). " \ + "Give input_size=(x, y(, z))!" + return self.encoder.compute_conv_feature_map_size(input_size) + self.decoder.compute_conv_feature_map_size(input_size) + +class ResidualMambaEncoder(nn.Module): + def __init__(self, + input_size: Tuple[int, ...], + input_channels: int, + n_stages: int, + features_per_stage: Union[int, List[int], Tuple[int, ...]], + conv_op: Type[_ConvNd], + kernel_sizes: Union[int, List[int], Tuple[int, ...]], + strides: Union[int, List[int], Tuple[int, ...], Tuple[Tuple[int, ...], ...]], + n_blocks_per_stage: Union[int, List[int], Tuple[int, ...]], + conv_bias: bool = False, + norm_op: Union[None, Type[nn.Module]] = None, + norm_op_kwargs: dict = None, + nonlin: Union[None, Type[torch.nn.Module]] = None, + nonlin_kwargs: dict = None, + return_skips: bool = False, + stem_channels: int = None, + pool_type: str = 'conv', + ): + super().__init__() + if isinstance(kernel_sizes, int): + kernel_sizes = [kernel_sizes] * n_stages + if isinstance(features_per_stage, int): + features_per_stage = [features_per_stage] * n_stages + if isinstance(n_blocks_per_stage, int): + n_blocks_per_stage = [n_blocks_per_stage] * n_stages + if isinstance(strides, int): + strides = [strides] * n_stages + assert len( + kernel_sizes) == n_stages, "kernel_sizes must have as many entries as we have resolution stages (n_stages)" + assert len( + n_blocks_per_stage) == n_stages, "n_conv_per_stage must have as many entries as we have resolution stages (n_stages)" + assert len( + features_per_stage) == n_stages, "features_per_stage must have as many entries as we have resolution stages (n_stages)" + assert len(strides) == n_stages, "strides must have as many entries as we have resolution stages (n_stages). " \ + "Important: first entry is recommended to be 1, else we run strided conv drectly on the input" + + pool_op = get_matching_pool_op(conv_op, pool_type=pool_type) if pool_type != 'conv' else None + + do_channel_token = [False] * n_stages + feature_map_sizes = [] + feature_map_size = input_size + for s in range(n_stages): + feature_map_sizes.append([i // j for i, j in zip(feature_map_size, strides[s])]) + feature_map_size = feature_map_sizes[-1] + if np.prod(feature_map_size) <= features_per_stage[s]: + do_channel_token[s] = True + + + print(f"feature_map_sizes: {feature_map_sizes}") + print(f"do_channel_token: {do_channel_token}") + + self.conv_pad_sizes = [] + for krnl in kernel_sizes: + self.conv_pad_sizes.append([i // 2 for i in krnl]) + + stem_channels = features_per_stage[0] + self.stem = nn.Sequential( + BasicResBlock( + conv_op = conv_op, + input_channels = input_channels, + output_channels = stem_channels, + norm_op=norm_op, + norm_op_kwargs=norm_op_kwargs, + kernel_size=kernel_sizes[0], + padding=self.conv_pad_sizes[0], + stride=1, + nonlin=nonlin, + nonlin_kwargs=nonlin_kwargs, + use_1x1conv=True + ), + *[ + BasicBlockD( + conv_op = conv_op, + input_channels = stem_channels, + output_channels = stem_channels, + kernel_size = kernel_sizes[0], + stride = 1, + conv_bias = conv_bias, + norm_op = norm_op, + norm_op_kwargs = norm_op_kwargs, + nonlin = nonlin, + nonlin_kwargs = nonlin_kwargs, + ) for _ in range(n_blocks_per_stage[0] - 1) + ] + ) + + input_channels = stem_channels + + stages = [] + mamba_layers = [] + for s in range(n_stages): + stage = nn.Sequential( + BasicResBlock( + conv_op = conv_op, + norm_op = norm_op, + norm_op_kwargs = norm_op_kwargs, + input_channels = input_channels, + output_channels = features_per_stage[s], + kernel_size = kernel_sizes[s], + padding=self.conv_pad_sizes[s], + stride=strides[s], + use_1x1conv=True, + nonlin = nonlin, + nonlin_kwargs = nonlin_kwargs + ), + *[ + BasicBlockD( + conv_op = conv_op, + input_channels = features_per_stage[s], + output_channels = features_per_stage[s], + kernel_size = kernel_sizes[s], + stride = 1, + conv_bias = conv_bias, + norm_op = norm_op, + norm_op_kwargs = norm_op_kwargs, + nonlin = nonlin, + nonlin_kwargs = nonlin_kwargs, + ) for _ in range(n_blocks_per_stage[s] - 1) + ] + ) + + if bool(s % 2) ^ bool(n_stages % 2): ## gurantee the last stage has mamaba layer + mamba_layers.append( + MambaLayer( + dim = np.prod(feature_map_sizes[s]) if do_channel_token[s] else features_per_stage[s], + channel_token = do_channel_token[s] + ) + ) + else: + mamba_layers.append(nn.Identity()) + + stages.append(stage) + input_channels = features_per_stage[s] + + self.mamba_layers = nn.ModuleList(mamba_layers) + self.stages = nn.ModuleList(stages) + self.output_channels = features_per_stage + self.strides = [[item] * dim_of_conv_op(conv_op) if not isinstance(item, (tuple, list, np.ndarray)) \ + else item for item in strides] + # self.strides = [maybe_convert_scalar_to_list(conv_op, i) for i in strides] + self.return_skips = return_skips + + self.conv_op = conv_op + self.norm_op = norm_op + self.norm_op_kwargs = norm_op_kwargs + self.nonlin = nonlin + self.nonlin_kwargs = nonlin_kwargs + #self.dropout_op = dropout_op + #self.dropout_op_kwargs = dropout_op_kwargs + self.conv_bias = conv_bias + self.kernel_sizes = kernel_sizes + + def forward(self, x): + if self.stem is not None: + x = self.stem(x) + ret = [] + for s in range(len(self.stages)): + x = self.stages[s](x) + x = self.mamba_layers[s](x) + ret.append(x) + if self.return_skips: + return ret + else: + return ret[-1] + + def compute_conv_feature_map_size(self, input_size): + if self.stem is not None: + output = self.stem.compute_conv_feature_map_size(input_size) + else: + output = np.int64(0) + + for s in range(len(self.stages)): + output += self.stages[s].compute_conv_feature_map_size(input_size) + input_size = [i // j for i, j in zip(input_size, self.strides[s])] + + return output + + +class UMambaEnc(nn.Module): + """ + UMambaEnc that uses Mamba block at the encoder and bottleneck of UNet. + + * Reference: Jun Ma, Feifei Li, Bo Wang. + U-Mamba: Enhancing long-range dependency for biomedical image segmentation. + arxiv 2403.20035, 2024. + + The implementation is based on the code at: + https://github.com/bowang-lab/U-Mamba. + + The parameters for the backbone should be given in the `params` dictionary. + + :param input_size: (list) the size of input image, such as [256, 256] + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param class_num: (int) The class number for segmentation task. + :param n_blocks_per_stage: (int) the number of con blocks at each stage. + """ + def __init__(self, params): + super(UMambaEnc, self).__init__() + params = self.get_default_parameters(params) + for p in params: + print(p, params[p]) + # def __init__(self, + # input_size: Tuple[int, ...], + # input_channels: int, + # n_stages: int, + # features_per_stage: Union[int, List[int], Tuple[int, ...]], + # conv_op: Type[_ConvNd], + # kernel_sizes: Union[int, List[int], Tuple[int, ...]], + # strides: Union[int, List[int], Tuple[int, ...]], + # n_conv_per_stage: Union[int, List[int], Tuple[int, ...]], + # num_classes: int, + # n_conv_per_stage_decoder: Union[int, Tuple[int, ...], List[int]], + # conv_bias: bool = False, + # norm_op: Union[None, Type[nn.Module]] = None, + # norm_op_kwargs: dict = None, + # dropout_op: Union[None, Type[_DropoutNd]] = None, + # dropout_op_kwargs: dict = None, + # nonlin: Union[None, Type[torch.nn.Module]] = None, + # nonlin_kwargs: dict = None, + # deep_supervision: bool = False, + # stem_channels: int = None + # ): + # super().__init__() + + input_size = params['input_size'] + input_channels = params['in_chns'] + features_per_stage = params['feature_chns'] + num_classes = params['class_num'] + n_blocks_per_stage = params['n_blocks_per_stage'] + n_conv_per_stage_decoder = n_blocks_per_stage + n_stages = len(features_per_stage) + conv_op = nn.Conv2d + kernel_sizes = [(3,3)] * len(features_per_stage) + strides = [(1, 1)] + [(2,2)] * (len(features_per_stage) - 1) + + conv_bias = True + norm_op = nn.InstanceNorm2d + norm_op_kwargs = {"affine":True} + nonlin=nn.LeakyReLU + nonlin_kwargs={'inplace': True} + deep_supervision = False + stem_channels = None + + if isinstance(n_blocks_per_stage, int): + n_blocks_per_stage = [n_blocks_per_stage] * n_stages + if isinstance(n_conv_per_stage_decoder, int): + n_conv_per_stage_decoder = [n_conv_per_stage_decoder] * (n_stages - 1) + + for s in range(math.ceil(n_stages / 2), n_stages): + n_blocks_per_stage[s] = 1 + + for s in range(math.ceil((n_stages - 1) / 2 + 0.5), n_stages - 1): + n_conv_per_stage_decoder[s] = 1 + + + assert len(n_blocks_per_stage) == n_stages, "n_blocks_per_stage must have as many entries as we have " \ + f"resolution stages. here: {n_stages}. " \ + f"n_blocks_per_stage: {n_blocks_per_stage}" + assert len(n_conv_per_stage_decoder) == (n_stages - 1), "n_conv_per_stage_decoder must have one less entries " \ + f"as we have resolution stages. here: {n_stages} " \ + f"stages, so it should have {n_stages - 1} entries. " \ + f"n_conv_per_stage_decoder: {n_conv_per_stage_decoder}" + self.encoder = ResidualMambaEncoder( + input_size, + input_channels, + n_stages, + features_per_stage, + conv_op, + kernel_sizes, + strides, + n_blocks_per_stage, + conv_bias, + norm_op, + norm_op_kwargs, + nonlin, + nonlin_kwargs, + return_skips=True, + stem_channels=stem_channels + ) + + self.decoder = UNetResDecoder(self.encoder, num_classes, n_conv_per_stage_decoder, deep_supervision) + + def get_default_parameters(self, params): + default_param = { + 'feature_chns': [32, 64, 128, 256, 512], + 'n_blocks_per_stage': 2 + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + skips = self.encoder(x) + output = self.decoder(skips) + if(len(x_shape) == 5): + if(isinstance(output, (list,tuple))): + for i in range(len(output)): + new_shape = [N, D] + list(output[i].shape)[1:] + output[i] = torch.transpose(torch.reshape(output[i], new_shape), 1, 2) + else: + new_shape = [N, D] + list(output.shape)[1:] + output = torch.transpose(torch.reshape(output, new_shape), 1, 2) + return output + + def compute_conv_feature_map_size(self, input_size): + assert len(input_size) == dim_of_conv_op(self.encoder.conv_op), "just give the image size without color/feature channels or " \ + "batch channel. Do not give input_size=(b, c, x, y(, z)). " \ + "Give input_size=(x, y(, z))!" + return self.encoder.compute_conv_feature_map_size(input_size) + self.decoder.compute_conv_feature_map_size(input_size) + diff --git a/pymic/net/net_dict_seg.py b/pymic/net/net_dict_seg.py index 8877ade..ffd82ae 100644 --- a/pymic/net/net_dict_seg.py +++ b/pymic/net/net_dict_seg.py @@ -26,6 +26,7 @@ from pymic.net.net2d.unet2d_scse import UNet2D_ScSE from pymic.net.net2d.trans2d.transunet import TransUNet from pymic.net.net2d.trans2d.swinunet import SwinUNet +from pymic.net.net2d.umamba import UMambaBot, UMambaEnc from pymic.net.net3d.unet2d5 import UNet2D5 from pymic.net.net3d.unet3d import UNet3D from pymic.net.net3d.grunet import GRUNet @@ -62,6 +63,8 @@ 'UNet2D_CCT': UNet2D_CCT, 'UNet2Dpp': UNet2Dpp, 'UNet2D_ScSE': UNet2D_ScSE, + 'UMambaBot': UMambaBot, + 'UMambaEnc': UMambaEnc, 'TransUNet': TransUNet, 'SwinUNet': SwinUNet, 'UNet2D5': UNet2D5, diff --git a/pymic/test/test_net2d.py b/pymic/test/test_net2d.py index aafaf20..9013386 100644 --- a/pymic/test/test_net2d.py +++ b/pymic/test/test_net2d.py @@ -5,7 +5,7 @@ import numpy as np from pymic.net.net2d.unet2d import UNet2D from pymic.net.net2d.unet2d_scse import UNet2D_ScSE - +from pymic.net.net2d.umamba_bot import UMambaBot def test_unet2d(): params = {'in_chns':4, 'feature_chns':[16, 32, 64, 128, 256], @@ -52,6 +52,17 @@ def test_unet2d_scse(): else: print(out.shape) +def test_umamba(): + x = np.random.rand(4, 4, 10, 256, 256) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + Net = UMambaBot() + out = Net(xt) + y = out.detach().numpy() + print(y.shape) + if __name__ == "__main__": # test_unet2d() - test_unet2d_scse() \ No newline at end of file + # test_unet2d_scse() + test_umamba() \ No newline at end of file From 9890b77ccaab4f963e0208291df93d242239967f Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 31 Jan 2025 22:47:09 +0800 Subject: [PATCH 75/86] add ultralight_vm_unet --- pymic/net/net2d/unet2d_vm_light.py | 284 +++++++++++++++++++++++++++++ pymic/net/net_dict_seg.py | 2 + 2 files changed, 286 insertions(+) create mode 100644 pymic/net/net2d/unet2d_vm_light.py diff --git a/pymic/net/net2d/unet2d_vm_light.py b/pymic/net/net2d/unet2d_vm_light.py new file mode 100644 index 0000000..ed3f1c5 --- /dev/null +++ b/pymic/net/net2d/unet2d_vm_light.py @@ -0,0 +1,284 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import math +import torch +from torch import nn +import torch.nn.functional as F + +from timm.models.layers import trunc_normal_ +from mamba_ssm import Mamba + + +class PVMLayer(nn.Module): + def __init__(self, input_dim, output_dim, d_state = 16, d_conv = 4, expand = 2): + super().__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.norm = nn.LayerNorm(input_dim) + self.mamba = Mamba( + d_model=input_dim//4, # Model dimension d_model + d_state=d_state, # SSM state expansion factor + d_conv=d_conv, # Local convolution width + expand=expand, # Block expansion factor + ) + self.proj = nn.Linear(input_dim, output_dim) + self.skip_scale= nn.Parameter(torch.ones(1)) + + def forward(self, x): + if x.dtype == torch.float16: + x = x.type(torch.float32) + B, C = x.shape[:2] + assert C == self.input_dim + n_tokens = x.shape[2:].numel() + img_dims = x.shape[2:] + x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2) + x_norm = self.norm(x_flat) + + x1, x2, x3, x4 = torch.chunk(x_norm, 4, dim=2) + x_mamba1 = self.mamba(x1) + self.skip_scale * x1 + x_mamba2 = self.mamba(x2) + self.skip_scale * x2 + x_mamba3 = self.mamba(x3) + self.skip_scale * x3 + x_mamba4 = self.mamba(x4) + self.skip_scale * x4 + x_mamba = torch.cat([x_mamba1, x_mamba2,x_mamba3,x_mamba4], dim=2) + + x_mamba = self.norm(x_mamba) + x_mamba = self.proj(x_mamba) + out = x_mamba.transpose(-1, -2).reshape(B, self.output_dim, *img_dims) + return out + + +class Channel_Att_Bridge(nn.Module): + def __init__(self, c_list, split_att='fc'): + super().__init__() + c_list_sum = sum(c_list) - c_list[-1] + self.split_att = split_att + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.get_all_att = nn.Conv1d(1, 1, kernel_size=3, padding=1, bias=False) + self.att1 = nn.Linear(c_list_sum, c_list[0]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[0], 1) + self.att2 = nn.Linear(c_list_sum, c_list[1]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[1], 1) + self.att3 = nn.Linear(c_list_sum, c_list[2]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[2], 1) + self.att4 = nn.Linear(c_list_sum, c_list[3]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[3], 1) + self.att5 = nn.Linear(c_list_sum, c_list[4]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[4], 1) + self.sigmoid = nn.Sigmoid() + + def forward(self, t1, t2, t3, t4, t5): + att = torch.cat((self.avgpool(t1), + self.avgpool(t2), + self.avgpool(t3), + self.avgpool(t4), + self.avgpool(t5)), dim=1) + att = self.get_all_att(att.squeeze(-1).transpose(-1, -2)) + if self.split_att != 'fc': + att = att.transpose(-1, -2) + att1 = self.sigmoid(self.att1(att)) + att2 = self.sigmoid(self.att2(att)) + att3 = self.sigmoid(self.att3(att)) + att4 = self.sigmoid(self.att4(att)) + att5 = self.sigmoid(self.att5(att)) + if self.split_att == 'fc': + att1 = att1.transpose(-1, -2).unsqueeze(-1).expand_as(t1) + att2 = att2.transpose(-1, -2).unsqueeze(-1).expand_as(t2) + att3 = att3.transpose(-1, -2).unsqueeze(-1).expand_as(t3) + att4 = att4.transpose(-1, -2).unsqueeze(-1).expand_as(t4) + att5 = att5.transpose(-1, -2).unsqueeze(-1).expand_as(t5) + else: + att1 = att1.unsqueeze(-1).expand_as(t1) + att2 = att2.unsqueeze(-1).expand_as(t2) + att3 = att3.unsqueeze(-1).expand_as(t3) + att4 = att4.unsqueeze(-1).expand_as(t4) + att5 = att5.unsqueeze(-1).expand_as(t5) + + return att1, att2, att3, att4, att5 + + +class Spatial_Att_Bridge(nn.Module): + def __init__(self): + super().__init__() + self.shared_conv2d = nn.Sequential(nn.Conv2d(2, 1, 7, stride=1, padding=9, dilation=3), + nn.Sigmoid()) + + def forward(self, t1, t2, t3, t4, t5): + t_list = [t1, t2, t3, t4, t5] + att_list = [] + for t in t_list: + avg_out = torch.mean(t, dim=1, keepdim=True) + max_out, _ = torch.max(t, dim=1, keepdim=True) + att = torch.cat([avg_out, max_out], dim=1) + att = self.shared_conv2d(att) + att_list.append(att) + return att_list[0], att_list[1], att_list[2], att_list[3], att_list[4] + + +class SC_Att_Bridge(nn.Module): + def __init__(self, c_list, split_att='fc'): + super().__init__() + + self.catt = Channel_Att_Bridge(c_list, split_att=split_att) + self.satt = Spatial_Att_Bridge() + + def forward(self, t1, t2, t3, t4, t5): + r1, r2, r3, r4, r5 = t1, t2, t3, t4, t5 + + satt1, satt2, satt3, satt4, satt5 = self.satt(t1, t2, t3, t4, t5) + t1, t2, t3, t4, t5 = satt1 * t1, satt2 * t2, satt3 * t3, satt4 * t4, satt5 * t5 + + r1_, r2_, r3_, r4_, r5_ = t1, t2, t3, t4, t5 + t1, t2, t3, t4, t5 = t1 + r1, t2 + r2, t3 + r3, t4 + r4, t5 + r5 + + catt1, catt2, catt3, catt4, catt5 = self.catt(t1, t2, t3, t4, t5) + t1, t2, t3, t4, t5 = catt1 * t1, catt2 * t2, catt3 * t3, catt4 * t4, catt5 * t5 + + return t1 + r1_, t2 + r2_, t3 + r3_, t4 + r4_, t5 + r5_ + + +class UltraLight_VM_UNet(nn.Module): + def __init__(self, params): + """ + UltraLight_VM_UNet that is a lightweight model using CNN and Mamba. + + * Reference: Renkai Wu, Yinghao Liu, Pengchen Liang, Qing Chang. + UltraLight VM-UNet: Parallel Vision Mamba Significantly Reduces Parameters for Skin Lesion Segmentation. + arxiv 2403.20035, 2024. + + The implementation is based on the code at: + https://github.com/wurenkai/UltraLight-VM-UNet. + + The parameters for the backbone should be given in the `params` dictionary. + + :param in_chns: (int) Input channel number. + :param class_num: (int) The class number for segmentation task. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 6, by default it is [8, 16, 24, 32, 48, 64]. + :param bridge: (int) If the bridge based on spatial and channel attentions is used or not. + By default it is True. + """ + super(UltraLight_VM_UNet, self).__init__() + + input_channels = params['in_chns'] + num_classes = params['class_num'] + c_list = params.get('feature_chns', [8, 16, 24, 32, 48, 64]) + self.bridge = params.get('bridge', True) + split_att = 'fc' + # def __init__(self, num_classes=1, input_channels=3, c_list=[8,16,24,32,48,64], + # split_att='fc', bridge=True): + # super().__init__() + # self.bridge = bridge + + self.encoder1 = nn.Sequential( + nn.Conv2d(input_channels, c_list[0], 3, stride=1, padding=1), + ) + self.encoder2 =nn.Sequential( + nn.Conv2d(c_list[0], c_list[1], 3, stride=1, padding=1), + ) + self.encoder3 = nn.Sequential( + nn.Conv2d(c_list[1], c_list[2], 3, stride=1, padding=1), + ) + self.encoder4 = nn.Sequential( + PVMLayer(input_dim=c_list[2], output_dim=c_list[3]) + ) + self.encoder5 = nn.Sequential( + PVMLayer(input_dim=c_list[3], output_dim=c_list[4]) + ) + self.encoder6 = nn.Sequential( + PVMLayer(input_dim=c_list[4], output_dim=c_list[5]) + ) + + if self.bridge: + self.scab = SC_Att_Bridge(c_list, split_att) + print('SC_Att_Bridge was used') + + self.decoder1 = nn.Sequential( + PVMLayer(input_dim=c_list[5], output_dim=c_list[4]) + ) + self.decoder2 = nn.Sequential( + PVMLayer(input_dim=c_list[4], output_dim=c_list[3]) + ) + self.decoder3 = nn.Sequential( + PVMLayer(input_dim=c_list[3], output_dim=c_list[2]) + ) + self.decoder4 = nn.Sequential( + nn.Conv2d(c_list[2], c_list[1], 3, stride=1, padding=1), + ) + self.decoder5 = nn.Sequential( + nn.Conv2d(c_list[1], c_list[0], 3, stride=1, padding=1), + ) + self.ebn1 = nn.GroupNorm(4, c_list[0]) + self.ebn2 = nn.GroupNorm(4, c_list[1]) + self.ebn3 = nn.GroupNorm(4, c_list[2]) + self.ebn4 = nn.GroupNorm(4, c_list[3]) + self.ebn5 = nn.GroupNorm(4, c_list[4]) + self.dbn1 = nn.GroupNorm(4, c_list[4]) + self.dbn2 = nn.GroupNorm(4, c_list[3]) + self.dbn3 = nn.GroupNorm(4, c_list[2]) + self.dbn4 = nn.GroupNorm(4, c_list[1]) + self.dbn5 = nn.GroupNorm(4, c_list[0]) + + self.final = nn.Conv2d(c_list[0], num_classes, kernel_size=1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv1d): + n = m.kernel_size[0] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + + out = F.gelu(F.max_pool2d(self.ebn1(self.encoder1(x)),2,2)) + t1 = out # b, c0, H/2, W/2 + + out = F.gelu(F.max_pool2d(self.ebn2(self.encoder2(out)),2,2)) + t2 = out # b, c1, H/4, W/4 + + out = F.gelu(F.max_pool2d(self.ebn3(self.encoder3(out)),2,2)) + t3 = out # b, c2, H/8, W/8 + + out = F.gelu(F.max_pool2d(self.ebn4(self.encoder4(out)),2,2)) + t4 = out # b, c3, H/16, W/16 + + out = F.gelu(F.max_pool2d(self.ebn5(self.encoder5(out)),2,2)) + t5 = out # b, c4, H/32, W/32 + + if self.bridge: t1, t2, t3, t4, t5 = self.scab(t1, t2, t3, t4, t5) + + out = F.gelu(self.encoder6(out)) # b, c5, H/32, W/32 + + out5 = F.gelu(self.dbn1(self.decoder1(out))) # b, c4, H/32, W/32 + out5 = torch.add(out5, t5) # b, c4, H/32, W/32 + + out4 = F.gelu(F.interpolate(self.dbn2(self.decoder2(out5)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c3, H/16, W/16 + out4 = torch.add(out4, t4) # b, c3, H/16, W/16 + + out3 = F.gelu(F.interpolate(self.dbn3(self.decoder3(out4)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c2, H/8, W/8 + out3 = torch.add(out3, t3) # b, c2, H/8, W/8 + + out2 = F.gelu(F.interpolate(self.dbn4(self.decoder4(out3)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c1, H/4, W/4 + out2 = torch.add(out2, t2) # b, c1, H/4, W/4 + + out1 = F.gelu(F.interpolate(self.dbn5(self.decoder5(out2)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c0, H/2, W/2 + out1 = torch.add(out1, t1) # b, c0, H/2, W/2 + + out0 = F.interpolate(self.final(out1),scale_factor=(2,2),mode ='bilinear',align_corners=True) # b, num_class, H, W + + if(len(x_shape) == 5): + new_shape = [N, D] + list(out0.shape)[1:] + out0 = torch.transpose(torch.reshape(out0, new_shape), 1, 2) + return out0 + diff --git a/pymic/net/net_dict_seg.py b/pymic/net/net_dict_seg.py index ffd82ae..b557bec 100644 --- a/pymic/net/net_dict_seg.py +++ b/pymic/net/net_dict_seg.py @@ -27,6 +27,7 @@ from pymic.net.net2d.trans2d.transunet import TransUNet from pymic.net.net2d.trans2d.swinunet import SwinUNet from pymic.net.net2d.umamba import UMambaBot, UMambaEnc +from pymic.net.net2d.unet2d_vm_light import UltraLight_VM_UNet from pymic.net.net3d.unet2d5 import UNet2D5 from pymic.net.net3d.unet3d import UNet3D from pymic.net.net3d.grunet import GRUNet @@ -65,6 +66,7 @@ 'UNet2D_ScSE': UNet2D_ScSE, 'UMambaBot': UMambaBot, 'UMambaEnc': UMambaEnc, + 'UltraLight_VM_UNet': UltraLight_VM_UNet, 'TransUNet': TransUNet, 'SwinUNet': SwinUNet, 'UNet2D5': UNet2D5, From a279271326ddea66a630f26556673d4332bdf7e6 Mon Sep 17 00:00:00 2001 From: taigw Date: Sun, 2 Feb 2025 22:19:21 +0800 Subject: [PATCH 76/86] add VMUNet --- pymic/net/net2d/unet2d_vm.py | 820 +++++++++++++++++++++++++++++ pymic/net/net2d/unet2d_vm_light.py | 2 +- pymic/net/net_dict_seg.py | 2 + 3 files changed, 823 insertions(+), 1 deletion(-) create mode 100644 pymic/net/net2d/unet2d_vm.py diff --git a/pymic/net/net2d/unet2d_vm.py b/pymic/net/net2d/unet2d_vm.py new file mode 100644 index 0000000..126fa30 --- /dev/null +++ b/pymic/net/net2d/unet2d_vm.py @@ -0,0 +1,820 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import math +from functools import partial +from typing import Optional, Callable + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from einops import rearrange, repeat +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +try: + from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref +except: + pass + +# an alternative for mamba_ssm (in which causal_conv1d is needed) +try: + from selective_scan import selective_scan_fn as selective_scan_fn_v1 + from selective_scan import selective_scan_ref as selective_scan_ref_v1 +except: + pass + +DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})" + + +def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False): + """ + u: r(B D L) + delta: r(B D L) + A: r(D N) + B: r(B N L) + C: r(B N L) + D: r(D) + z: r(B D L) + delta_bias: r(D), fp32 + + ignores: + [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] + """ + import numpy as np + + # fvcore.nn.jit_handles + def get_flops_einsum(input_shapes, equation): + np_arrs = [np.zeros(s) for s in input_shapes] + optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1] + for line in optim.split("\n"): + if "optimized flop" in line.lower(): + # divided by 2 because we count MAC (multiply-add counted as one flop) + flop = float(np.floor(float(line.split(":")[-1]) / 2)) + return flop + + + assert not with_complex + + flops = 0 # below code flops = 0 + if False: + ... + """ + dtype_in = u.dtype + u = u.float() + delta = delta.float() + if delta_bias is not None: + delta = delta + delta_bias[..., None].float() + if delta_softplus: + delta = F.softplus(delta) + batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] + is_variable_B = B.dim() >= 3 + is_variable_C = C.dim() >= 3 + if A.is_complex(): + if is_variable_B: + B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) + if is_variable_C: + C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) + else: + B = B.float() + C = C.float() + x = A.new_zeros((batch, dim, dstate)) + ys = [] + """ + + flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln") + if with_Group: + flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln") + else: + flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln") + if False: + ... + """ + deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) + if not is_variable_B: + deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) + else: + if B.dim() == 3: + deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) + else: + B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) + deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) + if is_variable_C and C.dim() == 4: + C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) + last_state = None + """ + + in_for_flops = B * D * N + if with_Group: + in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd") + else: + in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd") + flops += L * in_for_flops + if False: + ... + """ + for i in range(u.shape[2]): + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + if not is_variable_C: + y = torch.einsum('bdn,dn->bd', x, C) + else: + if C.dim() == 3: + y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) + else: + y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) + if i == u.shape[2] - 1: + last_state = x + if y.is_complex(): + y = y.real * 2 + ys.append(y) + y = torch.stack(ys, dim=2) # (batch dim L) + """ + + if with_D: + flops += B * D * L + if with_Z: + flops += B * D * L + if False: + ... + """ + out = y if D is None else y + u * rearrange(D, "d -> d 1") + if z is not None: + out = out * F.silu(z) + out = out.to(dtype=dtype_in) + """ + + return flops + + +class PatchEmbed2D(nn.Module): + r""" Image to Patch Embedding + Args: + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, **kwargs): + super().__init__() + if isinstance(patch_size, int): + patch_size = (patch_size, patch_size) + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = self.proj(x).permute(0, 2, 3, 1) + if self.norm is not None: + x = self.norm(x) + return x + + +class PatchMerging2D(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + B, H, W, C = x.shape + + SHAPE_FIX = [-1, -1] + if (W % 2 != 0) or (H % 2 != 0): + print(f"Warning, x.shape {x.shape} is not match even ===========", flush=True) + SHAPE_FIX[0] = H // 2 + SHAPE_FIX[1] = W // 2 + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + + if SHAPE_FIX[0] > 0: + x0 = x0[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] + x1 = x1[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] + x2 = x2[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] + x3 = x3[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] + + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, H//2, W//2, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +class PatchExpand2D(nn.Module): + def __init__(self, dim, dim_scale=2, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim*2 + self.dim_scale = dim_scale + self.expand = nn.Linear(self.dim, dim_scale*self.dim, bias=False) + self.norm = norm_layer(self.dim // dim_scale) + + def forward(self, x): + B, H, W, C = x.shape + x = self.expand(x) + + x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//self.dim_scale) + x= self.norm(x) + + return x + + +class Final_PatchExpand2D(nn.Module): + def __init__(self, dim, dim_scale=4, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.dim_scale = dim_scale + self.expand = nn.Linear(self.dim, dim_scale*self.dim, bias=False) + self.norm = norm_layer(self.dim // dim_scale) + + def forward(self, x): + B, H, W, C = x.shape + x = self.expand(x) + + x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//self.dim_scale) + x= self.norm(x) + + return x + + +class SS2D(nn.Module): + def __init__( + self, + d_model, + d_state=16, + # d_state="auto", # 20240109 + d_conv=3, + expand=2, + dt_rank="auto", + dt_min=0.001, + dt_max=0.1, + dt_init="random", + dt_scale=1.0, + dt_init_floor=1e-4, + dropout=0., + conv_bias=True, + bias=False, + device=None, + dtype=None, + **kwargs, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.d_model = d_model + self.d_state = d_state + # self.d_state = math.ceil(self.d_model / 6) if d_state == "auto" else d_model # 20240109 + self.d_conv = d_conv + self.expand = expand + self.d_inner = int(self.expand * self.d_model) + self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank + + self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) + self.conv2d = nn.Conv2d( + in_channels=self.d_inner, + out_channels=self.d_inner, + groups=self.d_inner, + bias=conv_bias, + kernel_size=d_conv, + padding=(d_conv - 1) // 2, + **factory_kwargs, + ) + self.act = nn.SiLU() + + self.x_proj = ( + nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), + nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), + nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), + nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), + ) + self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K=4, N, inner) + del self.x_proj + + self.dt_projs = ( + self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), + self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), + self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), + self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), + ) + self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K=4, inner, rank) + self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K=4, inner) + del self.dt_projs + + self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=4, merge=True) # (K=4, D, N) + self.Ds = self.D_init(self.d_inner, copies=4, merge=True) # (K=4, D, N) + + # self.selective_scan = selective_scan_fn + self.forward_core = self.forward_corev0 + + self.out_norm = nn.LayerNorm(self.d_inner) + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + self.dropout = nn.Dropout(dropout) if dropout > 0. else None + + @staticmethod + def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs): + dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs) + + # Initialize special dt projection to preserve variance at initialization + dt_init_std = dt_rank**-0.5 * dt_scale + if dt_init == "constant": + nn.init.constant_(dt_proj.weight, dt_init_std) + elif dt_init == "random": + nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) + else: + raise NotImplementedError + + # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max + dt = torch.exp( + torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ).clamp(min=dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + with torch.no_grad(): + dt_proj.bias.copy_(inv_dt) + # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit + dt_proj.bias._no_reinit = True + + return dt_proj + + @staticmethod + def A_log_init(d_state, d_inner, copies=1, device=None, merge=True): + # S4D real initialization + A = repeat( + torch.arange(1, d_state + 1, dtype=torch.float32, device=device), + "n -> d n", + d=d_inner, + ).contiguous() + A_log = torch.log(A) # Keep A_log in fp32 + if copies > 1: + A_log = repeat(A_log, "d n -> r d n", r=copies) + if merge: + A_log = A_log.flatten(0, 1) + A_log = nn.Parameter(A_log) + A_log._no_weight_decay = True + return A_log + + @staticmethod + def D_init(d_inner, copies=1, device=None, merge=True): + # D "skip" parameter + D = torch.ones(d_inner, device=device) + if copies > 1: + D = repeat(D, "n1 -> r n1", r=copies) + if merge: + D = D.flatten(0, 1) + D = nn.Parameter(D) # Keep in fp32 + D._no_weight_decay = True + return D + + def forward_corev0(self, x: torch.Tensor): + self.selective_scan = selective_scan_fn + + B, C, H, W = x.shape + L = H * W + K = 4 + + x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) + xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l) + + x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight) + # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1) + dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2) + dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight) + # dts = dts + self.dt_projs_bias.view(1, K, -1, 1) + + xs = xs.float().view(B, -1, L) # (b, k * d, l) + dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l) + Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l) + Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l) + Ds = self.Ds.float().view(-1) # (k * d) + As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state) + dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d) + + out_y = self.selective_scan( + xs, dts, + As, Bs, Cs, Ds, z=None, + delta_bias=dt_projs_bias, + delta_softplus=True, + return_last_state=False, + ).view(B, K, -1, L) + assert out_y.dtype == torch.float + + inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) + wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) + invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) + + return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y + + # an alternative to forward_corev1 + def forward_corev1(self, x: torch.Tensor): + self.selective_scan = selective_scan_fn_v1 + + B, C, H, W = x.shape + L = H * W + K = 4 + + x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) + xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l) + + x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight) + # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1) + dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2) + dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight) + # dts = dts + self.dt_projs_bias.view(1, K, -1, 1) + + xs = xs.float().view(B, -1, L) # (b, k * d, l) + dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l) + Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l) + Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l) + Ds = self.Ds.float().view(-1) # (k * d) + As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state) + dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d) + + out_y = self.selective_scan( + xs, dts, + As, Bs, Cs, Ds, + delta_bias=dt_projs_bias, + delta_softplus=True, + ).view(B, K, -1, L) + assert out_y.dtype == torch.float + + inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) + wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) + invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) + + return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y + + def forward(self, x: torch.Tensor, **kwargs): + B, H, W, C = x.shape + + xz = self.in_proj(x) + x, z = xz.chunk(2, dim=-1) # (b, h, w, d) + + x = x.permute(0, 3, 1, 2).contiguous() + x = self.act(self.conv2d(x)) # (b, d, h, w) + y1, y2, y3, y4 = self.forward_core(x) + assert y1.dtype == torch.float32 + y = y1 + y2 + y3 + y4 + y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1) + y = self.out_norm(y) + y = y * F.silu(z) + out = self.out_proj(y) + if self.dropout is not None: + out = self.dropout(out) + return out + + +class VSSBlock(nn.Module): + def __init__( + self, + hidden_dim: int = 0, + drop_path: float = 0, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + attn_drop_rate: float = 0, + d_state: int = 16, + **kwargs, + ): + super().__init__() + self.ln_1 = norm_layer(hidden_dim) + self.self_attention = SS2D(d_model=hidden_dim, dropout=attn_drop_rate, d_state=d_state, **kwargs) + self.drop_path = DropPath(drop_path) + + def forward(self, input: torch.Tensor): + x = input + self.drop_path(self.self_attention(self.ln_1(input))) + return x + + +class VSSLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + depth (int): Number of blocks. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + dim, + depth, + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + d_state=16, + **kwargs, + ): + super().__init__() + self.dim = dim + self.use_checkpoint = use_checkpoint + + self.blocks = nn.ModuleList([ + VSSBlock( + hidden_dim=dim, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + attn_drop_rate=attn_drop, + d_state=d_state, + ) + for i in range(depth)]) + + if True: # is this really applied? Yes, but been overriden later in VSSM! + def _init_weights(module: nn.Module): + for name, p in module.named_parameters(): + if name in ["out_proj.weight"]: + p = p.clone().detach_() # fake init, just to keep the seed .... + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + self.apply(_init_weights) + + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + + if self.downsample is not None: + x = self.downsample(x) + + return x + + + +class VSSLayer_up(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + depth (int): Number of blocks. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + dim, + depth, + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + upsample=None, + use_checkpoint=False, + d_state=16, + **kwargs, + ): + super().__init__() + self.dim = dim + self.use_checkpoint = use_checkpoint + + self.blocks = nn.ModuleList([ + VSSBlock( + hidden_dim=dim, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + attn_drop_rate=attn_drop, + d_state=d_state, + ) + for i in range(depth)]) + + if True: # is this really applied? Yes, but been overriden later in VSSM! + def _init_weights(module: nn.Module): + for name, p in module.named_parameters(): + if name in ["out_proj.weight"]: + p = p.clone().detach_() # fake init, just to keep the seed .... + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + self.apply(_init_weights) + + if upsample is not None: + self.upsample = upsample(dim=dim, norm_layer=norm_layer) + else: + self.upsample = None + + + def forward(self, x): + if self.upsample is not None: + x = self.upsample(x) + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + return x + + + +# class VSSM(nn.Module): +class VMUNet(nn.Module): + """ + VM_UNet that is a UNet-like pure vision mambal network for segmentation. + + * Reference: Jiacheng Ruan et al., VM-UNet: Vision Mamba UNet for Medical Image Segmentation. + arxiv 2403.09157, 2024. + + The implementation is based on the code at: + https://github.com/JCruan519/VM-UNet. + + The parameters for the backbone should be given in the `params` dictionary. + + :param in_chns: (int) Input channel number. + :param class_num: (int) The class number for segmentation task. + :param depths: (list) The depth of VSS block at each resolution level. + The length should be 4, by default it is [2, 2, 9, 2]. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4, by default it is [96, 192, 384, 768]. + """ + # def __init__(self, c=4, in_chans=3, num_classes=1000, depths=[2, 2, 9, 2], depths_decoder=[2, 9, 2, 2], + # dims=[96, 192, 384, 768], dims_decoder=[768, 384, 192, 96], d_state=16, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + # norm_layer=nn.LayerNorm, patch_norm=True, + # use_checkpoint=False, **kwargs): + # super().__init__() + def __init__(self, params): + super(VMUNet, self).__init__() + in_chans = params['in_chns'] + num_classes = params['class_num'] + patch_size = params.get('patch_size', 4) + depths = params.get('depths', [2, 2, 9, 2]) + depths_decoder = depths.copy() + depths_decoder.reverse() + dims = params.get('feature_chns', [96, 192, 384, 768]) + dims_decoder = dims.copy() + dims_decoder.reverse() + d_state = params.get('d_state', 16) + drop_rate = params.get('drop_rate', 0.) + attn_drop_rate = params.get('att_drop_rate', 0.) + drop_path_rate = params.get('path_drop_rate', 0.1) + + norm_layer = nn.LayerNorm + patch_norm = True + use_checkpoint = False + + self.num_classes = num_classes + self.num_layers = len(depths) + if isinstance(dims, int): + dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)] + self.embed_dim = dims[0] + self.num_features = dims[-1] + self.dims = dims + + self.patch_embed = PatchEmbed2D(patch_size=patch_size, in_chans=in_chans, embed_dim=self.embed_dim, + norm_layer=norm_layer if patch_norm else None) + + # WASTED absolute position embedding ====================== + self.ape = False + # self.ape = False + # drop_rate = 0.0 + if self.ape: + self.patches_resolution = self.patch_embed.patches_resolution + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, *self.patches_resolution, self.embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + dpr_decoder = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths_decoder))][::-1] + + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = VSSLayer( + dim=dims[i_layer], + depth=depths[i_layer], + d_state=math.ceil(dims[0] / 6) if d_state is None else d_state, # 20240109 + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging2D if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + ) + self.layers.append(layer) + + self.layers_up = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = VSSLayer_up( + dim=dims_decoder[i_layer], + depth=depths_decoder[i_layer], + d_state=math.ceil(dims[0] / 6) if d_state is None else d_state, # 20240109 + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr_decoder[sum(depths_decoder[:i_layer]):sum(depths_decoder[:i_layer + 1])], + norm_layer=norm_layer, + upsample=PatchExpand2D if (i_layer != 0) else None, + use_checkpoint=use_checkpoint, + ) + self.layers_up.append(layer) + + self.final_up = Final_PatchExpand2D(dim=dims_decoder[-1], dim_scale=4, norm_layer=norm_layer) + self.final_conv = nn.Conv2d(dims_decoder[-1]//4, num_classes, 1) + + # self.norm = norm_layer(self.num_features) + # self.avgpool = nn.AdaptiveAvgPool1d(1) + # self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m: nn.Module): + """ + out_proj.weight which is previously initilized in VSSBlock, would be cleared in nn.Linear + no fc.weight found in the any of the model parameters + no nn.Embedding found in the any of the model parameters + so the thing is, VSSBlock initialization is useless + + Conv2D is not intialized !!! + """ + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def forward_features(self, x): + skip_list = [] + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + skip_list.append(x) + x = layer(x) + return x, skip_list + + def forward_features_up(self, x, skip_list): + for inx, layer_up in enumerate(self.layers_up): + if inx == 0: + x = layer_up(x) + else: + x = layer_up(x+skip_list[-inx]) + + return x + + def forward_final(self, x): + x = self.final_up(x) + x = x.permute(0,3,1,2) + x = self.final_conv(x) + return x + + def forward_backbone(self, x): + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x) + return x + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + x, skip_list = self.forward_features(x) + x = self.forward_features_up(x, skip_list) + x = self.forward_final(x) + + if(len(x_shape) == 5): + new_shape = [N, D] + list(x.shape)[1:] + x = torch.transpose(torch.reshape(x, new_shape), 1, 2) + + return x + + + + + + diff --git a/pymic/net/net2d/unet2d_vm_light.py b/pymic/net/net2d/unet2d_vm_light.py index ed3f1c5..ab1de76 100644 --- a/pymic/net/net2d/unet2d_vm_light.py +++ b/pymic/net/net2d/unet2d_vm_light.py @@ -150,7 +150,7 @@ def __init__(self, params): :param class_num: (int) The class number for segmentation task. :param feature_chns: (list) Feature channel for each resolution level. The length should be 6, by default it is [8, 16, 24, 32, 48, 64]. - :param bridge: (int) If the bridge based on spatial and channel attentions is used or not. + :param bridge: (bool) If the bridge based on spatial and channel attentions is used or not. By default it is True. """ super(UltraLight_VM_UNet, self).__init__() diff --git a/pymic/net/net_dict_seg.py b/pymic/net/net_dict_seg.py index b557bec..6f0f0c6 100644 --- a/pymic/net/net_dict_seg.py +++ b/pymic/net/net_dict_seg.py @@ -27,6 +27,7 @@ from pymic.net.net2d.trans2d.transunet import TransUNet from pymic.net.net2d.trans2d.swinunet import SwinUNet from pymic.net.net2d.umamba import UMambaBot, UMambaEnc +from pymic.net.net2d.unet2d_vm import VMUNet from pymic.net.net2d.unet2d_vm_light import UltraLight_VM_UNet from pymic.net.net3d.unet2d5 import UNet2D5 from pymic.net.net3d.unet3d import UNet3D @@ -66,6 +67,7 @@ 'UNet2D_ScSE': UNet2D_ScSE, 'UMambaBot': UMambaBot, 'UMambaEnc': UMambaEnc, + 'VMUNet':VMUNet, 'UltraLight_VM_UNet': UltraLight_VM_UNet, 'TransUNet': TransUNet, 'SwinUNet': SwinUNet, From 9481ec32aab3dfb818be8c75056a6622046c1ffc Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 1 Aug 2025 11:24:36 +0800 Subject: [PATCH 77/86] add DMSPS add DMSPS for weakly supervised segmentation add adaptive region specific Tverskyloss --- pymic/loss/seg/ars_tversky.py | 67 ++++++++++++++ pymic/net_run/weak_sup/__init__.py | 4 +- pymic/net_run/weak_sup/wsl_dmpls.py | 10 ++- pymic/net_run/weak_sup/wsl_dmsps.py | 131 ++++++++++++++++++++++++++++ 4 files changed, 208 insertions(+), 4 deletions(-) create mode 100644 pymic/loss/seg/ars_tversky.py create mode 100644 pymic/net_run/weak_sup/wsl_dmsps.py diff --git a/pymic/loss/seg/ars_tversky.py b/pymic/loss/seg/ars_tversky.py new file mode 100644 index 0000000..4fafeae --- /dev/null +++ b/pymic/loss/seg/ars_tversky.py @@ -0,0 +1,67 @@ + +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch.nn as nn +from pymic.loss.seg.abstract import AbstractSegLoss + +class ARSTverskyLoss(AbstractSegLoss): + """ + The Adaptive Region-Specific Loss in this paper: + + * Y. Chen et al.: Adaptive Region-Specific Loss for Improved Medical Image Segmentation. + `IEEE TPAMI 2023. `_ + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `ARSTversky_patch_size`: (list) the patch size. + :param `A`: the lowest weight for FP or FN (default 0.3) + :param `B`: the gap between lowest and highest weight (default 0.4) + """ + def __init__(self, params): + super(ARSTverskyLoss, self).__init__(params) + self.patch_size = params['ARSTversky_patch_size'.lower()] + self.a = params.get('ARSTversky_a'.lower(), 0.3) + self.b = params.get('ARSTversky_b'.lower(), 0.4) + + self.dim = len(self.patch_size) + assert self.dim in [2, 3], "The num of dim must be 2 or 3." + if self.dim == 3: + self.pool = nn.AvgPool3d(kernel_size=self.patch_size, stride=self.patch_size) + elif self.dim == 2: + self.pool = nn.AvgPool2d(kernel_size=self.patch_size, stride=self.patch_size) + + def forward(self, loss_input_dict): + predict = loss_input_dict['prediction'] + soft_y = loss_input_dict['ground_truth'] + + if(isinstance(predict, (list, tuple))): + predict = predict[0] + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) + + smooth = 1e-5 + if self.dim == 2: + assert predict.shape[-2] % self.patch_size[0] == 0, "image size % patch size must be 0 in dimension y" + assert predict.shape[-1] % self.patch_size[1] == 0, "image size % patch size must be 0 in dimension x" + elif self.dim == 3: + assert predict.shape[-3] % self.patch_size[0] == 0, "image size % patch size must be 0 in dimension z" + assert predict.shape[-2] % self.patch_size[1] == 0, "image size % patch size must be 0 in dimension y" + assert predict.shape[-1] % self.patch_size[2] == 0, "image size % patch size must be 0 in dimension x" + + tp = predict * soft_y + fp = predict * (1 - soft_y) + fn = (1 - predict) * soft_y + + region_tp = self.pool(tp) + region_fp = self.pool(fp) + region_fn = self.pool(fn) + + alpha = self.a + self.b * (region_fp + smooth) / (region_fp + region_fn + smooth) + beta = self.a + self.b * (region_fn + smooth) / (region_fp + region_fn + smooth) + + region_tversky = (region_tp + smooth) / (region_tp + alpha * region_fp + beta * region_fn + smooth) + region_tversky = 1 - region_tversky + loss = region_tversky.mean() + return loss \ No newline at end of file diff --git a/pymic/net_run/weak_sup/__init__.py b/pymic/net_run/weak_sup/__init__.py index b3c8332..a583ae8 100644 --- a/pymic/net_run/weak_sup/__init__.py +++ b/pymic/net_run/weak_sup/__init__.py @@ -6,10 +6,12 @@ from pymic.net_run.weak_sup.wsl_tv import WSLTotalVariation from pymic.net_run.weak_sup.wsl_ustm import WSLUSTM from pymic.net_run.weak_sup.wsl_dmpls import WSLDMPLS +from pymic.net_run.weak_sup.wsl_dmsps import WSLDMSPS WSLMethodDict = {'EntropyMinimization': WSLEntropyMinimization, 'GatedCRF': WSLGatedCRF, 'MumfordShah': WSLMumfordShah, 'TotalVariation': WSLTotalVariation, 'USTM': WSLUSTM, - 'DMPLS': WSLDMPLS} \ No newline at end of file + 'DMPLS': WSLDMPLS, + 'DMSPS': WSLDMSPS} \ No newline at end of file diff --git a/pymic/net_run/weak_sup/wsl_dmpls.py b/pymic/net_run/weak_sup/wsl_dmpls.py index 01902c7..ea96bbb 100644 --- a/pymic/net_run/weak_sup/wsl_dmpls.py +++ b/pymic/net_run/weak_sup/wsl_dmpls.py @@ -5,11 +5,11 @@ import random import time import torch -from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.dice import DiceLoss +from pymic.loss.seg.ce import CrossEntropyLoss from pymic.net_run.weak_sup import WSLSegAgent from pymic.util.ramps import get_rampup_ratio @@ -42,9 +42,13 @@ def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] wsl_cfg = self.config['weakly_supervised_learning'] - iter_max = self.config['training']['iter_max'] + iter_max = self.config['training']['iter_max'] rampup_start = wsl_cfg.get('rampup_start', 0) rampup_end = wsl_cfg.get('rampup_end', iter_max) + pseudo_loss_type = wsl_cfg.get('pseudo_sup_loss', 'dice_loss') + if (pseudo_loss_type not in ('dice_loss', 'ce_loss')): + raise ValueError("""For pseudo supervision loss, only dice_loss and ce_loss \ + are supported.""") train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 @@ -83,7 +87,7 @@ def training(self): pseudo_lab = get_soft_label(pseudo_lab, class_num, self.tensor_type) # calculate the pseudo label supervision loss - loss_calculator = DiceLoss() + loss_calculator = DiceLoss() if pseudo_loss_type == 'dice_loss' else CrossEntropyLoss() loss_dict1 = {"prediction":outputs1, 'ground_truth':pseudo_lab} loss_dict2 = {"prediction":outputs2, 'ground_truth':pseudo_lab} loss_reg = 0.5 * (loss_calculator(loss_dict1) + loss_calculator(loss_dict2)) diff --git a/pymic/net_run/weak_sup/wsl_dmsps.py b/pymic/net_run/weak_sup/wsl_dmsps.py new file mode 100644 index 0000000..dca3643 --- /dev/null +++ b/pymic/net_run/weak_sup/wsl_dmsps.py @@ -0,0 +1,131 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import logging +import numpy as np +import random +import time +import torch +from pymic.loss.seg.util import get_soft_label +from pymic.loss.seg.util import reshape_prediction_and_ground_truth +from pymic.loss.seg.util import get_classwise_dice +from pymic.loss.seg.dice import DiceLoss +from pymic.loss.seg.ce import CrossEntropyLoss +from pymic.net_run.weak_sup import WSLSegAgent +from pymic.util.ramps import get_rampup_ratio + +class WSLDMSPS(WSLSegAgent): + """ + Weakly supervised segmentation based on Dynamically Mixed Pseudo Labels Supervision. + + * Reference: Meng Han, Xiangde Luo, Xiangjiang Xie, Wenjun Liao, Shichuan Zhang, Tao Song, + Guotai Wang, Shaoting Zhang. DMSPS: Dynamically mixed soft pseudo-label supervision for + scribble-supervised medical image segmentation. + `Medical Image Analysis 2024. `_ + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `weakly_supervised_learning` is needed. See :doc:`usage.wsl` for details. + """ + def __init__(self, config, stage = 'train'): + net_type = config['network']['net_type'] + if net_type not in ['UNet2D_DualBranch', 'UNet3D_DualBranch']: + raise ValueError("""For WSL_DMPLS, a dual branch network is expected. \ + It only supports UNet2D_DualBranch and UNet3D_DualBranch currently.""") + super(WSLDMSPS, self).__init__(config, stage) + + def training(self): + class_num = self.config['network']['class_num'] + iter_valid = self.config['training']['iter_valid'] + wsl_cfg = self.config['weakly_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = wsl_cfg.get('rampup_start', 0) + rampup_end = wsl_cfg.get('rampup_end', iter_max) + pseudo_loss_type = wsl_cfg.get('pseudo_sup_loss', 'ce_loss') + if (pseudo_loss_type not in ('dice_loss', 'ce_loss')): + raise ValueError("""For pseudo supervision loss, only dice_loss and ce_loss \ + are supported.""") + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 + train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 + self.net.train() + for it in range(iter_valid): + t0 = time.time() + try: + data = next(self.trainIter) + except StopIteration: + self.trainIter = iter(self.train_loader) + data = next(self.trainIter) + t1 = time.time() + # get the inputs + inputs = self.convert_tensor_type(data['image']) + y = self.convert_tensor_type(data['label_prob']) + + inputs, y = inputs.to(self.device), y.to(self.device) + + # zero the parameter gradients + self.optimizer.zero_grad() + + # forward + backward + optimize + outputs1, outputs2 = self.net(inputs) + t2 = time.time() + + loss_sup1 = self.get_loss_value(data, outputs1, y) + loss_sup2 = self.get_loss_value(data, outputs2, y) + loss_sup = 0.5 * (loss_sup1 + loss_sup2) + + # get pseudo label with dynamical mix + outputs_soft1 = torch.softmax(outputs1, dim=1) + outputs_soft2 = torch.softmax(outputs2, dim=1) + beta = random.random() + pseudo_lab = beta*outputs_soft1.detach() + (1.0-beta)*outputs_soft2.detach() + # pseudo_lab = torch.argmax(pseudo_lab, dim = 1, keepdim = True) + # pseudo_lab = get_soft_label(pseudo_lab, class_num, self.tensor_type) + + # calculate the pseudo label supervision loss + loss_calculator = DiceLoss() if pseudo_loss_type == 'dice_loss' else CrossEntropyLoss() + loss_dict1 = {"prediction":outputs1, 'ground_truth':pseudo_lab} + loss_dict2 = {"prediction":outputs2, 'ground_truth':pseudo_lab} + loss_reg = 0.5 * (loss_calculator(loss_dict1) + loss_calculator(loss_dict2)) + + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio + loss = loss_sup + regular_w*loss_reg + t3 = time.time() + loss.backward() + t4 = time.time() + self.optimizer.step() + + train_loss = train_loss + loss.item() + train_loss_sup = train_loss_sup + loss_sup.item() + train_loss_reg = train_loss_reg + loss_reg.item() + # get dice evaluation for each class in annotated images + if(isinstance(outputs1, tuple) or isinstance(outputs1, list)): + outputs1 = outputs1[0] + p_argmax = torch.argmax(outputs1, dim = 1, keepdim = True) + p_soft = get_soft_label(p_argmax, class_num, self.tensor_type) + p_soft, y = reshape_prediction_and_ground_truth(p_soft, y) + dice_list = get_classwise_dice(p_soft, y) + train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 + train_avg_loss = train_loss / iter_valid + train_avg_loss_sup = train_loss_sup / iter_valid + train_avg_loss_reg = train_loss_reg / iter_valid + train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) + train_avg_dice = train_cls_dice[1:].mean() + + train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, + 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time} + return train_scalers + \ No newline at end of file From 70736d54f0f933614d8a0d08d9d47de1c2a384ae Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 1 Aug 2025 11:32:36 +0800 Subject: [PATCH 78/86] Create history.txt --- docs/history.txt | 1 + 1 file changed, 1 insertion(+) create mode 100644 docs/history.txt diff --git a/docs/history.txt b/docs/history.txt new file mode 100644 index 0000000..3ca7b28 --- /dev/null +++ b/docs/history.txt @@ -0,0 +1 @@ +2025.8.1 Add code of DMSPS \ No newline at end of file From a4634f26262a26b64a118fe9c6fe9e32c95b6ee8 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 2 Aug 2025 12:41:51 +0800 Subject: [PATCH 79/86] add support to h5 files 1, add support to h5 files 2, edit Rescale, Pad and Rotate so that setting output size to a 2D list is allowed for 3D images --- pymic/io/image_read_write.py | 4 +- pymic/io/nifty_dataset.py | 143 ++++++++++++++++++++++++----------- pymic/net_run/agent_seg.py | 12 ++- pymic/transform/pad.py | 5 ++ pymic/transform/rescale.py | 8 +- pymic/transform/rotate.py | 31 ++++++-- 6 files changed, 143 insertions(+), 60 deletions(-) diff --git a/pymic/io/image_read_write.py b/pymic/io/image_read_write.py index efbe656..cb17259 100644 --- a/pymic/io/image_read_write.py +++ b/pymic/io/image_read_write.py @@ -81,8 +81,8 @@ def load_image_as_nd_array(image_name): if (image_name.endswith(".nii.gz") or image_name.endswith(".nii") or image_name.endswith(".mha")): image_dict = load_nifty_volume_as_4d_array(image_name) - elif(image_name.endswith(".jpg") or image_name.endswith(".jpeg") or - image_name.endswith(".tif") or image_name.endswith(".png")): + elif(image_name.lower().endswith(".jpg") or image_name.lower().endswith(".jpeg") or + image_name.lower().endswith(".tif") or image_name.lower().endswith(".png")): image_dict = load_rgb_image_as_3d_array(image_name) else: raise ValueError("unsupported image format: {0:}".format(image_name)) diff --git a/pymic/io/nifty_dataset.py b/pymic/io/nifty_dataset.py index 048ba64..cb23e2f 100644 --- a/pymic/io/nifty_dataset.py +++ b/pymic/io/nifty_dataset.py @@ -3,12 +3,26 @@ import logging import os +import h5py import pandas as pd import numpy as np from torch.utils.data import Dataset from pymic import TaskType from pymic.io.image_read_write import load_image_as_nd_array +def check_and_expand_dim(x, img_dim): + """ + check the input dim and expand it with a channel dimension if necessary. + For 2D images, return a 3D numpy array with a shape of [C, H, W] + for 3D images, return a 3D numpy array with a shape of [C, D, H, W] + """ + input_dim = len(x.shape) + if(input_dim == 2 and img_dim == 2): + x = np.expand_dims(x, axis = 0) + elif(input_dim == 3 and img_dim == 3): + x = np.expand_dims(x, axis = 0) + return x + class NiftyDataset(Dataset): """ Dataset for loading images for segmentation. It generates 4D tensors with @@ -16,37 +30,64 @@ class NiftyDataset(Dataset): with dimention order [C, H, W] for 2D images. :param root_dir: (str) Directory with all the images. - :param csv_file: (str) Path to the csv file with image names. - :param modal_num: (int) Number of modalities. + :param csv: (str) Path to the csv file with image names. If it is None, + the images will be those under root_dir. This only works for testing with + a single input modality. If the images are stored in h5 files, the *.csv file + only has one column, while for other types of images such as .nii.gz and.png, + each column is for an input modality, and the last column is for label. + :param modal_num: (int) Number of modalities. This is only used if the data_file is *.csv. + :param image_dim: (int) Spacial dimension of the input image. This is ony used for h5 files. :param with_label: (bool) Load the data with segmentation ground truth or not. :param transform: (list) List of transforms to be applied on a sample. The built-in transforms can listed in :mod:`pymic.transform.trans_dict`. """ - # def __init__(self, root_dir, csv_file, modal_num = 1, - def __init__(self, root_dir, csv_file, modal_num = 1, allow_missing_modal = False, - with_label = False, transform=None, task = TaskType.SEGMENTATION): + def __init__(self, root_dir, csv_file, modal_num = 1, image_dim = 3, allow_missing_modal = False, + with_label = True, transform=None, task = TaskType.SEGMENTATION): self.root_dir = root_dir - self.csv_items = pd.read_csv(csv_file) + if(csv_file is not None): + self.csv_items = pd.read_csv(csv_file) + else: + img_names = os.listdir(root_dir) + img_names = [item for item in img_names if ("nii" in item or "jpg" in item or + "jpeg" in item or "bmp" in item or "png" in item)] + csv_dict = {"image":img_names} + self.csv_items = pd.DataFrame.from_dict(csv_dict) + self.modal_num = modal_num + self.image_dim = image_dim self.allow_emtpy= allow_missing_modal self.with_label = with_label self.transform = transform self.task = task + self.h5files = False assert self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION] - csv_keys = list(self.csv_items.keys()) - if('label' not in csv_keys): + # check if the files are h5 images, and if the labels are provided. + temp_name = self.csv_items.iloc[0, 0] + logging.warning(temp_name) + if(temp_name.endswith(".h5")): + self.h5files = True + temp_full_name = "{0:}/{1:}".format(self.root_dir, temp_name) + h5f = h5py.File(temp_full_name, 'r') + if('label' not in h5f): + self.with_label = False + else: + csv_keys = list(self.csv_items.keys()) + if('label' not in csv_keys): + self.with_label = False + + self.image_weight_idx = None + self.pixel_weight_idx = None + if('image_weight' in csv_keys): + self.image_weight_idx = csv_keys.index('image_weight') + if('pixel_weight' in csv_keys): + self.pixel_weight_idx = csv_keys.index('pixel_weight') + if(not self.with_label): logging.warning("`label` section is not found in the csv file {0:}".format( - csv_file) + "\n -- This is only allowed for self-supervised learning" + + csv_file) + "or the corresponding h5 file." + + "\n -- This is only allowed for self-supervised learning" + "\n -- when `SelfSuperviseLabel` is used in the transform, or when" + "\n -- loading the unlabeled data for preprocessing.") - self.with_label = False - self.image_weight_idx = None - self.pixel_weight_idx = None - if('image_weight' in csv_keys): - self.image_weight_idx = csv_keys.index('image_weight') - if('pixel_weight' in csv_keys): - self.pixel_weight_idx = csv_keys.index('pixel_weight') def __len__(self): return len(self.csv_items) @@ -92,36 +133,46 @@ def __get_pixel_weight__(self, idx): def __getitem__(self, idx): names_list, image_list = [], [] image_shape = None - for i in range (self.modal_num): - image_name = self.csv_items.iloc[idx, i] - image_full_name = "{0:}/{1:}".format(self.root_dir, image_name) - if(os.path.exists(image_full_name)): - image_dict = load_image_as_nd_array(image_full_name) - image_data = image_dict['data_array'] - elif(self.allow_emtpy and image_shape is not None): - image_data = np.zeros(image_shape) - else: - raise KeyError("File not found: {0:}".format(image_full_name)) - if(i == 0): - image_shape = image_data.shape - names_list.append(image_name) - image_list.append(image_data) - image = np.concatenate(image_list, axis = 0) - image = np.asarray(image, np.float32) - - sample = {'image': image, 'names' : names_list, - 'origin':image_dict['origin'], - 'spacing': image_dict['spacing'], - 'direction':image_dict['direction']} - if (self.with_label): - sample['label'], label_name = self.__getlabel__(idx) - sample['names'].append(label_name) - assert(image.shape[1:] == sample['label'].shape[1:]) - if (self.image_weight_idx is not None): - sample['image_weight'] = self.csv_items.iloc[idx, self.image_weight_idx] - if (self.pixel_weight_idx is not None): - sample['pixel_weight'] = self.__get_pixel_weight__(idx) - assert(image.shape[1:] == sample['pixel_weight'].shape[1:]) + if(self.h5files): + sample_name = self.csv_items.iloc[idx, 0] + h5f = h5py.File(self.root_dir + '/' + sample_name, 'r') + img = check_and_expand_dim(h5f['image'][:], self.image_dim) + sample = {'image':img} + if(self.with_label): + lab = check_and_expand_dim(h5f['label'][:], self.image_dim) + sample['label'] = lab + sample['names'] = [sample_name] + else: + for i in range (self.modal_num): + image_name = self.csv_items.iloc[idx, i] + image_full_name = "{0:}/{1:}".format(self.root_dir, image_name) + if(os.path.exists(image_full_name)): + image_dict = load_image_as_nd_array(image_full_name) + image_data = image_dict['data_array'] + elif(self.allow_emtpy and image_shape is not None): + image_data = np.zeros(image_shape) + else: + raise KeyError("File not found: {0:}".format(image_full_name)) + if(i == 0): + image_shape = image_data.shape + names_list.append(image_name) + image_list.append(image_data) + image = np.concatenate(image_list, axis = 0) + image = np.asarray(image, np.float32) + + sample = {'image': image, 'names' : names_list, + 'origin':image_dict['origin'], + 'spacing': image_dict['spacing'], + 'direction':image_dict['direction']} + if (self.with_label): + sample['label'], label_name = self.__getlabel__(idx) + sample['names'].append(label_name) + assert(image.shape[1:] == sample['label'].shape[1:]) + if (self.image_weight_idx is not None): + sample['image_weight'] = self.csv_items.iloc[idx, self.image_weight_idx] + if (self.pixel_weight_idx is not None): + sample['pixel_weight'] = self.__get_pixel_weight__(idx) + assert(image.shape[1:] == sample['pixel_weight'].shape[1:]) if self.transform: sample = self.transform(sample) diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 45e815b..e1a3377 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -71,14 +71,18 @@ def get_stage_dataset_from_config(self, stage): modal_num = self.config['dataset'].get('modal_num', 1) allow_miss = self.config['dataset'].get('allow_missing_modal', False) stage_dir = self.config['dataset'].get('train_dir', None) - if(stage == 'valid' and "valid_dir" in self.config['dataset']): - stage_dir = self.config['dataset']['valid_dir'] - if(stage == 'test' and "test_dir" in self.config['dataset']): - stage_dir = self.config['dataset']['test_dir'] + stage_dim = self.config['dataset'].get('train_dim', 3) + if(stage == 'valid'): # and "valid_dir" in self.config['dataset']): + stage_dir = self.config['dataset'].get('valid_dir', stage_dir) + stage_dim = self.config['dataset'].get('valid_dim', stage_dim) + if(stage == 'test'): # and "test_dir" in self.config['dataset']): + stage_dir = self.config['dataset'].get('test_dir', stage_dir) + stage_dim = self.config['dataset'].get('test_dim', stage_dim) logging.info("Creating dataset for {0:}".format(stage)) dataset = NiftyDataset(root_dir = stage_dir, csv_file = csv_file, modal_num = modal_num, + image_dim = stage_dim, allow_missing_modal = allow_miss, with_label= with_label, transform = data_transform, diff --git a/pymic/transform/pad.py b/pymic/transform/pad.py index 8624aa2..509643d 100644 --- a/pymic/transform/pad.py +++ b/pymic/transform/pad.py @@ -38,6 +38,11 @@ def __call__(self, sample): image = sample['image'] input_shape = image.shape input_dim = len(input_shape) - 1 + + if(input_dim == 3): + if(len(self.output_size) == 2): + # for 3D images, igore the z-axis + self.output_size = [input_shape[1]] + list(self.output_size) assert(len(self.output_size) == input_dim) if(self.ceil_mode): multiple = [int(math.ceil(float(input_shape[1+i])/self.output_size[i]))\ diff --git a/pymic/transform/rescale.py b/pymic/transform/rescale.py index 47271ec..154b1e0 100644 --- a/pymic/transform/rescale.py +++ b/pymic/transform/rescale.py @@ -17,9 +17,9 @@ class Rescale(AbstractTransform): following fields: :param `Rescale_output_size`: (list/tuple or int) The output size along each spatial axis, - such as [D, H, W] or [H, W]. If D is None, the input image is only reslcaled in 2D. - If int, the smallest axis is matched to output_size keeping aspect ratio the same - as the input. + such as [D, H, W] or [H, W]. For 3D images, if D is None, or the lenght of tuple/list is 2, + the input image is only reslcaled in 2D. If int, the smallest axis is matched to output_size + keeping aspect ratio the same as the input. :param `Rescale_inverse`: (optional, bool) Is inverse transform needed for inference. Default is `True`. """ @@ -38,6 +38,8 @@ def __call__(self, sample): output_size = self.output_size if(output_size[0] is None): output_size[0] = input_shape[1] + if(input_dim == 3 and len(self.output_size) == 2): + output_size = [input_shape[1]] + list(output_size) assert(len(output_size) == input_dim) else: min_edge = min(input_shape[1:]) diff --git a/pymic/transform/rotate.py b/pymic/transform/rotate.py index 5f85e28..b09f8da 100644 --- a/pymic/transform/rotate.py +++ b/pymic/transform/rotate.py @@ -19,13 +19,19 @@ class RandomRotate(AbstractTransform): :param `RandomRotate_angle_range_d`: (list/tuple or None) Rotation angle (degree) range along depth axis (x-y plane), e.g., (-90, 90). + The length of the list/tuple can be larger than 2, when `RandomRotate_discrete_mode` is True. If None, no rotation along this axis. :param `RandomRotate_angle_range_h`: (list/tuple or None) Rotation angle (degree) range along height axis (x-z plane), e.g., (-90, 90). + The length of the list/tuple can be larger than 2, when `RandomRotate_discrete_mode` is True. If None, no rotation along this axis. Only used for 3D images. :param `RandomRotate_angle_range_w`: (list/tuple or None) Rotation angle (degree) range along width axis (y-z plane), e.g., (-90, 90). + The length of the list/tuple can be larger than 2, when `RandomRotate_discrete_mode` is True. If None, no rotation along this axis. Only used for 3D images. + :param `RandomRotate_discrete_mode`: (optional, bool) Whether the rotate angles + are discrete values in rangle range. For example, if you only want to rotate + the images with a fixed set of angles like (90, 180, 270), then set discrete_mode mode as True. :param `RandomRotate_probability`: (optional, float) The probability of applying RandomRotate. Default is 0.5. :param `RandomRotate_inverse`: (optional, bool) @@ -36,8 +42,11 @@ def __init__(self, params): self.angle_range_d = params['RandomRotate_angle_range_d'.lower()] self.angle_range_h = params.get('RandomRotate_angle_range_h'.lower(), None) self.angle_range_w = params.get('RandomRotate_angle_range_w'.lower(), None) + self.discrete_mode = params.get('RandomRotate_discrete_mode'.lower(), False) self.prob = params.get('RandomRotate_probability'.lower(), 0.5) self.inverse = params.get('RandomRotate_inverse'.lower(), True) + if(len(self.angle_range_d) > 2): + assert(self.discrete_mode) def __apply_transformation(self, image, transform_param_list, order = 1): """ @@ -63,15 +72,27 @@ def __call__(self, sample): transform_param_list = [] if(self.angle_range_d is not None): - angle_d = np.random.uniform(self.angle_range_d[0], self.angle_range_d[1]) + if(self.discrete_mode): + idx = random.randint(0, len(self.angle_range_d) - 1) + angle_d = self.angle_range_d[idx] + else: + angle_d = np.random.uniform(self.angle_range_d[0], self.angle_range_d[1]) transform_param_list.append([angle_d, (-1, -2)]) if(input_dim == 3): if(self.angle_range_h is not None): - angle_h = np.random.uniform(self.angle_range_h[0], self.angle_range_h[1]) - transform_param_list.append([angle_h, (-1, -3)]) + if(self.discrete_mode): + idx = random.randint(0, len(self.angle_range_h) - 1) + angle_h = self.angle_range_h[idx] + else: + angle_h = np.random.uniform(self.angle_range_h[0], self.angle_range_h[1]) + transform_param_list.append([angle_h, (-1, -3)]) if(self.angle_range_w is not None): - angle_w = np.random.uniform(self.angle_range_w[0], self.angle_range_w[1]) - transform_param_list.append([angle_w, (-2, -3)]) + if(self.discrete_mode): + idx = random.randint(0, len(self.angle_range_w) - 1) + angle_w = self.angle_range_w[idx] + else: + angle_w = np.random.uniform(self.angle_range_w[0], self.angle_range_w[1]) + transform_param_list.append([angle_w, (-2, -3)]) assert(len(transform_param_list) > 0) # select a random transform from the possible list rather than # use a combination for higher efficiency From 4590b3045d9248e516d3d8f74a49ada085781125 Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 5 Aug 2025 10:42:56 +0800 Subject: [PATCH 80/86] update dataset and transform fix code for loading h5 images update transforms --- pymic/io/nifty_dataset.py | 20 +++---- pymic/net_run/agent_seg.py | 5 +- pymic/net_run/weak_sup/wsl_dmsps.py | 45 +++++++++------- pymic/transform/rotate.py | 81 ++++++++++++++++++----------- pymic/transform/trans_dict.py | 4 ++ 5 files changed, 94 insertions(+), 61 deletions(-) diff --git a/pymic/io/nifty_dataset.py b/pymic/io/nifty_dataset.py index cb23e2f..5424e91 100644 --- a/pymic/io/nifty_dataset.py +++ b/pymic/io/nifty_dataset.py @@ -41,8 +41,9 @@ class NiftyDataset(Dataset): :param transform: (list) List of transforms to be applied on a sample. The built-in transforms can listed in :mod:`pymic.transform.trans_dict`. """ - def __init__(self, root_dir, csv_file, modal_num = 1, image_dim = 3, allow_missing_modal = False, - with_label = True, transform=None, task = TaskType.SEGMENTATION): + def __init__(self, root_dir, csv_file, modal_num = 1, image_dim = 3, + allow_missing_modal = False, label_key = "label", + transform=None, task = TaskType.SEGMENTATION): self.root_dir = root_dir if(csv_file is not None): self.csv_items = pd.read_csv(csv_file) @@ -56,10 +57,11 @@ def __init__(self, root_dir, csv_file, modal_num = 1, image_dim = 3, allow_missi self.modal_num = modal_num self.image_dim = image_dim self.allow_emtpy= allow_missing_modal - self.with_label = with_label + self.label_key = label_key self.transform = transform self.task = task self.h5files = False + self.with_label = True assert self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION] # check if the files are h5 images, and if the labels are provided. @@ -69,11 +71,11 @@ def __init__(self, root_dir, csv_file, modal_num = 1, image_dim = 3, allow_missi self.h5files = True temp_full_name = "{0:}/{1:}".format(self.root_dir, temp_name) h5f = h5py.File(temp_full_name, 'r') - if('label' not in h5f): + if(self.label_key not in h5f): self.with_label = False else: csv_keys = list(self.csv_items.keys()) - if('label' not in csv_keys): + if(self.label_key not in csv_keys): self.with_label = False self.image_weight_idx = None @@ -84,7 +86,7 @@ def __init__(self, root_dir, csv_file, modal_num = 1, image_dim = 3, allow_missi self.pixel_weight_idx = csv_keys.index('pixel_weight') if(not self.with_label): logging.warning("`label` section is not found in the csv file {0:}".format( - csv_file) + "or the corresponding h5 file." + + csv_file) + " or the corresponding h5 file." + "\n -- This is only allowed for self-supervised learning" + "\n -- when `SelfSuperviseLabel` is used in the transform, or when" + "\n -- loading the unlabeled data for preprocessing.") @@ -94,7 +96,7 @@ def __len__(self): def __getlabel__(self, idx): csv_keys = list(self.csv_items.keys()) - label_idx = csv_keys.index('label') + label_idx = csv_keys.index(self.label_key) label_name = self.csv_items.iloc[idx, label_idx] label_name_full = "{0:}/{1:}".format(self.root_dir, label_name) label = load_image_as_nd_array(label_name_full)['data_array'] @@ -139,8 +141,8 @@ def __getitem__(self, idx): img = check_and_expand_dim(h5f['image'][:], self.image_dim) sample = {'image':img} if(self.with_label): - lab = check_and_expand_dim(h5f['label'][:], self.image_dim) - sample['label'] = lab + lab = check_and_expand_dim(h5f[self.label_key][:], self.image_dim) + sample['label'] = np.asarray(lab, np.float32) sample['names'] = [sample_name] else: for i in range (self.modal_num): diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index e1a3377..0259e3e 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -72,19 +72,22 @@ def get_stage_dataset_from_config(self, stage): allow_miss = self.config['dataset'].get('allow_missing_modal', False) stage_dir = self.config['dataset'].get('train_dir', None) stage_dim = self.config['dataset'].get('train_dim', 3) + stage_lab_key = self.config['dataset'].get('train_label_key', 'label') if(stage == 'valid'): # and "valid_dir" in self.config['dataset']): stage_dir = self.config['dataset'].get('valid_dir', stage_dir) stage_dim = self.config['dataset'].get('valid_dim', stage_dim) + stage_lab_key = self.config['dataset'].get('valid_label_key', 'label') if(stage == 'test'): # and "test_dir" in self.config['dataset']): stage_dir = self.config['dataset'].get('test_dir', stage_dir) stage_dim = self.config['dataset'].get('test_dim', stage_dim) + stage_lab_key = self.config['dataset'].get('test_label_key', 'label') logging.info("Creating dataset for {0:}".format(stage)) dataset = NiftyDataset(root_dir = stage_dir, csv_file = csv_file, modal_num = modal_num, image_dim = stage_dim, allow_missing_modal = allow_miss, - with_label= with_label, + label_key = stage_lab_key, transform = data_transform, task = self.task_type) return dataset diff --git a/pymic/net_run/weak_sup/wsl_dmsps.py b/pymic/net_run/weak_sup/wsl_dmsps.py index dca3643..5eec2a4 100644 --- a/pymic/net_run/weak_sup/wsl_dmsps.py +++ b/pymic/net_run/weak_sup/wsl_dmsps.py @@ -5,11 +5,13 @@ import random import time import torch +from PIL import Image from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.dice import DiceLoss from pymic.loss.seg.ce import CrossEntropyLoss +# from torch.nn.modules.loss import CrossEntropyLoss as TorchCELoss from pymic.net_run.weak_sup import WSLSegAgent from pymic.util.ramps import get_rampup_ratio @@ -33,11 +35,11 @@ class WSLDMSPS(WSLSegAgent): """ def __init__(self, config, stage = 'train'): net_type = config['network']['net_type'] - if net_type not in ['UNet2D_DualBranch', 'UNet3D_DualBranch']: - raise ValueError("""For WSL_DMPLS, a dual branch network is expected. \ - It only supports UNet2D_DualBranch and UNet3D_DualBranch currently.""") + # if net_type not in ['UNet2D_DualBranch', 'UNet3D_DualBranch']: + # raise ValueError("""For WSL_DMPLS, a dual branch network is expected. \ + # It only supports UNet2D_DualBranch and UNet3D_DualBranch currently.""") super(WSLDMSPS, self).__init__(config, stage) - + def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] @@ -49,10 +51,12 @@ def training(self): if (pseudo_loss_type not in ('dice_loss', 'ce_loss')): raise ValueError("""For pseudo supervision loss, only dice_loss and ce_loss \ are supported.""") + pseudo_loss_func = CrossEntropyLoss() if pseudo_loss_type == 'ce_loss' else DiceLoss() train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() + # ce_loss = CrossEntropyLoss() for it in range(iter_valid): t0 = time.time() try: @@ -66,7 +70,6 @@ def training(self): y = self.convert_tensor_type(data['label_prob']) inputs, y = inputs.to(self.device), y.to(self.device) - # zero the parameter gradients self.optimizer.zero_grad() @@ -78,23 +81,26 @@ def training(self): loss_sup2 = self.get_loss_value(data, outputs2, y) loss_sup = 0.5 * (loss_sup1 + loss_sup2) - # get pseudo label with dynamical mix + # torch_ce_loss = TorchCELoss(ignore_index=class_num) + # torch_ce_loss2 = TorchCELoss() + # loss_ce1 = torch_ce_loss(outputs1, label[:].long()) + # loss_ce2 = torch_ce_loss(outputs2, label[:].long()) + # loss_sup = 0.5 * (loss_ce1 + loss_ce2) + + # get pseudo label with dynamic mixture outputs_soft1 = torch.softmax(outputs1, dim=1) outputs_soft2 = torch.softmax(outputs2, dim=1) - beta = random.random() - pseudo_lab = beta*outputs_soft1.detach() + (1.0-beta)*outputs_soft2.detach() - # pseudo_lab = torch.argmax(pseudo_lab, dim = 1, keepdim = True) - # pseudo_lab = get_soft_label(pseudo_lab, class_num, self.tensor_type) - - # calculate the pseudo label supervision loss - loss_calculator = DiceLoss() if pseudo_loss_type == 'dice_loss' else CrossEntropyLoss() - loss_dict1 = {"prediction":outputs1, 'ground_truth':pseudo_lab} - loss_dict2 = {"prediction":outputs2, 'ground_truth':pseudo_lab} - loss_reg = 0.5 * (loss_calculator(loss_dict1) + loss_calculator(loss_dict2)) + alpha = random.random() + soft_pseudo_label = alpha * outputs_soft1.detach() + (1.0-alpha) * outputs_soft2.detach() + # loss_reg = 0.5*(torch_ce_loss2(outputs_soft1, soft_pseudo_label) +torch_ce_loss2(outputs_soft2, soft_pseudo_label) ) + loss_dict1 = {"prediction":outputs_soft1, 'ground_truth':soft_pseudo_label} + loss_dict2 = {"prediction":outputs_soft2, 'ground_truth':soft_pseudo_label} + loss_reg = 0.5 * (pseudo_loss_func(loss_dict1) + pseudo_loss_func(loss_dict2)) + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") - regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio - loss = loss_sup + regular_w*loss_reg + regular_w = wsl_cfg.get('regularize_w', 8.0) * rampup_ratio + loss = loss_sup + regular_w*loss_reg t3 = time.time() loss.backward() t4 = time.time() @@ -127,5 +133,4 @@ def training(self): 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, 'data_time': data_time, 'forward_time':gpu_time, 'loss_time':loss_time, 'backward_time':back_time} - return train_scalers - \ No newline at end of file + return train_scalers \ No newline at end of file diff --git a/pymic/transform/rotate.py b/pymic/transform/rotate.py index b09f8da..5de77d7 100644 --- a/pymic/transform/rotate.py +++ b/pymic/transform/rotate.py @@ -19,19 +19,13 @@ class RandomRotate(AbstractTransform): :param `RandomRotate_angle_range_d`: (list/tuple or None) Rotation angle (degree) range along depth axis (x-y plane), e.g., (-90, 90). - The length of the list/tuple can be larger than 2, when `RandomRotate_discrete_mode` is True. If None, no rotation along this axis. :param `RandomRotate_angle_range_h`: (list/tuple or None) Rotation angle (degree) range along height axis (x-z plane), e.g., (-90, 90). - The length of the list/tuple can be larger than 2, when `RandomRotate_discrete_mode` is True. If None, no rotation along this axis. Only used for 3D images. :param `RandomRotate_angle_range_w`: (list/tuple or None) Rotation angle (degree) range along width axis (y-z plane), e.g., (-90, 90). - The length of the list/tuple can be larger than 2, when `RandomRotate_discrete_mode` is True. If None, no rotation along this axis. Only used for 3D images. - :param `RandomRotate_discrete_mode`: (optional, bool) Whether the rotate angles - are discrete values in rangle range. For example, if you only want to rotate - the images with a fixed set of angles like (90, 180, 270), then set discrete_mode mode as True. :param `RandomRotate_probability`: (optional, float) The probability of applying RandomRotate. Default is 0.5. :param `RandomRotate_inverse`: (optional, bool) @@ -42,11 +36,8 @@ def __init__(self, params): self.angle_range_d = params['RandomRotate_angle_range_d'.lower()] self.angle_range_h = params.get('RandomRotate_angle_range_h'.lower(), None) self.angle_range_w = params.get('RandomRotate_angle_range_w'.lower(), None) - self.discrete_mode = params.get('RandomRotate_discrete_mode'.lower(), False) self.prob = params.get('RandomRotate_probability'.lower(), 0.5) self.inverse = params.get('RandomRotate_inverse'.lower(), True) - if(len(self.angle_range_d) > 2): - assert(self.discrete_mode) def __apply_transformation(self, image, transform_param_list, order = 1): """ @@ -61,38 +52,21 @@ def __apply_transformation(self, image, transform_param_list, order = 1): return image def __call__(self, sample): - # if(random.random() > self.prob): - # sample['RandomRotate_triggered'] = False - # return sample - # else: - # sample['RandomRotate_triggered'] = True image = sample['image'] input_shape = image.shape input_dim = len(input_shape) - 1 transform_param_list = [] if(self.angle_range_d is not None): - if(self.discrete_mode): - idx = random.randint(0, len(self.angle_range_d) - 1) - angle_d = self.angle_range_d[idx] - else: - angle_d = np.random.uniform(self.angle_range_d[0], self.angle_range_d[1]) + angle_d = np.random.uniform(self.angle_range_d[0], self.angle_range_d[1]) transform_param_list.append([angle_d, (-1, -2)]) if(input_dim == 3): if(self.angle_range_h is not None): - if(self.discrete_mode): - idx = random.randint(0, len(self.angle_range_h) - 1) - angle_h = self.angle_range_h[idx] - else: - angle_h = np.random.uniform(self.angle_range_h[0], self.angle_range_h[1]) - transform_param_list.append([angle_h, (-1, -3)]) + angle_h = np.random.uniform(self.angle_range_h[0], self.angle_range_h[1]) + transform_param_list.append([angle_h, (-1, -3)]) if(self.angle_range_w is not None): - if(self.discrete_mode): - idx = random.randint(0, len(self.angle_range_w) - 1) - angle_w = self.angle_range_w[idx] - else: - angle_w = np.random.uniform(self.angle_range_w[0], self.angle_range_w[1]) - transform_param_list.append([angle_w, (-2, -3)]) + angle_w = np.random.uniform(self.angle_range_w[0], self.angle_range_w[1]) + transform_param_list.append([angle_w, (-2, -3)]) assert(len(transform_param_list) > 0) # select a random transform from the possible list rather than # use a combination for higher efficiency @@ -123,4 +97,49 @@ def inverse_transform_for_prediction(self, sample): transform_param_list[i][0] = - transform_param_list[i][0] sample['predict'] = self.__apply_transformation(sample['predict'] , transform_param_list, 1) + return sample + +class RandomRot90(AbstractTransform): + """ + Random rotate an image in x-y plane with angles in [90, 180, 270]. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `RandomRot90_probability`: (optional, float) + The probability of applying RandomRot90. Default is 0.75. + :param `RandomRot90_inverse`: (optional, bool) + Is inverse transform needed for inference. Default is `True`. + """ + def __init__(self, params): + super(RandomRot90, self).__init__(params) + self.prob = params.get('RandomRot90_probability'.lower(), 0.75) + self.inverse = params.get('RandomRot90_inverse'.lower(), True) + + def __call__(self, sample): + if(random.random() > self.prob): + sample['RandomRot90_triggered'] = False + sample['RandomRot90_Param'] = 0 + return sample + else: + sample['RandomRot90_triggered'] = True + image = sample['image'] + rote_k = random.randint(1, 3) + sample['RandomRot90_Param'] = rote_k + image_t = np.rot90(image, rote_k, (-2, -1)) + sample['image'] = image_t + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + sample['label'] = np.rot90(sample['label'], rote_k, (-2, -1)) + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + sample['pixel_weight'] = np.rot90(sample['pixel_weight'], rote_k, (-2, -1)) + return sample + + def inverse_transform_for_prediction(self, sample): + if(not sample['RandomRot90_triggered']): + return sample + rote_k = sample['RandomRot90_Param'] + rote_i = 4 - rote_k + sample['predict'] = np.rot90(sample['predict'], rote_i, (-2, -1)) return sample \ No newline at end of file diff --git a/pymic/transform/trans_dict.py b/pymic/transform/trans_dict.py index 332e594..e4bfe24 100644 --- a/pymic/transform/trans_dict.py +++ b/pymic/transform/trans_dict.py @@ -25,6 +25,7 @@ 'RandomRescale': RandomRescale, 'RandomFlip': RandomFlip, 'RandomRotate': RandomRotate, + 'RandomRot90': RandomRot90, 'ReduceLabelDim': ReduceLabelDim, 'Rescale': Rescale, 'SelfSuperviseLabel': SelfSuperviseLabel, @@ -43,6 +44,7 @@ from pymic.transform.normalize import * from pymic.transform.crop import * from pymic.transform.crop4dino import Crop4Dino +from pymic.transform.crop4voco import Crop4VoCo from pymic.transform.crop4vox2vec import Crop4Vox2Vec from pymic.transform.crop4vf import Crop4VolumeFusion, VolumeFusion, VolumeFusionShuffle from pymic.transform.label_convert import * @@ -57,6 +59,7 @@ 'CropHumanRegion': CropHumanRegion, 'CenterCrop': CenterCrop, 'Crop4Dino': Crop4Dino, + 'Crop4VoCo': Crop4VoCo, 'Crop4Vox2Vec': Crop4Vox2Vec, 'Crop4VolumeFusion': Crop4VolumeFusion, 'GrayscaleToRGB': GrayscaleToRGB, @@ -83,6 +86,7 @@ 'RandomTranspose': RandomTranspose, 'RandomFlip': RandomFlip, 'RandomRotate': RandomRotate, + 'RandomRot90': RandomRot90, 'ReduceLabelDim': ReduceLabelDim, 'Rescale': Rescale, 'Resample': Resample, From 4baccae73d1c7eeeaf0353c8a53602f05ea5687a Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 11 Aug 2025 15:32:05 +0800 Subject: [PATCH 81/86] update files for DMSPS --- pymic/io/image_read_write.py | 8 ++- pymic/loss/loss_dict_seg.py | 2 + pymic/loss/seg/ce.py | 4 +- pymic/net/net2d/unet2d_multi_decoder.py | 19 +++---- pymic/net_run/weak_sup/wsl_dmsps.py | 70 ++++++++++++++++++++++++- pymic/transform/rescale.py | 20 +++++-- pymic/util/evaluation_seg.py | 24 ++++----- pymic/util/parse_config.py | 23 +++++++- 8 files changed, 137 insertions(+), 33 deletions(-) diff --git a/pymic/io/image_read_write.py b/pymic/io/image_read_write.py index cb17259..f570628 100644 --- a/pymic/io/image_read_write.py +++ b/pymic/io/image_read_write.py @@ -99,7 +99,7 @@ def save_array_as_nifty_volume(data, image_name, reference_name = None, spacing :param spacing: (list or tuple) the spacing of a volume data when `reference_name` is not provided. """ img = sitk.GetImageFromArray(data) - if(reference_name is not None): + if((reference_name is not None) and (not reference_name.endswith(".h5"))): img_ref = sitk.ReadImage(reference_name) #img.CopyInformation(img_ref) img.SetSpacing(img_ref.GetSpacing()) @@ -141,11 +141,15 @@ def save_nd_array_as_image(data, image_name, reference_name = None, spacing = [1 """ data_dim = len(data.shape) assert(data_dim == 2 or data_dim == 3) + if(image_name.endswith(".h5")): + if(data_dim == 3): + image_name = image_name.replace(".h5", ".nii.gz") + else: + image_name = image_name.replace(".h5", ".png") if (image_name.endswith(".nii.gz") or image_name.endswith(".nii") or image_name.endswith(".mha")): assert(data_dim == 3) save_array_as_nifty_volume(data, image_name, reference_name, spacing) - elif(image_name.endswith(".jpg") or image_name.endswith(".jpeg") or image_name.endswith(".tif") or image_name.endswith(".png")): assert(data_dim == 2) diff --git a/pymic/loss/loss_dict_seg.py b/pymic/loss/loss_dict_seg.py index fd72ce4..36e6a21 100644 --- a/pymic/loss/loss_dict_seg.py +++ b/pymic/loss/loss_dict_seg.py @@ -26,6 +26,7 @@ from pymic.loss.seg.dice import DiceLoss, FocalDiceLoss, \ NoiseRobustDiceLoss, BinaryDiceLoss, GroupDiceLoss from pymic.loss.seg.exp_log import ExpLogLoss +from pymic.loss.seg.ars_tversky import ARSTverskyLoss from pymic.loss.seg.mse import MSELoss, MAELoss from pymic.loss.seg.slsr import SLSRLoss @@ -35,6 +36,7 @@ 'DiceLoss': DiceLoss, 'BinaryDiceLoss': BinaryDiceLoss, 'FocalDiceLoss': FocalDiceLoss, + 'ARSTverskyLoss': ARSTverskyLoss, 'NoiseRobustDiceLoss': NoiseRobustDiceLoss, 'GroupDiceLoss': GroupDiceLoss, 'ExpLogLoss': ExpLogLoss, diff --git a/pymic/loss/seg/ce.py b/pymic/loss/seg/ce.py index 4edbbc3..bf036a3 100644 --- a/pymic/loss/seg/ce.py +++ b/pymic/loss/seg/ce.py @@ -36,7 +36,7 @@ def forward(self, loss_input_dict): soft_y = reshape_tensor_to_2D(soft_y) # for numeric stability - predict = predict * 0.999 + 5e-4 + # predict = predict * (1-1e-10) + 0.5e-10 ce = - soft_y* torch.log(predict) if(cls_w is not None): ce = torch.sum(ce*cls_w, dim = 1) @@ -46,7 +46,7 @@ def forward(self, loss_input_dict): ce = torch.mean(ce) else: pix_w = torch.squeeze(reshape_tensor_to_2D(pix_w)) - ce = torch.sum(pix_w * ce) / (pix_w.sum() + 1e-5) + ce = torch.sum(pix_w * ce) / (pix_w.sum() + 1e-10) return ce class GeneralizedCELoss(AbstractSegLoss): diff --git a/pymic/net/net2d/unet2d_multi_decoder.py b/pymic/net/net2d/unet2d_multi_decoder.py index 2c7b3c4..03bd99f 100644 --- a/pymic/net/net2d/unet2d_multi_decoder.py +++ b/pymic/net/net2d/unet2d_multi_decoder.py @@ -63,15 +63,16 @@ def forward(self, x): output2 = torch.reshape(output2, new_shape) output2 = torch.transpose(output2, 1, 2) - if(self.training): - return output1, output2 - else: - if(self.output_mode == "average"): - return (output1 + output2)/2 - elif(self.output_mode == "first"): - return output1 - else: - return output2 + return output1, output2 + # if(self.training): + # return output1, output2 + # else: + # if(self.output_mode == "average"): + # return (output1 + output2)/2 + # elif(self.output_mode == "first"): + # return output1 + # else: + # return output2 class UNet2D_TriBranch(nn.Module): """ diff --git a/pymic/net_run/weak_sup/wsl_dmsps.py b/pymic/net_run/weak_sup/wsl_dmsps.py index 5eec2a4..b610f72 100644 --- a/pymic/net_run/weak_sup/wsl_dmsps.py +++ b/pymic/net_run/weak_sup/wsl_dmsps.py @@ -1,11 +1,13 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division import logging +import os import numpy as np import random import time import torch -from PIL import Image +import scipy +from pymic.io.image_read_write import save_nd_array_as_image from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice @@ -133,4 +135,68 @@ def training(self): 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, 'data_time': data_time, 'forward_time':gpu_time, 'loss_time':loss_time, 'backward_time':back_time} - return train_scalers \ No newline at end of file + return train_scalers + + def save_outputs(self, data): + """ + Save prediction output. + + :param data: (dictionary) A data dictionary with prediciton result and other + information such as input image name. + """ + output_dir = self.config['testing']['output_dir'] + test_mode = self.config['testing'].get('dmsps_test_mode', 0) + uct_threshold = self.config['testing'].get('dmsps_uncertainty_threshold', 0.1) + # DMSPS_test_mode == 0: only save the segmentation label for the main decoder + # DMSPS_test_mode == 1: save all the results, including the the probability map of each decoder, + # the uncertainty map, and the confident predictions + if(not os.path.exists(output_dir)): + os.makedirs(output_dir, exist_ok=True) + + names, pred = data['names'], data['predict'] + pred0, pred1 = pred + prob0 = scipy.special.softmax(pred0, axis = 1) + prob1 = scipy.special.softmax(pred1, axis = 1) + prob_mean = (prob0 + prob1) / 2 + lab0 = np.asarray(np.argmax(prob0, axis = 1), np.uint8) + lab1 = np.asarray(np.argmax(prob1, axis = 1), np.uint8) + lab_mean = np.asarray(np.argmax(prob_mean, axis = 1), np.uint8) + + # save the output and (optionally) probability predictions + test_dir = self.config['dataset'].get('test_dir', None) + if(test_dir is None): + test_dir = self.config['dataset']['train_dir'] + img_name = names[0][0].split('/')[-1] + print(img_name) + lab0_name = img_name + if(".h5" in lab0_name): + lab0_name = lab0_name.replace(".h5", ".nii.gz") + save_nd_array_as_image(lab0[0], output_dir + "/" + lab0_name, test_dir + '/' + names[0][0]) + if(test_mode == 1): + lab1_name = lab0_name.replace(".nii.gz", "_predaux.nii.gz") + save_nd_array_as_image(lab1[0], output_dir + "/" + lab1_name, test_dir + '/' + names[0][0]) + C = pred0.shape[1] + uct = -1.0 * np.sum(prob_mean * np.log(prob_mean), axis=1, keepdims=False)/ np.log(C) + uct_name = lab0_name.replace(".nii.gz", "_uncertainty.nii.gz") + save_nd_array_as_image(uct[0], output_dir + "/" + uct_name, test_dir + '/' + names[0][0]) + conf_mask = uct < uct_threshold + conf_lab = conf_mask * lab_mean + (1 - conf_mask)*4 + conf_lab_name = lab0_name.replace(".nii.gz", "_seeds_expand.nii.gz") + + # get the largest connected component in each slice for each class + D, H, W = conf_lab[0].shape + from pymic.util.image_process import get_largest_k_components + for d in range(D): + lab2d = conf_lab[0][d] + for c in range(C): + lab2d_c = lab2d == c + mask_c = get_largest_k_components(lab2d_c, k = 1) + diff = lab2d_c != mask_c + if(np.sum(diff) > 0): + lab2d[diff] = C + conf_lab[0][d] = lab2d + save_nd_array_as_image(conf_lab[0], output_dir + "/" + conf_lab_name, test_dir + '/' + img_name) + + + + \ No newline at end of file diff --git a/pymic/transform/rescale.py b/pymic/transform/rescale.py index 154b1e0..ba519c7 100644 --- a/pymic/transform/rescale.py +++ b/pymic/transform/rescale.py @@ -72,12 +72,22 @@ def inverse_transform_for_prediction(self, sample): origin_shape = json.loads(sample['Rescale_origin_shape']) origin_dim = len(origin_shape) - 1 predict = sample['predict'] - input_shape = predict.shape - scale = [(origin_shape[1:][i] + 0.0)/input_shape[2:][i] for \ - i in range(origin_dim)] - scale = [1.0, 1.0] + scale - output_predict = ndimage.interpolation.zoom(predict, scale, order = 1) + if(isinstance(predict, tuple) or isinstance(predict, list)): + output_predict = [] + for predict_i in predict: + input_shape = predict_i.shape + scale = [(origin_shape[1:][i] + 0.0)/input_shape[2:][i] for \ + i in range(origin_dim)] + scale = [1.0, 1.0] + scale + output_predict_i = ndimage.interpolation.zoom(predict_i, scale, order = 1) + output_predict.append(output_predict_i) + else: + input_shape = predict.shape + scale = [(origin_shape[1:][i] + 0.0)/input_shape[2:][i] for \ + i in range(origin_dim)] + scale = [1.0, 1.0] + scale + output_predict = ndimage.interpolation.zoom(predict, scale, order = 1) sample['predict'] = output_predict return sample diff --git a/pymic/util/evaluation_seg.py b/pymic/util/evaluation_seg.py index 926ff0e..5099d3b 100644 --- a/pymic/util/evaluation_seg.py +++ b/pymic/util/evaluation_seg.py @@ -212,8 +212,10 @@ def get_binary_evaluation_score(s_volume, g_volume, spacing, metric): score = binary_iou(s_volume,g_volume) elif(metric_lower == 'assd'): score = binary_assd(s_volume, g_volume, spacing) + score = min(score, 20) # to reject outliers elif(metric_lower == "hd95"): score = binary_hd95(s_volume, g_volume, spacing) + score = min(score, 50) # to reject outliers elif(metric_lower == "rve"): score = binary_relative_volume_error(s_volume, g_volume) elif(metric_lower == "volume"): @@ -269,8 +271,8 @@ def evaluation(config): :param label_fuse: (option, bool) If true, fuse the labels in the `label_list` as the foreground, and other labels as the background. Default is False. :param organ_name: (str) The name of the organ for segmentation. - :param ground_truth_folder_root: (str) The root dir of ground truth images. - :param segmentation_folder_root: (str or list) The root dir of segmentation images. + :param ground_truth_folder: (str) The root dir of ground truth images. + :param segmentation_folder: (str or list) The root dir of segmentation images. When a list is given, each list element should be the root dir of the results of one method. :param evaluation_image_pair: (str) The csv file that provide the segmentation images and the corresponding ground truth images. @@ -366,23 +368,23 @@ def main(): """ parser = argparse.ArgumentParser() - parser.add_argument("-cfg", help="configuration file for evaluation", + parser.add_argument("--cfg", help="configuration file for evaluation", required=False, default=None) - parser.add_argument("-metric", help="evaluation metrics, e.g., dice, or [dice, assd]", + parser.add_argument("--metric", help="evaluation metrics, e.g., dice, or [dice, assd]", required=False, default=None) - parser.add_argument("-cls_num", help="number of classes", + parser.add_argument("--cls_num", help="number of classes", required=False, default=None) - parser.add_argument("-cls_index", help="The class index for evaluation, e.g., 255, [1, 2]", + parser.add_argument("--cls_index", help="The class index for evaluation, e.g., 255, [1, 2]", required=False, default=None) - parser.add_argument("-gt_dir", help="path of folder for ground truth", + parser.add_argument("--gt_dir", help="path of folder for ground truth", required=False, default=None) - parser.add_argument("-seg_dir", help="path of folder for segmentation", + parser.add_argument("--seg_dir", help="path of folder for segmentation", required=False, default=None) - parser.add_argument("-name_pair", help="the .csv file for name mapping in case" + parser.add_argument("--name_pair", help="the .csv file for name mapping in case" " the names of one case are different in the gt_dir " " and seg_dir", required=False, default=None) - parser.add_argument("-out", help="the output .csv file name", + parser.add_argument("--out", help="the output .csv file name", required=False, default=None) args = parser.parse_args() print(args) @@ -402,5 +404,3 @@ def main(): if __name__ == '__main__': main() - - main() diff --git a/pymic/util/parse_config.py b/pymic/util/parse_config.py index 09f6db0..3be02cb 100644 --- a/pymic/util/parse_config.py +++ b/pymic/util/parse_config.py @@ -96,12 +96,22 @@ def parse_config(args): args_key = getattr(args, key) if(args_key is not None): val_str = args_key + print(section, key, val_str) if(len(val_str)>0): val = parse_value_from_string(val_str) output[section][key] = val else: val = None - print(section, key, val) + + for key in ["train_dir", "train_csv", "valid_csv", "test_dir", "test_csv"]: + if key in args and getattr(args, key) is not None: + output["dataset"][key] = parse_value_from_string(getattr(args, key)) + for key in ["ckpt_dir", "iter_max", "gpus"]: + if key in args and getattr(args, key) is not None: + output["training"][key] = parse_value_from_string(getattr(args, key)) + for key in ["output_dir", "ckpt_mode", "ckpt_name"]: + if key in args and getattr(args, key) is not None: + output["testing"][key] = parse_value_from_string(getattr(args, key)) return output def synchronize_config(config): @@ -133,6 +143,17 @@ def synchronize_config(config): if('RandomResizedCrop' in transform and \ 'RandomResizedCrop_output_size'.lower() not in data_cfg): data_cfg['RandomResizedCrop_output_size'.lower()] = patch_size + if('testing' in config): + test_cfg = config['testing'] + sliding_window_enable = test_cfg.get("sliding_window_enable", False) + if(sliding_window_enable): + sliding_window_size = test_cfg.get("sliding_window_size", None) + if(sliding_window_size is None): + test_cfg["sliding_window_size"] = patch_size + sliding_window_stride = test_cfg.get("sliding_window_stride", None) + if(sliding_window_stride is None): + test_cfg["sliding_window_stride"] = [item // 2 for item in patch_size] + config['testing'] = test_cfg config['dataset'] = data_cfg # config['network'] = net_cfg return config From 254491e3deaeb2df7f47078846dcb6ba36937345 Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 12 Aug 2025 12:44:23 +0800 Subject: [PATCH 82/86] Update setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index cbf7355..51d939d 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setuptools.setup( name = 'PYMIC', - version = "0.5.0", + version = "0.5.1", author ='PyMIC Consortium', author_email = 'wguotai@gmail.com', description = description, From 92a9bf9e7fc9745cf6a153d85409d5b4671cc632 Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 12 Aug 2025 13:08:04 +0800 Subject: [PATCH 83/86] Create crop4voco.py --- pymic/transform/crop4voco.py | 107 +++++++++++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 pymic/transform/crop4voco.py diff --git a/pymic/transform/crop4voco.py b/pymic/transform/crop4voco.py new file mode 100644 index 0000000..6c52ca7 --- /dev/null +++ b/pymic/transform/crop4voco.py @@ -0,0 +1,107 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import json +import math +import random +import numpy as np +from scipy import ndimage +from pymic import TaskType +from pymic.transform.abstract_transform import AbstractTransform +from pymic.transform.crop import CenterCrop +from pymic.transform.intensity import * +from pymic.util.image_process import * + +def get_position_label(roi=96, num_crops=4): + half = roi // 2 + max_roi = roi * num_crops + center_x, center_y = np.random.randint(low=half, high=max_roi - half), \ + np.random.randint(low=half, high=max_roi - half) + + x_min, x_max = center_x - half, center_x + half + y_min, y_max = center_y - half, center_y + half + + total_area = roi * roi + labels = [] + for j in range(num_crops): + for i in range(num_crops): + crop_x_min, crop_x_max = i * roi, (i + 1) * roi + crop_y_min, crop_y_max = j * roi, (j + 1) * roi + + dx = min(crop_x_max, x_max) - max(crop_x_min, x_min) + dy = min(crop_y_max, y_max) - max(crop_y_min, y_min) + if dx <= 0 or dy <= 0: + area = 0 + else: + area = (dx * dy) / total_area + labels.append(area) + + labels = np.asarray(labels).reshape(1, num_crops * num_crops) + return x_min, y_min, labels + +class Crop4VoCo(CenterCrop): + """ + Randomly crop an volume into two views with augmentation. This is used for + self-supervised pretraining such as DeSD. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `DualViewCrop_output_size`: (list/tuple) Desired output size [D, H, W]. + The output channel is the same as the input channel. + :param `DualViewCrop_scale_lower_bound`: (list/tuple) Lower bound of the range of scale + for each dimension. e.g. (1.0, 0.5, 0.5). + param `DualViewCrop_scale_upper_bound`: (list/tuple) Upper bound of the range of scale + for each dimension. e.g. (1.0, 2.0, 2.0). + :param `DualViewCrop_inverse`: (optional, bool) Is inverse transform needed for inference. + Default is `False`. Currently, the inverse transform is not supported, and + this transform is assumed to be used only during training stage. + """ + def __init__(self, params): + roi_size = params.get('Crop4VoCo_roi_size'.lower(), 64) + if isinstance(roi_size, int): + self.roi_size = [roi_size] * 3 + else: + self.roi_size = roi_size + self.roi_num = params.get('Crop4VoCo_roi_num'.lower(), 2) + self.base_num = params.get('Crop4VoCo_base_num'.lower(), 4) + + self.inverse = params.get('Crop4VoCo_inverse'.lower(), False) + self.task = params['Task'.lower()] + + def __call__(self, sample): + image = sample['image'] + channel, input_size = image.shape[0], image.shape[1:] + input_dim = len(input_size) + # print(input_size, self.roi_size) + assert(input_size[0] == self.roi_size[0]) + assert(input_size[1] == self.roi_size[1] * self.base_num) + assert(input_size[2] == self.roi_size[2] * self.base_num) + + base_num, roi_num, roi_size = self.base_num, self.roi_num, self.roi_size + base_crops, roi_crops, roi_labels = [], [], [] + crop_size = [channel] + list(roi_size) + for j in range(base_num): + for i in range(base_num): + crop_min = [0, 0, roi_size[1]*j, roi_size[2]*i] + crop_max = [crop_min[d] + crop_size[d] for d in range(4)] + crop_out = crop_ND_volume_with_bounding_box(image, crop_min, crop_max) + base_crops.append(crop_out) + + for i in range(roi_num): + x_min, y_min, label = get_position_label(self.roi_size[2], base_num) + # print('label', label) + crop_min = [0, 0, y_min, x_min] + crop_max = [crop_min[d] + crop_size[d] for d in range(4)] + crop_out = crop_ND_volume_with_bounding_box(image, crop_min, crop_max) + roi_crops.append(crop_out) + roi_labels.append(label) + roi_labels = np.concatenate(roi_labels, 0).reshape(roi_num, base_num * base_num) + + base_crops = np.stack(base_crops, 0) + roi_crops = np.stack(roi_crops, 0) + sample['image'] = base_crops, roi_crops, roi_labels + return sample + + \ No newline at end of file From 48adf1d330d7dd9a2cd476b46c639d1b2858ea85 Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 12 Aug 2025 14:42:29 +0800 Subject: [PATCH 84/86] update dataset and requirement version --- pymic/net_run/noisy_label/nll_clslsr.py | 17 +++++++++++------ pymic/net_run/noisy_label/nll_dast.py | 16 ++++++++++------ pymic/net_run/predict.py | 17 ++++++++++------- pymic/net_run/semi_sup/ssl_abstract.py | 10 +++++++--- pymic/net_run/train.py | 10 +++++----- pymic/util/evaluation_cls.py | 8 ++++---- requirements.txt | 18 ++++++++++-------- 7 files changed, 57 insertions(+), 39 deletions(-) diff --git a/pymic/net_run/noisy_label/nll_clslsr.py b/pymic/net_run/noisy_label/nll_clslsr.py index c977eba..3f1059e 100644 --- a/pymic/net_run/noisy_label/nll_clslsr.py +++ b/pymic/net_run/noisy_label/nll_clslsr.py @@ -165,15 +165,20 @@ def get_confidence_map(cfg_file): transform_list.append(one_transform) data_transform = transforms.Compose(transform_list) + stage_dir = config['dataset']['train_dir'] csv_file = config['dataset']['train_csv'] modal_num = config['dataset'].get('modal_num', 1) - stage_dir = config['dataset']['train_dir'] + stage_dim = config['dataset'].get('train_dim', 3) + lab_key = config['dataset'].get('train_label_key', 'label') + dataset = NiftyDataset(root_dir = stage_dir, - csv_file = csv_file, - modal_num = modal_num, - with_label= True, - transform = data_transform, - task = agent.task_type) + csv_file = csv_file, + modal_num = modal_num, + image_dim = stage_dim, + allow_missing_modal = False, + label_key = lab_key, + transform = data_transform, + task = agent.task_type) agent.set_datasets(None, None, dataset) agent.transform_list = transform_list diff --git a/pymic/net_run/noisy_label/nll_dast.py b/pymic/net_run/noisy_label/nll_dast.py index 938e10a..95203ba 100644 --- a/pymic/net_run/noisy_label/nll_dast.py +++ b/pymic/net_run/noisy_label/nll_dast.py @@ -129,13 +129,17 @@ def get_noisy_dataset_from_config(self): data_transform = transforms.Compose(transform_list) modal_num = self.config['dataset'].get('modal_num', 1) - csv_file = self.config['dataset'].get('train_csv_noise', None) + stage_dim = self.config['dataset'].get('train_dim', 3) + lab_key = self.config['dataset'].get('train_label_key', 'label') + csv_file = self.config['dataset'].get('train_csv_noise', None) dataset = NiftyDataset(root_dir = self.config['dataset']['train_dir'], - csv_file = csv_file, - modal_num = modal_num, - with_label= True, - transform = data_transform , - task = self.task_type) + csv_file = csv_file, + modal_num = modal_num, + image_dim = stage_dim, + allow_missing_modal = False, + label_key = lab_key, + transform = data_transform, + task = self.task_type) return dataset diff --git a/pymic/net_run/predict.py b/pymic/net_run/predict.py index e618be6..d63cbad 100644 --- a/pymic/net_run/predict.py +++ b/pymic/net_run/predict.py @@ -21,23 +21,26 @@ def main(): exit() parser = argparse.ArgumentParser() parser.add_argument("cfg", help="configuration file for testing") - parser.add_argument("-test_csv", help="the csv file for testing images", - required=False, default=None) - parser.add_argument("-output_dir", help="the output dir for inference results", + parser.add_argument("--test_csv", help="the csv file for testing images", required=False, default=None) - parser.add_argument("-ckpt_dir", help="the dir for trained model", + parser.add_argument("--test_dir", help="the dir for testing images", required=False, default=None) - parser.add_argument("-ckpt_mode", help="the mode for chekpoint: 0-latest, 1-best, 2-customized", + parser.add_argument("--output_dir", help="the output dir for inference results", required=False, default=None) - parser.add_argument("-ckpt_name", help="the name chekpoint if ckpt_mode = 2", + parser.add_argument("--ckpt_dir", help="the dir for trained model", required=False, default=None) - parser.add_argument("-gpus", help="the gpus for runing, e.g., [0]", + parser.add_argument("--ckpt_mode", help="the mode for chekpoint: 0-latest, 1-best, 2-customized", + required=False, default=None) + parser.add_argument("--ckpt_name", help="the name chekpoint if ckpt_mode = 2", + required=False, default=None) + parser.add_argument("--gpus", help="the gpus for runing, e.g., [0]", required=False, default=None) args = parser.parse_args() if(not os.path.isfile(args.cfg)): raise ValueError("The config file does not exist: " + args.cfg) config = parse_config(args) config = synchronize_config(config) + print(config) log_dir = config['testing']['output_dir'] if(not os.path.exists(log_dir)): os.makedirs(log_dir, exist_ok=True) diff --git a/pymic/net_run/semi_sup/ssl_abstract.py b/pymic/net_run/semi_sup/ssl_abstract.py index 69a09fd..4925859 100644 --- a/pymic/net_run/semi_sup/ssl_abstract.py +++ b/pymic/net_run/semi_sup/ssl_abstract.py @@ -52,12 +52,16 @@ def get_unlabeled_dataset_from_config(self): self.transform_list.append(one_transform) data_transform = transforms.Compose(self.transform_list) - csv_file = self.config['dataset'].get('train_csv_unlab', None) + csv_file = self.config['dataset'].get('train_csv_unlab', None) + stage_dim = self.config['dataset'].get('train_dim', 3) dataset = NiftyDataset(root_dir = train_dir, csv_file = csv_file, modal_num = modal_num, - with_label= False, - transform = data_transform ) + image_dim = stage_dim, + allow_missing_modal = False, + label_key = None, + transform = data_transform, + task = self.task_type) return dataset def create_dataset(self): diff --git a/pymic/net_run/train.py b/pymic/net_run/train.py index ec0002f..d98145a 100644 --- a/pymic/net_run/train.py +++ b/pymic/net_run/train.py @@ -54,15 +54,15 @@ def main(): exit() parser = argparse.ArgumentParser() parser.add_argument("cfg", help="configuration file for training") - parser.add_argument("-train_csv", help="the csv file for training images", + parser.add_argument("--train_csv", help="the csv file for training images", required=False, default=None) - parser.add_argument("-valid_csv", help="the csv file for validation images", + parser.add_argument("--valid_csv", help="the csv file for validation images", required=False, default=None) - parser.add_argument("-ckpt_dir", help="the output dir for trained model", + parser.add_argument("--ckpt_dir", help="the output dir for trained model", required=False, default=None) - parser.add_argument("-iter_max", help="the maximal iteration number for training", + parser.add_argument("--iter_max", help="the maximal iteration number for training", required=False, default=None) - parser.add_argument("-gpus", help="the gpus for runing, e.g., [0]", + parser.add_argument("--gpus", help="the gpus for runing, e.g., [0]", required=False, default=None) args = parser.parse_args() if(not os.path.isfile(args.cfg)): diff --git a/pymic/util/evaluation_cls.py b/pymic/util/evaluation_cls.py index a65953a..686a811 100644 --- a/pymic/util/evaluation_cls.py +++ b/pymic/util/evaluation_cls.py @@ -176,13 +176,13 @@ def main(): :param pred_prob_csv: (str) The csv file for prediction probability. """ parser = argparse.ArgumentParser() - parser.add_argument("-cfg", help="configuration file for evaluation", + parser.add_argument("--cfg", help="configuration file for evaluation", required=False, default=None) - parser.add_argument("-metric", help="evaluation metrics, e.g., accuracy, or [accuracy, auc]", + parser.add_argument("--metric", help="evaluation metrics, e.g., accuracy, or [accuracy, auc]", required=False, default=None) - parser.add_argument("-gt_csv", help="csv file for ground truth", + parser.add_argument("--gt_csv", help="csv file for ground truth", required=False, default=None) - parser.add_argument("-pred_prob_csv", help="csv file for probability prediction", + parser.add_argument("--pred_prob_csv", help="csv file for probability prediction", required=False, default=None) args = parser.parse_args() print(args) diff --git a/requirements.txt b/requirements.txt index cac47f3..e70fc83 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,14 @@ h5py matplotlib>=3.1.2 -numpy>=1.17.4 -pandas>=0.25.3 -scikit-image>=0.16.2 -scikit-learn>=0.22 -scipy>=1.3.3 -SimpleITK>=2.0.0 +numpy>=1.23.5 +pandas>=1.5.2 +scikit-image>=0.19.3 +scikit-learn>=1.2.0 +scipy>=1.10.0 +SimpleITK>=2.0.2 tensorboard tensorboardX -torch>=1.1.12 -torchvision>=0.13.0 +torch>=1.13.1 +torchvision>=0.14.1 +causal-conv1d>=1.5.0 +mamba-ssm>=2.2.4 From 46835c73382c12480950d7714e2e6acd8acb9882 Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 12 Aug 2025 15:24:13 +0800 Subject: [PATCH 85/86] Update setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 51d939d..879ee6c 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setuptools.setup( name = 'PYMIC', - version = "0.5.1", + version = "0.5.4", author ='PyMIC Consortium', author_email = 'wguotai@gmail.com', description = description, From b27763cf6f78569a0748f74c2922894be024a7bd Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 12 Aug 2025 15:52:56 +0800 Subject: [PATCH 86/86] Update README.md --- README.md | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index bedeaae..e7757c4 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,16 @@ BibTeX entry: pages = {107398}, } +# News +* 2025/08 PyMIC has contained the implementation of [`DMSPS`][dmsps_paper], a state-of-the-art weakly supervised segmentation method by learning from scribble annotations. +* 2025/05 Several self-supervised learning methods have been provided in PyMIC, including [`VolF`][volf_paper], [`VoCo`][voco_paper] and [`Vox2Vec`][vox2vec_paper]. +* 2025/01 Novel architectures are available now, such as `UMamba`, `VMUNet`, `SwinUNet`, `TransUNet` and `UNETR++`. + +[dmsps_paper]: https://www.sciencedirect.com/science/article/pii/S1361841524001993 +[volf_paper]: https://arxiv.org/abs/2306.16925 +[voco_paper]: https://arxiv.org/abs/2402.17300 +[vox2vec_paper]:https://conferences.miccai.org/2023/papers/712-Paper3421.html + # Features PyMIC provides flixible modules for medical image computing tasks including classification and segmentation. It currently provides the following functions: * Support for annotation-efficient image segmentation, especially for semi-supervised, self-supervised, self-supervised, weakly-supervised and noisy-label learning. @@ -33,9 +43,10 @@ PyMIC provides flixible modules for medical image computing tasks including clas # Usage ## Requirement -* [Pytorch][torch_link] version >=1.0.1 +* [Pytorch][torch_link] version >=1.13.1 * [TensorboardX][tbx_link] to visualize training performance * Some common python packages such as Numpy, Pandas, SimpleITK +* causal-conv1d>=1.5.0 and mamba-ssm>=2.2.4 are required if you want to use Mamba in PyMIC. * See `requirements.txt` for details. [torch_link]:https://pytorch.org/ @@ -47,10 +58,10 @@ Run the following command to install the latest released version of PyMIC: ```bash pip install PYMIC ``` -To install a specific version of PYMIC such as 0.5.0, run: +To install a specific version of PYMIC such as 0.5.4, run: ```bash -pip install PYMIC==0.5.0 +pip install PYMIC==0.5.4 ``` Alternatively, you can download the source code for the latest version. Run the following command to compile and install: @@ -76,8 +87,11 @@ Using PyMIC, it becomes easy to develop deep learning models for different proje 4, [UGIR][ugir] (MICCAI 2020) Uncertainty-guided interactive refinement for medical image segmentation. +5, [DMSPS][dmsps] (MedIA 2024) Weakly supervised segmentation by learning from scribbles. + [myops]: https://github.com/HiLab-git/MyoPS2020 [coplenet]:https://github.com/HiLab-git/COPLE-Net [hn_gtv]: https://github.com/HiLab-git/Head-Neck-GTV [ugir]: https://github.com/HiLab-git/UGIR +[dmsps]: https://github.com/HiLab-git/DMSPS