diff --git a/.gitignore b/.gitignore index f7da7ac..639f873 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ dist/* *egg*/* *stop* files.txt +pymic/test/runs/* # Created by https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks # Edit at https://www.toptal.com/developers/gitignore?templates=python,jupyternotebooks diff --git a/README.md b/README.md index ac7f9b6..e7757c4 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # PyMIC: A Pytorch-Based Toolkit for Medical Image Computing -PyMIC is a pytorch-based toolkit for medical image computing with annotation-efficient deep learning. Despite that pytorch is a fantastic platform for deep learning, using it for medical image computing is not straightforward as medical images are often with high dimension and large volume, multiple modalities and difficulies in annotating. This toolkit is developed to facilitate medical image computing researchers so that training and testing deep learning models become easier. It is very friendly to researchers who are new to this area. Even without writing any code, you can use PyMIC commands to train and test a model by simply editing configuration files. PyMIC is developed to support learning with imperfect labels, including semi-supervised and weakly supervised learning, and learning with noisy annotations. +PyMIC is a pytorch-based toolkit for medical image computing with annotation-efficient deep learning. Despite that pytorch is a fantastic platform for deep learning, using it for medical image computing is not straightforward as medical images are often with high dimension and large volume, multiple modalities and difficulies in annotating. This toolkit is developed to facilitate medical image computing researchers so that training and testing deep learning models become easier. It is very friendly to researchers who are new to this area. Even without writing any code, you can use PyMIC commands to train and test a model by simply editing configuration files. PyMIC is developed to support learning with imperfect labels, including semi-supervised, self-supervised, and weakly supervised learning, and learning with noisy annotations. Currently PyMIC supports 2D/3D medical image classification and segmentation, and it is still under development. If you use this toolkit, please cite the following paper: @@ -21,9 +21,19 @@ BibTeX entry: pages = {107398}, } +# News +* 2025/08 PyMIC has contained the implementation of [`DMSPS`][dmsps_paper], a state-of-the-art weakly supervised segmentation method by learning from scribble annotations. +* 2025/05 Several self-supervised learning methods have been provided in PyMIC, including [`VolF`][volf_paper], [`VoCo`][voco_paper] and [`Vox2Vec`][vox2vec_paper]. +* 2025/01 Novel architectures are available now, such as `UMamba`, `VMUNet`, `SwinUNet`, `TransUNet` and `UNETR++`. + +[dmsps_paper]: https://www.sciencedirect.com/science/article/pii/S1361841524001993 +[volf_paper]: https://arxiv.org/abs/2306.16925 +[voco_paper]: https://arxiv.org/abs/2402.17300 +[vox2vec_paper]:https://conferences.miccai.org/2023/papers/712-Paper3421.html + # Features PyMIC provides flixible modules for medical image computing tasks including classification and segmentation. It currently provides the following functions: -* Support for annotation-efficient image segmentation, especially for semi-supervised, self-supervised, weakly-supervised and noisy-label learning. +* Support for annotation-efficient image segmentation, especially for semi-supervised, self-supervised, self-supervised, weakly-supervised and noisy-label learning. * User friendly: For beginners, you only need to edit the configuration files for model training and inference, without writing code. For advanced users, you can customize different modules (networks, loss functions, training pipeline, etc) and easily integrate them into PyMIC. * Easy-to-use I/O interface to read and write different 2D and 3D images. * Various data pre-processing/transformation methods before sending a tensor into a network. @@ -33,9 +43,10 @@ PyMIC provides flixible modules for medical image computing tasks including clas # Usage ## Requirement -* [Pytorch][torch_link] version >=1.0.1 +* [Pytorch][torch_link] version >=1.13.1 * [TensorboardX][tbx_link] to visualize training performance * Some common python packages such as Numpy, Pandas, SimpleITK +* causal-conv1d>=1.5.0 and mamba-ssm>=2.2.4 are required if you want to use Mamba in PyMIC. * See `requirements.txt` for details. [torch_link]:https://pytorch.org/ @@ -47,10 +58,10 @@ Run the following command to install the latest released version of PyMIC: ```bash pip install PYMIC ``` -To install a specific version of PYMIC such as 0.4.0, run: +To install a specific version of PYMIC such as 0.5.4, run: ```bash -pip install PYMIC==0.4.0 +pip install PYMIC==0.5.4 ``` Alternatively, you can download the source code for the latest version. Run the following command to compile and install: @@ -76,8 +87,11 @@ Using PyMIC, it becomes easy to develop deep learning models for different proje 4, [UGIR][ugir] (MICCAI 2020) Uncertainty-guided interactive refinement for medical image segmentation. +5, [DMSPS][dmsps] (MedIA 2024) Weakly supervised segmentation by learning from scribbles. + [myops]: https://github.com/HiLab-git/MyoPS2020 [coplenet]:https://github.com/HiLab-git/COPLE-Net [hn_gtv]: https://github.com/HiLab-git/Head-Neck-GTV [ugir]: https://github.com/HiLab-git/UGIR +[dmsps]: https://github.com/HiLab-git/DMSPS diff --git a/docs/history.txt b/docs/history.txt new file mode 100644 index 0000000..3ca7b28 --- /dev/null +++ b/docs/history.txt @@ -0,0 +1 @@ +2025.8.1 Add code of DMSPS \ No newline at end of file diff --git a/docs/source/api.rst b/docs/source/api.rst index 206000d..d09809c 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -9,8 +9,5 @@ API pymic.loss pymic.net pymic.net_run - pymic.net_run_nll - pymic.net_run_ssl - pymic.net_run_wsl pymic.transform pymic.util \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index cf1b568..09f2b50 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -9,8 +9,8 @@ copyright = '2021, HiLab' author = 'HiLab' -release = '0.1' -version = '0.1.0' +release = '0.4' +version = '0.4.0' # -- General configuration diff --git a/docs/source/index.rst b/docs/source/index.rst index f9b62ba..c1b6523 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -28,9 +28,9 @@ Citation If you use PyMIC for your research, please acknowledge it accordingly by citing our paper: -`G. Wang, X. Luo, R. Gu, S. Yang, Y. Qu, S. Zhai, Q. Zhao, K. Li, S. Zhang. (2022). -PyMIC: A deep learning toolkit for annotation-efficient medical image segmentation. -arXiv, 2208.09350. `_ +`G. Wang, X. Luo, R. Gu, S. Yang, Y. Qu, S. Zhai, Q. Zhao, K. Li, S. Zhang. +PyMIC: A deep learning toolkit for annotation-efficient medical image segmentation. +Computer Methods and Programs in Biomedicine (CMPB). 231 (2023): 107398. `_ BibTeX entry: @@ -41,8 +41,8 @@ BibTeX entry: author = {Guotai Wang and Xiangde Luo and Ran Gu and Shuojue Yang and Yijie Qu and Shuwei Zhai and Qianfei Zhao and Kang Li and Shaoting Zhang}, title = {{PyMIC: A deep learning toolkit for annotation-efficient medical image segmentation}}, year = {2022}, - url = {http://arxiv.org/abs/2208.09350}, - journal = {arXiv}, - volume = {2208.09350}, - pages = {1-10}, + url = {https://doi.org/10.1016/j.cmpb.2023.107398}, + journal = {Computer Methods and Programs in Biomedicine}, + volume = {231}, + pages = {107398}, } diff --git a/docs/source/installation.rst b/docs/source/installation.rst index ced640f..c1055e1 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -17,7 +17,8 @@ Alternatively, you can download or clone the code from `GitHub `_ - `h5py `_ diff --git a/docs/source/pymic.net.net2d.rst b/docs/source/pymic.net.net2d.rst index d978dfe..bd54bc6 100644 --- a/docs/source/pymic.net.net2d.rst +++ b/docs/source/pymic.net.net2d.rst @@ -52,6 +52,14 @@ pymic.net.net2d.unet2d\_dual\_branch module :undoc-members: :show-inheritance: +pymic.net.net2d.unet2d\_mcnet module +------------------------------------------- + +.. automodule:: pymic.net.net2d.unet2d_mcnet + :members: + :undoc-members: + :show-inheritance: + pymic.net.net2d.unet2d\_nest module ----------------------------------- diff --git a/docs/source/pymic.net_run.noisy_label.rst b/docs/source/pymic.net_run.noisy_label.rst new file mode 100644 index 0000000..04d38bb --- /dev/null +++ b/docs/source/pymic.net_run.noisy_label.rst @@ -0,0 +1,45 @@ +pymic.net\_run.noisy\_label package +=================================== + +Submodules +---------- + +pymic.net\_run.noisy\_label.nll\_clslsr module +---------------------------------------------- + +.. automodule:: pymic.net_run.noisy_label.nll_clslsr + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.noisy\_label.nll\_co\_teaching module +---------------------------------------------------- + +.. automodule:: pymic.net_run.noisy_label.nll_co_teaching + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.noisy\_label.nll\_dast module +-------------------------------------------- + +.. automodule:: pymic.net_run.noisy_label.nll_dast + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.noisy\_label.nll\_trinet module +---------------------------------------------- + +.. automodule:: pymic.net_run.noisy_label.nll_trinet + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: pymic.net_run.noisy_label + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/pymic.net_run.rst b/docs/source/pymic.net_run.rst index 74aab12..9f3bc26 100644 --- a/docs/source/pymic.net_run.rst +++ b/docs/source/pymic.net_run.rst @@ -1,6 +1,16 @@ pymic.net\_run package ====================== +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + pymic.net_run.semi_sup + pymic.net_run.weak_sup + pymic.net_run.noisy_label + Submodules ---------- @@ -44,14 +54,6 @@ pymic.net\_run.infer\_func module :undoc-members: :show-inheritance: -pymic.net\_run.net\_run module ------------------------------- - -.. automodule:: pymic.net_run.net_run - :members: - :undoc-members: - :show-inheritance: - Module contents --------------- diff --git a/docs/source/pymic.net_run.self_sup.rst b/docs/source/pymic.net_run.self_sup.rst new file mode 100644 index 0000000..a4568d1 --- /dev/null +++ b/docs/source/pymic.net_run.self_sup.rst @@ -0,0 +1,21 @@ +pymic.net\_run.self\_sup package +================================ + +Submodules +---------- + +pymic.net\_run.self\_sup.self\_sl\_agent module +----------------------------------------------- + +.. automodule:: pymic.net_run.self_sup.self_sl_agent + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: pymic.net_run.self_sup + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/pymic.net_run.semi_sup.rst b/docs/source/pymic.net_run.semi_sup.rst new file mode 100644 index 0000000..15692b2 --- /dev/null +++ b/docs/source/pymic.net_run.semi_sup.rst @@ -0,0 +1,77 @@ +pymic.net\_run.semi\_sup package +================================ + +Submodules +---------- + +pymic.net\_run.semi\_sup.ssl\_abstract module +--------------------------------------------- + +.. automodule:: pymic.net_run.semi_sup.ssl_abstract + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.semi\_sup.ssl\_cct module +---------------------------------------- + +.. automodule:: pymic.net_run.semi_sup.ssl_cct + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.semi\_sup.ssl\_cps module +---------------------------------------- + +.. automodule:: pymic.net_run.semi_sup.ssl_cps + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.semi\_sup.ssl\_mcnet module +---------------------------------------- + +.. automodule:: pymic.net_run.semi_sup.ssl_mcnet + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.semi\_sup.ssl\_em module +--------------------------------------- + +.. automodule:: pymic.net_run.semi_sup.ssl_em + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.semi\_sup.ssl\_mt module +--------------------------------------- + +.. automodule:: pymic.net_run.semi_sup.ssl_mt + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.semi\_sup.ssl\_uamt module +----------------------------------------- + +.. automodule:: pymic.net_run.semi_sup.ssl_uamt + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.semi\_sup.ssl\_urpc module +----------------------------------------- + +.. automodule:: pymic.net_run.semi_sup.ssl_urpc + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: pymic.net_run.semi_sup + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/pymic.net_run.weak_sup.rst b/docs/source/pymic.net_run.weak_sup.rst new file mode 100644 index 0000000..b906f72 --- /dev/null +++ b/docs/source/pymic.net_run.weak_sup.rst @@ -0,0 +1,69 @@ +pymic.net\_run.weak\_sup package +================================ + +Submodules +---------- + +pymic.net\_run.weak\_sup.wsl\_abstract module +--------------------------------------------- + +.. automodule:: pymic.net_run.weak_sup.wsl_abstract + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.weak\_sup.wsl\_dmpls module +------------------------------------------ + +.. automodule:: pymic.net_run.weak_sup.wsl_dmpls + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.weak\_sup.wsl\_em module +--------------------------------------- + +.. automodule:: pymic.net_run.weak_sup.wsl_em + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.weak\_sup.wsl\_gatedcrf module +--------------------------------------------- + +.. automodule:: pymic.net_run.weak_sup.wsl_gatedcrf + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.weak\_sup.wsl\_mumford\_shah module +-------------------------------------------------- + +.. automodule:: pymic.net_run.weak_sup.wsl_mumford_shah + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.weak\_sup.wsl\_tv module +--------------------------------------- + +.. automodule:: pymic.net_run.weak_sup.wsl_tv + :members: + :undoc-members: + :show-inheritance: + +pymic.net\_run.weak\_sup.wsl\_ustm module +----------------------------------------- + +.. automodule:: pymic.net_run.weak_sup.wsl_ustm + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: pymic.net_run.weak_sup + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/pymic.net_run_nll.rst b/docs/source/pymic.net_run_nll.rst deleted file mode 100644 index 40ee9ef..0000000 --- a/docs/source/pymic.net_run_nll.rst +++ /dev/null @@ -1,53 +0,0 @@ -pymic.net\_run\_nll package -=========================== - -Submodules ----------- - -pymic.net\_run\_nll.nll\_clslsr module --------------------------------------- - -.. automodule:: pymic.net_run_nll.nll_clslsr - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_nll.nll\_co\_teaching module --------------------------------------------- - -.. automodule:: pymic.net_run_nll.nll_co_teaching - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_nll.nll\_dast module ------------------------------------- - -.. automodule:: pymic.net_run_nll.nll_dast - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_nll.nll\_main module ------------------------------------- - -.. automodule:: pymic.net_run_nll.nll_main - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_nll.nll\_trinet module --------------------------------------- - -.. automodule:: pymic.net_run_nll.nll_trinet - :members: - :undoc-members: - :show-inheritance: - -Module contents ---------------- - -.. automodule:: pymic.net_run_nll - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/pymic.net_run_ssl.rst b/docs/source/pymic.net_run_ssl.rst deleted file mode 100644 index 236e2d6..0000000 --- a/docs/source/pymic.net_run_ssl.rst +++ /dev/null @@ -1,77 +0,0 @@ -pymic.net\_run\_ssl package -=========================== - -Submodules ----------- - -pymic.net\_run\_ssl.ssl\_abstract module ----------------------------------------- - -.. automodule:: pymic.net_run_ssl.ssl_abstract - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_ssl.ssl\_cct module ------------------------------------ - -.. automodule:: pymic.net_run_ssl.ssl_cct - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_ssl.ssl\_cps module ------------------------------------ - -.. automodule:: pymic.net_run_ssl.ssl_cps - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_ssl.ssl\_em module ----------------------------------- - -.. automodule:: pymic.net_run_ssl.ssl_em - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_ssl.ssl\_main module ------------------------------------- - -.. automodule:: pymic.net_run_ssl.ssl_main - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_ssl.ssl\_mt module ----------------------------------- - -.. automodule:: pymic.net_run_ssl.ssl_mt - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_ssl.ssl\_uamt module ------------------------------------- - -.. automodule:: pymic.net_run_ssl.ssl_uamt - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_ssl.ssl\_urpc module ------------------------------------- - -.. automodule:: pymic.net_run_ssl.ssl_urpc - :members: - :undoc-members: - :show-inheritance: - -Module contents ---------------- - -.. automodule:: pymic.net_run_ssl - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/pymic.net_run_wsl.rst b/docs/source/pymic.net_run_wsl.rst deleted file mode 100644 index 5eda921..0000000 --- a/docs/source/pymic.net_run_wsl.rst +++ /dev/null @@ -1,77 +0,0 @@ -pymic.net\_run\_wsl package -=========================== - -Submodules ----------- - -pymic.net\_run\_wsl.wsl\_abstract module ----------------------------------------- - -.. automodule:: pymic.net_run_wsl.wsl_abstract - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_wsl.wsl\_dmpls module -------------------------------------- - -.. automodule:: pymic.net_run_wsl.wsl_dmpls - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_wsl.wsl\_em module ----------------------------------- - -.. automodule:: pymic.net_run_wsl.wsl_em - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_wsl.wsl\_gatedcrf module ----------------------------------------- - -.. automodule:: pymic.net_run_wsl.wsl_gatedcrf - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_wsl.wsl\_main module ------------------------------------- - -.. automodule:: pymic.net_run_wsl.wsl_main - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_wsl.wsl\_mumford\_shah module ---------------------------------------------- - -.. automodule:: pymic.net_run_wsl.wsl_mumford_shah - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_wsl.wsl\_tv module ----------------------------------- - -.. automodule:: pymic.net_run_wsl.wsl_tv - :members: - :undoc-members: - :show-inheritance: - -pymic.net\_run\_wsl.wsl\_ustm module ------------------------------------- - -.. automodule:: pymic.net_run_wsl.wsl_ustm - :members: - :undoc-members: - :show-inheritance: - -Module contents ---------------- - -.. automodule:: pymic.net_run_wsl - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/pymic.rst b/docs/source/pymic.rst index 7545740..dd180f0 100644 --- a/docs/source/pymic.rst +++ b/docs/source/pymic.rst @@ -12,9 +12,6 @@ Subpackages pymic.loss pymic.net pymic.net_run - pymic.net_run_nll - pymic.net_run_ssl - pymic.net_run_wsl pymic.transform pymic.util diff --git a/docs/source/setup.rst b/docs/source/setup.rst new file mode 100644 index 0000000..552eb49 --- /dev/null +++ b/docs/source/setup.rst @@ -0,0 +1,7 @@ +setup module +============ + +.. automodule:: setup + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/usage.fsl.rst b/docs/source/usage.fsl.rst index 053daf7..937bb01 100644 --- a/docs/source/usage.fsl.rst +++ b/docs/source/usage.fsl.rst @@ -28,8 +28,8 @@ configuration file for running. .. tip:: If you use the built-in modules such as ``UNet`` and ``Dice`` + ``CrossEntropy`` loss - for segmentation, you don't need to write the above code. Just just use the ``pymic_run`` - command. + for segmentation, you don't need to write the above code. Just just use the ``pymic_train`` + command. See examples in `PyMIC_examples/segmentation/ `_. Dataset ------- @@ -207,7 +207,7 @@ hyper-parameters. For example, the following is a configuration for using ``2DUN feature_chns = [16, 32, 64, 128, 256] dropout = [0, 0, 0.3, 0.4, 0.5] bilinear = False - deep_supervise= False + multiscale_pred = False The ``SegNetDict`` in :mod:`pymic.net.net_dict_seg` lists all the built-in network structures currently implemented in PyMIC. @@ -299,9 +299,6 @@ Itreations For training iterations, the following parameters need to be specified in the configuration file: -* ``iter_start``: the start iteration, by default is 0. None zero value means the - iteration where a pre-trained model stopped for continuing with the trainnig. - * ``iter_max``: the maximal allowed iteration for training. * ``iter_valid``: if the value is K, it means evaluating the performance on the @@ -321,9 +318,9 @@ Optimizer For optimizer, users need to set ``optimizer``, ``learning_rate``, ``momentum`` and ``weight_decay``. The built-in optimizers include ``SGD``, ``Adam``, ``SparseAdam``, ``Adadelta``, ``Adagrad``, ``Adamax``, ``ASGD``, -``LBFGS``, ``RMSprop`` and ``Rprop`` that are implemented in :mod:`torch.optim`. +``LBFGS``, ``RMSprop`` and ``Rprop`` that are implemented in `torch.optim`. -You can also use customized optimizers via :mod:`SegmentationAgent.set_optimizer()`. +You can also use customized optimizers via `SegmentationAgent.set_optimizer()`. Learning Rate Scheduler ^^^^^^^^^^^^^^^^^^^^^^^ @@ -335,7 +332,7 @@ the configuration file. Parameters related to ``ReduceLROnPlateau`` include ``lr_gamma``. Parameters related to ``MultiStepLR`` include ``lr_gamma`` and ``lr_milestones``. -You can also use customized lr schedulers via :mod:`SegmentationAgent.set_scheduler()`. +You can also use customized lr schedulers via `SegmentationAgent.set_scheduler()`. Other Options ^^^^^^^^^^^^^ @@ -373,8 +370,8 @@ test-time augmentation, etc. The following is a list of options availble for inf * ``ckpt_name`` (string, optinal): the full path to the checkpoint if ckpt_mode = 2. * ``post_process`` (string, default is None): the post process method after inference. - The current available post processing is :mod:`PostKeepLargestComponent`. Uses can also - specify customized post process methods via :mod:`SegmentationAgent.set_postprocessor()`. + The current available post processing is :mod:`pymic.util.post_process.PostKeepLargestComponent`. + Uses can also specify customized post process methods via `SegmentationAgent.set_postprocessor()`. * ``sliding_window_enable`` (bool, default is False): use sliding window for inference or not. @@ -390,14 +387,14 @@ test-time augmentation, etc. The following is a list of options availble for inf * ``ignore_dir`` (bool, default is True): if the input image name has a `/`, it will be replaced with `_` in the output file name. -* ``save_probability`` (boold, default is False): save the output probability for each class. +* ``save_probability`` (bool, default is False): save the output probability for each class. * ``label_source`` (list, default is None): a list of label to be converted after prediction. For example, - :mod:`label_source` = [0, 1] and :mod:`label_target` = [0, 255] will convert label value from 1 to 255. + `label_source` = [0, 1] and `label_target` = [0, 255] will convert label value from 1 to 255. -* ``label_target`` (list, default is None): a list of label after conversion. Use this with :mod:`label_source`. +* ``label_target`` (list, default is None): a list of label after conversion. Use this with `label_source`. * ``filename_replace_source`` (string, default is None): the substring in the filename will be replaced with - a new substring specified by :mod:`filename_replace_target`. + a new substring specified by `filename_replace_target`. -* ``filename_replace_target`` (string, default is None): work with :mod:`filename_replace_source`. \ No newline at end of file +* ``filename_replace_target`` (string, default is None): work with `filename_replace_source`. \ No newline at end of file diff --git a/docs/source/usage.nll.rst b/docs/source/usage.nll.rst index f6edd33..0a87be9 100644 --- a/docs/source/usage.nll.rst +++ b/docs/source/usage.nll.rst @@ -3,43 +3,28 @@ Noisy Label Learning ==================== -pymic_nll ---------- - -:mod:`pymic_nll` is the command for using built-in NLL methods for training. -Similarly to :mod:`pymic_run`, it should be followed by two parameters, specifying the -stage and configuration file, respectively. The training and testing commands are: - -.. code-block:: bash - - pymic_nll train myconfig_nll.cfg - pymic_nll test myconfig_nll.cfg - -.. tip:: - - If the NLL method only involves one network, either ``pymic_nll`` or ``pymic_run`` - can be used for inference. Their difference only exists in the training stage. - .. note:: Some NLL methods only use noise-robust loss functions without complex training process, and just combining the standard :mod:`SegmentationAgent` with such - loss function works for training. ``pymic_run`` instead of ``pymic_nll`` should - be used for these methods. + loss function works for training. NLL Configurations ------------------ -In the configuration file for ``pymic_nll``, in addition to those used in standard fully +In the configuration file for noisy label learning, in addition to those used in standard fully supervised learning, there is a ``noisy_label_learning`` section that is specifically designed -for NLL methods. In that section, users need to specify the ``nll_method`` and configurations -related to the NLL method. For example, the correspoinding configuration for CoTeaching is: +for NLL methods. In that section, users need to specify the ``method_name`` and configurations +related to the NLL method. ``supervise_type`` should be set as "`noisy_label`" in the ``dataset`` section. + For example, the correspoinding configuration for CoTeaching is: .. code-block:: none [dataset] ... + supervise_type = noisy_label + ... [network] ... @@ -48,7 +33,7 @@ related to the NLL method. For example, the correspoinding configuration for CoT ... [noisy_label_learning] - nll_method = CoTeaching + method_name = CoTeaching co_teaching_select_ratio = 0.8 rampup_start = 1000 rampup_end = 8000 @@ -60,13 +45,14 @@ related to the NLL method. For example, the correspoinding configuration for CoT The configuration items vary with different NLL methods. Please refer to the API of each built-in NLL method for details of the correspoinding configuration. + See examples in `PyMIC_examples/seg_nll/ `_. + Built-in NLL Methods -------------------- -Some NLL methods only use noise-robust loss functions. They are used with ``pymic_run`` -for training. Just set ``loss_type`` to one of them in the configuration file, similarly -to the fully supervised learning. +Some NLL methods only use noise-robust loss functions. They are used with a standard fully supervised training +paradigm. Just set ``supervise_type`` = `fully_sup`, and use ``loss_type`` to one of them in the configuration file: * ``GCELoss``: (`NeurIPS 2018 `_) Generalized cross entropy loss. @@ -78,7 +64,7 @@ to the fully supervised learning. Noise-robust Dice loss. The other NLL methods are implemented in child classes of -:mod:`pymic.net_run_nll.nll_abstract.NLLSegAgent`, and they are: +:mod:`pymic.net_run.agent_seg.SegmentationAgent`, and they are: * ``CLSLSR``: (`MICCAI 2020 `_) Confident learning with spatial label smoothing regularization. @@ -95,15 +81,15 @@ The other NLL methods are implemented in child classes of Customized NLL Methods ---------------------- -PyMIC alo supports customizing NLL methods by inheriting the :mod:`NLLSegAgent` class. -You may only need to rewrite the :mod:`training()` method and reuse most part of the +PyMIC alo supports customized NLL methods by inheriting the `SegmentationAgent` class. +You may only need to rewrite the `training()` method and reuse most part of the existing pipeline, such as data loading, validation and inference methods. For example: .. code-block:: none - from pymic.net_run_nll.nll_abstract import NLLSegAgent + from pymic.net_run.agent_seg import SegmentationAgent - class MyNLLMethod(NLLSegAgent): + class MyNLLMethod(SegmentationAgent): def __init__(self, config, stage = 'train'): super(MyNLLMethod, self).__init__(config, stage) ... diff --git a/docs/source/usage.quickstart.rst b/docs/source/usage.quickstart.rst index 95cf20d..bced277 100644 --- a/docs/source/usage.quickstart.rst +++ b/docs/source/usage.quickstart.rst @@ -12,13 +12,13 @@ for segmentation with full supervision, run the fullowing command: .. code-block:: bash - pymic_run train myconfig.cfg + pymic_train myconfig.cfg After training, run the following command for testing: .. code-block:: bash - pymic_run test myconfig.cfg + pymic_test myconfig.cfg .. tip:: @@ -51,11 +51,13 @@ file used for segmentation of lung from radiograph, which can be find in [dataset] # tensor type (float or double) tensor_type = float + task_type = seg root_dir = ../../PyMIC_data/JSRT train_csv = config/jsrt_train.csv valid_csv = config/jsrt_valid.csv test_csv = config/jsrt_test.csv + train_batch_size = 4 # data transforms @@ -69,19 +71,26 @@ file used for segmentation of lung from radiograph, which can be find in LabelConvert_source_list = [0, 255] LabelConvert_target_list = [0, 1] + [network] + # this section gives parameters for network + # the keys may be different for different networks + + # type of network net_type = UNet2D - # Parameters for UNet2D + + # number of class, required for segmentation task class_num = 2 in_chns = 1 feature_chns = [16, 32, 64, 128, 256] dropout = [0, 0, 0.3, 0.4, 0.5] bilinear = False - deep_supervise= False + multiscale_pred = False [training] # list of gpus gpus = [0] + loss_type = DiceLoss # for optimizers @@ -95,8 +104,8 @@ file used for segmentation of lung from radiograph, which can be find in lr_gamma = 0.5 lr_milestones = [2000, 4000, 6000] - ckpt_save_dir = model/unet_dice_loss - ckpt_prefix = unet + ckpt_save_dir = model/unet + ckpt_prefix = unet # start iter iter_start = 0 @@ -107,9 +116,10 @@ file used for segmentation of lung from radiograph, which can be find in [testing] # list of gpus gpus = [0] + # checkpoint mode can be [0-latest, 1-best, 2-specified] - ckpt_mode = 0 - output_dir = result + ckpt_mode = 0 + output_dir = result/unet # convert the label of prediction output label_source = [0, 1] @@ -131,17 +141,18 @@ For example, for segmentation tasks, run: pymic_eval_seg evaluation.cfg -The configuration file is like (an example from ``PYMIC_examples/seg_ssl/ACDC``): +The configuration file is like (an example from +`PyMIC_examples/seg_ssl/ACDC `_): .. code-block:: none [evaluation] - metric = dice + metric_list = [dice, hd95] label_list = [1,2,3] organ_name = heart ground_truth_folder_root = ../../PyMIC_data/ACDC/preprocess - segmentation_folder_root = result/unet2d_em + segmentation_folder_root = result/unet2d_urpc evaluation_image_pair = config/data/image_test_gt_seg.csv See :mod:`pymic.util.evaluation_seg.evaluation` for details of the configuration required. @@ -152,7 +163,8 @@ For classification tasks, run: pymic_eval_cls evaluation.cfg -The configuration file is like (an example from ``PYMIC_examples/classification/CHNCXR``): +The configuration file is like (an example from +`PyMIC_examples/classification/CHNCXR `_): .. code-block:: none diff --git a/docs/source/usage.ssl.rst b/docs/source/usage.ssl.rst index 143d3f8..0fd8dc6 100644 --- a/docs/source/usage.ssl.rst +++ b/docs/source/usage.ssl.rst @@ -3,37 +3,22 @@ Semi-Supervised Learning ========================= -pymic_ssl ---------- - -:mod:`pymic_ssl` is the command for using built-in semi-supervised methods for training. -Similarly to :mod:`pymic_run`, it should be followed by two parameters, specifying the -stage and configuration file, respectively. The training and testing commands are: - -.. code-block:: bash - - pymic_ssl train myconfig_ssl.cfg - pymic_ssl test myconfig_ssl.cfg - -.. tip:: - - If the SSL method only involves one network, either ``pymic_ssl`` or ``pymic_run`` - can be used for inference. Their difference only exists in the training stage. - SSL Configurations ------------------ -In the configuration file for ``pymic_ssl``, in addition to those used in fully +In the configuration file for semi-supervised segmentation, in addition to those used in fully supervised learning, there are some items specificalized for semi-supervised learning. Users should provide values for the following items in ``dataset`` section of the configuration file: +* ``supervise_type`` (string): The value should be "`semi_sup`". + * ``train_csv_unlab`` (string): the csv file for unlabeled dataset. Note that ``train_csv`` is only used for labeled dataset. * ``train_batch_size_unlab`` (int): the batch size for unlabeled dataset. - Note that ``train_batch_size`` means the batch size for the labeled dataset. + Note that `train_batch_size` means the batch size for the labeled dataset. * ``train_transform_unlab`` (list): a list of transforms used for unlabeled data. @@ -43,7 +28,12 @@ The following is an example of the ``dataset`` section for semi-supervised learn .. code-block:: none ... - root_dir =../../PyMIC_data/ACDC/preprocess/ + + tensor_type = float + task_type = seg + supervise_type = semi_sup + + root_dir = ../../PyMIC_data/ACDC/preprocess/ train_csv = config/data/image_train_r10_lab.csv train_csv_unlab = config/data/image_train_r10_unlab.csv valid_csv = config/data/image_valid.csv @@ -60,14 +50,14 @@ The following is an example of the ``dataset`` section for semi-supervised learn ... In addition, there is a ``semi_supervised_learning`` section that is specifically designed -for SSL methods. In that section, users need to specify the ``ssl_method`` and configurations +for SSL methods. In that section, users need to specify the ``method_name`` and configurations related to the SSL method. For example, the correspoinding configuration for CPS is: .. code-block:: none ... [semi_supervised_learning] - ssl_method = CPS + method_name = CPS regularize_w = 0.1 rampup_start = 1000 rampup_end = 20000 @@ -76,14 +66,15 @@ related to the SSL method. For example, the correspoinding configuration for CPS .. note:: The configuration items vary with different SSL methods. Please refer to the API - of each built-in SSL method for details of the correspoinding configuration. + of each built-in SSL method for details of the correspoinding configuration. + See examples in `PyMIC_examples/seg_ssl/ `_. Built-in SSL Methods -------------------- -:mod:`pymic.net_run_ssl.ssl_abstract.SSLSegAgent` is the abstract class used for -semi-supervised learning. The built-in SSL methods are child classes of :mod:`SSLSegAgent`. -The available SSL methods implemnted in PyMIC are listed in :mod:`pymic.net_run_ssl.ssl_main.SSLMethodDict`, +:mod:`pymic.net_run.semi_sup.ssl_abstract.SSLSegAgent` is the abstract class used for +semi-supervised learning. The built-in SSL methods are child classes of `SSLSegAgent`. +The available SSL methods implemnted in PyMIC are listed in `pymic.net_run.semi_sup.SSLMethodDict`, and they are: * ``EntropyMinimization``: (`NeurIPS 2005 `_) @@ -103,13 +94,13 @@ and they are: Customized SSL Methods ---------------------- -PyMIC alo supports customizing SSL methods by inheriting the :mod:`SSLSegAgent` class. -You may only need to rewrite the :mod:`training()` method and reuse most part of the +PyMIC alo supports customizing SSL methods by inheriting the `SSLSegAgent` class. +You may only need to rewrite the `training()` method and reuse most part of the existing pipeline, such as data loading, validation and inference methods. For example: .. code-block:: none - from pymic.net_run_ssl.ssl_abstract import SSLSegAgent + from pymic.net_run.semi_sup import SSLSegAgent class MySSLMethod(SSLSegAgent): def __init__(self, config, stage = 'train'): diff --git a/docs/source/usage.wsl.rst b/docs/source/usage.wsl.rst index 00471f6..10d6da2 100644 --- a/docs/source/usage.wsl.rst +++ b/docs/source/usage.wsl.rst @@ -3,23 +3,6 @@ Weakly-Supervised Learning ========================== -pymic_wsl ---------- - -:mod:`pymic_wsl` is the command for using built-in weakly-supervised methods for training. -Similarly to :mod:`pymic_run`, it should be followed by two parameters, specifying the -stage and configuration file, respectively. The training and testing commands are: - -.. code-block:: bash - - pymic_wsl train myconfig_wsl.cfg - pymic_wsl test myconfig_wsl.cfg - -.. tip:: - - If the WSL method only involves one network, either ``pymic_wsl`` or ``pymic_run`` - can be used for inference. Their difference only exists in the training stage. - .. note:: Currently, the weakly supervised methods supported by PyMIC are only for learning @@ -31,17 +14,19 @@ stage and configuration file, respectively. The training and testing commands ar WSL Configurations ------------------ -In the configuration file for ``pymic_wsl``, in addition to those used in fully +In the configuration file for weakly supervised learning, in addition to those used in fully supervised learning, there are some items specificalized for weakly-supervised learning. -First, in the :mod:`train_transform` list, a special transform named :mod:`PartialLabelToProbability` +First, ``supervise_type`` should be set as "`weak_sup`" in the ``dataset`` section. + +Second, in the ``train_transform`` list, a special transform named `PartialLabelToProbability` should be used to transform patial labels into a one-hot probability map and a weighting map of pixels (i.e., the weight of a pixel is 1 if labeled and 0 otherwise). The patial cross entropy loss on labeled pixels is actually implemented by a weighted cross entropy loss. The loss setting is `loss_type = CrossEntropyLoss`. -Second, there is a ``weakly_supervised_learning`` section that is specifically designed -for WSL methods. In that section, users need to specify the ``wsl_method`` and configurations +Thirdly, there is a ``weakly_supervised_learning`` section that is specifically designed +for WSL methods. In that section, users need to specify the ``method_name`` and configurations related to the WSL method. For example, the correspoinding configuration for GatedCRF is: @@ -50,6 +35,7 @@ related to the WSL method. For example, the correspoinding configuration for Gat [dataset] ... + supervise_type = weak_sup root_dir = ../../PyMIC_data/ACDC/preprocess train_csv = config/data/image_train.csv valid_csv = config/data/image_valid.csv @@ -72,7 +58,7 @@ related to the WSL method. For example, the correspoinding configuration for Gat ... [weakly_supervised_learning] - wsl_method = GatedCRF + method_name = GatedCRF regularize_w = 0.1 rampup_start = 2000 rampup_end = 15000 @@ -90,13 +76,14 @@ related to the WSL method. For example, the correspoinding configuration for Gat The configuration items vary with different WSL methods. Please refer to the API of each built-in WSL method for details of the correspoinding configuration. + See examples in `PyMIC_examples/seg_wsl/ `_. Built-in WSL Methods -------------------- -:mod:`pymic.net_run_wsl.wsl_abstract.WSLSegAgent` is the abstract class used for -weakly-supervised learning. The built-in WSL methods are child classes of :mod:`WSLSegAgent`. -The available WSL methods implemnted in PyMIC are listed in :mod:`pymic.net_run_wsl.wsl_main.WSLMethodDict`, +:mod:`pymic.net_run.weak_sup.wsl_abstract.WSLSegAgent` is the abstract class used for +weakly-supervised learning. The built-in WSL methods are child classes of `WSLSegAgent`. +The available WSL methods implemnted in PyMIC are listed in `pymic.net_run.weak_sup.WSLMethodDict`, and they are: * ``EntropyMinimization``: (`NeurIPS 2005 `_) @@ -120,13 +107,13 @@ and they are: Customized WSL Methods ---------------------- -PyMIC alo supports customizing WSL methods by inheriting the :mod:`WSLSegAgent` class. -You may only need to rewrite the :mod:`training()` method and reuse most part of the +PyMIC alo supports customizing WSL methods by inheriting the `WSLSegAgent` class. +You may only need to rewrite the `training()` method and reuse most part of the existing pipeline, such as data loading, validation and inference methods. For example: .. code-block:: none - from pymic.net_run_wsl.wsl_abstract import WSLSegAgent + from pymic.net_run.weak_sup import WSLSegAgent class MyWSLMethod(WSLSegAgent): def __init__(self, config, stage = 'train'): diff --git a/pymic/__init__.py b/pymic/__init__.py index cb6356a..ae1775d 100644 --- a/pymic/__init__.py +++ b/pymic/__init__.py @@ -1,2 +1,19 @@ from __future__ import absolute_import -__version__ = "0.4.0" \ No newline at end of file +from enum import Enum + +__version__ = "0.5.0" # 2024.11.15 + +class TaskType(Enum): + CLASSIFICATION_ONE_HOT = 1 + CLASSIFICATION_COEXIST = 2 + REGRESSION = 3 + SEGMENTATION = 4 + RECONSTRUCTION = 5 + +TaskDict = { + 'cls': TaskType.CLASSIFICATION_ONE_HOT, + 'cls_coexist': TaskType.CLASSIFICATION_COEXIST, + 'regress': TaskType.REGRESSION, + 'seg': TaskType.SEGMENTATION, + 'rec': TaskType.RECONSTRUCTION +} \ No newline at end of file diff --git a/pymic/io/h5_dataset.py b/pymic/io/h5_dataset.py index 02f94f3..34fa1a4 100644 --- a/pymic/io/h5_dataset.py +++ b/pymic/io/h5_dataset.py @@ -8,8 +8,9 @@ import pandas as pd from torch.utils.data import Dataset from torch.utils.data.sampler import Sampler +from pymic import TaskType -class H5DataSet(Dataset): +class H5DataSet_backup(Dataset): """ Dataset for loading images stored in h5 format. It generates 4D tensors with dimention order [C, D, H, W] for 3D images, and @@ -39,7 +40,9 @@ def __getitem__(self, idx): if self.transform: sample = self.transform(sample) return sample - + + + class TwoStreamBatchSampler(Sampler): """Iterate two sets of indices diff --git a/pymic/io/image_read_write.py b/pymic/io/image_read_write.py index cb65e19..f570628 100644 --- a/pymic/io/image_read_write.py +++ b/pymic/io/image_read_write.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division - +import logging import os import numpy as np import SimpleITK as sitk @@ -23,11 +23,9 @@ def load_nifty_volume_as_4d_array(filename): spacing = img_obj.GetSpacing() direction = img_obj.GetDirection() shape = data_array.shape - if(len(shape) == 4): - assert(shape[3] == 1) - elif(len(shape) == 3): + if(len(shape) == 3): data_array = np.expand_dims(data_array, axis = 0) - else: + elif(len(shape) > 4 or len(shape) < 3): raise ValueError("unsupported image dim: {0:}".format(len(shape))) output = {} output['data_array'] = data_array @@ -55,10 +53,16 @@ def load_rgb_image_as_3d_array(filename): image = np.expand_dims(image, axis = 0) else: # transpose rgb image from [H, W, C] to [C, H, W] - assert(image_shape[2] == 3 or image_shape[2] == 4) + # logging.warning("The image is expected to have 1 or three channels, but it has a different channel number") + # logging.warning("({0:} {1:}".format(filename, image_shape)) if(image_shape[2] == 4): image = image[:, :, range(3)] + elif(image_shape[2] == 2): + image = image[:, :, 0:1] + elif(image_shape[2] != 3): + raise ValueError("invalid channel number {0:}", image_shape[2]) image = np.transpose(image, axes = [2, 0, 1]) + output = {} output['data_array'] = image output['origin'] = (0, 0) @@ -77,14 +81,14 @@ def load_image_as_nd_array(image_name): if (image_name.endswith(".nii.gz") or image_name.endswith(".nii") or image_name.endswith(".mha")): image_dict = load_nifty_volume_as_4d_array(image_name) - elif(image_name.endswith(".jpg") or image_name.endswith(".jpeg") or - image_name.endswith(".tif") or image_name.endswith(".png")): + elif(image_name.lower().endswith(".jpg") or image_name.lower().endswith(".jpeg") or + image_name.lower().endswith(".tif") or image_name.lower().endswith(".png")): image_dict = load_rgb_image_as_3d_array(image_name) else: - raise ValueError("unsupported image format") + raise ValueError("unsupported image format: {0:}".format(image_name)) return image_dict -def save_array_as_nifty_volume(data, image_name, reference_name = None): +def save_array_as_nifty_volume(data, image_name, reference_name = None, spacing = [1.0,1.0,1.0]): """ Save a numpy array as nifty image @@ -92,14 +96,21 @@ def save_array_as_nifty_volume(data, image_name, reference_name = None): :param image_name: (str) The ouput file name. :param reference_name: (str) File name of the reference image of which meta information is used. + :param spacing: (list or tuple) the spacing of a volume data when `reference_name` is not provided. """ img = sitk.GetImageFromArray(data) - if(reference_name is not None): + if((reference_name is not None) and (not reference_name.endswith(".h5"))): img_ref = sitk.ReadImage(reference_name) #img.CopyInformation(img_ref) img.SetSpacing(img_ref.GetSpacing()) img.SetOrigin(img_ref.GetOrigin()) - img.SetDirection(img_ref.GetDirection()) + direction0 = img_ref.GetDirection() + direction1 = img.GetDirection() + if(len(direction0) == len(direction1)): + img.SetDirection(direction0) + else: + nifty_spacing = spacing[1:] + spacing[:1] + img.SetSpacing(nifty_spacing) sitk.WriteImage(img, image_name) def save_array_as_rgb_image(data, image_name): @@ -118,7 +129,7 @@ def save_array_as_rgb_image(data, image_name): img = Image.fromarray(data) img.save(image_name) -def save_nd_array_as_image(data, image_name, reference_name = None): +def save_nd_array_as_image(data, image_name, reference_name = None, spacing = [1.0,1.0,1.0]): """ Save a 3D or 2D numpy array as medical image or RGB image @@ -126,14 +137,19 @@ def save_nd_array_as_image(data, image_name, reference_name = None): [H, W, 3] or [H, W]. :param reference_name: (str) File name of the reference image of which meta information is used. + :param spacing: (list or tuple) the spacing of a volume data when `reference_name` is not provided. """ data_dim = len(data.shape) assert(data_dim == 2 or data_dim == 3) + if(image_name.endswith(".h5")): + if(data_dim == 3): + image_name = image_name.replace(".h5", ".nii.gz") + else: + image_name = image_name.replace(".h5", ".png") if (image_name.endswith(".nii.gz") or image_name.endswith(".nii") or image_name.endswith(".mha")): assert(data_dim == 3) - save_array_as_nifty_volume(data, image_name, reference_name) - + save_array_as_nifty_volume(data, image_name, reference_name, spacing) elif(image_name.endswith(".jpg") or image_name.endswith(".jpeg") or image_name.endswith(".tif") or image_name.endswith(".png")): assert(data_dim == 2) diff --git a/pymic/io/nifty_dataset.py b/pymic/io/nifty_dataset.py index bb1ff23..5424e91 100644 --- a/pymic/io/nifty_dataset.py +++ b/pymic/io/nifty_dataset.py @@ -1,14 +1,28 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division +import logging import os -import torch +import h5py import pandas as pd import numpy as np -from torch.utils.data import Dataset, DataLoader -from torchvision import transforms, utils +from torch.utils.data import Dataset +from pymic import TaskType from pymic.io.image_read_write import load_image_as_nd_array +def check_and_expand_dim(x, img_dim): + """ + check the input dim and expand it with a channel dimension if necessary. + For 2D images, return a 3D numpy array with a shape of [C, H, W] + for 3D images, return a 3D numpy array with a shape of [C, D, H, W] + """ + input_dim = len(x.shape) + if(input_dim == 2 and img_dim == 2): + x = np.expand_dims(x, axis = 0) + elif(input_dim == 3 and img_dim == 3): + x = np.expand_dims(x, axis = 0) + return x + class NiftyDataset(Dataset): """ Dataset for loading images for segmentation. It generates 4D tensors with @@ -16,39 +30,81 @@ class NiftyDataset(Dataset): with dimention order [C, H, W] for 2D images. :param root_dir: (str) Directory with all the images. - :param csv_file: (str) Path to the csv file with image names. - :param modal_num: (int) Number of modalities. + :param csv: (str) Path to the csv file with image names. If it is None, + the images will be those under root_dir. This only works for testing with + a single input modality. If the images are stored in h5 files, the *.csv file + only has one column, while for other types of images such as .nii.gz and.png, + each column is for an input modality, and the last column is for label. + :param modal_num: (int) Number of modalities. This is only used if the data_file is *.csv. + :param image_dim: (int) Spacial dimension of the input image. This is ony used for h5 files. :param with_label: (bool) Load the data with segmentation ground truth or not. :param transform: (list) List of transforms to be applied on a sample. The built-in transforms can listed in :mod:`pymic.transform.trans_dict`. """ - def __init__(self, root_dir, csv_file, modal_num = 1, - with_label = False, transform=None): + def __init__(self, root_dir, csv_file, modal_num = 1, image_dim = 3, + allow_missing_modal = False, label_key = "label", + transform=None, task = TaskType.SEGMENTATION): self.root_dir = root_dir - self.csv_items = pd.read_csv(csv_file) + if(csv_file is not None): + self.csv_items = pd.read_csv(csv_file) + else: + img_names = os.listdir(root_dir) + img_names = [item for item in img_names if ("nii" in item or "jpg" in item or + "jpeg" in item or "bmp" in item or "png" in item)] + csv_dict = {"image":img_names} + self.csv_items = pd.DataFrame.from_dict(csv_dict) + self.modal_num = modal_num - self.with_label = with_label + self.image_dim = image_dim + self.allow_emtpy= allow_missing_modal + self.label_key = label_key self.transform = transform + self.task = task + self.h5files = False + self.with_label = True + assert self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION] - csv_keys = list(self.csv_items.keys()) - self.image_weight_idx = None - self.pixel_weight_idx = None - if('image_weight' in csv_keys): - self.image_weight_idx = csv_keys.index('image_weight') - if('pixel_weight' in csv_keys): - self.pixel_weight_idx = csv_keys.index('pixel_weight') + # check if the files are h5 images, and if the labels are provided. + temp_name = self.csv_items.iloc[0, 0] + logging.warning(temp_name) + if(temp_name.endswith(".h5")): + self.h5files = True + temp_full_name = "{0:}/{1:}".format(self.root_dir, temp_name) + h5f = h5py.File(temp_full_name, 'r') + if(self.label_key not in h5f): + self.with_label = False + else: + csv_keys = list(self.csv_items.keys()) + if(self.label_key not in csv_keys): + self.with_label = False + + self.image_weight_idx = None + self.pixel_weight_idx = None + if('image_weight' in csv_keys): + self.image_weight_idx = csv_keys.index('image_weight') + if('pixel_weight' in csv_keys): + self.pixel_weight_idx = csv_keys.index('pixel_weight') + if(not self.with_label): + logging.warning("`label` section is not found in the csv file {0:}".format( + csv_file) + " or the corresponding h5 file." + + "\n -- This is only allowed for self-supervised learning" + + "\n -- when `SelfSuperviseLabel` is used in the transform, or when" + + "\n -- loading the unlabeled data for preprocessing.") def __len__(self): return len(self.csv_items) def __getlabel__(self, idx): - csv_keys = list(self.csv_items.keys()) - label_idx = csv_keys.index('label') - label_name = "{0:}/{1:}".format(self.root_dir, - self.csv_items.iloc[idx, label_idx]) - label = load_image_as_nd_array(label_name)['data_array'] - label = np.asarray(label, np.int32) - return label + csv_keys = list(self.csv_items.keys()) + label_idx = csv_keys.index(self.label_key) + label_name = self.csv_items.iloc[idx, label_idx] + label_name_full = "{0:}/{1:}".format(self.root_dir, label_name) + label = load_image_as_nd_array(label_name_full)['data_array'] + if(self.task == TaskType.SEGMENTATION): + label = np.asarray(label, np.int32) + elif(self.task == TaskType.RECONSTRUCTION): + label = np.asarray(label, np.float32) + return label, label_name def __get_pixel_weight__(self, idx): weight_name = "{0:}/{1:}".format(self.root_dir, @@ -57,36 +113,75 @@ def __get_pixel_weight__(self, idx): weight = np.asarray(weight, np.float32) return weight + # def __getitem__(self, idx): + # sample_name = self.csv_items.iloc[idx, 0] + # h5f = h5py.File(self.root_dir + '/' + sample_name, 'r') + # image = np.asarray(h5f['image'][:], np.float32) + + # # this a temporaory process, will be delieted later + # if(len(image.shape) == 3 and image.shape[0] > 1): + # image = np.expand_dims(image, 0) + # sample = {'image': image, 'names':sample_name} + + # if('label' in h5f): + # label = np.asarray(h5f['label'][:], np.uint8) + # if(len(label.shape) == 3 and label.shape[0] > 1): + # label = np.expand_dims(label, 0) + # sample['label'] = label + # if self.transform: + # sample = self.transform(sample) + # return sample + def __getitem__(self, idx): names_list, image_list = [], [] - for i in range (self.modal_num): - image_name = self.csv_items.iloc[idx, i] - image_full_name = "{0:}/{1:}".format(self.root_dir, image_name) - image_dict = load_image_as_nd_array(image_full_name) - image_data = image_dict['data_array'] - names_list.append(image_name) - image_list.append(image_data) - image = np.concatenate(image_list, axis = 0) - image = np.asarray(image, np.float32) - sample = {'image': image, 'names' : names_list[0], - 'origin':image_dict['origin'], - 'spacing': image_dict['spacing'], - 'direction':image_dict['direction']} - if (self.with_label): - sample['label'] = self.__getlabel__(idx) - assert(image.shape[1:] == sample['label'].shape[1:]) - if (self.image_weight_idx is not None): - sample['image_weight'] = self.csv_items.iloc[idx, self.image_weight_idx] - if (self.pixel_weight_idx is not None): - sample['pixel_weight'] = self.__get_pixel_weight__(idx) - assert(image.shape[1:] == sample['pixel_weight'].shape[1:]) + image_shape = None + if(self.h5files): + sample_name = self.csv_items.iloc[idx, 0] + h5f = h5py.File(self.root_dir + '/' + sample_name, 'r') + img = check_and_expand_dim(h5f['image'][:], self.image_dim) + sample = {'image':img} + if(self.with_label): + lab = check_and_expand_dim(h5f[self.label_key][:], self.image_dim) + sample['label'] = np.asarray(lab, np.float32) + sample['names'] = [sample_name] + else: + for i in range (self.modal_num): + image_name = self.csv_items.iloc[idx, i] + image_full_name = "{0:}/{1:}".format(self.root_dir, image_name) + if(os.path.exists(image_full_name)): + image_dict = load_image_as_nd_array(image_full_name) + image_data = image_dict['data_array'] + elif(self.allow_emtpy and image_shape is not None): + image_data = np.zeros(image_shape) + else: + raise KeyError("File not found: {0:}".format(image_full_name)) + if(i == 0): + image_shape = image_data.shape + names_list.append(image_name) + image_list.append(image_data) + image = np.concatenate(image_list, axis = 0) + image = np.asarray(image, np.float32) + + sample = {'image': image, 'names' : names_list, + 'origin':image_dict['origin'], + 'spacing': image_dict['spacing'], + 'direction':image_dict['direction']} + if (self.with_label): + sample['label'], label_name = self.__getlabel__(idx) + sample['names'].append(label_name) + assert(image.shape[1:] == sample['label'].shape[1:]) + if (self.image_weight_idx is not None): + sample['image_weight'] = self.csv_items.iloc[idx, self.image_weight_idx] + if (self.pixel_weight_idx is not None): + sample['pixel_weight'] = self.__get_pixel_weight__(idx) + assert(image.shape[1:] == sample['pixel_weight'].shape[1:]) if self.transform: sample = self.transform(sample) return sample -class ClassificationDataset(NiftyDataset): +class ClassificationDataset(Dataset): """ Dataset for loading images for classification. It generates 4D tensors with dimention order [C, D, H, W] for 3D images, and 3D tensors @@ -101,15 +196,33 @@ class ClassificationDataset(NiftyDataset): The built-in transforms can listed in :mod:`pymic.transform.trans_dict`. """ def __init__(self, root_dir, csv_file, modal_num = 1, class_num = 2, - with_label = False, transform=None): - super(ClassificationDataset, self).__init__(root_dir, - csv_file, modal_num, with_label, transform) + with_label = False, transform=None, task = TaskType.CLASSIFICATION_ONE_HOT): + # super(ClassificationDataset, self).__init__(root_dir, + # csv_file, modal_num, with_label, transform, task) + self.root_dir = root_dir + self.csv_items = pd.read_csv(csv_file) + self.modal_num = modal_num + self.with_label = with_label + self.transform = transform self.class_num = class_num + self.task = task + assert self.task in [TaskType.CLASSIFICATION_ONE_HOT, TaskType.CLASSIFICATION_COEXIST] + + csv_keys = list(self.csv_items.keys()) + self.image_weight_idx = None + if('image_weight' in csv_keys): + self.image_weight_idx = csv_keys.index('image_weight') + + def __len__(self): + return len(self.csv_items) def __getlabel__(self, idx): csv_keys = list(self.csv_items.keys()) - label_idx = csv_keys.index('label') - label = self.csv_items.iloc[idx, label_idx] + if self.task == TaskType.CLASSIFICATION_ONE_HOT: + label_idx = csv_keys.index('label') + label = self.csv_items.iloc[idx, label_idx] + else: + label = np.asarray(self.csv_items.iloc[idx, 1:self.class_num + 1], np.float32) return label def __getweight__(self, idx): @@ -127,13 +240,15 @@ def __getitem__(self, idx): names_list.append(image_name) image_list.append(image_data) image = np.concatenate(image_list, axis = 0) - image = np.asarray(image, np.float32) + image = np.asarray(image, np.float32) sample = {'image': image, 'names' : names_list[0], 'origin':image_dict['origin'], 'spacing': image_dict['spacing'], 'direction':image_dict['direction']} + if (self.with_label): - sample['label'] = self.__getlabel__(idx) + label = self.__getlabel__(idx) + sample['label'] = label #np.asarray(label, np.float32) if (self.image_weight_idx is not None): sample['image_weight'] = self.__getweight__(idx) if self.transform: diff --git a/pymic/loss/cls/basic.py b/pymic/loss/cls/basic.py index 56925fc..6b71e23 100644 --- a/pymic/loss/cls/basic.py +++ b/pymic/loss/cls/basic.py @@ -39,7 +39,7 @@ def forward(self, loss_input_dict): class SigmoidCELoss(AbstractClassificationLoss): """ - Sigmoid-based CE loss. + Sigmoid-based CE loss, should be used when task_type = cls_coexist """ def __init__(self, params = None): super(SigmoidCELoss, self).__init__(params) diff --git a/pymic/loss/cls/infoNCE.py b/pymic/loss/cls/infoNCE.py new file mode 100644 index 0000000..fb6f1c1 --- /dev/null +++ b/pymic/loss/cls/infoNCE.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import torch.nn as nn + +class InfoNCELoss(nn.Module): + """ + Abstract Classification Loss. + """ + def __init__(self, params = None): + super(InfoNCELoss, self).__init__() + self.temp = params.get("temperature", 0.1) + + def forward(self, input_1, input_2): + """ + The arguments should be written in the `loss_input_dict` dictionary, and it has the + following fields. + + :param prediction: A prediction with shape of [N, C] where C is the class number. + :param ground_truth: The corresponding ground truth, with shape of [N, 1]. + + Note that `prediction` is the digit output of a network, before using softmax. + """ + B = list(input_1.shape)[0] + loss = 0.0 + for b in range(B): + embeds_1 = input_1[b] + embeds_2 = input_2[b] + logits_11 = torch.matmul(embeds_1, embeds_1.T) / self.temp + logits_11.fill_diagonal_(float('-inf')) + logits_12 = torch.matmul(embeds_1, embeds_2.T) / self.temp + logits_22 = torch.matmul(embeds_2, embeds_2.T) / self.temp + logits_22.fill_diagonal_(float('-inf')) + loss_1 = torch.mean(-logits_12.diag() + torch.logsumexp(torch.cat([logits_11, logits_12], dim=1), dim=1)) + loss_2 = torch.mean(-logits_12.diag() + torch.logsumexp(torch.cat([logits_12.T, logits_22], dim=1), dim=1)) + loss = loss + (loss_1 + loss_2) / 2 + loss = loss / B + return loss \ No newline at end of file diff --git a/pymic/loss/loss_dict_cls.py b/pymic/loss/loss_dict_cls.py index e07f46b..44744cb 100644 --- a/pymic/loss/loss_dict_cls.py +++ b/pymic/loss/loss_dict_cls.py @@ -11,9 +11,10 @@ """ from __future__ import print_function, division from pymic.loss.cls.basic import * - +from pymic.loss.cls.infoNCE import InfoNCELoss PyMICClsLossDict = {"CrossEntropyLoss": CrossEntropyLoss, "SigmoidCELoss": SigmoidCELoss, + 'InfoNCELoss': InfoNCELoss, "L1Loss": L1Loss, "MSELoss": MSELoss, "NLLLoss": NLLLoss} diff --git a/pymic/loss/loss_dict_seg.py b/pymic/loss/loss_dict_seg.py index 97c537e..36e6a21 100644 --- a/pymic/loss/loss_dict_seg.py +++ b/pymic/loss/loss_dict_seg.py @@ -23,8 +23,10 @@ from __future__ import print_function, division import torch.nn as nn from pymic.loss.seg.ce import CrossEntropyLoss, GeneralizedCELoss -from pymic.loss.seg.dice import DiceLoss, FocalDiceLoss, NoiseRobustDiceLoss +from pymic.loss.seg.dice import DiceLoss, FocalDiceLoss, \ + NoiseRobustDiceLoss, BinaryDiceLoss, GroupDiceLoss from pymic.loss.seg.exp_log import ExpLogLoss +from pymic.loss.seg.ars_tversky import ARSTverskyLoss from pymic.loss.seg.mse import MSELoss, MAELoss from pymic.loss.seg.slsr import SLSRLoss @@ -32,8 +34,11 @@ 'CrossEntropyLoss': CrossEntropyLoss, 'GeneralizedCELoss': GeneralizedCELoss, 'DiceLoss': DiceLoss, + 'BinaryDiceLoss': BinaryDiceLoss, 'FocalDiceLoss': FocalDiceLoss, + 'ARSTverskyLoss': ARSTverskyLoss, 'NoiseRobustDiceLoss': NoiseRobustDiceLoss, + 'GroupDiceLoss': GroupDiceLoss, 'ExpLogLoss': ExpLogLoss, 'MAELoss': MAELoss, 'MSELoss': MSELoss, diff --git a/pymic/loss/seg/abstract.py b/pymic/loss/seg/abstract.py index f42d816..68643e8 100644 --- a/pymic/loss/seg/abstract.py +++ b/pymic/loss/seg/abstract.py @@ -16,9 +16,20 @@ class AbstractSegLoss(nn.Module): def __init__(self, params = None): super(AbstractSegLoss, self).__init__() if(params is None): - self.softmax = True + self.acti_func = 'softmax' else: - self.softmax = params.get('loss_softmax', True) + self.acti_func = params.get('loss_acti_func', 'softmax') + + def get_activated_prediction(self, p, acti_func = 'softmax'): + if(acti_func == "softmax"): + p = nn.Softmax(dim = 1)(p) + elif(acti_func == "tanh"): + p = nn.Tanh()(p) + elif(acti_func == "sigmoid"): + p = nn.Sigmoid()(p) + else: + raise ValueError("activation for output is not supported: {0:}".format(acti_func)) + return p def forward(self, loss_input_dict): """ diff --git a/pymic/loss/seg/ars_tversky.py b/pymic/loss/seg/ars_tversky.py new file mode 100644 index 0000000..4fafeae --- /dev/null +++ b/pymic/loss/seg/ars_tversky.py @@ -0,0 +1,67 @@ + +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch.nn as nn +from pymic.loss.seg.abstract import AbstractSegLoss + +class ARSTverskyLoss(AbstractSegLoss): + """ + The Adaptive Region-Specific Loss in this paper: + + * Y. Chen et al.: Adaptive Region-Specific Loss for Improved Medical Image Segmentation. + `IEEE TPAMI 2023. `_ + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `ARSTversky_patch_size`: (list) the patch size. + :param `A`: the lowest weight for FP or FN (default 0.3) + :param `B`: the gap between lowest and highest weight (default 0.4) + """ + def __init__(self, params): + super(ARSTverskyLoss, self).__init__(params) + self.patch_size = params['ARSTversky_patch_size'.lower()] + self.a = params.get('ARSTversky_a'.lower(), 0.3) + self.b = params.get('ARSTversky_b'.lower(), 0.4) + + self.dim = len(self.patch_size) + assert self.dim in [2, 3], "The num of dim must be 2 or 3." + if self.dim == 3: + self.pool = nn.AvgPool3d(kernel_size=self.patch_size, stride=self.patch_size) + elif self.dim == 2: + self.pool = nn.AvgPool2d(kernel_size=self.patch_size, stride=self.patch_size) + + def forward(self, loss_input_dict): + predict = loss_input_dict['prediction'] + soft_y = loss_input_dict['ground_truth'] + + if(isinstance(predict, (list, tuple))): + predict = predict[0] + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) + + smooth = 1e-5 + if self.dim == 2: + assert predict.shape[-2] % self.patch_size[0] == 0, "image size % patch size must be 0 in dimension y" + assert predict.shape[-1] % self.patch_size[1] == 0, "image size % patch size must be 0 in dimension x" + elif self.dim == 3: + assert predict.shape[-3] % self.patch_size[0] == 0, "image size % patch size must be 0 in dimension z" + assert predict.shape[-2] % self.patch_size[1] == 0, "image size % patch size must be 0 in dimension y" + assert predict.shape[-1] % self.patch_size[2] == 0, "image size % patch size must be 0 in dimension x" + + tp = predict * soft_y + fp = predict * (1 - soft_y) + fn = (1 - predict) * soft_y + + region_tp = self.pool(tp) + region_fp = self.pool(fp) + region_fn = self.pool(fn) + + alpha = self.a + self.b * (region_fp + smooth) / (region_fp + region_fn + smooth) + beta = self.a + self.b * (region_fn + smooth) / (region_fp + region_fn + smooth) + + region_tversky = (region_tp + smooth) / (region_tp + alpha * region_fp + beta * region_fn + smooth) + region_tversky = 1 - region_tversky + loss = region_tversky.mean() + return loss \ No newline at end of file diff --git a/pymic/loss/seg/ce.py b/pymic/loss/seg/ce.py index 529482b..bf036a3 100644 --- a/pymic/loss/seg/ce.py +++ b/pymic/loss/seg/ce.py @@ -13,8 +13,10 @@ class CrossEntropyLoss(AbstractSegLoss): The parameters should be written in the `params` dictionary, and it has the following fields: - :param `loss_softmax`: (optional, bool) - Apply softmax to the prediction of network or not. Default is True. + :param `loss_acti_func`: (optional, string) + Apply an activation function to the prediction of network or not, for example, + 'softmax' for image segmentation tasks, 'tanh' for reconstruction tasks, and None + means no activation is used. """ def __init__(self, params = None): super(CrossEntropyLoss, self).__init__(params) @@ -23,23 +25,28 @@ def forward(self, loss_input_dict): predict = loss_input_dict['prediction'] soft_y = loss_input_dict['ground_truth'] pix_w = loss_input_dict.get('pixel_weight', None) + cls_w = loss_input_dict.get('class_weight', None) if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) + predict = reshape_tensor_to_2D(predict) soft_y = reshape_tensor_to_2D(soft_y) # for numeric stability - predict = predict * 0.999 + 5e-4 + # predict = predict * (1-1e-10) + 0.5e-10 ce = - soft_y* torch.log(predict) - ce = torch.sum(ce, dim = 1) # shape is [N] + if(cls_w is not None): + ce = torch.sum(ce*cls_w, dim = 1) + else: + ce = torch.sum(ce, dim = 1) # shape is [N] if(pix_w is None): ce = torch.mean(ce) else: pix_w = torch.squeeze(reshape_tensor_to_2D(pix_w)) - ce = torch.sum(pix_w * ce) / (pix_w.sum() + 1e-5) + ce = torch.sum(pix_w * ce) / (pix_w.sum() + 1e-10) return ce class GeneralizedCELoss(AbstractSegLoss): @@ -61,32 +68,29 @@ class GeneralizedCELoss(AbstractSegLoss): def __init__(self, params): super(GeneralizedCELoss, self).__init__(params) self.q = params.get('loss_gce_q', 0.5) - self.enable_pix_weight = params.get('loss_with_pixel_weight', False) - self.cls_weight = params.get('loss_class_weight', None) def forward(self, loss_input_dict): predict = loss_input_dict['prediction'] - soft_y = loss_input_dict['ground_truth'] + soft_y = loss_input_dict['ground_truth'] + pix_w = loss_input_dict.get('pixel_weight', None) + cls_w = loss_input_dict.get('class_weight', None) if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) predict = reshape_tensor_to_2D(predict) soft_y = reshape_tensor_to_2D(soft_y) gce = (1.0 - torch.pow(predict, self.q)) / self.q * soft_y - if(self.cls_weight is not None): - gce = torch.sum(gce * self.cls_w, dim = 1) + if(cls_w is not None): + gce = torch.sum(gce * cls_w, dim = 1) else: gce = torch.sum(gce, dim = 1) - if(self.enable_pix_weight): - pix_w = loss_input_dict.get('pixel_weight', None) - if(pix_w is None): - raise ValueError("Pixel weight is enabled but not defined") - pix_w = reshape_tensor_to_2D(pix_w) - gce = torch.sum(gce * pix_w) / torch.sum(pix_w) + if(pix_w is not None): + pix_w = torch.squeeze(reshape_tensor_to_2D(pix_w)) + gce = torch.sum(gce * pix_w) / torch.sum(pix_w) else: gce = torch.mean(gce) return gce diff --git a/pymic/loss/seg/deep_sup.py b/pymic/loss/seg/deep_sup.py index da6d9ef..6669486 100644 --- a/pymic/loss/seg/deep_sup.py +++ b/pymic/loss/seg/deep_sup.py @@ -2,6 +2,7 @@ from __future__ import print_function, division import torch.nn as nn +import numpy as np from torch.nn.functional import interpolate from pymic.loss.seg.abstract import AbstractSegLoss @@ -69,7 +70,7 @@ def forward(self, loss_input_dict): be a list or a tuple""") pred_num = len(pred) if(self.deep_sup_weight is None): - self.deep_sup_weight = [1.0] * pred_num + self.deep_sup_weight = [1.0 / pow(2, i) for i in range(pred_num)] else: assert(pred_num == len(self.deep_sup_weight)) loss_sum, weight_sum = 0.0, 0.0 diff --git a/pymic/loss/seg/dice.py b/pymic/loss/seg/dice.py index c3a1134..c423c2c 100644 --- a/pymic/loss/seg/dice.py +++ b/pymic/loss/seg/dice.py @@ -20,17 +20,78 @@ def __init__(self, params = None): def forward(self, loss_input_dict): predict = loss_input_dict['prediction'] soft_y = loss_input_dict['ground_truth'] + pix_w = loss_input_dict.get('pixel_weight', None) + cls_w = loss_input_dict.get('class_weight', None) if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) + predict = reshape_tensor_to_2D(predict) + soft_y = reshape_tensor_to_2D(soft_y) + if(pix_w is not None): + pix_w = reshape_tensor_to_2D(pix_w) + dice_loss = 1.0 - get_classwise_dice(predict, soft_y, pix_w) + if(cls_w is not None): + weighted_loss = dice_loss * cls_w + avg_loss = weighted_loss.sum() / cls_w.sum() + else: + avg_loss = dice_loss.mean() + return avg_loss + +class BinaryDiceLoss(AbstractSegLoss): + ''' + Fuse all the foreground classes together and calculate the Dice value. + ''' + def __init__(self, params = None): + super(BinaryDiceLoss, self).__init__(params) + + def forward(self, loss_input_dict): + predict = loss_input_dict['prediction'] + soft_y = loss_input_dict['ground_truth'] + + if(isinstance(predict, (list, tuple))): + predict = predict[0] + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) + predict = 1.0 - predict[:, :1, :, :, :] + soft_y = 1.0 - soft_y[:, :1, :, :, :] predict = reshape_tensor_to_2D(predict) soft_y = reshape_tensor_to_2D(soft_y) dice_score = get_classwise_dice(predict, soft_y) dice_loss = 1.0 - dice_score.mean() return dice_loss +class GroupDiceLoss(AbstractSegLoss): + ''' + Fuse all the foreground classes together and calculate the Dice value. + ''' + def __init__(self, params = None): + super(GroupDiceLoss, self).__init__(params) + self.group = 2 + + def forward(self, loss_input_dict): + predict = loss_input_dict['prediction'] + soft_y = loss_input_dict['ground_truth'] + + if(isinstance(predict, (list, tuple))): + predict = predict[0] + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) + predict = reshape_tensor_to_2D(predict) + soft_y = reshape_tensor_to_2D(soft_y) + num_class = list(predict.size())[1] + cls_per_group = (num_class - 1) // self.group + loss_all = 0.0 + for g in range(self.group): + c0 = 1 + g*cls_per_group + c1 = min(num_class, c0 + cls_per_group) + pred_g = torch.sum(predict[:, c0:c1], dim = 1, keepdim = True) + y_g = torch.sum( soft_y[:, c0:c1], dim = 1, keepdim = True) + loss_all += 1.0 - get_classwise_dice(pred_g, y_g)[0] + avg_loss = loss_all / self.group + return avg_loss + class FocalDiceLoss(AbstractSegLoss): """ Focal Dice according to the following paper: @@ -54,8 +115,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) predict = reshape_tensor_to_2D(predict) soft_y = reshape_tensor_to_2D(soft_y) @@ -88,8 +149,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) predict = reshape_tensor_to_2D(predict) soft_y = reshape_tensor_to_2D(soft_y) diff --git a/pymic/loss/seg/exp_log.py b/pymic/loss/seg/exp_log.py index c1b3f00..8c0d494 100644 --- a/pymic/loss/seg/exp_log.py +++ b/pymic/loss/seg/exp_log.py @@ -32,8 +32,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) predict = reshape_tensor_to_2D(predict) soft_y = reshape_tensor_to_2D(soft_y) diff --git a/pymic/loss/seg/mse.py b/pymic/loss/seg/mse.py index ad83899..eb53af4 100644 --- a/pymic/loss/seg/mse.py +++ b/pymic/loss/seg/mse.py @@ -19,8 +19,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) mse = torch.square(predict - soft_y) mse = torch.mean(mse) return mse @@ -40,11 +40,15 @@ def __init__(self, params = None): def forward(self, loss_input_dict): predict = loss_input_dict['prediction'] soft_y = loss_input_dict['ground_truth'] + weight = loss_input_dict.get('pixel_weight', None) if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) mae = torch.abs(predict - soft_y) - mae = torch.mean(mae) + if(weight is None): + mae = torch.mean(mae) + else: + mae = torch.sum(mae * weight) / weight.sum() return mae diff --git a/pymic/loss/seg/mumford_shah.py b/pymic/loss/seg/mumford_shah.py index 6da51b5..db9368a 100644 --- a/pymic/loss/seg/mumford_shah.py +++ b/pymic/loss/seg/mumford_shah.py @@ -3,8 +3,9 @@ import torch import torch.nn as nn +from pymic.loss.seg.abstract import AbstractSegLoss -class MumfordShahLoss(nn.Module): +class MumfordShahLoss(AbstractSegLoss): """ Implementation of Mumford Shah Loss for weakly supervised learning. @@ -76,8 +77,8 @@ def forward(self, loss_input_dict): image = loss_input_dict['image'] if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) pred_shape = list(predict.shape) if(len(pred_shape) == 5): diff --git a/pymic/loss/seg/slsr.py b/pymic/loss/seg/slsr.py index d5c4151..92adea7 100644 --- a/pymic/loss/seg/slsr.py +++ b/pymic/loss/seg/slsr.py @@ -38,8 +38,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) predict = reshape_tensor_to_2D(predict) soft_y = reshape_tensor_to_2D(soft_y) if(pix_w is not None): diff --git a/pymic/loss/seg/ssl.py b/pymic/loss/seg/ssl.py index 0bf276f..3a7430a 100644 --- a/pymic/loss/seg/ssl.py +++ b/pymic/loss/seg/ssl.py @@ -34,8 +34,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) # for numeric stability predict = predict * 0.999 + 5e-4 @@ -70,8 +70,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) # for numeric stability predict = predict * 0.999 + 5e-4 diff --git a/pymic/net/cls/torch_pretrained_net.py b/pymic/net/cls/torch_pretrained_net.py index 5017f72..9d05c28 100644 --- a/pymic/net/cls/torch_pretrained_net.py +++ b/pymic/net/cls/torch_pretrained_net.py @@ -2,7 +2,6 @@ from __future__ import print_function, division import itertools -import torch import torch.nn as nn import torchvision.models as models @@ -61,7 +60,49 @@ class ResNet18(BuiltInNet): """ def __init__(self, params): super(ResNet18, self).__init__(params) - self.net = models.resnet18(pretrained = self.pretrain) + weights = 'IMAGENET1K_V1' if self.pretrain else None + self.net = models.resnet18(weights = weights) + + # replace the last layer + num_ftrs = self.net.fc.in_features + self.net.fc = nn.Linear(num_ftrs, params['class_num']) + + # replace the first layer when in_chns is not 3 + if(self.in_chns != 3): + self.net.conv1 = nn.Conv2d(self.in_chns, 64, kernel_size=(7, 7), + stride=(2, 2), padding=(3, 3), bias=False) + + def get_parameters_to_update(self): + if(self.update_mode == "all"): + return self.net.parameters() + elif(self.update_mode == "last"): + params = self.net.fc.parameters() + if(self.in_chns !=3): + # combining the two iterables into a single one + # see: https://dzone.com/articles/python-joining-multiple + params = itertools.chain() + for pram in [self.net.fc.parameters(), self.net.conv1.parameters()]: + params = itertools.chain(params, pram) + return params + else: + raise(ValueError("update_mode can only be 'all' or 'last'.")) + +class ResNet50(BuiltInNet): + """ + ResNet18 for classification. + Parameters should be set in the `params` dictionary that contains the + following fields: + + :param input_chns: (int) Input channel number, default is 3. + :param pretrain: (bool) Using pretrained model or not, default is True. + :param update_mode: (str) The strategy for updating layers: "`all`" means updating + all the layers, and "`last`" (by default) means updating the last layer, + as well as the first layer when `input_chns` is not 3. + """ + def __init__(self, params): + super(ResNet50, self).__init__(params) + weights = 'IMAGENET1K_V1' if self.pretrain else None + self.net = models.resnet50(weights = weights) # replace the last layer num_ftrs = self.net.fc.in_features @@ -101,7 +142,8 @@ class VGG16(BuiltInNet): """ def __init__(self, params): super(VGG16, self).__init__(params) - self.net = models.vgg16(pretrained = self.pretrain) + weights = 'IMAGENET1K_V1' if self.pretrain else None + self.net = models.vgg16(weights = weights) # replace the last layer num_ftrs = self.net.classifier[-1].in_features @@ -139,7 +181,8 @@ class MobileNetV2(BuiltInNet): """ def __init__(self, params): super(MobileNetV2, self).__init__(params) - self.net = models.mobilenet_v2(pretrained = self.pretrain) + weights = 'IMAGENET1K_V1' if self.pretrain else None + self.net = models.mobilenet_v2(weights = weights) # replace the last layer num_ftrs = self.net.last_channel @@ -162,6 +205,38 @@ def get_parameters_to_update(self): return params else: raise(ValueError("update_mode can only be 'all' or 'last'.")) + +class ViTB16(BuiltInNet): + """ + ViTB16 for classification. + Parameters should be set in the `params` dictionary that contains the + following fields: + + :param input_chns: (int) Input channel number, default is 3. + :param pretrain: (bool) Using pretrained model or not, default is True. + :param update_mode: (str) The strategy for updating layers: "`all`" means updating + all the layers, and "`last`" (by default) means updating the last layer, + as well as the first layer when `input_chns` is not 3. + """ + def __init__(self, params): + super(ViTB16, self).__init__(params) + weights = 'IMAGENET1K_V1' if self.pretrain else None + self.net = models.vit_b_16(weights = weights) + + # replace the last layer + num_ftrs = self.net.representation_size + if(num_ftrs is None): + num_ftrs = self.net.hidden_dim + self.net.heads[-1] = nn.Linear(num_ftrs, params['class_num']) + + def get_parameters_to_update(self): + if(self.update_mode == "all"): + return self.net.parameters() + elif(self.update_mode == "last"): + params = self.net.heads[-1].parameters() + return params + else: + raise(ValueError("update_mode can only be 'all' or 'last'.")) if __name__ == "__main__": params = {"class_num": 2, "pretrain": False, "input_chns": 3} diff --git a/pymic/net/multi_net.py b/pymic/net/multi_net.py new file mode 100644 index 0000000..78209b1 --- /dev/null +++ b/pymic/net/multi_net.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import torch.nn as nn + +class MultiNet(nn.Module): + ''' + A combination of multiple networks. + Parameters should be saved in the `params` dictionary. + + :param `net_names`: (list) A list of network class name. + :param `infer_mode`: (int) Mode for inference. 0: only use the first network. + 1: taking an average of all the networks. + ''' + def __init__(self, net_dict, params): + super(MultiNet, self).__init__() + net_names = params['net_type'] # should be a list of network class name + self.output_mode = params.get('infer_mode', 0) + self.networks = nn.ModuleList([net_dict[item](params) for item in net_names]) + + def forward(self, x): + if(self.training): + output = [net(x) for net in self.networks] + else: + output = self.networks[0](x) + if(self.output_mode == 1): + for i in range(1, len(self.networks)): + output += self.networks[i](x) + output = output / len(self.networks) + return output + \ No newline at end of file diff --git a/pymic/net/net2d/cople_net.py b/pymic/net/net2d/cople_net.py index bd54fd6..046dee8 100644 --- a/pymic/net/net2d/cople_net.py +++ b/pymic/net/net2d/cople_net.py @@ -120,20 +120,30 @@ class UpBlock(nn.Module): Upssampling followed by ConvBNActBlock. """ def __init__(self, in_channels1, in_channels2, out_channels, - bilinear=True, dropout_p = 0.5): + up_mode = 2, dropout_p = 0.5): super(UpBlock, self).__init__() - self.bilinear = bilinear - if bilinear: - self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size = 1) - self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + if(isinstance(up_mode, int)): + up_mode_values = ["transconv", "nearest", "bilinear", "bicubic"] + if(up_mode > 3): + raise ValueError("The upsample mode should be 0-3, but {0:} is given.".format(up_mode)) + self.up_mode = up_mode_values[up_mode] else: + self.up_mode = up_mode.lower() + + if (self.up_mode == "transconv"): self.up = nn.ConvTranspose2d(in_channels1, in_channels2, kernel_size=2, stride=2) + else: + self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size = 1) + if(self.up_mode == "nearest"): + self.up = nn.Upsample(scale_factor=2, mode=self.up_mode) + else: + self.up = nn.Upsample(scale_factor=2, mode=self.up_mode, align_corners=True) self.conv = ConvBNActBlock(in_channels2 * 2, out_channels, dropout_p) def forward(self, x1, x2): - if self.bilinear: + if self.up_mode != "transconv": x1 = self.conv1x1(x1) - x1 = self.up(x1) + x1 = self.up(x1) x_cat = torch.cat([x2, x1], dim=1) y = self.conv(x_cat) return y + x_cat @@ -165,7 +175,7 @@ def __init__(self, params): self.ft_chns = self.params['feature_chns'] self.dropout = self.params['dropout'] self.n_class = self.params['class_num'] - self.bilinear = self.params['bilinear'] + self.up_mode = self.params.get('up_mode', 2) assert(len(self.ft_chns) == 5) f0_half = int(self.ft_chns[0] / 2) @@ -183,10 +193,10 @@ def __init__(self, params): self.bridge2= ConvLayer(self.ft_chns[2], f2_half) self.bridge3= ConvLayer(self.ft_chns[3], f3_half) - self.up1 = UpBlock(self.ft_chns[4], f3_half, self.ft_chns[3], dropout_p = self.dropout[3]) - self.up2 = UpBlock(self.ft_chns[3], f2_half, self.ft_chns[2], dropout_p = self.dropout[2]) - self.up3 = UpBlock(self.ft_chns[2], f1_half, self.ft_chns[1], dropout_p = self.dropout[1]) - self.up4 = UpBlock(self.ft_chns[1], f0_half, self.ft_chns[0], dropout_p = self.dropout[0]) + self.up1 = UpBlock(self.ft_chns[4], f3_half, self.ft_chns[3], self.up_mode, dropout_p = self.dropout[3]) + self.up2 = UpBlock(self.ft_chns[3], f2_half, self.ft_chns[2], self.up_mode, dropout_p = self.dropout[2]) + self.up3 = UpBlock(self.ft_chns[2], f1_half, self.ft_chns[1], self.up_mode, dropout_p = self.dropout[1]) + self.up4 = UpBlock(self.ft_chns[1], f0_half, self.ft_chns[0], self.up_mode, dropout_p = self.dropout[0]) f4 = self.ft_chns[4] aspp_chns = [int(f4 / 4), int(f4 / 4), int(f4 / 4), int(f4 / 4)] diff --git a/pymic/net/net2d/trans2d/__init__.py b/pymic/net/net2d/trans2d/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pymic/net/net2d/trans2d/swinunet.py b/pymic/net/net2d/trans2d/swinunet.py new file mode 100644 index 0000000..f35539a --- /dev/null +++ b/pymic/net/net2d/trans2d/swinunet.py @@ -0,0 +1,124 @@ +# -*- coding: utf-8 -*- +""" +code adapted from: https://github.com/HuCaoFighting/Swin-Unet + +""" +from __future__ import print_function, division + +import copy +import numpy as np +import torch +import torch.nn as nn + +from pymic.net.net2d.trans2d.swinunet_sys import SwinTransformerSys + +class SwinUNet(nn.Module): + """ + Implementatin of Swin-UNet. + + * Reference: Hu Cao, Yueyue Wang et al: + Swin-Unet: Unet-Like Pure Transformer for Medical Image Segmentation. + `ECCV 2022 Workshops. `_ + + Note that the input channel can only be 1 or 3, and the input image size should be 224x224. + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param img_size: (tuple) The input image size, should be [224, 224]. + :param class_num: (int) The class number for segmentation task. + """ + def __init__(self, params): + super(SwinUNet, self).__init__() + img_size = params['img_size'] + if(isinstance(img_size, tuple) or isinstance(img_size, list)): + img_size = img_size[0] + self.num_classes = params['class_num'] + self.swin_unet = SwinTransformerSys(img_size = img_size, num_classes=self.num_classes) + # self.swin_unet = SwinTransformerSys(img_size=config.DATA.IMG_SIZE, + # patch_size=config.MODEL.SWIN.PATCH_SIZE, + # in_chans=config.MODEL.SWIN.IN_CHANS, + # num_classes=self.num_classes, + # embed_dim=config.MODEL.SWIN.EMBED_DIM, + # depths=config.MODEL.SWIN.DEPTHS, + # num_heads=config.MODEL.SWIN.NUM_HEADS, + # window_size=config.MODEL.SWIN.WINDOW_SIZE, + # mlp_ratio=config.MODEL.SWIN.MLP_RATIO, + # qkv_bias=config.MODEL.SWIN.QKV_BIAS, + # qk_scale=config.MODEL.SWIN.QK_SCALE, + # drop_rate=config.MODEL.DROP_RATE, + # drop_path_rate=config.MODEL.DROP_PATH_RATE, + # ape=config.MODEL.SWIN.APE, + # patch_norm=config.MODEL.SWIN.PATCH_NORM, + # use_checkpoint=config.TRAIN.USE_CHECKPOINT) + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + + if x.size()[1] == 1: + x = x.repeat(1,3,1,1) + logits = self.swin_unet(x) + + if(len(x_shape) == 5): + new_shape = [N, D] + list(logits.shape)[1:] + logits = torch.reshape(logits, new_shape) + logits = torch.transpose(logits, 1, 2) + + return logits + + def load_from(self, config): + pretrained_path = config.MODEL.PRETRAIN_CKPT + if pretrained_path is not None: + print("pretrained_path:{}".format(pretrained_path)) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + pretrained_dict = torch.load(pretrained_path, map_location=device) + if "model" not in pretrained_dict: + print("---start load pretrained modle by splitting---") + pretrained_dict = {k[17:]:v for k,v in pretrained_dict.items()} + for k in list(pretrained_dict.keys()): + if "output" in k: + print("delete key:{}".format(k)) + del pretrained_dict[k] + msg = self.swin_unet.load_state_dict(pretrained_dict,strict=False) + # print(msg) + return + pretrained_dict = pretrained_dict['model'] + print("---start load pretrained modle of swin encoder---") + + model_dict = self.swin_unet.state_dict() + full_dict = copy.deepcopy(pretrained_dict) + for k, v in pretrained_dict.items(): + if "layers." in k: + current_layer_num = 3-int(k[7:8]) + current_k = "layers_up." + str(current_layer_num) + k[8:] + full_dict.update({current_k:v}) + for k in list(full_dict.keys()): + if k in model_dict: + if full_dict[k].shape != model_dict[k].shape: + print("delete:{};shape pretrain:{};shape model:{}".format(k,v.shape,model_dict[k].shape)) + del full_dict[k] + + msg = self.swin_unet.load_state_dict(full_dict, strict=False) + # print(msg) + else: + print("none pretrain") + + +if __name__ == "__main__": + params = {'img_size': [224, 224], + 'class_num': 2} + net = SwinUNet(params) + net.double() + + x = np.random.rand(4, 3, 224, 224) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = net(xt) + print(len(y.size())) + y = y.detach().numpy() + print(y.shape) \ No newline at end of file diff --git a/pymic/net/net2d/trans2d/swinunet_sys.py b/pymic/net/net2d/trans2d/swinunet_sys.py new file mode 100644 index 0000000..a6e3552 --- /dev/null +++ b/pymic/net/net2d/trans2d/swinunet_sys.py @@ -0,0 +1,749 @@ +# -*- coding: utf-8 -*- +""" +code adapted from: https://github.com/HuCaoFighting/Swin-Unet + +""" +from __future__ import print_function, division + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint +from einops import rearrange +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + +class PatchExpand(nn.Module): + def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.expand = nn.Linear(dim, 2*dim, bias=False) if dim_scale==2 else nn.Identity() + self.norm = norm_layer(dim // dim_scale) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + x = self.expand(x) + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4) + x = x.view(B,-1,C//4) + x= self.norm(x) + + return x + +class FinalPatchExpand_X4(nn.Module): + def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.dim_scale = dim_scale + self.expand = nn.Linear(dim, 16*dim, bias=False) + self.output_dim = dim + self.norm = norm_layer(self.output_dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + x = self.expand(x) + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//(self.dim_scale**2)) + x = x.view(B,-1,self.output_dim) + x= self.norm(x) + + return x + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + +class BasicLayer_up(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if upsample is not None: + self.upsample = PatchExpand(input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer) + else: + self.upsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.upsample is not None: + x = self.upsample(x) + return x + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + + +class SwinTransformerSys(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, + embed_dim=96, depths=[2, 2, 2, 2], depths_decoder=[1, 2, 2, 2], num_heads=[3, 6, 12, 24], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, final_upsample="expand_first", **kwargs): + super().__init__() + + print("SwinTransformerSys expand initial----depths:{};depths_decoder:{};drop_path_rate:{};num_classes:{}".format(depths, + depths_decoder,drop_path_rate,num_classes)) + + self.num_classes = num_classes + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.num_features_up = int(embed_dim * 2) + self.mlp_ratio = mlp_ratio + self.final_upsample = final_upsample + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build encoder and bottleneck layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), + input_resolution=(patches_resolution[0] // (2 ** i_layer), + patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) + self.layers.append(layer) + + # build decoder layers + self.layers_up = nn.ModuleList() + self.concat_back_dim = nn.ModuleList() + for i_layer in range(self.num_layers): + concat_linear = nn.Linear(2*int(embed_dim*2**(self.num_layers-1-i_layer)), + int(embed_dim*2**(self.num_layers-1-i_layer))) if i_layer > 0 else nn.Identity() + if i_layer ==0 : + layer_up = PatchExpand(input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)), + patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), dim_scale=2, norm_layer=norm_layer) + else: + layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), + input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)), + patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), + depth=depths[(self.num_layers-1-i_layer)], + num_heads=num_heads[(self.num_layers-1-i_layer)], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:(self.num_layers-1-i_layer)]):sum(depths[:(self.num_layers-1-i_layer) + 1])], + norm_layer=norm_layer, + upsample=PatchExpand if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) + self.layers_up.append(layer_up) + self.concat_back_dim.append(concat_linear) + + self.norm = norm_layer(self.num_features) + self.norm_up= norm_layer(self.embed_dim) + + if self.final_upsample == "expand_first": + print("---final upsample expand_first---") + self.up = FinalPatchExpand_X4(input_resolution=(img_size//patch_size,img_size//patch_size),dim_scale=4,dim=embed_dim) + self.output = nn.Conv2d(in_channels=embed_dim,out_channels=self.num_classes,kernel_size=1,bias=False) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + #Encoder and Bottleneck + def forward_features(self, x): + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + x_downsample = [] + + for layer in self.layers: + x_downsample.append(x) + x = layer(x) + + x = self.norm(x) # B L C + + return x, x_downsample + + #Dencoder and Skip connection + def forward_up_features(self, x, x_downsample): + for inx, layer_up in enumerate(self.layers_up): + if inx == 0: + x = layer_up(x) + else: + x = torch.cat([x,x_downsample[3-inx]],-1) + x = self.concat_back_dim[inx](x) + x = layer_up(x) + + x = self.norm_up(x) # B L C + + return x + + def up_x4(self, x): + H, W = self.patches_resolution + B, L, C = x.shape + assert L == H*W, "input features has wrong size" + + if self.final_upsample=="expand_first": + x = self.up(x) + x = x.view(B,4*H,4*W,-1) + x = x.permute(0,3,1,2) #B,C,H,W + x = self.output(x) + + return x + + def forward(self, x): + x, x_downsample = self.forward_features(x) + x = self.forward_up_features(x,x_downsample) + x = self.up_x4(x) + + return x + + def flops(self): + flops = 0 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) + flops += self.num_features * self.num_classes + return \ No newline at end of file diff --git a/pymic/net/net2d/trans2d/transunet.py b/pymic/net/net2d/trans2d/transunet.py new file mode 100644 index 0000000..9db5d2d --- /dev/null +++ b/pymic/net/net2d/trans2d/transunet.py @@ -0,0 +1,491 @@ +# -*- coding: utf-8 -*- +""" +code adapted from: https://github.com/Beckschen/TransUNet +""" +from __future__ import print_function, division + +import copy +# import logging +import math +import torch +import torch.nn as nn +from torch.nn import Dropout, Softmax, Linear, Conv2d, LayerNorm +from torch.nn.modules.utils import _pair + +import numpy as np +from scipy import ndimage +from os.path import join as pjoin +import pymic.net.net2d.trans2d.transunet_cfg as configs +from pymic.net.net2d.trans2d.transunet_resnet import ResNetV2 + + +VIT_CONFIGS = { + 'ViT-B_16': configs.get_b16_config(), + 'ViT-B_32': configs.get_b32_config(), + 'ViT-L_16': configs.get_l16_config(), + 'ViT-L_32': configs.get_l32_config(), + 'ViT-H_14': configs.get_h14_config(), + 'R50-ViT-B_16': configs.get_r50_b16_config(), + 'R50-ViT-L_16': configs.get_r50_l16_config(), + 'testing': configs.get_testing(), +} + +ATTENTION_Q = "MultiHeadDotProductAttention_1/query" +ATTENTION_K = "MultiHeadDotProductAttention_1/key" +ATTENTION_V = "MultiHeadDotProductAttention_1/value" +ATTENTION_OUT = "MultiHeadDotProductAttention_1/out" +FC_0 = "MlpBlock_3/Dense_0" +FC_1 = "MlpBlock_3/Dense_1" +ATTENTION_NORM = "LayerNorm_0" +MLP_NORM = "LayerNorm_2" + + +def np2th(weights, conv=False): + """Possibly convert HWIO to OIHW.""" + if conv: + weights = weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(weights) + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} + +class Attention(nn.Module): + def __init__(self, config, vis): + super(Attention, self).__init__() + self.vis = vis + self.num_attention_heads = config.transformer["num_heads"] + self.attention_head_size = int(config.hidden_size / self.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = Linear(config.hidden_size, self.all_head_size) + self.key = Linear(config.hidden_size, self.all_head_size) + self.value = Linear(config.hidden_size, self.all_head_size) + + self.out = Linear(config.hidden_size, config.hidden_size) + self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"]) + self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"]) + + self.softmax = Softmax(dim=-1) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + attention_probs = self.softmax(attention_scores) + weights = attention_probs if self.vis else None + attention_probs = self.attn_dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + attention_output = self.out(context_layer) + attention_output = self.proj_dropout(attention_output) + return attention_output, weights + +class Mlp(nn.Module): + def __init__(self, config): + super(Mlp, self).__init__() + self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"]) + self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size) + self.act_fn = ACT2FN["gelu"] + self.dropout = Dropout(config.transformer["dropout_rate"]) + + self._init_weights() + + def _init_weights(self): + nn.init.xavier_uniform_(self.fc1.weight) + nn.init.xavier_uniform_(self.fc2.weight) + nn.init.normal_(self.fc1.bias, std=1e-6) + nn.init.normal_(self.fc2.bias, std=1e-6) + + def forward(self, x): + x = self.fc1(x) + x = self.act_fn(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + +class Embeddings(nn.Module): + """Construct the embeddings from patch, position embeddings. + """ + def __init__(self, config, img_size, in_channels=3): + super(Embeddings, self).__init__() + self.hybrid = None + self.config = config + img_size = _pair(img_size) + + if config.patches.get("grid") is not None: # ResNet + grid_size = config.patches["grid"] + patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1]) + patch_size_real = (patch_size[0] * 16, patch_size[1] * 16) + n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1]) + self.hybrid = True + else: + patch_size = _pair(config.patches["size"]) + n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) + self.hybrid = False + + if self.hybrid: + self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor) + in_channels = self.hybrid_model.width * 16 + self.patch_embeddings = Conv2d(in_channels=in_channels, + out_channels=config.hidden_size, + kernel_size=patch_size, + stride=patch_size) + self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size)) + + self.dropout = Dropout(config.transformer["dropout_rate"]) + + + def forward(self, x): + if self.hybrid: + x, features = self.hybrid_model(x) + else: + features = None + x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2)) + x = x.flatten(2) + x = x.transpose(-1, -2) # (B, n_patches, hidden) + + embeddings = x + self.position_embeddings + embeddings = self.dropout(embeddings) + return embeddings, features + +class Block(nn.Module): + def __init__(self, config, vis): + super(Block, self).__init__() + self.hidden_size = config.hidden_size + self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6) + self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6) + self.ffn = Mlp(config) + self.attn = Attention(config, vis) + + def forward(self, x): + h = x + x = self.attention_norm(x) + x, weights = self.attn(x) + x = x + h + + h = x + x = self.ffn_norm(x) + x = self.ffn(x) + x = x + h + return x, weights + + def load_from(self, weights, n_block): + ROOT = f"Transformer/encoderblock_{n_block}" + with torch.no_grad(): + query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t() + key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t() + value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t() + out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t() + + query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1) + key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1) + value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1) + out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1) + + self.attn.query.weight.copy_(query_weight) + self.attn.key.weight.copy_(key_weight) + self.attn.value.weight.copy_(value_weight) + self.attn.out.weight.copy_(out_weight) + self.attn.query.bias.copy_(query_bias) + self.attn.key.bias.copy_(key_bias) + self.attn.value.bias.copy_(value_bias) + self.attn.out.bias.copy_(out_bias) + + mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t() + mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t() + mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t() + mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t() + + self.ffn.fc1.weight.copy_(mlp_weight_0) + self.ffn.fc2.weight.copy_(mlp_weight_1) + self.ffn.fc1.bias.copy_(mlp_bias_0) + self.ffn.fc2.bias.copy_(mlp_bias_1) + + self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")])) + self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")])) + self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")])) + self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")])) + +class Encoder(nn.Module): + def __init__(self, config, vis): + super(Encoder, self).__init__() + self.vis = vis + self.layer = nn.ModuleList() + self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6) + for _ in range(config.transformer["num_layers"]): + layer = Block(config, vis) + self.layer.append(copy.deepcopy(layer)) + + def forward(self, hidden_states): + attn_weights = [] + for layer_block in self.layer: + hidden_states, weights = layer_block(hidden_states) + if self.vis: + attn_weights.append(weights) + encoded = self.encoder_norm(hidden_states) + return encoded, attn_weights + +class Transformer(nn.Module): + def __init__(self, config, img_size, vis): + super(Transformer, self).__init__() + self.embeddings = Embeddings(config, img_size=img_size) + self.encoder = Encoder(config, vis) + + def forward(self, input_ids): + embedding_output, features = self.embeddings(input_ids) + encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden) + return encoded, attn_weights, features + +class Conv2dReLU(nn.Sequential): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + padding=0, + stride=1, + use_batchnorm=True, + ): + conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=not (use_batchnorm), + ) + relu = nn.ReLU(inplace=True) + + bn = nn.BatchNorm2d(out_channels) + + super(Conv2dReLU, self).__init__(conv, bn, relu) + + +class DecoderBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + skip_channels=0, + use_batchnorm=True, + ): + super().__init__() + self.conv1 = Conv2dReLU( + in_channels + skip_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.conv2 = Conv2dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.up = nn.UpsamplingBilinear2d(scale_factor=2) + + def forward(self, x, skip=None): + x = self.up(x) + if skip is not None: + x = torch.cat([x, skip], dim=1) + x = self.conv1(x) + x = self.conv2(x) + return x + + +class SegmentationHead(nn.Sequential): + + def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1): + conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) + upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity() + super().__init__(conv2d, upsampling) + + +class DecoderCup(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + head_channels = 512 + self.conv_more = Conv2dReLU( + config.hidden_size, + head_channels, + kernel_size=3, + padding=1, + use_batchnorm=True, + ) + decoder_channels = config.decoder_channels + in_channels = [head_channels] + list(decoder_channels[:-1]) + out_channels = decoder_channels + + if self.config.n_skip != 0: + skip_channels = self.config.skip_channels + for i in range(4-self.config.n_skip): # re-select the skip channels according to n_skip + skip_channels[3-i]=0 + + else: + skip_channels=[0,0,0,0] + + blocks = [ + DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels) + ] + self.blocks = nn.ModuleList(blocks) + + def forward(self, hidden_states, features=None): + B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden) + h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch)) + x = hidden_states.permute(0, 2, 1) + x = x.contiguous().view(B, hidden, h, w) + x = self.conv_more(x) + for i, decoder_block in enumerate(self.blocks): + if features is not None: + skip = features[i] if (i < self.config.n_skip) else None + else: + skip = None + x = decoder_block(x, skip=skip) + return x + +class TransUNet(nn.Module): + """ + Implementatin of TransUNet. + + * Reference: Jieneng Chen, Yongyi Lu et al: + TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation. + `Arxiv 2021. `_ + + Note that the input channel can only be 1 or 3, and the input image size should be 256x256. + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param img_size: (tuple) The input image size, should be [256, 256]. + :param class_num: (int) The class number for segmentation task. + :param vit_name: (string) The name for vit backbone. It can be one of the following: 'ViT-B_16', + 'ViT-B_32','ViT-L_16', 'ViT-L_32', 'ViT-H_14'. 'R50-ViT-B_16', 'R50-ViT-L_16'. + By default, it is 'R50-ViT-B_16'. + """ + def __init__(self, params): + super(TransUNet, self).__init__() + vit_name = params.get("vit_name", 'R50-ViT-B_16') + img_size = params['img_size'] + vis = params.get("vis", False) + self.config = VIT_CONFIGS[vit_name] + self.num_classes = params['class_num'] + self.zero_head = params.get("zero_head", False) + + self.classifier = self.config.classifier + self.transformer = Transformer(self.config, img_size, vis) + self.decoder = DecoderCup(self.config) + self.segmentation_head = SegmentationHead( + in_channels=self.config['decoder_channels'][-1], + out_channels=self.num_classes, + kernel_size=3, + ) + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + + if x.size()[1] == 1: + x = x.repeat(1,3,1,1) + elif(x.size()[1] !=3): + raise ValueError("The input channel number should be 1 or 3 for TransUNet") + x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden) + x = self.decoder(x, features) + logits = self.segmentation_head(x) + + if(len(x_shape) == 5): + new_shape = [N, D] + list(logits.shape)[1:] + logits = torch.reshape(logits, new_shape) + logits = torch.transpose(logits, 1, 2) + + return logits + + def load_from(self, weights): + with torch.no_grad(): + + res_weight = weights + self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True)) + self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"])) + + self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"])) + self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"])) + + posemb = np2th(weights["Transformer/posembed_input/pos_embedding"]) + + posemb_new = self.transformer.embeddings.position_embeddings + if posemb.size() == posemb_new.size(): + self.transformer.embeddings.position_embeddings.copy_(posemb) + elif posemb.size()[1]-1 == posemb_new.size()[1]: + posemb = posemb[:, 1:] + self.transformer.embeddings.position_embeddings.copy_(posemb) + else: + # logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size())) + ntok_new = posemb_new.size(1) + if self.classifier == "seg": + _, posemb_grid = posemb[:, :1], posemb[0, 1:] + gs_old = int(np.sqrt(len(posemb_grid))) + gs_new = int(np.sqrt(ntok_new)) + print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new)) + posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) + zoom = (gs_new / gs_old, gs_new / gs_old, 1) + posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) # th2np + posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) + posemb = posemb_grid + self.transformer.embeddings.position_embeddings.copy_(np2th(posemb)) + + # Encoder whole + for bname, block in self.transformer.encoder.named_children(): + for uname, unit in block.named_children(): + unit.load_from(weights, n_block=uname) + + if self.transformer.embeddings.hybrid: + self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(res_weight["conv_root/kernel"], conv=True)) + gn_weight = np2th(res_weight["gn_root/scale"]).view(-1) + gn_bias = np2th(res_weight["gn_root/bias"]).view(-1) + self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight) + self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias) + + for bname, block in self.transformer.embeddings.hybrid_model.body.named_children(): + for uname, unit in block.named_children(): + unit.load_from(res_weight, n_block=bname, n_unit=uname) + +if __name__ == "__main__": + params = {'img_size': [256, 256], + 'class_num': 2} + net = TransUNet(params) + net.double() + + for c in [1,3]: + x = np.random.rand(4, c, 256, 256) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = net(xt) + print(len(y.size())) + y = y.detach().numpy() + print(y.shape) \ No newline at end of file diff --git a/pymic/net/net2d/trans2d/transunet_cfg.py b/pymic/net/net2d/trans2d/transunet_cfg.py new file mode 100644 index 0000000..aab62d4 --- /dev/null +++ b/pymic/net/net2d/trans2d/transunet_cfg.py @@ -0,0 +1,135 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +""" +code adapted from: https://github.com/Beckschen/TransUNet +""" +import ml_collections + +def get_b16_config(): + """Returns the ViT-B/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.hidden_size = 768 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 3072 + config.transformer.num_heads = 12 + config.transformer.num_layers = 12 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + + config.classifier = 'seg' + config.representation_size = None + config.resnet_pretrained_path = None + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_16.npz' + config.patch_size = 16 + + config.decoder_channels = (256, 128, 64, 16) + config.n_classes = 2 + config.activation = 'softmax' + return config + + +def get_testing(): + """Returns a minimal configuration for testing.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.hidden_size = 1 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 1 + config.transformer.num_heads = 1 + config.transformer.num_layers = 1 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.classifier = 'token' + config.representation_size = None + return config + +def get_r50_b16_config(): + """Returns the Resnet50 + ViT-B/16 configuration.""" + config = get_b16_config() + config.patches.grid = (16, 16) + config.resnet = ml_collections.ConfigDict() + config.resnet.num_layers = (3, 4, 9) + config.resnet.width_factor = 1 + + config.classifier = 'seg' + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz' + config.decoder_channels = (256, 128, 64, 16) + config.skip_channels = [512, 256, 64, 16] + config.n_classes = 2 + config.n_skip = 3 + config.activation = 'softmax' + + return config + + +def get_b32_config(): + """Returns the ViT-B/32 configuration.""" + config = get_b16_config() + config.patches.size = (32, 32) + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_32.npz' + return config + + +def get_l16_config(): + """Returns the ViT-L/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.hidden_size = 1024 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 4096 + config.transformer.num_heads = 16 + config.transformer.num_layers = 24 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.representation_size = None + + # custom + config.classifier = 'seg' + config.resnet_pretrained_path = None + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-L_16.npz' + config.decoder_channels = (256, 128, 64, 16) + config.n_classes = 2 + config.activation = 'softmax' + return config + + +def get_r50_l16_config(): + """Returns the Resnet50 + ViT-L/16 configuration. customized """ + config = get_l16_config() + config.patches.grid = (16, 16) + config.resnet = ml_collections.ConfigDict() + config.resnet.num_layers = (3, 4, 9) + config.resnet.width_factor = 1 + + config.classifier = 'seg' + config.resnet_pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz' + config.decoder_channels = (256, 128, 64, 16) + config.skip_channels = [512, 256, 64, 16] + config.n_classes = 2 + config.activation = 'softmax' + return config + + +def get_l32_config(): + """Returns the ViT-L/32 configuration.""" + config = get_l16_config() + config.patches.size = (32, 32) + return config + + +def get_h14_config(): + """Returns the ViT-L/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (14, 14)}) + config.hidden_size = 1280 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 5120 + config.transformer.num_heads = 16 + config.transformer.num_layers = 32 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.classifier = 'token' + config.representation_size = None + + return config \ No newline at end of file diff --git a/pymic/net/net2d/trans2d/transunet_resnet.py b/pymic/net/net2d/trans2d/transunet_resnet.py new file mode 100644 index 0000000..144a268 --- /dev/null +++ b/pymic/net/net2d/trans2d/transunet_resnet.py @@ -0,0 +1,164 @@ +# -*- coding: utf-8 -*- +""" +code adapted from: https://github.com/Beckschen/TransUNet +""" +from __future__ import print_function, division + +from os.path import join as pjoin +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def np2th(weights, conv=False): + """Possibly convert HWIO to OIHW.""" + if conv: + weights = weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(weights) + + +class StdConv2d(nn.Conv2d): + + def forward(self, x): + w = self.weight + v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False) + w = (w - m) / torch.sqrt(v + 1e-5) + return F.conv2d(x, w, self.bias, self.stride, self.padding, + self.dilation, self.groups) + + +def conv3x3(cin, cout, stride=1, groups=1, bias=False): + return StdConv2d(cin, cout, kernel_size=3, stride=stride, + padding=1, bias=bias, groups=groups) + + +def conv1x1(cin, cout, stride=1, bias=False): + return StdConv2d(cin, cout, kernel_size=1, stride=stride, + padding=0, bias=bias) + + +class PreActBottleneck(nn.Module): + """Pre-activation (v2) bottleneck block. + """ + + def __init__(self, cin, cout=None, cmid=None, stride=1): + super().__init__() + cout = cout or cin + cmid = cmid or cout//4 + + self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6) + self.conv1 = conv1x1(cin, cmid, bias=False) + self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6) + self.conv2 = conv3x3(cmid, cmid, stride, bias=False) # Original code has it on conv1!! + self.gn3 = nn.GroupNorm(32, cout, eps=1e-6) + self.conv3 = conv1x1(cmid, cout, bias=False) + self.relu = nn.ReLU(inplace=True) + + if (stride != 1 or cin != cout): + # Projection also with pre-activation according to paper. + self.downsample = conv1x1(cin, cout, stride, bias=False) + self.gn_proj = nn.GroupNorm(cout, cout) + + def forward(self, x): + + # Residual branch + residual = x + if hasattr(self, 'downsample'): + residual = self.downsample(x) + residual = self.gn_proj(residual) + + # Unit's branch + y = self.relu(self.gn1(self.conv1(x))) + y = self.relu(self.gn2(self.conv2(y))) + y = self.gn3(self.conv3(y)) + + y = self.relu(residual + y) + return y + + def load_from(self, weights, n_block, n_unit): + conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True) + conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True) + conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True) + + gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")]) + gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")]) + + gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")]) + gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")]) + + gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")]) + gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")]) + + self.conv1.weight.copy_(conv1_weight) + self.conv2.weight.copy_(conv2_weight) + self.conv3.weight.copy_(conv3_weight) + + self.gn1.weight.copy_(gn1_weight.view(-1)) + self.gn1.bias.copy_(gn1_bias.view(-1)) + + self.gn2.weight.copy_(gn2_weight.view(-1)) + self.gn2.bias.copy_(gn2_bias.view(-1)) + + self.gn3.weight.copy_(gn3_weight.view(-1)) + self.gn3.bias.copy_(gn3_bias.view(-1)) + + if hasattr(self, 'downsample'): + proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True) + proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")]) + proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")]) + + self.downsample.weight.copy_(proj_conv_weight) + self.gn_proj.weight.copy_(proj_gn_weight.view(-1)) + self.gn_proj.bias.copy_(proj_gn_bias.view(-1)) + +class ResNetV2(nn.Module): + """Implementation of Pre-activation (v2) ResNet mode.""" + + def __init__(self, block_units, width_factor): + super().__init__() + width = int(64 * width_factor) + self.width = width + + self.root = nn.Sequential(OrderedDict([ + ('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)), + ('gn', nn.GroupNorm(32, width, eps=1e-6)), + ('relu', nn.ReLU(inplace=True)), + # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)) + ])) + + self.body = nn.Sequential(OrderedDict([ + ('block1', nn.Sequential(OrderedDict( + [('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] + + [(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)], + ))), + ('block2', nn.Sequential(OrderedDict( + [('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] + + [(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)], + ))), + ('block3', nn.Sequential(OrderedDict( + [('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] + + [(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)], + ))), + ])) + + def forward(self, x): + features = [] + b, c, in_size, _ = x.size() + x = self.root(x) + features.append(x) + x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x) + for i in range(len(self.body)-1): + x = self.body[i](x) + right_size = int(in_size / 4 / (i+1)) + if x.size()[2] != right_size: + pad = right_size - x.size()[2] + assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), right_size) + feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device) + feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:] + else: + feat = x + features.append(feat) + x = self.body[-1](x) + return x, features[::-1] \ No newline at end of file diff --git a/pymic/net/net2d/umamba.py b/pymic/net/net2d/umamba.py new file mode 100644 index 0000000..63ccacf --- /dev/null +++ b/pymic/net/net2d/umamba.py @@ -0,0 +1,1234 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import logging +import numpy as np +import math +import torch +from torch import nn +from torch.nn import functional as F +from typing import Union, Type, List, Tuple + +from torch.nn.modules.conv import _ConvNd +from torch.nn.modules.dropout import _DropoutNd +from mamba_ssm import Mamba +from torch.cuda.amp import autocast + +# from dynamic_network_architectures.building_blocks.helper import convert_conv_op_to_dim +# from nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager, PlansManager +# from dynamic_network_architectures.building_blocks.helper import get_matching_instancenorm, convert_dim_to_conv_op +# from nnunetv2.utilities.network_initialization import InitWeights_He +# from dynamic_network_architectures.building_blocks.helper import maybe_convert_scalar_to_list, get_matching_pool_op + +def dim_of_conv_op(conv_op: Type[_ConvNd]) -> int: + """ + :param conv_op: conv class + :return: dimension: 1, 2 or 3 + """ + if conv_op == nn.Conv1d: + return 1 + elif conv_op == nn.Conv2d: + return 2 + elif conv_op == nn.Conv3d: + return 3 + else: + raise ValueError("Unknown dimension. Only 1d 2d and 3d conv are supported. got %s" % str(conv_op)) + +def get_matching_pool_op(conv_op: Type[_ConvNd] = None, + dimension: int = None, + adaptive=False, + pool_type: str = 'avg') -> Type[torch.nn.Module]: + """ + You MUST set EITHER conv_op OR dimension. Do not set both! + :param conv_op: + :param dimension: + :param adaptive: + :param pool_type: either 'avg' or 'max' + :return: + """ + assert not ((conv_op is not None) and (dimension is not None)), \ + "You MUST set EITHER conv_op OR dimension. Do not set both!" + assert pool_type in ['avg', 'max'], 'pool_type must be either avg or max' + if conv_op is not None: + dimension = dim_of_conv_op(conv_op) + assert dimension in [1, 2, 3], 'Dimension must be 1, 2 or 3' + + if conv_op is not None: + dimension = dim_of_conv_op(conv_op) + + if dimension == 1: + if pool_type == 'avg': + if adaptive: + return nn.AdaptiveAvgPool1d + else: + return nn.AvgPool1d + elif pool_type == 'max': + if adaptive: + return nn.AdaptiveMaxPool1d + else: + return nn.MaxPool1d + elif dimension == 2: + if pool_type == 'avg': + if adaptive: + return nn.AdaptiveAvgPool2d + else: + return nn.AvgPool2d + elif pool_type == 'max': + if adaptive: + return nn.AdaptiveMaxPool2d + else: + return nn.MaxPool2d + elif dimension == 3: + if pool_type == 'avg': + if adaptive: + return nn.AdaptiveAvgPool3d + else: + return nn.AvgPool3d + elif pool_type == 'max': + if adaptive: + return nn.AdaptiveMaxPool3d + else: + return nn.MaxPool3d + +def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): + """ + This function is taken from the timm package (https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py). + + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + +class DropPath(nn.Module): + """ + This class is taken from the timm package (https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py). + + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + +def make_divisible(v, divisor=8, min_value=None, round_limit=.9): + """ + This function is taken from the timm package (https://github.com/rwightman/pytorch-image-models/blob/b7cb8d0337b3e7b50516849805ddb9be5fc11644/timm/models/layers/helpers.py#L25) + """ + min_value = min_value or divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < round_limit * v: + new_v += divisor + return new_v + +class SqueezeExcite(nn.Module): + """ + This class is taken from the timm package (https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/squeeze_excite.py) + and slightly modified so that the convolution type can be adapted. + + SE Module as defined in original SE-Nets with a few additions + Additions include: + * divisor can be specified to keep channels % div == 0 (default: 8) + * reduction channels can be specified directly by arg (if rd_channels is set) + * reduction channels can be specified by float rd_ratio (default: 1/16) + * global max pooling can be added to the squeeze aggregation + * customizable activation, normalization, and gate layer + """ + def __init__( + self, channels, conv_op, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False, + act_layer=nn.ReLU, norm_layer=None, gate_layer=nn.Sigmoid): + super(SqueezeExcite, self).__init__() + self.add_maxpool = add_maxpool + if not rd_channels: + rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) + self.fc1 = conv_op(channels, rd_channels, kernel_size=1, bias=True) + self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity() + self.act = act_layer(inplace=True) + self.fc2 = conv_op(rd_channels, channels, kernel_size=1, bias=True) + self.gate = gate_layer() + + def forward(self, x): + x_se = x.mean((2, 3), keepdim=True) + if self.add_maxpool: + # experimental codepath, may remove or change + x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True) + x_se = self.fc1(x_se) + x_se = self.act(self.bn(x_se)) + x_se = self.fc2(x_se) + return x * self.gate(x_se) + + +class ConvDropoutNormReLU(nn.Module): + def __init__(self, + conv_op: Type[_ConvNd], + input_channels: int, + output_channels: int, + kernel_size: Union[int, List[int], Tuple[int, ...]], + stride: Union[int, List[int], Tuple[int, ...]], + conv_bias: bool = False, + norm_op: Union[None, Type[nn.Module]] = None, + norm_op_kwargs: dict = None, + dropout_op: Union[None, Type[_DropoutNd]] = None, + dropout_op_kwargs: dict = None, + nonlin: Union[None, Type[torch.nn.Module]] = None, + nonlin_kwargs: dict = None, + nonlin_first: bool = False + ): + super(ConvDropoutNormReLU, self).__init__() + self.input_channels = input_channels + self.output_channels = output_channels + if not isinstance(stride, (tuple, list, np.ndarray)): + stride = [stride] * dim_of_conv_op(conv_op) + self.stride = stride + + if not isinstance(kernel_size, (tuple, list, np.ndarray)): + kernel_size = [kernel_size] * dim_of_conv_op(conv_op) + if norm_op_kwargs is None: + norm_op_kwargs = {} + if nonlin_kwargs is None: + nonlin_kwargs = {} + + ops = [] + + self.conv = conv_op( + input_channels, + output_channels, + kernel_size, + stride, + padding=[(i - 1) // 2 for i in kernel_size], + dilation=1, + bias=conv_bias, + ) + ops.append(self.conv) + + if dropout_op is not None: + self.dropout = dropout_op(**dropout_op_kwargs) + ops.append(self.dropout) + + if norm_op is not None: + self.norm = norm_op(output_channels, **norm_op_kwargs) + ops.append(self.norm) + + if nonlin is not None: + self.nonlin = nonlin(**nonlin_kwargs) + ops.append(self.nonlin) + + if nonlin_first and (norm_op is not None and nonlin is not None): + ops[-1], ops[-2] = ops[-2], ops[-1] + + self.all_modules = nn.Sequential(*ops) + + def forward(self, x): + return self.all_modules(x) + + def compute_conv_feature_map_size(self, input_size): + assert len(input_size) == len(self.stride), "just give the image size without color/feature channels or " \ + "batch channel. Do not give input_size=(b, c, x, y(, z)). " \ + "Give input_size=(x, y(, z))!" + output_size = [i // j for i, j in zip(input_size, self.stride)] # we always do same padding + return np.prod([self.output_channels, *output_size], dtype=np.int64) + +# # from dynamic_network_architectures.building_blocks.residual import BasicBlockD +class BasicBlockD(nn.Module): + def __init__(self, + conv_op: Type[_ConvNd], + input_channels: int, + output_channels: int, + kernel_size: Union[int, List[int], Tuple[int, ...]], + stride: Union[int, List[int], Tuple[int, ...]], + conv_bias: bool = False, + norm_op: Union[None, Type[nn.Module]] = None, + norm_op_kwargs: dict = None, + dropout_op: Union[None, Type[_DropoutNd]] = None, + dropout_op_kwargs: dict = None, + nonlin: Union[None, Type[torch.nn.Module]] = None, + nonlin_kwargs: dict = None, + stochastic_depth_p: float = 0.0, + squeeze_excitation: bool = False, + squeeze_excitation_reduction_ratio: float = 1. / 16, + # todo wideresnet? + ): + """ + This implementation follows ResNet-D: + + He, Tong, et al. "Bag of tricks for image classification with convolutional neural networks." + Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2019. + + The skip has an avgpool (if needed) followed by 1x1 conv instead of just a strided 1x1 conv + + :param conv_op: + :param input_channels: + :param output_channels: + :param kernel_size: refers only to convs in feature extraction path, not to 1x1x1 conv in skip + :param stride: only applies to first conv (and skip). Second conv always has stride 1 + :param conv_bias: + :param norm_op: + :param norm_op_kwargs: + :param dropout_op: only the first conv can have dropout. The second never has + :param dropout_op_kwargs: + :param nonlin: + :param nonlin_kwargs: + :param stochastic_depth_p: + :param squeeze_excitation: + :param squeeze_excitation_reduction_ratio: + """ + super().__init__() + self.input_channels = input_channels + self.output_channels = output_channels + if not isinstance(stride, (tuple, list, np.ndarray)): + stride = [stride] * dim_of_conv_op(conv_op) + self.stride = stride + + if not isinstance(kernel_size, (tuple, list, np.ndarray)): + kernel_size = [kernel_size] * dim_of_conv_op(conv_op) + + if norm_op_kwargs is None: + norm_op_kwargs = {} + if nonlin_kwargs is None: + nonlin_kwargs = {} + + self.conv1 = ConvDropoutNormReLU(conv_op, input_channels, output_channels, kernel_size, stride, conv_bias, + norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs) + self.conv2 = ConvDropoutNormReLU(conv_op, output_channels, output_channels, kernel_size, 1, conv_bias, norm_op, + norm_op_kwargs, None, None, None, None) + + self.nonlin2 = nonlin(**nonlin_kwargs) if nonlin is not None else lambda x: x + + # Stochastic Depth + self.apply_stochastic_depth = False if stochastic_depth_p == 0.0 else True + if self.apply_stochastic_depth: + self.drop_path = DropPath(drop_prob=stochastic_depth_p) + + # Squeeze Excitation + self.apply_se = squeeze_excitation + if self.apply_se: + self.squeeze_excitation = SqueezeExcite(self.output_channels, conv_op, + rd_ratio=squeeze_excitation_reduction_ratio, rd_divisor=8) + + has_stride = (isinstance(stride, int) and stride != 1) or any([i != 1 for i in stride]) + requires_projection = (input_channels != output_channels) + + if has_stride or requires_projection: + ops = [] + if has_stride: + ops.append(get_matching_pool_op(conv_op=conv_op, adaptive=False, pool_type='avg')(stride, stride)) + if requires_projection: + ops.append( + ConvDropoutNormReLU(conv_op, input_channels, output_channels, 1, 1, False, norm_op, + norm_op_kwargs, None, None, None, None + ) + ) + self.skip = nn.Sequential(*ops) + else: + self.skip = lambda x: x + + def forward(self, x): + residual = self.skip(x) + out = self.conv2(self.conv1(x)) + if self.apply_stochastic_depth: + out = self.drop_path(out) + if self.apply_se: + out = self.squeeze_excitation(out) + out += residual + return self.nonlin2(out) + + def compute_conv_feature_map_size(self, input_size): + assert len(input_size) == len(self.stride), "just give the image size without color/feature channels or " \ + "batch channel. Do not give input_size=(b, c, x, y(, z)). " \ + "Give input_size=(x, y(, z))!" + size_after_stride = [i // j for i, j in zip(input_size, self.stride)] + # conv1 + output_size_conv1 = np.prod([self.output_channels, *size_after_stride], dtype=np.int64) + # conv2 + output_size_conv2 = np.prod([self.output_channels, *size_after_stride], dtype=np.int64) + # skip conv (if applicable) + if (self.input_channels != self.output_channels) or any([i != j for i, j in zip(input_size, size_after_stride)]): + assert isinstance(self.skip, nn.Sequential) + output_size_skip = np.prod([self.output_channels, *size_after_stride], dtype=np.int64) + else: + assert not isinstance(self.skip, nn.Sequential) + output_size_skip = 0 + return output_size_conv1 + output_size_conv2 + output_size_skip + +class UpsampleLayer(nn.Module): + def __init__( + self, + conv_op, + input_channels, + output_channels, + pool_op_kernel_size, + mode='nearest' + ): + super().__init__() + self.conv = conv_op(input_channels, output_channels, kernel_size=1) + self.pool_op_kernel_size = pool_op_kernel_size + self.mode = mode + + def forward(self, x): + x = F.interpolate(x, scale_factor=self.pool_op_kernel_size, mode=self.mode) + x = self.conv(x) + return x + +# class MambaLayer(nn.Module): +# def __init__(self, dim, d_state = 16, d_conv = 4, expand = 2): +# super().__init__() +# self.dim = dim +# self.norm = nn.LayerNorm(dim) +# self.mamba = Mamba( +# d_model=dim, # Model dimension d_model +# d_state=d_state, # SSM state expansion factor +# d_conv=d_conv, # Local convolution width +# expand=expand, # Block expansion factor +# ) + +# @autocast(enabled=False) +# def forward(self, x): +# if x.dtype == torch.float16: +# x = x.type(torch.float32) +# B, C = x.shape[:2] +# assert C == self.dim +# n_tokens = x.shape[2:].numel() +# img_dims = x.shape[2:] +# x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2) +# x_norm = self.norm(x_flat) +# x_mamba = self.mamba(x_norm) +# out = x_mamba.transpose(-1, -2).reshape(B, C, *img_dims) + +# return out + +class MambaLayer(nn.Module): + def __init__(self, dim, d_state = 16, d_conv = 4, expand = 2, channel_token = False): + super().__init__() + self.dim = dim + self.norm = nn.LayerNorm(dim) + self.mamba = Mamba( + d_model=dim, # Model dimension d_model + d_state=d_state, # SSM state expansion factor + d_conv=d_conv, # Local convolution width + expand=expand, # Block expansion factor + ) + self.channel_token = channel_token ## whether to use channel as tokens + + def forward_patch_token(self, x): + B, d_model = x.shape[:2] + assert d_model == self.dim + n_tokens = x.shape[2:].numel() + img_dims = x.shape[2:] + x_flat = x.reshape(B, d_model, n_tokens).transpose(-1, -2) + x_norm = self.norm(x_flat) + x_mamba = self.mamba(x_norm) + out = x_mamba.transpose(-1, -2).reshape(B, d_model, *img_dims) + + return out + + def forward_channel_token(self, x): + B, n_tokens = x.shape[:2] + d_model = x.shape[2:].numel() + assert d_model == self.dim, f"d_model: {d_model}, self.dim: {self.dim}" + img_dims = x.shape[2:] + x_flat = x.flatten(2) + assert x_flat.shape[2] == d_model, f"x_flat.shape[2]: {x_flat.shape[2]}, d_model: {d_model}" + x_norm = self.norm(x_flat) + x_mamba = self.mamba(x_norm) + out = x_mamba.reshape(B, n_tokens, *img_dims) + + return out + + @autocast(enabled=False) + def forward(self, x): + if x.dtype == torch.float16: + x = x.type(torch.float32) + + if self.channel_token: + out = self.forward_channel_token(x) + else: + out = self.forward_patch_token(x) + + return out + + +class BasicResBlock(nn.Module): + def __init__( + self, + conv_op, + input_channels, + output_channels, + norm_op, + norm_op_kwargs, + kernel_size=3, + padding=1, + stride=1, + use_1x1conv=False, + nonlin=nn.LeakyReLU, + nonlin_kwargs={'inplace': True} + ): + super().__init__() + + self.conv1 = conv_op(input_channels, output_channels, kernel_size, stride=stride, padding=padding) + self.norm1 = norm_op(output_channels, **norm_op_kwargs) + self.act1 = nonlin(**nonlin_kwargs) + + self.conv2 = conv_op(output_channels, output_channels, kernel_size, padding=padding) + self.norm2 = norm_op(output_channels, **norm_op_kwargs) + self.act2 = nonlin(**nonlin_kwargs) + + if use_1x1conv: + self.conv3 = conv_op(input_channels, output_channels, kernel_size=1, stride=stride) + else: + self.conv3 = None + + def forward(self, x): + y = self.conv1(x) + y = self.act1(self.norm1(y)) + y = self.norm2(self.conv2(y)) + if self.conv3: + x = self.conv3(x) + y += x + return self.act2(y) + +class UNetResEncoder(nn.Module): + def __init__(self, + input_channels: int, + n_stages: int, + features_per_stage: Union[int, List[int], Tuple[int, ...]], + conv_op: Type[_ConvNd], + kernel_sizes: Union[int, List[int], Tuple[int, ...]], + strides: Union[int, List[int], Tuple[int, ...], Tuple[Tuple[int, ...], ...]], + n_blocks_per_stage: Union[int, List[int], Tuple[int, ...]], + conv_bias: bool = False, + norm_op: Union[None, Type[nn.Module]] = None, + norm_op_kwargs: dict = None, + nonlin: Union[None, Type[torch.nn.Module]] = None, + nonlin_kwargs: dict = None, + return_skips: bool = False, + stem_channels: int = None, + pool_type: str = 'conv', + ): + super().__init__() + if isinstance(kernel_sizes, int): + kernel_sizes = [kernel_sizes] * n_stages + if isinstance(features_per_stage, int): + features_per_stage = [features_per_stage] * n_stages + if isinstance(n_blocks_per_stage, int): + n_blocks_per_stage = [n_blocks_per_stage] * n_stages + if isinstance(strides, int): + strides = [strides] * n_stages + + assert len( + kernel_sizes) == n_stages, "kernel_sizes must have as many entries as we have resolution stages (n_stages)" + assert len( + n_blocks_per_stage) == n_stages, "n_conv_per_stage must have as many entries as we have resolution stages (n_stages)" + assert len( + features_per_stage) == n_stages, "features_per_stage must have as many entries as we have resolution stages (n_stages)" + assert len(strides) == n_stages, "strides must have as many entries as we have resolution stages (n_stages). " \ + "Important: first entry is recommended to be 1, else we run strided conv drectly on the input" + + pool_op = get_matching_pool_op(conv_op, pool_type=pool_type) if pool_type != 'conv' else None + + self.conv_pad_sizes = [] + for krnl in kernel_sizes: + self.conv_pad_sizes.append([i // 2 for i in krnl]) + + stem_channels = features_per_stage[0] + + self.stem = nn.Sequential( + BasicResBlock( + conv_op = conv_op, + input_channels = input_channels, + output_channels = stem_channels, + norm_op=norm_op, + norm_op_kwargs=norm_op_kwargs, + kernel_size=kernel_sizes[0], + padding=self.conv_pad_sizes[0], + stride=1, + nonlin=nonlin, + nonlin_kwargs=nonlin_kwargs, + use_1x1conv=True + ), + *[ + BasicBlockD( + conv_op = conv_op, + input_channels = stem_channels, + output_channels = stem_channels, + kernel_size = kernel_sizes[0], + stride = 1, + conv_bias = conv_bias, + norm_op = norm_op, + norm_op_kwargs = norm_op_kwargs, + nonlin = nonlin, + nonlin_kwargs = nonlin_kwargs, + ) for _ in range(n_blocks_per_stage[0] - 1) + ] + ) + + + input_channels = stem_channels + + # now build the network + stages = [] + for s in range(n_stages): + stage = nn.Sequential( + BasicResBlock( + conv_op = conv_op, + norm_op = norm_op, + norm_op_kwargs = norm_op_kwargs, + input_channels = input_channels, + output_channels = features_per_stage[s], + kernel_size = kernel_sizes[s], + padding=self.conv_pad_sizes[s], + stride=strides[s], + use_1x1conv=True, + nonlin = nonlin, + nonlin_kwargs = nonlin_kwargs + ), + *[ + BasicBlockD( + conv_op = conv_op, + input_channels = features_per_stage[s], + output_channels = features_per_stage[s], + kernel_size = kernel_sizes[s], + stride = 1, + conv_bias = conv_bias, + norm_op = norm_op, + norm_op_kwargs = norm_op_kwargs, + nonlin = nonlin, + nonlin_kwargs = nonlin_kwargs, + ) for _ in range(n_blocks_per_stage[s] - 1) + ] + ) + + stages.append(stage) + input_channels = features_per_stage[s] + + self.stages = nn.Sequential(*stages) + self.output_channels = features_per_stage + self.strides = [[item] * dim_of_conv_op(conv_op) if not isinstance(item, (tuple, list, np.ndarray)) \ + else item for item in strides] + # self.strides = [maybe_convert_scalar_to_list(conv_op, i) for i in strides] + + self.return_skips = return_skips + + self.conv_op = conv_op + self.norm_op = norm_op + self.norm_op_kwargs = norm_op_kwargs + self.nonlin = nonlin + self.nonlin_kwargs = nonlin_kwargs + #self.dropout_op = dropout_op + #self.dropout_op_kwargs = dropout_op_kwargs + self.conv_bias = conv_bias + self.kernel_sizes = kernel_sizes + + def forward(self, x): + if self.stem is not None: + x = self.stem(x) + ret = [] + for s in self.stages: + x = s(x) + ret.append(x) + if self.return_skips: + return ret + else: + return ret[-1] + + def compute_conv_feature_map_size(self, input_size): + if self.stem is not None: + output = self.stem.compute_conv_feature_map_size(input_size) + else: + output = np.int64(0) + + for s in range(len(self.stages)): + output += self.stages[s].compute_conv_feature_map_size(input_size) + input_size = [i // j for i, j in zip(input_size, self.strides[s])] + + return output + + +class UNetResDecoder(nn.Module): + def __init__(self, + encoder, + num_classes, + n_conv_per_stage: Union[int, Tuple[int, ...], List[int]], + deep_supervision, nonlin_first: bool = False): + + super().__init__() + self.deep_supervision = deep_supervision + self.encoder = encoder + self.num_classes = num_classes + n_stages_encoder = len(encoder.output_channels) + if isinstance(n_conv_per_stage, int): + n_conv_per_stage = [n_conv_per_stage] * (n_stages_encoder - 1) + assert len(n_conv_per_stage) == n_stages_encoder - 1, "n_conv_per_stage must have as many entries as we have " \ + "resolution stages - 1 (n_stages in encoder - 1), " \ + "here: %d" % n_stages_encoder + + stages = [] + upsample_layers = [] + + seg_layers = [] + for s in range(1, n_stages_encoder): + input_features_below = encoder.output_channels[-s] + input_features_skip = encoder.output_channels[-(s + 1)] + stride_for_upsampling = encoder.strides[-s] + upsample_layers.append(UpsampleLayer( + conv_op = encoder.conv_op, + input_channels = input_features_below, + output_channels = input_features_skip, + pool_op_kernel_size = stride_for_upsampling, + mode='nearest' + )) + + stages.append(nn.Sequential( + BasicResBlock( + conv_op = encoder.conv_op, + norm_op = encoder.norm_op, + norm_op_kwargs = encoder.norm_op_kwargs, + nonlin = encoder.nonlin, + nonlin_kwargs = encoder.nonlin_kwargs, + input_channels = 2 * input_features_skip if s < n_stages_encoder - 1 else input_features_skip, + output_channels = input_features_skip, + kernel_size = encoder.kernel_sizes[-(s + 1)], + padding=encoder.conv_pad_sizes[-(s + 1)], + stride=1, + use_1x1conv=True + ), + *[ + BasicBlockD( + conv_op = encoder.conv_op, + input_channels = input_features_skip, + output_channels = input_features_skip, + kernel_size = encoder.kernel_sizes[-(s + 1)], + stride = 1, + conv_bias = encoder.conv_bias, + norm_op = encoder.norm_op, + norm_op_kwargs = encoder.norm_op_kwargs, + nonlin = encoder.nonlin, + nonlin_kwargs = encoder.nonlin_kwargs, + ) for _ in range(n_conv_per_stage[s-1] - 1) + ] + )) + seg_layers.append(encoder.conv_op(input_features_skip, num_classes, 1, 1, 0, bias=True)) + + self.stages = nn.ModuleList(stages) + self.upsample_layers = nn.ModuleList(upsample_layers) + self.seg_layers = nn.ModuleList(seg_layers) + + def forward(self, skips): + lres_input = skips[-1] + seg_outputs = [] + for s in range(len(self.stages)): + x = self.upsample_layers[s](lres_input) + if s < (len(self.stages) - 1): + x = torch.cat((x, skips[-(s+2)]), 1) + x = self.stages[s](x) + if self.deep_supervision: + seg_outputs.append(self.seg_layers[s](x)) + elif s == (len(self.stages) - 1): + seg_outputs.append(self.seg_layers[-1](x)) + lres_input = x + seg_outputs = seg_outputs[::-1] + + if not self.deep_supervision: + r = seg_outputs[0] + else: + r = seg_outputs + return r + + def compute_conv_feature_map_size(self, input_size): + skip_sizes = [] + for s in range(len(self.encoder.strides) - 1): + skip_sizes.append([i // j for i, j in zip(input_size, self.encoder.strides[s])]) + input_size = skip_sizes[-1] + + assert len(skip_sizes) == len(self.stages) + + output = np.int64(0) + for s in range(len(self.stages)): + output += self.stages[s].compute_conv_feature_map_size(skip_sizes[-(s+1)]) + output += np.prod([self.encoder.output_channels[-(s+2)], *skip_sizes[-(s+1)]], dtype=np.int64) + if self.deep_supervision or (s == (len(self.stages) - 1)): + output += np.prod([self.num_classes, *skip_sizes[-(s+1)]], dtype=np.int64) + return output + + +class UMambaBot(nn.Module): + """ + UMambaBot that uses Mamba block at the bottleneck of UNet. + + * Reference: Jun Ma, Feifei Li, Bo Wang. + U-Mamba: Enhancing long-range dependency for biomedical image segmentation. + arxiv 2403.20035, 2024. + + The implementation is based on the code at: + https://github.com/bowang-lab/U-Mamba. + + The parameters for the backbone should be given in the `params` dictionary. + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param class_num: (int) The class number for segmentation task. + :param n_blocks_per_stage: (int) the number of con blocks at each stage. + """ + def __init__(self, params): + super(UMambaBot, self).__init__() + params = self.get_default_parameters(params) + for p in params: + print(p, params[p]) + + # def __init__(self, + # input_channels: int, + # n_stages: int, + # features_per_stage: Union[int, List[int], Tuple[int, ...]], + # conv_op: Type[_ConvNd], + # kernel_sizes: Union[int, List[int], Tuple[int, ...]], + # strides: Union[int, List[int], Tuple[int, ...]], + # n_conv_per_stage: Union[int, List[int], Tuple[int, ...]], + # num_classes: int, + # n_conv_per_stage_decoder: Union[int, Tuple[int, ...], List[int]], + # conv_bias: bool = False, + # norm_op: Union[None, Type[nn.Module]] = None, + # norm_op_kwargs: dict = None, + # dropout_op: Union[None, Type[_DropoutNd]] = None, + # dropout_op_kwargs: dict = None, + # nonlin: Union[None, Type[torch.nn.Module]] = None, + # nonlin_kwargs: dict = None, + # deep_supervision: bool = False, + # stem_channels: int = None + # ): + # super().__init__() + + input_channels = params['in_chns'] + features_per_stage = params['feature_chns'] + num_classes = params['class_num'] + n_blocks_per_stage = params['n_blocks_per_stage'] + n_conv_per_stage_decoder = n_blocks_per_stage + n_stages = len(features_per_stage) + conv_op = nn.Conv2d + kernel_sizes = [(3,3)] * len(features_per_stage) + strides = [(1, 1)] + [(2,2)] * (len(features_per_stage) - 1) + # strides = [(1,1)] * len(features_per_stage) + + conv_bias = True + norm_op = nn.InstanceNorm2d + norm_op_kwargs = {"affine":True} + nonlin=nn.LeakyReLU + nonlin_kwargs={'inplace': True} + deep_supervision = False + stem_channels = None + + if isinstance(n_blocks_per_stage, int): + n_blocks_per_stage = [n_blocks_per_stage] * n_stages + if isinstance(n_conv_per_stage_decoder, int): + n_conv_per_stage_decoder = [n_conv_per_stage_decoder] * (n_stages - 1) + + for s in range(math.ceil(n_stages / 2), n_stages): + n_blocks_per_stage[s] = 1 + + for s in range(math.ceil((n_stages - 1) / 2 + 0.5), n_stages - 1): + n_conv_per_stage_decoder[s] = 1 + + + assert len(n_blocks_per_stage) == n_stages, "n_blocks_per_stage must have as many entries as we have " \ + f"resolution stages. here: {n_stages}. " \ + f"n_blocks_per_stage: {n_blocks_per_stage}" + assert len(n_conv_per_stage_decoder) == (n_stages - 1), "n_conv_per_stage_decoder must have one less entries " \ + f"as we have resolution stages. here: {n_stages} " \ + f"stages, so it should have {n_stages - 1} entries. " \ + f"n_conv_per_stage_decoder: {n_conv_per_stage_decoder}" + + + self.encoder = UNetResEncoder( + input_channels, + n_stages, + features_per_stage, + conv_op, + kernel_sizes, + strides, + n_blocks_per_stage, + conv_bias, + norm_op, + norm_op_kwargs, + nonlin, + nonlin_kwargs, + return_skips=True, + stem_channels=stem_channels + ) + + self.mamba_layer = MambaLayer(dim = features_per_stage[-1]) + + self.decoder = UNetResDecoder(self.encoder, num_classes, n_conv_per_stage_decoder, deep_supervision) + + def get_default_parameters(self, params): + default_param = { + 'feature_chns': [32, 64, 128, 256, 512], + 'n_blocks_per_stage': 2 + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + skips = self.encoder(x) + # for skip in skips: + # print(skip.shape) + skips[-1] = self.mamba_layer(skips[-1]) + output = self.decoder(skips) + if(len(x_shape) == 5): + if(isinstance(output, (list,tuple))): + for i in range(len(output)): + new_shape = [N, D] + list(output[i].shape)[1:] + output[i] = torch.transpose(torch.reshape(output[i], new_shape), 1, 2) + else: + new_shape = [N, D] + list(output.shape)[1:] + output = torch.transpose(torch.reshape(output, new_shape), 1, 2) + return output + + def compute_conv_feature_map_size(self, input_size): + assert len(input_size) == dim_of_conv_op(self.encoder.conv_op), "just give the image size without color/feature channels or " \ + "batch channel. Do not give input_size=(b, c, x, y(, z)). " \ + "Give input_size=(x, y(, z))!" + return self.encoder.compute_conv_feature_map_size(input_size) + self.decoder.compute_conv_feature_map_size(input_size) + +class ResidualMambaEncoder(nn.Module): + def __init__(self, + input_size: Tuple[int, ...], + input_channels: int, + n_stages: int, + features_per_stage: Union[int, List[int], Tuple[int, ...]], + conv_op: Type[_ConvNd], + kernel_sizes: Union[int, List[int], Tuple[int, ...]], + strides: Union[int, List[int], Tuple[int, ...], Tuple[Tuple[int, ...], ...]], + n_blocks_per_stage: Union[int, List[int], Tuple[int, ...]], + conv_bias: bool = False, + norm_op: Union[None, Type[nn.Module]] = None, + norm_op_kwargs: dict = None, + nonlin: Union[None, Type[torch.nn.Module]] = None, + nonlin_kwargs: dict = None, + return_skips: bool = False, + stem_channels: int = None, + pool_type: str = 'conv', + ): + super().__init__() + if isinstance(kernel_sizes, int): + kernel_sizes = [kernel_sizes] * n_stages + if isinstance(features_per_stage, int): + features_per_stage = [features_per_stage] * n_stages + if isinstance(n_blocks_per_stage, int): + n_blocks_per_stage = [n_blocks_per_stage] * n_stages + if isinstance(strides, int): + strides = [strides] * n_stages + assert len( + kernel_sizes) == n_stages, "kernel_sizes must have as many entries as we have resolution stages (n_stages)" + assert len( + n_blocks_per_stage) == n_stages, "n_conv_per_stage must have as many entries as we have resolution stages (n_stages)" + assert len( + features_per_stage) == n_stages, "features_per_stage must have as many entries as we have resolution stages (n_stages)" + assert len(strides) == n_stages, "strides must have as many entries as we have resolution stages (n_stages). " \ + "Important: first entry is recommended to be 1, else we run strided conv drectly on the input" + + pool_op = get_matching_pool_op(conv_op, pool_type=pool_type) if pool_type != 'conv' else None + + do_channel_token = [False] * n_stages + feature_map_sizes = [] + feature_map_size = input_size + for s in range(n_stages): + feature_map_sizes.append([i // j for i, j in zip(feature_map_size, strides[s])]) + feature_map_size = feature_map_sizes[-1] + if np.prod(feature_map_size) <= features_per_stage[s]: + do_channel_token[s] = True + + + print(f"feature_map_sizes: {feature_map_sizes}") + print(f"do_channel_token: {do_channel_token}") + + self.conv_pad_sizes = [] + for krnl in kernel_sizes: + self.conv_pad_sizes.append([i // 2 for i in krnl]) + + stem_channels = features_per_stage[0] + self.stem = nn.Sequential( + BasicResBlock( + conv_op = conv_op, + input_channels = input_channels, + output_channels = stem_channels, + norm_op=norm_op, + norm_op_kwargs=norm_op_kwargs, + kernel_size=kernel_sizes[0], + padding=self.conv_pad_sizes[0], + stride=1, + nonlin=nonlin, + nonlin_kwargs=nonlin_kwargs, + use_1x1conv=True + ), + *[ + BasicBlockD( + conv_op = conv_op, + input_channels = stem_channels, + output_channels = stem_channels, + kernel_size = kernel_sizes[0], + stride = 1, + conv_bias = conv_bias, + norm_op = norm_op, + norm_op_kwargs = norm_op_kwargs, + nonlin = nonlin, + nonlin_kwargs = nonlin_kwargs, + ) for _ in range(n_blocks_per_stage[0] - 1) + ] + ) + + input_channels = stem_channels + + stages = [] + mamba_layers = [] + for s in range(n_stages): + stage = nn.Sequential( + BasicResBlock( + conv_op = conv_op, + norm_op = norm_op, + norm_op_kwargs = norm_op_kwargs, + input_channels = input_channels, + output_channels = features_per_stage[s], + kernel_size = kernel_sizes[s], + padding=self.conv_pad_sizes[s], + stride=strides[s], + use_1x1conv=True, + nonlin = nonlin, + nonlin_kwargs = nonlin_kwargs + ), + *[ + BasicBlockD( + conv_op = conv_op, + input_channels = features_per_stage[s], + output_channels = features_per_stage[s], + kernel_size = kernel_sizes[s], + stride = 1, + conv_bias = conv_bias, + norm_op = norm_op, + norm_op_kwargs = norm_op_kwargs, + nonlin = nonlin, + nonlin_kwargs = nonlin_kwargs, + ) for _ in range(n_blocks_per_stage[s] - 1) + ] + ) + + if bool(s % 2) ^ bool(n_stages % 2): ## gurantee the last stage has mamaba layer + mamba_layers.append( + MambaLayer( + dim = np.prod(feature_map_sizes[s]) if do_channel_token[s] else features_per_stage[s], + channel_token = do_channel_token[s] + ) + ) + else: + mamba_layers.append(nn.Identity()) + + stages.append(stage) + input_channels = features_per_stage[s] + + self.mamba_layers = nn.ModuleList(mamba_layers) + self.stages = nn.ModuleList(stages) + self.output_channels = features_per_stage + self.strides = [[item] * dim_of_conv_op(conv_op) if not isinstance(item, (tuple, list, np.ndarray)) \ + else item for item in strides] + # self.strides = [maybe_convert_scalar_to_list(conv_op, i) for i in strides] + self.return_skips = return_skips + + self.conv_op = conv_op + self.norm_op = norm_op + self.norm_op_kwargs = norm_op_kwargs + self.nonlin = nonlin + self.nonlin_kwargs = nonlin_kwargs + #self.dropout_op = dropout_op + #self.dropout_op_kwargs = dropout_op_kwargs + self.conv_bias = conv_bias + self.kernel_sizes = kernel_sizes + + def forward(self, x): + if self.stem is not None: + x = self.stem(x) + ret = [] + for s in range(len(self.stages)): + x = self.stages[s](x) + x = self.mamba_layers[s](x) + ret.append(x) + if self.return_skips: + return ret + else: + return ret[-1] + + def compute_conv_feature_map_size(self, input_size): + if self.stem is not None: + output = self.stem.compute_conv_feature_map_size(input_size) + else: + output = np.int64(0) + + for s in range(len(self.stages)): + output += self.stages[s].compute_conv_feature_map_size(input_size) + input_size = [i // j for i, j in zip(input_size, self.strides[s])] + + return output + + +class UMambaEnc(nn.Module): + """ + UMambaEnc that uses Mamba block at the encoder and bottleneck of UNet. + + * Reference: Jun Ma, Feifei Li, Bo Wang. + U-Mamba: Enhancing long-range dependency for biomedical image segmentation. + arxiv 2403.20035, 2024. + + The implementation is based on the code at: + https://github.com/bowang-lab/U-Mamba. + + The parameters for the backbone should be given in the `params` dictionary. + + :param input_size: (list) the size of input image, such as [256, 256] + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param class_num: (int) The class number for segmentation task. + :param n_blocks_per_stage: (int) the number of con blocks at each stage. + """ + def __init__(self, params): + super(UMambaEnc, self).__init__() + params = self.get_default_parameters(params) + for p in params: + print(p, params[p]) + # def __init__(self, + # input_size: Tuple[int, ...], + # input_channels: int, + # n_stages: int, + # features_per_stage: Union[int, List[int], Tuple[int, ...]], + # conv_op: Type[_ConvNd], + # kernel_sizes: Union[int, List[int], Tuple[int, ...]], + # strides: Union[int, List[int], Tuple[int, ...]], + # n_conv_per_stage: Union[int, List[int], Tuple[int, ...]], + # num_classes: int, + # n_conv_per_stage_decoder: Union[int, Tuple[int, ...], List[int]], + # conv_bias: bool = False, + # norm_op: Union[None, Type[nn.Module]] = None, + # norm_op_kwargs: dict = None, + # dropout_op: Union[None, Type[_DropoutNd]] = None, + # dropout_op_kwargs: dict = None, + # nonlin: Union[None, Type[torch.nn.Module]] = None, + # nonlin_kwargs: dict = None, + # deep_supervision: bool = False, + # stem_channels: int = None + # ): + # super().__init__() + + input_size = params['input_size'] + input_channels = params['in_chns'] + features_per_stage = params['feature_chns'] + num_classes = params['class_num'] + n_blocks_per_stage = params['n_blocks_per_stage'] + n_conv_per_stage_decoder = n_blocks_per_stage + n_stages = len(features_per_stage) + conv_op = nn.Conv2d + kernel_sizes = [(3,3)] * len(features_per_stage) + strides = [(1, 1)] + [(2,2)] * (len(features_per_stage) - 1) + + conv_bias = True + norm_op = nn.InstanceNorm2d + norm_op_kwargs = {"affine":True} + nonlin=nn.LeakyReLU + nonlin_kwargs={'inplace': True} + deep_supervision = False + stem_channels = None + + if isinstance(n_blocks_per_stage, int): + n_blocks_per_stage = [n_blocks_per_stage] * n_stages + if isinstance(n_conv_per_stage_decoder, int): + n_conv_per_stage_decoder = [n_conv_per_stage_decoder] * (n_stages - 1) + + for s in range(math.ceil(n_stages / 2), n_stages): + n_blocks_per_stage[s] = 1 + + for s in range(math.ceil((n_stages - 1) / 2 + 0.5), n_stages - 1): + n_conv_per_stage_decoder[s] = 1 + + + assert len(n_blocks_per_stage) == n_stages, "n_blocks_per_stage must have as many entries as we have " \ + f"resolution stages. here: {n_stages}. " \ + f"n_blocks_per_stage: {n_blocks_per_stage}" + assert len(n_conv_per_stage_decoder) == (n_stages - 1), "n_conv_per_stage_decoder must have one less entries " \ + f"as we have resolution stages. here: {n_stages} " \ + f"stages, so it should have {n_stages - 1} entries. " \ + f"n_conv_per_stage_decoder: {n_conv_per_stage_decoder}" + self.encoder = ResidualMambaEncoder( + input_size, + input_channels, + n_stages, + features_per_stage, + conv_op, + kernel_sizes, + strides, + n_blocks_per_stage, + conv_bias, + norm_op, + norm_op_kwargs, + nonlin, + nonlin_kwargs, + return_skips=True, + stem_channels=stem_channels + ) + + self.decoder = UNetResDecoder(self.encoder, num_classes, n_conv_per_stage_decoder, deep_supervision) + + def get_default_parameters(self, params): + default_param = { + 'feature_chns': [32, 64, 128, 256, 512], + 'n_blocks_per_stage': 2 + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + skips = self.encoder(x) + output = self.decoder(skips) + if(len(x_shape) == 5): + if(isinstance(output, (list,tuple))): + for i in range(len(output)): + new_shape = [N, D] + list(output[i].shape)[1:] + output[i] = torch.transpose(torch.reshape(output[i], new_shape), 1, 2) + else: + new_shape = [N, D] + list(output.shape)[1:] + output = torch.transpose(torch.reshape(output, new_shape), 1, 2) + return output + + def compute_conv_feature_map_size(self, input_size): + assert len(input_size) == dim_of_conv_op(self.encoder.conv_op), "just give the image size without color/feature channels or " \ + "batch channel. Do not give input_size=(b, c, x, y(, z)). " \ + "Give input_size=(x, y(, z))!" + return self.encoder.compute_conv_feature_map_size(input_size) + self.decoder.compute_conv_feature_map_size(input_size) + diff --git a/pymic/net/net2d/unet2d.py b/pymic/net/net2d/unet2d.py index 9acc0ad..be69f0d 100644 --- a/pymic/net/net2d/unet2d.py +++ b/pymic/net/net2d/unet2d.py @@ -1,10 +1,10 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division +import logging import torch import torch.nn as nn import numpy as np -from torch.nn.functional import interpolate class ConvBlock(nn.Module): """ @@ -56,22 +56,32 @@ class UpBlock(nn.Module): :param in_channels2: (int) Channel number of low-level features. :param out_channels: (int) Output channel number. :param dropout_p: (int) Dropout probability. - :param bilinear: (bool) Use bilinear for up-sampling (by default). - If False, deconvolution is used for up-sampling. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Bilinear`), 3 (`Bicubic`). The default value + is 2 (`Bilinear`). """ - def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, - bilinear=True): + def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, up_mode = 2): super(UpBlock, self).__init__() - self.bilinear = bilinear - if bilinear: - self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size = 1) - self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + if(isinstance(up_mode, int)): + up_mode_values = ["transconv", "nearest", "bilinear", "bicubic"] + if(up_mode > 3): + raise ValueError("The upsample mode should be 0-3, but {0:} is given.".format(up_mode)) + self.up_mode = up_mode_values[up_mode] else: + self.up_mode = up_mode.lower() + + if (self.up_mode == "transconv"): self.up = nn.ConvTranspose2d(in_channels1, in_channels2, kernel_size=2, stride=2) + else: + self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size = 1) + if(self.up_mode == "nearest"): + self.up = nn.Upsample(scale_factor=2, mode=self.up_mode) + else: + self.up = nn.Upsample(scale_factor=2, mode=self.up_mode, align_corners=True) self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p) def forward(self, x1, x2): - if self.bilinear: + if self.up_mode != "transconv": x1 = self.conv1x1(x1) x1 = self.up(x1) x = torch.cat([x2, x1], dim=1) @@ -129,8 +139,10 @@ class Decoder(nn.Module): :param dropout: (list) The dropout ratio for each resolution level. The length should be the same as that of `feature_chns`. :param class_num: (int) The class number for segmentation task. - :param bilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (or `Nearest`), 2 (or `Bilinear`), 3 (or `Bicubic`). + The default value is 2 (or `Bilinear`). + :param multiscale_pred: (bool) Get multi-scale prediction. """ def __init__(self, params): super(Decoder, self).__init__() @@ -139,17 +151,27 @@ def __init__(self, params): self.ft_chns = self.params['feature_chns'] self.dropout = self.params['dropout'] self.n_class = self.params['class_num'] - self.bilinear = self.params['bilinear'] + self.up_mode = self.params.get('up_mode', 2) + self.mul_pred = self.params.get('multiscale_pred', False) assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) if(len(self.ft_chns) == 5): - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.bilinear) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.bilinear) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.bilinear) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.bilinear) + self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.up_mode) + self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.up_mode) + self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.up_mode) + self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.up_mode) self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 1) + if(self.mul_pred and (self.training or self.mul_infer)): + self.out_conv1 = nn.Conv2d(self.ft_chns[1], self.n_class, kernel_size = 1) + self.out_conv2 = nn.Conv2d(self.ft_chns[2], self.n_class, kernel_size = 1) + self.out_conv3 = nn.Conv2d(self.ft_chns[3], self.n_class, kernel_size = 1) + self.stage = 'train' + + def set_stage(self, stage): + self.stage = stage + def forward(self, x): if(len(self.ft_chns) == 5): assert(len(x) == 5) @@ -163,6 +185,11 @@ def forward(self, x): x_d1 = self.up3(x_d2, x1) x_d0 = self.up4(x_d1, x0) output = self.out_conv(x_d0) + if(self.mul_pred and self.stage == 'train'): + output1 = self.out_conv1(x_d1) + output2 = self.out_conv2(x_d2) + output3 = self.out_conv3(x_d3) + output = [output, output1, output2, output3] return output class UNet2D(nn.Module): @@ -180,43 +207,43 @@ class UNet2D(nn.Module): following fields: :param in_chns: (int) Input channel number. + :param class_num: (int) The class number for segmentation task. + + Optional parameters: + :param feature_chns: (list) Feature channel for each resolution level. The length should be 4 or 5, such as [16, 32, 64, 128, 256]. :param dropout: (list) The dropout ratio for each resolution level. The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param bilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (or `Nearest`), 2 (or `Bilinear`), 3 (or `Bicubic`). + The default value is 2 (or `Bilinear`). :param multiscale_pred: (bool) Get multiscale prediction. """ def __init__(self, params): super(UNet2D, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.n_class = self.params['class_num'] - self.bilinear = self.params['bilinear'] - self.mul_pred = self.params['multiscale_pred'] + params = self.get_default_parameters(params) + for p in params: + print(p, params[p]) + self.encoder = Encoder(params) + self.decoder = Decoder(params) + + def get_default_parameters(self, params): + default_param = { + 'feature_chns': [32, 64, 128, 256, 512], + 'dropout': [0.0, 0.0, 0.2, 0.3, 0.4], + 'up_mode': 2, + 'multiscale_pred': False + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params - assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) - - self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) - self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) - self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) - self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) - if(len(self.ft_chns) == 5): - self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.bilinear) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.bilinear) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.bilinear) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.bilinear) - - self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 1) - if(self.mul_pred): - self.out_conv1 = nn.Conv2d(self.ft_chns[1], self.n_class, kernel_size = 1) - self.out_conv2 = nn.Conv2d(self.ft_chns[2], self.n_class, kernel_size = 1) - self.out_conv3 = nn.Conv2d(self.ft_chns[3], self.n_class, kernel_size = 1) + def set_stage(self, stage): + self.stage = stage + self.decoder.set_stage(stage) def forward(self, x): x_shape = list(x.shape) @@ -226,51 +253,15 @@ def forward(self, x): x = torch.transpose(x, 1, 2) x = torch.reshape(x, new_shape) - x0 = self.in_conv(x) - x1 = self.down1(x0) - x2 = self.down2(x1) - x3 = self.down3(x2) - if(len(self.ft_chns) == 5): - x4 = self.down4(x3) - x_d3 = self.up1(x4, x3) - else: - x_d3 = x3 - x_d2 = self.up2(x_d3, x2) - x_d1 = self.up3(x_d2, x1) - x_d0 = self.up4(x_d1, x0) - output = self.out_conv(x_d0) - if(self.mul_pred): - output1 = self.out_conv1(x_d1) - output2 = self.out_conv2(x_d2) - output3 = self.out_conv3(x_d3) - output = [output, output1, output2, output3] - - if(len(x_shape) == 5): + f = self.encoder(x) + output = self.decoder(f) + if(len(x_shape) == 5): + if(isinstance(output, (list,tuple))): for i in range(len(output)): new_shape = [N, D] + list(output[i].shape)[1:] output[i] = torch.transpose(torch.reshape(output[i], new_shape), 1, 2) - elif(len(x_shape) == 5): - new_shape = [N, D] + list(output.shape)[1:] - output = torch.reshape(output, new_shape) - output = torch.transpose(output, 1, 2) - - return output - -if __name__ == "__main__": - params = {'in_chns':4, - 'feature_chns':[2, 8, 32, 48, 64], - 'dropout': [0, 0, 0.3, 0.4, 0.5], - 'class_num': 2, - 'bilinear': True, - 'multiscale_pred': False} - Net = UNet2D(params) - Net = Net.double() - - x = np.random.rand(4, 4, 10, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - print(len(y.size())) - y = y.detach().numpy() - print(y.shape) + else: + new_shape = [N, D] + list(output.shape)[1:] + output = torch.transpose(torch.reshape(output, new_shape), 1, 2) + + return output \ No newline at end of file diff --git a/pymic/net/net2d/unet2d_attention.py b/pymic/net/net2d/unet2d_attention.py index 6afdfdc..36faec8 100644 --- a/pymic/net/net2d/unet2d_attention.py +++ b/pymic/net/net2d/unet2d_attention.py @@ -4,14 +4,7 @@ import torch import torch.nn as nn from pymic.net.net2d.unet2d import * -""" -A Reimplementation of the attention U-Net paper: - Ozan Oktay, Jo Schlemper et al.: - Attentin U-Net: Looking Where to Look for the Pancreas. MIDL, 2018. -Note that there are some modifications from the original paper, such as -the use of batch normalization, dropout, and leaky relu here. -""" class AttentionGateBlock(nn.Module): def __init__(self, chns_l, chns_h): """ @@ -80,6 +73,14 @@ def forward(self, x1, x2): return self.conv(x) class AttentionUNet2D(UNet2D): + """ + A Reimplementation of the attention U-Net paper: + Ozan Oktay, Jo Schlemper et al.: + Attentin U-Net: Looking Where to Look for the Pancreas. MIDL, 2018. + + Note that there are some modifications from the original paper, such as + the use of batch normalization, dropout, and leaky relu here. + """ def __init__(self, params): super(AttentionUNet2D, self).__init__(params) self.up1 = UpBlockWithAttention(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p = 0.0) diff --git a/pymic/net/net2d/unet2d_canet.py b/pymic/net/net2d/unet2d_canet.py new file mode 100644 index 0000000..f578025 --- /dev/null +++ b/pymic/net/net2d/unet2d_canet.py @@ -0,0 +1,1224 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import numpy as np +import torch +import torch.nn as nn +from torch.nn import init +from torch.nn import functional as F + +## init +def weights_init_normal(m): + classname = m.__class__.__name__ + #print(classname) + if classname.find('Conv') != -1: + init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find('Linear') != -1: + init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm') != -1: + init.normal_(m.weight.data, 1.0, 0.02) + init.constant_(m.bias.data, 0.0) + + +def weights_init_xavier(m): + classname = m.__class__.__name__ + #print(classname) + if classname.find('Conv') != -1: + init.xavier_normal_(m.weight.data, gain=1) + elif classname.find('Linear') != -1: + init.xavier_normal_(m.weight.data, gain=1) + elif classname.find('BatchNorm') != -1: + init.normal_(m.weight.data, 1.0, 0.02) + init.constant_(m.bias.data, 0.0) + + +def weights_init_kaiming(m): + classname = m.__class__.__name__ + #print(classname) + if classname.find('Conv') != -1: + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif classname.find('Linear') != -1: + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif classname.find('BatchNorm') != -1: + init.normal_(m.weight.data, 1.0, 0.02) + init.constant_(m.bias.data, 0.0) + + +def weights_init_orthogonal(m): + classname = m.__class__.__name__ + #print(classname) + if classname.find('Conv') != -1: + init.orthogonal_(m.weight.data, gain=1) + elif classname.find('Linear') != -1: + init.orthogonal_(m.weight.data, gain=1) + elif classname.find('BatchNorm') != -1: + init.normal_(m.weight.data, 1.0, 0.02) + init.constant_(m.bias.data, 0.0) + + +def init_weights(net, init_type='normal'): + #print('initialization method [%s]' % init_type) + if init_type == 'normal': + net.apply(weights_init_normal) + elif init_type == 'xavier': + net.apply(weights_init_xavier) + elif init_type == 'kaiming': + net.apply(weights_init_kaiming) + elif init_type == 'orthogonal': + net.apply(weights_init_orthogonal) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + +## 1, modules +def conv1x1(in_planes, out_planes, stride=1, bias=False): + "1x1 convolution" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, + padding=0, bias=bias) + + +def conv3x3(in_planes, out_planes, stride=1, bias=False, group=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,padding=1, groups=group, bias=bias) + +# conv_block(nn.Module) for U-net convolution block +class conv_block(nn.Module): + def __init__(self, ch_in, ch_out, drop_out=False): + super(conv_block, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), + nn.BatchNorm2d(ch_out), + nn.ReLU(inplace=True), + nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True), + nn.BatchNorm2d(ch_out), + nn.ReLU(inplace=True), + ) + self.dropout = drop_out + + def forward(self, x): + x = self.conv(x) + if self.dropout: + x = nn.Dropout2d(0.5)(x) + return x + + +# # UpCat(nn.Module) for U-net UP convolution +class UpCat(nn.Module): + def __init__(self, in_feat, out_feat, is_deconv=True): + super(UpCat, self).__init__() + + if is_deconv: + self.up = nn.ConvTranspose2d(in_feat, out_feat, kernel_size=2, stride=2) + else: + self.up = nn.Upsample(scale_factor=2, mode='bilinear') + + def forward(self, inputs, down_outputs): + # TODO: Upsampling required after deconv? + outputs = self.up(down_outputs) + offset = inputs.size()[3] - outputs.size()[3] + if offset == 1: + addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2]), out=None).unsqueeze( + 3).cuda() + outputs = torch.cat([outputs, addition], dim=3) + elif offset > 1: + addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2], offset), out=None).cuda() + outputs = torch.cat([outputs, addition], dim=3) + out = torch.cat([inputs, outputs], dim=1) + + return out + + +# # UpCatconv(nn.Module) for up convolution +class UpCatconv(nn.Module): + def __init__(self, in_feat, out_feat, is_deconv=True, drop_out=False): + super(UpCatconv, self).__init__() + + if is_deconv: + self.conv = conv_block(in_feat, out_feat, drop_out=drop_out) + self.up = nn.ConvTranspose2d(in_feat, out_feat, kernel_size=2, stride=2) + else: + self.conv = conv_block(in_feat + out_feat, out_feat, drop_out=drop_out) + self.up = nn.Upsample(scale_factor=2, mode='bilinear') + + def forward(self, inputs, down_outputs): + # TODO: Upsampling required after deconv + outputs = self.up(down_outputs) + offset = inputs.size()[3] - outputs.size()[3] + if offset == 1: + addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2]), out=None).unsqueeze( + 3).cuda() + outputs = torch.cat([outputs, addition], dim=3) + elif offset > 1: + addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2], offset), out=None).cuda() + outputs = torch.cat([outputs, addition], dim=3) + out = self.conv(torch.cat([inputs, outputs], dim=1)) + + return out + + + +class _GridAttentionBlockND(nn.Module): + def __init__(self, in_channels, gating_channels, inter_channels=None, dimension=3, mode='concatenation', + sub_sample_factor=(2,2,2)): + super(_GridAttentionBlockND, self).__init__() + + assert dimension in [2, 3] + assert mode in ['concatenation', 'concatenation_debug', 'concatenation_residual'] + + # Downsampling rate for the input featuremap + if isinstance(sub_sample_factor, tuple): self.sub_sample_factor = sub_sample_factor + elif isinstance(sub_sample_factor, list): self.sub_sample_factor = tuple(sub_sample_factor) + else: self.sub_sample_factor = tuple([sub_sample_factor]) * dimension + + # Default parameter set + self.mode = mode + self.dimension = dimension + self.sub_sample_kernel_size = self.sub_sample_factor + + # Number of channels (pixel dimensions) + self.in_channels = in_channels + self.gating_channels = gating_channels + self.inter_channels = inter_channels + + if self.inter_channels is None: + self.inter_channels = in_channels // 2 + if self.inter_channels == 0: + self.inter_channels = 1 + + if dimension == 3: + conv_nd = nn.Conv3d + bn = nn.BatchNorm3d + self.upsample_mode = 'trilinear' + elif dimension == 2: + conv_nd = nn.Conv2d + bn = nn.BatchNorm2d + self.upsample_mode = 'bilinear' + else: + raise NotImplemented + + # Output transform + self.W = nn.Sequential( + conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), + bn(self.in_channels), + ) + + # Theta^T * x_ij + Phi^T * gating_signal + bias + self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, + kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=True) + self.phi = conv_nd(in_channels=self.gating_channels, out_channels=self.inter_channels, + kernel_size=(1, 1), stride=1, padding=0, bias=True) + self.psi = conv_nd(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True) + + # Initialise weights + for m in self.children(): + init_weights(m, init_type='kaiming') + + # Define the operation + if mode == 'concatenation': + self.operation_function = self._concatenation + elif mode == 'concatenation_debug': + self.operation_function = self._concatenation_debug + elif mode == 'concatenation_residual': + self.operation_function = self._concatenation_residual + else: + raise NotImplementedError('Unknown operation function.') + + + def forward(self, x, g): + ''' + :param x: (b, c, t, h, w) + :param g: (b, g_d) + :return: + ''' + + output = self.operation_function(x, g) + return output + + def _concatenation(self, x, g): + input_size = x.size() + batch_size = input_size[0] + assert batch_size == g.size(0) + + # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) + # phi => (b, g_d) -> (b, i_c) + theta_x = self.theta(x) + theta_x_size = theta_x.size() + + # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') + # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) + phi_g = F.interpolate(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) + f = F.relu(theta_x + phi_g, inplace=True) + + # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) + sigm_psi_f = torch.sigmoid(self.psi(f)) + + # upsample the attentions and multiply + sigm_psi_f = F.interpolate(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) + y = sigm_psi_f.expand_as(x) * x + W_y = self.W(y) + + return W_y, sigm_psi_f + + def _concatenation_debug(self, x, g): + input_size = x.size() + batch_size = input_size[0] + assert batch_size == g.size(0) + + # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) + # phi => (b, g_d) -> (b, i_c) + theta_x = self.theta(x) + theta_x_size = theta_x.size() + + # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') + # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) + phi_g = F.interpolate(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) + f = F.softplus(theta_x + phi_g) + + # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) + sigm_psi_f = torch.sigmoid(self.psi(f)) + + # upsample the attentions and multiply + sigm_psi_f = F.interpolate(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) + y = sigm_psi_f.expand_as(x) * x + W_y = self.W(y) + + return W_y, sigm_psi_f + + + def _concatenation_residual(self, x, g): + input_size = x.size() + batch_size = input_size[0] + assert batch_size == g.size(0) + + # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) + # phi => (b, g_d) -> (b, i_c) + theta_x = self.theta(x) + theta_x_size = theta_x.size() + + # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') + # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) + phi_g = F.interpolate(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) + f = F.relu(theta_x + phi_g, inplace=True) + + # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) + f = self.psi(f).view(batch_size, 1, -1) + sigm_psi_f = F.softmax(f, dim=2).view(batch_size, 1, *theta_x.size()[2:]) + + # upsample the attentions and multiply + sigm_psi_f = F.interpolate(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) + y = sigm_psi_f.expand_as(x) * x + W_y = self.W(y) + + return W_y, sigm_psi_f + + +class GridAttentionBlock2D(_GridAttentionBlockND): + def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation', + sub_sample_factor=(2, 2)): + super(GridAttentionBlock2D, self).__init__(in_channels, + inter_channels=inter_channels, + gating_channels=gating_channels, + dimension=2, mode=mode, + sub_sample_factor=sub_sample_factor, + ) + + +## 2, channel attention +class SE_Conv_Block(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, drop_out=False): + super(SE_Conv_Block, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes * 2) + self.bn2 = nn.BatchNorm2d(planes * 2) + self.conv3 = conv3x3(planes * 2, planes) + self.bn3 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + self.dropout = drop_out + + self.globalAvgPool = nn.AdaptiveAvgPool2d(1) + self.globalMaxPool = nn.AdaptiveMaxPool2d(1) + + self.fc1 = nn.Linear(in_features=planes * 2, out_features=round(planes / 2)) + self.fc2 = nn.Linear(in_features=round(planes / 2), out_features=planes * 2) + self.sigmoid = nn.Sigmoid() + + self.downchannel = None + if inplanes != planes: + self.downchannel = nn.Sequential(nn.Conv2d(inplanes, planes * 2, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * 2),) + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downchannel is not None: + residual = self.downchannel(x) + + original_out = out + out1 = out + # For global average pool + out = self.globalAvgPool(out) + out = out.view(out.size(0), -1) + out = self.fc1(out) + out = self.relu(out) + out = self.fc2(out) + out = self.sigmoid(out) + out = out.view(out.size(0), out.size(1), 1, 1) + avg_att = out + out = out * original_out + # For global maximum pool + out1 = self.globalMaxPool(out1) + out1 = out1.view(out1.size(0), -1) + out1 = self.fc1(out1) + out1 = self.relu(out1) + out1 = self.fc2(out1) + out1 = self.sigmoid(out1) + out1 = out1.view(out1.size(0), out1.size(1), 1, 1) + max_att = out1 + out1 = out1 * original_out + + att_weight = avg_att + max_att + out += out1 + out += residual + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + out = self.relu(out) + if self.dropout: + out = nn.Dropout2d(0.5)(out) + + return out, att_weight + +## 3, grid attention +class _GridAttentionBlockND(nn.Module): + def __init__(self, in_channels, gating_channels, inter_channels=None, dimension=3, mode='concatenation', + sub_sample_factor=(2,2,2)): + super(_GridAttentionBlockND, self).__init__() + + assert dimension in [2, 3] + assert mode in ['concatenation', 'concatenation_debug', 'concatenation_residual'] + + # Downsampling rate for the input featuremap + if isinstance(sub_sample_factor, tuple): self.sub_sample_factor = sub_sample_factor + elif isinstance(sub_sample_factor, list): self.sub_sample_factor = tuple(sub_sample_factor) + else: self.sub_sample_factor = tuple([sub_sample_factor]) * dimension + + # Default parameter set + self.mode = mode + self.dimension = dimension + self.sub_sample_kernel_size = self.sub_sample_factor + + # Number of channels (pixel dimensions) + self.in_channels = in_channels + self.gating_channels = gating_channels + self.inter_channels = inter_channels + + if self.inter_channels is None: + self.inter_channels = in_channels // 2 + if self.inter_channels == 0: + self.inter_channels = 1 + + if dimension == 3: + conv_nd = nn.Conv3d + bn = nn.BatchNorm3d + self.upsample_mode = 'trilinear' + elif dimension == 2: + conv_nd = nn.Conv2d + bn = nn.BatchNorm2d + self.upsample_mode = 'bilinear' + else: + raise NotImplemented + + # Output transform + self.W = nn.Sequential( + conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), + bn(self.in_channels), + ) + + # Theta^T * x_ij + Phi^T * gating_signal + bias + self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, + kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=True) + self.phi = conv_nd(in_channels=self.gating_channels, out_channels=self.inter_channels, + kernel_size=(1, 1), stride=1, padding=0, bias=True) + self.psi = conv_nd(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True) + + # Initialise weights + for m in self.children(): + init_weights(m, init_type='kaiming') + + # Define the operation + if mode == 'concatenation': + self.operation_function = self._concatenation + elif mode == 'concatenation_debug': + self.operation_function = self._concatenation_debug + elif mode == 'concatenation_residual': + self.operation_function = self._concatenation_residual + else: + raise NotImplementedError('Unknown operation function.') + + + def forward(self, x, g): + ''' + :param x: (b, c, t, h, w) + :param g: (b, g_d) + :return: + ''' + + output = self.operation_function(x, g) + return output + + def _concatenation(self, x, g): + input_size = x.size() + batch_size = input_size[0] + assert batch_size == g.size(0) + + # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) + # phi => (b, g_d) -> (b, i_c) + theta_x = self.theta(x) + theta_x_size = theta_x.size() + + # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') + # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) + phi_g = F.interpolate(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) + f = F.relu(theta_x + phi_g, inplace=True) + + # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) + sigm_psi_f = torch.sigmoid(self.psi(f)) + + # upsample the attentions and multiply + sigm_psi_f = F.interpolate(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) + y = sigm_psi_f.expand_as(x) * x + W_y = self.W(y) + + return W_y, sigm_psi_f + + def _concatenation_debug(self, x, g): + input_size = x.size() + batch_size = input_size[0] + assert batch_size == g.size(0) + + # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) + # phi => (b, g_d) -> (b, i_c) + theta_x = self.theta(x) + theta_x_size = theta_x.size() + + # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') + # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) + phi_g = F.interpolate(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) + f = F.softplus(theta_x + phi_g) + + # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) + sigm_psi_f = torch.sigmoid(self.psi(f)) + + # upsample the attentions and multiply + sigm_psi_f = F.interpolate(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) + y = sigm_psi_f.expand_as(x) * x + W_y = self.W(y) + + return W_y, sigm_psi_f + + + def _concatenation_residual(self, x, g): + input_size = x.size() + batch_size = input_size[0] + assert batch_size == g.size(0) + + # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) + # phi => (b, g_d) -> (b, i_c) + theta_x = self.theta(x) + theta_x_size = theta_x.size() + + # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') + # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) + phi_g = F.interpolate(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) + f = F.relu(theta_x + phi_g, inplace=True) + + # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) + f = self.psi(f).view(batch_size, 1, -1) + sigm_psi_f = F.softmax(f, dim=2).view(batch_size, 1, *theta_x.size()[2:]) + + # upsample the attentions and multiply + sigm_psi_f = F.interpolate(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) + y = sigm_psi_f.expand_as(x) * x + W_y = self.W(y) + + return W_y, sigm_psi_f + + +class GridAttentionBlock2D(_GridAttentionBlockND): + def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation', + sub_sample_factor=(2, 2)): + super(GridAttentionBlock2D, self).__init__(in_channels, + inter_channels=inter_channels, + gating_channels=gating_channels, + dimension=2, mode=mode, + sub_sample_factor=sub_sample_factor, + ) + +class MultiAttentionBlock(nn.Module): + def __init__(self, in_size, gate_size, inter_size, nonlocal_mode, sub_sample_factor): + super(MultiAttentionBlock, self).__init__() + self.gate_block_1 = GridAttentionBlock2D(in_channels=in_size, gating_channels=gate_size, + inter_channels=inter_size, mode=nonlocal_mode, + sub_sample_factor=sub_sample_factor) + self.gate_block_2 = GridAttentionBlock2D(in_channels=in_size, gating_channels=gate_size, + inter_channels=inter_size, mode=nonlocal_mode, + sub_sample_factor=sub_sample_factor) + self.combine_gates = nn.Sequential(nn.Conv2d(in_size*2, in_size, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(in_size), + nn.ReLU(inplace=True)) + + # initialise the blocks + for m in self.children(): + if m.__class__.__name__.find('GridAttentionBlock2D') != -1: continue + init_weights(m, init_type='kaiming') + + def forward(self, input, gating_signal): + gate_1, attention_1 = self.gate_block_1(input, gating_signal) + gate_2, attention_2 = self.gate_block_2(input, gating_signal) + + return self.combine_gates(torch.cat([gate_1, gate_2], 1)), torch.cat([attention_1, attention_2], 1) + +## 4, Non-local layers +class _NonLocalBlockND(nn.Module): + def __init__(self, in_channels, inter_channels=None, dimension=3, mode='embedded_gaussian', + sub_sample_factor=4, bn_layer=True): + super(_NonLocalBlockND, self).__init__() + + assert dimension in [1, 2, 3] + assert mode in ['embedded_gaussian', 'gaussian', 'dot_product', 'concatenation', 'concat_proper', 'concat_proper_down'] + + # print('Dimension: %d, mode: %s' % (dimension, mode)) + + self.mode = mode + self.dimension = dimension + self.sub_sample_factor = sub_sample_factor if isinstance(sub_sample_factor, list) else [sub_sample_factor] + + self.in_channels = in_channels + self.inter_channels = inter_channels + + if self.inter_channels is None: + self.inter_channels = in_channels // 2 + if self.inter_channels == 0: + self.inter_channels = 1 + + if dimension == 3: + conv_nd = nn.Conv3d + max_pool = nn.MaxPool3d + bn = nn.BatchNorm3d + elif dimension == 2: + conv_nd = nn.Conv2d + max_pool = nn.MaxPool2d + bn = nn.BatchNorm2d + else: + conv_nd = nn.Conv1d + max_pool = nn.MaxPool1d + bn = nn.BatchNorm1d + + self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, + kernel_size=1, stride=1, padding=0) + + if bn_layer: + self.W = nn.Sequential( + conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, + kernel_size=1, stride=1, padding=0), + bn(self.in_channels) + ) + nn.init.constant_(self.W[1].weight, 0) + nn.init.constant_(self.W[1].bias, 0) + else: + self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, + kernel_size=1, stride=1, padding=0) + nn.init.constant_(self.W.weight, 0) + nn.init.constant_(self.W.bias, 0) + + self.theta = None + self.phi = None + + if mode in ['embedded_gaussian', 'dot_product', 'concatenation', 'concat_proper', 'concat_proper_down']: + self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, + kernel_size=1, stride=1, padding=0) + self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, + kernel_size=1, stride=1, padding=0) + + if mode in ['concatenation']: + self.wf_phi = nn.Linear(self.inter_channels, 1, bias=False) + self.wf_theta = nn.Linear(self.inter_channels, 1, bias=False) + elif mode in ['concat_proper', 'concat_proper_down']: + self.psi = nn.Conv2d(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, + padding=0, bias=True) + + if mode == 'embedded_gaussian': + self.operation_function = self._embedded_gaussian + elif mode == 'dot_product': + self.operation_function = self._dot_product + elif mode == 'gaussian': + self.operation_function = self._gaussian + elif mode == 'concatenation': + self.operation_function = self._concatenation + elif mode == 'concat_proper': + self.operation_function = self._concatenation_proper + elif mode == 'concat_proper_down': + self.operation_function = self._concatenation_proper_down + else: + raise NotImplementedError('Unknown operation function.') + + if any(ss > 1 for ss in self.sub_sample_factor): + self.g = nn.Sequential(self.g, max_pool(kernel_size=sub_sample_factor)) + if self.phi is None: + self.phi = max_pool(kernel_size=sub_sample_factor) + else: + self.phi = nn.Sequential(self.phi, max_pool(kernel_size=sub_sample_factor)) + if mode == 'concat_proper_down': + self.theta = nn.Sequential(self.theta, max_pool(kernel_size=sub_sample_factor)) + + # Initialise weights + for m in self.children(): + init_weights(m, init_type='kaiming') + + def forward(self, x): + ''' + :param x: (b, c, t, h, w) + :return: + ''' + + output = self.operation_function(x) + return output + + def _embedded_gaussian(self, x): + batch_size = x.size(0) + + # g=>(b, c, t, h, w)->(b, 0.5c, t, h, w)->(b, thw, 0.5c) + g_x = self.g(x).view(batch_size, self.inter_channels, -1) + g_x = g_x.permute(0, 2, 1) + + # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw, 0.5c) + # phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw) + # f=>(b, thw, 0.5c)dot(b, 0.5c, twh) = (b, thw, thw) + theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) + theta_x = theta_x.permute(0, 2, 1) + phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) + f = torch.matmul(theta_x, phi_x) + f_div_C = F.softmax(f, dim=-1) + + # (b, thw, thw)dot(b, thw, 0.5c) = (b, thw, 0.5c)->(b, 0.5c, t, h, w)->(b, c, t, h, w) + y = torch.matmul(f_div_C, g_x) + y = y.permute(0, 2, 1).contiguous() + y = y.view(batch_size, self.inter_channels, *x.size()[2:]) + W_y = self.W(y) + z = W_y + x + + return z + + def _gaussian(self, x): + batch_size = x.size(0) + g_x = self.g(x).view(batch_size, self.inter_channels, -1) + g_x = g_x.permute(0, 2, 1) + + theta_x = x.view(batch_size, self.in_channels, -1) + theta_x = theta_x.permute(0, 2, 1) + + if self.sub_sample_factor > 1: + phi_x = self.phi(x).view(batch_size, self.in_channels, -1) + else: + phi_x = x.view(batch_size, self.in_channels, -1) + + f = torch.matmul(theta_x, phi_x) + f_div_C = F.softmax(f, dim=-1) + + y = torch.matmul(f_div_C, g_x) + y = y.permute(0, 2, 1).contiguous() + y = y.view(batch_size, self.inter_channels, *x.size()[2:]) + W_y = self.W(y) + z = W_y + x + + return z + + def _dot_product(self, x): + batch_size = x.size(0) + + g_x = self.g(x).view(batch_size, self.inter_channels, -1) + g_x = g_x.permute(0, 2, 1) + + theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) + theta_x = theta_x.permute(0, 2, 1) + phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) + f = torch.matmul(theta_x, phi_x) + N = f.size(-1) + f_div_C = f / N + + y = torch.matmul(f_div_C, g_x) + y = y.permute(0, 2, 1).contiguous() + y = y.view(batch_size, self.inter_channels, *x.size()[2:]) + W_y = self.W(y) + z = W_y + x + + return z + + def _concatenation(self, x): + batch_size = x.size(0) + + # g=>(b, c, t, h, w)->(b, 0.5c, thw/s**2) + g_x = self.g(x).view(batch_size, self.inter_channels, -1) + + # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw, 0.5c) + # phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw/s**2, 0.5c) + theta_x = self.theta(x).view(batch_size, self.inter_channels, -1).permute(0, 2, 1) + phi_x = self.phi(x).view(batch_size, self.inter_channels, -1).permute(0, 2, 1) + + # theta => (b, thw, 0.5c) -> (b, thw, 1) -> (b, 1, thw) -> (expand) (b, thw/s**2, thw) + # phi => (b, thw/s**2, 0.5c) -> (b, thw/s**2, 1) -> (expand) (b, thw/s**2, thw) + # f=> RELU[(b, thw/s**2, thw) + (b, thw/s**2, thw)] = (b, thw/s**2, thw) + f = self.wf_theta(theta_x).permute(0, 2, 1).repeat(1, phi_x.size(1), 1) + \ + self.wf_phi(phi_x).repeat(1, 1, theta_x.size(1)) + f = F.relu(f, inplace=True) + + # Normalise the relations + N = f.size(-1) + f_div_c = f / N + + # g(x_j) * f(x_j, x_i) + # (b, 0.5c, thw/s**2) * (b, thw/s**2, thw) -> (b, 0.5c, thw) + y = torch.matmul(g_x, f_div_c) + y = y.contiguous().view(batch_size, self.inter_channels, *x.size()[2:]) + W_y = self.W(y) + z = W_y + x + + return z + + def _concatenation_proper(self, x): + batch_size = x.size(0) + + # g=>(b, c, t, h, w)->(b, 0.5c, thw/s**2) + g_x = self.g(x).view(batch_size, self.inter_channels, -1) + + # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw) + # phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw/s**2) + theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) + phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) + + # theta => (b, 0.5c, thw) -> (expand) (b, 0.5c, thw/s**2, thw) + # phi => (b, 0.5c, thw/s**2) -> (expand) (b, 0.5c, thw/s**2, thw) + # f=> RELU[(b, 0.5c, thw/s**2, thw) + (b, 0.5c, thw/s**2, thw)] = (b, 0.5c, thw/s**2, thw) + f = theta_x.unsqueeze(dim=2).repeat(1,1,phi_x.size(2),1) + \ + phi_x.unsqueeze(dim=3).repeat(1,1,1,theta_x.size(2)) + f = F.relu(f, inplace=True) + + # psi -> W_psi^t * f -> (b, 1, thw/s**2, thw) -> (b, thw/s**2, thw) + f = torch.squeeze(self.psi(f), dim=1) + + # Normalise the relations + f_div_c = F.softmax(f, dim=1) + + # g(x_j) * f(x_j, x_i) + # (b, 0.5c, thw/s**2) * (b, thw/s**2, thw) -> (b, 0.5c, thw) + y = torch.matmul(g_x, f_div_c) + y = y.contiguous().view(batch_size, self.inter_channels, *x.size()[2:]) + W_y = self.W(y) + z = W_y + x + + return z + + def _concatenation_proper_down(self, x): + batch_size = x.size(0) + + # g=>(b, c, t, h, w)->(b, 0.5c, thw/s**2) + g_x = self.g(x).view(batch_size, self.inter_channels, -1) + + # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw) + # phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw/s**2) + theta_x = self.theta(x) + downsampled_size = theta_x.size() + theta_x = theta_x.view(batch_size, self.inter_channels, -1) + phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) + + # theta => (b, 0.5c, thw) -> (expand) (b, 0.5c, thw/s**2, thw) + # phi => (b, 0.5, thw/s**2) -> (expand) (b, 0.5c, thw/s**2, thw) + # f=> RELU[(b, 0.5c, thw/s**2, thw) + (b, 0.5c, thw/s**2, thw)] = (b, 0.5c, thw/s**2, thw) + f = theta_x.unsqueeze(dim=2).repeat(1,1,phi_x.size(2),1) + \ + phi_x.unsqueeze(dim=3).repeat(1,1,1,theta_x.size(2)) + f = F.relu(f, inplace=True) + + # psi -> W_psi^t * f -> (b, 0.5c, thw/s**2, thw) -> (b, 1, thw/s**2, thw) -> (b, thw/s**2, thw) + f = torch.squeeze(self.psi(f), dim=1) + + # Normalise the relations + f_div_c = F.softmax(f, dim=1) + + # g(x_j) * f(x_j, x_i) + # (b, 0.5c, thw/s**2) * (b, thw/s**2, thw) -> (b, 0.5c, thw) + y = torch.matmul(g_x, f_div_c) + y = y.contiguous().view(batch_size, self.inter_channels, *downsampled_size[2:]) + + # upsample the final featuremaps # (b,0.5c,t/s1,h/s2,w/s3) + y = F.interpolate(y, size=x.size()[2:], mode='trilinear') + + # attention block output + W_y = self.W(y) + z = W_y + x + + return z + + +class NONLocalBlock2D(_NonLocalBlockND): + def __init__(self, in_channels, inter_channels=None, mode='embedded_gaussian', sub_sample_factor=2, bn_layer=True): + super(NONLocalBlock2D, self).__init__(in_channels, + inter_channels=inter_channels, + dimension=2, mode=mode, + sub_sample_factor=sub_sample_factor, + bn_layer=bn_layer) + +## 5, scale attention +class BasicConv(nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, + relu=True, bn=True, bias=False): + super(BasicConv, self).__init__() + self.out_channels = out_planes + self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, + dilation=dilation, groups=groups, bias=bias) + self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None + self.relu = nn.ReLU() if relu else None + + def forward(self, x): + x = self.conv(x) + if self.bn is not None: + x = self.bn(x) + if self.relu is not None: + x = self.relu(x) + return x + + +class Flatten(nn.Module): + def forward(self, x): + return x.view(x.size(0), -1) + + +class ChannelGate(nn.Module): + def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): + super(ChannelGate, self).__init__() + self.gate_channels = gate_channels + self.mlp = nn.Sequential( + Flatten(), + nn.Linear(gate_channels, gate_channels // reduction_ratio), + nn.ReLU(), + nn.Linear(gate_channels // reduction_ratio, gate_channels) + ) + self.pool_types = pool_types + + def forward(self, x): + channel_att_sum = None + for pool_type in self.pool_types: + if pool_type == 'avg': + avg_pool = F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) + channel_att_raw = self.mlp(avg_pool) + elif pool_type == 'max': + max_pool = F.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) + channel_att_raw = self.mlp(max_pool) + elif pool_type == 'lp': + lp_pool = F.lp_pool2d(x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) + channel_att_raw = self.mlp(lp_pool) + elif pool_type == 'lse': + # LSE pool only + lse_pool = logsumexp_2d(x) + channel_att_raw = self.mlp(lse_pool) + + if channel_att_sum is None: + channel_att_sum = channel_att_raw + else: + channel_att_sum = channel_att_sum + channel_att_raw + + # scalecoe = F.sigmoid(channel_att_sum) + channel_att_sum = channel_att_sum.reshape(channel_att_sum.shape[0], 4, 4) + avg_weight = torch.mean(channel_att_sum, dim=2).unsqueeze(2) + avg_weight = avg_weight.expand(channel_att_sum.shape[0], 4, 4).reshape(channel_att_sum.shape[0], 16) + scale = torch.sigmoid(avg_weight).unsqueeze(2).unsqueeze(3).expand_as(x) + + return x * scale, scale + + +def logsumexp_2d(tensor): + tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) + s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) + outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() + return outputs + + +class ChannelPool(nn.Module): + def forward(self, x): + return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1) + + +class SpatialGate(nn.Module): + def __init__(self): + super(SpatialGate, self).__init__() + kernel_size = 7 + self.compress = ChannelPool() + self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False) + + def forward(self, x): + x_compress = self.compress(x) + x_out = self.spatial(x_compress) + scale = torch.sigmoid(x_out) # broadcasting + # spa_scale = scale.expand_as(x) + # print(spa_scale.shape) + return x * scale, scale + +class SpatialAtten(nn.Module): + def __init__(self, in_size, out_size, kernel_size=3, stride=1): + super(SpatialAtten, self).__init__() + self.conv1 = BasicConv(in_size, out_size, kernel_size, stride=stride, + padding=(kernel_size-1) // 2, relu=True) + self.conv2 = BasicConv(out_size, out_size, kernel_size=1, stride=stride, + padding=0, relu=True, bn=False) + + def forward(self, x): + residual = x + x_out = self.conv1(x) + x_out = self.conv2(x_out) + spatial_att = torch.sigmoid(x_out).unsqueeze(4).permute(0, 1, 4, 2, 3) + spatial_att = spatial_att.expand(spatial_att.shape[0], 4, 4, spatial_att.shape[3], spatial_att.shape[4]).reshape( + spatial_att.shape[0], 16, spatial_att.shape[3], spatial_att.shape[4]) + x_out = residual * spatial_att + + x_out += residual + + return x_out, spatial_att + +class Scale_atten_block(nn.Module): + def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False): + super(Scale_atten_block, self).__init__() + self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) + self.no_spatial = no_spatial + if not no_spatial: + self.SpatialGate = SpatialAtten(gate_channels, gate_channels //reduction_ratio) + + def forward(self, x): + x_out, ca_atten = self.ChannelGate(x) + if not self.no_spatial: + x_out, sa_atten = self.SpatialGate(x_out) + + return x_out, ca_atten, sa_atten + + +class scale_atten_convblock(nn.Module): + def __init__(self, in_size, out_size, stride=1, downsample=None, use_cbam=True, no_spatial=False, drop_out=False): + super(scale_atten_convblock, self).__init__() + # if stride != 1 or in_size != out_size: + # downsample = nn.Sequential( + # nn.Conv2d(in_size, out_size, + # kernel_size=1, stride=stride, bias=False), + # nn.BatchNorm2d(out_size), + # ) + self.downsample = downsample + self.stride = stride + self.no_spatial = no_spatial + self.dropout = drop_out + + self.relu = nn.ReLU(inplace=True) + self.conv3 = conv3x3(in_size, out_size) + self.bn3 = nn.BatchNorm2d(out_size) + + if use_cbam: + self.cbam = Scale_atten_block(in_size, reduction_ratio=4, no_spatial=self.no_spatial) # out_size + else: + self.cbam = None + + def forward(self, x): + residual = x + + if self.downsample is not None: + residual = self.downsample(x) + + if not self.cbam is None: + out, scale_c_atten, scale_s_atten = self.cbam(x) + + # scale_c_atten = nn.Sigmoid()(scale_c_atten) + # scale_s_atten = nn.Sigmoid()(scale_s_atten) + # scale_atten = channel_atten_c * spatial_atten_s + + # scale_max = torch.argmax(scale_atten, dim=1, keepdim=True) + # scale_max_soft = get_soft_label(input_tensor=scale_max, num_class=8) + # scale_max_soft = scale_max_soft.permute(0, 3, 1, 2) + # scale_atten_soft = scale_atten * scale_max_soft + + out += residual + out = self.relu(out) + out = self.conv3(out) + out = self.bn3(out) + out = self.relu(out) + + if self.dropout: + out = nn.Dropout2d(0.5)(out) + + return out + +## 6, CANet +class CANet(nn.Module): + """ + Implementation of CANet (Comprehensive Attention Network) for image segmentation. + + * Reference: R. Gu et al. `CA-Net: Comprehensive Attention Convolutional Neural Networks + for Explainable Medical Image Segmentation `_. + IEEE Transactions on Medical Imaging, 40(2),2021:699-711. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 5, such as [16, 32, 64, 128, 256]. + :param class_num: (int) The class number for segmentation task. + :param is_deconv: (bool) Using deconvolution for up-sampling or not. + If False, bilinear interpolation will be used for up-sampling. Default is True. + :param is_batchnorm: (bool) If batch normalization is or not. Default is True. + :param feature_scale: (int) The scale of resolution levels. Default is 4. + """ + def __init__(self, params): #args, in_ch=3, n_classes=2, feature_scale=4, is_deconv=True, is_batchnorm=True, + # nonlocal_mode='concatenation', attention_dsample=(1, 1)): + super(CANet, self).__init__() + self.in_channels = params['in_chns'] + self.num_classes = params['class_num'] + self.feature_chns= params.get('feature_chns', [32, 64, 128, 256, 512]) + self.is_deconv = params.get('is_deconv', True) + self.is_batchnorm = params.get('is_batchnorm', True) + self.feature_scale = params.get('feature_scale', 4) + nonlocal_mode = 'concatenation' + attention_dsample = (1, 1) + + filters = self.feature_chns + filters = [int(x / self.feature_scale) for x in filters] + + # downsampling + self.conv1 = conv_block(self.in_channels, filters[0]) + self.maxpool1 = nn.MaxPool2d(kernel_size=(2, 2)) + + self.conv2 = conv_block(filters[0], filters[1]) + self.maxpool2 = nn.MaxPool2d(kernel_size=(2, 2)) + + self.conv3 = conv_block(filters[1], filters[2]) + self.maxpool3 = nn.MaxPool2d(kernel_size=(2, 2)) + + self.conv4 = conv_block(filters[2], filters[3], drop_out=True) + self.maxpool4 = nn.MaxPool2d(kernel_size=(2, 2)) + + self.center = conv_block(filters[3], filters[4], drop_out=True) + + # attention blocks + # self.attentionblock1 = GridAttentionBlock2D(in_channels=filters[0], gating_channels=filters[1], + # inter_channels=filters[0]) + self.attentionblock2 = MultiAttentionBlock(in_size=filters[1], gate_size=filters[2], inter_size=filters[1], + nonlocal_mode=nonlocal_mode, sub_sample_factor=attention_dsample) + self.attentionblock3 = MultiAttentionBlock(in_size=filters[2], gate_size=filters[3], inter_size=filters[2], + nonlocal_mode=nonlocal_mode, sub_sample_factor=attention_dsample) + self.nonlocal4_2 = NONLocalBlock2D(in_channels=filters[4], inter_channels=filters[4] // 4) + + # upsampling + self.up_concat4 = UpCat(filters[4], filters[3], self.is_deconv) + self.up_concat3 = UpCat(filters[3], filters[2], self.is_deconv) + self.up_concat2 = UpCat(filters[2], filters[1], self.is_deconv) + self.up_concat1 = UpCat(filters[1], filters[0], self.is_deconv) + self.up4 = SE_Conv_Block(filters[4], filters[3], drop_out=True) + self.up3 = SE_Conv_Block(filters[3], filters[2]) + self.up2 = SE_Conv_Block(filters[2], filters[1]) + self.up1 = SE_Conv_Block(filters[1], filters[0]) + + # For deep supervision, project the multi-scale feature maps to the same number of channels + self.dsv1 = nn.Conv2d(in_channels=filters[0], out_channels=filters[0]//2, kernel_size=1) + self.dsv2 = nn.Conv2d(in_channels=filters[1], out_channels=filters[0]//2, kernel_size=1) + self.dsv3 = nn.Conv2d(in_channels=filters[2], out_channels=filters[0]//2, kernel_size=1) + self.dsv4 = nn.Conv2d(in_channels=filters[3], out_channels=filters[0]//2, kernel_size=1) + + self.scale_att = scale_atten_convblock(in_size=filters[0]//2 * 4, out_size=filters[0]) + self.final = nn.Conv2d(filters[0], self.num_classes, kernel_size=1) + + def forward(self, inputs): + x_shape = list(inputs.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + inputs = torch.transpose(inputs, 1, 2) + inputs = torch.reshape(inputs, new_shape) + + # Feature Extraction + conv1 = self.conv1(inputs) + maxpool1 = self.maxpool1(conv1) + + conv2 = self.conv2(maxpool1) + maxpool2 = self.maxpool2(conv2) + + conv3 = self.conv3(maxpool2) + maxpool3 = self.maxpool3(conv3) + + conv4 = self.conv4(maxpool3) + maxpool4 = self.maxpool4(conv4) + + # Gating Signal Generation + center = self.center(maxpool4) + + # Attention Mechanism + # Upscaling Part (Decoder) + up4 = self.up_concat4(conv4, center) + g_conv4 = self.nonlocal4_2(up4) + + up4, att_weight4 = self.up4(g_conv4) + g_conv3, att3 = self.attentionblock3(conv3, up4) + + # atten3_map = att3.cpu().detach().numpy().astype(np.float) + # atten3_map = ndimage.interpolation.zoom(atten3_map, [1.0, 1.0, 224 / atten3_map.shape[2], + # 300 / atten3_map.shape[3]], order=0) + + up3 = self.up_concat3(g_conv3, up4) + up3, att_weight3 = self.up3(up3) + g_conv2, att2 = self.attentionblock2(conv2, up3) + + up2 = self.up_concat2(g_conv2, up3) + up2, att_weight2 = self.up2(up2) + + up1 = self.up_concat1(conv1, up2) + up1, att_weight1 = self.up1(up1) + + # Deep Supervision + dsv1 = self.dsv1(up1) + dsv2 = F.interpolate(self.dsv2(up2), dsv1.shape[2:], mode = 'bilinear') + dsv3 = F.interpolate(self.dsv3(up3), dsv1.shape[2:], mode = 'bilinear') + dsv4 = F.interpolate(self.dsv4(up4), dsv1.shape[2:], mode = 'bilinear') + + dsv_cat = torch.cat([dsv1, dsv2, dsv3, dsv4], dim=1) + out = self.scale_att(dsv_cat) + + out = self.final(out) + if(len(x_shape) == 5): + if(isinstance(out, (list,tuple))): + for i in range(len(out)): + new_shape = [N, D] + list(out[i].shape)[1:] + out[i] = torch.transpose(torch.reshape(out[i], new_shape), 1, 2) + else: + new_shape = [N, D] + list(out.shape)[1:] + out = torch.transpose(torch.reshape(out, new_shape), 1, 2) + + return out + +if __name__ == "__main__": + params = {'in_chns':3, + 'class_num':2} + Net = CANet(params) + Net = Net.double() + + x = np.random.rand(4, 3, 224, 224) + xt = torch.from_numpy(x) + xt = xt.clone().detach() + + y = Net(xt) + print(len(y.size())) + y = y.detach().numpy() + print(y.shape) diff --git a/pymic/net/net2d/unet2d_dual_branch.py b/pymic/net/net2d/unet2d_dual_branch.py index 828bdfe..19a0788 100644 --- a/pymic/net/net2d/unet2d_dual_branch.py +++ b/pymic/net/net2d/unet2d_dual_branch.py @@ -25,11 +25,26 @@ class UNet2D_DualBranch(nn.Module): """ def __init__(self, params): super(UNet2D_DualBranch, self).__init__() - self.output_mode = params.get("output_mode", "average") + params = self.get_default_parameters(params) + self.output_mode = params["output_mode"] self.encoder = Encoder(params) self.decoder1 = Decoder(params) self.decoder2 = Decoder(params) + def get_default_parameters(self, params): + default_param = { + 'feature_chns': [32, 64, 128, 256, 512], + 'dropout': [0.0, 0.0, 0.2, 0.3, 0.4], + 'up_mode': 2, + 'multiscale_pred': False, + 'output_mode': "average" + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params + def forward(self, x): x_shape = list(x.shape) if(len(x_shape) == 5): diff --git a/pymic/net/net2d/unet2d_mcnet.py b/pymic/net/net2d/unet2d_mcnet.py new file mode 100644 index 0000000..be5b16b --- /dev/null +++ b/pymic/net/net2d/unet2d_mcnet.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch.nn as nn +from pymic.net.net2d.unet2d import * + +class MCNet2D(nn.Module): + """ + A tri-branch network using UNet2D as backbone. + + * Reference: Yicheng Wu, Zongyuan Ge et al. Mutual consistency learning for + semi-supervised medical image segmentation. + `Medical Image Analysis 2022. `_ + + The original code is at: https://github.com/ycwu1997/MC-Net + + The parameters for the backbone should be given in the `params` dictionary. + See :mod:`pymic.net.net2d.unet2d.UNet2D` for details. + """ + def __init__(self, params): + super(MCNet2D, self).__init__() + in_chns = params['in_chns'] + class_num = params['class_num'] + params1 = {'in_chns': in_chns, + 'feature_chns': [16, 32, 64, 128, 256], + 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], + 'class_num': class_num, + 'up_mode': 0, + 'multiscale_pred': False } + params2 = {'in_chns': in_chns, + 'feature_chns': [16, 32, 64, 128, 256], + 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], + 'class_num': class_num, + 'up_mode': 1, + 'multiscale_pred': False} + params3 = {'in_chns': in_chns, + 'feature_chns': [16, 32, 64, 128, 256], + 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], + 'class_num': class_num, + 'up_mode': 2, + 'multiscale_pred': False} + self.encoder = Encoder(params1) + self.decoder1 = Decoder(params1) + self.decoder2 = Decoder(params2) + self.decoder3 = Decoder(params3) + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + + feature = self.encoder(x) + output1 = self.decoder1(feature) + new_shape = [N, D] + list(output1.shape)[1:] + output1 = torch.transpose(torch.reshape(output1, new_shape), 1, 2) + if(not self.training): + return output1 + output2 = self.decoder2(feature) + output3 = self.decoder3(feature) + if(len(x_shape) == 5): + output2 = torch.transpose(torch.reshape(output2, new_shape), 1, 2) + output3 = torch.transpose(torch.reshape(output3, new_shape), 1, 2) + return output1, output2, output3 diff --git a/pymic/net/net2d/unet2d_multi_decoder.py b/pymic/net/net2d/unet2d_multi_decoder.py new file mode 100644 index 0000000..03bd99f --- /dev/null +++ b/pymic/net/net2d/unet2d_multi_decoder.py @@ -0,0 +1,163 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import torch.nn as nn +from pymic.net.net2d.unet2d import * + +class UNet2D_DualBranch(nn.Module): + """ + A dual branch network using UNet2D as backbone. + + * Reference: Xiangde Luo, Minhao Hu, Wenjun Liao, Shuwei Zhai, Tao Song, Guotai Wang, + Shaoting Zhang. ScribblScribble-Supervised Medical Image Segmentation via + Dual-Branch Network and Dynamically Mixed Pseudo Labels Supervision. + `MICCAI 2022. `_ + + The parameters for the backbone should be given in the `params` dictionary. + See :mod:`pymic.net.net2d.unet2d.UNet2D` for details. + In addition, the following field should be included: + + :param output_mode: (str) How to obtain the result during the inference. + `average`: taking average of the two branches. + `first`: takeing the result in the first branch. + `second`: taking the result in the second branch. + """ + def __init__(self, params): + super(UNet2D_DualBranch, self).__init__() + params = self.get_default_parameters(params) + self.output_mode = params["output_mode"] + self.encoder = Encoder(params) + self.decoder1 = Decoder(params) + self.decoder2 = Decoder(params) + + def get_default_parameters(self, params): + default_param = { + 'feature_chns': [32, 64, 128, 256, 512], + 'dropout': [0.0, 0.0, 0.2, 0.3, 0.4], + 'up_mode': 2, + 'multiscale_pred': False, + 'output_mode': "average" + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + + f = self.encoder(x) + output1 = self.decoder1(f) + output2 = self.decoder2(f) + if(len(x_shape) == 5): + new_shape = [N, D] + list(output1.shape)[1:] + output1 = torch.reshape(output1, new_shape) + output1 = torch.transpose(output1, 1, 2) + output2 = torch.reshape(output2, new_shape) + output2 = torch.transpose(output2, 1, 2) + + return output1, output2 + # if(self.training): + # return output1, output2 + # else: + # if(self.output_mode == "average"): + # return (output1 + output2)/2 + # elif(self.output_mode == "first"): + # return output1 + # else: + # return output2 + +class UNet2D_TriBranch(nn.Module): + """ + A tri-branch network using UNet2D as backbone. The super class for MCNet2D and MTNet2D. + + The parameters for the backbone should be given in the `params` dictionary. + See :mod:`pymic.net.net2d.unet2d.UNet2D` for details. + """ + def __init__(self, params): + super(UNet2D_TriBranch, self).__init__() + params = self.get_default_parameters(params) + self.encoder = Encoder(params) + self.decoder1 = Decoder(params) + self.decoder2 = Decoder(params) + self.decoder3 = Decoder(params) + + def get_default_parameters(self, params): + default_param = { + 'feature_chns': [32, 64, 128, 256, 512], + 'dropout': [0.0, 0.0, 0.2, 0.3, 0.4], + 'up_mode': 2, + 'multiscale_pred': False, + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + + feature = self.encoder(x) + output1 = self.decoder1(feature) + new_shape = [N, D] + list(output1.shape)[1:] + output1 = torch.transpose(torch.reshape(output1, new_shape), 1, 2) + if(not self.training): + return output1 + output2 = self.decoder2(feature) + output3 = self.decoder3(feature) + if(len(x_shape) == 5): + output2 = torch.transpose(torch.reshape(output2, new_shape), 1, 2) + output3 = torch.transpose(torch.reshape(output3, new_shape), 1, 2) + return output1, output2, output3 + +class MCNet2D(UNet2D_TriBranch): + """ + A tri-branch network using UNet2D as backbone. + + * Reference: Yicheng Wu, Zongyuan Ge et al. Mutual consistency learning for + semi-supervised medical image segmentation. + `Medical Image Analysis 2022. `_ + + The original code is at: https://github.com/ycwu1997/MC-Net + + The parameters for the backbone should be given in the `params` dictionary. + See :mod:`pymic.net.net2d.unet2d.UNet2D` for details. + """ + def __init__(self, params): + super(MCNet2D, self).__init__(params) + in_chns = params['in_chns'] + class_num = params['class_num'] + ft_chns = params['feature_chns'] + dropout = params['dropout'] + params1 = {'in_chns': in_chns, + 'feature_chns': ft_chns, + 'dropout': dropout, + 'class_num': class_num, + 'up_mode': 0 } + params2 = {'in_chns': in_chns, + 'feature_chns': ft_chns, + 'dropout': dropout, + 'class_num': class_num, + 'up_mode': 1 } + params3 = {'in_chns': in_chns, + 'feature_chns': ft_chns, + 'dropout': dropout, + 'class_num': class_num, + 'up_mode': 2 } + self.encoder = Encoder(params1) + self.decoder1 = Decoder(params1) + self.decoder2 = Decoder(params2) + self.decoder3 = Decoder(params3) \ No newline at end of file diff --git a/pymic/net/net2d/unet2d_nest.py b/pymic/net/net2d/unet2d_pp.py similarity index 96% rename from pymic/net/net2d/unet2d_nest.py rename to pymic/net/net2d/unet2d_pp.py index efa048f..f2a003b 100644 --- a/pymic/net/net2d/unet2d_nest.py +++ b/pymic/net/net2d/unet2d_pp.py @@ -3,9 +3,9 @@ import torch.nn as nn from pymic.net.net2d.unet2d import * -class NestedUNet2D(nn.Module): +class UNet2Dpp(nn.Module): """ - An implementation of the Nested U-Net. + An implementation of the U-Net++. * Reference: Zongwei Zhou, et al.: `UNet++: A Nested U-Net Architecture for Medical Image Segmentation. `_ @@ -25,7 +25,7 @@ class NestedUNet2D(nn.Module): :param class_num: (int) The class number for segmentation task. """ def __init__(self, params): - super(NestedUNet2D, self).__init__() + super(UNet2Dpp, self).__init__() self.params = params self.in_chns = self.params['in_chns'] self.filters = self.params['feature_chns'] @@ -96,7 +96,7 @@ def forward(self, x): 'feature_chns':[2, 8, 32, 48, 64], 'dropout': [0, 0, 0.3, 0.4, 0.5], 'class_num': 2} - Net = NestedUNet2D(params) + Net = UNet2Dpp(params) Net = Net.double() x = np.random.rand(4, 4, 10, 96, 96) diff --git a/pymic/net/net2d/unet2d_scse.py b/pymic/net/net2d/unet2d_scse.py index 125843e..54a5d2f 100644 --- a/pymic/net/net2d/unet2d_scse.py +++ b/pymic/net/net2d/unet2d_scse.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn import numpy as np +from pymic.net.net2d.unet2d import UpBlock, Encoder, Decoder, UNet2D from pymic.net.net2d.scse2d import * class ConvScSEBlock(nn.Module): @@ -50,116 +51,64 @@ def __init__(self, in_channels, out_channels, dropout_p): def forward(self, x): return self.maxpool_conv(x) -class UpBlock(nn.Module): +class UpBlockScSE(UpBlock): """Up-sampling followed by `ConvScSEBlock` in U-Net structure. - :param in_channels1: (int) Input channel number for low-resolution feature map. - :param in_channels2: (int) Input channel number for high-resolution feature map. - :param out_channels: (int) Output channel number. - :param dropout_p: (int) Dropout probability. - :param bilinear: (bool) Use bilinear for up-sampling or not. + The parameters for the backbone should be given in the `params` dictionary. + See :mod:`pymic.net.net2d.unet2d.UpBlock` for details. """ - def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, - bilinear=True): - super(UpBlock, self).__init__() - self.bilinear = bilinear - if bilinear: - self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size = 1) - self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) - else: - self.up = nn.ConvTranspose2d(in_channels1, in_channels2, kernel_size=2, stride=2) + def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, up_mode = 2): + super(UpBlockScSE, self).__init__(in_channels1, in_channels2, out_channels, dropout_p, up_mode) self.conv = ConvScSEBlock(in_channels2 * 2, out_channels, dropout_p) - def forward(self, x1, x2): - if self.bilinear: - x1 = self.conv1x1(x1) - x1 = self.up(x1) - x = torch.cat([x2, x1], dim=1) - return self.conv(x) - -class UNet2D_ScSE(nn.Module): +class EncoderScSE(Encoder): """ - Combining 2D U-Net with SCSE module. + Encoder of 2D UNet with ScSE. - * Reference: Abhijit Guha Roy, Nassir Navab, Christian Wachinger: - Recalibrating Fully Convolutional Networks With Spatial and Channel - "Squeeze and Excitation" Blocks. - `IEEE Trans. Med. Imaging 38(2): 540-549 (2019). `_ - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param bilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. + The parameters for the backbone should be given in the `params` dictionary. + See :mod:`pymic.net.net2d.unet2d.Encoder` for details. """ def __init__(self, params): - super(UNet2D_ScSE, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.n_class = self.params['class_num'] - self.bilinear = self.params['bilinear'] - assert(len(self.ft_chns) == 5) + super(EncoderScSE, self).__init__(params) self.in_conv= ConvScSEBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) - self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p = self.dropout[3]) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p = self.dropout[2]) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p = self.dropout[1]) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p = self.dropout[0]) - - self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, - kernel_size = 3, padding = 1) + if(len(self.ft_chns) == 5): + self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - def forward(self, x): - x_shape = list(x.shape) - if(len(x_shape) == 5): - [N, C, D, H, W] = x_shape - new_shape = [N*D, C, H, W] - x = torch.transpose(x, 1, 2) - x = torch.reshape(x, new_shape) - x0 = self.in_conv(x) - x1 = self.down1(x0) - x2 = self.down2(x1) - x3 = self.down3(x2) - x4 = self.down4(x3) +class DecoderScSE(Decoder): + """ + Decoder of 2D UNet with ScSE. + + The parameters for the backbone should be given in the `params` dictionary. + See :mod:`pymic.net.net2d.unet2d.Decoder` for details. + """ + def __init__(self, params): + super(DecoderScSE, self).__init__(params) - x = self.up1(x4, x3) - x = self.up2(x, x2) - x = self.up3(x, x1) - x = self.up4(x, x0) - output = self.out_conv(x) - - if(len(x_shape) == 5): - new_shape = [N, D] + list(output.shape)[1:] - output = torch.reshape(output, new_shape) - output = torch.transpose(output, 1, 2) - return output - -if __name__ == "__main__": - params = {'in_chns':4, - 'feature_chns':[2, 8, 32, 48, 64], - 'dropout': [0, 0, 0.3, 0.4, 0.5], - 'class_num': 2, - 'bilinear': True} - Net = UNet2D_ScSE(params) - Net = Net.double() - - x = np.random.rand(4, 4, 10, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - print(len(y.size())) - y = y.detach().numpy() - print(y.shape) \ No newline at end of file + + if(len(self.ft_chns) == 5): + self.up1 = UpBlockScSE(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.up_mode) + self.up2 = UpBlockScSE(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.up_mode) + self.up3 = UpBlockScSE(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.up_mode) + self.up4 = UpBlockScSE(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.up_mode) + + +class UNet2D_ScSE(UNet2D): + """ + Combining 2D U-Net with SCSE module. + + * Reference: Abhijit Guha Roy, Nassir Navab, Christian Wachinger: + Recalibrating Fully Convolutional Networks With Spatial and Channel + "Squeeze and Excitation" Blocks. + `IEEE Trans. Med. Imaging 38(2): 540-549 (2019). `_ + + The parameters for the backbone should be given in the `params` dictionary. + See :mod:`pymic.net.net2d.unet2d.unet2d` for details. + """ + def __init__(self, params): + super(UNet2D_ScSE, self).__init__(params) + self.encoder = Encoder(params) + self.decoder = Decoder(params) diff --git a/pymic/net/net2d/unet2d_vm.py b/pymic/net/net2d/unet2d_vm.py new file mode 100644 index 0000000..126fa30 --- /dev/null +++ b/pymic/net/net2d/unet2d_vm.py @@ -0,0 +1,820 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import math +from functools import partial +from typing import Optional, Callable + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from einops import rearrange, repeat +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +try: + from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref +except: + pass + +# an alternative for mamba_ssm (in which causal_conv1d is needed) +try: + from selective_scan import selective_scan_fn as selective_scan_fn_v1 + from selective_scan import selective_scan_ref as selective_scan_ref_v1 +except: + pass + +DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})" + + +def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False): + """ + u: r(B D L) + delta: r(B D L) + A: r(D N) + B: r(B N L) + C: r(B N L) + D: r(D) + z: r(B D L) + delta_bias: r(D), fp32 + + ignores: + [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] + """ + import numpy as np + + # fvcore.nn.jit_handles + def get_flops_einsum(input_shapes, equation): + np_arrs = [np.zeros(s) for s in input_shapes] + optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1] + for line in optim.split("\n"): + if "optimized flop" in line.lower(): + # divided by 2 because we count MAC (multiply-add counted as one flop) + flop = float(np.floor(float(line.split(":")[-1]) / 2)) + return flop + + + assert not with_complex + + flops = 0 # below code flops = 0 + if False: + ... + """ + dtype_in = u.dtype + u = u.float() + delta = delta.float() + if delta_bias is not None: + delta = delta + delta_bias[..., None].float() + if delta_softplus: + delta = F.softplus(delta) + batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] + is_variable_B = B.dim() >= 3 + is_variable_C = C.dim() >= 3 + if A.is_complex(): + if is_variable_B: + B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) + if is_variable_C: + C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) + else: + B = B.float() + C = C.float() + x = A.new_zeros((batch, dim, dstate)) + ys = [] + """ + + flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln") + if with_Group: + flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln") + else: + flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln") + if False: + ... + """ + deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) + if not is_variable_B: + deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) + else: + if B.dim() == 3: + deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) + else: + B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) + deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) + if is_variable_C and C.dim() == 4: + C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) + last_state = None + """ + + in_for_flops = B * D * N + if with_Group: + in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd") + else: + in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd") + flops += L * in_for_flops + if False: + ... + """ + for i in range(u.shape[2]): + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + if not is_variable_C: + y = torch.einsum('bdn,dn->bd', x, C) + else: + if C.dim() == 3: + y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) + else: + y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) + if i == u.shape[2] - 1: + last_state = x + if y.is_complex(): + y = y.real * 2 + ys.append(y) + y = torch.stack(ys, dim=2) # (batch dim L) + """ + + if with_D: + flops += B * D * L + if with_Z: + flops += B * D * L + if False: + ... + """ + out = y if D is None else y + u * rearrange(D, "d -> d 1") + if z is not None: + out = out * F.silu(z) + out = out.to(dtype=dtype_in) + """ + + return flops + + +class PatchEmbed2D(nn.Module): + r""" Image to Patch Embedding + Args: + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, **kwargs): + super().__init__() + if isinstance(patch_size, int): + patch_size = (patch_size, patch_size) + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = self.proj(x).permute(0, 2, 3, 1) + if self.norm is not None: + x = self.norm(x) + return x + + +class PatchMerging2D(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + B, H, W, C = x.shape + + SHAPE_FIX = [-1, -1] + if (W % 2 != 0) or (H % 2 != 0): + print(f"Warning, x.shape {x.shape} is not match even ===========", flush=True) + SHAPE_FIX[0] = H // 2 + SHAPE_FIX[1] = W // 2 + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + + if SHAPE_FIX[0] > 0: + x0 = x0[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] + x1 = x1[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] + x2 = x2[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] + x3 = x3[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] + + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, H//2, W//2, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +class PatchExpand2D(nn.Module): + def __init__(self, dim, dim_scale=2, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim*2 + self.dim_scale = dim_scale + self.expand = nn.Linear(self.dim, dim_scale*self.dim, bias=False) + self.norm = norm_layer(self.dim // dim_scale) + + def forward(self, x): + B, H, W, C = x.shape + x = self.expand(x) + + x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//self.dim_scale) + x= self.norm(x) + + return x + + +class Final_PatchExpand2D(nn.Module): + def __init__(self, dim, dim_scale=4, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.dim_scale = dim_scale + self.expand = nn.Linear(self.dim, dim_scale*self.dim, bias=False) + self.norm = norm_layer(self.dim // dim_scale) + + def forward(self, x): + B, H, W, C = x.shape + x = self.expand(x) + + x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//self.dim_scale) + x= self.norm(x) + + return x + + +class SS2D(nn.Module): + def __init__( + self, + d_model, + d_state=16, + # d_state="auto", # 20240109 + d_conv=3, + expand=2, + dt_rank="auto", + dt_min=0.001, + dt_max=0.1, + dt_init="random", + dt_scale=1.0, + dt_init_floor=1e-4, + dropout=0., + conv_bias=True, + bias=False, + device=None, + dtype=None, + **kwargs, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.d_model = d_model + self.d_state = d_state + # self.d_state = math.ceil(self.d_model / 6) if d_state == "auto" else d_model # 20240109 + self.d_conv = d_conv + self.expand = expand + self.d_inner = int(self.expand * self.d_model) + self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank + + self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) + self.conv2d = nn.Conv2d( + in_channels=self.d_inner, + out_channels=self.d_inner, + groups=self.d_inner, + bias=conv_bias, + kernel_size=d_conv, + padding=(d_conv - 1) // 2, + **factory_kwargs, + ) + self.act = nn.SiLU() + + self.x_proj = ( + nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), + nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), + nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), + nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), + ) + self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K=4, N, inner) + del self.x_proj + + self.dt_projs = ( + self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), + self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), + self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), + self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), + ) + self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K=4, inner, rank) + self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K=4, inner) + del self.dt_projs + + self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=4, merge=True) # (K=4, D, N) + self.Ds = self.D_init(self.d_inner, copies=4, merge=True) # (K=4, D, N) + + # self.selective_scan = selective_scan_fn + self.forward_core = self.forward_corev0 + + self.out_norm = nn.LayerNorm(self.d_inner) + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + self.dropout = nn.Dropout(dropout) if dropout > 0. else None + + @staticmethod + def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs): + dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs) + + # Initialize special dt projection to preserve variance at initialization + dt_init_std = dt_rank**-0.5 * dt_scale + if dt_init == "constant": + nn.init.constant_(dt_proj.weight, dt_init_std) + elif dt_init == "random": + nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) + else: + raise NotImplementedError + + # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max + dt = torch.exp( + torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ).clamp(min=dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + with torch.no_grad(): + dt_proj.bias.copy_(inv_dt) + # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit + dt_proj.bias._no_reinit = True + + return dt_proj + + @staticmethod + def A_log_init(d_state, d_inner, copies=1, device=None, merge=True): + # S4D real initialization + A = repeat( + torch.arange(1, d_state + 1, dtype=torch.float32, device=device), + "n -> d n", + d=d_inner, + ).contiguous() + A_log = torch.log(A) # Keep A_log in fp32 + if copies > 1: + A_log = repeat(A_log, "d n -> r d n", r=copies) + if merge: + A_log = A_log.flatten(0, 1) + A_log = nn.Parameter(A_log) + A_log._no_weight_decay = True + return A_log + + @staticmethod + def D_init(d_inner, copies=1, device=None, merge=True): + # D "skip" parameter + D = torch.ones(d_inner, device=device) + if copies > 1: + D = repeat(D, "n1 -> r n1", r=copies) + if merge: + D = D.flatten(0, 1) + D = nn.Parameter(D) # Keep in fp32 + D._no_weight_decay = True + return D + + def forward_corev0(self, x: torch.Tensor): + self.selective_scan = selective_scan_fn + + B, C, H, W = x.shape + L = H * W + K = 4 + + x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) + xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l) + + x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight) + # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1) + dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2) + dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight) + # dts = dts + self.dt_projs_bias.view(1, K, -1, 1) + + xs = xs.float().view(B, -1, L) # (b, k * d, l) + dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l) + Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l) + Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l) + Ds = self.Ds.float().view(-1) # (k * d) + As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state) + dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d) + + out_y = self.selective_scan( + xs, dts, + As, Bs, Cs, Ds, z=None, + delta_bias=dt_projs_bias, + delta_softplus=True, + return_last_state=False, + ).view(B, K, -1, L) + assert out_y.dtype == torch.float + + inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) + wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) + invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) + + return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y + + # an alternative to forward_corev1 + def forward_corev1(self, x: torch.Tensor): + self.selective_scan = selective_scan_fn_v1 + + B, C, H, W = x.shape + L = H * W + K = 4 + + x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) + xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l) + + x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight) + # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1) + dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2) + dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight) + # dts = dts + self.dt_projs_bias.view(1, K, -1, 1) + + xs = xs.float().view(B, -1, L) # (b, k * d, l) + dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l) + Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l) + Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l) + Ds = self.Ds.float().view(-1) # (k * d) + As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state) + dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d) + + out_y = self.selective_scan( + xs, dts, + As, Bs, Cs, Ds, + delta_bias=dt_projs_bias, + delta_softplus=True, + ).view(B, K, -1, L) + assert out_y.dtype == torch.float + + inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) + wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) + invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) + + return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y + + def forward(self, x: torch.Tensor, **kwargs): + B, H, W, C = x.shape + + xz = self.in_proj(x) + x, z = xz.chunk(2, dim=-1) # (b, h, w, d) + + x = x.permute(0, 3, 1, 2).contiguous() + x = self.act(self.conv2d(x)) # (b, d, h, w) + y1, y2, y3, y4 = self.forward_core(x) + assert y1.dtype == torch.float32 + y = y1 + y2 + y3 + y4 + y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1) + y = self.out_norm(y) + y = y * F.silu(z) + out = self.out_proj(y) + if self.dropout is not None: + out = self.dropout(out) + return out + + +class VSSBlock(nn.Module): + def __init__( + self, + hidden_dim: int = 0, + drop_path: float = 0, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + attn_drop_rate: float = 0, + d_state: int = 16, + **kwargs, + ): + super().__init__() + self.ln_1 = norm_layer(hidden_dim) + self.self_attention = SS2D(d_model=hidden_dim, dropout=attn_drop_rate, d_state=d_state, **kwargs) + self.drop_path = DropPath(drop_path) + + def forward(self, input: torch.Tensor): + x = input + self.drop_path(self.self_attention(self.ln_1(input))) + return x + + +class VSSLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + depth (int): Number of blocks. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + dim, + depth, + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + d_state=16, + **kwargs, + ): + super().__init__() + self.dim = dim + self.use_checkpoint = use_checkpoint + + self.blocks = nn.ModuleList([ + VSSBlock( + hidden_dim=dim, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + attn_drop_rate=attn_drop, + d_state=d_state, + ) + for i in range(depth)]) + + if True: # is this really applied? Yes, but been overriden later in VSSM! + def _init_weights(module: nn.Module): + for name, p in module.named_parameters(): + if name in ["out_proj.weight"]: + p = p.clone().detach_() # fake init, just to keep the seed .... + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + self.apply(_init_weights) + + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + + if self.downsample is not None: + x = self.downsample(x) + + return x + + + +class VSSLayer_up(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + depth (int): Number of blocks. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + dim, + depth, + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + upsample=None, + use_checkpoint=False, + d_state=16, + **kwargs, + ): + super().__init__() + self.dim = dim + self.use_checkpoint = use_checkpoint + + self.blocks = nn.ModuleList([ + VSSBlock( + hidden_dim=dim, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + attn_drop_rate=attn_drop, + d_state=d_state, + ) + for i in range(depth)]) + + if True: # is this really applied? Yes, but been overriden later in VSSM! + def _init_weights(module: nn.Module): + for name, p in module.named_parameters(): + if name in ["out_proj.weight"]: + p = p.clone().detach_() # fake init, just to keep the seed .... + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + self.apply(_init_weights) + + if upsample is not None: + self.upsample = upsample(dim=dim, norm_layer=norm_layer) + else: + self.upsample = None + + + def forward(self, x): + if self.upsample is not None: + x = self.upsample(x) + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + return x + + + +# class VSSM(nn.Module): +class VMUNet(nn.Module): + """ + VM_UNet that is a UNet-like pure vision mambal network for segmentation. + + * Reference: Jiacheng Ruan et al., VM-UNet: Vision Mamba UNet for Medical Image Segmentation. + arxiv 2403.09157, 2024. + + The implementation is based on the code at: + https://github.com/JCruan519/VM-UNet. + + The parameters for the backbone should be given in the `params` dictionary. + + :param in_chns: (int) Input channel number. + :param class_num: (int) The class number for segmentation task. + :param depths: (list) The depth of VSS block at each resolution level. + The length should be 4, by default it is [2, 2, 9, 2]. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4, by default it is [96, 192, 384, 768]. + """ + # def __init__(self, c=4, in_chans=3, num_classes=1000, depths=[2, 2, 9, 2], depths_decoder=[2, 9, 2, 2], + # dims=[96, 192, 384, 768], dims_decoder=[768, 384, 192, 96], d_state=16, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + # norm_layer=nn.LayerNorm, patch_norm=True, + # use_checkpoint=False, **kwargs): + # super().__init__() + def __init__(self, params): + super(VMUNet, self).__init__() + in_chans = params['in_chns'] + num_classes = params['class_num'] + patch_size = params.get('patch_size', 4) + depths = params.get('depths', [2, 2, 9, 2]) + depths_decoder = depths.copy() + depths_decoder.reverse() + dims = params.get('feature_chns', [96, 192, 384, 768]) + dims_decoder = dims.copy() + dims_decoder.reverse() + d_state = params.get('d_state', 16) + drop_rate = params.get('drop_rate', 0.) + attn_drop_rate = params.get('att_drop_rate', 0.) + drop_path_rate = params.get('path_drop_rate', 0.1) + + norm_layer = nn.LayerNorm + patch_norm = True + use_checkpoint = False + + self.num_classes = num_classes + self.num_layers = len(depths) + if isinstance(dims, int): + dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)] + self.embed_dim = dims[0] + self.num_features = dims[-1] + self.dims = dims + + self.patch_embed = PatchEmbed2D(patch_size=patch_size, in_chans=in_chans, embed_dim=self.embed_dim, + norm_layer=norm_layer if patch_norm else None) + + # WASTED absolute position embedding ====================== + self.ape = False + # self.ape = False + # drop_rate = 0.0 + if self.ape: + self.patches_resolution = self.patch_embed.patches_resolution + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, *self.patches_resolution, self.embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + dpr_decoder = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths_decoder))][::-1] + + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = VSSLayer( + dim=dims[i_layer], + depth=depths[i_layer], + d_state=math.ceil(dims[0] / 6) if d_state is None else d_state, # 20240109 + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging2D if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + ) + self.layers.append(layer) + + self.layers_up = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = VSSLayer_up( + dim=dims_decoder[i_layer], + depth=depths_decoder[i_layer], + d_state=math.ceil(dims[0] / 6) if d_state is None else d_state, # 20240109 + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr_decoder[sum(depths_decoder[:i_layer]):sum(depths_decoder[:i_layer + 1])], + norm_layer=norm_layer, + upsample=PatchExpand2D if (i_layer != 0) else None, + use_checkpoint=use_checkpoint, + ) + self.layers_up.append(layer) + + self.final_up = Final_PatchExpand2D(dim=dims_decoder[-1], dim_scale=4, norm_layer=norm_layer) + self.final_conv = nn.Conv2d(dims_decoder[-1]//4, num_classes, 1) + + # self.norm = norm_layer(self.num_features) + # self.avgpool = nn.AdaptiveAvgPool1d(1) + # self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m: nn.Module): + """ + out_proj.weight which is previously initilized in VSSBlock, would be cleared in nn.Linear + no fc.weight found in the any of the model parameters + no nn.Embedding found in the any of the model parameters + so the thing is, VSSBlock initialization is useless + + Conv2D is not intialized !!! + """ + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def forward_features(self, x): + skip_list = [] + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + skip_list.append(x) + x = layer(x) + return x, skip_list + + def forward_features_up(self, x, skip_list): + for inx, layer_up in enumerate(self.layers_up): + if inx == 0: + x = layer_up(x) + else: + x = layer_up(x+skip_list[-inx]) + + return x + + def forward_final(self, x): + x = self.final_up(x) + x = x.permute(0,3,1,2) + x = self.final_conv(x) + return x + + def forward_backbone(self, x): + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x) + return x + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + x, skip_list = self.forward_features(x) + x = self.forward_features_up(x, skip_list) + x = self.forward_final(x) + + if(len(x_shape) == 5): + new_shape = [N, D] + list(x.shape)[1:] + x = torch.transpose(torch.reshape(x, new_shape), 1, 2) + + return x + + + + + + diff --git a/pymic/net/net2d/unet2d_vm_light.py b/pymic/net/net2d/unet2d_vm_light.py new file mode 100644 index 0000000..ab1de76 --- /dev/null +++ b/pymic/net/net2d/unet2d_vm_light.py @@ -0,0 +1,284 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import math +import torch +from torch import nn +import torch.nn.functional as F + +from timm.models.layers import trunc_normal_ +from mamba_ssm import Mamba + + +class PVMLayer(nn.Module): + def __init__(self, input_dim, output_dim, d_state = 16, d_conv = 4, expand = 2): + super().__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.norm = nn.LayerNorm(input_dim) + self.mamba = Mamba( + d_model=input_dim//4, # Model dimension d_model + d_state=d_state, # SSM state expansion factor + d_conv=d_conv, # Local convolution width + expand=expand, # Block expansion factor + ) + self.proj = nn.Linear(input_dim, output_dim) + self.skip_scale= nn.Parameter(torch.ones(1)) + + def forward(self, x): + if x.dtype == torch.float16: + x = x.type(torch.float32) + B, C = x.shape[:2] + assert C == self.input_dim + n_tokens = x.shape[2:].numel() + img_dims = x.shape[2:] + x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2) + x_norm = self.norm(x_flat) + + x1, x2, x3, x4 = torch.chunk(x_norm, 4, dim=2) + x_mamba1 = self.mamba(x1) + self.skip_scale * x1 + x_mamba2 = self.mamba(x2) + self.skip_scale * x2 + x_mamba3 = self.mamba(x3) + self.skip_scale * x3 + x_mamba4 = self.mamba(x4) + self.skip_scale * x4 + x_mamba = torch.cat([x_mamba1, x_mamba2,x_mamba3,x_mamba4], dim=2) + + x_mamba = self.norm(x_mamba) + x_mamba = self.proj(x_mamba) + out = x_mamba.transpose(-1, -2).reshape(B, self.output_dim, *img_dims) + return out + + +class Channel_Att_Bridge(nn.Module): + def __init__(self, c_list, split_att='fc'): + super().__init__() + c_list_sum = sum(c_list) - c_list[-1] + self.split_att = split_att + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.get_all_att = nn.Conv1d(1, 1, kernel_size=3, padding=1, bias=False) + self.att1 = nn.Linear(c_list_sum, c_list[0]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[0], 1) + self.att2 = nn.Linear(c_list_sum, c_list[1]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[1], 1) + self.att3 = nn.Linear(c_list_sum, c_list[2]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[2], 1) + self.att4 = nn.Linear(c_list_sum, c_list[3]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[3], 1) + self.att5 = nn.Linear(c_list_sum, c_list[4]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[4], 1) + self.sigmoid = nn.Sigmoid() + + def forward(self, t1, t2, t3, t4, t5): + att = torch.cat((self.avgpool(t1), + self.avgpool(t2), + self.avgpool(t3), + self.avgpool(t4), + self.avgpool(t5)), dim=1) + att = self.get_all_att(att.squeeze(-1).transpose(-1, -2)) + if self.split_att != 'fc': + att = att.transpose(-1, -2) + att1 = self.sigmoid(self.att1(att)) + att2 = self.sigmoid(self.att2(att)) + att3 = self.sigmoid(self.att3(att)) + att4 = self.sigmoid(self.att4(att)) + att5 = self.sigmoid(self.att5(att)) + if self.split_att == 'fc': + att1 = att1.transpose(-1, -2).unsqueeze(-1).expand_as(t1) + att2 = att2.transpose(-1, -2).unsqueeze(-1).expand_as(t2) + att3 = att3.transpose(-1, -2).unsqueeze(-1).expand_as(t3) + att4 = att4.transpose(-1, -2).unsqueeze(-1).expand_as(t4) + att5 = att5.transpose(-1, -2).unsqueeze(-1).expand_as(t5) + else: + att1 = att1.unsqueeze(-1).expand_as(t1) + att2 = att2.unsqueeze(-1).expand_as(t2) + att3 = att3.unsqueeze(-1).expand_as(t3) + att4 = att4.unsqueeze(-1).expand_as(t4) + att5 = att5.unsqueeze(-1).expand_as(t5) + + return att1, att2, att3, att4, att5 + + +class Spatial_Att_Bridge(nn.Module): + def __init__(self): + super().__init__() + self.shared_conv2d = nn.Sequential(nn.Conv2d(2, 1, 7, stride=1, padding=9, dilation=3), + nn.Sigmoid()) + + def forward(self, t1, t2, t3, t4, t5): + t_list = [t1, t2, t3, t4, t5] + att_list = [] + for t in t_list: + avg_out = torch.mean(t, dim=1, keepdim=True) + max_out, _ = torch.max(t, dim=1, keepdim=True) + att = torch.cat([avg_out, max_out], dim=1) + att = self.shared_conv2d(att) + att_list.append(att) + return att_list[0], att_list[1], att_list[2], att_list[3], att_list[4] + + +class SC_Att_Bridge(nn.Module): + def __init__(self, c_list, split_att='fc'): + super().__init__() + + self.catt = Channel_Att_Bridge(c_list, split_att=split_att) + self.satt = Spatial_Att_Bridge() + + def forward(self, t1, t2, t3, t4, t5): + r1, r2, r3, r4, r5 = t1, t2, t3, t4, t5 + + satt1, satt2, satt3, satt4, satt5 = self.satt(t1, t2, t3, t4, t5) + t1, t2, t3, t4, t5 = satt1 * t1, satt2 * t2, satt3 * t3, satt4 * t4, satt5 * t5 + + r1_, r2_, r3_, r4_, r5_ = t1, t2, t3, t4, t5 + t1, t2, t3, t4, t5 = t1 + r1, t2 + r2, t3 + r3, t4 + r4, t5 + r5 + + catt1, catt2, catt3, catt4, catt5 = self.catt(t1, t2, t3, t4, t5) + t1, t2, t3, t4, t5 = catt1 * t1, catt2 * t2, catt3 * t3, catt4 * t4, catt5 * t5 + + return t1 + r1_, t2 + r2_, t3 + r3_, t4 + r4_, t5 + r5_ + + +class UltraLight_VM_UNet(nn.Module): + def __init__(self, params): + """ + UltraLight_VM_UNet that is a lightweight model using CNN and Mamba. + + * Reference: Renkai Wu, Yinghao Liu, Pengchen Liang, Qing Chang. + UltraLight VM-UNet: Parallel Vision Mamba Significantly Reduces Parameters for Skin Lesion Segmentation. + arxiv 2403.20035, 2024. + + The implementation is based on the code at: + https://github.com/wurenkai/UltraLight-VM-UNet. + + The parameters for the backbone should be given in the `params` dictionary. + + :param in_chns: (int) Input channel number. + :param class_num: (int) The class number for segmentation task. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 6, by default it is [8, 16, 24, 32, 48, 64]. + :param bridge: (bool) If the bridge based on spatial and channel attentions is used or not. + By default it is True. + """ + super(UltraLight_VM_UNet, self).__init__() + + input_channels = params['in_chns'] + num_classes = params['class_num'] + c_list = params.get('feature_chns', [8, 16, 24, 32, 48, 64]) + self.bridge = params.get('bridge', True) + split_att = 'fc' + # def __init__(self, num_classes=1, input_channels=3, c_list=[8,16,24,32,48,64], + # split_att='fc', bridge=True): + # super().__init__() + # self.bridge = bridge + + self.encoder1 = nn.Sequential( + nn.Conv2d(input_channels, c_list[0], 3, stride=1, padding=1), + ) + self.encoder2 =nn.Sequential( + nn.Conv2d(c_list[0], c_list[1], 3, stride=1, padding=1), + ) + self.encoder3 = nn.Sequential( + nn.Conv2d(c_list[1], c_list[2], 3, stride=1, padding=1), + ) + self.encoder4 = nn.Sequential( + PVMLayer(input_dim=c_list[2], output_dim=c_list[3]) + ) + self.encoder5 = nn.Sequential( + PVMLayer(input_dim=c_list[3], output_dim=c_list[4]) + ) + self.encoder6 = nn.Sequential( + PVMLayer(input_dim=c_list[4], output_dim=c_list[5]) + ) + + if self.bridge: + self.scab = SC_Att_Bridge(c_list, split_att) + print('SC_Att_Bridge was used') + + self.decoder1 = nn.Sequential( + PVMLayer(input_dim=c_list[5], output_dim=c_list[4]) + ) + self.decoder2 = nn.Sequential( + PVMLayer(input_dim=c_list[4], output_dim=c_list[3]) + ) + self.decoder3 = nn.Sequential( + PVMLayer(input_dim=c_list[3], output_dim=c_list[2]) + ) + self.decoder4 = nn.Sequential( + nn.Conv2d(c_list[2], c_list[1], 3, stride=1, padding=1), + ) + self.decoder5 = nn.Sequential( + nn.Conv2d(c_list[1], c_list[0], 3, stride=1, padding=1), + ) + self.ebn1 = nn.GroupNorm(4, c_list[0]) + self.ebn2 = nn.GroupNorm(4, c_list[1]) + self.ebn3 = nn.GroupNorm(4, c_list[2]) + self.ebn4 = nn.GroupNorm(4, c_list[3]) + self.ebn5 = nn.GroupNorm(4, c_list[4]) + self.dbn1 = nn.GroupNorm(4, c_list[4]) + self.dbn2 = nn.GroupNorm(4, c_list[3]) + self.dbn3 = nn.GroupNorm(4, c_list[2]) + self.dbn4 = nn.GroupNorm(4, c_list[1]) + self.dbn5 = nn.GroupNorm(4, c_list[0]) + + self.final = nn.Conv2d(c_list[0], num_classes, kernel_size=1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv1d): + n = m.kernel_size[0] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + + out = F.gelu(F.max_pool2d(self.ebn1(self.encoder1(x)),2,2)) + t1 = out # b, c0, H/2, W/2 + + out = F.gelu(F.max_pool2d(self.ebn2(self.encoder2(out)),2,2)) + t2 = out # b, c1, H/4, W/4 + + out = F.gelu(F.max_pool2d(self.ebn3(self.encoder3(out)),2,2)) + t3 = out # b, c2, H/8, W/8 + + out = F.gelu(F.max_pool2d(self.ebn4(self.encoder4(out)),2,2)) + t4 = out # b, c3, H/16, W/16 + + out = F.gelu(F.max_pool2d(self.ebn5(self.encoder5(out)),2,2)) + t5 = out # b, c4, H/32, W/32 + + if self.bridge: t1, t2, t3, t4, t5 = self.scab(t1, t2, t3, t4, t5) + + out = F.gelu(self.encoder6(out)) # b, c5, H/32, W/32 + + out5 = F.gelu(self.dbn1(self.decoder1(out))) # b, c4, H/32, W/32 + out5 = torch.add(out5, t5) # b, c4, H/32, W/32 + + out4 = F.gelu(F.interpolate(self.dbn2(self.decoder2(out5)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c3, H/16, W/16 + out4 = torch.add(out4, t4) # b, c3, H/16, W/16 + + out3 = F.gelu(F.interpolate(self.dbn3(self.decoder3(out4)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c2, H/8, W/8 + out3 = torch.add(out3, t3) # b, c2, H/8, W/8 + + out2 = F.gelu(F.interpolate(self.dbn4(self.decoder4(out3)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c1, H/4, W/4 + out2 = torch.add(out2, t2) # b, c1, H/4, W/4 + + out1 = F.gelu(F.interpolate(self.dbn5(self.decoder5(out2)),scale_factor=(2,2),mode ='bilinear',align_corners=True)) # b, c0, H/2, W/2 + out1 = torch.add(out1, t1) # b, c0, H/2, W/2 + + out0 = F.interpolate(self.final(out1),scale_factor=(2,2),mode ='bilinear',align_corners=True) # b, num_class, H, W + + if(len(x_shape) == 5): + new_shape = [N, D] + list(out0.shape)[1:] + out0 = torch.transpose(torch.reshape(out0, new_shape), 1, 2) + return out0 + diff --git a/pymic/net/net3d/fmunet.py b/pymic/net/net3d/fmunet.py new file mode 100644 index 0000000..84ee385 --- /dev/null +++ b/pymic/net/net3d/fmunet.py @@ -0,0 +1,265 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import itertools +import logging +import torch +import torch.nn as nn +from pymic.net.net_init import Initialization_He, Initialization_XavierUniform + +''' +A copy of fmunetv3, and rename the class as FMUNet. +''' +dim0 = {0:3, 1:2, 2:2} +dim1 = {0:3, 1:3, 2:2} +conv_knl = {2: (1, 3, 3), 3: 3} +conv_pad = {2: (0, 1, 1), 3: 1} +pool_knl = {2: (1, 2, 2), 3: 2} +down_stride = {2: (1, 2, 2), 3: 2} + +class ResConv(nn.Module): + def __init__(self, out_channels, dim = 3, dropout_p = 0.0, depth = 2): + super(ResConv, self).__init__() + assert(dim == 2 or dim == 3) + self.out_channels = out_channels + self.conv_list = nn.ModuleList([nn.Sequential( + nn.InstanceNorm3d(out_channels, affine = True), + nn.LeakyReLU(), + nn.Dropout(dropout_p), + nn.Conv3d(out_channels, out_channels, kernel_size=conv_knl[dim], padding=conv_pad[dim])) + for i in range(depth)]) + + def forward(self, x): + for conv in self.conv_list: + x = conv(x) + x + return x + +class DownSample(nn.Module): + """downsampling based on convolution + + :param in_channels: (int) Input channel number. + :param out_channels: (int) Output channel number. + :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. + :param dropout_p: (int) Dropout probability. + :param downsample: (bool) Use downsample or not after convolution. + """ + def __init__(self, in_channels, out_channels, dim = 3): + super(DownSample, self).__init__() + self.down = nn.Sequential( + nn.InstanceNorm3d(in_channels, affine = True), + nn.LeakyReLU(), + nn.Conv3d(in_channels, out_channels, kernel_size=conv_knl[dim], + padding=conv_pad[dim], stride = down_stride[dim]) + ) + + def forward(self, x): + return self.down(x) + +class UpCatConv(nn.Module): + """Upsampling followed by `ResConv` block + + :param in_channels1: (int) Input channel number for low-resolution feature map. + :param in_channels2: (int) Input channel number for high-resolution feature map. + :param out_channels: (int) Output channel number. + :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. + :param dropout_p: (int) Dropout probability. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear` for 3D and `Bilinear` for 2D). + The default value is 2. + """ + def __init__(self, in_channels1, in_channels2, out_channels, dim = 3): + super(UpCatConv, self).__init__() + + self.up = nn.Sequential( + nn.InstanceNorm3d(in_channels1, affine = True), + nn.LeakyReLU(), + nn.Conv3d(in_channels1, in_channels2, kernel_size=1, padding=0), + nn.Upsample(scale_factor=pool_knl[dim], mode='trilinear', align_corners=True) + ) + + self.conv = nn.Sequential( + nn.InstanceNorm3d(in_channels2*2, affine = True), + nn.LeakyReLU(), + nn.Conv3d(in_channels2 * 2, out_channels, kernel_size=conv_knl[dim], padding=conv_pad[dim]) + ) + + def forward(self, x_l, x_h): + """ + x_l: low-resolution feature map. + x_h: high-resolution feature map. + """ + y = torch.cat([x_h, self.up(x_l)], dim=1) + return self.conv(y) + +class Encoder(nn.Module): + """ + A modification of the encoder of 3D UNet by using ConvScSEBlock3D + + Parameters are given in the `params` dictionary. + See :mod:`pymic.net.net3d.unet3d.Encoder` for details. + + res_mode: resolution mode: 0-- isotrpic, 1-- near isotrpic, 2-- isotropic + """ + def __init__(self, ft_chns, res_mode = 0, dropout_p = 0, depth = 2): + super(Encoder, self).__init__() + d0, d1 = dim0[res_mode], dim1[res_mode] + + self.en_conv0 = ResConv(ft_chns[0], d0, 0, depth) + self.en_conv1 = ResConv(ft_chns[1], d1, 0, depth) + self.en_conv2 = ResConv(ft_chns[2], 3, dropout_p, depth) + self.en_conv3 = ResConv(ft_chns[3], 3, dropout_p, depth) + self.en_conv4 = ResConv(ft_chns[4], 3, dropout_p, depth) + + self.down0 = DownSample(ft_chns[0], ft_chns[1], d0) + self.down1 = DownSample(ft_chns[1], ft_chns[2], d1) + self.down2 = DownSample(ft_chns[2], ft_chns[3], 3) + self.down3 = DownSample(ft_chns[3], ft_chns[4], 3) + + def forward(self, x): + x0 = self.en_conv0(x) + x1 = self.en_conv1(self.down0(x0)) + x2 = self.en_conv2(self.down1(x1)) + x3 = self.en_conv3(self.down2(x2)) + x4 = self.en_conv4(self.down3(x3)) + return [x0, x1, x2, x3, x4] + +class Decoder(nn.Module): + """ + A modification of the encoder of 3D UNet by using ConvScSEBlock3D + + Parameters are given in the `params` dictionary. + See :mod:`pymic.net.net3d.unet3d.Encoder` for details. + """ + def __init__(self, ft_chns, res_mode = 0, dropout_p = 0, depth = 2): + super(Decoder, self).__init__() + d0, d1 = dim0[res_mode], dim1[res_mode] + + self.upcat0 = UpCatConv(ft_chns[1], ft_chns[0], ft_chns[0], d0) + self.upcat1 = UpCatConv(ft_chns[2], ft_chns[1], ft_chns[1], d1) + self.upcat2 = UpCatConv(ft_chns[3], ft_chns[2], ft_chns[2], 3) + self.upcat3 = UpCatConv(ft_chns[4], ft_chns[3], ft_chns[3], 3) + + self.de_conv0 = ResConv(ft_chns[0], d0, 0, depth) + self.de_conv1 = ResConv(ft_chns[1], d1, 0, depth) + self.de_conv2 = ResConv(ft_chns[2], 3, dropout_p, depth) + self.de_conv3 = ResConv(ft_chns[3], 3, dropout_p, depth) + self.de_conv4 = ResConv(ft_chns[4], 3, dropout_p, depth) + + def forward(self, x): + x0, x1, x2, x3, x4 = x + x4_de = self.de_conv4(x4) + x3_de = self.de_conv3(self.upcat3(x4_de, x3)) + x2_de = self.de_conv2(self.upcat2(x3_de, x2)) + x1_de = self.de_conv1(self.upcat1(x2_de, x1)) + x0_de = self.de_conv0(self.upcat0(x1_de, x0)) + return [x0_de, x1_de, x2_de, x3_de] + +class FMUNet(nn.Module): + """ + A 2.5D network combining 3D convolutions with 2D convolutions. + + * Reference: Guotai Wang, Jonathan Shapey, Wenqi Li, Reuben Dorent, Alex Demitriadis, + Sotirios Bisdas, Ian Paddick, Robert Bradford, Shaoting Zhang, Sébastien Ourselin, + Tom Vercauteren: Automatic Segmentation of Vestibular Schwannoma from T2-Weighted + MRI by Deep Spatial Attention with Hardness-Weighted Loss. + `MICCAI (2) 2019: 264-272. `_ + + Note that the attention module in the orininal paper is not used here. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param conv_dims: (list) The convolution dimension (2 or 3) for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear`). The default value + is 2 (`Trilinear`). + :param multiscale_pred: (bool) Get multi-scale prediction. + """ + def __init__(self, params): + super(FMUNet, self).__init__() + params = self.get_default_parameters(params) + + self.stage = 'train' + in_chns = params['in_chns'] + ft_chns = params['feature_chns'] + res_mode = params['res_mode'] + dropout = params['dropout'] + depth = params['depth'] + cls_num = params['class_num'] + self.mul_pred = params.get('multiscale_pred', True) + self.tune_mode= params.get('finetune_mode', 'all') + self.load_mode= params.get('weights_load_mode', 'all') + + d0 = dim0[res_mode] + self.project = nn.Conv3d(in_chns, ft_chns[0], kernel_size=conv_knl[d0], padding=conv_pad[d0]) + self.encoder = Encoder(ft_chns, res_mode, dropout, depth) + # self.decoder = Decoder(ft_chns, res_mode, dropout, depth = 2) + self.decoder = Decoder(ft_chns, res_mode, dropout, depth) + + self.out_layers = nn.ModuleList() + dims = [dim0[res_mode], dim1[res_mode], 3, 3] + for i in range(4): + out_layer = nn.Sequential( + nn.InstanceNorm3d(ft_chns[i], affine = True), + nn.LeakyReLU(), + nn.Conv3d(ft_chns[i], cls_num, kernel_size=conv_knl[dims[i]], padding=conv_pad[dims[i]])) + self.out_layers.append(out_layer) + + init = params['initialization'].lower() + weightInitializer = Initialization_He(1e-2) if init == 'he' else Initialization_XavierUniform() + self.apply(weightInitializer) + + def get_default_parameters(self, params): + default_param = { + 'finetune_mode': 'all', + 'initialization': 'he', + 'feature_chns': [32, 64, 128, 256, 512], + 'dropout': 0.2, + 'res_mode': 0, + 'depth': 2, + 'multiscale_pred': True + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params + + def set_stage(self, stage): + self.stage = stage + + def forward(self, x): + x_en = self.encoder(self.project(x)) + x_de = self.decoder(x_en) + output = self.out_layers[0](x_de[0]) + if(self.mul_pred and self.stage == 'train'): + output = [output] + for i in range(1, len(x_de)): + output.append(self.out_layers[i](x_de[i])) + return output + + def get_parameters_to_update(self): + if(self.tune_mode == 'all'): + return self.parameters() + + up_params = itertools.chain() + if(self.tune_mode == 'decoder'): + up_blocks = [self.decoder, self.out_layers] + else: + raise ValueError("undefined fine-tune mode for FMUNet: {0:}".format(self.tune_mode)) + for block in up_blocks: + up_params = itertools.chain(up_params, block.parameters()) + return up_params + + def get_parameters_to_load(self): + state_dict = self.state_dict() + if(self.load_mode == 'encoder'): + state_dict = {k:v for k, v in state_dict.items() if "project" in k or "encoder" in k } + return state_dict \ No newline at end of file diff --git a/pymic/net/net3d/fmunetv3.py b/pymic/net/net3d/fmunetv3.py new file mode 100644 index 0000000..1367209 --- /dev/null +++ b/pymic/net/net3d/fmunetv3.py @@ -0,0 +1,262 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import itertools +import logging +import torch +import torch.nn as nn +from pymic.net.net_init import Initialization_He, Initialization_XavierUniform + +dim0 = {0:3, 1:2, 2:2} +dim1 = {0:3, 1:3, 2:2} +conv_knl = {2: (1, 3, 3), 3: 3} +conv_pad = {2: (0, 1, 1), 3: 1} +pool_knl = {2: (1, 2, 2), 3: 2} +down_stride = {2: (1, 2, 2), 3: 2} + +class ResConv(nn.Module): + def __init__(self, out_channels, dim = 3, dropout_p = 0.0, depth = 2): + super(ResConv, self).__init__() + assert(dim == 2 or dim == 3) + self.out_channels = out_channels + self.conv_list = nn.ModuleList([nn.Sequential( + nn.InstanceNorm3d(out_channels, affine = True), + nn.LeakyReLU(), + nn.Dropout(dropout_p), + nn.Conv3d(out_channels, out_channels, kernel_size=conv_knl[dim], padding=conv_pad[dim])) + for i in range(depth)]) + + def forward(self, x): + for conv in self.conv_list: + x = conv(x) + x + return x + +class DownSample(nn.Module): + """downsampling based on convolution + + :param in_channels: (int) Input channel number. + :param out_channels: (int) Output channel number. + :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. + :param dropout_p: (int) Dropout probability. + :param downsample: (bool) Use downsample or not after convolution. + """ + def __init__(self, in_channels, out_channels, dim = 3): + super(DownSample, self).__init__() + self.down = nn.Sequential( + nn.InstanceNorm3d(in_channels, affine = True), + nn.LeakyReLU(), + nn.Conv3d(in_channels, out_channels, kernel_size=conv_knl[dim], + padding=conv_pad[dim], stride = down_stride[dim]) + ) + + def forward(self, x): + return self.down(x) + +class UpCatConv(nn.Module): + """Upsampling followed by `ResConv` block + + :param in_channels1: (int) Input channel number for low-resolution feature map. + :param in_channels2: (int) Input channel number for high-resolution feature map. + :param out_channels: (int) Output channel number. + :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. + :param dropout_p: (int) Dropout probability. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear` for 3D and `Bilinear` for 2D). + The default value is 2. + """ + def __init__(self, in_channels1, in_channels2, out_channels, dim = 3): + super(UpCatConv, self).__init__() + + self.up = nn.Sequential( + nn.InstanceNorm3d(in_channels1, affine = True), + nn.LeakyReLU(), + nn.Conv3d(in_channels1, in_channels2, kernel_size=1, padding=0), + nn.Upsample(scale_factor=pool_knl[dim], mode='trilinear', align_corners=True) + ) + + self.conv = nn.Sequential( + nn.InstanceNorm3d(in_channels2*2, affine = True), + nn.LeakyReLU(), + nn.Conv3d(in_channels2 * 2, out_channels, kernel_size=conv_knl[dim], padding=conv_pad[dim]) + ) + + def forward(self, x_l, x_h): + """ + x_l: low-resolution feature map. + x_h: high-resolution feature map. + """ + y = torch.cat([x_h, self.up(x_l)], dim=1) + return self.conv(y) + +class Encoder(nn.Module): + """ + A modification of the encoder of 3D UNet by using ConvScSEBlock3D + + Parameters are given in the `params` dictionary. + See :mod:`pymic.net.net3d.unet3d.Encoder` for details. + + res_mode: resolution mode: 0-- isotrpic, 1-- near isotrpic, 2-- isotropic + """ + def __init__(self, ft_chns, res_mode = 0, dropout_p = 0, depth = 2): + super(Encoder, self).__init__() + d0, d1 = dim0[res_mode], dim1[res_mode] + + self.en_conv0 = ResConv(ft_chns[0], d0, 0, depth) + self.en_conv1 = ResConv(ft_chns[1], d1, 0, depth) + self.en_conv2 = ResConv(ft_chns[2], 3, dropout_p, depth) + self.en_conv3 = ResConv(ft_chns[3], 3, dropout_p, depth) + self.en_conv4 = ResConv(ft_chns[4], 3, dropout_p, depth) + + self.down0 = DownSample(ft_chns[0], ft_chns[1], d0) + self.down1 = DownSample(ft_chns[1], ft_chns[2], d1) + self.down2 = DownSample(ft_chns[2], ft_chns[3], 3) + self.down3 = DownSample(ft_chns[3], ft_chns[4], 3) + + def forward(self, x): + x0 = self.en_conv0(x) + x1 = self.en_conv1(self.down0(x0)) + x2 = self.en_conv2(self.down1(x1)) + x3 = self.en_conv3(self.down2(x2)) + x4 = self.en_conv4(self.down3(x3)) + return [x0, x1, x2, x3, x4] + +class Decoder(nn.Module): + """ + A modification of the encoder of 3D UNet by using ConvScSEBlock3D + + Parameters are given in the `params` dictionary. + See :mod:`pymic.net.net3d.unet3d.Encoder` for details. + """ + def __init__(self, ft_chns, res_mode = 0, dropout_p = 0, depth = 2): + super(Decoder, self).__init__() + d0, d1 = dim0[res_mode], dim1[res_mode] + + self.upcat0 = UpCatConv(ft_chns[1], ft_chns[0], ft_chns[0], d0) + self.upcat1 = UpCatConv(ft_chns[2], ft_chns[1], ft_chns[1], d1) + self.upcat2 = UpCatConv(ft_chns[3], ft_chns[2], ft_chns[2], 3) + self.upcat3 = UpCatConv(ft_chns[4], ft_chns[3], ft_chns[3], 3) + + self.de_conv0 = ResConv(ft_chns[0], d0, 0, depth) + self.de_conv1 = ResConv(ft_chns[1], d1, 0, depth) + self.de_conv2 = ResConv(ft_chns[2], 3, dropout_p, depth) + self.de_conv3 = ResConv(ft_chns[3], 3, dropout_p, depth) + self.de_conv4 = ResConv(ft_chns[4], 3, dropout_p, depth) + + def forward(self, x): + x0, x1, x2, x3, x4 = x + x4_de = self.de_conv4(x4) + x3_de = self.de_conv3(self.upcat3(x4_de, x3)) + x2_de = self.de_conv2(self.upcat2(x3_de, x2)) + x1_de = self.de_conv1(self.upcat1(x2_de, x1)) + x0_de = self.de_conv0(self.upcat0(x1_de, x0)) + return [x0_de, x1_de, x2_de, x3_de] + +class FMUNetV3(nn.Module): + """ + A 2.5D network combining 3D convolutions with 2D convolutions. + + * Reference: Guotai Wang, Jonathan Shapey, Wenqi Li, Reuben Dorent, Alex Demitriadis, + Sotirios Bisdas, Ian Paddick, Robert Bradford, Shaoting Zhang, Sébastien Ourselin, + Tom Vercauteren: Automatic Segmentation of Vestibular Schwannoma from T2-Weighted + MRI by Deep Spatial Attention with Hardness-Weighted Loss. + `MICCAI (2) 2019: 264-272. `_ + + Note that the attention module in the orininal paper is not used here. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param conv_dims: (list) The convolution dimension (2 or 3) for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear`). The default value + is 2 (`Trilinear`). + :param multiscale_pred: (bool) Get multi-scale prediction. + """ + def __init__(self, params): + super(FMUNetV3, self).__init__() + params = self.get_default_parameters(params) + + self.stage = 'train' + in_chns = params['in_chns'] + ft_chns = params['feature_chns'] + res_mode = params['res_mode'] + dropout = params['dropout'] + depth = params['depth'] + cls_num = params['class_num'] + self.mul_pred = params.get('multiscale_pred', True) + self.tune_mode= params.get('finetune_mode', 'all') + self.load_mode= params.get('weights_load_mode', 'all') + + d0 = dim0[res_mode] + self.project = nn.Conv3d(in_chns, ft_chns[0], kernel_size=conv_knl[d0], padding=conv_pad[d0]) + self.encoder = Encoder(ft_chns, res_mode, dropout, depth) + # self.decoder = Decoder(ft_chns, res_mode, dropout, depth = 2) + self.decoder = Decoder(ft_chns, res_mode, dropout, depth) + + self.out_layers = nn.ModuleList() + dims = [dim0[res_mode], dim1[res_mode], 3, 3] + for i in range(4): + out_layer = nn.Sequential( + nn.InstanceNorm3d(ft_chns[i], affine = True), + nn.LeakyReLU(), + nn.Conv3d(ft_chns[i], cls_num, kernel_size=conv_knl[dims[i]], padding=conv_pad[dims[i]])) + self.out_layers.append(out_layer) + + init = params['initialization'].lower() + weightInitializer = Initialization_He(1e-2) if init == 'he' else Initialization_XavierUniform() + self.apply(weightInitializer) + + def get_default_parameters(self, params): + default_param = { + 'finetune_mode': 'all', + 'initialization': 'he', + 'feature_chns': [32, 64, 128, 256, 512], + 'dropout': 0.2, + 'res_mode': 0, + 'depth': 2, + 'multiscale_pred': True + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params + + def set_stage(self, stage): + self.stage = stage + + def forward(self, x): + x_en = self.encoder(self.project(x)) + x_de = self.decoder(x_en) + output = self.out_layers[0](x_de[0]) + if(self.mul_pred and self.stage == 'train'): + output = [output] + for i in range(1, len(x_de)): + output.append(self.out_layers[i](x_de[i])) + return output + + def get_parameters_to_update(self): + if(self.tune_mode == 'all'): + return self.parameters() + + up_params = itertools.chain() + if(self.tune_mode == 'decoder'): + up_blocks = [self.decoder, self.out_layers] + else: + raise ValueError("undefined fine-tune mode for FMUNet: {0:}".format(self.tune_mode)) + for block in up_blocks: + up_params = itertools.chain(up_params, block.parameters()) + return up_params + + def get_parameters_to_load(self): + state_dict = self.state_dict() + if(self.load_mode == 'encoder'): + state_dict = {k:v for k, v in state_dict.items() if "project" in k or "encoder" in k } + return state_dict \ No newline at end of file diff --git a/pymic/net/net3d/grunet.py b/pymic/net/net3d/grunet.py new file mode 100644 index 0000000..bae6447 --- /dev/null +++ b/pymic/net/net3d/grunet.py @@ -0,0 +1,264 @@ +# -*- coding: utf-8 -*- +# Note: this is a renamed version of fmunetv3. fmunetv3 will be removed in a later version. +# This network is originally used in the VolF paper. +from __future__ import print_function, division + +import itertools +import logging +import torch +import torch.nn as nn +from pymic.net.net_init import Initialization_He, Initialization_XavierUniform + +dim0 = {0:3, 1:2, 2:2} +dim1 = {0:3, 1:3, 2:2} +conv_knl = {2: (1, 3, 3), 3: 3} +conv_pad = {2: (0, 1, 1), 3: 1} +pool_knl = {2: (1, 2, 2), 3: 2} +down_stride = {2: (1, 2, 2), 3: 2} + +class ResConv(nn.Module): + def __init__(self, out_channels, dim = 3, dropout_p = 0.0, depth = 2): + super(ResConv, self).__init__() + assert(dim == 2 or dim == 3) + self.out_channels = out_channels + self.conv_list = nn.ModuleList([nn.Sequential( + nn.InstanceNorm3d(out_channels, affine = True), + nn.LeakyReLU(), + nn.Dropout(dropout_p), + nn.Conv3d(out_channels, out_channels, kernel_size=conv_knl[dim], padding=conv_pad[dim])) + for i in range(depth)]) + + def forward(self, x): + for conv in self.conv_list: + x = conv(x) + x + return x + +class DownSample(nn.Module): + """downsampling based on convolution + + :param in_channels: (int) Input channel number. + :param out_channels: (int) Output channel number. + :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. + :param dropout_p: (int) Dropout probability. + :param downsample: (bool) Use downsample or not after convolution. + """ + def __init__(self, in_channels, out_channels, dim = 3): + super(DownSample, self).__init__() + self.down = nn.Sequential( + nn.InstanceNorm3d(in_channels, affine = True), + nn.LeakyReLU(), + nn.Conv3d(in_channels, out_channels, kernel_size=conv_knl[dim], + padding=conv_pad[dim], stride = down_stride[dim]) + ) + + def forward(self, x): + return self.down(x) + +class UpCatConv(nn.Module): + """Upsampling followed by `ResConv` block + + :param in_channels1: (int) Input channel number for low-resolution feature map. + :param in_channels2: (int) Input channel number for high-resolution feature map. + :param out_channels: (int) Output channel number. + :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. + :param dropout_p: (int) Dropout probability. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear` for 3D and `Bilinear` for 2D). + The default value is 2. + """ + def __init__(self, in_channels1, in_channels2, out_channels, dim = 3): + super(UpCatConv, self).__init__() + + self.up = nn.Sequential( + nn.InstanceNorm3d(in_channels1, affine = True), + nn.LeakyReLU(), + nn.Conv3d(in_channels1, in_channels2, kernel_size=1, padding=0), + nn.Upsample(scale_factor=pool_knl[dim], mode='trilinear', align_corners=True) + ) + + self.conv = nn.Sequential( + nn.InstanceNorm3d(in_channels2*2, affine = True), + nn.LeakyReLU(), + nn.Conv3d(in_channels2 * 2, out_channels, kernel_size=conv_knl[dim], padding=conv_pad[dim]) + ) + + def forward(self, x_l, x_h): + """ + x_l: low-resolution feature map. + x_h: high-resolution feature map. + """ + y = torch.cat([x_h, self.up(x_l)], dim=1) + return self.conv(y) + +class Encoder(nn.Module): + """ + A modification of the encoder of 3D UNet by using ConvScSEBlock3D + + Parameters are given in the `params` dictionary. + See :mod:`pymic.net.net3d.unet3d.Encoder` for details. + + res_mode: resolution mode: 0-- isotrpic, 1-- near isotrpic, 2-- isotropic + """ + def __init__(self, ft_chns, res_mode = 0, dropout_p = 0, depth = 2): + super(Encoder, self).__init__() + d0, d1 = dim0[res_mode], dim1[res_mode] + + self.en_conv0 = ResConv(ft_chns[0], d0, 0, depth) + self.en_conv1 = ResConv(ft_chns[1], d1, 0, depth) + self.en_conv2 = ResConv(ft_chns[2], 3, dropout_p, depth) + self.en_conv3 = ResConv(ft_chns[3], 3, dropout_p, depth) + self.en_conv4 = ResConv(ft_chns[4], 3, dropout_p, depth) + + self.down0 = DownSample(ft_chns[0], ft_chns[1], d0) + self.down1 = DownSample(ft_chns[1], ft_chns[2], d1) + self.down2 = DownSample(ft_chns[2], ft_chns[3], 3) + self.down3 = DownSample(ft_chns[3], ft_chns[4], 3) + + def forward(self, x): + x0 = self.en_conv0(x) + x1 = self.en_conv1(self.down0(x0)) + x2 = self.en_conv2(self.down1(x1)) + x3 = self.en_conv3(self.down2(x2)) + x4 = self.en_conv4(self.down3(x3)) + return [x0, x1, x2, x3, x4] + +class Decoder(nn.Module): + """ + A modification of the encoder of 3D UNet by using ConvScSEBlock3D + + Parameters are given in the `params` dictionary. + See :mod:`pymic.net.net3d.unet3d.Encoder` for details. + """ + def __init__(self, ft_chns, res_mode = 0, dropout_p = 0, depth = 2): + super(Decoder, self).__init__() + d0, d1 = dim0[res_mode], dim1[res_mode] + + self.upcat0 = UpCatConv(ft_chns[1], ft_chns[0], ft_chns[0], d0) + self.upcat1 = UpCatConv(ft_chns[2], ft_chns[1], ft_chns[1], d1) + self.upcat2 = UpCatConv(ft_chns[3], ft_chns[2], ft_chns[2], 3) + self.upcat3 = UpCatConv(ft_chns[4], ft_chns[3], ft_chns[3], 3) + + self.de_conv0 = ResConv(ft_chns[0], d0, 0, depth) + self.de_conv1 = ResConv(ft_chns[1], d1, 0, depth) + self.de_conv2 = ResConv(ft_chns[2], 3, dropout_p, depth) + self.de_conv3 = ResConv(ft_chns[3], 3, dropout_p, depth) + self.de_conv4 = ResConv(ft_chns[4], 3, dropout_p, depth) + + def forward(self, x): + x0, x1, x2, x3, x4 = x + x4_de = self.de_conv4(x4) + x3_de = self.de_conv3(self.upcat3(x4_de, x3)) + x2_de = self.de_conv2(self.upcat2(x3_de, x2)) + x1_de = self.de_conv1(self.upcat1(x2_de, x1)) + x0_de = self.de_conv0(self.upcat0(x1_de, x0)) + return [x0_de, x1_de, x2_de, x3_de] + +class GRUNet(nn.Module): + """ + A General Residual UNet. + + * Reference: Guotai Wang, Jonathan Shapey, Wenqi Li, Reuben Dorent, Alex Demitriadis, + Sotirios Bisdas, Ian Paddick, Robert Bradford, Shaoting Zhang, Sébastien Ourselin, + Tom Vercauteren: Automatic Segmentation of Vestibular Schwannoma from T2-Weighted + MRI by Deep Spatial Attention with Hardness-Weighted Loss. + `MICCAI (2) 2019: 264-272. `_ + + Note that the attention module in the orininal paper is not used here. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param conv_dims: (list) The convolution dimension (2 or 3) for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear`). The default value + is 2 (`Trilinear`). + :param multiscale_pred: (bool) Get multi-scale prediction. + """ + def __init__(self, params): + super(GRUNet, self).__init__() + params = self.get_default_parameters(params) + + self.stage = 'train' + in_chns = params['in_chns'] + ft_chns = params['feature_chns'] + res_mode = params['res_mode'] + dropout = params['dropout'] + depth = params['depth'] + cls_num = params['class_num'] + self.mul_pred = params.get('multiscale_pred', True) + self.tune_mode= params.get('finetune_mode', 'all') + self.load_mode= params.get('weights_load_mode', 'all') + + d0 = dim0[res_mode] + self.project = nn.Conv3d(in_chns, ft_chns[0], kernel_size=conv_knl[d0], padding=conv_pad[d0]) + self.encoder = Encoder(ft_chns, res_mode, dropout, depth) + # self.decoder = Decoder(ft_chns, res_mode, dropout, depth = 2) + self.decoder = Decoder(ft_chns, res_mode, dropout, depth) + + self.out_layers = nn.ModuleList() + dims = [dim0[res_mode], dim1[res_mode], 3, 3] + for i in range(4): + out_layer = nn.Sequential( + nn.InstanceNorm3d(ft_chns[i], affine = True), + nn.LeakyReLU(), + nn.Conv3d(ft_chns[i], cls_num, kernel_size=conv_knl[dims[i]], padding=conv_pad[dims[i]])) + self.out_layers.append(out_layer) + + init = params['initialization'].lower() + weightInitializer = Initialization_He(1e-2) if init == 'he' else Initialization_XavierUniform() + self.apply(weightInitializer) + + def get_default_parameters(self, params): + default_param = { + 'finetune_mode': 'all', + 'initialization': 'he', + 'feature_chns': [32, 64, 128, 256, 512], + 'dropout': 0.2, + 'res_mode': 0, + 'depth': 2, + 'multiscale_pred': True + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params + + def set_stage(self, stage): + self.stage = stage + + def forward(self, x): + x_en = self.encoder(self.project(x)) + x_de = self.decoder(x_en) + output = self.out_layers[0](x_de[0]) + if(self.mul_pred and self.stage == 'train'): + output = [output] + for i in range(1, len(x_de)): + output.append(self.out_layers[i](x_de[i])) + return output + + def get_parameters_to_update(self): + if(self.tune_mode == 'all'): + return self.parameters() + + up_params = itertools.chain() + if(self.tune_mode == 'decoder'): + up_blocks = [self.decoder, self.out_layers] + else: + raise ValueError("undefined fine-tune mode for GRUNet: {0:}".format(self.tune_mode)) + for block in up_blocks: + up_params = itertools.chain(up_params, block.parameters()) + return up_params + + def get_parameters_to_load(self): + state_dict = self.state_dict() + if(self.load_mode == 'encoder'): + state_dict = {k:v for k, v in state_dict.items() if "project" in k or "encoder" in k } + return state_dict \ No newline at end of file diff --git a/pymic/net/net3d/lcovnet.py b/pymic/net/net3d/lcovnet.py new file mode 100644 index 0000000..bd91878 --- /dev/null +++ b/pymic/net/net3d/lcovnet.py @@ -0,0 +1,246 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import logging +import torch +import torch.nn as nn +import numpy as np +from pymic.net.net_init import Initialization_He, Initialization_XavierUniform + +class UnetBlock_Encode(nn.Module): + def __init__(self, in_channels, out_channel): + super(UnetBlock_Encode, self).__init__() + + self.in_chns = in_channels + self.out_chns = out_channel + + self.conv1 = nn.Sequential( + nn.Conv3d(self.in_chns, self.out_chns, kernel_size=(1, 1, 3), + padding=(0, 0, 1)), + nn.BatchNorm3d(self.out_chns), + nn.ReLU6(inplace=True) + ) + + self.conv2_1 = nn.Sequential( + nn.Conv3d(self.out_chns, self.out_chns, kernel_size=(3, 3, 1), + padding=(1, 1, 0), groups=1), + nn.BatchNorm3d(self.out_chns), + nn.ReLU6(inplace=True), + nn.Dropout(p=0.2) + ) + + self.conv2_2 = nn.Sequential( + nn.AvgPool3d(kernel_size=4, stride=2, padding=1), + nn.Conv3d(self.out_chns, self.out_chns, kernel_size=1, + padding=0), + nn.BatchNorm3d(self.out_chns), + nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False) + ) + + def forward(self, x): + # print(x.shape) + x = self.conv1(x) + + x1 = self.conv2_1(x) + x2 = self.conv2_2(x) + x2 = torch.sigmoid(x2) + x = x1 + x2 * x + return x + + +class UnetBlock_Encode_BottleNeck(nn.Module): + def __init__(self, in_channels, out_channel): + super(UnetBlock_Encode_BottleNeck, self).__init__() + + self.in_chns = in_channels + self.out_chns = out_channel + + self.conv1 = nn.Sequential( + nn.Conv3d(self.in_chns, self.out_chns, kernel_size=(1, 1, 3), + padding=(0, 0, 1)), + nn.BatchNorm3d(self.out_chns), + nn.ReLU6(inplace=True) + ) + + self.conv2_1 = nn.Sequential( + nn.Conv3d(self.out_chns, self.out_chns, kernel_size=(3, 3, 1), + padding=(1, 1, 0), groups=self.out_chns), + nn.BatchNorm3d(self.out_chns), + nn.ReLU6(inplace=True), + nn.Dropout(p=0.2) + ) + + self.conv2_2 = nn.Sequential( + # nn.AvgPool3d(kernel_size=4, stride=2), + nn.Conv3d(self.out_chns, self.out_chns, kernel_size=1, + padding=0), + nn.BatchNorm3d(self.out_chns), + nn.ReLU6(inplace=True), + nn.Dropout(p=0.2) + ) + + def forward(self, x): + x = self.conv1(x) + + x1 = self.conv2_1(x) + x2 = self.conv2_2(x) + x2 = torch.sigmoid(x2) + x = x1 + x2 * x + return x + + +class UnetBlock_Down(nn.Module): + def __init__(self): + super(UnetBlock_Down, self).__init__() + self.avg_pool = nn.MaxPool3d(kernel_size=2, stride=2) + + def forward(self, x): + x = self.avg_pool(x) + return x + + +class UnetBlock_Up(nn.Module): + def __init__(self, in_channels, out_channel): + super(UnetBlock_Up, self).__init__() + self.conv = self.conv1 = nn.Sequential( + nn.Conv3d(in_channels, out_channel, kernel_size=1, + padding=0, groups=1), + nn.BatchNorm3d(out_channel), + nn.ReLU6(inplace=True), + nn.Dropout(p=0.2) + ) + + self.up = nn.Upsample( + scale_factor=2, mode='trilinear', align_corners=False) + + def forward(self, x): + x = self.conv(x) + x = self.up(x) + return x + + +class LCOVNet(nn.Module): + """ + An implementation of the LCOVNet. + + * Reference: Q. Zhao, L. Zhong, J. Xiao, J. Zhang, Y. Chen , W. Liao, S. Zhang, and G. Wang: + Efficient Multi-Organ Segmentation From 3D Abdominal CT Images With Lightweight Network and Knowledge Distillation. + `IEEE TMI 42(9) 2023: 2513 - 2523. `_ + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear`). The default value + is 2 (`Trilinear`). + :param multiscale_pred: (bool) Get multi-scale prediction. + """ + def __init__(self, params): + super(LCOVNet, self).__init__() + params = self.get_default_parameters(params) + for p in params: + print(p, params[p]) + self.stage = 'train' + # C_in=32, n_classes=17, m=1, is_ds=True): + + in_chns = params['in_chns'] + n_class = params['class_num'] + self.ft_chns = params['feature_chns'] + self.mul_pred = params.get('multiscale_pred', False) + + self.Encode_block1 = UnetBlock_Encode(in_chns, self.ft_chns[0]) + self.down1 = UnetBlock_Down() + + self.Encode_block2 = UnetBlock_Encode(self.ft_chns[0], self.ft_chns[1]) + self.down2 = UnetBlock_Down() + + self.Encode_block3 = UnetBlock_Encode(self.ft_chns[1], self.ft_chns[2]) + self.down3 = UnetBlock_Down() + + self.Encode_block4 = UnetBlock_Encode(self.ft_chns[2], self.ft_chns[3]) + self.down4 = UnetBlock_Down() + + self.Encode_BottleNeck_block5 = UnetBlock_Encode_BottleNeck( + self.ft_chns[3], self.ft_chns[4]) + + self.up1 = UnetBlock_Up(self.ft_chns[4], self.ft_chns[3]) + self.Decode_block1 = UnetBlock_Encode( + self.ft_chns[3]*2, self.ft_chns[3]) + self.segout1 = nn.Conv3d( + self.ft_chns[3], n_class, kernel_size=1, padding=0) + + self.up2 = UnetBlock_Up(self.ft_chns[3], self.ft_chns[2]) + self.Decode_block2 = UnetBlock_Encode( + self.ft_chns[2]*2, self.ft_chns[2]) + self.segout2 = nn.Conv3d( + self.ft_chns[2], n_class, kernel_size=1, padding=0) + + self.up3 = UnetBlock_Up(self.ft_chns[2], self.ft_chns[1]) + self.Decode_block3 = UnetBlock_Encode( + self.ft_chns[1]*2, self.ft_chns[1]) + self.segout3 = nn.Conv3d( + self.ft_chns[1], n_class, kernel_size=1, padding=0) + + self.up4 = UnetBlock_Up(self.ft_chns[1], self.ft_chns[0]) + self.Decode_block4 = UnetBlock_Encode( + self.ft_chns[0]*2, self.ft_chns[0]) + self.segout4 = nn.Conv3d( + self.ft_chns[0], n_class, kernel_size=1, padding=0) + + def get_default_parameters(self, params): + default_param = { + 'feature_chns': [32, 64, 128, 256, 512], + 'initialization': 'he', + 'multiscale_pred': False + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params + + def forward(self, x): + _x1 = self.Encode_block1(x) + x1 = self.down1(_x1) + + _x2 = self.Encode_block2(x1) + x2 = self.down2(_x2) + + _x3 = self.Encode_block3(x2) + x3 = self.down2(_x3) + + _x4 = self.Encode_block4(x3) + x4 = self.down2(_x4) + + x5 = self.Encode_BottleNeck_block5(x4) + + x6 = self.up1(x5) + x6 = torch.cat((x6, _x4), dim=1) + x6 = self.Decode_block1(x6) + segout1 = self.segout1(x6) + + x7 = self.up2(x6) + x7 = torch.cat((x7, _x3), dim=1) + x7 = self.Decode_block2(x7) + segout2 = self.segout2(x7) + + x8 = self.up3(x7) + x8 = torch.cat((x8, _x2), dim=1) + x8 = self.Decode_block3(x8) + segout3 = self.segout3(x8) + + x9 = self.up4(x8) + x9 = torch.cat((x9, _x1), dim=1) + x9 = self.Decode_block4(x9) + segout4 = self.segout4(x9) + + if (self.mul_pred == True and self.stage == 'train'): + return [segout4, segout3, segout2, segout1] + else: + return segout4 \ No newline at end of file diff --git a/pymic/net/net3d/trans3d/__init__.py b/pymic/net/net3d/trans3d/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pymic/net/net3d/trans3d/transunet3d.py b/pymic/net/net3d/trans3d/transunet3d.py new file mode 100644 index 0000000..183834e --- /dev/null +++ b/pymic/net/net3d/trans3d/transunet3d.py @@ -0,0 +1,1053 @@ +# 3D version of TransUNet; Copyright Johns Hopkins University +# Modified from nnUNet + + + +import torch +import numpy as np +import torch.nn.functional +import torch.nn.functional as F + +from copy import deepcopy +from torch import nn +from torch.cuda.amp import autocast +from scipy.optimize import linear_sum_assignment + +from ..networks.neural_network import SegmentationNetwork +from .vit_modeling import Transformer +from .vit_modeling import CONFIGS as CONFIGS_ViT + +softmax_helper = lambda x: F.softmax(x, 1) + +class InitWeights_He(object): + def __init__(self, neg_slope=1e-2): + self.neg_slope = neg_slope + + def __call__(self, module): + if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d): + module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope) + if module.bias is not None: + module.bias = nn.init.constant_(module.bias, 0) + +class ConvDropoutNormNonlin(nn.Module): + """ + fixes a bug in ConvDropoutNormNonlin where lrelu was used regardless of nonlin. Bad. + """ + + def __init__(self, input_channels, output_channels, + conv_op=nn.Conv2d, conv_kwargs=None, + norm_op=nn.BatchNorm2d, norm_op_kwargs=None, + dropout_op=nn.Dropout2d, dropout_op_kwargs=None, + nonlin=nn.LeakyReLU, nonlin_kwargs=None): + super(ConvDropoutNormNonlin, self).__init__() + if nonlin_kwargs is None: + nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True} + if dropout_op_kwargs is None: + dropout_op_kwargs = {'p': 0.5, 'inplace': True} + if norm_op_kwargs is None: + norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'momentum': 0.1} + if conv_kwargs is None: + conv_kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1, 'dilation': 1, 'bias': True} + + self.nonlin_kwargs = nonlin_kwargs + self.nonlin = nonlin + self.dropout_op = dropout_op + self.dropout_op_kwargs = dropout_op_kwargs + self.norm_op_kwargs = norm_op_kwargs + self.conv_kwargs = conv_kwargs + self.conv_op = conv_op + self.norm_op = norm_op + + self.conv = self.conv_op(input_channels, output_channels, **self.conv_kwargs) + if self.dropout_op is not None and self.dropout_op_kwargs['p'] is not None and self.dropout_op_kwargs[ + 'p'] > 0: + self.dropout = self.dropout_op(**self.dropout_op_kwargs) + else: + self.dropout = None + self.instnorm = self.norm_op(output_channels, **self.norm_op_kwargs) + self.lrelu = self.nonlin(**self.nonlin_kwargs) + + def forward(self, x): + x = self.conv(x) + if self.dropout is not None: + x = self.dropout(x) + return self.lrelu(self.instnorm(x)) + + +class ConvDropoutNonlinNorm(ConvDropoutNormNonlin): + def forward(self, x): + x = self.conv(x) + if self.dropout is not None: + x = self.dropout(x) + return self.instnorm(self.lrelu(x)) + + +class StackedConvLayers(nn.Module): + def __init__(self, input_feature_channels, output_feature_channels, num_convs, + conv_op=nn.Conv2d, conv_kwargs=None, + norm_op=nn.BatchNorm2d, norm_op_kwargs=None, + dropout_op=nn.Dropout2d, dropout_op_kwargs=None, + nonlin=nn.LeakyReLU, nonlin_kwargs=None, first_stride=None, basic_block=ConvDropoutNormNonlin): + ''' + stacks ConvDropoutNormLReLU layers. initial_stride will only be applied to first layer in the stack. The other parameters affect all layers + :param input_feature_channels: + :param output_feature_channels: + :param num_convs: + :param dilation: + :param kernel_size: + :param padding: + :param dropout: + :param initial_stride: + :param conv_op: + :param norm_op: + :param dropout_op: + :param inplace: + :param neg_slope: + :param norm_affine: + :param conv_bias: + ''' + self.input_channels = input_feature_channels + self.output_channels = output_feature_channels + + if nonlin_kwargs is None: + nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True} + if dropout_op_kwargs is None: + dropout_op_kwargs = {'p': 0.5, 'inplace': True} + if norm_op_kwargs is None: + norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'momentum': 0.1} + if conv_kwargs is None: + conv_kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1, 'dilation': 1, 'bias': True} + + self.nonlin_kwargs = nonlin_kwargs + self.nonlin = nonlin + self.dropout_op = dropout_op + self.dropout_op_kwargs = dropout_op_kwargs + self.norm_op_kwargs = norm_op_kwargs + self.conv_kwargs = conv_kwargs + self.conv_op = conv_op + self.norm_op = norm_op + + if first_stride is not None: + self.conv_kwargs_first_conv = deepcopy(conv_kwargs) + self.conv_kwargs_first_conv['stride'] = first_stride + else: + self.conv_kwargs_first_conv = conv_kwargs + + super(StackedConvLayers, self).__init__() + self.blocks = nn.Sequential( + *([basic_block(input_feature_channels, output_feature_channels, self.conv_op, + self.conv_kwargs_first_conv, + self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, + self.nonlin, self.nonlin_kwargs)] + + [basic_block(output_feature_channels, output_feature_channels, self.conv_op, + self.conv_kwargs, + self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, + self.nonlin, self.nonlin_kwargs) for _ in range(num_convs - 1)])) + + def forward(self, x): + return self.blocks(x) + + +def print_module_training_status(module): + if isinstance(module, nn.Conv2d) or isinstance(module, nn.Conv3d) or isinstance(module, nn.Dropout3d) or \ + isinstance(module, nn.Dropout2d) or isinstance(module, nn.Dropout) or isinstance(module, nn.InstanceNorm3d) \ + or isinstance(module, nn.InstanceNorm2d) or isinstance(module, nn.InstanceNorm1d) \ + or isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d) or isinstance(module, + nn.BatchNorm1d): + print(str(module), module.training) + + +class Upsample(nn.Module): + def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=False): + super(Upsample, self).__init__() + self.align_corners = align_corners + self.mode = mode + self.scale_factor = scale_factor + self.size = size + + def forward(self, x): + return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, + align_corners=self.align_corners) + +def c2_xavier_fill(module: nn.Module) -> None: + """ + Initialize `module.weight` using the "XavierFill" implemented in Caffe2. + Also initializes `module.bias` to 0. + Args: + module (torch.nn.Module): module to initialize. + """ + # Caffe2 implementation of XavierFill in fact + # corresponds to kaiming_uniform_ in PyTorch + nn.init.kaiming_uniform_(module.weight, a=1) + if module.bias is not None: + # pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[nn.Module, + # torch.Tensor]`. + nn.init.constant_(module.bias, 0) + +class Generic_TransUNet_max_ppbp(SegmentationNetwork): + DEFAULT_BATCH_SIZE_3D = 2 + DEFAULT_PATCH_SIZE_3D = (64, 192, 160) + SPACING_FACTOR_BETWEEN_STAGES = 2 + BASE_NUM_FEATURES_3D = 30 + MAX_NUMPOOL_3D = 999 + MAX_NUM_FILTERS_3D = 320 + + DEFAULT_PATCH_SIZE_2D = (256, 256) + BASE_NUM_FEATURES_2D = 30 + DEFAULT_BATCH_SIZE_2D = 50 + MAX_NUMPOOL_2D = 999 + MAX_FILTERS_2D = 480 + + use_this_for_batch_size_computation_2D = 19739648 + use_this_for_batch_size_computation_3D = 520000000 # 505789440 + + def __init__(self, input_channels, base_num_features, num_classes, num_pool, num_conv_per_stage=2, + feat_map_mul_on_downscale=2, conv_op=nn.Conv2d, + norm_op=nn.BatchNorm2d, norm_op_kwargs=None, + dropout_op=nn.Dropout2d, dropout_op_kwargs=None, + nonlin=nn.LeakyReLU, nonlin_kwargs=None, deep_supervision=True, dropout_in_localization=False, + final_nonlin=softmax_helper, weightInitializer=InitWeights_He(1e-2), pool_op_kernel_sizes=None, + conv_kernel_sizes=None, + upscale_logits=False, convolutional_pooling=False, convolutional_upsampling=False, # TODO default False + max_num_features=None, basic_block=ConvDropoutNormNonlin, + seg_output_use_bias=False, + patch_size=None, is_vit_pretrain=False, + vit_depth=12, vit_hidden_size=768, vit_mlp_dim=3072, vit_num_heads=12, + max_msda='', is_max_ms=True, is_max_ms_fpn=False, max_n_fpn=4, max_ms_idxs=[-4,-3,-2], max_ss_idx=0, + is_max_bottleneck_transformer=False, max_seg_weight=1.0, max_hidden_dim=256, max_dec_layers=10, + mw = 0.5, + is_max=True, is_masked_attn=False, is_max_ds=False, is_masking=False, is_masking_argmax=False, + is_fam=False, fam_k=5, fam_reduct_ratio=8, + is_max_hungarian=False, num_queries=None, is_max_cls=False, + point_rend=False, num_point_rend=None, no_object_weight=None, is_mhsa_float32=False, no_max_hw_pe=False, + max_infer=None, cost_weight=[2.0, 5.0, 5.0], vit_layer_scale=False, decoder_layer_scale=False): + + super(Generic_TransUNet_max_ppbp, self).__init__() + + # newly added + self.is_fam = is_fam + self.is_max, self.max_msda, self.is_max_ms, self.is_max_ms_fpn, self.max_n_fpn, self.max_ss_idx, self.mw = is_max, max_msda, is_max_ms, is_max_ms_fpn, max_n_fpn, max_ss_idx, mw + self.max_ms_idxs = max_ms_idxs + + self.is_max_cls = is_max_cls + self.is_masked_attn, self.is_max_ds = is_masked_attn, is_max_ds + self.is_max_bottleneck_transformer = is_max_bottleneck_transformer + + self.convolutional_upsampling = convolutional_upsampling + self.convolutional_pooling = convolutional_pooling + self.upscale_logits = upscale_logits + if nonlin_kwargs is None: + nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True} + if dropout_op_kwargs is None: + dropout_op_kwargs = {'p': 0.5, 'inplace': True} + if norm_op_kwargs is None: + norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'momentum': 0.1} + + self.conv_kwargs = {'stride': 1, 'dilation': 1, 'bias': True} + + self.nonlin = nonlin + self.nonlin_kwargs = nonlin_kwargs + self.dropout_op_kwargs = dropout_op_kwargs + self.norm_op_kwargs = norm_op_kwargs + self.weightInitializer = weightInitializer + self.conv_op = conv_op + self.norm_op = norm_op + self.dropout_op = dropout_op + self.num_classes = num_classes + self.final_nonlin = final_nonlin + self._deep_supervision = deep_supervision + self.do_ds = deep_supervision + + if conv_op == nn.Conv2d: + upsample_mode = 'bilinear' + pool_op = nn.MaxPool2d + transpconv = nn.ConvTranspose2d + if pool_op_kernel_sizes is None: + pool_op_kernel_sizes = [(2, 2)] * num_pool + if conv_kernel_sizes is None: + conv_kernel_sizes = [(3, 3)] * (num_pool + 1) + elif conv_op == nn.Conv3d: + upsample_mode = 'trilinear' + pool_op = nn.MaxPool3d + transpconv = nn.ConvTranspose3d + if pool_op_kernel_sizes is None: + pool_op_kernel_sizes = [(2, 2, 2)] * num_pool + if conv_kernel_sizes is None: + conv_kernel_sizes = [(3, 3, 3)] * (num_pool + 1) + else: + raise ValueError("unknown convolution dimensionality, conv op: %s" % str(conv_op)) + + self.input_shape_must_be_divisible_by = np.prod(pool_op_kernel_sizes, 0, dtype=np.int64) + self.pool_op_kernel_sizes = pool_op_kernel_sizes + self.conv_kernel_sizes = conv_kernel_sizes + + self.conv_pad_sizes = [] + for krnl in self.conv_kernel_sizes: + self.conv_pad_sizes.append([1 if i == 3 else 0 for i in krnl]) + + if max_num_features is None: + if self.conv_op == nn.Conv3d: + self.max_num_features = self.MAX_NUM_FILTERS_3D + else: + self.max_num_features = self.MAX_FILTERS_2D + else: + self.max_num_features = max_num_features + + self.conv_blocks_context = [] + self.conv_blocks_localization = [] + self.td = [] + self.tu = [] + + + self.fams = [] + + output_features = base_num_features + input_features = input_channels + + for d in range(num_pool): + # determine the first stride + if d != 0 and self.convolutional_pooling: + first_stride = pool_op_kernel_sizes[d - 1] + else: + first_stride = None + + self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[d] + self.conv_kwargs['padding'] = self.conv_pad_sizes[d] + # add convolutions + self.conv_blocks_context.append(StackedConvLayers(input_features, output_features, num_conv_per_stage, + self.conv_op, self.conv_kwargs, self.norm_op, + self.norm_op_kwargs, self.dropout_op, + self.dropout_op_kwargs, self.nonlin, self.nonlin_kwargs, + first_stride, basic_block=basic_block)) + if not self.convolutional_pooling: + self.td.append(pool_op(pool_op_kernel_sizes[d])) + input_features = output_features + output_features = int(np.round(output_features * feat_map_mul_on_downscale)) + + output_features = min(output_features, self.max_num_features) + + # now the bottleneck. + # determine the first stride + if self.convolutional_pooling: + first_stride = pool_op_kernel_sizes[-1] + else: + first_stride = None + + # the output of the last conv must match the number of features from the skip connection if we are not using + # convolutional upsampling. If we use convolutional upsampling then the reduction in feature maps will be + # done by the transposed conv + if self.convolutional_upsampling: + final_num_features = output_features + else: + final_num_features = self.conv_blocks_context[-1].output_channels + + self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[num_pool] + self.conv_kwargs['padding'] = self.conv_pad_sizes[num_pool] + self.conv_blocks_context.append(nn.Sequential( + StackedConvLayers(input_features, output_features, num_conv_per_stage - 1, self.conv_op, self.conv_kwargs, + self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, self.nonlin, + self.nonlin_kwargs, first_stride, basic_block=basic_block), + StackedConvLayers(output_features, final_num_features, 1, self.conv_op, self.conv_kwargs, + self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, self.nonlin, + self.nonlin_kwargs, basic_block=basic_block))) + + # if we don't want to do dropout in the localization pathway then we set the dropout prob to zero here + if not dropout_in_localization: + old_dropout_p = self.dropout_op_kwargs['p'] + self.dropout_op_kwargs['p'] = 0.0 + + # now lets build the localization pathway + for u in range(num_pool): + nfeatures_from_down = final_num_features + nfeatures_from_skip = self.conv_blocks_context[ + -(2 + u)].output_channels # self.conv_blocks_context[-1] is bottleneck, so start with -2 + n_features_after_tu_and_concat = nfeatures_from_skip * 2 + + # the first conv reduces the number of features to match those of skip + # the following convs work on that number of features + # if not convolutional upsampling then the final conv reduces the num of features again + if u != num_pool - 1 and not self.convolutional_upsampling: + final_num_features = self.conv_blocks_context[-(3 + u)].output_channels + else: + final_num_features = nfeatures_from_skip + + if not self.convolutional_upsampling: + self.tu.append(Upsample(scale_factor=pool_op_kernel_sizes[-(u + 1)], mode=upsample_mode)) + else: + self.tu.append(transpconv(nfeatures_from_down, nfeatures_from_skip, pool_op_kernel_sizes[-(u + 1)], + pool_op_kernel_sizes[-(u + 1)], bias=False)) + + self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[- (u + 1)] + self.conv_kwargs['padding'] = self.conv_pad_sizes[- (u + 1)] + self.conv_blocks_localization.append(nn.Sequential( + StackedConvLayers(n_features_after_tu_and_concat, nfeatures_from_skip, num_conv_per_stage - 1, + self.conv_op, self.conv_kwargs, self.norm_op, self.norm_op_kwargs, self.dropout_op, + self.dropout_op_kwargs, self.nonlin, self.nonlin_kwargs, basic_block=basic_block), + StackedConvLayers(nfeatures_from_skip, final_num_features, 1, self.conv_op, self.conv_kwargs, + self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, + self.nonlin, self.nonlin_kwargs, basic_block=basic_block) + )) + + + + if self.is_fam: + self.fams = nn.ModuleList(self.fams) + + if self.do_ds: + self.seg_outputs = [] + for ds in range(len(self.conv_blocks_localization)): + self.seg_outputs.append(conv_op(self.conv_blocks_localization[ds][-1].output_channels, num_classes, + 1, 1, 0, 1, 1, seg_output_use_bias)) + self.seg_outputs = nn.ModuleList(self.seg_outputs) + + self.upscale_logits_ops = [] + cum_upsample = np.cumprod(np.vstack(pool_op_kernel_sizes), axis=0)[::-1] + for usl in range(num_pool - 1): + if self.upscale_logits: + self.upscale_logits_ops.append(Upsample(scale_factor=tuple([int(i) for i in cum_upsample[usl + 1]]), + mode=upsample_mode)) + else: + self.upscale_logits_ops.append(lambda x: x) + + if not dropout_in_localization: + self.dropout_op_kwargs['p'] = old_dropout_p + + # register all modules properly + self.conv_blocks_localization = nn.ModuleList(self.conv_blocks_localization) + self.conv_blocks_context = nn.ModuleList(self.conv_blocks_context) + self.td = nn.ModuleList(self.td) + self.tu = nn.ModuleList(self.tu) + + if self.upscale_logits: + self.upscale_logits_ops = nn.ModuleList( + self.upscale_logits_ops) # lambda x:x is not a Module so we need to distinguish here + + if self.weightInitializer is not None: + self.apply(self.weightInitializer) + # self.apply(print_module_training_status) + + # Transformer configuration + if self.is_max_bottleneck_transformer: + self.patch_size = patch_size # e.g. [48, 192, 192] + config_vit = CONFIGS_ViT['R50-ViT-B_16'] + config_vit.transformer.num_layers = vit_depth + config_vit.hidden_size = vit_hidden_size # 768 + config_vit.transformer.mlp_dim = vit_mlp_dim # 3072 + config_vit.transformer.num_heads = vit_num_heads # 12 + self.conv_more = nn.Conv3d(config_vit.hidden_size, output_features, 1) + num_pool_per_axis = np.prod(np.array(pool_op_kernel_sizes), axis=0) + num_pool_per_axis = np.log2(num_pool_per_axis).astype(np.uint8) + feat_size = [int(self.patch_size[0]/2**num_pool_per_axis[0]), int(self.patch_size[1]/2**num_pool_per_axis[1]), int(self.patch_size[2]/2**num_pool_per_axis[2])] + self.transformer = Transformer(config_vit, feat_size=feat_size, vis=False, feat_channels=output_features, use_layer_scale=vit_layer_scale) + if is_vit_pretrain: + self.transformer.load_from(weights=np.load(config_vit.pretrained_path)) + + + if self.is_max: + # Max PPB+ configuration (i.e. MultiScaleStandardTransformerDecoder) + cfg = { + "num_classes": num_classes, + "hidden_dim": max_hidden_dim, + "num_queries": num_classes if num_queries is None else num_queries, # N=K if 'fixed matching', else default=100, + "nheads": 8, + "dim_feedforward": max_hidden_dim * 8, # 2048, + "dec_layers": max_dec_layers, # 9 decoder layers, add one for the loss on learnable query? + "pre_norm": False, + "enforce_input_project": False, + "mask_dim": max_hidden_dim, # input feat of segm head? + "non_object": False, + "use_layer_scale": decoder_layer_scale, + } + cfg['non_object'] = is_max_cls + input_proj_list = [] # from low resolution to high resolution (res4 -> res1), [1, 1024, 14, 14], [1, 512, 28, 28], 1, 256, 56, 56], [1, 64, 112, 112] + decoder_channels = [320, 320, 256, 128, 64, 32] + if self.is_max_ms: # use multi-scale feature as Transformer decoder input + if self.is_max_ms_fpn: + for idx, in_channels in enumerate(decoder_channels[:max_n_fpn]): # max_n_fpn=4: 1/32, 1/16, 1/8, 1/4 + input_proj_list.append(nn.Sequential( + nn.Conv3d(in_channels, max_hidden_dim, kernel_size=1), + nn.GroupNorm(32, max_hidden_dim), + nn.Upsample(size=(int(patch_size[0]/2), int(patch_size[1]/4), int(patch_size[2]/4)), mode='trilinear') + )) # proj to scale (1, 1/2, 1/2), TODO: init + self.input_proj = nn.ModuleList(input_proj_list) + self.linear_encoder_feature = nn.Conv3d(max_hidden_dim * max_n_fpn, max_hidden_dim, 1, 1) # concat four-level feature + else: + for idx, in_channels in enumerate([decoder_channels[i] for i in self.max_ms_idxs]): + input_proj_list.append(nn.Sequential( + nn.Conv3d(in_channels, max_hidden_dim, kernel_size=1), + nn.GroupNorm(32, max_hidden_dim), + )) + self.input_proj = nn.ModuleList(input_proj_list) + + # self.linear_mask_features =nn.Conv3d(decoder_channels[max_n_fpn-1], cfg["mask_dim"], kernel_size=1, stride=1, padding=0,) # low-level feat, dot product Trans-feat + self.linear_mask_features =nn.Conv3d(decoder_channels[-1], cfg["mask_dim"], kernel_size=1, stride=1, padding=0,) # following SingleScale, high-level feat, obtain seg_map + else: + self.linear_encoder_feature = nn.Conv3d(decoder_channels[max_ss_idx], cfg["mask_dim"], kernel_size=1) + self.linear_mask_features = nn.Conv3d(decoder_channels[-1], cfg["mask_dim"], kernel_size=1, stride=1, padding=0,) # low-level feat, dot product Trans-feat + + if self.is_masked_attn: + from .mask2former_modeling.transformer_decoder.mask2former_transformer_decoder3d import MultiScaleMaskedTransformerDecoder3d + cfg['num_feature_levels'] = 1 if not self.is_max_ms or self.is_max_ms_fpn else 3 + cfg["is_masking"] = True if is_masking else False + cfg["is_masking_argmax"] = True if is_masking_argmax else False + cfg["is_mhsa_float32"] = True if is_mhsa_float32 else False + cfg["no_max_hw_pe"] = True if no_max_hw_pe else False + self.predictor = MultiScaleMaskedTransformerDecoder3d(in_channels=max_hidden_dim, mask_classification=is_max_cls, **cfg) + else: + from .mask2former_modeling.transformer_decoder.maskformer_transformer_decoder3d import StandardTransformerDecoder + cfg["dropout"], cfg["enc_layers"], cfg["deep_supervision"] = 0.1, 0, False + self.predictor = StandardTransformerDecoder(in_channels=max_hidden_dim, mask_classification=is_max_cls, **cfg) + + def forward(self, x): + skips = [] + seg_outputs = [] + for d in range(len(self.conv_blocks_context) - 1): + x = self.conv_blocks_context[d](x) + skips.append(x) + if not self.convolutional_pooling: + x = self.td[d](x) + + x = self.conv_blocks_context[-1](x) + ######### TransUNet ######### + if self.is_max_bottleneck_transformer: + x, attn = self.transformer(x) # [b, hidden, d/8, h/16, w/16] + x = self.conv_more(x) + ############################# + + ds_feats = [] # obtain multi-scale feature + ds_feats.append(x) + for u in range(len(self.tu)): + if unm", inputs, targets) + denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :] + loss = 1 - (numerator + 1) / (denominator + 1) + return loss + + +batch_dice_loss_jit = torch.jit.script( + batch_dice_loss +) # type: torch.jit.ScriptModule + + + +def batch_sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor): + """ + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + Returns: + Loss tensor + """ + hw = inputs.shape[1] + + pos = F.binary_cross_entropy_with_logits( + inputs, torch.ones_like(inputs), reduction="none" + ) + neg = F.binary_cross_entropy_with_logits( + inputs, torch.zeros_like(inputs), reduction="none" + ) + + loss = torch.einsum("nc,mc->nm", pos, targets) + torch.einsum( + "nc,mc->nm", neg, (1 - targets) + ) + + return loss / hw + + +batch_sigmoid_ce_loss_jit = torch.jit.script( + batch_sigmoid_ce_loss +) # type: torch.jit.ScriptModule + + +class HungarianMatcher3D(nn.Module): + """This class computes an assignment between the targets and the predictions of the network + + For efficiency reasons, the targets don't include the no_object. Because of this, in general, + there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, + while the others are un-matched (and thus treated as non-objects). + """ + + def __init__(self, cost_class: float = 1, cost_mask: float = 1, cost_dice: float = 1, ): + """Creates the matcher + + Params: + cost_class: This is the relative weight of the classification error in the matching cost + cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost + cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost + """ + super().__init__() + self.cost_class = cost_class + self.cost_mask = cost_mask + self.cost_dice = cost_dice + + def compute_cls_loss(self, inputs, targets): + """ Classification loss (NLL) + implemented in compute_loss() + """ + raise NotImplementedError + + + def compute_dice_loss(self, inputs, targets): + """ mask dice loss + inputs (B*K, C, H, W) + target (B*K, D, H, W) + """ + inputs = inputs.sigmoid() + inputs = inputs.flatten(1) + targets = targets.flatten(1) + num_masks = len(inputs) + + numerator = 2 * (inputs * targets).sum(-1) + denominator = inputs.sum(-1) + targets.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + return loss.sum() / num_masks + + + def compute_ce_loss(self, inputs, targets): + """mask ce loss""" + num_masks = len(inputs) + loss = F.binary_cross_entropy_with_logits(inputs.flatten(1), targets.flatten(1), reduction="none") + loss = loss.mean(1).sum() / num_masks + return loss + + def compute_dice(self, inputs, targets): + """ output (N_q, C, H, W) + target (K, D, H, W) + """ + inputs = inputs.sigmoid() + inputs = inputs.flatten(1) + targets = targets.flatten(1) + numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets) + denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :] + loss = 1 - (numerator + 1) / (denominator + 1) + return loss # [N_q, K] + + + def compute_ce(self, inputs, targets): + """ output (N_q, C, H, W) + target (K, D, H, W) + return (N_q, K) + """ + inputs = inputs.flatten(1) + targets = targets.flatten(1) + hw = inputs.shape[1] + + pos = F.binary_cross_entropy_with_logits( + inputs, torch.ones_like(inputs), reduction="none" + ) + + neg = F.binary_cross_entropy_with_logits( + inputs, torch.zeros_like(inputs), reduction="none" + ) + + loss = torch.einsum("nc,mc->nm", pos, targets) + torch.einsum( + "nc,mc->nm", neg, (1 - targets) + ) + + return loss / hw + + # target_onehot = torch.zeros_like(output, device=output.device) + # target_onehot.scatter_(1, target.long(), 1) + # assert (torch.argmax(target_onehot, dim=1) == target[:, 0].long()).all() + # ce_loss = F.binary_cross_entropy_with_logits(output, target_onehot) + # return ce_loss + + + @torch.no_grad() + def memory_efficient_forward(self, outputs, targets): + """More memory-friendly matching for single aux, outputs: (b, q, d, h, w)""" + """suppose each crop must contain foreground class""" + bs, num_queries = outputs["pred_logits"].shape[:2] + indices = [] + + # Iterate through batch size + for b in range(bs): + out_prob = outputs["pred_logits"][b].softmax(-1) # [num_queries, num_classes+1] + out_mask = outputs["pred_masks"][b] # [num_queries, H_pred, W_pred] + + tgt_ids = targets[b]["labels"] + tgt_mask = targets[b]["masks"].to(out_mask) # [K, D, H, W], K is number of classes shown in this image, and K < n_class + + # target_onehot = torch.zeros_like(tgt_mask, device=out_mask.device) + # target_onehot.scatter_(1, targets.long(), 1) + + cost_class = -out_prob[:, tgt_ids] # [num_queries, K] + + with autocast(enabled=False): + out_mask = out_mask.float() + tgt_mask = tgt_mask.float() + cost_dice = self.compute_dice(out_mask, tgt_mask) + cost_mask = self.compute_ce(out_mask, tgt_mask) + + # Final cost matrix + C = ( + self.cost_class * cost_class + + self.cost_mask * cost_mask + + self.cost_dice * cost_dice + ) + + C = C.reshape(num_queries, -1).cpu() # (num_queries, K) + + # linear_sum_assignment return a tuple of two arrays: row_ind, col_ind, the length of array is min(N_q, K) + # The cost of the assignment can be computed as cost_matrix[row_ind, col_ind].sum() + + indices.append(linear_sum_assignment(C)) + + final_indices = [ + (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) + for i, j in indices + ] + + return final_indices + + @torch.no_grad() + def forward(self, outputs, targets): + """Performs the matching + + Params: + outputs: This is a dict that contains at least these entries: + "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits + "pred_masks": Tensor of dim [batch_size, num_queries, H_pred, W_pred] with the predicted masks + + targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: + "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth + objects in the target) containing the class labels + "masks": Tensor of dim [num_target_boxes, H_gt, W_gt] containing the target masks + + Returns: + A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected targets (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_boxes) + """ + + return self.memory_efficient_forward(outputs, targets) + + def __repr__(self, _repr_indent=4): + head = "Matcher " + self.__class__.__name__ + body = [ + "cost_class: {}".format(self.cost_class), + "cost_mask: {}".format(self.cost_mask), + "cost_dice: {}".format(self.cost_dice), + ] + lines = [head] + [" " * _repr_indent + line for line in body] + return "\n".join(lines) + + + def _get_src_permutation_idx(self, indices): + # permute predictions following indices + batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + src_idx = torch.cat([src for (src, _) in indices]) + return batch_idx, src_idx + + def _get_tgt_permutation_idx(self, indices): + # permute targets following indices + batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + tgt_idx = torch.cat([tgt for (_, tgt) in indices]) + return batch_idx, tgt_idx + + +def compute_loss_hungarian(outputs, targets, idx, matcher, num_classes, point_rend=False, num_points=12544, oversample_ratio=3.0, importance_sample_ratio=0.75, no_object_weight=None, cost_weight=[2,5,5]): + """output is a dict only contain keys ['pred_masks', 'pred_logits'] """ + # outputs_without_aux = {k: v for k, v in output.items() if k != "aux_outputs"} + + indices = matcher(outputs, targets) + src_idx = matcher._get_src_permutation_idx(indices) # return a tuple of (batch_idx, src_idx) + tgt_idx = matcher._get_tgt_permutation_idx(indices) # return a tuple of (batch_idx, tgt_idx) + assert len(tgt_idx[0]) == sum([len(t["masks"]) for t in targets]) # verify that all masks of (K1, K2, ..) are used + + # step2 : compute mask loss + src_masks = outputs["pred_masks"] + src_masks = src_masks[src_idx] # [len(src_idx[0]), D, H, W] -> (K1+K2+..., D, H, W) + target_masks = torch.cat([t["masks"] for t in targets], dim=0) # (K1+K2+..., D, H, W) actually + src_masks = src_masks[:, None] # [K..., 1, D, H, W] + target_masks = target_masks[:, None] + + if point_rend: # only calculate hard example + with torch.no_grad(): + # num_points=12544 config in cityscapes + + # sample point_coords + point_coords = get_uncertain_point_coords_with_randomness( + src_masks.float(), + lambda logits: calculate_uncertainty(logits), + num_points, + oversample_ratio, + importance_sample_ratio, + ) # [K, num_points=12544, 3] + + point_labels = point_sample_3d( + target_masks.float(), + point_coords.float(), + align_corners=False, + ).squeeze(1) # [K, 12544] + + point_logits = point_sample_3d( + src_masks.float(), + point_coords.float(), + align_corners=False, + ).squeeze(1) # [K, 12544] + + src_masks, target_masks = point_logits, point_labels + + loss_mask_ce = matcher.compute_ce_loss(src_masks, target_masks) + loss_mask_dice = matcher.compute_dice_loss(src_masks, target_masks) + + # step3: compute class loss + src_logits = outputs["pred_logits"].float() # (B, num_query, num_class+1) + target_classes_o = torch.cat([t["labels"] for t in targets], dim=0) # (K1+K2+, ) + target_classes = torch.full( + src_logits.shape[:2], num_classes, dtype=torch.int64, device=src_logits.device + ) # (B, num_query, num_class+1) + target_classes[src_idx] = target_classes_o + + + if no_object_weight is not None: + empty_weight = torch.ones(num_classes + 1).to(src_logits.device) + empty_weight[-1] = no_object_weight + loss_cls = F.cross_entropy(src_logits.transpose(1, 2), target_classes, empty_weight) + else: + loss_cls = F.cross_entropy(src_logits.transpose(1, 2), target_classes) + + loss = (cost_weight[0]/10)*loss_cls + (cost_weight[1]/10)*loss_mask_ce + (cost_weight[2]/10)*loss_mask_dice # 2:5:5, like hungarian matching + # print("idx {}, loss {}, loss_cls {}, loss_mask_ce {}, loss_mask_dice {}".format(idx, loss, loss_cls, loss_mask_ce, loss_mask_dice)) + return loss + + +def point_sample_3d(input, point_coords, **kwargs): + """ + from detectron2.projects.point_rend.point_features + A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors. + Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside + [0, 1] x [0, 1] square. + Args: + input (Tensor): A tensor of shape (N, C, D, H, W) that contains features map on a D x H x W grid. + point_coords (Tensor): A tensor of shape (N, P, 3) or (N, Dgrid, Hgrid, Wgrid, 3) that contains + [0, 1] x [0, 1] x [0, 1] normalized point coordinates. + Returns: + output (Tensor): A tensor of shape (N, C, P) or (N, C, Dgrid, Hgrid, Wgrid) that contains + features for points in `point_coords`. The features are obtained via bilinear + interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`. + """ + add_dim = False + if point_coords.dim() == 3: + add_dim = True + point_coords = point_coords.unsqueeze(2).unsqueeze(2) # why + + # point_coords should be (N, D, H, W, 3) + output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs) + + if add_dim: + output = output.squeeze(3).squeeze(3) + + return output + + +def calculate_uncertainty(logits): + """ + We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the + foreground class in `classes`. + Args: + logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or + class-agnostic, where R is the total number of predicted masks in all images and C is + the number of foreground classes. The values are logits. + Returns: + scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with + the most uncertain locations having the highest uncertainty score. + """ + assert logits.shape[1] == 1 + gt_class_logits = logits.clone() + return -(torch.abs(gt_class_logits)) + + +# implemented! +def get_uncertain_point_coords_with_randomness( + coarse_logits, uncertainty_func, num_points, oversample_ratio, importance_sample_ratio): + """ + Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties + are calculated for each point using 'uncertainty_func' function that takes point's logit + prediction as input. + See PointRend paper for details. + Args: + coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for + class-specific or class-agnostic prediction. + uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that + contains logit predictions for P points and returns their uncertainties as a Tensor of + shape (N, 1, P). + num_points (int): The number of points P to sample. + oversample_ratio (int): Oversampling parameter. + importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling. + Returns: + point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P + sampled points. + """ + assert oversample_ratio >= 1 + assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0 + n_dim = 3 + num_boxes = coarse_logits.shape[0] + num_sampled = int(num_points * oversample_ratio) # 12544 * 3, oversampled + point_coords = torch.rand(num_boxes, num_sampled, n_dim, device=coarse_logits.device) # (K, 37632, 3); uniform dist [0, 1) + point_logits = point_sample_3d(coarse_logits, point_coords, align_corners=False) # (K, 1, 37632) + + # It is crucial to calculate uncertainty based on the sampled prediction value for the points. + # Calculating uncertainties of the coarse predictions first and sampling them for points leads + # to incorrect results. + # To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between + # two coarse predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value. + # However, if we calculate uncertainties for the coarse predictions first, + # both will have -1 uncertainty, and the sampled point will get -1 uncertainty. + point_uncertainties = uncertainty_func(point_logits) + num_uncertain_points = int(importance_sample_ratio * num_points) # 9408 + + num_random_points = num_points - num_uncertain_points # 3136 + idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] + + shift = num_sampled * torch.arange(num_boxes, dtype=torch.long, device=coarse_logits.device) + idx += shift[:, None] # [K, 9408] + + point_coords = point_coords.view(-1, n_dim)[idx.view(-1), :].view( + num_boxes, num_uncertain_points, n_dim + ) # [K, 9408, 3] + + if num_random_points > 0: + # from detectron2.layers import cat + point_coords = torch.cat( + [ + point_coords, + torch.rand(num_boxes, num_random_points, n_dim, device=coarse_logits.device), + ], + dim=1, + ) # [K, 12544, 3] + + return point_coords \ No newline at end of file diff --git a/pymic/net/net3d/unet2d5.py b/pymic/net/net3d/unet2d5.py index 308fdde..4e70393 100644 --- a/pymic/net/net3d/unet2d5.py +++ b/pymic/net/net3d/unet2d5.py @@ -1,9 +1,16 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division + +import logging import torch import torch.nn as nn import numpy as np +ConvND = {2: nn.Conv2d, 3: nn.Conv3d} +BatchNormND = {2: nn.BatchNorm2d, 3: nn.BatchNorm3d} +MaxPoolND = {2: nn.MaxPool2d, 3: nn.MaxPool3d} +ConvTransND = {2: nn.ConvTranspose2d, 3: nn.ConvTranspose3d} + class ConvBlockND(nn.Module): """ 2D or 3D convolutional block @@ -13,29 +20,17 @@ class ConvBlockND(nn.Module): :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. :param dropout_p: (int) Dropout probability. """ - def __init__(self, in_channels, out_channels, - dim = 2, dropout_p = 0.0): + def __init__(self, in_channels, out_channels, dim = 2, dropout_p = 0.0): super(ConvBlockND, self).__init__() assert(dim == 2 or dim == 3) self.dim = dim - if(self.dim == 2): - self.conv_conv = nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), - nn.BatchNorm2d(out_channels), - nn.PReLU(), - nn.Dropout(dropout_p), - nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), - nn.BatchNorm2d(out_channels), - nn.PReLU() - ) - else: - self.conv_conv = nn.Sequential( - nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1), - nn.BatchNorm3d(out_channels), + self.conv_conv = nn.Sequential( + ConvND[dim](in_channels, out_channels, kernel_size=3, padding=1), + BatchNormND[dim](out_channels), nn.PReLU(), nn.Dropout(dropout_p), - nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1), - nn.BatchNorm3d(out_channels), + ConvND[dim](out_channels, out_channels, kernel_size=3, padding=1), + BatchNormND[dim](out_channels), nn.PReLU() ) @@ -52,17 +47,12 @@ class DownBlock(nn.Module): :param dropout_p: (int) Dropout probability. :param downsample: (bool) Use downsample or not after convolution. """ - def __init__(self, in_channels, out_channels, - dim = 2, dropout_p = 0.0, downsample = True): + def __init__(self, in_channels, out_channels, dim = 2, dropout_p = 0.0, downsample = True): super(DownBlock, self).__init__() self.downsample = downsample self.dim = dim self.conv = ConvBlockND(in_channels, out_channels, dim, dropout_p) - if(downsample): - if(self.dim == 2): - self.down_layer = nn.MaxPool2d(kernel_size = 2, stride = 2) - else: - self.down_layer = nn.MaxPool3d(kernel_size = 2, stride = 2) + self.down_layer = MaxPoolND[dim](kernel_size = 2, stride = 2) def forward(self, x): x_shape = list(x.shape) @@ -95,28 +85,31 @@ class UpBlock(nn.Module): :param out_channels: (int) Output channel number. :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. :param dropout_p: (int) Dropout probability. - :param bilinear: (bool) Use bilinear for up-sampling or not. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear` for 3D and `Bilinear` for 2D). + The default value is 2. """ def __init__(self, in_channels1, in_channels2, out_channels, - dim = 2, dropout_p = 0.0, bilinear=True): + dim = 2, dropout_p = 0.0, up_mode= 2): super(UpBlock, self).__init__() - self.bilinear = bilinear - self.dim = dim - if bilinear: - if(dim == 2): - self.up = nn.Sequential( - nn.Conv2d(in_channels1, in_channels2, kernel_size = 1), - nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)) - else: - self.up = nn.Sequential( - nn.Conv3d(in_channels1, in_channels2, kernel_size = 1), - nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)) + if(isinstance(up_mode, int)): + up_mode_values = ["transconv", "nearest", "trilinear"] + if(up_mode > 2): + raise ValueError("The upsample mode should be 0-2, but {0:} is given.".format(up_mode)) + self.up_mode = up_mode_values[up_mode] else: - if(dim == 2): - self.up = nn.ConvTranspose2d(in_channels1, in_channels2, kernel_size=2, stride=2) + self.up_mode = up_mode.lower() + + self.dim = dim + if (self.up_mode == "transconv"): + self.up = ConvTransND[dim](in_channels1, in_channels2, kernel_size=2, stride=2) + else: + self.conv1x1 = ConvND[dim](in_channels1, in_channels2, kernel_size = 1) + if(self.up_mode == "nearest"): + self.up = nn.Upsample(scale_factor=2, mode=self.up_mode) else: - self.up = nn.ConvTranspose3d(in_channels1, in_channels2, kernel_size=2, stride=2) - + mode = "trilinear" if dim == 3 else "bilinear" + self.up = nn.Upsample(scale_factor=2, mode=mode, align_corners=True) self.conv = ConvBlockND(in_channels2 * 2, out_channels, dim, dropout_p) def forward(self, x1, x2): @@ -132,6 +125,8 @@ def forward(self, x1, x2): x2 = torch.transpose(x2, 1, 2) x2 = torch.reshape(x2, new_shape) + if self.up_mode != "transconv": + x1 = self.conv1x1(x1) x1 = self.up(x1) output = torch.cat([x2, x1], dim=1) output = self.conv(output) @@ -141,6 +136,98 @@ def forward(self, x1, x2): output = torch.transpose(output, 1, 2) return output +class Encoder(nn.Module): + """ + A modification of the encoder of 3D UNet by using ConvScSEBlock3D + + Parameters are given in the `params` dictionary. + See :mod:`pymic.net.net3d.unet3d.Encoder` for details. + """ + def __init__(self, params): + super(Encoder, self).__init__() + self.params = params + self.in_chns = self.params['in_chns'] + self.n_class = self.params['class_num'] + self.ft_chns = self.params['feature_chns'] + self.dropout = self.params['dropout'] + self.dims = self.params['conv_dims'] + + self.block0 = DownBlock(self.in_chns, self.ft_chns[0], self.dims[0], self.dropout[0], True) + self.block1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dims[1], self.dropout[1], True) + self.block2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dims[2], self.dropout[2], True) + self.block3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dims[3], self.dropout[3], True) + self.block4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dims[4], self.dropout[4], False) + + def forward(self, x): + x0, x0_d = self.block0(x) + x1, x1_d = self.block1(x0_d) + x2, x2_d = self.block2(x1_d) + x3, x3_d = self.block3(x2_d) + x4, x4_d = self.block4(x3_d) + return [x0, x1, x2, x3, x4] + +class Decoder(nn.Module): + """ + Decoder of 3D UNet. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 4 or 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear` for 3D and `Bilinear` for 2D). + The default value is 2. + :param multiscale_pred: (bool) Get multi-scale prediction. + """ + def __init__(self, params): + super(Decoder, self).__init__() + self.params = params + self.in_chns = self.params['in_chns'] + self.n_class = self.params['class_num'] + self.ft_chns = self.params['feature_chns'] + self.dropout = self.params['dropout'] + self.dims = self.params['conv_dims'] + self.up_mode = self.params.get('up_mode', 2) + self.mul_pred = self.params.get('multiscale_pred', False) + + self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], + self.dims[3], dropout_p = self.dropout[3], up_mode=self.up_mode) + self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], + self.dims[2], dropout_p = self.dropout[2], up_mode=self.up_mode) + self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], + self.dims[1], dropout_p = self.dropout[1], up_mode=self.up_mode) + self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], + self.dims[0], dropout_p = self.dropout[0], up_mode=self.up_mode) + + self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 1) + if(self.mul_pred): + self.out_conv1 = nn.Conv3d(self.ft_chns[1], self.n_class, kernel_size = 1) + self.out_conv2 = nn.Conv3d(self.ft_chns[2], self.n_class, kernel_size = 1) + self.out_conv3 = nn.Conv3d(self.ft_chns[3], self.n_class, kernel_size = 1) + self.stage = 'train' + + def set_stage(self, stage): + self.stage = stage + + def forward(self, x): + x0, x1, x2, x3, x4 = x + x_d3 = self.up1(x4, x3) + x_d2 = self.up2(x_d3, x2) + x_d1 = self.up3(x_d2, x1) + x_d0 = self.up4(x_d1, x0) + output = self.out_conv(x_d0) + if(self.mul_pred and self.stage == 'train'): + output1 = self.out_conv1(x_d1) + output2 = self.out_conv2(x_d2) + output3 = self.out_conv3(x_d3) + output = [output, output1, output2, output3] + return output + class UNet2D5(nn.Module): """ A 2.5D network combining 3D convolutions with 2D convolutions. @@ -164,68 +251,39 @@ class UNet2D5(nn.Module): :param conv_dims: (list) The convolution dimension (2 or 3) for each resolution level. The length should be the same as that of `feature_chns`. :param class_num: (int) The class number for segmentation task. - :param bilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear`). The default value + is 2 (`Trilinear`). + :param multiscale_pred: (bool) Get multi-scale prediction. """ def __init__(self, params): super(UNet2D5, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.dims = self.params['conv_dims'] - self.n_class = self.params['class_num'] - self.bilinear = self.params['bilinear'] - - assert(len(self.ft_chns) == 5) + params = self.get_default_parameters(params) + for p in params: + print(p, params[p]) + self.stage = 'train' + self.encoder = Encoder(params) + self.decoder = Decoder(params) + + def get_default_parameters(self, params): + default_param = { + 'feature_chns': [32, 64, 128, 256, 512], + 'dropout': [0.0, 0.0, 0.2, 0.3, 0.4], + 'conv_dims':[2, 2, 3, 3, 3], + 'up_mode': 2, + 'multiscale_pred': False + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params - self.block0 = DownBlock(self.in_chns, self.ft_chns[0], self.dims[0], self.dropout[0], True) - self.block1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dims[1], self.dropout[1], True) - self.block2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dims[2], self.dropout[2], True) - self.block3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dims[3], self.dropout[3], True) - self.block4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dims[4], self.dropout[4], False) - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], - self.dims[3], dropout_p = self.dropout[3], bilinear = self.bilinear) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], - self.dims[2], dropout_p = self.dropout[2], bilinear = self.bilinear) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], - self.dims[1], dropout_p = self.dropout[1], bilinear = self.bilinear) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], - self.dims[0], dropout_p = self.dropout[0], bilinear = self.bilinear) - - self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, - kernel_size = (1, 3, 3), padding = (0, 1, 1)) + def set_stage(self, stage): + self.stage = stage + self.decoder.set_stage(stage) def forward(self, x): - x0, x0_d = self.block0(x) - x1, x1_d = self.block1(x0_d) - x2, x2_d = self.block2(x1_d) - x3, x3_d = self.block3(x2_d) - x4, x4_d = self.block4(x3_d) - - x = self.up1(x4, x3) - x = self.up2(x, x2) - x = self.up3(x, x1) - x = self.up4(x, x0) - output = self.out_conv(x) + f = self.encoder(x) + output = self.decoder(f) return output - - -if __name__ == "__main__": - params = {'in_chns':4, - 'feature_chns':[2, 8, 32, 48, 64], - 'conv_dims': [2, 2, 3, 3, 3], - 'dropout': [0, 0, 0.3, 0.4, 0.5], - 'class_num': 2, - 'bilinear': False} - Net = UNet2D5(params) - Net = Net.double() - - x = np.random.rand(4, 4, 32, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - print(len(y.size())) - y = y.detach().numpy() - print(y.shape) diff --git a/pymic/net/net3d/unet3d.py b/pymic/net/net3d/unet3d.py index a17bcb8..5954869 100644 --- a/pymic/net/net3d/unet3d.py +++ b/pymic/net/net3d/unet3d.py @@ -1,10 +1,12 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division +import logging import torch import torch.nn as nn import numpy as np -from torch.nn.functional import interpolate +from pymic.net.net_init import Initialization_He, Initialization_XavierUniform + class ConvBlock(nn.Module): """ @@ -15,15 +17,23 @@ class ConvBlock(nn.Module): :param out_channels: (int) Output channel number. :param dropout_p: (int) Dropout probability. """ - def __init__(self, in_channels, out_channels, dropout_p): + def __init__(self, in_channels, out_channels, dropout_p, norm_type = 'batch_norm'): super(ConvBlock, self).__init__() + if(norm_type == 'batch_norm'): + norm1 = nn.BatchNorm3d(out_channels, affine = True) + norm2 = nn.BatchNorm3d(out_channels, affine = True) + elif(norm_type == 'instance_norm'): + norm1 = nn.InstanceNorm3d(out_channels, affine = True) + norm2 = nn.InstanceNorm3d(out_channels, affine = True) + else: + raise ValueError("norm_type {0:} not supported, it should be batch_norm or instance_norm".format(norm_type)) self.conv_conv = nn.Sequential( nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1), - nn.BatchNorm3d(out_channels), + norm1, nn.LeakyReLU(), nn.Dropout(dropout_p), nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1), - nn.BatchNorm3d(out_channels), + norm2, nn.LeakyReLU() ) @@ -38,11 +48,11 @@ class DownBlock(nn.Module): :param out_channels: (int) Output channel number. :param dropout_p: (int) Dropout probability. """ - def __init__(self, in_channels, out_channels, dropout_p): + def __init__(self, in_channels, out_channels, dropout_p, norm_type = 'batch_norm'): super(DownBlock, self).__init__() self.maxpool_conv = nn.Sequential( nn.MaxPool3d(2), - ConvBlock(in_channels, out_channels, dropout_p) + ConvBlock(in_channels, out_channels, dropout_p, norm_type) ) def forward(self, x): @@ -56,22 +66,33 @@ class UpBlock(nn.Module): :param in_channels2: (int) Channel number of low-level features. :param out_channels: (int) Output channel number. :param dropout_p: (int) Dropout probability. - :param trilinear: (bool) Use trilinear for up-sampling (by default). - If False, deconvolution is used for up-sampling. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear`). The default value + is 2 (`Trilinear`). """ - def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, - trilinear=True): + def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, + up_mode=2, norm_type = 'batch_norm'): super(UpBlock, self).__init__() - self.trilinear = trilinear - if trilinear: - self.conv1x1 = nn.Conv3d(in_channels1, in_channels2, kernel_size = 1) - self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True) + if(isinstance(up_mode, int)): + up_mode_values = ["transconv", "nearest", "trilinear"] + if(up_mode > 2): + raise ValueError("The upsample mode should be 0-2, but {0:} is given.".format(up_mode)) + self.up_mode = up_mode_values[up_mode] else: + self.up_mode = up_mode.lower() + + if (self.up_mode == "transconv"): self.up = nn.ConvTranspose3d(in_channels1, in_channels2, kernel_size=2, stride=2) - self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p) + else: + self.conv1x1 = nn.Conv3d(in_channels1, in_channels2, kernel_size = 1) + if(self.up_mode == "nearest"): + self.up = nn.Upsample(scale_factor=2, mode=self.up_mode) + else: + self.up = nn.Upsample(scale_factor=2, mode=self.up_mode, align_corners=True) + self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p, norm_type) def forward(self, x1, x2): - if self.trilinear: + if self.up_mode != "transconv": x1 = self.conv1x1(x1) x1 = self.up(x1) x = torch.cat([x2, x1], dim=1) @@ -93,17 +114,19 @@ class Encoder(nn.Module): def __init__(self, params): super(Encoder, self).__init__() self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) + in_chns = self.params['in_chns'] + ft_chns = self.params['feature_chns'] + dropout = self.params['dropout'] + norm_type = self.params['norm_type'] + assert(len(ft_chns) == 5 or len(ft_chns) == 4) - self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) - self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) - self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) - self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) - if(len(self.ft_chns) == 5): - self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) + self.ft_chns= ft_chns + self.in_conv= ConvBlock(in_chns, ft_chns[0], dropout[0], norm_type) + self.down1 = DownBlock(ft_chns[0], ft_chns[1], dropout[1], norm_type) + self.down2 = DownBlock(ft_chns[1], ft_chns[2], dropout[2], norm_type) + self.down3 = DownBlock(ft_chns[2], ft_chns[3], dropout[3], norm_type) + if(len(ft_chns) == 5): + self.down4 = DownBlock(ft_chns[3], ft_chns[4], dropout[4]) def forward(self, x): x0 = self.in_conv(x) @@ -129,32 +152,38 @@ class Decoder(nn.Module): :param dropout: (list) The dropout ratio for each resolution level. The length should be the same as that of `feature_chns`. :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - :param multiscale_pred: (bool) Get multi-scale prediction. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear`). The default value + is 2 (`Trilinear`). + :param multiscale_pred: (bool) Get multi-scale prediction. """ def __init__(self, params): super(Decoder, self).__init__() self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.n_class = self.params['class_num'] - self.trilinear = self.params.get('trilinear', True) - self.mul_pred = self.params.get('multiscale_pred', False) + ft_chns = self.params['feature_chns'] + dropout = self.params['dropout'] + n_class = self.params['class_num'] + norm_type = self.params['norm_type'] + up_mode = self.params.get('up_mode', 2) + self.ft_chns = ft_chns + self.mul_pred = self.params.get('multiscale_pred', False) + assert(len(ft_chns) == 5 or len(ft_chns) == 4) - assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) + if(len(ft_chns) == 5): + self.up1 = UpBlock(ft_chns[4], ft_chns[3], ft_chns[3], dropout[3], up_mode, norm_type) + self.up2 = UpBlock(ft_chns[3], ft_chns[2], ft_chns[2], dropout[2], up_mode, norm_type) + self.up3 = UpBlock(ft_chns[2], ft_chns[1], ft_chns[1], dropout[1], up_mode, norm_type) + self.up4 = UpBlock(ft_chns[1], ft_chns[0], ft_chns[0], dropout[0], up_mode, norm_type) + self.out_conv = nn.Conv3d(ft_chns[0], n_class, kernel_size = 1) - if(len(self.ft_chns) == 5): - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.trilinear) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.trilinear) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.trilinear) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.trilinear) - self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 1) if(self.mul_pred): - self.out_conv1 = nn.Conv3d(self.ft_chns[1], self.n_class, kernel_size = 1) - self.out_conv2 = nn.Conv3d(self.ft_chns[2], self.n_class, kernel_size = 1) - self.out_conv3 = nn.Conv3d(self.ft_chns[3], self.n_class, kernel_size = 1) + self.out_conv1 = nn.Conv3d(ft_chns[1], n_class, kernel_size = 1) + self.out_conv2 = nn.Conv3d(ft_chns[2], n_class, kernel_size = 1) + self.out_conv3 = nn.Conv3d(ft_chns[3], n_class, kernel_size = 1) + self.stage = 'train' + + def set_stage(self, stage): + self.stage = stage def forward(self, x): if(len(self.ft_chns) == 5): @@ -169,7 +198,7 @@ def forward(self, x): x_d1 = self.up3(x_d2, x1) x_d0 = self.up4(x_d1, x0) output = self.out_conv(x_d0) - if(self.mul_pred): + if(self.mul_pred and self.stage == 'train'): output1 = self.out_conv1(x_d1) output2 = self.out_conv2(x_d2) output3 = self.out_conv3(x_d3) @@ -196,77 +225,63 @@ class UNet3D(nn.Module): :param dropout: (list) The dropout ratio for each resolution level. The length should be the same as that of `feature_chns`. :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using trilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear`). The default value + is 2 (`Trilinear`). :param multiscale_pred: (bool) Get multi-scale prediction. """ def __init__(self, params): super(UNet3D, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.n_class = self.params['class_num'] - self.trilinear = self.params['trilinear'] - self.mul_pred = self.params['multiscale_pred'] - assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) + params = self.get_default_parameters(params) + for p in params: + print(p, params[p]) + self.stage = 'train' + self.tune_mode= params.get('finetune_mode', 'all') + self.load_mode= params.get('weights_load_mode', 'all') + self.encoder = Encoder(params) + self.decoder = Decoder(params) - self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) - self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) - self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) - self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) - if(len(self.ft_chns) == 5): - self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], - dropout_p = self.dropout[3], trilinear=self.trilinear) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], - dropout_p = self.dropout[2], trilinear=self.trilinear) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], - dropout_p = self.dropout[1], trilinear=self.trilinear) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], - dropout_p = self.dropout[0], trilinear=self.trilinear) - - self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 1) - if(self.mul_pred): - self.out_conv1 = nn.Conv3d(self.ft_chns[1], self.n_class, kernel_size = 1) - self.out_conv2 = nn.Conv3d(self.ft_chns[2], self.n_class, kernel_size = 1) - self.out_conv3 = nn.Conv3d(self.ft_chns[3], self.n_class, kernel_size = 1) + init = params['initialization'].lower() + weightInitializer = Initialization_He(1e-2) if init == 'he' else Initialization_XavierUniform() + self.apply(weightInitializer) + def get_default_parameters(self, params): + default_param = { + 'feature_chns': [32, 64, 128, 256, 512], + 'dropout': [0.0, 0.0, 0.2, 0.3, 0.4], + 'up_mode': 2, + 'initialization': 'he', + 'norm_type': 'batch_norm', + 'multiscale_pred': False + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params + + def set_stage(self, stage): + self.stage = stage + self.decoder.set_stage(stage) + + def forward(self, x): - x0 = self.in_conv(x) - x1 = self.down1(x0) - x2 = self.down2(x1) - x3 = self.down3(x2) - if(len(self.ft_chns) == 5): - x4 = self.down4(x3) - x_d3 = self.up1(x4, x3) - else: - x_d3 = x3 - x_d2 = self.up2(x_d3, x2) - x_d1 = self.up3(x_d2, x1) - x_d0 = self.up4(x_d1, x0) - output = self.out_conv(x_d0) - if(self.mul_pred): - output1 = self.out_conv1(x_d1) - output2 = self.out_conv2(x_d2) - output3 = self.out_conv3(x_d3) - output = [output, output1, output2, output3] + f = self.encoder(x) + output = self.decoder(f) return output -if __name__ == "__main__": - params = {'in_chns':4, - 'class_num': 2, - 'feature_chns':[2, 8, 32, 64], - 'dropout' : [0, 0, 0, 0.5], - 'trilinear': True, - 'multiscale_pred': False} - Net = UNet3D(params) - Net = Net.double() + def get_parameters_to_update(self): + if(self.tune_mode == "all"): + return self.parameters() + elif(self.tune_mode == "decoder"): + print("only update parameters in decoder") + params = self.decoder.parameters() + return params + else: + raise(ValueError("update_mode can only be 'all' or 'decoder'.")) - x = np.random.rand(4, 4, 96, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - y = y.detach().numpy() - print(y.shape) + def get_parameters_to_load(self): + state_dict = self.state_dict() + if(self.load_mode == 'encoder'): + state_dict = {k:v for k, v in state_dict.items() if "encoder" in k } + return state_dict diff --git a/pymic/net/net3d/unet3d_dual_branch.py b/pymic/net/net3d/unet3d_dual_branch.py index 3bede4e..54b01a0 100644 --- a/pymic/net/net3d/unet3d_dual_branch.py +++ b/pymic/net/net3d/unet3d_dual_branch.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division -import torch import torch.nn as nn from pymic.net.net3d.unet3d import * @@ -20,7 +19,7 @@ class UNet3D_DualBranch(nn.Module): :param output_mode: (str) How to obtain the result during the inference. `average`: taking average of the two branches. - `first`: takeing the result in the first branch. + `first`: taking the result in the first branch. `second`: taking the result in the second branch. """ def __init__(self, params): diff --git a/pymic/net/net3d/unet3d_scse.py b/pymic/net/net3d/unet3d_scse.py index b2da0dc..49cecc4 100644 --- a/pymic/net/net3d/unet3d_scse.py +++ b/pymic/net/net3d/unet3d_scse.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn import numpy as np +from pymic.net.net3d.unet3d import UpBlock, Encoder, Decoder, UNet3D from pymic.net.net3d.scse3d import * class ConvScSEBlock3D(nn.Module): @@ -48,108 +49,73 @@ def __init__(self, in_channels, out_channels, dropout_p): def forward(self, x): return self.maxpool_conv(x) -class UpBlock(nn.Module): +class UpBlockScSE(UpBlock): """3D Up-sampling followed by `ConvScSEBlock3D` in UNet3D_ScSE. :param in_channels1: (int) Input channel number for low-resolution feature map. :param in_channels2: (int) Input channel number for high-resolution feature map. :param out_channels: (int) Output channel number. :param dropout_p: (int) Dropout probability. - :param trilinear: (bool) Use trilinear for up-sampling or not. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear`). The default value + is 2 (`Trilinear`). """ - def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, - trilinear=True): - super(UpBlock, self).__init__() - self.trilinear = trilinear - if trilinear: - self.conv1x1 = nn.Conv3d(in_channels1, in_channels2, kernel_size = 1) - self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True) - else: - self.up = nn.ConvTranspose3d(in_channels1, in_channels2, kernel_size=2, stride=2) + def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, up_mode=2): + super(UpBlockScSE, self).__init__(in_channels1, in_channels2, + out_channels, dropout_p, up_mode) self.conv = ConvScSEBlock3D(in_channels2 * 2, out_channels, dropout_p) - def forward(self, x1, x2): - if self.trilinear: - x1 = self.conv1x1(x1) - x1 = self.up(x1) - x = torch.cat([x2, x1], dim=1) - return self.conv(x) -class UNet3D_ScSE(nn.Module): +class EncoderScSE(Encoder): """ - Combining 3D U-Net with SCSE module. - - * Reference: Abhijit Guha Roy, Nassir Navab, Christian Wachinger: - Recalibrating Fully Convolutional Networks With Spatial and Channel - "Squeeze and Excitation" Blocks. - `IEEE Trans. Med. Imaging 38(2): 540-549 (2019). `_ + A modification of the encoder of 3D UNet by using ConvScSEBlock3D - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using trilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. + Parameters are given in the `params` dictionary. + See :mod:`pymic.net.net3d.unet3d.Encoder` for details. """ def __init__(self, params): - super(UNet3D_ScSE, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.n_class = self.params['class_num'] - self.bilinear = self.params['trilinear'] + super(EncoderScSE, self).__init__(params) - assert(len(self.ft_chns) == 5) + in_chns = self.params['in_chns'] + dropout = self.params['dropout'] + self.in_conv= ConvScSEBlock3D(in_chns, self.ft_chns[0], dropout[0]) + self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], dropout[1]) + self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], dropout[2]) + self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], dropout[3]) + if(len(self.ft_chns) == 5): + self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], dropout[4]) + +class DecoderScSE(Decoder): + """ + A modification of the decoder of 3D UNet by using ConvScSEBlock3D - self.in_conv= ConvScSEBlock3D(self.in_chns, self.ft_chns[0], self.dropout[0]) - self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) - self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) - self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) - self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p = self.dropout[3]) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p = self.dropout[2]) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p = self.dropout[1]) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p = self.dropout[0]) - - self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, - kernel_size = 3, padding = 1) + Parameters are given in the `params` dictionary. + See :mod:`pymic.net.net3d.unet3d.Decoder` for details. + """ + def __init__(self, params): + super(DecoderScSE, self).__init__(params) + dropout = self.params['dropout'] + up_mode = self.params.get('up_mode', 2) + if(len(self.ft_chns) == 5): + self.up1 = UpBlockScSE(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout[3], up_mode) + self.up2 = UpBlockScSE(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout[2], up_mode) + self.up3 = UpBlockScSE(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout[1], up_mode) + self.up4 = UpBlockScSE(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout[0], up_mode) - def forward(self, x): - - x0 = self.in_conv(x) - x1 = self.down1(x0) - x2 = self.down2(x1) - x3 = self.down3(x2) - x4 = self.down4(x3) - - x = self.up1(x4, x3) - x = self.up2(x, x2) - x = self.up3(x, x1) - x = self.up4(x, x0) - output = self.out_conv(x) - return output +class UNet3D_ScSE(UNet3D): + """ + Combining 3D U-Net with SCSE module. -if __name__ == "__main__": - params = {'in_chns':4, - 'feature_chns':[2, 8, 32, 48, 64], - 'dropout': [0, 0, 0.3, 0.4, 0.5], - 'class_num': 2, - 'trilinear': True} - Net = UNet3D_ScSE(params) - Net = Net.double() + * Reference: Abhijit Guha Roy, Nassir Navab, Christian Wachinger: + Recalibrating Fully Convolutional Networks With Spatial and Channel + "Squeeze and Excitation" Blocks. + `IEEE Trans. Med. Imaging 38(2): 540-549 (2019). `_ - x = np.random.rand(4, 4, 96, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - print(len(y.size())) - y = y.detach().numpy() - print(y.shape) \ No newline at end of file + Parameters are given in the `params` dictionary. + See :mod:`pymic.net.net3d.unet3d.UNet3D` for details. + """ + def __init__(self, params): + super(UNet3D_ScSE, self).__init__(params) + self.encoder = EncoderScSE(params) + self.decoder = DecoderScSE(params) diff --git a/pymic/net/net_dict_cls.py b/pymic/net/net_dict_cls.py index 7996e59..a83334a 100644 --- a/pymic/net/net_dict_cls.py +++ b/pymic/net/net_dict_cls.py @@ -3,7 +3,7 @@ Built-in networks for classification. * resnet18 :mod:`pymic.net.cls.torch_pretrained_net.ResNet18` -* vgg16 :mod:`pymic.net.cls.torch_pretrained_net.VGG16` +* vgg16 :mod:`pymic.net.cls.torch_pretrained_net.VGG16` * mobilenetv2 :mod:`pymic.net.cls.torch_pretrained_net.MobileNetV2` """ @@ -13,5 +13,6 @@ TorchClsNetDict = { 'resnet18': ResNet18, 'vgg16': VGG16, - 'mobilenetv2':MobileNetV2 + 'mobilenetv2':MobileNetV2, + 'vitb16': ViTB16 } diff --git a/pymic/net/net_dict_seg.py b/pymic/net/net_dict_seg.py index fc7692f..6f0f0c6 100644 --- a/pymic/net/net_dict_seg.py +++ b/pymic/net/net_dict_seg.py @@ -7,6 +7,7 @@ * UNet2D_CCT :mod:`pymic.net.net2d.unet2d_cct.UNet2D_CCT` * UNet2D_ScSE :mod:`pymic.net.net2d.unet2d_scse.UNet2D_ScSE` * AttentionUNet2D :mod:`pymic.net.net2d.unet2d_attention.AttentionUNet2D` +* MCNet2D :mod:`pymic.net.net2d.unet2d_mcnet.MCNet2D` * NestedUNet2D :mod:`pymic.net.net2d.unet2d_nest.NestedUNet2D` * COPLENet :mod:`pymic.net.net2d.cople_net.COPLENet` * UNet2D5 :mod:`pymic.net.net3d.unet2d5.UNet2D5` @@ -15,28 +16,72 @@ """ from __future__ import print_function, division from pymic.net.net2d.unet2d import UNet2D -from pymic.net.net2d.unet2d_dual_branch import UNet2D_DualBranch +from pymic.net.net2d.unet2d_multi_decoder import UNet2D_DualBranch, MCNet2D +from pymic.net.net2d.unet2d_canet import CANet from pymic.net.net2d.unet2d_cct import UNet2D_CCT +# from pymic.net.net2d.unet2d_mtnet import MTNet2D from pymic.net.net2d.cople_net import COPLENet from pymic.net.net2d.unet2d_attention import AttentionUNet2D -from pymic.net.net2d.unet2d_nest import NestedUNet2D +from pymic.net.net2d.unet2d_pp import UNet2Dpp from pymic.net.net2d.unet2d_scse import UNet2D_ScSE +from pymic.net.net2d.trans2d.transunet import TransUNet +from pymic.net.net2d.trans2d.swinunet import SwinUNet +from pymic.net.net2d.umamba import UMambaBot, UMambaEnc +from pymic.net.net2d.unet2d_vm import VMUNet +from pymic.net.net2d.unet2d_vm_light import UltraLight_VM_UNet from pymic.net.net3d.unet2d5 import UNet2D5 from pymic.net.net3d.unet3d import UNet3D +from pymic.net.net3d.grunet import GRUNet +from pymic.net.net3d.fmunetv3 import FMUNetV3 +from pymic.net.net3d.fmunet import FMUNet +from pymic.net.net3d.lcovnet import LCOVNet from pymic.net.net3d.unet3d_scse import UNet3D_ScSE from pymic.net.net3d.unet3d_dual_branch import UNet3D_DualBranch +# from pymic.net.net3d.stunet_wrap import STUNet_wrap +# from pymic.net.net3d.mystunet import MySTUNet + +# from pymic.net.net3d.trans3d.nnFormer_wrap import nnFormer_wrap +# from pymic.net.net3d.trans3d.unetr import UNETR +# from pymic.net.net3d.trans3d.unetr_pp import UNETR_PP +# from pymic.net.net3d.trans3d.MedFormer_v1 import MedFormerV1 +# from pymic.net.net3d.trans3d.MedFormer_v2 import MedFormerV2 +# from pymic.net.net3d.trans3d.MedFormer_v3 import MedFormerV3 +# from pymic.net.net3d.trans3d.MedFormer_va1 import MedFormerVA1 +# from pymic.net.net3d.trans3d.HiFormer_v1 import HiFormer_v1 +# from pymic.net.net3d.trans3d.HiFormer_v2 import HiFormer_v2 +# from pymic.net.net3d.trans3d.HiFormer_v3 import HiFormer_v3 +# from pymic.net.net3d.trans3d.HiFormer_v4 import HiFormer_v4 +# from pymic.net.net3d.trans3d.HiFormer_v5 import HiFormer_v5 +# from pymic.net.net3d.trans3d.SwitchNet import SwitchNet SegNetDict = { + 'AttentionUNet2D': AttentionUNet2D, + 'CANet': CANet, + 'COPLENet': COPLENet, + 'MCNet2D': MCNet2D, + # 'MTNet2D': MTNet2D, 'UNet2D': UNet2D, 'UNet2D_DualBranch': UNet2D_DualBranch, 'UNet2D_CCT': UNet2D_CCT, - 'COPLENet': COPLENet, - 'AttentionUNet2D': AttentionUNet2D, - 'NestedUNet2D': NestedUNet2D, + 'UNet2Dpp': UNet2Dpp, 'UNet2D_ScSE': UNet2D_ScSE, + 'UMambaBot': UMambaBot, + 'UMambaEnc': UMambaEnc, + 'VMUNet':VMUNet, + 'UltraLight_VM_UNet': UltraLight_VM_UNet, + 'TransUNet': TransUNet, + 'SwinUNet': SwinUNet, 'UNet2D5': UNet2D5, + 'GRUNet': GRUNet, + 'LCOVNet': LCOVNet, + 'FMUNet': FMUNet, + 'FMUNetV3': FMUNetV3, 'UNet3D': UNet3D, 'UNet3D_ScSE': UNet3D_ScSE, - 'UNet3D_DualBranch': UNet3D_DualBranch - + 'UNet3D_DualBranch': UNet3D_DualBranch, + # 'STUNet': STUNet_wrap, + # 'MySTUNet': MySTUNet, + # 'nnFormer': nnFormer_wrap, + # 'UNETR': UNETR, + # 'UNETR_PP': UNETR_PP, } diff --git a/pymic/net/net_init.py b/pymic/net/net_init.py new file mode 100644 index 0000000..1f9b48e --- /dev/null +++ b/pymic/net/net_init.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +from torch import nn + + +class Initialization_He(object): + def __init__(self, neg_slope=1e-2): + self.neg_slope = neg_slope + + def __call__(self, module): + if isinstance(module, (nn.Conv3d, nn.Conv2d, nn.ConvTranspose2d, nn.ConvTranspose3d)): + module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope) + if module.bias is not None: + module.bias = nn.init.constant_(module.bias, 0) + + +class Initialization_XavierUniform(object): + def __init__(self, gain=1): + self.gain = gain + + def __call__(self, module): + if isinstance(module, (nn.Conv3d ,nn.Conv2d, nn.ConvTranspose2d, nn.ConvTranspose3d)): + module.weight = nn.init.xavier_uniform_(module.weight, self.gain) + if module.bias is not None: + module.bias = nn.init.constant_(module.bias, 0) diff --git a/pymic/net_run/__init__.py b/pymic/net_run/__init__.py index 72b8078..e69de29 100644 --- a/pymic/net_run/__init__.py +++ b/pymic/net_run/__init__.py @@ -1,2 +0,0 @@ -from __future__ import absolute_import -from . import * \ No newline at end of file diff --git a/pymic/net_run/agent_abstract.py b/pymic/net_run/agent_abstract.py index 9131ba0..01ad808 100644 --- a/pymic/net_run/agent_abstract.py +++ b/pymic/net_run/agent_abstract.py @@ -56,13 +56,16 @@ def __init__(self, config, stage = 'train'): self.loss_dict = None self.transform_dict = None self.inferer = None + self.postprocess_dict = None + self.postprocessor = None self.tensor_type = config['dataset']['tensor_type'] - self.task_type = config['dataset']['task_type'] #cls, cls_mtbc, seg + self.task_type = config['dataset']['task_type'] self.deterministic = config['training'].get('deterministic', True) self.random_seed = config['training'].get('random_seed', 1) if(self.deterministic): seed_torch(self.random_seed) logging.info("deterministric is true") + def set_datasets(self, train_set, valid_set, test_set): """ @@ -101,6 +104,14 @@ def set_net_dict(self, net_dict): """ self.net_dict = net_dict + def set_postprocess_dict(self, postprocess_dict): + """ + Set the available methods for postprocess, including customized postprocess methods. + + :param postprocess_dict: (dictionary) A dictionary of available postprocess methods. + """ + self.postprocess_dict = postprocess_dict + def set_loss_dict(self, loss_dict): """ Set the available loss functions, including customized loss functions. @@ -139,7 +150,7 @@ def get_checkpoint_name(self): """ ckpt_mode = self.config['testing']['ckpt_mode'] if(ckpt_mode == 0 or ckpt_mode == 1): - ckpt_dir = self.config['training']['ckpt_save_dir'] + ckpt_dir = self.config['training']['ckpt_dir'] ckpt_prefix = self.config['training'].get('ckpt_prefix', None) if(ckpt_prefix is None): ckpt_prefix = ckpt_dir.split('/')[-1] @@ -148,10 +159,21 @@ def get_checkpoint_name(self): with open(txt_name, 'r') as txt_file: it_num = txt_file.read().replace('\n', '') ckpt_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, it_num) + if(ckpt_mode == 1 and not os.path.isfile(ckpt_name)): + ckpt_name = "{0:}/{1:}_best.pt".format(ckpt_dir, ckpt_prefix) else: ckpt_name = self.config['testing']['ckpt_name'] return ckpt_name + @abstractmethod + def get_stage_transform_from_config(self, stage): + """ + Get the transform list required by dataset for training, validation or inference stage. + + :param stage: (str) `train`, `valid` or `test`. + """ + raise(ValueError("not implemented")) + @abstractmethod def get_stage_dataset_from_config(self, stage): """ @@ -246,7 +268,11 @@ def create_dataset(self): if(self.train_set is None): self.train_set = self.get_stage_dataset_from_config('train') if(self.valid_set is None): - self.valid_set = self.get_stage_dataset_from_config('valid') + valid_csv = self.config['dataset'].get('valid_csv', None) + if valid_csv is not None: + self.valid_set = self.get_stage_dataset_from_config('valid') + else: + logging.warning("Dataset for validation is not created, as valid_dir is not provided.") if(self.deterministic): def worker_init_fn(worker_id): # workder_seed = self.random_seed+worker_id @@ -257,18 +283,20 @@ def worker_init_fn(worker_id): else: worker_init = None - bn_train = self.config['dataset']['train_batch_size'] - bn_valid = self.config['dataset'].get('valid_batch_size', 1) - num_worker = self.config['dataset'].get('num_worker', 16) - g_train, g_valid = torch.Generator(), torch.Generator() + num_worker = self.config['dataset'].get('num_worker', 8) + bn_train = self.config['dataset']['train_batch_size'] + g_train = torch.Generator() g_train.manual_seed(self.random_seed) - g_valid.manual_seed(self.random_seed) self.train_loader = torch.utils.data.DataLoader(self.train_set, batch_size = bn_train, shuffle=True, num_workers= num_worker, - worker_init_fn=worker_init, generator = g_train) - self.valid_loader = torch.utils.data.DataLoader(self.valid_set, - batch_size = bn_valid, shuffle=False, num_workers= num_worker, - worker_init_fn=worker_init, generator = g_valid) + worker_init_fn=worker_init, generator = g_train, drop_last = True) + if(self.valid_set is not None): + bn_valid = self.config['dataset'].get('valid_batch_size', 1) + g_valid = torch.Generator() + g_valid.manual_seed(self.random_seed) + self.valid_loader = torch.utils.data.DataLoader(self.valid_set, + batch_size = bn_valid, shuffle=False, num_workers= num_worker, + worker_init_fn=worker_init, generator = g_valid) else: bn_test = self.config['dataset'].get('test_batch_size', 1) if(self.test_set is None): diff --git a/pymic/net_run/agent_cls.py b/pymic/net_run/agent_cls.py index 5610982..728f805 100644 --- a/pymic/net_run/agent_cls.py +++ b/pymic/net_run/agent_cls.py @@ -13,12 +13,13 @@ from torch.optim import lr_scheduler from torchvision import transforms from tensorboardX import SummaryWriter +from pymic import TaskType from pymic.io.nifty_dataset import ClassificationDataset from pymic.loss.loss_dict_cls import PyMICClsLossDict from pymic.net.net_dict_cls import TorchClsNetDict from pymic.transform.trans_dict import TransformDict from pymic.net_run.agent_abstract import NetRunAgent -from pymic.util.general import mixup +from pymic.util.general import mixup, tensor_shape_match import warnings warnings.filterwarnings('ignore', '.*output shape of zoom.*') @@ -38,7 +39,6 @@ class ClassificationAgent(NetRunAgent): def __init__(self, config, stage = 'train'): super(ClassificationAgent, self).__init__(config, stage) self.transform_dict = TransformDict - assert(self.task_type in ["cls", "cls_nexcl"]) def get_stage_dataset_from_config(self, stage): assert(stage in ['train', 'valid', 'test']) @@ -58,7 +58,7 @@ def get_stage_dataset_from_config(self, stage): data_transform = None else: transform_param = self.config['dataset'] - transform_param['task'] = 'classification' + transform_param['task'] = self.task_type for name in transform_names: if(name not in self.transform_dict): raise(ValueError("Undefined transform {0:}".format(name))) @@ -73,7 +73,8 @@ def get_stage_dataset_from_config(self, stage): modal_num = modal_num, class_num = class_num, with_label= not (stage == 'test'), - transform = data_transform ) + transform = data_transform, + task = self.task_type) return dataset def create_network(self): @@ -97,6 +98,8 @@ def create_loss_calculator(self): if(self.loss_dict is None): self.loss_dict = PyMICClsLossDict loss_name = self.config['training']['loss_type'] + if(loss_name != "SigmoidCELoss" and self.task_type == TaskType.CLASSIFICATION_COEXIST): + raise ValueError("SigmoidCELoss should be used when task_type is cls_coexist") if(loss_name in self.loss_dict): self.loss_calculater = self.loss_dict[loss_name](self.config['training']) else: @@ -119,12 +122,12 @@ def get_evaluation_score(self, outputs, labels): metrics = self.config['training'].get("evaluation_metric", "accuracy") if(metrics != "accuracy"): # default classification accuracy raise ValueError("Not implemeted for metric {0:}".format(metrics)) - if(self.task_type == "cls"): + if(self.task_type == TaskType.CLASSIFICATION_ONE_HOT): out_argmax = torch.argmax(outputs, 1) lab_argmax = torch.argmax(labels, 1) consis = self.convert_tensor_type(out_argmax == lab_argmax) score = torch.mean(consis) - elif(self.task_type == "cls_nexcl"): #nonexclusive classification + elif(self.task_type == TaskType.CLASSIFICATION_COEXIST): preds = self.convert_tensor_type(outputs > 0.5) consis= self.convert_tensor_type(preds == labels.data) score = torch.mean(consis) @@ -209,6 +212,27 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): logging.info('valid loss {0:.4f}, avg {1:} {2:.4f}'.format( valid_scalars['loss'], metrics, valid_scalars[metrics])) + def load_pretrained_weights(self, network, pretrained_dict, device_ids): + if(len(device_ids) > 1): + if(hasattr(network.module, "get_parameters_to_load")): + model_dict = network.module.get_parameters_to_load() + else: + model_dict = network.module.state_dict() + else: + if(hasattr(network, "get_parameters_to_load")): + model_dict = network.get_parameters_to_load() + else: + model_dict = network.state_dict() + pretrained_dict = {k: v for k, v in pretrained_dict.items() if \ + k in model_dict and tensor_shape_match(pretrained_dict[k], model_dict[k])} + logging.info("Initializing the following parameters with pre-trained model") + for k in pretrained_dict: + logging.info(k) + if (len(device_ids) > 1): + network.module.load_state_dict(pretrained_dict, strict = False) + else: + network.load_state_dict(pretrained_dict, strict = False) + def train_valid(self): device_ids = self.config['training']['gpus'] if(len(device_ids) > 1): @@ -218,13 +242,13 @@ def train_valid(self): self.device = torch.device("cuda:{0:}".format(device_ids[0])) self.net.to(self.device) - ckpt_dir = self.config['training']['ckpt_save_dir'] + ckpt_dir = self.config['training']['ckpt_dir'] if(ckpt_dir[-1] == "/"): ckpt_dir = ckpt_dir[:-1] ckpt_prefix = self.config['training'].get('ckpt_prefix', None) if(ckpt_prefix is None): ckpt_prefix = ckpt_dir.split('/')[-1] - iter_start = self.config['training']['iter_start'] + iter_start = 0 iter_max = self.config['training']['iter_max'] iter_valid = self.config['training']['iter_valid'] iter_save = self.config['training']['iter_save'] @@ -240,18 +264,18 @@ def train_valid(self): self.max_val_score = 0.0 self.max_val_it = 0 self.best_model_wts = None - self.checkpoint = None - if(iter_start > 0): - checkpoint_file = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, iter_start) - self.checkpoint = torch.load(checkpoint_file, map_location = self.device) - assert(self.checkpoint['iteration'] == iter_start) - if(len(device_ids) > 1): - self.net.module.load_state_dict(self.checkpoint['model_state_dict']) - else: - self.net.load_state_dict(self.checkpoint['model_state_dict']) - self.max_val_score = self.checkpoint.get('valid_pred', 0) - self.max_val_it = self.checkpoint['iteration'] - self.best_model_wts = self.checkpoint['model_state_dict'] + ckpt_init_name = self.config['training'].get('ckpt_init_name', None) + ckpt_init_mode = self.config['training'].get('ckpt_init_mode', 0) + + if(ckpt_init_name is not None): + checkpoint = torch.load(ckpt_dir + "/" + ckpt_init_name, map_location = self.device) + pretrained_dict = checkpoint['model_state_dict'] + self.load_pretrained_weights(self.net, pretrained_dict, device_ids) + if(ckpt_init_mode > 0): # Load other information + iter_start = checkpoint['iteration'] + self.max_val_score = checkpoint.get('valid_pred', 0) + self.max_val_it = checkpoint['iteration'] + self.best_model_wts = checkpoint['model_state_dict'] self.create_optimizer(self.get_parameters_to_update()) self.create_loss_calculator() @@ -259,7 +283,7 @@ def train_valid(self): self.trainIter = iter(self.train_loader) logging.info("{0:} training start".format(str(datetime.now())[:-7])) - self.summ_writer = SummaryWriter(self.config['training']['ckpt_save_dir']) + self.summ_writer = SummaryWriter(self.config['training']['ckpt_dir']) self.glob_it = iter_start for it in range(iter_start, iter_max, iter_valid): lr_value = self.optimizer.param_groups[0]['lr'] @@ -267,6 +291,7 @@ def train_valid(self): train_scalars = self.training() t1 = time.time() valid_scalars = self.validation() + t2 = time.time() if(isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): self.scheduler.step(valid_scalars[metrics]) @@ -285,6 +310,15 @@ def train_valid(self): self.best_model_wts = copy.deepcopy(self.net.module.state_dict()) else: self.best_model_wts = copy.deepcopy(self.net.state_dict()) + save_dict = {'iteration': self.max_val_it, + 'valid_pred': self.max_val_score, + 'model_state_dict': self.best_model_wts, + 'optimizer_state_dict': self.optimizer.state_dict()} + save_name = "{0:}/{1:}_best.pt".format(ckpt_dir, ckpt_prefix) + torch.save(save_dict, save_name) + txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefix), 'wt') + txt_file.write(str(self.max_val_it)) + txt_file.close() stop_now = True if(early_stop_it is not None and \ self.glob_it - self.max_val_it > early_stop_it) else False @@ -302,16 +336,6 @@ def train_valid(self): if(stop_now): logging.info("The training is early stopped") break - # save the best performing checkpoint - save_dict = {'iteration': self.max_val_it, - 'valid_pred': self.max_val_score, - 'model_state_dict': self.best_model_wts, - 'optimizer_state_dict': self.optimizer.state_dict()} - save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.max_val_it) - torch.save(save_dict, save_name) - txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefix), 'wt') - txt_file.write(str(self.max_val_it)) - txt_file.close() logging.info('The best perfroming iter is {0:}, valid {1:} {2:}'.format(\ self.max_val_it, metrics, self.max_val_score)) self.summ_writer.close() @@ -346,15 +370,15 @@ def infer(self): infer_time = time.time() - start_time infer_time_list.append(infer_time) - if (self.task_type == "cls"): + if (self.task_type == TaskType.CLASSIFICATION_ONE_HOT): out_prob = nn.Softmax(dim = 1)(out_digit).detach().cpu().numpy() out_lab = np.argmax(out_prob, axis=1) - else: #self.task_type == "cls_nexcl" + else: #self.task_type == TaskType.CLASSIFICATION_COEXIST out_prob = nn.Sigmoid()(out_digit).detach().cpu().numpy() out_lab = np.asarray(out_prob > 0.5, np.uint8) for i in range(len(names)): print(names[i], out_lab[i]) - if(self.task_type == "cls"): + if(self.task_type == TaskType.CLASSIFICATION_ONE_HOT): out_lab_list.append([names[i]] + [out_lab[i]]) else: out_lab_list.append([names[i]] + out_lab[i].tolist()) diff --git a/pymic/net_run/agent_preprocess.py b/pymic/net_run/agent_preprocess.py new file mode 100644 index 0000000..c53421f --- /dev/null +++ b/pymic/net_run/agent_preprocess.py @@ -0,0 +1,130 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import os +import sys +import torch +import torchvision.transforms as transforms +from pymic.util.parse_config import * +from pymic.io.image_read_write import save_nd_array_as_image +from pymic.io.nifty_dataset import NiftyDataset +from pymic.transform.trans_dict import TransformDict +from pymic.net_run.agent_abstract import seed_torch +from pymic.net_run.self_sup.util import volume_fusion, nonlienar_volume_fusion,augmented_volume_fusion,self_volume_fusion + +class PreprocessAgent(object): + def __init__(self, config): + super(PreprocessAgent, self).__init__() + self.config = config + self.transform_dict = TransformDict + self.task_type = config['dataset']['task_type'] + self.dataloader = None + self.dataloader_unlab= None + + deterministic = config['dataset'].get('deterministic', True) + if(deterministic): + random_seed = config['dataset'].get('random_seed', 1) + seed_torch(random_seed) + + def get_dataset_from_config(self): + root_dir = self.config['dataset']['data_dir'] + modal_num = self.config['dataset'].get('modal_num', 1) + transform_names = self.config['dataset']["transform"] + + self.transform_list = [] + if(transform_names is None or len(transform_names) == 0): + data_transform = None + else: + transform_param = self.config['dataset'] + transform_param['task'] = self.task_type + for name in transform_names: + if(name not in self.transform_dict): + raise(ValueError("Undefined transform {0:}".format(name))) + one_transform = self.transform_dict[name](transform_param) + self.transform_list.append(one_transform) + data_transform = transforms.Compose(self.transform_list) + + data_csv = self.config['dataset'].get('data_csv', None) + data_csv_unlab = self.config['dataset'].get('data_csv_unlab', None) + batch_size = self.config['dataset'].get('batch_size', 1) + data_shuffle = self.config['dataset'].get('data_shuffle', False) + if(data_csv is not None): + dataset = NiftyDataset(root_dir = root_dir, + csv_file = data_csv, + modal_num = modal_num, + with_label= True, + transform = data_transform, + task = self.task_type) + self.dataloader = torch.utils.data.DataLoader(dataset, + batch_size = batch_size, shuffle=data_shuffle, num_workers= 8, + worker_init_fn=None, generator = torch.Generator()) + if(data_csv_unlab is not None): + dataset_unlab = NiftyDataset(root_dir = root_dir, + csv_file = data_csv_unlab, + modal_num = modal_num, + with_label= False, + transform = data_transform, + task = self.task_type) + self.dataloader_unlab = torch.utils.data.DataLoader(dataset_unlab, + batch_size = batch_size, shuffle=data_shuffle, num_workers= 8, + worker_init_fn=None, generator = torch.Generator()) + + def run(self): + """ + Do preprocessing for labeled and unlabeled data. + """ + self.get_dataset_from_config() + out_dir = self.config['dataset']['output_dir'] + modal_num = self.config['dataset']['modal_num'] + if(not os.path.isdir(out_dir)): + os.mkdir(out_dir) + batch_operation = self.config['dataset'].get('batch_operation', None) + for dataloader in [self.dataloader, self.dataloader_unlab]: + if(dataloader is None): + continue + for data in dataloader: + inputs = data['image'] + labels = data.get('label', None) + img_names = data['names'] + if(len(img_names) == modal_num): # for unlabeled dataset + lab_names = [item.replace(".nii.gz", "_lab.nii.gz") for item in img_names[0]] + else: + lab_names = img_names[-1] + B, C = inputs.shape[0], inputs.shape[1] + spacing = [x.numpy()[0] for x in data['spacing']] + + if(batch_operation is not None): + if('VolumeFusion' in batch_operation): + class_num = self.config['dataset']['VolumeFusion_cls_num'.lower()] + block_range = self.config['dataset']['VolumeFusion_block_range'.lower()] + size_min = self.config['dataset']['VolumeFusion_size_min'.lower()] + size_max = self.config['dataset']['VolumeFusion_size_max'.lower()] + inputs, labels = volume_fusion(inputs, class_num - 1, block_range, size_min, size_max) + elif('SelfVolumeFusion' in batch_operation): + class_num = self.config['dataset']['SelfVolumeFusion_cls_num'.lower()] + fuse_ratio = self.config['dataset']['SelfVolumeFusion_fuse_ratio'.lower()] + size_min = self.config['dataset']['SelfVolumeFusion_size_min'.lower()] + size_max = self.config['dataset']['SelfVolumeFusion_size_max'.lower()] + inputs, labels = self_volume_fusion(inputs, class_num - 1, fuse_ratio, size_min, size_max) + elif('NonLinearVolumeFusion' in batch_operation): + block_range = self.config['dataset']['NonLinearVolumeFusion_block_range'.lower()] + size_min = self.config['dataset']['NonLinearVolumeFusion_size_min'.lower()] + size_max = self.config['dataset']['NonLinearVolumeFusion_size_max'.lower()] + inputs, labels = nonlienar_volume_fusion(inputs, block_range, size_min, size_max) + elif('AugmentedVolumeFusion' in batch_operation): + size_min = self.config['dataset']['AugmentedVolumeFusion_size_min'.lower()] + size_max = self.config['dataset']['AugmentedVolumeFusion_size_max'.lower()] + inputs, labels = augmented_volume_fusion(inputs, size_min, size_max) + + for b in range(B): + for c in range(C): + image_name = out_dir + "/" + img_names[c][b] + print(image_name) + out_dir_full = "/".join(image_name.split("/")[:-1]) + print(out_dir_full) + if(not os.path.exists(out_dir_full)): + os.mkdir(out_dir_full) + save_nd_array_as_image(inputs[b][c], image_name, reference_name = None, spacing=spacing) + if(labels is not None): + label_name = out_dir + "/" + lab_names[b] + print(label_name) + save_nd_array_as_image(labels[b][0], label_name, reference_name = None, spacing=spacing) diff --git a/pymic/net_run/agent_rec.py b/pymic/net_run/agent_rec.py new file mode 100644 index 0000000..0d78b43 --- /dev/null +++ b/pymic/net_run/agent_rec.py @@ -0,0 +1,336 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import copy +import logging +import time +import logging +import numpy as np +import os +import scipy +import torch +import torch.nn as nn +from datetime import datetime +from torch.optim import lr_scheduler +from tensorboardX import SummaryWriter +from pymic.io.image_read_write import save_nd_array_as_image +from pymic.net_run.infer_func import Inferer +from pymic.net_run.agent_seg import SegmentationAgent +from pymic.loss.seg.mse import MAELoss, MSELoss +from pymic.util.general import mixup, tensor_shape_match + +ReconstructionLossDict = { + 'MAELoss': MAELoss, + 'MSELoss': MSELoss + } + +class ReconstructionAgent(SegmentationAgent): + """ + An agent for image reconstruction (pixel-level intensity prediction). + """ + def __init__(self, config, stage = 'train'): + super(ReconstructionAgent, self).__init__(config, stage) + if (self.config['network']['class_num'] != 1): + raise ValueError("For reconstruction tasks, the output channel number should be 1, " + + "but {} was given.".format(self.config['network']['class_num'])) + + def create_loss_calculator(self): + if(self.loss_dict is None): + self.loss_dict = ReconstructionLossDict + loss_name = self.config['training']['loss_type'] + if isinstance(loss_name, (list, tuple)): + raise ValueError("Undefined loss function {0:}".format(loss_name)) + elif (loss_name not in self.loss_dict): + raise ValueError("Undefined loss function {0:}".format(loss_name)) + else: + loss_param = self.config['training'] + base_loss = self.loss_dict[loss_name](self.config['training']) + if(self.config['training'].get('deep_supervise', False)): + raise ValueError("Deep supervised loss not implemented for reconstruction tasks") + # weight = self.config['training'].get('deep_supervise_weight', None) + # mode = self.config['training'].get('deep_supervise_mode', 2) + # params = {'deep_supervise_weight': weight, + # 'deep_supervise_mode': mode, + # 'base_loss':base_loss} + # self.loss_calculator = DeepSuperviseLoss(params) + else: + self.loss_calculator = base_loss + + def training(self): + iter_valid = self.config['training']['iter_valid'] + train_loss = 0 + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 + self.net.train() + for it in range(iter_valid): + t0 = time.time() + try: + data = next(self.trainIter) + except StopIteration: + self.trainIter = iter(self.train_loader) + data = next(self.trainIter) + # get the inputs + t1 = time.time() + inputs = self.convert_tensor_type(data['image']) + label = self.convert_tensor_type(data['label']) + + # for debug + # from pymic.io.image_read_write import save_nd_array_as_image + # print(inputs.shape) + # for i in range(inputs.shape[0]): + # image_i = inputs[i][0] + # label_i = label[i][0] + # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) + # label_name = "temp/label_{0:}_{1:}.nii.gz".format(it, i) + # save_nd_array_as_image(image_i, image_name, reference_name = None) + # save_nd_array_as_image(label_i, label_name, reference_name = None) + # if(it > 10): + # break + # return + + inputs, label = inputs.to(self.device), label.to(self.device) + + # zero the parameter gradients + self.optimizer.zero_grad() + + # forward + backward + optimize + outputs = self.net(inputs) + t2 = time.time() + # for debug + # if it < 5: + # outputs = nn.Tanh()(outputs) + # for i in range(inputs.shape[0]): + # out_name = "temp/output_{0:}_{1:}.nii.gz".format(it, i) + # output = outputs[i][0] + # output = output.cpu().detach().numpy() + # save_nd_array_as_image(output, out_name, reference_name = None) + # else: + # break + + loss = self.get_loss_value(data, outputs, label) + t3 = time.time() + loss.backward() + self.optimizer.step() + t4 = time.time() + train_loss = train_loss + loss.item() + + if(isinstance(outputs, tuple) or isinstance(outputs, list)): + outputs = outputs[0] + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 + + train_avg_loss = train_loss / iter_valid + train_scalers = {'loss': train_avg_loss, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time} + return train_scalers + + def validation(self): + class_num = self.config['network']['class_num'] + if(self.inferer is None): + infer_cfg = self.config['testing'] + infer_cfg['class_num'] = class_num + self.inferer = Inferer(infer_cfg) + + valid_loss_list = [] + validIter = iter(self.valid_loader) + with torch.no_grad(): + self.net.eval() + + # for debug + # save_num = 0 + for data in validIter: + inputs = self.convert_tensor_type(data['image']) + label = self.convert_tensor_type(data['label']) + inputs, label = inputs.to(self.device), label.to(self.device) + outputs = self.inferer.run(self.net, inputs) + # The tensors are on CPU when calculating loss for validation data + loss = self.get_loss_value(data, outputs, label) + valid_loss_list.append(loss.item()) + + # for debug + # print(inputs.shape, label.shape, outputs.shape) + # inputs = inputs.cpu().numpy() + # label = label.cpu().numpy() + # outputs = outputs.cpu().numpy() + # for i in range(inputs.shape[0]): + # image_i = inputs[i][0] + # label_i = label[i][0] + # output_i = outputs[i][0] + # image_name = "temp/case{0:}_image.nii.gz".format(save_num + i) + # label_name = "temp/case{0:}_label.nii.gz".format(save_num + i) + # output_name= "temp/case{0:}_output.nii.gz".format(save_num + i) + # save_nd_array_as_image(image_i, image_name, reference_name = None) + # save_nd_array_as_image(label_i, label_name, reference_name = None) + # save_nd_array_as_image(output_i, output_name, reference_name = None) + # save_num += inputs.shape[0] + # if(save_num > 20): + # break + valid_avg_loss = np.asarray(valid_loss_list).mean() + valid_scalers = {'loss': valid_avg_loss} + return valid_scalers + + def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): + loss_scalar ={'train':train_scalars['loss'], + 'valid':valid_scalars['loss']} + self.summ_writer.add_scalars('loss', loss_scalar, glob_it) + self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it) + logging.info('train/valid loss {0:.4f}/{1:.4f}'.format(train_scalars['loss'],valid_scalars['loss'])) + logging.info("data: {0:.2f}s, forward: {1:.2f}s, loss: {2:.2f}s, backward: {3:.2f}s".format( + train_scalars['data_time'], train_scalars['forward_time'], + train_scalars['loss_time'], train_scalars['backward_time'])) + + def train_valid(self): + device_ids = self.config['training']['gpus'] + if(len(device_ids) > 1): + self.device = torch.device("cuda:0") + self.net = nn.DataParallel(self.net, device_ids = device_ids) + else: + self.device = torch.device("cuda:{0:}".format(device_ids[0])) + self.net.to(self.device) + + ckpt_dir = self.config['training']['ckpt_dir'] + ckpt_prefix = self.config['training'].get('ckpt_prefix', None) + if(ckpt_prefix is None): + ckpt_prefix = ckpt_dir.split('/')[-1] + # iter_start = self.config['training']['iter_start'] + iter_start = 0 + iter_max = self.config['training']['iter_max'] + iter_valid = self.config['training']['iter_valid'] + iter_save = self.config['training'].get('iter_save', None) + early_stop_it = self.config['training'].get('early_stop_patience', None) + if(iter_save is None): + iter_save_list = [iter_max] + elif(isinstance(iter_save, (tuple, list))): + iter_save_list = iter_save + else: + iter_save_list = range(0, iter_max + 1, iter_save) + + self.min_val_loss = 10000.0 + self.max_val_it = 0 + self.best_model_wts = None + checkpoint = None + # initialize the network with pre-trained weights + ckpt_init_name = self.config['training'].get('ckpt_init_name', None) + ckpt_init_mode = self.config['training'].get('ckpt_init_mode', 0) + ckpt_for_optm = None + if(ckpt_init_name is not None): + checkpoint = torch.load(ckpt_dir + "/" + ckpt_init_name, map_location = self.device) + pretrained_dict = checkpoint['model_state_dict'] + model_dict = self.net.module.state_dict() if (len(device_ids) > 1) else self.net.state_dict() + pretrained_dict = {k: v for k, v in pretrained_dict.items() if \ + k in model_dict and tensor_shape_match(pretrained_dict[k], model_dict[k])} + logging.info("Initializing the following parameters with pre-trained model") + for k in pretrained_dict: + logging.info(k) + if (len(device_ids) > 1): + self.net.module.load_state_dict(pretrained_dict, strict = False) + else: + self.net.load_state_dict(pretrained_dict, strict = False) + if(ckpt_init_mode > 0): # Load other information + self.min_val_loss = checkpoint.get('valid_loss', 10000) + iter_start = checkpoint['iteration'] + self.max_val_it = iter_start + self.best_model_wts = checkpoint['model_state_dict'] + ckpt_for_optm = checkpoint + + self.create_optimizer(self.get_parameters_to_update(), ckpt_for_optm) + self.create_loss_calculator() + + self.trainIter = iter(self.train_loader) + + logging.info("{0:} training start".format(str(datetime.now())[:-7])) + self.summ_writer = SummaryWriter(self.config['training']['ckpt_dir']) + self.glob_it = iter_start + for it in range(iter_start, iter_max, iter_valid): + lr_value = self.optimizer.param_groups[0]['lr'] + t0 = time.time() + train_scalars = self.training() + t1 = time.time() + valid_scalars = self.validation() + t2 = time.time() + if(isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): + self.scheduler.step(-valid_scalars['loss']) + else: + self.scheduler.step() + + self.glob_it = it + iter_valid + logging.info("\n{0:} it {1:}".format(str(datetime.now())[:-7], self.glob_it)) + logging.info('learning rate {0:}'.format(lr_value)) + logging.info("training/validation time: {0:.2f}s/{1:.2f}s".format(t1-t0, t2-t1)) + self.write_scalars(train_scalars, valid_scalars, lr_value, self.glob_it) + if(valid_scalars['loss'] < self.min_val_loss): + self.min_val_loss = valid_scalars['loss'] + self.max_val_it = self.glob_it + if(len(device_ids) > 1): + self.best_model_wts = copy.deepcopy(self.net.module.state_dict()) + else: + self.best_model_wts = copy.deepcopy(self.net.state_dict()) + + save_dict = {'iteration': self.max_val_it, + 'valid_loss': self.min_val_loss, + 'model_state_dict': self.best_model_wts, + 'optimizer_state_dict': self.optimizer.state_dict()} + save_name = "{0:}/{1:}_best.pt".format(ckpt_dir, ckpt_prefix) + torch.save(save_dict, save_name) + txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefix), 'wt') + txt_file.write(str(self.max_val_it)) + txt_file.close() + + stop_now = True if(early_stop_it is not None and \ + self.glob_it - self.max_val_it > early_stop_it) else False + if ((self.glob_it in iter_save_list) or stop_now): + save_dict = {'iteration': self.glob_it, + 'valid_loss': valid_scalars['loss'], + 'model_state_dict': self.net.module.state_dict() \ + if len(device_ids) > 1 else self.net.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict()} + save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.glob_it) + torch.save(save_dict, save_name) + txt_file = open("{0:}/{1:}_latest.txt".format(ckpt_dir, ckpt_prefix), 'wt') + txt_file.write(str(self.glob_it)) + txt_file.close() + if(stop_now): + logging.info("The training is early stopped") + break + # save the best performing checkpoint + logging.info('The best performing iter is {0:}, valid loss {1:}'.format(\ + self.max_val_it, self.min_val_loss)) + self.summ_writer.close() + + def save_outputs(self, data): + """ + Save prediction output. + + :param data: (dictionary) A data dictionary with prediciton result and other + information such as input image name. + """ + output_dir = self.config['testing']['output_dir'] + ignore_dir = self.config['testing'].get('filename_ignore_dir', True) + filename_replace_source = self.config['testing'].get('filename_replace_source', None) + filename_replace_target = self.config['testing'].get('filename_replace_target', None) + if(not os.path.exists(output_dir)): + os.makedirs(output_dir, exist_ok=True) + + names, pred = data['names'], data['predict'] + if(isinstance(pred, (list, tuple))): + pred = pred[0] + pred = np.tanh(pred) + if(self.postprocessor is not None): + pred = self.postprocessor(pred) + # pred = scipy.special.expit(pred) + # save the output predictions + test_dir = self.config['dataset'].get('test_dir', None) + if(test_dir is None): + test_dir = self.config['dataset']['train_dir'] + + for i in range(pred.shape[1]): + save_name = names[i][0].split('/')[-1] if ignore_dir else \ + names[i][0].replace('/', '_') + if((filename_replace_source is not None) and (filename_replace_target is not None)): + save_name = save_name.replace(filename_replace_source, filename_replace_target) + print(save_name) + save_name = "{0:}/{1:}".format(output_dir, save_name) + save_nd_array_as_image(pred[i][0], save_name, test_dir + '/' + names[i][0]) + + \ No newline at end of file diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 887a516..0259e3e 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -18,6 +18,7 @@ from pymic.io.image_read_write import save_nd_array_as_image from pymic.io.nifty_dataset import NiftyDataset from pymic.net.net_dict_seg import SegNetDict +from pymic.net.multi_net import MultiNet from pymic.net_run.agent_abstract import NetRunAgent from pymic.net_run.infer_func import Inferer from pymic.loss.loss_dict_seg import SegLossDict @@ -38,53 +39,84 @@ def __init__(self, config, stage = 'train'): self.net_dict = SegNetDict self.postprocess_dict = PostProcessDict self.postprocessor = None - - def get_stage_dataset_from_config(self, stage): - assert(stage in ['train', 'valid', 'test']) - root_dir = self.config['dataset']['root_dir'] - modal_num = self.config['dataset'].get('modal_num', 1) + def get_transform_names_and_parameters(self, stage): + """ + Get a list of transform objects for creating a dataset + """ + assert(stage in ['train', 'valid', 'test']) transform_key = stage + '_transform' - if(stage == "valid" and transform_key not in self.config['dataset']): - transform_key = "train_transform" - transform_names = self.config['dataset'][transform_key] - - self.transform_list = [] - if(transform_names is None or len(transform_names) == 0): - data_transform = None - else: - transform_param = self.config['dataset'] - transform_param['task'] = 'segmentation' - for name in transform_names: + trans_names = self.config['dataset'][transform_key] + trans_params = self.config['dataset'] + trans_params['task'] = self.task_type + return trans_names, trans_params + + def get_stage_dataset_from_config(self, stage): + trans_names, trans_params = self.get_transform_names_and_parameters(stage) + transform_list = [] + if(trans_names is not None and len(trans_names) > 0): + for name in trans_names: if(name not in self.transform_dict): raise(ValueError("Undefined transform {0:}".format(name))) - one_transform = self.transform_dict[name](transform_param) - self.transform_list.append(one_transform) - data_transform = transforms.Compose(self.transform_list) + one_transform = self.transform_dict[name](trans_params) + transform_list.append(one_transform) + data_transform = transforms.Compose(transform_list) - csv_file = self.config['dataset'].get(stage + '_csv', None) - dataset = NiftyDataset(root_dir = root_dir, + csv_file = self.config['dataset'].get(stage + '_csv', None) + if(stage == 'test'): + with_label = False + self.test_transforms = transform_list + else: + with_label = self.config['dataset'].get(stage + '_label', True) + modal_num = self.config['dataset'].get('modal_num', 1) + allow_miss = self.config['dataset'].get('allow_missing_modal', False) + stage_dir = self.config['dataset'].get('train_dir', None) + stage_dim = self.config['dataset'].get('train_dim', 3) + stage_lab_key = self.config['dataset'].get('train_label_key', 'label') + if(stage == 'valid'): # and "valid_dir" in self.config['dataset']): + stage_dir = self.config['dataset'].get('valid_dir', stage_dir) + stage_dim = self.config['dataset'].get('valid_dim', stage_dim) + stage_lab_key = self.config['dataset'].get('valid_label_key', 'label') + if(stage == 'test'): # and "test_dir" in self.config['dataset']): + stage_dir = self.config['dataset'].get('test_dir', stage_dir) + stage_dim = self.config['dataset'].get('test_dim', stage_dim) + stage_lab_key = self.config['dataset'].get('test_label_key', 'label') + logging.info("Creating dataset for {0:}".format(stage)) + dataset = NiftyDataset(root_dir = stage_dir, csv_file = csv_file, modal_num = modal_num, - with_label= not (stage == 'test'), - transform = data_transform ) + image_dim = stage_dim, + allow_missing_modal = allow_miss, + label_key = stage_lab_key, + transform = data_transform, + task = self.task_type) return dataset def create_network(self): if(self.net is None): net_name = self.config['network']['net_type'] - if(net_name not in self.net_dict): - raise ValueError("Undefined network {0:}".format(net_name)) - self.net = self.net_dict[net_name](self.config['network']) + if(isinstance(net_name, (tuple, list))): + self.net = MultiNet(self.net_dict, self.config['network']) + else: + if(net_name not in self.net_dict): + raise ValueError("Undefined network {0:}".format(net_name)) + self.net = self.net_dict[net_name](self.config['network']) if(self.tensor_type == 'float'): self.net.float() else: self.net.double() + if(hasattr(self.net, "set_stage")): + self.net.set_stage(self.stage) param_number = sum(p.numel() for p in self.net.parameters() if p.requires_grad) logging.info('parameter number {0:}'.format(param_number)) def get_parameters_to_update(self): - return self.net.parameters() + if hasattr(self.net, "get_parameters_to_update"): + params = self.net.get_parameters_to_update() + else: + params = self.net.parameters() + return params + def create_loss_calculator(self): if(self.loss_dict is None): @@ -108,22 +140,23 @@ def create_loss_calculator(self): def get_loss_value(self, data, pred, gt, param = None): loss_input_dict = {'prediction':pred, 'ground_truth': gt} - if data.get('pixel_weight', None) is not None: - if(isinstance(pred, tuple) or isinstance(pred, list)): - loss_input_dict['pixel_weight'] = data['pixel_weight'].to(pred[0].device) - else: - loss_input_dict['pixel_weight'] = data['pixel_weight'].to(pred.device) + if(isinstance(pred, tuple) or isinstance(pred, list)): + device = pred[0].device + else: + device = pred.device + pixel_weight = data.get('pixel_weight', None) + if(pixel_weight is not None): + loss_input_dict['pixel_weight'] = pixel_weight.to(device) + + class_weight = self.config['training'].get('class_weight', None) + if(class_weight is not None): + class_num = self.config['network']['class_num'] + assert(len(class_weight) == class_num) + class_weight = torch.from_numpy(np.asarray(class_weight)) + class_weight = self.convert_tensor_type(class_weight) + loss_input_dict['class_weight'] = class_weight.to(device) loss_value = self.loss_calculator(loss_input_dict) return loss_value - - def set_postprocessor(self, postprocessor): - """ - Set post processor after prediction. - - :param postprocessor: post processor, such as an instance of - `pymic.util.post_process.PostProcess`. - """ - self.postprocessor = postprocessor def training(self): class_num = self.config['network']['class_num'] @@ -131,42 +164,52 @@ def training(self): mixup_prob = self.config['training'].get('mixup_probability', 0.0) train_loss = 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() for it in range(iter_valid): + t0 = time.time() try: data = next(self.trainIter) except StopIteration: self.trainIter = iter(self.train_loader) data = next(self.trainIter) - # get the inputs + t1 = time.time() inputs = self.convert_tensor_type(data['image']) labels_prob = self.convert_tensor_type(data['label_prob']) if(mixup_prob > 0 and random() < mixup_prob): inputs, labels_prob = mixup(inputs, labels_prob) - # # for debug + # for debug + # print("current iteration", it) + # if(it > 10): + # break # for i in range(inputs.shape[0]): # image_i = inputs[i][0] - # label_i = labels_prob[i][1] - # pixw_i = pix_w[i][0] - # print(image_i.shape, label_i.shape, pixw_i.shape) + # # label_i = labels_prob[i][1] + # label_i = np.argmax(labels_prob[i], axis = 0) + # # pixw_i = pix_w[i][0] + # print(image_i.shape, label_i.shape) # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) # label_name = "temp/label_{0:}_{1:}.nii.gz".format(it, i) - # weight_name= "temp/weight_{0:}_{1:}.nii.gz".format(it, i) + # # weight_name= "temp/weight_{0:}_{1:}.nii.gz".format(it, i) # save_nd_array_as_image(image_i, image_name, reference_name = None) # save_nd_array_as_image(label_i, label_name, reference_name = None) - # save_nd_array_as_image(pixw_i, weight_name, reference_name = None) + # # save_nd_array_as_image(pixw_i, weight_name, reference_name = None) # continue + inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device) - + # zero the parameter gradients self.optimizer.zero_grad() # forward + backward + optimize outputs = self.net(inputs) + t2 = time.time() loss = self.get_loss_value(data, outputs, labels_prob) + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss = train_loss + loss.item() # get dice evaluation for each class @@ -177,19 +220,33 @@ def training(self): soft_out, labels_prob = reshape_prediction_and_ground_truth(soft_out, labels_prob) dice_list = get_classwise_dice(soft_out, labels_prob) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) train_avg_dice = train_cls_dice[1:].mean() train_scalers = {'loss': train_avg_loss, 'avg_fg_dice':train_avg_dice,\ - 'class_dice': train_cls_dice} + 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time} return train_scalers def validation(self): class_num = self.config['network']['class_num'] if(self.inferer is None): - infer_cfg = self.config['testing'] + infer_cfg = {} infer_cfg['class_num'] = class_num + infer_cfg['sliding_window_enable'] = self.config['testing'].get('sliding_window_enable', False) + if(infer_cfg['sliding_window_enable']): + patch_size = self.config['dataset'].get('patch_size', None) + if(patch_size is None): + patch_size = self.config['testing']['sliding_window_size'] + infer_cfg['sliding_window_size'] = patch_size + infer_cfg['sliding_window_stride'] = [i//2 for i in patch_size] self.inferer = Inferer(infer_cfg) valid_loss_list = [] @@ -199,6 +256,9 @@ def validation(self): self.net.eval() for data in validIter: inputs = self.convert_tensor_type(data['image']) + if('label_prob' not in data): + raise ValueError("label_prob is not found in validation data, make sure" + + "that LabelToProbability is used in valid_transform.") labels_prob = self.convert_tensor_type(data['label_prob']) inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device) batch_n = inputs.shape[0] @@ -242,7 +302,31 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): ' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]") logging.info('valid loss {0:.4f}, avg foreground dice {1:.4f} '.format( valid_scalars['loss'], valid_scalars['avg_fg_dice']) + "[" + \ - ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") + ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") + logging.info("data: {0:.2f}s, forward: {1:.2f}s, loss: {2:.2f}s, backward: {3:.2f}s".format( + train_scalars['data_time'], train_scalars['forward_time'], + train_scalars['loss_time'], train_scalars['backward_time'])) + + def load_pretrained_weights(self, network, pretrained_dict, device_ids): + if(len(device_ids) > 1): + if(hasattr(network.module, "get_parameters_to_load")): + model_dict = network.module.get_parameters_to_load() + else: + model_dict = network.module.state_dict() + else: + if(hasattr(network, "get_parameters_to_load")): + model_dict = network.get_parameters_to_load() + else: + model_dict = network.state_dict() + pretrained_dict = {k: v for k, v in pretrained_dict.items() if \ + k in model_dict and tensor_shape_match(pretrained_dict[k], model_dict[k])} + logging.info("Initializing the following parameters with pre-trained model") + for k in pretrained_dict: + logging.info(k) + if (len(device_ids) > 1): + network.module.load_state_dict(pretrained_dict, strict = False) + else: + network.load_state_dict(pretrained_dict, strict = False) def train_valid(self): device_ids = self.config['training']['gpus'] @@ -253,7 +337,7 @@ def train_valid(self): self.device = torch.device("cuda:{0:}".format(device_ids[0])) self.net.to(self.device) - ckpt_dir = self.config['training']['ckpt_save_dir'] + ckpt_dir = self.config['training']['ckpt_dir'] if(ckpt_dir[-1] == "/"): ckpt_dir = ckpt_dir[:-1] ckpt_prefix = self.config['training'].get('ckpt_prefix', None) @@ -283,20 +367,11 @@ def train_valid(self): if(ckpt_init_name is not None): checkpoint = torch.load(ckpt_dir + "/" + ckpt_init_name, map_location = self.device) pretrained_dict = checkpoint['model_state_dict'] - model_dict = self.net.module.state_dict() if (len(device_ids) > 1) else self.net.state_dict() - pretrained_dict = {k: v for k, v in pretrained_dict.items() if \ - k in model_dict and tensor_shape_match(pretrained_dict[k], model_dict[k])} - logging.info("Initializing the following parameters with pre-trained model") - for k in pretrained_dict: - logging.info(k) - if (len(device_ids) > 1): - self.net.module.load_state_dict(pretrained_dict, strict = False) - else: - self.net.load_state_dict(pretrained_dict, strict = False) + self.load_pretrained_weights(self.net, pretrained_dict, device_ids) if(ckpt_init_mode > 0): # Load other information self.max_val_dice = checkpoint.get('valid_pred', 0) - iter_start = checkpoint['iteration'] - 1 + iter_start = checkpoint['iteration'] self.max_val_it = iter_start self.best_model_wts = checkpoint['model_state_dict'] ckpt_for_optm = checkpoint @@ -306,7 +381,7 @@ def train_valid(self): self.trainIter = iter(self.train_loader) logging.info("{0:} training start".format(str(datetime.now())[:-7])) - self.summ_writer = SummaryWriter(self.config['training']['ckpt_save_dir']) + self.summ_writer = SummaryWriter(self.config['training']['ckpt_dir']) self.glob_it = iter_start for it in range(iter_start, iter_max, iter_valid): lr_value = self.optimizer.param_groups[0]['lr'] @@ -325,6 +400,7 @@ def train_valid(self): logging.info('learning rate {0:}'.format(lr_value)) logging.info("training/validation time: {0:.2f}s/{1:.2f}s".format(t1-t0, t2-t1)) self.write_scalars(train_scalars, valid_scalars, lr_value, self.glob_it) + if(valid_scalars['avg_fg_dice'] > self.max_val_dice): self.max_val_dice = valid_scalars['avg_fg_dice'] self.max_val_it = self.glob_it @@ -332,8 +408,17 @@ def train_valid(self): self.best_model_wts = copy.deepcopy(self.net.module.state_dict()) else: self.best_model_wts = copy.deepcopy(self.net.state_dict()) + save_dict = {'iteration': self.max_val_it, + 'valid_pred': self.max_val_dice, + 'model_state_dict': self.best_model_wts, + 'optimizer_state_dict': self.optimizer.state_dict()} + save_name = "{0:}/{1:}_best.pt".format(ckpt_dir, ckpt_prefix) + torch.save(save_dict, save_name) + txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefix), 'wt') + txt_file.write(str(self.max_val_it)) + txt_file.close() - stop_now = True if(early_stop_it is not None and \ + stop_now = True if (early_stop_it is not None and \ self.glob_it - self.max_val_it > early_stop_it) else False if ((self.glob_it in iter_save_list) or stop_now): save_dict = {'iteration': self.glob_it, @@ -350,15 +435,6 @@ def train_valid(self): logging.info("The training is early stopped") break # save the best performing checkpoint - save_dict = {'iteration': self.max_val_it, - 'valid_pred': self.max_val_dice, - 'model_state_dict': self.best_model_wts, - 'optimizer_state_dict': self.optimizer.state_dict()} - save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.max_val_it) - torch.save(save_dict, save_name) - txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefix), 'wt') - txt_file.write(str(self.max_val_it)) - txt_file.close() logging.info('The best performing iter is {0:}, valid dice {1:}'.format(\ self.max_val_it, self.max_val_dice)) self.summ_writer.close() @@ -388,6 +464,7 @@ def test_time_dropout(m): raise ValueError("ckpt_mode should be 3 if ckpt_name is a list") # load network parameters and set the network as evaluation mode + print("ckpt name", ckpt_name) checkpoint = torch.load(ckpt_name, map_location = device) self.net.load_state_dict(checkpoint['model_state_dict']) @@ -397,7 +474,7 @@ def test_time_dropout(m): self.inferer = Inferer(infer_cfg) postpro_name = self.config['testing'].get('post_process', None) if(self.postprocessor is None and postpro_name is not None): - self.postprocessor = PostProcessDict[postpro_name](self.config['testing']) + self.postprocessor = self.postprocess_dict[postpro_name](self.config['testing']) infer_time_list = [] with torch.no_grad(): for data in self.test_loader: @@ -423,7 +500,7 @@ def test_time_dropout(m): pred = pred.cpu().numpy() data['predict'] = pred # inverse transform - for transform in self.transform_list[::-1]: + for transform in self.test_transforms[::-1]: if (transform.inverse): data = transform.inverse_transform_for_prediction(data) @@ -477,7 +554,7 @@ def infer_with_multiple_checkpoints(self): pred = np.mean(predict_list, axis=0) data['predict'] = pred # inverse transform - for transform in self.transform_list[::-1]: + for transform in self.test_transforms[::-1]: if (transform.inverse): data = transform.inverse_transform_for_prediction(data) @@ -516,15 +593,18 @@ def save_outputs(self, data): for i in range(len(names)): output[i] = self.postprocessor(output[i]) # save the output and (optionally) probability predictions - root_dir = self.config['dataset']['root_dir'] - for i in range(len(names)): - save_name = names[i].split('/')[-1] if ignore_dir else \ - names[i].replace('/', '_') + test_dir = self.config['dataset'].get('test_dir', None) + if(test_dir is None): + test_dir = self.config['dataset']['train_dir'] + + for i in range(output.shape[0]): + save_name = names[i][0].split('/')[-1] if ignore_dir else \ + names[i][0].replace('/', '_') if((filename_replace_source is not None) and (filename_replace_target is not None)): save_name = save_name.replace(filename_replace_source, filename_replace_target) print(save_name) save_name = "{0:}/{1:}".format(output_dir, save_name) - save_nd_array_as_image(output[i], save_name, root_dir + '/' + names[i]) + save_nd_array_as_image(output[i], save_name, test_dir + '/' + names[i][0]) save_name_split = save_name.split('.') if(not save_prob): @@ -542,4 +622,4 @@ def save_outputs(self, data): prob_save_name = "{0:}_prob_{1:}.{2:}".format(save_prefix, c, save_format) if(len(temp_prob.shape) == 2): temp_prob = np.asarray(temp_prob * 255, np.uint8) - save_nd_array_as_image(temp_prob, prob_save_name, root_dir + '/' + names[i]) + save_nd_array_as_image(temp_prob, prob_save_name, test_dir + '/' + names[i][0]) diff --git a/pymic/net_run/get_optimizer.py b/pymic/net_run/get_optimizer.py index ad8fda0..771b448 100644 --- a/pymic/net_run/get_optimizer.py +++ b/pymic/net_run/get_optimizer.py @@ -13,8 +13,9 @@ def get_optimizer(name, net_params, optim_params): # see https://www.codeleading.com/article/44815584159/ param_group = [{'params': net_params, 'initial_lr': lr}] if(keyword_match(name, "SGD")): + nesterov = optim_params.get('nesterov', True) return optim.SGD(param_group, lr, - momentum = momentum, weight_decay = weight_decay) + momentum = momentum, weight_decay = weight_decay, nesterov = nesterov) elif(keyword_match(name, "Adam")): return optim.Adam(param_group, lr, weight_decay = weight_decay) elif(keyword_match(name, "SparseAdam")): diff --git a/pymic/net_run/infer_func.py b/pymic/net_run/infer_func.py index 43162d0..e0e466e 100644 --- a/pymic/net_run/infer_func.py +++ b/pymic/net_run/infer_func.py @@ -2,6 +2,8 @@ from __future__ import print_function, division import torch +import numpy as np +from scipy.ndimage.filters import gaussian_filter from torch.nn.functional import interpolate class Inferer(object): @@ -48,16 +50,29 @@ def __get_prediction_number_and_scales(self, tempx): output_num, scales = 1, None return output_num, scales + def __get_gaussian_weight_map(self, window_size, sigma_scale = 1.0/8): + w = np.zeros(window_size) + center = [i//2 for i in window_size] + sigmas = [i*sigma_scale for i in window_size] + w[tuple(center)] = 1.0 + w = gaussian_filter(w, sigmas, 0, mode='constant', cval=0) + return w + def __infer_with_sliding_window(self, image): """ - Use sliding window to predict segmentation for large images. - Note that the network may output a list of tensors with difference sizes. + Use sliding window to predict segmentation for large images. The outupt of each + sliding window is weighted by a Gaussian map that hihglights contributions of windows + with a centroid closer to a given pixel. + Note that the network may output a list of tensors with difference sizes for multi-scale prediction. """ window_size = [x for x in self.config['sliding_window_size']] window_stride = [x for x in self.config['sliding_window_stride']] + window_batch = self.config.get('sliding_window_batch', 1) class_num = self.config['class_num'] img_full_shape = list(image.shape) batch_size = img_full_shape[0] + assert(batch_size == 1 or window_batch == 1) + img_chns = img_full_shape[1] img_shape = img_full_shape[2:] img_dim = len(img_shape) if(img_dim != 2 and img_dim !=3): @@ -86,57 +101,80 @@ def __infer_with_sliding_window(self, image): crop_start_list.append([d_min, h_min, w_min]) output_shape = [batch_size, class_num] + img_shape - mask_shape = [batch_size, class_num] + window_size - counter = torch.zeros(output_shape).to(image.device) - temp_mask = torch.ones(mask_shape).to(image.device) + weight = torch.zeros(output_shape).to(image.device) + temp_w = self.__get_gaussian_weight_map(window_size) + temp_w = np.broadcast_to(temp_w, [batch_size, class_num] + window_size) + temp_w = torch.from_numpy(np.array(temp_w)).to(image.device) temp_in_shape = img_full_shape[:2] + window_size tempx = torch.ones(temp_in_shape).to(image.device) out_num, scale_list = self.__get_prediction_number_and_scales(tempx) + + window_num = len(crop_start_list) + assert(window_num >= window_batch) + patches_shape = [window_batch, img_chns] + window_size + patches_in = torch.ones(patches_shape).to(image.device) if(out_num == 1): # for a single prediction output = torch.zeros(output_shape).to(image.device) - for c0 in crop_start_list: - c1 = [c0[d] + window_size[d] for d in range(img_dim)] - if(img_dim == 2): - patch_in = image[:, :, c0[0]:c1[0], c0[1]:c1[1]] - else: - patch_in = image[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] - patch_out = self.model(patch_in) - if(isinstance(patch_out, (tuple, list))): - patch_out = patch_out[0] - if(img_dim == 2): - output[:, :, c0[0]:c1[0], c0[1]:c1[1]] += patch_out - counter[:, :, c0[0]:c1[0], c0[1]:c1[1]] += temp_mask - else: - output[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] += patch_out - counter[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] += temp_mask - return output/counter + for w_i in range(0, window_num, window_batch): + for k in range(window_batch): + if(w_i + k >= window_num): + break + c0 = crop_start_list[w_i + k] + c1 = [c0[d] + window_size[d] for d in range(img_dim)] + if(img_dim == 2): + patches_in[k] = image[:, :, c0[0]:c1[0], c0[1]:c1[1]] + else: + patches_in[k] = image[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] + patches_out = self.model(patches_in) + if(isinstance(patches_out, (tuple, list))): + patches_out = patches_out[0] + for k in range(window_batch): + if(w_i + k >= window_num): + break + c0 = crop_start_list[w_i + k] + c1 = [c0[d] + window_size[d] for d in range(img_dim)] + if(img_dim == 2): + output[:, :, c0[0]:c1[0], c0[1]:c1[1]] += patches_out[k] * temp_w + weight[:, :, c0[0]:c1[0], c0[1]:c1[1]] += temp_w + else: + output[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] += patches_out[k] * temp_w + weight[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] += temp_w + return output/weight else: # for multiple prediction output_list= [] for i in range(out_num): output_shape_i = [batch_size, class_num] + \ [int(img_shape[d] * scale_list[i][d]) for d in range(img_dim)] output_list.append(torch.zeros(output_shape_i).to(image.device)) - - for c0 in crop_start_list: - c1 = [c0[d] + window_size[d] for d in range(img_dim)] - if(img_dim == 2): - patch_in = image[:, :, c0[0]:c1[0], c0[1]:c1[1]] - else: - patch_in = image[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] - patch_out = self.model(patch_in) - - for i in range(out_num): - c0_i = [int(c0[d] * scale_list[i][d]) for d in range(img_dim)] - c1_i = [int(c1[d] * scale_list[i][d]) for d in range(img_dim)] + temp_ws = [interpolate(temp_w, scale_factor = scale_list[i]) for i in range(out_num)] + weights = [interpolate(weight, scale_factor = scale_list[i]) for i in range(out_num)] + for w_i in range(0, window_num, window_batch): + for k in range(window_batch): + if(w_i + k >= window_num): + break + c0 = crop_start_list[w_i + k] + c1 = [c0[d] + window_size[d] for d in range(img_dim)] if(img_dim == 2): - output_list[i][:, :, c0_i[0]:c1_i[0], c0_i[1]:c1_i[1]] += patch_out[i] - counter[:, :, c0[0]:c1[0], c0[1]:c1[1]] += temp_mask + patches_in[k] = image[:, :, c0[0]:c1[0], c0[1]:c1[1]] else: - output_list[i][:, :, c0_i[0]:c1_i[0], c0_i[1]:c1_i[1], c0_i[2]:c1_i[2]] += patch_out[i] - counter[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] += temp_mask + patches_in[k] = image[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] + patches_out = self.model(patches_in) + + for i in range(out_num): + for k in range(window_batch): + if(w_i + k >= window_num): + break + c0 = crop_start_list[w_i + k] + c0_i = [int(c0[d] * scale_list[i][d]) for d in range(img_dim)] + c1_i = [int(c1[d] * scale_list[i][d]) for d in range(img_dim)] + if(img_dim == 2): + output_list[i][:, :, c0_i[0]:c1_i[0], c0_i[1]:c1_i[1]] += patches_out[i][k] * temp_ws[i] + weights[i][:, :, c0_i[0]:c1_i[0], c0_i[1]:c1_i[1]] += temp_ws[i] + else: + output_list[i][:, :, c0_i[0]:c1_i[0], c0_i[1]:c1_i[1], c0_i[2]:c1_i[2]] += patches_out[i][k] * temp_ws[i] + weights[i][:, :, c0_i[0]:c1_i[0], c0_i[1]:c1_i[1], c0_i[2]:c1_i[2]] += temp_ws[i] for i in range(out_num): - counter_i = interpolate(counter, scale_factor = scale_list[i]) - output_list[i] = output_list[i] / counter_i + output_list[i] = output_list[i] / weights[i] return output_list def run(self, model, image): diff --git a/pymic/net_run/noisy_label/nll_clslsr.py b/pymic/net_run/noisy_label/nll_clslsr.py index 0148621..3f1059e 100644 --- a/pymic/net_run/noisy_label/nll_clslsr.py +++ b/pymic/net_run/noisy_label/nll_clslsr.py @@ -142,9 +142,9 @@ def test_time_dropout(m): print(gt.shape, pred_cat.shape) conf = get_confident_map(gt, pred_cat) conf = conf.reshape(-1, 256, 256).astype(np.uint8) * 255 - save_dir = self.config['dataset']['root_dir'] + "/slsr_conf" + save_dir = self.config['dataset']['train_dir'] + "/slsr_conf" for idx in range(len(filename_list)): - filename = filename_list[idx][0].split('/')[-1] + filename = filename_list[idx][0][0].split('/')[-1] conf_map = Image.fromarray(conf[idx]) dst_path = os.path.join(save_dir, filename) conf_map.save(dst_path) @@ -152,32 +152,34 @@ def test_time_dropout(m): def get_confidence_map(cfg_file): config = parse_config(cfg_file) config = synchronize_config(config) + agent = NLLCLSLSR(config, 'test') - # set dataset - transform_names = config['dataset']['valid_transform'] + # set customized dataset for testing, i.e,. inference with training images + trans_names, trans_params = agent.get_transform_names_and_parameters('valid') transform_list = [] - transform_dict = TransformDict - if(transform_names is None or len(transform_names) == 0): - data_transform = None - else: - transform_param = config['dataset'] - transform_param['task'] = 'segmentation' - for name in transform_names: - if(name not in transform_dict): + if(trans_names is not None and len(trans_names) > 0): + for name in trans_names: + if(name not in agent.transform_dict): raise(ValueError("Undefined transform {0:}".format(name))) - one_transform = transform_dict[name](transform_param) + one_transform = agent.transform_dict[name](trans_params) transform_list.append(one_transform) - data_transform = transforms.Compose(transform_list) + data_transform = transforms.Compose(transform_list) + stage_dir = config['dataset']['train_dir'] csv_file = config['dataset']['train_csv'] modal_num = config['dataset'].get('modal_num', 1) - dataset = NiftyDataset(root_dir = config['dataset']['root_dir'], - csv_file = csv_file, - modal_num = modal_num, - with_label= True, - transform = data_transform ) + stage_dim = config['dataset'].get('train_dim', 3) + lab_key = config['dataset'].get('train_label_key', 'label') + + dataset = NiftyDataset(root_dir = stage_dir, + csv_file = csv_file, + modal_num = modal_num, + image_dim = stage_dim, + allow_missing_modal = False, + label_key = lab_key, + transform = data_transform, + task = agent.task_type) - agent = NLLCLSLSR(config, 'test') agent.set_datasets(None, None, dataset) agent.transform_list = transform_list agent.create_dataset() diff --git a/pymic/net_run/noisy_label/nll_co_teaching.py b/pymic/net_run/noisy_label/nll_co_teaching.py index ec8e230..33e375c 100644 --- a/pymic/net_run/noisy_label/nll_co_teaching.py +++ b/pymic/net_run/noisy_label/nll_co_teaching.py @@ -5,35 +5,17 @@ import os import sys import numpy as np +import time import torch import torch.nn as nn -import torch.optim as optim -from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.util import reshape_tensor_to_2D from pymic.net_run.agent_seg import SegmentationAgent -from pymic.net.net_dict_seg import SegNetDict from pymic.util.parse_config import * from pymic.util.ramps import get_rampup_ratio -class BiNet(nn.Module): - def __init__(self, params): - super(BiNet, self).__init__() - net_name = params['net_type'] - self.net1 = SegNetDict[net_name](params) - self.net2 = SegNetDict[net_name](params) - - def forward(self, x): - out1 = self.net1(x) - out2 = self.net2(x) - - if(self.training): - return out1, out2 - else: - return (out1 + out2) / 2 - class NLLCoTeaching(SegmentationAgent): """ Co-teaching for noisy-label learning. @@ -58,14 +40,6 @@ def __init__(self, config, stage = 'train'): logging.warn("only CrossEntropyLoss supported for" + " coteaching, the specified loss {0:} is ingored".format(loss_type)) - def create_network(self): - if(self.net is None): - self.net = BiNet(self.config['network']) - if(self.tensor_type == 'float'): - self.net.float() - else: - self.net.double() - def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] @@ -75,19 +49,20 @@ def training(self): rampup_start = nll_cfg.get('rampup_start', 0) rampup_end = nll_cfg.get('rampup_end', iter_max) - train_loss_no_select1 = 0 - train_loss_no_select2 = 0 - train_loss1 = 0 - train_loss2 = 0 + train_loss_no_select1, train_loss_no_select2 = 0, 0 + train_loss1, train_avg_loss1 = 0, 0 + train_loss2, train_avg_loss2 = 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() for it in range(iter_valid): + t0 = time.time() try: data = next(self.trainIter) except StopIteration: self.trainIter = iter(self.train_loader) data = next(self.trainIter) - + t1 = time.time() # get the inputs inputs = self.convert_tensor_type(data['image']) labels_prob = self.convert_tensor_type(data['label_prob']) @@ -98,7 +73,7 @@ def training(self): # forward + backward + optimize outputs1, outputs2 = self.net(inputs) - + t2 = time.time() prob1 = nn.Softmax(dim = 1)(outputs1) prob2 = nn.Softmax(dim = 1)(outputs2) prob1_2d = reshape_tensor_to_2D(prob1) * 0.999 + 5e-4 @@ -125,8 +100,9 @@ def training(self): loss2_select = loss2[ind_1_update] loss = loss1_select.mean() + loss2_select.mean() - + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss_no_select1 = train_loss_no_select1 + loss1.mean().item() @@ -139,6 +115,11 @@ def training(self): soft_out1, labels_prob = reshape_prediction_and_ground_truth(soft_out1, labels_prob) dice_list = get_classwise_dice(soft_out1, labels_prob).detach().cpu().numpy() train_dice_list.append(dice_list) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss_no_select1 = train_loss_no_select1 / iter_valid train_avg_loss_no_select2 = train_loss_no_select2 / iter_valid train_avg_loss1 = train_loss1 / iter_valid @@ -150,7 +131,9 @@ def training(self): 'loss1':train_avg_loss1, 'loss2': train_avg_loss2, 'loss_no_select1':train_avg_loss_no_select1, 'loss_no_select2':train_avg_loss_no_select2, - 'select_ratio':remb_ratio, 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'select_ratio':remb_ratio, 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } return train_scalers def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): @@ -177,3 +160,6 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): logging.info('valid loss {0:.4f}, avg foreground dice {1:.4f} '.format( valid_scalars['loss'], valid_scalars['avg_fg_dice']) + "[" + \ ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") + logging.info("data: {0:.2f}s, forward: {1:.2f}s, loss: {2:.2f}s, backward: {3:.2f}s".format( + train_scalars['data_time'], train_scalars['forward_time'], + train_scalars['loss_time'], train_scalars['backward_time'])) diff --git a/pymic/net_run/noisy_label/nll_dast.py b/pymic/net_run/noisy_label/nll_dast.py index 1921e9c..95203ba 100644 --- a/pymic/net_run/noisy_label/nll_dast.py +++ b/pymic/net_run/noisy_label/nll_dast.py @@ -1,8 +1,9 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division +import numpy as np import random import torch -import numpy as np +import time import torch.nn as nn import torchvision.transforms as transforms from pymic.io.nifty_dataset import NiftyDataset @@ -117,31 +118,31 @@ def get_noisy_dataset_from_config(self): """ Create a dataset for images with noisy labels based on configuraiton. """ - root_dir = self.config['dataset']['root_dir'] - modal_num = self.config['dataset'].get('modal_num', 1) - transform_names = self.config['dataset']['train_transform'] - - self.transform_list = [] - if(transform_names is None or len(transform_names) == 0): - data_transform = None - else: - transform_param = self.config['dataset'] - transform_param['task'] = 'segmentation' - for name in transform_names: + trans_names, trans_params = self.get_transform_names_and_parameters('train') + transform_list = [] + if(trans_names is not None and len(trans_names) > 0): + for name in trans_names: if(name not in self.transform_dict): raise(ValueError("Undefined transform {0:}".format(name))) - one_transform = self.transform_dict[name](transform_param) - self.transform_list.append(one_transform) - data_transform = transforms.Compose(self.transform_list) + one_transform = self.transform_dict[name](trans_params) + transform_list.append(one_transform) + data_transform = transforms.Compose(transform_list) - csv_file = self.config['dataset'].get('train_csv_noise', None) - dataset = NiftyDataset(root_dir=root_dir, - csv_file = csv_file, - modal_num = modal_num, - with_label= True, - transform = data_transform ) + modal_num = self.config['dataset'].get('modal_num', 1) + stage_dim = self.config['dataset'].get('train_dim', 3) + lab_key = self.config['dataset'].get('train_label_key', 'label') + csv_file = self.config['dataset'].get('train_csv_noise', None) + dataset = NiftyDataset(root_dir = self.config['dataset']['train_dir'], + csv_file = csv_file, + modal_num = modal_num, + image_dim = stage_dim, + allow_missing_modal = False, + label_key = lab_key, + transform = data_transform, + task = self.task_type) return dataset + def create_dataset(self): super(NLLDAST, self).create_dataset() if(self.stage == 'train'): @@ -167,15 +168,15 @@ def training(self): iter_max = self.config['training']['iter_max'] rampup_start = nll_cfg.get('rampup_start', 0) rampup_end = nll_cfg.get('rampup_end', iter_max) - train_loss = 0 - train_loss_sup = 0 - train_loss_reg = 0 + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() rank_length = nll_cfg.get("dast_rank_length", 20) consist_loss = ConsistLoss() for it in range(iter_valid): + t0 = time.time() try: data_cl = next(self.trainIter) except StopIteration: @@ -186,7 +187,7 @@ def training(self): except StopIteration: self.trainIter_noise = iter(self.train_loader_noise) data_no = next(self.trainIter_noise) - + t1 = time.time() # get the inputs x0 = self.convert_tensor_type(data_cl['image']) # clean sample y0 = self.convert_tensor_type(data_cl['label_prob']) @@ -200,6 +201,7 @@ def training(self): # forward + backward + optimize b0_pred, b1_pred = self.net(inputs) + t2 = time.time() n0 = list(x0.shape)[0] # number of clean samples b0_x0_pred = b0_pred[:n0] # predication of clean samples from clean branch b0_x1_pred = b0_pred[n0:] # predication of noisy samples from clean branch @@ -235,8 +237,9 @@ def training(self): b0_x1_prob = nn.Softmax(dim = 1)(b0_x1_pred) loss_st = torch.mean(torch.abs(b0_x1_prob - sharpen(pseudo_label, 0.5))) loss = loss + loss_st * w_st - + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss = train_loss + loss.item() @@ -252,6 +255,11 @@ def training(self): p0_soft, y0 = reshape_prediction_and_ground_truth(p0_soft, y0) dice_list = get_classwise_dice(p0_soft, y0) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid @@ -260,7 +268,9 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':w_dbc, - 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } return train_scalers def train_valid(self): diff --git a/pymic/net_run/noisy_label/nll_trinet.py b/pymic/net_run/noisy_label/nll_trinet.py index 25c90cf..4c4198f 100644 --- a/pymic/net_run/noisy_label/nll_trinet.py +++ b/pymic/net_run/noisy_label/nll_trinet.py @@ -2,9 +2,8 @@ from __future__ import print_function, division import logging -import os -import sys import numpy as np +import time import torch import torch.nn as nn import torch.optim as optim @@ -17,24 +16,6 @@ from pymic.util.parse_config import * from pymic.util.ramps import get_rampup_ratio -class TriNet(nn.Module): - def __init__(self, params): - super(TriNet, self).__init__() - net_name = params['net_type'] - self.net1 = SegNetDict[net_name](params) - self.net2 = SegNetDict[net_name](params) - self.net3 = SegNetDict[net_name](params) - - def forward(self, x): - out1 = self.net1(x) - out2 = self.net2(x) - out3 = self.net3(x) - - if(self.training): - return out1, out2, out3 - else: - return (out1 + out2 + out3) / 3 - class NLLTriNet(SegmentationAgent): """ Implementation of trinet for learning from noisy samples for @@ -56,14 +37,6 @@ class NLLTriNet(SegmentationAgent): def __init__(self, config, stage = 'train'): super(NLLTriNet, self).__init__(config, stage) - def create_network(self): - if(self.net is None): - self.net = TriNet(self.config['network']) - if(self.tensor_type == 'float'): - self.net.float() - else: - self.net.double() - def get_loss_and_confident_mask(self, pred, labels_prob, conf_ratio): prob = nn.Softmax(dim = 1)(pred) prob_2d = reshape_tensor_to_2D(prob) * 0.999 + 5e-4 @@ -88,14 +61,16 @@ def training(self): train_loss_no_select2 = 0 train_loss1, train_loss2, train_loss3 = 0, 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() for it in range(iter_valid): + t0 = time.time() try: data = next(self.trainIter) except StopIteration: self.trainIter = iter(self.train_loader) data = next(self.trainIter) - + t1 = time.time() # get the inputs inputs = self.convert_tensor_type(data['image']) labels_prob = self.convert_tensor_type(data['label_prob']) @@ -106,7 +81,7 @@ def training(self): # forward + backward + optimize outputs1, outputs2, outputs3 = self.net(inputs) - + t2 = time.time() rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end) forget_ratio = (1 - select_ratio) * rampup_ratio remb_ratio = 1 - forget_ratio @@ -121,8 +96,9 @@ def training(self): loss2_avg = torch.sum(loss2 * mask13) / mask13.sum() loss3_avg = torch.sum(loss3 * mask12) / mask12.sum() loss = (loss1_avg + loss2_avg + loss3_avg) / 3 - + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss_no_select1 = train_loss_no_select1 + loss1.mean().item() @@ -135,6 +111,11 @@ def training(self): soft_out1, labels_prob = reshape_prediction_and_ground_truth(soft_out1, labels_prob) dice_list = get_classwise_dice(soft_out1, labels_prob).detach().cpu().numpy() train_dice_list.append(dice_list) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss_no_select1 = train_loss_no_select1 / iter_valid train_avg_loss_no_select2 = train_loss_no_select2 / iter_valid train_avg_loss1 = train_loss1 / iter_valid @@ -146,7 +127,9 @@ def training(self): 'loss1':train_avg_loss1, 'loss2': train_avg_loss2, 'loss_no_select1':train_avg_loss_no_select1, 'loss_no_select2':train_avg_loss_no_select2, - 'select_ratio':remb_ratio, 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'select_ratio':remb_ratio, 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } return train_scalers def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): @@ -172,4 +155,7 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): ' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]") logging.info('valid loss {0:.4f}, avg foreground dice {1:.4f} '.format( valid_scalars['loss'], valid_scalars['avg_fg_dice']) + "[" + \ - ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") + ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") + logging.info("data: {0:.2f}s, forward: {1:.2f}s, loss: {2:.2f}s, backward: {3:.2f}s".format( + train_scalars['data_time'], train_scalars['forward_time'], + train_scalars['loss_time'], train_scalars['backward_time'])) diff --git a/pymic/net_run/predict.py b/pymic/net_run/predict.py index ca4ef25..d63cbad 100644 --- a/pymic/net_run/predict.py +++ b/pymic/net_run/predict.py @@ -1,26 +1,46 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division +import argparse import logging import os import sys from datetime import datetime +from pymic import TaskType from pymic.util.parse_config import * from pymic.net_run.agent_cls import ClassificationAgent from pymic.net_run.agent_seg import SegmentationAgent +from pymic.net_run.agent_rec import ReconstructionAgent def main(): """ - The main function for running a network for training or inference. + The main function for running a network for inference. """ if(len(sys.argv) < 2): - print('Number of arguments should be 2. e.g.') - print(' pymic_test config.cfg') + print('Number of arguments should be at least 2. e.g.') + print(' pymic_test config.cfg -test_csv train.csv -output_dir result_dir -ckpt_mode 1') exit() - cfg_file = str(sys.argv[1]) - if(not os.path.isfile(cfg_file)): - raise ValueError("The config file does not exist: " + cfg_file) - config = parse_config(cfg_file) + parser = argparse.ArgumentParser() + parser.add_argument("cfg", help="configuration file for testing") + parser.add_argument("--test_csv", help="the csv file for testing images", + required=False, default=None) + parser.add_argument("--test_dir", help="the dir for testing images", + required=False, default=None) + parser.add_argument("--output_dir", help="the output dir for inference results", + required=False, default=None) + parser.add_argument("--ckpt_dir", help="the dir for trained model", + required=False, default=None) + parser.add_argument("--ckpt_mode", help="the mode for chekpoint: 0-latest, 1-best, 2-customized", + required=False, default=None) + parser.add_argument("--ckpt_name", help="the name chekpoint if ckpt_mode = 2", + required=False, default=None) + parser.add_argument("--gpus", help="the gpus for runing, e.g., [0]", + required=False, default=None) + args = parser.parse_args() + if(not os.path.isfile(args.cfg)): + raise ValueError("The config file does not exist: " + args.cfg) + config = parse_config(args) config = synchronize_config(config) + print(config) log_dir = config['testing']['output_dir'] if(not os.path.exists(log_dir)): os.makedirs(log_dir, exist_ok=True) @@ -32,13 +52,17 @@ def main(): logging.basicConfig(filename=log_dir+"/log_test.txt", level=logging.INFO, format='%(message)s') # for python 3.6 logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) - logging_config(config) - task = config['dataset']['task_type'] - assert task in ['cls', 'cls_nexcl', 'seg'] - if(task == 'cls' or task == 'cls_nexcl'): + dst_cfg = args.cfg if "/" not in args.cfg else args.cfg.split("/")[-1] + wrtie_config(config, log_dir + "/" + dst_cfg) + task = config['dataset']['task_type'] + if(task == TaskType.CLASSIFICATION_ONE_HOT or task == TaskType.CLASSIFICATION_COEXIST): agent = ClassificationAgent(config, 'test') - else: + elif(task == TaskType.SEGMENTATION): agent = SegmentationAgent(config, 'test') + elif(task == TaskType.RECONSTRUCTION): + agent = ReconstructionAgent(config, 'test') + else: + raise ValueError("Undefined task for inference: {0:}".format(task)) agent.run() if __name__ == "__main__": diff --git a/pymic/net_run/preprocess.py b/pymic/net_run/preprocess.py new file mode 100644 index 0000000..63410b5 --- /dev/null +++ b/pymic/net_run/preprocess.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import argparse +import os +import sys +from datetime import datetime +from pymic.util.parse_config import * +from pymic.net_run.agent_preprocess import PreprocessAgent + + +def main(): + """ + The main function for data preprocessing. + """ + if(len(sys.argv) < 2): + print('Number of arguments should be 2. e.g.') + print(' pymic_preprocess config.cfg') + exit() + parser = argparse.ArgumentParser() + parser.add_argument("cfg", help="configuration file for preprocessing") + args = parser.parse_args() + if(not os.path.isfile(args.cfg)): + raise ValueError("The config file does not exist: " + args.cfg) + config = parse_config(args) + config = synchronize_config(config) + agent = PreprocessAgent(config) + agent.run() + +if __name__ == "__main__": + main() + + + diff --git a/pymic/net_run/self_sup/__init__.py b/pymic/net_run/self_sup/__init__.py index 55f26bf..86482f9 100644 --- a/pymic/net_run/self_sup/__init__.py +++ b/pymic/net_run/self_sup/__init__.py @@ -1,2 +1,16 @@ from __future__ import absolute_import -from pymic.net_run.self_sup.self_sl_agent import SelfSLSegAgent \ No newline at end of file +from pymic.net_run.self_sup.self_genesis import SelfSupModelGenesis +from pymic.net_run.self_sup.self_patch_swapping import SelfSupPatchSwapping +# from pymic.net_run.self_sup.self_mim import SelfSupMIM +# from pymic.net_run.self_sup.self_dino import SelfSupDINO +from pymic.net_run.self_sup.self_vox2vec import SelfSupVox2Vec +from pymic.net_run.self_sup.self_volf import SelfSupVolumeFusion + +SelfSupMethodDict = { + # 'DINO': SelfSupDINO, + 'Vox2Vec': SelfSupVox2Vec, + 'ModelGenesis': SelfSupModelGenesis, + 'PatchSwapping': SelfSupPatchSwapping, + 'VolumeFusion': SelfSupVolumeFusion + # 'MaskedImageModeling': SelfSupMIM + } \ No newline at end of file diff --git a/pymic/net_run/self_sup/self_genesis.py b/pymic/net_run/self_sup/self_genesis.py new file mode 100644 index 0000000..85ee194 --- /dev/null +++ b/pymic/net_run/self_sup/self_genesis.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import copy +import logging +import time +from pymic.net_run.agent_rec import ReconstructionAgent + +class SelfSupModelGenesis(ReconstructionAgent): + """ + Patch swapping-based self-supervised learning. + + Reference: Liang Chen et al., Self-supervised learning for medical image analysis + using image context restoration, Medical Image Analysis, 2019. + + A PatchSwaping transform need to be used in the cnfiguration. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `self_supervised_learning` is needed. See :doc:`usage.selfsl` for details. + + In the configuration file, it should look like this: + ``` + [dataset] + task_type = rec + supervise_type = self_sup + train_transform = [..., ..., PatchSwaping] + valid_transform = [..., ..., PatchSwaping] + + [self_supervised_learning] + method_name = ModelGenesis + + """ + def __init__(self, config, stage = 'train'): + super(SelfSupModelGenesis, self).__init__(config, stage) + + def get_transform_names_and_parameters(self, stage): + trans_names, trans_params = super(SelfSupModelGenesis, self).get_transform_names_and_parameters(stage) + # if(stage == 'train'): + # print('training transforms:', trans_names) + # if("LocalShuffling" not in trans_names): + # raise ValueError("LocalShuffling is required for model genesis, \ + # but it is not given in training transform") + # if("NonLinearTransform" not in trans_names): + # raise ValueError("NonLinearTransform is required for model genesis, \ + # but it is not given in training transform") + # if("InOutPainting" not in trans_names): + # raise ValueError("InOutPainting is required for model genesis, \ + # but it is not given in training transform") + return trans_names, trans_params diff --git a/pymic/net_run/self_sup/self_patch_swapping.py b/pymic/net_run/self_sup/self_patch_swapping.py new file mode 100644 index 0000000..1692fa7 --- /dev/null +++ b/pymic/net_run/self_sup/self_patch_swapping.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import copy +import logging +import time +from pymic.net_run.agent_rec import ReconstructionAgent + +class SelfSupPatchSwapping(ReconstructionAgent): + """ + Patch swapping-based self-supervised learning. + + Reference: Liang Chen et al., Self-supervised learning for medical image analysis + using image context restoration, Medical Image Analysis, 2019. + + A PatchSwaping transform need to be used in the cnfiguration. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `self_supervised_learning` is needed. See :doc:`usage.selfsl` for details. + + In the configuration file, it should look like this: + ``` + [dataset] + task_type = rec + supervise_type = self_sup + train_transform = [..., ..., PatchSwaping] + valid_transform = [..., ..., PatchSwaping] + + [self_supervised_learning] + method_name = PatchSwapping + + """ + def __init__(self, config, stage = 'train'): + super(SelfSupPatchSwapping, self).__init__(config, stage) + + def get_transform_names_and_parameters(self, stage): + trans_names, trans_params = super(SelfSupPatchSwapping, self).get_transform_names_and_parameters(stage) + if(stage == 'train'): + print('training transforms:', trans_names) + assert("PatchSwaping" in trans_names) + return trans_names, trans_params + diff --git a/pymic/net_run/self_sup/self_sl_agent.py b/pymic/net_run/self_sup/self_sl_agent.py index 24a6e66..45bee26 100644 --- a/pymic/net_run/self_sup/self_sl_agent.py +++ b/pymic/net_run/self_sup/self_sl_agent.py @@ -3,31 +3,11 @@ import copy import logging import time -import logging -import numpy as np -import random -import torch -import torch.nn as nn -import torchvision.transforms as transforms -from datetime import datetime -from random import random -from torch.optim import lr_scheduler -from tensorboardX import SummaryWriter -from pymic.io.nifty_dataset import NiftyDataset -from pymic.loss.seg.util import get_soft_label -from pymic.loss.seg.util import reshape_prediction_and_ground_truth -from pymic.loss.seg.util import get_classwise_dice -from pymic.net_run.infer_func import Inferer -from pymic.net_run.agent_seg import SegmentationAgent -from pymic.transform.trans_dict import TransformDict -from pymic.loss.seg.mse import MAELoss, MSELoss +from pymic.net_run.agent_rec import ReconstructionAgent -RegressionLossDict = { - 'MAELoss': MAELoss, - 'MSELoss': MSELoss - } -class SelfSLSegAgent(SegmentationAgent): + +class SelfSLSegAgent(ReconstructionAgent): """ Abstract class for self-supervised segmentation. @@ -38,208 +18,8 @@ class SelfSLSegAgent(SegmentationAgent): In the configuration dictionary, in addition to the four sections (`dataset`, `network`, `training` and `inference`) used in fully supervised learning, an - extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. + extra section `self_supervised_learning` is needed. See :doc:`usage.selfsl` for details. """ def __init__(self, config, stage = 'train'): super(SelfSLSegAgent, self).__init__(config, stage) - self.transform_dict = TransformDict - - def create_loss_calculator(self): - if(self.loss_dict is None): - self.loss_dict = RegressionLossDict - loss_name = self.config['training']['loss_type'] - if isinstance(loss_name, (list, tuple)): - raise ValueError("Undefined loss function {0:}".format(loss_name)) - elif (loss_name not in self.loss_dict): - raise ValueError("Undefined loss function {0:}".format(loss_name)) - else: - loss_param = self.config['training'] - loss_param['loss_softmax'] = False - base_loss = self.loss_dict[loss_name](self.config['training']) - if(self.config['training'].get('deep_supervise', False)): - raise ValueError("Deep supervised loss not implemented for self-supervised learning") - # weight = self.config['training'].get('deep_supervise_weight', None) - # mode = self.config['training'].get('deep_supervise_mode', 2) - # params = {'deep_supervise_weight': weight, - # 'deep_supervise_mode': mode, - # 'base_loss':base_loss} - # self.loss_calculator = DeepSuperviseLoss(params) - else: - self.loss_calculator = base_loss - - def training(self): - iter_valid = self.config['training']['iter_valid'] - train_loss = 0 - self.net.train() - for it in range(iter_valid): - try: - data = next(self.trainIter) - except StopIteration: - self.trainIter = iter(self.train_loader) - data = next(self.trainIter) - # get the inputs - inputs = self.convert_tensor_type(data['image']) - label = self.convert_tensor_type(data['label']) - - # for debug - # from pymic.io.image_read_write import save_nd_array_as_image - # for i in range(inputs.shape[0]): - # image_i = inputs[i][0] - # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) - # save_nd_array_as_image(image_i, image_name, reference_name = None) - # return - - inputs, label = inputs.to(self.device), label.to(self.device) - - # zero the parameter gradients - self.optimizer.zero_grad() - - # forward + backward + optimize - outputs = self.net(inputs) - outputs = nn.Sigmoid()(outputs) - loss = self.get_loss_value(data, outputs, label) - loss.backward() - self.optimizer.step() - train_loss = train_loss + loss.item() - # get dice evaluation for each class - if(isinstance(outputs, tuple) or isinstance(outputs, list)): - outputs = outputs[0] - - train_avg_loss = train_loss / iter_valid - train_scalers = {'loss': train_avg_loss} - return train_scalers - - def validation(self): - if(self.inferer is None): - infer_cfg = self.config['testing'] - self.inferer = Inferer(infer_cfg) - - valid_loss_list = [] - validIter = iter(self.valid_loader) - with torch.no_grad(): - self.net.eval() - for data in validIter: - inputs = self.convert_tensor_type(data['image']) - label = self.convert_tensor_type(data['label']) - inputs, label = inputs.to(self.device), label.to(self.device) - outputs = self.inferer.run(self.net, inputs) - outputs = nn.Sigmoid()(outputs) - # The tensors are on CPU when calculating loss for validation data - loss = self.get_loss_value(data, outputs, label) - valid_loss_list.append(loss.item()) - - valid_avg_loss = np.asarray(valid_loss_list).mean() - valid_scalers = {'loss': valid_avg_loss} - return valid_scalers - - def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): - loss_scalar ={'train':train_scalars['loss'], - 'valid':valid_scalars['loss']} - self.summ_writer.add_scalars('loss', loss_scalar, glob_it) - self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it) - logging.info('train loss {0:.4f}'.format(train_scalars['loss'])) - logging.info('valid loss {0:.4f}'.format(valid_scalars['loss'])) - - def train_valid(self): - device_ids = self.config['training']['gpus'] - if(len(device_ids) > 1): - self.device = torch.device("cuda:0") - self.net = nn.DataParallel(self.net, device_ids = device_ids) - else: - self.device = torch.device("cuda:{0:}".format(device_ids[0])) - self.net.to(self.device) - ckpt_dir = self.config['training']['ckpt_save_dir'] - ckpt_prefix = self.config['training'].get('ckpt_prefix', None) - if(ckpt_prefix is None): - ckpt_prefix = ckpt_dir.split('/')[-1] - iter_start = self.config['training']['iter_start'] - iter_max = self.config['training']['iter_max'] - iter_valid = self.config['training']['iter_valid'] - iter_save = self.config['training'].get('iter_save', None) - early_stop_it = self.config['training'].get('early_stop_patience', None) - if(iter_save is None): - iter_save_list = [iter_max] - elif(isinstance(iter_save, (tuple, list))): - iter_save_list = iter_save - else: - iter_save_list = range(0, iter_max + 1, iter_save) - - self.min_val_loss = 10000.0 - self.max_val_it = 0 - self.best_model_wts = None - self.checkpoint = None - if(iter_start > 0): - checkpoint_file = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, iter_start) - self.checkpoint = torch.load(checkpoint_file, map_location = self.device) - # assert(self.checkpoint['iteration'] == iter_start) - if(len(device_ids) > 1): - self.net.module.load_state_dict(self.checkpoint['model_state_dict']) - else: - self.net.load_state_dict(self.checkpoint['model_state_dict']) - self.min_val_loss = self.checkpoint.get('valid_loss', 10000) - # self.max_val_it = self.checkpoint['iteration'] - self.max_val_it = iter_start - self.best_model_wts = self.checkpoint['model_state_dict'] - - self.create_optimizer(self.get_parameters_to_update()) - self.create_loss_calculator() - - self.trainIter = iter(self.train_loader) - - logging.info("{0:} training start".format(str(datetime.now())[:-7])) - self.summ_writer = SummaryWriter(self.config['training']['ckpt_save_dir']) - self.glob_it = iter_start - for it in range(iter_start, iter_max, iter_valid): - lr_value = self.optimizer.param_groups[0]['lr'] - t0 = time.time() - train_scalars = self.training() - t1 = time.time() - valid_scalars = self.validation() - t2 = time.time() - if(isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): - self.scheduler.step(-valid_scalars['loss']) - else: - self.scheduler.step() - - self.glob_it = it + iter_valid - logging.info("\n{0:} it {1:}".format(str(datetime.now())[:-7], self.glob_it)) - logging.info('learning rate {0:}'.format(lr_value)) - logging.info("training/validation time: {0:.2f}s/{1:.2f}s".format(t1-t0, t2-t1)) - self.write_scalars(train_scalars, valid_scalars, lr_value, self.glob_it) - if(valid_scalars['loss'] < self.min_val_loss): - self.min_val_loss = valid_scalars['loss'] - self.max_val_it = self.glob_it - if(len(device_ids) > 1): - self.best_model_wts = copy.deepcopy(self.net.module.state_dict()) - else: - self.best_model_wts = copy.deepcopy(self.net.state_dict()) - - stop_now = True if(early_stop_it is not None and \ - self.glob_it - self.max_val_it > early_stop_it) else False - if ((self.glob_it in iter_save_list) or stop_now): - save_dict = {'iteration': self.glob_it, - 'valid_loss': valid_scalars['loss'], - 'model_state_dict': self.net.module.state_dict() \ - if len(device_ids) > 1 else self.net.state_dict(), - 'optimizer_state_dict': self.optimizer.state_dict()} - save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.glob_it) - torch.save(save_dict, save_name) - txt_file = open("{0:}/{1:}_latest.txt".format(ckpt_dir, ckpt_prefix), 'wt') - txt_file.write(str(self.glob_it)) - txt_file.close() - if(stop_now): - logging.info("The training is early stopped") - break - # save the best performing checkpoint - save_dict = {'iteration': self.max_val_it, - 'valid_loss': self.min_val_loss, - 'model_state_dict': self.best_model_wts, - 'optimizer_state_dict': self.optimizer.state_dict()} - save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.max_val_it) - torch.save(save_dict, save_name) - txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefix), 'wt') - txt_file.write(str(self.max_val_it)) - txt_file.close() - logging.info('The best performing iter is {0:}, valid loss {1:}'.format(\ - self.max_val_it, self.min_val_loss)) - self.summ_writer.close() \ No newline at end of file + \ No newline at end of file diff --git a/pymic/net_run/self_sup/self_volf.py b/pymic/net_run/self_sup/self_volf.py new file mode 100644 index 0000000..4615979 --- /dev/null +++ b/pymic/net_run/self_sup/self_volf.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +from pymic.net_run.agent_seg import SegmentationAgent + + +class SelfSupVolumeFusion(SegmentationAgent): + """ + Abstract class for self-supervised segmentation. + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. + """ + def __init__(self, config, stage = 'train'): + super(SelfSupVolumeFusion, self).__init__(config, stage) + + def get_transform_names_and_parameters(self, stage): + trans_names, trans_params = super(SelfSupVolumeFusion, self).get_transform_names_and_parameters(stage) + if(stage == 'train'): + print('training transforms:', trans_names) + if("Crop4VolumeFusion" not in trans_names): + raise ValueError("Crop4VolumeFusion is required for VolF, \ + but it is not given in training transform") + if("VolumeFusion" not in trans_names): + raise ValueError("VolumeFusion is required for VolF, \ + but it is not given in training transform") + if("LabelToProbability" not in trans_names): + raise ValueError("LabelToProbability is required for VolF, \ + but it is not given in training transform") + return trans_names, trans_params + + \ No newline at end of file diff --git a/pymic/net_run/self_sup/self_vox2vec.py b/pymic/net_run/self_sup/self_vox2vec.py new file mode 100644 index 0000000..94e2b16 --- /dev/null +++ b/pymic/net_run/self_sup/self_vox2vec.py @@ -0,0 +1,304 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import copy +import logging +import time +import logging +import torch +import torch.nn as nn +from datetime import datetime +from torch.optim import lr_scheduler +from tensorboardX import SummaryWriter +from pymic.io.image_read_write import save_nd_array_as_image +from pymic.net.net3d.fmunetv3 import FMUNetV3 +from pymic.net_run.agent_seg import SegmentationAgent +from pymic.loss.cls.infoNCE import InfoNCELoss + +def select_from_pyramid(feature_pyramid, indices): + """Select features from feature pyramid by their indices w.r.t. base feature map. + + Args: + feature_pyramid (Sequence[torch.Tensor]): Sequence of tensors of shapes ``(B, C_i, D_i, H_i, W_i)``. + indices (torch.Tensor): tensor of shape ``(B, N, 3)`` + + Returns: + torch.Tensor: tensor of shape ``(B, N, \sum_i c_i)`` + """ + out = [] + for i, x in enumerate(feature_pyramid): + batch_size = list(x.shape)[0] + x_move = x.moveaxis(1, -1) + index_i = indices // 2 ** i + x_i = [x_move[b][index_i[b][:, 0], index_i[b][:, 1], index_i[b][:, 2], :] for \ + b in range(batch_size)] + x_i = torch.stack(x_i) + out.append(x_i) + out = torch.cat(out, dim = -1) + return out + +class Vox2VecHead(nn.Module): + def __init__(self, params): + super(Vox2VecHead, self).__init__() + ft_chns = params['feature_chns'] + hidden_dim = params['hidden_dim'] + proj_dim = params['project_dim'] + embed_dim = sum(ft_chns) + self.proj_head = nn.Sequential( + nn.Linear(embed_dim, hidden_dim), + nn.BatchNorm1d(hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.BatchNorm1d(hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, proj_dim) + ) + + def forward(self, x): + output = self.proj_head(x) + output = nn.functional.normalize(output) + return output + +class Vox2VecWrapper(nn.Module): + """ + Perform forward pass separately on each resolution input. + The inputs corresponding to a single resolution are clubbed and single + forward is run on the same resolution inputs. Hence we do several + forward passes = number of different resolutions used. We then + concatenate all the output features and run the head forward on these + concatenated features. + """ + def __init__(self, backbone, head): + super(Vox2VecWrapper, self).__init__() + self.backbone = backbone + self.head = head + + def forward(self, x, vex_idx): + if(isinstance(self.backbone, FMUNetV3)): + x = self.backbone.project(x) + f = self.backbone.encoder(x) + B = list(f[0].shape)[0] + f_fpn = select_from_pyramid(f, vex_idx) + feature_dim = list(f_fpn.shape)[-1] + f_fpn = f_fpn.view(-1, feature_dim) + output = self.head(f_fpn) + proj_dim = list(output.shape)[-1] + output = output.view(B, -1, proj_dim) + return output + +class SelfSupVox2Vec(SegmentationAgent): + """ + An agent for image self-supervised learning with DeSD. + """ + def __init__(self, config, stage = 'train'): + super(SelfSupVox2Vec, self).__init__(config, stage) + + def create_network(self): + super(SelfSupVox2Vec, self).create_network() + proj_dim = self.config['self_supervised_learning'].get('project_dim', 1024) + hidden_dim = self.config['self_supervised_learning'].get('hidden_dim', 1024) + head_params= {'feature_chns': self.config['network']['feature_chns'], + 'hidden_dim':hidden_dim, + 'project_dim':proj_dim} + self.head = Vox2VecHead(head_params) + self.net_wrapper = Vox2VecWrapper(self.net, self.head) + + def create_loss_calculator(self): + # constrastive loss + self_sup_params = self.config['self_supervised_learning'] + self.loss_calculator = InfoNCELoss(self_sup_params) + + def get_parameters_to_update(self): + params = self.net_wrapper.parameters() + return params + + def training(self): + iter_valid = self.config['training']['iter_valid'] + train_loss = 0 + err_info = None + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 + self.net_wrapper.train() + for it in range(iter_valid): + t0 = time.time() + try: + data = next(self.trainIter) + except StopIteration: + self.trainIter = iter(self.train_loader) + data = next(self.trainIter) + t1 = time.time() + patch1, patch2, vox_ids1, vox_ids2 = data['image'] + inputs = torch.cat([patch1, patch2], dim = 0) + vox_ids = torch.cat([vox_ids1, vox_ids2], dim = 0) + inputs = self.convert_tensor_type(inputs) + inputs = inputs.to(self.device) + vox_ids = vox_ids.to(self.device) + + # for debug + # for i in range(patch1.shape[0]): + # v1_i = patch1[i][0] + # v2_i = patch2[i][0] + # print("patch shape", v1_i.shape, v2_i.shape) + # image_name0 = "temp/image_{0:}_{1:}_v0.nii.gz".format(it, i) + # image_name1 = "temp/image_{0:}_{1:}_v1.nii.gz".format(it, i) + # save_nd_array_as_image(v1_i, image_name0, reference_name = None) + # save_nd_array_as_image(v2_i, image_name1, reference_name = None) + # if(it > 10): + # return + + # zero the parameter gradients + self.optimizer.zero_grad() + + # forward + backward + optimize + out = self.net_wrapper(inputs, vox_ids) + out1, out2 = out.chunk(2) + + t2 = time.time() + loss = self.loss_calculator(out1, out2) + t3 = time.time() + + loss.backward() + self.optimizer.step() + train_loss = train_loss + loss.item() + t4 = time.time() + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 + + train_avg_loss = train_loss / iter_valid + train_scalers = {'loss': train_avg_loss, 'data_time': data_time, + 'gpu_time':gpu_time, 'loss_time':loss_time, 'back_time':back_time, + 'err_info': err_info} + return train_scalers + + + def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): + loss_scalar ={'train':train_scalars['loss']} + self.summ_writer.add_scalars('loss', loss_scalar, glob_it) + self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it) + logging.info('train loss {0:.4f}'.format(train_scalars['loss'])) + + def train_valid(self): + device_ids = self.config['training']['gpus'] + if(len(device_ids) > 1): + self.device = torch.device("cuda:0") + self.net_wrapper = nn.DataParallel(self.net_wrapper, device_ids = device_ids) + else: + self.device = torch.device("cuda:{0:}".format(device_ids[0])) + self.net_wrapper.to(self.device) + + ckpt_dir = self.config['training']['ckpt_dir'] + ckpt_prefix = self.config['training'].get('ckpt_prefix', None) + if(ckpt_prefix is None): + ckpt_prefix = ckpt_dir.split('/')[-1] + # iter_start = self.config['training']['iter_start'] + iter_start = 0 + iter_max = self.config['training']['iter_max'] + iter_valid = self.config['training']['iter_valid'] + iter_save = self.config['training'].get('iter_save', None) + early_stop_it = self.config['training'].get('early_stop_patience', None) + if(iter_save is None): + iter_save_list = [iter_max] + elif(isinstance(iter_save, (tuple, list))): + iter_save_list = iter_save + else: + iter_save_list = range(0, iter_max + 1, iter_save) + + self.min_loss = 10000.0 + self.min_loss_it = 0 + self.best_model_wts = None + self.bett_head_wts = None + checkpoint = None + # initialize the network with pre-trained weights + ckpt_init_name = self.config['training'].get('ckpt_init_name', None) + ckpt_init_mode = self.config['training'].get('ckpt_init_mode', 0) + ckpt_for_optm = None + if(ckpt_init_name is not None): + checkpoint = torch.load(ckpt_dir + "/" + ckpt_init_name, map_location = self.device) + pretrained_dict = checkpoint['model_state_dict'] + pretrain_head_dict = checkpoint['head_state_dict'] + self.load_pretrained_weights(self.net, pretrained_dict, device_ids) + self.load_pretrained_weights(self.head, pretrain_head_dict, device_ids) + + if(ckpt_init_mode > 0): # Load other information + self.min_loss = checkpoint.get('train_loss', 10000) + iter_start = checkpoint['iteration'] + self.min_loss_it = iter_start + self.best_model_wts = checkpoint['model_state_dict'] + self.best_head_wts = checkpoint['head_state_dict'] + ckpt_for_optm = checkpoint + + self.create_optimizer(self.get_parameters_to_update(), ckpt_for_optm) + self.create_loss_calculator() + + self.trainIter = iter(self.train_loader) + + logging.info("{0:} training start".format(str(datetime.now())[:-7])) + self.summ_writer = SummaryWriter(self.config['training']['ckpt_dir']) + self.glob_it = iter_start + for it in range(iter_start, iter_max, iter_valid): + lr_value = self.optimizer.param_groups[0]['lr'] + + t0 = time.time() + train_scalars = self.training() + t1 = time.time() + if(isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): + self.scheduler.step(-train_scalars['loss']) + else: + self.scheduler.step() + + self.glob_it = it + iter_valid + logging.info("\n{0:} it {1:}".format(str(datetime.now())[:-7], self.glob_it)) + logging.info('learning rate {0:}'.format(lr_value)) + logging.info("training time: {0:.2f}s".format(t1-t0)) + logging.info("data: {0:.2f}s, gpu: {1:.2f}s, loss: {2:.2f}s, back: {3:.2f}s".format( + train_scalars['data_time'], train_scalars['gpu_time'], + train_scalars['loss_time'], train_scalars['back_time'])) + + self.write_scalars(train_scalars, None, lr_value, self.glob_it) + if(train_scalars['loss'] < self.min_loss): + self.min_loss = train_scalars['loss'] + self.min_loss_it = self.glob_it + if(len(device_ids) > 1): + self.best_model_wts = copy.deepcopy(self.net.module.state_dict()) + self.best_head_wts = copy.deepcopy(self.head.module.state_dict()) + else: + self.best_model_wts = copy.deepcopy(self.net.state_dict()) + self.best_head_wts = copy.deepcopy(self.head.state_dict()) + + save_dict = {'iteration': self.min_loss_it, + 'train_loss': self.min_loss, + 'model_state_dict': self.best_model_wts, + 'head_state_dict': self.best_head_wts, + 'optimizer_state_dict': self.optimizer.state_dict()} + save_name = "{0:}/{1:}_best.pt".format(ckpt_dir, ckpt_prefix) + torch.save(save_dict, save_name) + txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefix), 'wt') + txt_file.write(str(self.min_loss_it)) + txt_file.close() + + stop_now = True if(early_stop_it is not None and \ + self.glob_it - self.min_loss_it > early_stop_it) else False + if(train_scalars['err_info'] is not None): + logging.info("Early stopped due to error: {0:}".format(train_scalars['err_info'])) + stop_now = True + if ((self.glob_it in iter_save_list) or stop_now): + save_dict = {'iteration': self.glob_it, + 'train_loss': train_scalars['loss'], + 'model_state_dict': self.net.module.state_dict() \ + if len(device_ids) > 1 else self.net.state_dict(), + 'head_state_dict': self.head.module.state_dict() \ + if len(device_ids) > 1 else self.head.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict()} + save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.glob_it) + torch.save(save_dict, save_name) + txt_file = open("{0:}/{1:}_latest.txt".format(ckpt_dir, ckpt_prefix), 'wt') + txt_file.write(str(self.glob_it)) + txt_file.close() + if(stop_now): + logging.info("The training is early stopped") + break + # save the best performing checkpoint + logging.info('The best performing iter is {0:}, train loss {1:}'.format(\ + self.min_loss_it, self.min_loss)) + self.summ_writer.close() \ No newline at end of file diff --git a/pymic/net_run/self_sup/util.py b/pymic/net_run/self_sup/util.py new file mode 100644 index 0000000..d6adcc1 --- /dev/null +++ b/pymic/net_run/self_sup/util.py @@ -0,0 +1,325 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import os +import copy +import torch +import random +import numpy as np +from scipy import ndimage +from pymic.io.image_read_write import * +from pymic.util.image_process import * + + +def get_human_region_mask(img): + """ + Get the mask of human region in CT volumes + """ + dim = len(img.shape) + if( dim == 4): + img = img[0] + mask = np.asarray(img > -600) + se = np.ones([3,3,3]) + mask = ndimage.binary_opening(mask, se, iterations = 2) + D, H, W = mask.shape + for h in range(H): + if(mask[:,h,:].sum() < 2000): + mask[:,h, :] = np.zeros((D, W)) + mask = get_largest_k_components(mask, 1) + mask_close = ndimage.binary_closing(mask, se, iterations = 2) + + D, H, W = mask.shape + for d in [1, 2, D-3, D-2]: + mask_close[d] = mask[d] + for d in range(0, D, 2): + mask_close[d, 2:-2, 2:-2] = np.ones((H-4, W-4)) + + # get background component + bg = np.zeros_like(mask) + bgs = get_largest_k_components(1- mask_close, 10) + for bgi in bgs: + indices = np.where(bgi) + if(bgi.sum() < 1000): + break + if(indices[0].min() == 0 or indices[1].min() == 0 or indices[2].min() ==0 or \ + indices[0].max() == D-1 or indices[1].max() == H-1 or indices[2].max() ==W-1): + bg = bg + bgi + fg = 1 - bg + + fg = ndimage.binary_opening(fg, se, iterations = 1) + fg = get_largest_k_components(fg, 1) + if(dim == 4): + fg = np.expand_dims(fg, 0) + fg = np.asarray(fg, np.uint8) + return fg + +def get_human_region_mask_fast(img, itk_spacing): + # downsample + D, H, W = img.shape + # scale_down = [1, 1, 1] + if(itk_spacing[2] <= 1): + scale_down = [1/2, 1/2, 1/2] + else: + scale_down = [1, 1/2, 1/2] + img_sub = ndimage.interpolation.zoom(img, scale_down, order = 0) + mask = get_human_region_mask(img_sub) + D1, H1, W1 = mask.shape + scale_up = [D/D1, H/H1, W/W1] + mask = ndimage.interpolation.zoom(mask, scale_up, order = 0) + return mask + +def crop_ct_scan(input_img, output_img, input_lab = None, output_lab = None, z_axis_density = 0.5): + """ + Crop a CT scan based on the bounding box of the human region. + """ + img_obj = sitk.ReadImage(input_img) + img = sitk.GetArrayFromImage(img_obj) + mask = np.asarray(img > -600) + mask2d = np.mean(mask, axis = 0) > z_axis_density + se = np.ones([3,3]) + mask2d = ndimage.binary_opening(mask2d, se, iterations = 2) + mask2d = get_largest_k_components(mask2d, 1) + bbmin, bbmax = get_ND_bounding_box(mask2d, margin = [0, 0]) + bbmin = [0] + bbmin + bbmax = [img.shape[0]] + bbmax + img_sub = crop_ND_volume_with_bounding_box(img, bbmin, bbmax) + img_sub_obj = sitk.GetImageFromArray(img_sub) + img_sub_obj.SetSpacing(img_obj.GetSpacing()) + img_sub_obj.SetDirection(img_obj.GetDirection()) + sitk.WriteImage(img_sub_obj, output_img) + if(input_lab is not None): + lab_obj = sitk.ReadImage(input_lab) + lab = sitk.GetArrayFromImage(lab_obj) + lab_sub = crop_ND_volume_with_bounding_box(lab, bbmin, bbmax) + lab_sub_obj = sitk.GetImageFromArray(lab_sub) + lab_sub_obj.SetSpacing(img_obj.GetSpacing()) + sitk.WriteImage(lab_sub_obj, output_lab) + +def get_human_body_mask_and_crop(input_dir, out_img_dir, out_mask_dir): + if(not os.path.exists(out_img_dir)): + os.mkdir(out_img_dir) + os.mkdir(out_mask_dir) + + img_names = [item for item in os.listdir(input_dir) if "nii.gz" in item] + img_names = sorted(img_names) + for img_name in img_names: + print(img_name) + input_name = input_dir + "/" + img_name + out_name = out_img_dir + "/" + img_name + mask_name = out_mask_dir + "/" + img_name + if(os.path.isfile(out_name)): + continue + img_obj = sitk.ReadImage(input_name) + img = sitk.GetArrayFromImage(img_obj) + spacing = img_obj.GetSpacing() + + # downsample + D, H, W = img.shape + spacing = img_obj.GetSpacing() + # scale_down = [1, 1, 1] + if(spacing[2] <= 1): + scale_down = [1/2, 1/2, 1/2] + else: + scale_down = [1, 1/2, 1/2] + img_sub = ndimage.interpolation.zoom(img, scale_down, order = 0) + mask = get_human_region_mask(img_sub) + D1, H1, W1 = mask.shape + scale_up = [D/D1, H/H1, W/W1] + mask = ndimage.interpolation.zoom(mask, scale_up, order = 0) + + bbmin, bbmax = get_ND_bounding_box(mask) + img_crop = crop_ND_volume_with_bounding_box(img, bbmin, bbmax) + mask_crop = crop_ND_volume_with_bounding_box(mask, bbmin, bbmax) + + out_img_obj = sitk.GetImageFromArray(img_crop) + out_img_obj.SetSpacing(spacing) + sitk.WriteImage(out_img_obj, out_name) + mask_obj = sitk.GetImageFromArray(mask_crop) + mask_obj.CopyInformation(out_img_obj) + sitk.WriteImage(mask_obj, mask_name) + +def volume_fusion(x, fg_num, block_range, size_min, size_max): + """ + Fuse a subregion of an impage with another one to generate + images and labels for self-supervised segmentation. + input x should be a batch of tensors + """ + #n_min, n_max, + N, C, D, H, W = list(x.shape) + fg_mask = torch.zeros_like(x[:, :1, :, :, :]).to(torch.int32) + # generate mask + for n in range(N): + p_num = random.randint(block_range[0], block_range[1]) + for i in range(p_num): + d = random.randint(size_min[0], size_max[0]) + h = random.randint(size_min[1], size_max[1]) + w = random.randint(size_min[2], size_max[2]) + dc = random.randint(0, D - 1) + hc = random.randint(0, H - 1) + wc = random.randint(0, W - 1) + d0 = dc - d // 2 + h0 = hc - h // 2 + w0 = wc - w // 2 + d1 = min(D, d0 + d) + h1 = min(H, h0 + h) + w1 = min(W, w0 + w) + d0, h0, w0 = max(0, d0), max(0, h0), max(0, w0) + temp_m = torch.ones([1, d1 - d0, h1 - h0, w1 - w0]) * random.randint(1, fg_num) + fg_mask[n, :, d0:d1, h0:h1, w0:w1] = temp_m + fg_w = fg_mask * 1.0 / fg_num + x_roll = torch.roll(x, 1, 0) + x_fuse = fg_w*x_roll + (1.0 - fg_w)*x + # y_prob = get_one_hot_seg(fg_mask.to(torch.int32), fg_num + 1) + return x_fuse, fg_mask + +def nonlinear_transform(x): + v_min = torch.min(x) + v_max = torch.max(x) + x = (x - v_min)/(v_max - v_min) + a = random.random() * 0.7 + 0.15 + b = random.random() * 0.7 + 0.15 + alpha = b / a + beta = (1 - b) / (1 - a) + if(alpha < 1.0 ): + y = torch.maximum(alpha*x, beta*x + 1 - beta) + else: + y = torch.minimum(alpha*x, beta*x + 1 - beta) + if(random.random() < 0.5): + y = 1.0 - y + y = y * (v_max - v_min) + v_min + return y + +def nonlienar_volume_fusion(x, block_range, size_min, size_max): + """ + Fuse a subregion of an impage with another one to generate + images and labels for self-supervised segmentation. + input x should be a batch of tensors + """ + #n_min, n_max, + N, C, D, H, W = list(x.shape) + # apply nonlinear transform to x: + x_nl1 = torch.zeros_like(x).to(torch.float32) + x_nl2 = torch.zeros_like(x).to(torch.float32) + for n in range(N): + x_nl1[n] = nonlinear_transform(x[n]) + x_nl2[n] = nonlinear_transform(x[n]) + x_roll = torch.roll(x_nl2, 1, 0) + mask = torch.zeros_like(x).to(torch.int32) + p_num = random.randint(block_range[0], block_range[1]) + for n in range(N): + for i in range(p_num): + d = random.randint(size_min[0], size_max[0]) + h = random.randint(size_min[1], size_max[1]) + w = random.randint(size_min[2], size_max[2]) + dc = random.randint(0, D - 1) + hc = random.randint(0, H - 1) + wc = random.randint(0, W - 1) + d0 = dc - d // 2 + h0 = hc - h // 2 + w0 = wc - w // 2 + d1 = min(D, d0 + d) + h1 = min(H, h0 + h) + w1 = min(W, w0 + w) + d0, h0, w0 = max(0, d0), max(0, h0), max(0, w0) + temp_m = torch.ones([C, d1 - d0, h1 - h0, w1 - w0]) + if(random.random() < 0.5): + temp_m = temp_m * 2 + mask[n, :, d0:d1, h0:h1, w0:w1] = temp_m + + mask1 = (mask == 1).to(torch.int32) + mask2 = (mask == 2).to(torch.int32) + y = x_nl1 * (1.0 - mask1) + x_nl2 * mask1 + y = y * (1.0 - mask2) + x_roll * mask2 + return y, mask + +def augmented_volume_fusion(x, size_min, size_max): + """ + Fuse a subregion of an impage with another one to generate + images and labels for self-supervised segmentation. + input x should be a batch of tensors + """ + #n_min, n_max, + N, C, D, H, W = list(x.shape) + # apply nonlinear transform to x: + x1 = torch.zeros_like(x).to(torch.float32) + y = torch.zeros_like(x).to(torch.float32) + mask = torch.zeros_like(x).to(torch.int32) + for n in range(N): + x1[n] = nonlinear_transform(x[n]) + y[n] = nonlinear_transform(x[n]) + x2 = torch.roll(x1, 1, 0) + + for n in range(N): + block_size = [random.randint(size_min[i], size_max[i]) for i in range(3)] + d_start = random.randint(0, block_size[0] // 2) + h_start = random.randint(0, block_size[1] // 2) + w_stat = random.randint(0, block_size[2] // 2) + for d in range(d_start, D, block_size[0]): + if(D - d < block_size[0] // 2): + continue + d1 = min(d + block_size[0], D) + for h in range(h_start, H, block_size[1]): + if(H - h < block_size[1] // 2): + continue + h1 = min(h + block_size[1], H) + for w in range(w_stat, W, block_size[2]): + if(W - w < block_size[2] // 2): + continue + w1 = min(w + block_size[2], W) + p = random.random() + if(p < 0.15): # nonlinear intensity augmentation + mask[n, :, d:d1, h:h1, w:w1] = 1 + y[n, :, d:d1, h:h1, w:w1] = x1[n, :, d:d1, h:h1, w:w1] + elif(p < 0.3): # random flip across a certain axis + mask[n, :, d:d1, h:h1, w:w1] = 2 + flip_axis = random.randint(-3, -1) + y[n, :, d:d1, h:h1, w:w1] = torch.flip(y[n, :, d:d1, h:h1, w:w1], (flip_axis,)) + elif(p < 0.45): # nonlinear intensity augmentation and random flip across a certain axis + mask[n, :, d:d1, h:h1, w:w1] = 3 + flip_axis = random.randint(-3, -1) + y[n, :, d:d1, h:h1, w:w1] = torch.flip(x1[n, :, d:d1, h:h1, w:w1], (flip_axis,)) + elif(p < 0.6): # paste from another volume + mask[n, :, d:d1, h:h1, w:w1] = 4 + y[n, :, d:d1, h:h1, w:w1] = x2[n, :, d:d1, h:h1, w:w1] + return y, mask + +def self_volume_fusion(x, fg_num, fuse_ratio, size_min, size_max): + """ + Fuse a subregion of an impage with another one to generate + images and labels for self-supervised segmentation. + input x should be a batch of tensors + """ + #n_min, n_max, + N, C, D, H, W = list(x.shape) + y = 1.0 * x + fg_mask = torch.zeros_like(x[:, :1, :, :, :]).to(torch.int32) + + for n in range(N): + db = random.randint(size_min[0], size_max[0]) + hb = random.randint(size_min[1], size_max[1]) + wb = random.randint(size_min[2], size_max[2]) + d0 = random.randint(0, D % db) + h0 = random.randint(0, H % hb) + w0 = random.randint(0, W % wb) + coord_list_source = [] + for di in range(D // db): + for hi in range(H // hb): + for wi in range(W // wb): + coord_list_source.append([di, hi, wi]) + coord_list_target = copy.deepcopy(coord_list_source) + random.shuffle(coord_list_source) + random.shuffle(coord_list_target) + for i in range(int(len(coord_list_source)*fuse_ratio)): + ds_l = d0 + db * coord_list_source[i][0] + hs_l = h0 + hb * coord_list_source[i][1] + ws_l = w0 + wb * coord_list_source[i][2] + dt_l = d0 + db * coord_list_target[i][0] + ht_l = h0 + hb * coord_list_target[i][1] + wt_l = w0 + wb * coord_list_target[i][2] + s_crop = x[n, :, ds_l:ds_l+db, hs_l:hs_l+hb, ws_l:ws_l+wb] + t_crop = x[n, :, dt_l:dt_l+db, ht_l:ht_l+hb, wt_l:wt_l+wb] + fg_m = random.randint(1, fg_num) + fg_w = fg_m / (fg_num + 0.0) + y[n, :, dt_l:dt_l+db, ht_l:ht_l+hb, wt_l:wt_l+wb] = t_crop * (1.0 - fg_w) + s_crop * fg_w + fg_mask[n, 0, dt_l:dt_l+db, ht_l:ht_l+hb, wt_l:wt_l+wb] = \ + torch.ones([1, db, hb, wb]) * fg_m + return y, fg_mask \ No newline at end of file diff --git a/pymic/net_run/semi_sup/__init__.py b/pymic/net_run/semi_sup/__init__.py index be753c2..769a66b 100644 --- a/pymic/net_run/semi_sup/__init__.py +++ b/pymic/net_run/semi_sup/__init__.py @@ -2,6 +2,8 @@ from pymic.net_run.semi_sup.ssl_abstract import SSLSegAgent from pymic.net_run.semi_sup.ssl_em import SSLEntropyMinimization from pymic.net_run.semi_sup.ssl_mt import SSLMeanTeacher +from pymic.net_run.semi_sup.ssl_mcnet import SSLMCNet +# from pymic.net_run.semi_sup.ssl_cdma import SSLCDMA from pymic.net_run.semi_sup.ssl_uamt import SSLUncertaintyAwareMeanTeacher from pymic.net_run.semi_sup.ssl_cct import SSLCCT from pymic.net_run.semi_sup.ssl_cps import SSLCPS @@ -10,6 +12,8 @@ SSLMethodDict = {'EntropyMinimization': SSLEntropyMinimization, 'MeanTeacher': SSLMeanTeacher, + 'MCNet': SSLMCNet, + # 'CDMA': SSLCDMA, 'UAMT': SSLUncertaintyAwareMeanTeacher, 'CCT': SSLCCT, 'CPS': SSLCPS, diff --git a/pymic/net_run/semi_sup/ssl_abstract.py b/pymic/net_run/semi_sup/ssl_abstract.py index 5a46257..4925859 100644 --- a/pymic/net_run/semi_sup/ssl_abstract.py +++ b/pymic/net_run/semi_sup/ssl_abstract.py @@ -35,7 +35,7 @@ def get_unlabeled_dataset_from_config(self): """ Create a dataset for the unlabeled images based on configuration. """ - root_dir = self.config['dataset']['root_dir'] + train_dir = self.config['dataset']['train_dir'] modal_num = self.config['dataset'].get('modal_num', 1) transform_names = self.config['dataset']['train_transform_unlab'] @@ -44,7 +44,7 @@ def get_unlabeled_dataset_from_config(self): data_transform = None else: transform_param = self.config['dataset'] - transform_param['task'] = 'segmentation' + transform_param['task'] = self.task_type for name in transform_names: if(name not in self.transform_dict): raise(ValueError("Undefined transform {0:}".format(name))) @@ -52,12 +52,16 @@ def get_unlabeled_dataset_from_config(self): self.transform_list.append(one_transform) data_transform = transforms.Compose(self.transform_list) - csv_file = self.config['dataset'].get('train_csv_unlab', None) - dataset = NiftyDataset(root_dir=root_dir, + csv_file = self.config['dataset'].get('train_csv_unlab', None) + stage_dim = self.config['dataset'].get('train_dim', 3) + dataset = NiftyDataset(root_dir = train_dir, csv_file = csv_file, modal_num = modal_num, - with_label= False, - transform = data_transform ) + image_dim = stage_dim, + allow_missing_modal = False, + label_key = None, + transform = data_transform, + task = self.task_type) return dataset def create_dataset(self): @@ -76,7 +80,7 @@ def worker_init_fn(worker_id): num_worker = self.config['dataset'].get('num_worker', 16) self.train_loader_unlab = torch.utils.data.DataLoader(self.train_set_unlab, batch_size = bn_train_unlab, shuffle=True, num_workers= num_worker, - worker_init_fn=worker_init) + worker_init_fn=worker_init, drop_last = True) def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): loss_scalar ={'train':train_scalars['loss'], @@ -101,6 +105,9 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): logging.info('valid loss {0:.4f}, avg foreground dice {1:.4f} '.format( valid_scalars['loss'], valid_scalars['avg_fg_dice']) + "[" + \ ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") + logging.info("data: {0:.2f}s, forward: {1:.2f}s, loss: {2:.2f}s, backward: {3:.2f}s".format( + train_scalars['data_time'], train_scalars['forward_time'], + train_scalars['loss_time'], train_scalars['backward_time'])) def train_valid(self): self.trainIter_unlab = iter(self.train_loader_unlab) diff --git a/pymic/net_run/semi_sup/ssl_cct.py b/pymic/net_run/semi_sup/ssl_cct.py index 1943608..81723bd 100644 --- a/pymic/net_run/semi_sup/ssl_cct.py +++ b/pymic/net_run/semi_sup/ssl_cct.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division import logging +import time import torch import torch.nn as nn import torch.nn.functional as F @@ -88,13 +89,13 @@ def training(self): rampup_end = ssl_cfg.get('rampup_end', iter_max) unsup_loss_name = ssl_cfg.get('unsupervised_loss', "MSE") self.unsup_loss_f = unsup_loss_dict[unsup_loss_name] - train_loss = 0 - train_loss_sup = 0 - train_loss_reg = 0 + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() for it in range(iter_valid): + t0 = time.time() try: data_lab = next(self.trainIter) except StopIteration: @@ -105,7 +106,7 @@ def training(self): except StopIteration: self.trainIter_unlab = iter(self.train_loader_unlab) data_unlab = next(self.trainIter_unlab) - + t1 = time.time() # get the inputs x0 = self.convert_tensor_type(data_lab['image']) y0 = self.convert_tensor_type(data_lab['label_prob']) @@ -118,6 +119,7 @@ def training(self): # forward pass output, aux_outputs = self.net(inputs) + t2 = time.time() n0 = list(x0.shape)[0] # get supervised loss @@ -135,8 +137,9 @@ def training(self): rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") regular_w = ssl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg - + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss = train_loss + loss.item() @@ -150,6 +153,11 @@ def training(self): p0_soft, y0 = reshape_prediction_and_ground_truth(p0_soft, y0) dice_list = get_classwise_dice(p0_soft, y0) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid @@ -158,5 +166,7 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } return train_scalers diff --git a/pymic/net_run/semi_sup/ssl_cps.py b/pymic/net_run/semi_sup/ssl_cps.py index 4a3be9c..dc0d325 100644 --- a/pymic/net_run/semi_sup/ssl_cps.py +++ b/pymic/net_run/semi_sup/ssl_cps.py @@ -1,31 +1,17 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division import logging +import time import numpy as np import torch -import torch.nn as nn +from random import random from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice +from pymic.io.image_read_write import save_nd_array_as_image from pymic.net_run.semi_sup import SSLSegAgent -from pymic.net.net_dict_seg import SegNetDict from pymic.util.ramps import get_rampup_ratio - -class BiNet(nn.Module): - def __init__(self, params): - super(BiNet, self).__init__() - net_name = params['net_type'] - self.net1 = SegNetDict[net_name](params) - self.net2 = SegNetDict[net_name](params) - - def forward(self, x): - out1 = self.net1(x) - out2 = self.net2(x) - - if(self.training): - return out1, out2 - else: - return (out1 + out2) / 2 +from pymic.util.general import mixup, tensor_shape_match class SSLCPS(SSLSegAgent): """ @@ -47,27 +33,22 @@ class SSLCPS(SSLSegAgent): def __init__(self, config, stage = 'train'): super(SSLCPS, self).__init__(config, stage) - def create_network(self): - if(self.net is None): - self.net = BiNet(self.config['network']) - if(self.tensor_type == 'float'): - self.net.float() - else: - self.net.double() - def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] ssl_cfg = self.config['semi_supervised_learning'] - iter_max = self.config['training']['iter_max'] + iter_max = self.config['training']['iter_max'] + mixup_prob = self.config['training'].get('mixup_probability', 0.0) rampup_start = ssl_cfg.get('rampup_start', 0) rampup_end = ssl_cfg.get('rampup_end', iter_max) train_loss = 0 train_loss_sup1, train_loss_pseudo_sup1 = 0, 0 train_loss_sup2, train_loss_pseudo_sup2 = 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() for it in range(iter_valid): + t0 = time.time() try: data_lab = next(self.trainIter) except StopIteration: @@ -78,20 +59,35 @@ def training(self): except StopIteration: self.trainIter_unlab = iter(self.train_loader_unlab) data_unlab = next(self.trainIter_unlab) - + t1 = time.time() # get the inputs x0 = self.convert_tensor_type(data_lab['image']) y0 = self.convert_tensor_type(data_lab['label_prob']) x1 = self.convert_tensor_type(data_unlab['image']) + + # for debug + # for i in range(x0.shape[0]): + # image_i = x0[i][0] + # label_i = np.argmax(y0[i], axis = 0) + # # pixw_i = pix_w[i][0] + # print(image_i.shape, label_i.shape) + # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) + # label_name = "temp/label_{0:}_{1:}.nii.gz".format(it, i) + # save_nd_array_as_image(image_i, image_name, reference_name = None) + # save_nd_array_as_image(label_i, label_name, reference_name = None) + # continue + if(mixup_prob > 0 and random() < mixup_prob): + x0, y0 = mixup(x0, y0) inputs = torch.cat([x0, x1], dim = 0) inputs, y0 = inputs.to(self.device), y0.to(self.device) # zero the parameter gradients self.optimizer.zero_grad() - outputs1, outputs2 = self.net(inputs) + outputs1, outputs2 = self.net(inputs) outputs_soft1 = torch.softmax(outputs1, dim=1) outputs_soft2 = torch.softmax(outputs2, dim=1) + t2 = time.time() n0 = list(x0.shape)[0] p0 = outputs_soft1[:n0] @@ -113,8 +109,9 @@ def training(self): model1_loss = loss_sup1 + regular_w * pse_sup1 model2_loss = loss_sup2 + regular_w * pse_sup2 loss = model1_loss + model2_loss - + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss = train_loss + loss.item() @@ -131,6 +128,11 @@ def training(self): p0_soft, y0 = reshape_prediction_and_ground_truth(p0_soft, y0) dice_list = get_classwise_dice(p0_soft, y0) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup1 = train_loss_sup1 / iter_valid train_avg_loss_sup2 = train_loss_sup2 / iter_valid @@ -142,7 +144,9 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup1':train_avg_loss_sup1, 'loss_sup2': train_avg_loss_sup2, 'loss_pse_sup1':train_avg_loss_pse_sup1, 'loss_pse_sup2': train_avg_loss_pse_sup2, - 'regular_w':regular_w, 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'regular_w':regular_w, 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } return train_scalers def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): @@ -170,4 +174,7 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): ' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]") logging.info('valid loss {0:.4f}, avg dice {1:.4f} '.format( valid_scalars['loss'], valid_scalars['avg_fg_dice']) + "[" + \ - ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") \ No newline at end of file + ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") + logging.info("data: {0:.2f}s, forward: {1:.2f}s, loss: {2:.2f}s, backward: {3:.2f}s".format( + train_scalars['data_time'], train_scalars['forward_time'], + train_scalars['loss_time'], train_scalars['backward_time'])) \ No newline at end of file diff --git a/pymic/net_run/semi_sup/ssl_em.py b/pymic/net_run/semi_sup/ssl_em.py index fde941b..750e3d3 100644 --- a/pymic/net_run/semi_sup/ssl_em.py +++ b/pymic/net_run/semi_sup/ssl_em.py @@ -2,6 +2,7 @@ from __future__ import print_function, division import logging import numpy as np +import time import torch from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth @@ -40,12 +41,12 @@ def training(self): iter_max = self.config['training']['iter_max'] rampup_start = ssl_cfg.get('rampup_start', 0) rampup_end = ssl_cfg.get('rampup_end', iter_max) - train_loss = 0 - train_loss_sup = 0 - train_loss_reg = 0 + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() for it in range(iter_valid): + t0 = time.time() try: data_lab = next(self.trainIter) except StopIteration: @@ -56,7 +57,7 @@ def training(self): except StopIteration: self.trainIter_unlab = iter(self.train_loader_unlab) data_unlab = next(self.trainIter_unlab) - + t1 = time.time() # get the inputs x0 = self.convert_tensor_type(data_lab['image']) y0 = self.convert_tensor_type(data_lab['label_prob']) @@ -69,6 +70,8 @@ def training(self): # forward + backward + optimize outputs = self.net(inputs) + t2 = time.time() + n0 = list(x0.shape)[0] p0 = outputs[:n0] loss_sup = self.get_loss_value(data_lab, p0, y0) @@ -79,8 +82,10 @@ def training(self): regular_w = ssl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg - # if (self.config['training']['use']) + t3 = time.time() + loss.backward() + t4 = time.time() self.optimizer.step() train_loss = train_loss + loss.item() @@ -94,6 +99,11 @@ def training(self): p0_soft, y0 = reshape_prediction_and_ground_truth(p0_soft, y0) dice_list = get_classwise_dice(p0_soft, y0) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid @@ -102,5 +112,7 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } return train_scalers \ No newline at end of file diff --git a/pymic/net_run/semi_sup/ssl_mcnet.py b/pymic/net_run/semi_sup/ssl_mcnet.py new file mode 100644 index 0000000..d374773 --- /dev/null +++ b/pymic/net_run/semi_sup/ssl_mcnet.py @@ -0,0 +1,139 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import logging +import time +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from pymic.loss.seg.util import get_soft_label +from pymic.loss.seg.util import reshape_prediction_and_ground_truth +from pymic.loss.seg.util import get_classwise_dice +from pymic.net_run.semi_sup import SSLSegAgent +from pymic.util.ramps import get_rampup_ratio + +def sharpening(P, T = 0.1): + T = 1.0/T + P_sharpen = P**T / (P**T + (1-P)**T) + return P_sharpen + +class SSLMCNet(SSLSegAgent): + """ + Mutual Consistency Learning for semi-supervised segmentation. It requires a network + with multiple decoders for learning, such as `pymic.net.net2d.unet2d_mcnet.MCNet2D`. + + * Reference: Yicheng Wu, Zongyuan Ge et al. Mutual consistency learning for + semi-supervised medical image segmentation. + `MIA 2022. `_ + + The original code is at: https://github.com/ycwu1997/MC-Net + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + .. note:: + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. + """ + def training(self): + class_num = self.config['network']['class_num'] + iter_valid = self.config['training']['iter_valid'] + ssl_cfg = self.config['semi_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = ssl_cfg.get('rampup_start', 0) + rampup_end = ssl_cfg.get('rampup_end', iter_max) + temperature = ssl_cfg.get('temperature', 0.1) + unsup_loss_name = ssl_cfg.get('unsupervised_loss', "MSE") + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 + train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 + self.net.train() + + for it in range(iter_valid): + t0 = time.time() + try: + data_lab = next(self.trainIter) + except StopIteration: + self.trainIter = iter(self.train_loader) + data_lab = next(self.trainIter) + try: + data_unlab = next(self.trainIter_unlab) + except StopIteration: + self.trainIter_unlab = iter(self.train_loader_unlab) + data_unlab = next(self.trainIter_unlab) + t1 = time.time() + # get the inputs + x0 = self.convert_tensor_type(data_lab['image']) + y0 = self.convert_tensor_type(data_lab['label_prob']) + x1 = self.convert_tensor_type(data_unlab['image']) + inputs = torch.cat([x0, x1], dim = 0) + inputs, y0 = inputs.to(self.device), y0.to(self.device) + + # zero the parameter gradients + self.optimizer.zero_grad() + + # forward pass to obtain multiple predictions + outputs = self.net(inputs) + t2 = time.time() + num_outputs = len(outputs) + n0 = list(x0.shape)[0] + p0 = F.softmax(outputs[0], dim=1)[:n0] + # for probability prediction and pseudo respectively + p_ori = torch.zeros((num_outputs,) + outputs[0].shape) + y_psu = torch.zeros((num_outputs,) + outputs[0].shape) + + # get supervised loss + loss_sup = 0 + for idx in range(num_outputs): + p0i = outputs[idx][:n0] + loss_sup += self.get_loss_value(data_lab, p0i, y0) + + # get pseudo labels + p_i = F.softmax(outputs[idx], dim=1) + p_ori[idx] = p_i + y_psu[idx] = sharpening(p_i, temperature) + + # get regularization loss + loss_reg = 0.0 + for i in range(num_outputs): + for j in range(num_outputs): + if (i!=j): + loss_reg += F.mse_loss(p_ori[i], y_psu[j], reduction='mean') + + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + regular_w = ssl_cfg.get('regularize_w', 0.1) * rampup_ratio + loss = loss_sup + regular_w*loss_reg + t3 = time.time() + loss.backward() + t4 = time.time() + self.optimizer.step() + + train_loss = train_loss + loss.item() + train_loss_sup = train_loss_sup + loss_sup.item() + train_loss_reg = train_loss_reg + loss_reg.item() + # get dice evaluation for each class in annotated images + if(isinstance(p0, tuple) or isinstance(p0, list)): + p0 = p0[0] + p0_argmax = torch.argmax(p0, dim = 1, keepdim = True) + p0_soft = get_soft_label(p0_argmax, class_num, self.tensor_type) + p0_soft, y0 = reshape_prediction_and_ground_truth(p0_soft, y0) + dice_list = get_classwise_dice(p0_soft, y0) + train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 + train_avg_loss = train_loss / iter_valid + train_avg_loss_sup = train_loss_sup / iter_valid + train_avg_loss_reg = train_loss_reg / iter_valid + train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) + train_avg_dice = train_cls_dice[1:].mean() + + train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, + 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } + return train_scalers diff --git a/pymic/net_run/semi_sup/ssl_mt.py b/pymic/net_run/semi_sup/ssl_mt.py index 2a2abb8..303eeff 100644 --- a/pymic/net_run/semi_sup/ssl_mt.py +++ b/pymic/net_run/semi_sup/ssl_mt.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division import logging +import time import torch import numpy as np from pymic.loss.seg.util import get_soft_label @@ -50,13 +51,13 @@ def training(self): iter_max = self.config['training']['iter_max'] rampup_start = ssl_cfg.get('rampup_start', 0) rampup_end = ssl_cfg.get('rampup_end', iter_max) - train_loss = 0 - train_loss_sup = 0 - train_loss_reg = 0 + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() self.net_ema.to(self.device) for it in range(iter_valid): + t0 = time.time() try: data_lab = next(self.trainIter) except StopIteration: @@ -67,7 +68,7 @@ def training(self): except StopIteration: self.trainIter_unlab = iter(self.train_loader_unlab) data_unlab = next(self.trainIter_unlab) - + t1 = time.time() # get the inputs x0 = self.convert_tensor_type(data_lab['image']) y0 = self.convert_tensor_type(data_lab['label_prob']) @@ -82,6 +83,8 @@ def training(self): self.optimizer.zero_grad() outputs = self.net(inputs) + t2 = time.time() + n0 = list(x0.shape)[0] p0 = outputs[:n0] loss_sup = self.get_loss_value(data_lab, p0, y0) @@ -98,15 +101,16 @@ def training(self): loss_reg = torch.nn.MSELoss()(p1_soft, p1_ema_soft) loss = loss_sup + regular_w*loss_reg - + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() # update EMA alpha = ssl_cfg.get('ema_decay', 0.99) alpha = min(1 - 1 / (self.glob_it / iter_valid + 1), alpha) for ema_param, param in zip(self.net_ema.parameters(), self.net.parameters()): - ema_param.data.mul_(alpha).add_(1 - alpha, param.data) + ema_param.data.mul_(alpha).add(param.data, alpha = 1.0 - alpha) train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() @@ -119,6 +123,11 @@ def training(self): p0_soft, y0 = reshape_prediction_and_ground_truth(p0_soft, y0) dice_list = get_classwise_dice(p0_soft, y0) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid @@ -127,5 +136,7 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time} return train_scalers \ No newline at end of file diff --git a/pymic/net_run/semi_sup/ssl_uamt.py b/pymic/net_run/semi_sup/ssl_uamt.py index 6222fe3..5888a8b 100644 --- a/pymic/net_run/semi_sup/ssl_uamt.py +++ b/pymic/net_run/semi_sup/ssl_uamt.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division import logging +import time import torch import numpy as np from pymic.loss.seg.util import get_soft_label @@ -33,13 +34,13 @@ def training(self): iter_max = self.config['training']['iter_max'] rampup_start = ssl_cfg.get('rampup_start', 0) rampup_end = ssl_cfg.get('rampup_end', iter_max) - train_loss = 0 - train_loss_sup = 0 - train_loss_reg = 0 + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() self.net_ema.to(self.device) for it in range(iter_valid): + t0 = time.time() try: data_lab = next(self.trainIter) except StopIteration: @@ -50,7 +51,7 @@ def training(self): except StopIteration: self.trainIter_unlab = iter(self.train_loader_unlab) data_unlab = next(self.trainIter_unlab) - + t1 = time.time() # get the inputs x0 = self.convert_tensor_type(data_lab['image']) y0 = self.convert_tensor_type(data_lab['label_prob']) @@ -64,6 +65,7 @@ def training(self): self.optimizer.zero_grad() outputs = self.net(inputs) + t2 = time.time() n0 = list(x0.shape)[0] p0, p1 = torch.tensor_split(outputs, [n0,], dim = 0) outputs_soft = torch.softmax(outputs, dim=1) @@ -100,15 +102,16 @@ def training(self): regular_w = ssl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg - + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() # update EMA alpha = ssl_cfg.get('ema_decay', 0.99) alpha = min(1 - 1 / (self.glob_it / iter_valid + 1), alpha) for ema_param, param in zip(self.net_ema.parameters(), self.net.parameters()): - ema_param.data.mul_(alpha).add_(1 - alpha, param.data) + ema_param.data.mul_(alpha).add(param.data, alpha = 1.0 - alpha) train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() @@ -121,6 +124,11 @@ def training(self): p0_soft, y0 = reshape_prediction_and_ground_truth(p0_soft, y0) dice_list = get_classwise_dice(p0_soft, y0) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid @@ -129,5 +137,7 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } return train_scalers \ No newline at end of file diff --git a/pymic/net_run/semi_sup/ssl_urpc.py b/pymic/net_run/semi_sup/ssl_urpc.py index 56bb77e..8447709 100644 --- a/pymic/net_run/semi_sup/ssl_urpc.py +++ b/pymic/net_run/semi_sup/ssl_urpc.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division import logging +import time import torch import torch.nn as nn import numpy as np @@ -35,13 +36,13 @@ def training(self): iter_max = self.config['training']['iter_max'] rampup_start = ssl_cfg.get('rampup_start', 0) rampup_end = ssl_cfg.get('rampup_end', iter_max) - train_loss = 0 - train_loss_sup = 0 - train_loss_reg = 0 + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() kl_distance = nn.KLDivLoss(reduction='none') for it in range(iter_valid): + t0 = time.time() try: data_lab = next(self.trainIter) except StopIteration: @@ -52,7 +53,7 @@ def training(self): except StopIteration: self.trainIter_unlab = iter(self.train_loader_unlab) data_unlab = next(self.trainIter_unlab) - + t1 = time.time() # get the inputs x0 = self.convert_tensor_type(data_lab['image']) y0 = self.convert_tensor_type(data_lab['label_prob']) @@ -65,6 +66,7 @@ def training(self): # forward pass outputs_list = self.net(inputs) + t2 = time.time() n0 = list(x0.shape)[0] # get supervised loss @@ -95,8 +97,9 @@ def training(self): rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") regular_w = ssl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg - + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss = train_loss + loss.item() @@ -110,6 +113,11 @@ def training(self): p0_soft, y0 = reshape_prediction_and_ground_truth(p0_soft, y0) dice_list = get_classwise_dice(p0_soft, y0) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid @@ -118,5 +126,7 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time} return train_scalers diff --git a/pymic/net_run/train.py b/pymic/net_run/train.py index 1478527..d98145a 100644 --- a/pymic/net_run/train.py +++ b/pymic/net_run/train.py @@ -1,23 +1,29 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division +import argparse import logging import os import sys import shutil from datetime import datetime +from pymic import TaskType from pymic.util.parse_config import * from pymic.net_run.agent_cls import ClassificationAgent from pymic.net_run.agent_seg import SegmentationAgent +from pymic.net_run.agent_rec import ReconstructionAgent from pymic.net_run.semi_sup import SSLMethodDict from pymic.net_run.weak_sup import WSLMethodDict +from pymic.net_run.self_sup import SelfSupMethodDict from pymic.net_run.noisy_label import NLLMethodDict -from pymic.net_run.self_sup import SelfSLSegAgent -def get_segmentation_agent(config, sup_type): +def get_seg_rec_agent(config, sup_type): assert(sup_type in ['fully_sup', 'semi_sup', 'self_sup', 'weak_sup', 'noisy_label']) if(sup_type == 'fully_sup'): logging.info("\n********** Fully Supervised Learning **********\n") - agent = SegmentationAgent(config, 'train') + if config['dataset']['task_type'] == TaskType.SEGMENTATION: + agent = SegmentationAgent(config, 'train') + else: + agent = ReconstructionAgent(config, 'train') elif(sup_type == 'semi_sup'): logging.info("\n********** Semi Supervised Learning **********\n") method = config['semi_supervised_learning']['method_name'] @@ -33,28 +39,7 @@ def get_segmentation_agent(config, sup_type): elif(sup_type == 'self_sup'): logging.info("\n********** Self Supervised Learning **********\n") method = config['self_supervised_learning']['method_name'] - if(method == "custom"): - pass - elif(method == "model_genesis"): - transforms = ['RandomFlip', 'LocalShuffling', 'NonLinearTransform', 'InOutPainting'] - genesis_cfg = { - 'randomflip_flip_depth': True, - 'randomflip_flip_height': True, - 'randomflip_flip_width': True, - 'localshuffling_probability': 0.5, - 'nonLineartransform_probability': 0.9, - 'inoutpainting_probability': 0.9, - 'inpainting_probability': 0.2 - } - config['dataset']['train_transform'].extend(transforms) - config['dataset']['valid_transform'].extend(transforms) - config['dataset'].update(genesis_cfg) - logging_config(config['dataset']) - else: - raise ValueError("The specified method {0:} is not implemented. ".format(method) + \ - "Consider to set `self_sl_method = custom` and use customized" + \ - " transforms for self-supervised learning.") - agent = SelfSLSegAgent(config, 'train') + agent = SelfSupMethodDict[method](config, 'train') else: raise ValueError("undefined supervision type: {0:}".format(sup_type)) return agent @@ -64,34 +49,48 @@ def main(): The main function for running a network for training. """ if(len(sys.argv) < 2): - print('Number of arguments should be 2. e.g.') - print(' pymic_train config.cfg') + print('Number of arguments should be at least 2. e.g.') + print(' pymic_train config.cfg -train_csv train.csv') exit() - cfg_file = str(sys.argv[1]) - if(not os.path.isfile(cfg_file)): - raise ValueError("The config file does not exist: " + cfg_file) - config = parse_config(cfg_file) + parser = argparse.ArgumentParser() + parser.add_argument("cfg", help="configuration file for training") + parser.add_argument("--train_csv", help="the csv file for training images", + required=False, default=None) + parser.add_argument("--valid_csv", help="the csv file for validation images", + required=False, default=None) + parser.add_argument("--ckpt_dir", help="the output dir for trained model", + required=False, default=None) + parser.add_argument("--iter_max", help="the maximal iteration number for training", + required=False, default=None) + parser.add_argument("--gpus", help="the gpus for runing, e.g., [0]", + required=False, default=None) + args = parser.parse_args() + if(not os.path.isfile(args.cfg)): + raise ValueError("The config file does not exist: " + args.cfg) + config = parse_config(args) config = synchronize_config(config) - log_dir = config['training']['ckpt_save_dir'] + + log_dir = config['training']['ckpt_dir'] if(not os.path.exists(log_dir)): os.makedirs(log_dir, exist_ok=True) - dst_cfg = cfg_file if "/" not in cfg_file else cfg_file.split("/")[-1] - shutil.copy(cfg_file, log_dir + "/" + dst_cfg) + datetime_str = str(datetime.now())[:-7].replace(":", "_") if sys.version.startswith("3.9"): - logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(str(datetime.now())[:-7]), + logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(datetime_str), level=logging.INFO, format='%(message)s', force=True) # for python 3.9 else: - logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(str(datetime.now())[:-7]), + logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(datetime_str), level=logging.INFO, format='%(message)s') # for python 3.6 logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) - logging_config(config) + dst_cfg = args.cfg if "/" not in args.cfg else args.cfg.split("/")[-1] + wrtie_config(config, log_dir + "/" + dst_cfg) + task = config['dataset']['task_type'] - assert task in ['cls', 'cls_nexcl', 'seg'] - if(task == 'cls' or task == 'cls_nexcl'): + if(task == TaskType.CLASSIFICATION_ONE_HOT or task == TaskType.CLASSIFICATION_COEXIST): agent = ClassificationAgent(config, 'train') else: sup_type = config['dataset'].get('supervise_type', 'fully_sup') - agent = get_segmentation_agent(config, sup_type) + agent = get_seg_rec_agent(config, sup_type) + agent.run() if __name__ == "__main__": diff --git a/pymic/net_run/weak_sup/__init__.py b/pymic/net_run/weak_sup/__init__.py index b3c8332..a583ae8 100644 --- a/pymic/net_run/weak_sup/__init__.py +++ b/pymic/net_run/weak_sup/__init__.py @@ -6,10 +6,12 @@ from pymic.net_run.weak_sup.wsl_tv import WSLTotalVariation from pymic.net_run.weak_sup.wsl_ustm import WSLUSTM from pymic.net_run.weak_sup.wsl_dmpls import WSLDMPLS +from pymic.net_run.weak_sup.wsl_dmsps import WSLDMSPS WSLMethodDict = {'EntropyMinimization': WSLEntropyMinimization, 'GatedCRF': WSLGatedCRF, 'MumfordShah': WSLMumfordShah, 'TotalVariation': WSLTotalVariation, 'USTM': WSLUSTM, - 'DMPLS': WSLDMPLS} \ No newline at end of file + 'DMPLS': WSLDMPLS, + 'DMSPS': WSLDMSPS} \ No newline at end of file diff --git a/pymic/net_run/weak_sup/wsl_abstract.py b/pymic/net_run/weak_sup/wsl_abstract.py index f290465..37b9445 100644 --- a/pymic/net_run/weak_sup/wsl_abstract.py +++ b/pymic/net_run/weak_sup/wsl_abstract.py @@ -24,7 +24,7 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): 'valid':valid_scalars['loss']} loss_sup_scalar = {'train':train_scalars['loss_sup']} loss_upsup_scalar = {'train':train_scalars['loss_reg']} - dice_scalar ={'train':train_scalars['avg_fg_dice'], 'valid':valid_scalars['avg_fg_dice']} + dice_scalar ={'valid':valid_scalars['avg_fg_dice']} self.summ_writer.add_scalars('loss', loss_scalar, glob_it) self.summ_writer.add_scalars('loss_sup', loss_sup_scalar, glob_it) self.summ_writer.add_scalars('loss_reg', loss_upsup_scalar, glob_it) @@ -36,9 +36,10 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): cls_dice_scalar = {'train':train_scalars['class_dice'][c], \ 'valid':valid_scalars['class_dice'][c]} self.summ_writer.add_scalars('class_{0:}_dice'.format(c), cls_dice_scalar, glob_it) - logging.info('train loss {0:.4f}, avg foreground dice {1:.4f} '.format( - train_scalars['loss'], train_scalars['avg_fg_dice']) + "[" + \ - ' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]") + logging.info('train loss {0:.4f}'.format(train_scalars['loss'])) logging.info('valid loss {0:.4f}, avg foreground dice {1:.4f} '.format( valid_scalars['loss'], valid_scalars['avg_fg_dice']) + "[" + \ ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") + logging.info("data: {0:.2f}s, forward: {1:.2f}s, loss: {2:.2f}s, backward: {3:.2f}s".format( + train_scalars['data_time'], train_scalars['forward_time'], + train_scalars['loss_time'], train_scalars['backward_time'])) diff --git a/pymic/net_run/weak_sup/wsl_dmpls.py b/pymic/net_run/weak_sup/wsl_dmpls.py index 4212409..ea96bbb 100644 --- a/pymic/net_run/weak_sup/wsl_dmpls.py +++ b/pymic/net_run/weak_sup/wsl_dmpls.py @@ -3,12 +3,13 @@ import logging import numpy as np import random +import time import torch -from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.dice import DiceLoss +from pymic.loss.seg.ce import CrossEntropyLoss from pymic.net_run.weak_sup import WSLSegAgent from pymic.util.ramps import get_rampup_ratio @@ -41,21 +42,25 @@ def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] wsl_cfg = self.config['weakly_supervised_learning'] - iter_max = self.config['training']['iter_max'] + iter_max = self.config['training']['iter_max'] rampup_start = wsl_cfg.get('rampup_start', 0) rampup_end = wsl_cfg.get('rampup_end', iter_max) - train_loss = 0 - train_loss_sup = 0 - train_loss_reg = 0 + pseudo_loss_type = wsl_cfg.get('pseudo_sup_loss', 'dice_loss') + if (pseudo_loss_type not in ('dice_loss', 'ce_loss')): + raise ValueError("""For pseudo supervision loss, only dice_loss and ce_loss \ + are supported.""") + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() for it in range(iter_valid): + t0 = time.time() try: data = next(self.trainIter) except StopIteration: self.trainIter = iter(self.train_loader) data = next(self.trainIter) - + t1 = time.time() # get the inputs inputs = self.convert_tensor_type(data['image']) y = self.convert_tensor_type(data['label_prob']) @@ -67,6 +72,8 @@ def training(self): # forward + backward + optimize outputs1, outputs2 = self.net(inputs) + t2 = time.time() + loss_sup1 = self.get_loss_value(data, outputs1, y) loss_sup2 = self.get_loss_value(data, outputs2, y) loss_sup = 0.5 * (loss_sup1 + loss_sup2) @@ -80,7 +87,7 @@ def training(self): pseudo_lab = get_soft_label(pseudo_lab, class_num, self.tensor_type) # calculate the pseudo label supervision loss - loss_calculator = DiceLoss() + loss_calculator = DiceLoss() if pseudo_loss_type == 'dice_loss' else CrossEntropyLoss() loss_dict1 = {"prediction":outputs1, 'ground_truth':pseudo_lab} loss_dict2 = {"prediction":outputs2, 'ground_truth':pseudo_lab} loss_reg = 0.5 * (loss_calculator(loss_dict1) + loss_calculator(loss_dict2)) @@ -88,8 +95,9 @@ def training(self): rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg - + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss = train_loss + loss.item() @@ -103,6 +111,11 @@ def training(self): p_soft, y = reshape_prediction_and_ground_truth(p_soft, y) dice_list = get_classwise_dice(p_soft, y) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid @@ -111,7 +124,8 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} - + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time} return train_scalers \ No newline at end of file diff --git a/pymic/net_run/weak_sup/wsl_dmsps.py b/pymic/net_run/weak_sup/wsl_dmsps.py new file mode 100644 index 0000000..b610f72 --- /dev/null +++ b/pymic/net_run/weak_sup/wsl_dmsps.py @@ -0,0 +1,202 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import logging +import os +import numpy as np +import random +import time +import torch +import scipy +from pymic.io.image_read_write import save_nd_array_as_image +from pymic.loss.seg.util import get_soft_label +from pymic.loss.seg.util import reshape_prediction_and_ground_truth +from pymic.loss.seg.util import get_classwise_dice +from pymic.loss.seg.dice import DiceLoss +from pymic.loss.seg.ce import CrossEntropyLoss +# from torch.nn.modules.loss import CrossEntropyLoss as TorchCELoss +from pymic.net_run.weak_sup import WSLSegAgent +from pymic.util.ramps import get_rampup_ratio + +class WSLDMSPS(WSLSegAgent): + """ + Weakly supervised segmentation based on Dynamically Mixed Pseudo Labels Supervision. + + * Reference: Meng Han, Xiangde Luo, Xiangjiang Xie, Wenjun Liao, Shichuan Zhang, Tao Song, + Guotai Wang, Shaoting Zhang. DMSPS: Dynamically mixed soft pseudo-label supervision for + scribble-supervised medical image segmentation. + `Medical Image Analysis 2024. `_ + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `weakly_supervised_learning` is needed. See :doc:`usage.wsl` for details. + """ + def __init__(self, config, stage = 'train'): + net_type = config['network']['net_type'] + # if net_type not in ['UNet2D_DualBranch', 'UNet3D_DualBranch']: + # raise ValueError("""For WSL_DMPLS, a dual branch network is expected. \ + # It only supports UNet2D_DualBranch and UNet3D_DualBranch currently.""") + super(WSLDMSPS, self).__init__(config, stage) + + def training(self): + class_num = self.config['network']['class_num'] + iter_valid = self.config['training']['iter_valid'] + wsl_cfg = self.config['weakly_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = wsl_cfg.get('rampup_start', 0) + rampup_end = wsl_cfg.get('rampup_end', iter_max) + pseudo_loss_type = wsl_cfg.get('pseudo_sup_loss', 'ce_loss') + if (pseudo_loss_type not in ('dice_loss', 'ce_loss')): + raise ValueError("""For pseudo supervision loss, only dice_loss and ce_loss \ + are supported.""") + pseudo_loss_func = CrossEntropyLoss() if pseudo_loss_type == 'ce_loss' else DiceLoss() + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 + train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 + self.net.train() + # ce_loss = CrossEntropyLoss() + for it in range(iter_valid): + t0 = time.time() + try: + data = next(self.trainIter) + except StopIteration: + self.trainIter = iter(self.train_loader) + data = next(self.trainIter) + t1 = time.time() + # get the inputs + inputs = self.convert_tensor_type(data['image']) + y = self.convert_tensor_type(data['label_prob']) + + inputs, y = inputs.to(self.device), y.to(self.device) + # zero the parameter gradients + self.optimizer.zero_grad() + + # forward + backward + optimize + outputs1, outputs2 = self.net(inputs) + t2 = time.time() + + loss_sup1 = self.get_loss_value(data, outputs1, y) + loss_sup2 = self.get_loss_value(data, outputs2, y) + loss_sup = 0.5 * (loss_sup1 + loss_sup2) + + # torch_ce_loss = TorchCELoss(ignore_index=class_num) + # torch_ce_loss2 = TorchCELoss() + # loss_ce1 = torch_ce_loss(outputs1, label[:].long()) + # loss_ce2 = torch_ce_loss(outputs2, label[:].long()) + # loss_sup = 0.5 * (loss_ce1 + loss_ce2) + + # get pseudo label with dynamic mixture + outputs_soft1 = torch.softmax(outputs1, dim=1) + outputs_soft2 = torch.softmax(outputs2, dim=1) + alpha = random.random() + soft_pseudo_label = alpha * outputs_soft1.detach() + (1.0-alpha) * outputs_soft2.detach() + # loss_reg = 0.5*(torch_ce_loss2(outputs_soft1, soft_pseudo_label) +torch_ce_loss2(outputs_soft2, soft_pseudo_label) ) + + loss_dict1 = {"prediction":outputs_soft1, 'ground_truth':soft_pseudo_label} + loss_dict2 = {"prediction":outputs_soft2, 'ground_truth':soft_pseudo_label} + loss_reg = 0.5 * (pseudo_loss_func(loss_dict1) + pseudo_loss_func(loss_dict2)) + + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + regular_w = wsl_cfg.get('regularize_w', 8.0) * rampup_ratio + loss = loss_sup + regular_w*loss_reg + t3 = time.time() + loss.backward() + t4 = time.time() + self.optimizer.step() + + train_loss = train_loss + loss.item() + train_loss_sup = train_loss_sup + loss_sup.item() + train_loss_reg = train_loss_reg + loss_reg.item() + # get dice evaluation for each class in annotated images + if(isinstance(outputs1, tuple) or isinstance(outputs1, list)): + outputs1 = outputs1[0] + p_argmax = torch.argmax(outputs1, dim = 1, keepdim = True) + p_soft = get_soft_label(p_argmax, class_num, self.tensor_type) + p_soft, y = reshape_prediction_and_ground_truth(p_soft, y) + dice_list = get_classwise_dice(p_soft, y) + train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 + train_avg_loss = train_loss / iter_valid + train_avg_loss_sup = train_loss_sup / iter_valid + train_avg_loss_reg = train_loss_reg / iter_valid + train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) + train_avg_dice = train_cls_dice[1:].mean() + + train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, + 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time} + return train_scalers + + def save_outputs(self, data): + """ + Save prediction output. + + :param data: (dictionary) A data dictionary with prediciton result and other + information such as input image name. + """ + output_dir = self.config['testing']['output_dir'] + test_mode = self.config['testing'].get('dmsps_test_mode', 0) + uct_threshold = self.config['testing'].get('dmsps_uncertainty_threshold', 0.1) + # DMSPS_test_mode == 0: only save the segmentation label for the main decoder + # DMSPS_test_mode == 1: save all the results, including the the probability map of each decoder, + # the uncertainty map, and the confident predictions + if(not os.path.exists(output_dir)): + os.makedirs(output_dir, exist_ok=True) + + names, pred = data['names'], data['predict'] + pred0, pred1 = pred + prob0 = scipy.special.softmax(pred0, axis = 1) + prob1 = scipy.special.softmax(pred1, axis = 1) + prob_mean = (prob0 + prob1) / 2 + lab0 = np.asarray(np.argmax(prob0, axis = 1), np.uint8) + lab1 = np.asarray(np.argmax(prob1, axis = 1), np.uint8) + lab_mean = np.asarray(np.argmax(prob_mean, axis = 1), np.uint8) + + # save the output and (optionally) probability predictions + test_dir = self.config['dataset'].get('test_dir', None) + if(test_dir is None): + test_dir = self.config['dataset']['train_dir'] + img_name = names[0][0].split('/')[-1] + print(img_name) + lab0_name = img_name + if(".h5" in lab0_name): + lab0_name = lab0_name.replace(".h5", ".nii.gz") + save_nd_array_as_image(lab0[0], output_dir + "/" + lab0_name, test_dir + '/' + names[0][0]) + if(test_mode == 1): + lab1_name = lab0_name.replace(".nii.gz", "_predaux.nii.gz") + save_nd_array_as_image(lab1[0], output_dir + "/" + lab1_name, test_dir + '/' + names[0][0]) + C = pred0.shape[1] + uct = -1.0 * np.sum(prob_mean * np.log(prob_mean), axis=1, keepdims=False)/ np.log(C) + uct_name = lab0_name.replace(".nii.gz", "_uncertainty.nii.gz") + save_nd_array_as_image(uct[0], output_dir + "/" + uct_name, test_dir + '/' + names[0][0]) + conf_mask = uct < uct_threshold + conf_lab = conf_mask * lab_mean + (1 - conf_mask)*4 + conf_lab_name = lab0_name.replace(".nii.gz", "_seeds_expand.nii.gz") + + # get the largest connected component in each slice for each class + D, H, W = conf_lab[0].shape + from pymic.util.image_process import get_largest_k_components + for d in range(D): + lab2d = conf_lab[0][d] + for c in range(C): + lab2d_c = lab2d == c + mask_c = get_largest_k_components(lab2d_c, k = 1) + diff = lab2d_c != mask_c + if(np.sum(diff) > 0): + lab2d[diff] = C + conf_lab[0][d] = lab2d + save_nd_array_as_image(conf_lab[0], output_dir + "/" + conf_lab_name, test_dir + '/' + img_name) + + + + \ No newline at end of file diff --git a/pymic/net_run/weak_sup/wsl_em.py b/pymic/net_run/weak_sup/wsl_em.py index adcd70c..987aa89 100644 --- a/pymic/net_run/weak_sup/wsl_em.py +++ b/pymic/net_run/weak_sup/wsl_em.py @@ -2,12 +2,12 @@ from __future__ import print_function, division import logging import numpy as np +import time import torch from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.ssl import EntropyLoss -from pymic.net_run.agent_seg import SegmentationAgent from pymic.net_run.weak_sup import WSLSegAgent from pymic.util.ramps import get_rampup_ratio @@ -38,18 +38,18 @@ def training(self): iter_max = self.config['training']['iter_max'] rampup_start = wsl_cfg.get('rampup_start', 0) rampup_end = wsl_cfg.get('rampup_end', iter_max) - train_loss = 0 - train_loss_sup = 0 - train_loss_reg = 0 + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() for it in range(iter_valid): + t0 = time.time() try: data = next(self.trainIter) except StopIteration: self.trainIter = iter(self.train_loader) data = next(self.trainIter) - + t1 = time.time() # get the inputs inputs = self.convert_tensor_type(data['image']) y = self.convert_tensor_type(data['label_prob']) @@ -61,6 +61,8 @@ def training(self): # forward + backward + optimize outputs = self.net(inputs) + t2 = time.time() + loss_sup = self.get_loss_value(data, outputs, y) loss_dict= {"prediction":outputs, 'softmax':True} loss_reg = EntropyLoss()(loss_dict) @@ -68,8 +70,9 @@ def training(self): rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg - + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss = train_loss + loss.item() @@ -83,6 +86,11 @@ def training(self): p_soft, y = reshape_prediction_and_ground_truth(p_soft, y) dice_list = get_classwise_dice(p_soft, y) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid @@ -91,5 +99,7 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } return train_scalers \ No newline at end of file diff --git a/pymic/net_run/weak_sup/wsl_gatedcrf.py b/pymic/net_run/weak_sup/wsl_gatedcrf.py index 2ce1f95..7eaa67d 100644 --- a/pymic/net_run/weak_sup/wsl_gatedcrf.py +++ b/pymic/net_run/weak_sup/wsl_gatedcrf.py @@ -2,6 +2,7 @@ from __future__ import print_function, division import logging import numpy as np +import time import torch from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth @@ -48,20 +49,19 @@ def training(self): iter_max = self.config['training']['iter_max'] rampup_start = wsl_cfg.get('rampup_start', 0) rampup_end = wsl_cfg.get('rampup_end', iter_max) - train_loss = 0 - train_loss_sup = 0 - train_loss_reg = 0 + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] - + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 gatecrf_loss = GatedCRFLoss() self.net.train() for it in range(iter_valid): + t0 = time.time() try: data = next(self.trainIter) except StopIteration: self.trainIter = iter(self.train_loader) data = next(self.trainIter) - + t1 = time.time() # get the inputs inputs = self.convert_tensor_type(data['image']) y = self.convert_tensor_type(data['label_prob']) @@ -73,6 +73,7 @@ def training(self): # forward + backward + optimize outputs = self.net(inputs) + t2 = time.time() loss_sup = self.get_loss_value(data, outputs, y) # for gated CRF loss, the input should be like NCHW @@ -94,8 +95,9 @@ def training(self): rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg - + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss = train_loss + loss.item() @@ -109,6 +111,11 @@ def training(self): p_soft, y = reshape_prediction_and_ground_truth(p_soft, y) dice_list = get_classwise_dice(p_soft, y) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid @@ -117,6 +124,8 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } return train_scalers \ No newline at end of file diff --git a/pymic/net_run/weak_sup/wsl_mumford_shah.py b/pymic/net_run/weak_sup/wsl_mumford_shah.py index 2480fee..e917fb3 100644 --- a/pymic/net_run/weak_sup/wsl_mumford_shah.py +++ b/pymic/net_run/weak_sup/wsl_mumford_shah.py @@ -2,6 +2,7 @@ from __future__ import print_function, division import logging import numpy as np +import time import torch from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth @@ -37,20 +38,20 @@ def training(self): iter_max = self.config['training']['iter_max'] rampup_start = wsl_cfg.get('rampup_start', 0) rampup_end = wsl_cfg.get('rampup_end', iter_max) - train_loss = 0 - train_loss_sup = 0 - train_loss_reg = 0 + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 reg_loss_calculator = MumfordShahLoss(wsl_cfg) self.net.train() for it in range(iter_valid): + t0 = time.time() try: data = next(self.trainIter) except StopIteration: self.trainIter = iter(self.train_loader) data = next(self.trainIter) - + t1 = time.time() # get the inputs inputs = self.convert_tensor_type(data['image']) y = self.convert_tensor_type(data['label_prob']) @@ -62,6 +63,7 @@ def training(self): # forward + backward + optimize outputs = self.net(inputs) + t2 = time.time() loss_sup = self.get_loss_value(data, outputs, y) loss_dict = {"prediction":outputs, 'image':inputs} loss_reg = reg_loss_calculator(loss_dict) @@ -69,8 +71,9 @@ def training(self): rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg - # if (self.config['training']['use']) + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss = train_loss + loss.item() @@ -84,6 +87,11 @@ def training(self): p_soft, y = reshape_prediction_and_ground_truth(p_soft, y) dice_list = get_classwise_dice(p_soft, y) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid @@ -92,6 +100,8 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } return train_scalers \ No newline at end of file diff --git a/pymic/net_run/weak_sup/wsl_tv.py b/pymic/net_run/weak_sup/wsl_tv.py index 9d13c5d..5a43ce0 100644 --- a/pymic/net_run/weak_sup/wsl_tv.py +++ b/pymic/net_run/weak_sup/wsl_tv.py @@ -2,6 +2,7 @@ from __future__ import print_function, division import logging import numpy as np +import time import torch from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth @@ -34,18 +35,18 @@ def training(self): iter_max = self.config['training']['iter_max'] rampup_start = wsl_cfg.get('rampup_start', 0) rampup_end = wsl_cfg.get('rampup_end', iter_max) - train_loss = 0 - train_loss_sup = 0 - train_loss_reg = 0 + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() for it in range(iter_valid): + t0 = time.time() try: data = next(self.trainIter) except StopIteration: self.trainIter = iter(self.train_loader) data = next(self.trainIter) - + t1 = time.time() # get the inputs inputs = self.convert_tensor_type(data['image']) y = self.convert_tensor_type(data['label_prob']) @@ -57,6 +58,7 @@ def training(self): # forward + backward + optimize outputs = self.net(inputs) + t2 = time.time() loss_sup = self.get_loss_value(data, outputs, y) loss_dict = {"prediction":outputs, 'softmax':True} loss_reg = TotalVariationLoss()(loss_dict) @@ -64,8 +66,9 @@ def training(self): rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg - # if (self.config['training']['use']) + t3 = time.time() loss.backward() + t4 = time.time() self.optimizer.step() train_loss = train_loss + loss.item() @@ -79,6 +82,11 @@ def training(self): p_soft, y = reshape_prediction_and_ground_truth(p_soft, y) dice_list = get_classwise_dice(p_soft, y) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid @@ -87,6 +95,8 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } return train_scalers \ No newline at end of file diff --git a/pymic/net_run/weak_sup/wsl_ustm.py b/pymic/net_run/weak_sup/wsl_ustm.py index 0ea3fbc..ac1c6a5 100644 --- a/pymic/net_run/weak_sup/wsl_ustm.py +++ b/pymic/net_run/weak_sup/wsl_ustm.py @@ -3,6 +3,7 @@ import logging import numpy as np import random +import time import torch import torch.nn.functional as F from pymic.loss.seg.util import get_soft_label @@ -54,19 +55,19 @@ def training(self): iter_max = self.config['training']['iter_max'] rampup_start = wsl_cfg.get('rampup_start', 0) rampup_end = wsl_cfg.get('rampup_end', iter_max) - train_loss = 0 - train_loss_sup = 0 - train_loss_reg = 0 + train_loss, train_loss_sup, train_loss_reg = 0, 0, 0 train_dice_list = [] + data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0 self.net.train() self.net_ema.to(self.device) for it in range(iter_valid): + t0 = time.time() try: data = next(self.trainIter) except StopIteration: self.trainIter = iter(self.train_loader) data = next(self.trainIter) - + t1 = time.time() # get the inputs inputs = self.convert_tensor_type(data['image']) y = self.convert_tensor_type(data['label_prob']) @@ -79,6 +80,7 @@ def training(self): # forward + backward + optimize noise = torch.clamp(torch.randn_like(inputs) * 0.1, -0.2, 0.2) outputs = self.net(inputs + noise) + t2 = time.time() out_prob= F.softmax(outputs, dim=1) loss_sup = self.get_loss_value(data, outputs, y) @@ -117,7 +119,7 @@ def training(self): regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg - + t3 = time.time() loss.backward() self.optimizer.step() @@ -125,7 +127,8 @@ def training(self): alpha = wsl_cfg.get('ema_decay', 0.99) alpha = min(1 - 1 / (self.glob_it / iter_valid + 1), alpha) for ema_param, param in zip(self.net_ema.parameters(), self.net.parameters()): - ema_param.data.mul_(alpha).add_(1 - alpha, param.data) + ema_param.data.mul_(alpha).add(param.data, alpha = 1.0 - alpha) + t4 = time.time() train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() @@ -138,6 +141,11 @@ def training(self): p_soft, y = reshape_prediction_and_ground_truth(p_soft, y) dice_list = get_classwise_dice(p_soft, y) train_dice_list.append(dice_list.cpu().numpy()) + + data_time = data_time + t1 - t0 + gpu_time = gpu_time + t2 - t1 + loss_time = loss_time + t3 - t2 + back_time = back_time + t4 - t3 train_avg_loss = train_loss / iter_valid train_avg_loss_sup = train_loss_sup / iter_valid train_avg_loss_reg = train_loss_reg / iter_valid @@ -146,5 +154,7 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, - 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice, + 'data_time': data_time, 'forward_time':gpu_time, + 'loss_time':loss_time, 'backward_time':back_time } return train_scalers \ No newline at end of file diff --git a/pymic/test/test_assd.py b/pymic/test/test_assd.py index 35c1804..6b2732a 100644 --- a/pymic/test/test_assd.py +++ b/pymic/test/test_assd.py @@ -18,7 +18,8 @@ def test_assd_2d(): plt.show() def test_assd_3d(): - img_name = "/home/x/projects/PyMIC_project/PyMIC_examples/seg_ssl/ACDC/result/unet2d_baseline/patient001_frame01.nii.gz" + # img_name = "/home/x/projects/PyMIC_project/PyMIC_examples/seg_ssl/ACDC/result/unet2d_baseline/patient001_frame01.nii.gz" + img_name = "/home/disk4t/data/heart/ACDC/preprocess/patient001_frame12_gt.nii.gz" img_obj = sitk.ReadImage(img_name) spacing = img_obj.GetSpacing() spacing = spacing[::-1] diff --git a/pymic/test/test_net2d.py b/pymic/test/test_net2d.py new file mode 100644 index 0000000..9013386 --- /dev/null +++ b/pymic/test/test_net2d.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import numpy as np +from pymic.net.net2d.unet2d import UNet2D +from pymic.net.net2d.unet2d_scse import UNet2D_ScSE +from pymic.net.net2d.umamba_bot import UMambaBot +def test_unet2d(): + params = {'in_chns':4, + 'feature_chns':[16, 32, 64, 128, 256], + 'dropout': [0, 0, 0.3, 0.4, 0.5], + 'class_num': 2, + 'up_mode': 0, + 'multiscale_pred': True} + Net = UNet2D(params) + Net = Net.double() + + x = np.random.rand(4, 4, 10, 256, 256) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + out = Net(xt) + if params['multiscale_pred']: + for y in out: + print(len(y.size())) + y = y.detach().numpy() + print(y.shape) + else: + print(out.shape) + +def test_unet2d_scse(): + params = {'in_chns':4, + 'feature_chns':[16, 32, 64, 128, 256], + 'dropout': [0, 0, 0.3, 0.4, 0.5], + 'class_num': 2, + 'up_mode': 0, + 'multiscale_pred': True} + Net = UNet2D_ScSE(params) + Net = Net.double() + + x = np.random.rand(4, 4, 10, 256, 256) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + out = Net(xt) + if params['multiscale_pred']: + for y in out: + print(len(y.size())) + y = y.detach().numpy() + print(y.shape) + else: + print(out.shape) + +def test_umamba(): + x = np.random.rand(4, 4, 10, 256, 256) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + Net = UMambaBot() + out = Net(xt) + y = out.detach().numpy() + print(y.shape) + +if __name__ == "__main__": + # test_unet2d() + # test_unet2d_scse() + test_umamba() \ No newline at end of file diff --git a/pymic/test/test_net3d.py b/pymic/test/test_net3d.py new file mode 100644 index 0000000..058a6fd --- /dev/null +++ b/pymic/test/test_net3d.py @@ -0,0 +1,203 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import numpy as np +from pymic.net.net3d.unet3d import UNet3D +from pymic.net.net3d.unet3d_scse import UNet3D_ScSE +from pymic.net.net3d.unet2d5 import UNet2D5 +from pymic.net.net3d.grunet import GRUNet +from pymic.net.net3d.lcovnet import LCOVNet +from pymic.net.net3d.trans3d.unetr_pp import UNETR_PP + +def test_unet3d(): + params = {'in_chns':4, + 'class_num': 2, + 'feature_chns':[2, 8, 32, 64], + 'dropout' : [0, 0, 0, 0.5], + 'up_mode': 2, + 'multiscale_pred': False} + Net = UNet3D(params) + Net = Net.double() + + x = np.random.rand(4, 4, 96, 96, 96) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = Net(xt) + y = y.detach().numpy() + print(y.shape) + + params = {'in_chns':4, + 'class_num': 2, + 'feature_chns':[2, 8, 32, 64, 128], + 'dropout' : [0, 0, 0, 0.4, 0.5], + 'up_mode': 3, + 'multiscale_pred': True} + Net = UNet3D(params) + Net = Net.double() + + x = np.random.rand(4, 4, 96, 128, 128) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + out = Net(xt) + for y in out: + y = y.detach().numpy() + print(y.shape) + +def test_unet3d_scse(): + params = {'in_chns':4, + 'feature_chns':[2, 8, 32, 48, 64], + 'dropout': [0, 0, 0.3, 0.4, 0.5], + 'class_num': 2, + 'up_mode': 2} + Net = UNet3D_ScSE(params) + Net = Net.double() + + x = np.random.rand(4, 4, 96, 96, 96) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = Net(xt) + print(len(y.size())) + y = y.detach().numpy() + print(y.shape) + +def test_lcovnet(): + params = {'in_chns':4, + 'feature_chns':[16, 32, 64, 128, 256], + 'class_num': 2} + Net = LCOVNet(params) + Net = Net.double() + + x = np.random.rand(4, 4, 96, 96, 96) + xt = torch.from_numpy(x) + xt = xt.clone().detach() + + y = Net(xt) + print(len(y.size())) + y = y.detach().numpy() + print(y.shape) + +def test_unet2d5(): + params = {'in_chns':4, + 'feature_chns':[8, 16, 32, 64, 128], + 'conv_dims': [2, 2, 3, 3, 3], + 'dropout': [0, 0, 0.3, 0.4, 0.5], + 'class_num': 2, + 'up_mode': 2, + 'multiscale_pred': True} + Net = UNet2D5(params) + Net = Net.double() + + x = np.random.rand(4, 4, 32, 128, 128) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + out = Net(xt) + for y in out: + y = y.detach().numpy() + print(y.shape) + + params = {'in_chns':4, + 'feature_chns':[8, 16, 32, 64, 128], + 'conv_dims': [2, 3, 3, 3, 3], + 'dropout': [0, 0, 0.3, 0.4, 0.5], + 'class_num': 2, + 'up_mode': 0, + 'multiscale_pred': True} + Net = UNet2D5(params) + Net = Net.double() + + x = np.random.rand(4, 4, 64, 128, 128) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + out = Net(xt) + for y in out: + y = y.detach().numpy() + print(y.shape) + +def test_mystunet(): + in_chns = 4 + num_class = 4 + # input_channels, num_classes, depth=[1,1,1,1,1,1], dims=[32, 64, 128, 256, 512, 512], + # pool_op_kernel_sizes=None, conv_kernel_sizes=None) + dims=[16, 32, 64, 128, 256, 512] + Net = MySTUNet(in_chns, num_class, dims = dims, pool_op_kernel_sizes = [[2, 2, 2], [2,2,2], [2,2,2], [2,2,2], [1, 1, 1]], + conv_kernel_sizes = [[3, 3, 3], [3,3,3], [3,3,3], [3,3,3], [3,3,3], [3, 3, 3]]) + Net = Net.double() + + x = np.random.rand(4, 4, 96, 96, 96) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + out = Net(xt) + for y in out: + y = y.detach().numpy() + print(y.shape) + +def test_grunet(): + params = {'in_chns':4, + 'feature_chns':[8, 16, 32, 64, 128], + 'dims': [2, 3, 3, 3, 3], + 'dropout': [0, 0, 0.3, 0.4, 0.5], + 'class_num': 2, + 'depth': 2, + 'multiscale_pred': True} + x = np.random.rand(4, 4, 64, 128, 128) + + # params = {'in_chns':4, + # 'feature_chns':[8, 16, 32, 64, 128], + # 'dims': [3, 3, 3, 3, 3], + # 'dropout': [0, 0, 0.3, 0.4, 0.5], + # 'class_num': 2, + # 'depth': 4, + # 'multiscale_pred': True} + # x = np.random.rand(4, 4, 96, 96, 96) + + Net = GRUNet(params) + Net = Net.double() + + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + out = Net(xt) + for y in out: + y = y.detach().numpy() + print(y.shape) + +def test_unetr_pp(): + depths = [128, 64, 32] + for i in range(3): + params = {'in_chns': 4, + 'class_num': 2, + 'img_size': [depths[i], 128, 128], + 'resolution_mode': i + } + net = UNETR_PP(params) + net.double() + + x = np.random.rand(2, 4, depths[i], 128, 128) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = net(xt) + print(len(y)) + for yi in y: + yi = yi.detach().numpy() + print(yi.shape) + + + +if __name__ == "__main__": + # test_unet3d() + # test_unet3d_scse() + test_lcovnet() + # test_unetr_pp() + # test_unet2d5() + # test_mystunet() + # test_fmunetv2() + + \ No newline at end of file diff --git a/pymic/transform/affine.py b/pymic/transform/affine.py new file mode 100644 index 0000000..1717a97 --- /dev/null +++ b/pymic/transform/affine.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import json +import math +import random +import numpy as np +from skimage import transform +from scipy import ndimage +from pymic import TaskType +from pymic.transform.abstract_transform import AbstractTransform +from pymic.util.image_process import * + +class Affine(AbstractTransform): + """ + Apply Affine Transform to an ND volume in the x-y plane. + Input shape should be [C, D, H, W] or [C, H, W]. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `Affine_scale_range`: (list or tuple) The range for scaling, e.g., (0.5, 2.0) + :param `Affine_shear_range`: (list or tuple) The range for shearing angle, e.g., (0, 30) + :param `Affine_rotate_range`: (list or tuple) The range for rotation, e.g., (-45, 45) + :param `Affine_output_size`: (None, list or tuple of length 2) The output size after affine transformation. + For 3D volumes, as we only apply affine transformation in x-y plane, the output slice + number will be the same as the input slice number, so only the output height and width + need to be given here, e.g., (H, W). By default (`None`), the output size will be the + same as the input size. + """ + def __init__(self, params): + super(Affine, self).__init__(params) + self.scale_range = params['Affine_scale_range'.lower()] + self.shear_range = params['Affine_shear_range'.lower()] + self.rotat_range = params['Affine_rotate_range'.lower()] + self.output_shape= params.get('Affine_output_size'.lower(), None) + self.inverse = params.get('Affine_inverse'.lower(), True) + + def _get_affine_param(self, sample, output_shape): + """ + output_shape should only has two dimensions, e.g., (H, W) + """ + input_shape = sample['image'].shape + input_dim = len(input_shape) - 1 + assert(len(output_shape) >=2) + + in_y, in_x = input_shape[-2:] + out_y, out_x = output_shape[-2:] + points = [[0, out_y], + [0, 0], + [out_x, 0], + [out_x, out_y]] + + sx = random.random() * (self.scale_range[1] - self.scale_range[0]) + self.scale_range[0] + sy = random.random() * (self.scale_range[1] - self.scale_range[0]) + self.scale_range[0] + shx = (random.random() * (self.shear_range[1] - self.shear_range[0]) + self.shear_range[0]) * 3.14159/180 + shy = (random.random() * (self.shear_range[1] - self.shear_range[0]) + self.shear_range[0]) * 3.14159/180 + rot = (random.random() * (self.rotat_range[1] - self.rotat_range[0]) + self.rotat_range[0]) * 3.14159/180 + # get affine transform parameters + new_points = [] + for p in points: + x = sx * p[0] * (math.cos(rot) + math.tan(shy) * math.sin(rot)) - \ + sy * p[1] * (math.tan(shx) * math.cos(rot) + math.sin(rot)) + y = sx * p[0] * (math.sin(rot) - math.tan(shy) * math.cos(rot)) - \ + sy * p[1] * (math.tan(shx) * math.sin(rot) - math.cos(rot)) + new_points.append([x,y]) + bb_min = np.array(new_points).min(axis = 0) + bb_max = np.array(new_points).max(axis = 0) + bbx, bby = int(bb_max[0] - bb_min[0]), int(bb_max[1] - bb_min[1]) + # transform the points to the image coordinate + margin_x = in_x - bbx + margin_y = in_y - bby + p0x = random.random() * margin_x if margin_x > 0 else margin_x / 2 + p0y = random.random() * margin_y if margin_y > 0 else margin_y / 2 + dst = [[new_points[i][0] - bb_min[0] + p0x, new_points[i][1] - bb_min[1] + p0y] \ + for i in range(3)] + + tform = transform.AffineTransform() + tform.estimate(np.array(points[:3]), np.array(dst)) + # to do: need to find a solution to save the affine transform matrix + # Use the matplotlib.transforms.Affine2D function to generate transform matrices, + # and the scipy.ndimage.warp function to warp images using the transform matrices. + # The skimage AffineTransform shear functionality is weird, + # and the scipy affine_transform function for warping images swaps the X and Y axes. + # sample['Affine_Param'] = json.dumps((input_shape, tform["matrix"])) + return sample, tform + + def _apply_affine_to_ND_volume(self, image, output_shape, tform, order = 2): + """ + output_shape should only has two dimensions, e.g., (H, W) + """ + dim = len(image.shape) - 1 + if(dim == 2): + C, H, W = image.shape + output = np.zeros([C] + output_shape) + for c in range(C): + output[c] = ndimage.affine_transform(image[c], tform, + output_shape = output_shape, mode='mirror', order = order) + elif(dim == 3): + C, D, H, W = image.shape + output = np.zeros([C, D] + output_shape) + for c in range(C): + for d in range(D): + output[c,d] = ndimage.affine_transform(image[c,d], tform, + output_shape = output_shape, mode='mirror', order = order) + return output + + def __call__(self, sample): + image = sample['image'] + input_shape = sample['image'].shape + output_shape= input_shape if self.output_shape is None else self.output_shape + aff_out_shape = output_shape[-2:] + sample, tform = self._get_affine_param(sample, aff_out_shape) + image_t = self._apply_affine_to_ND_volume(image, aff_out_shape, tform) + sample['image'] = image_t + + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + label = sample['label'] + label = self._apply_affine_to_ND_volume(label, aff_out_shape, tform, order = 0) + sample['label'] = label + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + weight = sample['pixel_weight'] + weight = self._apply_affine_to_ND_volume(weight, aff_out_shape, tform) + sample['pixel_weight'] = weight + return sample + + def _get_param_for_inverse_transform(self, sample): + if(isinstance(sample['Affine_Param'], list) or \ + isinstance(sample['Affine_Param'], tuple)): + params = json.loads(sample['Affine_Param'][0]) + else: + params = json.loads(sample['Affine_Param']) + return params + + # def inverse_transform_for_prediction(self, sample): + # params = self._get_param_for_inverse_transform(sample) + # origin_shape = params[0] + # tform = params[1] + + # predict = sample['predict'] + # if(isinstance(predict, tuple) or isinstance(predict, list)): + # output_predict = [] + # for predict_i in predict: + # aff_out_shape = origin_shape[-2:] + # output_predict_i = self._apply_affine_to_ND_volume(predict_i, + # aff_out_shape, tform.inverse) + # output_predict.append(output_predict_i) + # else: + # aff_out_shape = origin_shape[-2:] + # output_predict = self._apply_affine_to_ND_volume(predict, aff_out_shape, tform.inverse) + + # sample['predict'] = output_predict + # return sample diff --git a/pymic/transform/crop.py b/pymic/transform/crop.py index a27288d..c444acd 100644 --- a/pymic/transform/crop.py +++ b/pymic/transform/crop.py @@ -7,6 +7,7 @@ import random import numpy as np from scipy import ndimage +from pymic import TaskType from pymic.transform.abstract_transform import AbstractTransform from pymic.util.image_process import * @@ -55,12 +56,14 @@ def __call__(self, sample): image_t = crop_ND_volume_with_bounding_box(image, crop_min, crop_max) sample['image'] = image_t - if('label' in sample and self.task == 'segmentation'): + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): label = sample['label'] crop_max[0] = label.shape[0] label = crop_ND_volume_with_bounding_box(label, crop_min, crop_max) sample['label'] = label - if('pixel_weight' in sample and self.task == 'segmentation'): + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): weight = sample['pixel_weight'] crop_max[0] = weight.shape[0] weight = crop_ND_volume_with_bounding_box(weight, crop_min, crop_max) @@ -110,16 +113,20 @@ class CropWithBoundingBox(CenterCrop): :param `CropWithBoundingBox_start`: (None, or list/tuple) The start index along each spatial axis. If None, calculate the start index automatically - so that the cropped region is centered at the non-zero region. + so that the cropped region is centered at the mask region defined by the threshold. :param `CropWithBoundingBox_output_size`: (None or tuple/list): Desired spatial output size. - If None, set it as the size of bounding box of non-zero region. + If None, set it as the size of bounding box of the mask region defined by the threshold. + :param `CropWithBoundingBox_threshold`: (None or float): + Threshold for obtaining a mask. This is used only when + `CropWithBoundingBox_start` is None. Default is 1.0 :param `CropWithBoundingBox_inverse`: (optional, bool) Is inverse transform needed for inference. Default is `True`. """ def __init__(self, params): self.start = params['CropWithBoundingBox_start'.lower()] self.output_size = params['CropWithBoundingBox_output_size'.lower()] + self.threshold = params.get('CropWithBoundingBox_threshold'.lower(), 1.0) self.inverse = params.get('CropWithBoundingBox_inverse'.lower(), True) self.task = params['task'] @@ -127,8 +134,9 @@ def _get_crop_param(self, sample): image = sample['image'] input_shape = sample['image'].shape input_dim = len(input_shape) - 1 - bb_min, bb_max = get_ND_bounding_box(image) - bb_min, bb_max = bb_min[1:], bb_max[1:] + if(self.start is None or self.output_size is None): + bb_min, bb_max = get_ND_bounding_box(image > self.threshold) + bb_min, bb_max = bb_min[1:], bb_max[1:] if(self.start is None): if(self.output_size is None): crop_min, crop_max = bb_min, bb_max @@ -150,7 +158,6 @@ def _get_crop_param(self, sample): crop_min = [0] + crop_min crop_max = list(input_shape[0:1]) + crop_max sample['CropWithBoundingBox_Param'] = json.dumps((input_shape, crop_min, crop_max)) - print("for crop", crop_min, crop_max) return sample, crop_min, crop_max def _get_param_for_inverse_transform(self, sample): @@ -161,7 +168,47 @@ def _get_param_for_inverse_transform(self, sample): params = json.loads(sample['CropWithBoundingBox_Param']) return params +class CropWithForeground(CenterCrop): + """ + Crop the image (shape [C, D, H, W] or [C, H, W]) based on a bounding box. + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `CropWithBoundingBox_start`: (None, or list/tuple) The start index + along each spatial axis. If None, calculate the start index automatically + so that the cropped region is centered at the non-zero region. + :param `CropWithBoundingBox_output_size`: (None or tuple/list): + Desired spatial output size. + If None, set it as the size of bounding box of non-zero region. + :param `CropWithBoundingBox_inverse`: (optional, bool) Is inverse transform needed for inference. + Default is `True`. + """ + def __init__(self, params): + self.labels = params.get('CropWithForeground_labels'.lower(), None) + self.margin = params.get('CropWithForeground_margin'.lower(), [5, 10, 10]) + self.inverse = params.get('CropWithForeground_inverse'.lower(), True) + self.task = params['task'] + + def _get_crop_param(self, sample): + image = sample['image'] + label = sample['label'] + input_shape = sample['image'].shape + bb_min, bb_max = get_ND_bounding_box(label, margin=[0] + self.margin) + bb_max[0] = input_shape[0] + + sample['CropWithForeground_Param'] = json.dumps((input_shape, bb_min, bb_max)) + + return sample, bb_min, bb_max + + def _get_param_for_inverse_transform(self, sample): + if(isinstance(sample['CropWithForeground_Param'], list) or \ + isinstance(sample['CropWithForeground_Param'], tuple)): + params = json.loads(sample['CropWithForeground_Param'][0]) + else: + params = json.loads(sample['CropWithForeground_Param']) + return params + class RandomCrop(CenterCrop): """Randomly crop the input image (shape [C, D, H, W] or [C, H, W]). @@ -170,7 +217,9 @@ class RandomCrop(CenterCrop): :param `RandomCrop_output_size`: (list/tuple) Desired output size [D, H, W] or [H, W]. The output channel is the same as the input channel. - If D is None for 3D images, the z-axis is not cropped. + If `None` is set for a certain axis, that axis will not be cropped. For example, + for 3D vlumes, (None, H, W) means only crop in 2D, and (D, None, None) means only + crop along the z axis. :param `RandomCrop_foreground_focus`: (optional, bool) If true, allow crop around the foreground. Default is False. :param `RandomCrop_foreground_ratio`: (optional, float) @@ -178,7 +227,8 @@ class RandomCrop(CenterCrop): `RandomCrop_foreground_focus` is True. :param `RandomCrop_mask_label`: (optional, None, or list/tuple) Specifying the foreground labels for foreground focus cropping when - `RandomCrop_foreground_focus` is True. + `RandomCrop_foreground_focus` is True. If it is None (by default), + the mask label will be the list of all the foreground classes. :param `RandomCrop_inverse`: (optional, bool) Is inverse transform needed for inference. Default is `True`. """ @@ -186,7 +236,7 @@ def __init__(self, params): self.output_size = params['RandomCrop_output_size'.lower()] self.fg_focus = params.get('RandomCrop_foreground_focus'.lower(), False) self.fg_ratio = params.get('RandomCrop_foreground_ratio'.lower(), 0.5) - self.mask_label = params.get('RandomCrop_mask_label'.lower(), [1]) + self.mask_label = params.get('RandomCrop_mask_label'.lower(), None) self.inverse = params.get('RandomCrop_inverse'.lower(), True) self.task = params['Task'.lower()] assert isinstance(self.output_size, (list, tuple)) @@ -195,39 +245,35 @@ def __init__(self, params): def _get_crop_param(self, sample): image = sample['image'] - input_shape = image.shape - input_dim = len(input_shape) - 1 + chns = image.shape[0] + input_shape = image.shape[1:] + input_dim = len(input_shape) assert(input_dim == len(self.output_size)) - temp_output_size = self.output_size - if(input_dim == 3 and self.output_size[0] is None): - # note that output size is [D, H, W] and input is [C, D, H, W] - temp_output_size = [input_shape[1]] + self.output_size[1:] - - crop_margin = [input_shape[i + 1] - temp_output_size[i]\ - for i in range(input_dim)] + + output_size = [item for item in self.output_size] + # print("crop input and output size", input_shape, output_size) + for i in range(input_dim): + if(output_size[i] is None): + output_size[i] = input_shape[i] + # print(output_size) + crop_margin = [input_shape[i] - output_size[i] for i in range(input_dim)] crop_min = [0 if item == 0 else random.randint(0, item) for item in crop_margin] - if(self.fg_focus and random.random() < self.fg_ratio): - label = sample['label'] - mask = np.zeros_like(label) - for temp_lab in self.mask_label: - mask = np.maximum(mask, label == temp_lab) - if(mask.sum() == 0): - bb_min = [0] * (input_dim + 1) - bb_max = mask.shape + crop_max = [crop_min[i] + output_size[i] for i in range(input_dim)] + + label_exist = True if ('label' in sample and sample['label'].sum() > 0) else False + if(label_exist and self.fg_focus and random.random() < self.fg_ratio): + label = sample['label'][0] + if(self.mask_label is None): + mask_label = np.unique(label)[1:] else: - bb_min, bb_max = get_ND_bounding_box(mask) - bb_min, bb_max = bb_min[1:], bb_max[1:] - crop_min = [random.randint(bb_min[i], bb_max[i]) - int(temp_output_size[i]/2) \ - for i in range(input_dim)] - crop_min = [max(0, item) for item in crop_min] - crop_min = [min(crop_min[i], input_shape[i+1] - temp_output_size[i]) \ - for i in range(input_dim)] + mask_label = self.mask_label + random_label = random.choice(mask_label) + crop_min, crop_max = get_random_box_from_mask(label == random_label, output_size, mode = 1) - crop_max = [crop_min[i] + temp_output_size[i] \ - for i in range(input_dim)] crop_min = [0] + crop_min - crop_max = list(input_shape[0:1]) + crop_max - sample['RandomCrop_Param'] = json.dumps((input_shape, crop_min, crop_max)) + crop_max = [chns] + crop_max + + sample['RandomCrop_Param'] = json.dumps((image.shape, crop_min, crop_max)) return sample, crop_min, crop_max def _get_param_for_inverse_transform(self, sample): @@ -240,76 +286,218 @@ def _get_param_for_inverse_transform(self, sample): class RandomResizedCrop(CenterCrop): """ - Randomly crop the input image (shape [C, H, W]). Only 2D images are supported. - + Randomly resize and crop the input image (shape [C, D, H, W]). The arguments should be written in the `params` dictionary, and it has the following fields: - :param `RandomResizedCrop_output_size`: (list/tuple) Desired output size [H, W]. + :param `RandomResizedCrop_output_size`: (list/tuple) Desired output size [D, H, W]. The output channel is the same as the input channel. - :param `RandomResizedCrop_scale`: (list/tuple) Range of scale, e.g. (0.08, 1.0). - :param `RandomResizedCrop_ratio`: (list/tuple) Range of aspect ratio, e.g. (0.75, 1.33). + :param `RandomResizedCrop_scale_lower_bound`: (list/tuple) Lower bound of the range of scale + for each dimension. e.g. (1.0, 0.5, 0.5). + param `RandomResizedCrop_scale_upper_bound`: (list/tuple) Upper bound of the range of scale + for each dimension. e.g. (1.0, 2.0, 2.0). :param `RandomResizedCrop_inverse`: (optional, bool) Is inverse transform needed for inference. Default is `False`. Currently, the inverse transform is not supported, and this transform is assumed to be used only during training stage. """ def __init__(self, params): self.output_size = params['RandomResizedCrop_output_size'.lower()] - self.scale = params['RandomResizedCrop_scale'.lower()] - self.ratio = params['RandomResizedCrop_ratio'.lower()] + self.scale_lower = params['RandomResizedCrop_resize_lower_bound'.lower()] + self.scale_upper = params['RandomResizedCrop_resize_upper_bound'.lower()] + self.prob = params.get('RandomResizedCrop_resize_prob'.lower(), 0.5) + self.fg_ratio = params.get('RandomResizedCrop_foreground_ratio'.lower(), 0.0) + self.mask_label = params.get('RandomResizedCrop_mask_label'.lower(), None) self.inverse = params.get('RandomResizedCrop_inverse'.lower(), False) self.task = params['Task'.lower()] assert isinstance(self.output_size, (list, tuple)) - assert isinstance(self.scale, (list, tuple)) - assert isinstance(self.ratio, (list, tuple)) + assert isinstance(self.scale_lower, (list, tuple)) + assert isinstance(self.scale_upper, (list, tuple)) - def _get_crop_param(self, sample): + def __call__(self, sample): image = sample['image'] - input_shape = image.shape - input_dim = len(input_shape) - 1 - assert(input_dim == 2) + channel, input_size = image.shape[0], image.shape[1:] + input_dim = len(input_size) assert(input_dim == len(self.output_size)) - - scale = self.scale[0] + random.random()*(self.scale[1] - self.scale[0]) - ratio = self.ratio[0] + random.random()*(self.ratio[1] - self.ratio[0]) - crop_w = input_shape[-1] * scale - crop_h = crop_w * ratio - crop_h = min(crop_h, input_shape[-2]) - output_shape = [int(crop_h), int(crop_w)] - - crop_margin = [input_shape[i + 1] - output_shape[i]\ - for i in range(input_dim)] - crop_min = [random.randint(0, item) for item in crop_margin] - crop_max = [crop_min[i] + output_shape[i] \ - for i in range(input_dim)] - crop_min = [0] + crop_min - crop_max = list(input_shape[0:1]) + crop_max - sample['RandomResizedCrop_Param'] = json.dumps((input_shape, crop_min, crop_max)) - return sample, crop_min, crop_max - def __call__(self, sample): - image = sample['image'] - input_shape = image.shape - input_dim = len(input_shape) - 1 - sample, crop_min, crop_max = self._get_crop_param(sample) + # get the resized crop size + resize = random.random() < self.prob + if(resize): + scale = [self.scale_lower[i] + (self.scale_upper[i] - self.scale_lower[i]) * random.random() \ + for i in range(input_dim)] + crop_size = [int(self.output_size[i] * scale[i]) for i in range(input_dim)] + else: + crop_size = self.output_size + + crop_margin = [input_size[i] - crop_size[i] for i in range(input_dim)] + pad_image = min(crop_margin) < 0 + if(pad_image): # pad the image if necessary + pad_size = [max(0, -crop_margin[i]) for i in range(input_dim)] + pad_lower = [int(pad_size[i] / 2) for i in range(input_dim)] + pad_upper = [pad_size[i] - pad_lower[i] for i in range(input_dim)] + pad = [(pad_lower[i], pad_upper[i]) for i in range(input_dim)] + pad = tuple([(0, 0)] + pad) + image = np.pad(image, pad, 'reflect') + crop_margin = [max(0, crop_margin[i]) for i in range(input_dim)] + # ge the bounding box for crop + if(random.random() < self.fg_ratio): + label = sample['label'] + if(pad_image): + label = np.pad(label, pad, 'reflect') + label = label[0] + if(self.mask_label is None): + mask_label = np.unique(label)[1:] + else: + mask_label = self.mask_label + random_label = random.choice(mask_label) + crop_min, crop_max = get_random_box_from_mask(label == random_label, crop_size, mode = 1) + else: + crop_min = [random.randint(0, item) for item in crop_margin] + crop_max = [crop_min[i] + crop_size[i] for i in range(input_dim)] + crop_min = [0] + crop_min + crop_max = [channel] + crop_max image_t = crop_ND_volume_with_bounding_box(image, crop_min, crop_max) - crp_shape = image_t.shape - scale = [(self.output_size[i] + 0.0)/crp_shape[1:][i] for i in range(input_dim)] - scale = [1.0] + scale - image_t = ndimage.interpolation.zoom(image_t, scale, order = 1) + if(resize): + scale = [(self.output_size[i] + 0.0)/crop_size[i] for i in range(input_dim)] + scale = [1.0] + scale + image_t = ndimage.interpolation.zoom(image_t, scale, order = 1) sample['image'] = image_t - if('label' in sample and self.task == 'segmentation'): + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): label = sample['label'] + if(pad_image): + label = np.pad(label, pad, 'reflect') crop_max[0] = label.shape[0] label = crop_ND_volume_with_bounding_box(label, crop_min, crop_max) - label = ndimage.interpolation.zoom(label, scale, order = 0) + if(resize): + order = 0 if(self.task == TaskType.SEGMENTATION) else 1 + label = ndimage.interpolation.zoom(label, scale, order = order) sample['label'] = label - if('pixel_weight' in sample and self.task == 'segmentation'): + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): weight = sample['pixel_weight'] + if(pad_image): + weight = np.pad(weight, pad, 'reflect') crop_max[0] = weight.shape[0] weight = crop_ND_volume_with_bounding_box(weight, crop_min, crop_max) - weight = ndimage.interpolation.zoom(weight, scale, order = 1) + if(resize): + weight = ndimage.interpolation.zoom(weight, scale, order = 1) sample['pixel_weight'] = weight - return sample \ No newline at end of file + return sample + +class RandomSlice(AbstractTransform): + """Randomly selecting N slices from a volume + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `RandomSlice_output_size`: (int) Desired number of slice for output. + :param `RandomSlice_inverse`: (optional, bool) Is inverse transform needed for inference. + Default is `True`. + """ + def __init__(self, params): + self.output_size = params['RandomSlice_output_size'.lower()] + self.fg_focus = params.get('RandomSlice_foreground_focus'.lower(), False) + self.fg_ratio = params.get('RandomSlice_foreground_ratio'.lower(), 0.5) + self.mask_label = params.get('RandomSlice_mask_label'.lower(), None) + self.shuffle = params.get('RandomSlice_shuffle'.lower(), False) + self.inverse = params.get('RandomSlice_inverse'.lower(), False) + self.task = params['Task'.lower()] + + def __call__(self, sample): + image = sample['image'] + D = image.shape[1] + assert( D >= self.output_size) + out_half = self.output_size // 2 + + label_exist = True if ('label' in sample and sample['label'].sum() > 0) else False + if(label_exist and self.fg_focus and random.random() < self.fg_ratio): + label = sample['label'][0] + if(self.mask_label is None): + mask_label = np.unique(label)[1:] + else: + mask_label = self.mask_label + random_label = random.choice(mask_label) + mask = label == random_label + dc = random.choice(np.nonzero(mask)[0]) + else: + dc = random.choice(range(out_half, D - out_half)) + + slice_idx = list(range(D)) + if(self.shuffle): + random.shuffle(slice_idx) + d0 = random.randint(0, D - self.output_size) + d1 = d0 + self.output_size + slice_idx = slice_idx[d0:d1-1] + [dc] + else: + d0 = max(0, dc - out_half) + d1 = d0 + self.output_size + slice_idx = slice_idx[d0:d1] + sample['image'] = image[:, slice_idx, :, :] + + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + label = sample['label'] + sample['label'] = label[:, slice_idx, :, :] + + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + weight = sample['pixel_weight'] + sample['pixel_weight'] = weight[:, slice_idx, :, :] + + return sample + +class CropHumanRegion(CenterCrop): + """ + Crop the human region from a CT for MRI volume. + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `CropWithBoundingBox_start`: (None, or list/tuple) The start index + along each spatial axis. If None, calculate the start index automatically + so that the cropped region is centered at the mask region defined by the threshold. + :param `CropWithBoundingBox_output_size`: (None or tuple/list): + Desired spatial output size. + If None, set it as the size of bounding box of the mask region defined by the threshold. + :param `CropWithBoundingBox_threshold`: (None or float): + Threshold for obtaining a mask. This is used only when + `CropWithBoundingBox_start` is None. Default is 1.0 + :param `CropWithBoundingBox_inverse`: (optional, bool) Is inverse transform needed for inference. + Default is `True`. + """ + def __init__(self, params): + self.threshold_i = params.get('CropHumanRegion_intensity_threshold'.lower(), -600) + self.threshold_mode = params.get('CropHumanRegion_threshold_mode'.lower(), 'mean') + self.threshold_z = params.get('CropHumanRegion_zaxis_threshold'.lower(), 0.5) + self.inverse = params.get('CropHumanRegion_inverse'.lower(), True) + self.task = params['task'] + + def _get_crop_param(self, sample): + image = sample['image'] + input_shape = image.shape + mask = np.asarray(image[0] > self.threshold_i) + if(self.threshold_mode == "mean"): + mask2d = np.mean(mask, axis = 0) > self.threshold_z + else: + mask2d = np.max(mask, axis = 0) + se = np.ones([3,3]) + mask2d = ndimage.binary_opening(mask2d, se, iterations = 2) + if(mask2d.sum() > 0): + mask2d = get_largest_k_components(mask2d, 1) + bbmin, bbmax = get_ND_bounding_box(mask2d, margin = [0, 0]) + else: + bbmin = [0] * (image.ndim - 2) + bbmax = list(input_shape[2:]) + crop_min = [0, 0] + bbmin + crop_max = list(input_shape[:2]) + bbmax + sample['CropHumanRegion_Param'] = json.dumps((input_shape, crop_min, crop_max)) + return sample, crop_min, crop_max + + def _get_param_for_inverse_transform(self, sample): + if(isinstance(sample['CropHumanRegion_Param'], list) or \ + isinstance(sample['CropHumanRegion_Param'], tuple)): + params = json.loads(sample['CropHumanRegion_Param'][0]) + else: + params = json.loads(sample['CropHumanRegion_Param']) + return params \ No newline at end of file diff --git a/pymic/transform/crop4dino.py b/pymic/transform/crop4dino.py new file mode 100644 index 0000000..df86787 --- /dev/null +++ b/pymic/transform/crop4dino.py @@ -0,0 +1,175 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import json +import math +import random +import numpy as np +from scipy import ndimage +from pymic import TaskType +from pymic.transform.abstract_transform import AbstractTransform +from pymic.transform.crop import CenterCrop +from pymic.transform.intensity import * +from pymic.util.image_process import * + +class Crop4Dino(CenterCrop): + """ + Randomly crop an volume into two views with augmentation. This is used for + self-supervised pretraining such as DeSD. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `DualViewCrop_output_size`: (list/tuple) Desired output size [D, H, W]. + The output channel is the same as the input channel. + :param `DualViewCrop_scale_lower_bound`: (list/tuple) Lower bound of the range of scale + for each dimension. e.g. (1.0, 0.5, 0.5). + param `DualViewCrop_scale_upper_bound`: (list/tuple) Upper bound of the range of scale + for each dimension. e.g. (1.0, 2.0, 2.0). + :param `DualViewCrop_inverse`: (optional, bool) Is inverse transform needed for inference. + Default is `False`. Currently, the inverse transform is not supported, and + this transform is assumed to be used only during training stage. + """ + def __init__(self, params): + self.output_size = params['Crop4Dino_output_size'.lower()] + self.scale_lower = params['Crop4Dino_resize_lower_bound'.lower()] + self.scale_upper = params['Crop4Dino_resize_upper_bound'.lower()] + self.prob = params.get('Crop4Dino_resize_prob'.lower(), 0.5) + self.noise_std_range = params.get('Crop4Dino_noise_std_range'.lower(), (0.05, 0.1)) + self.blur_sigma_range = params.get('Crop4Dino_blur_sigma_range'.lower(), (1.0, 3.0)) + self.gamma_range = params.get('Crop4Dino_gamma_range'.lower(), (0.75, 1.25)) + self.inverse = params.get('Crop4Dino_inverse'.lower(), False) + self.task = params['Task'.lower()] + assert isinstance(self.output_size, (list, tuple)) + assert isinstance(self.scale_lower, (list, tuple)) + assert isinstance(self.scale_upper, (list, tuple)) + + def __call__(self, sample): + image = sample['image'] + channel, input_size = image.shape[0], image.shape[1:] + input_dim = len(input_size) + assert(input_dim == len(self.output_size)) + + # # center crop first + # crop_size = self.output_size + # crop_margin = [input_size[i] - crop_size[i] for i in range(input_dim)] + # crop_min = [int(item/2) for item in crop_margin] + # crop_max = [crop_min[i] + crop_size[i] for i in range(input_dim)] + # crop_min = [0] + crop_min + # crop_max = [channel] + crop_max + # crop0 = crop_ND_volume_with_bounding_box(image, crop_min, crop_max) + + crop_num = 2 + crop_img = [] + for crop_i in range(crop_num): + resize = random.random() < self.prob + if(resize): + scale = [self.scale_lower[i] + (self.scale_upper[i] - self.scale_lower[i]) * random.random() \ + for i in range(input_dim)] + crop_size = [int(self.output_size[i] * scale[i]) for i in range(input_dim)] + else: + crop_size = self.output_size + + crop_margin = [input_size[i] - crop_size[i] for i in range(input_dim)] + pad_image = min(crop_margin) < 0 + if(pad_image): # pad the image if necessary + pad_size = [max(0, -crop_margin[i]) for i in range(input_dim)] + pad_lower = [int(pad_size[i] / 2) for i in range(input_dim)] + pad_upper = [pad_size[i] - pad_lower[i] for i in range(input_dim)] + pad = [(pad_lower[i], pad_upper[i]) for i in range(input_dim)] + pad = tuple([(0, 0)] + pad) + image = np.pad(image, pad, 'reflect') + crop_margin = [max(0, crop_margin[i]) for i in range(input_dim)] + + + crop_min = [random.randint(0, item) for item in crop_margin] + crop_max = [crop_min[i] + crop_size[i] for i in range(input_dim)] + crop_min = [0] + crop_min + crop_max = [channel] + crop_max + + crop_out = crop_ND_volume_with_bounding_box(image, crop_min, crop_max) + if(resize): + scale = [(self.output_size[i] + 0.0)/crop_size[i] for i in range(input_dim)] + scale = [1.0] + scale + crop_out = ndimage.interpolation.zoom(crop_out, scale, order = 1) + + # add intensity augmentation + C = crop_out.shape[0] + for c in range(C): + if(random.random() < 0.8): + crop_out[c] = gaussian_noise(crop_out[c], self.noise_std_range[0], self.noise_std_range[1]) + + if(random.uniform(0, 1) < 0.5): + crop_out[c] = gaussian_blur(crop_out[c], self.blur_sigma_range[0], self.blur_sigma_range[1]) + else: + alpha = random.uniform(0.0, 2.0) + crop_out[c] = gaussian_sharpen(crop_out[c], self.blur_sigma_range[0], self.blur_sigma_range[1], alpha) + if(random.random() < 0.8): + crop_out[c] = gamma_correction(crop_out[c], self.gamma_range[0], self.gamma_range[1]) + if(random.random() < 0.8): + crop_out[c] = window_level_augment(crop_out[c]) + crop_img.append(crop_out) + sample['image'] = crop_img + return sample + + def __call__backup(self, sample): + image = sample['image'] + channel, input_size = image.shape[0], image.shape[1:] + input_dim = len(input_size) + assert(input_dim == len(self.output_size)) + + # center crop first + crop_size = self.output_size + crop_margin = [input_size[i] - crop_size[i] for i in range(input_dim)] + crop_min = [int(item/2) for item in crop_margin] + crop_max = [crop_min[i] + crop_size[i] for i in range(input_dim)] + crop_min = [0] + crop_min + crop_max = [channel] + crop_max + crop0 = crop_ND_volume_with_bounding_box(image, crop_min, crop_max) + + # crop_num = 2 + # crop_img = [] + # for crop_i in range(crop_num): + # get another resized crop size + resize = random.random() < self.prob + if(resize): + scale = [self.scale_lower[i] + (self.scale_upper[i] - self.scale_lower[i]) * random.random() \ + for i in range(input_dim)] + crop_size = [int(self.output_size[i] * scale[i]) for i in range(input_dim)] + else: + crop_size = self.output_size + + crop_margin = [input_size[i] - crop_size[i] for i in range(input_dim)] + pad_image = min(crop_margin) < 0 + if(pad_image): # pad the image if necessary + pad_size = [max(0, -crop_margin[i]) for i in range(input_dim)] + pad_lower = [int(pad_size[i] / 2) for i in range(input_dim)] + pad_upper = [pad_size[i] - pad_lower[i] for i in range(input_dim)] + pad = [(pad_lower[i], pad_upper[i]) for i in range(input_dim)] + pad = tuple([(0, 0)] + pad) + image = np.pad(image, pad, 'reflect') + crop_margin = [max(0, crop_margin[i]) for i in range(input_dim)] + + + crop_min = [random.randint(0, item) for item in crop_margin] + crop_max = [crop_min[i] + crop_size[i] for i in range(input_dim)] + crop_min = [0] + crop_min + crop_max = [channel] + crop_max + + crop_out = crop_ND_volume_with_bounding_box(image, crop_min, crop_max) + if(resize): + scale = [(self.output_size[i] + 0.0)/crop_size[i] for i in range(input_dim)] + scale = [1.0] + scale + crop_out = ndimage.interpolation.zoom(crop_out, scale, order = 1) + # crop_img.append(crop_out) + crop_img = [crop0, crop_out] + # add intensity augmentation + # image_t = gaussian_noise(image_t, self.noise_std_range[0], self.noise_std_range[1], 0.8) + # image_t = gaussian_blur(image_t, self.blur_sigma_range[0], self.blur_sigma_range[1], 0.8) + # image_t = brightness_multiplicative(image_t, self.inten_multi_range[0], self.inten_multi_range[1], 0.8) + # image_t = brightness_additive(image_t, self.inten_add_range[0], self.inten_add_range[1], 0.8) + # image_t = contrast_augment(image_t, self.contrast_f_range[0], self.contrast_f_range[1], 0.8) + # image_t = gamma_correction(image_t, self.gamma_range[0], self.gamma_range[1], 0.8) + sample['image'] = crop_img + return sample diff --git a/pymic/transform/crop4vf.py b/pymic/transform/crop4vf.py new file mode 100644 index 0000000..4e07357 --- /dev/null +++ b/pymic/transform/crop4vf.py @@ -0,0 +1,232 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch + +import json +import math +import random +import numpy as np +from imops import crop_to_box +from typing import * +from scipy import ndimage +from pymic import TaskType +from pymic.transform.abstract_transform import AbstractTransform +from pymic.transform.crop import CenterCrop +from pymic.util.image_process import * +from pymic.transform.intensity import * + + +def random_resized_crop(image, output_size, scale_lower, scale_upper): + input_size = image.shape + scale = [scale_lower[i] + (scale_upper[i] - scale_lower[i]) * random.random() \ + for i in range(3)] + crop_size = [min(int(output_size[i] * scale[i]), input_size[1+i]) for i in range(3)] + crop_margin = [input_size[1+i] - crop_size[i] for i in range(3)] + crop_min = [random.randint(0, item) for item in crop_margin] + crop_max = [crop_min[i] + crop_size[i] for i in range(3)] + crop_min = [0] + crop_min + crop_max = [input_size[0]] + crop_max + + image_t = crop_ND_volume_with_bounding_box(image, crop_min, crop_max) + scale = [(output_size[i] + 0.0)/crop_size[i] for i in range(3)] + scale = [1.0] + scale + image_t = ndimage.interpolation.zoom(image_t, scale, order = 1) + return image_t + +def random_flip(image): + flip_axis = [] + if(random.random() > 0.5): + flip_axis.append(-1) + if(random.random() > 0.5): + flip_axis.append(-2) + if(random.random() > 0.5): + flip_axis.append(-3) + if(len(flip_axis) > 0): + image = np.flip(image , flip_axis) + return image + + +class Crop4VolumeFusion(AbstractTransform): + """ + Randomly crop an volume into two views with augmentation. This is used for + self-supervised pretraining in Vox2vec. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `Crop4VolumeFusion_output_size`: (list/tuple) Desired output size [D, H, W]. + The output channel is the same as the input channel. + :param `Crop4VolumeFusion_rescale_lower_bound`: (list/tuple) Lower bound of the range of scale + for each dimension. e.g. (1.0, 0.5, 0.5). + param `Crop4VolumeFusion_rescale_upper_bound`: (list/tuple) Upper bound of the range of scale + for each dimension. e.g. (1.0, 2.0, 2.0). + :param `Crop4VolumeFusion_augentation_mode`: (optional, int) The mode for augmentation of cropped volume. + 0: no spatial or intensity augmentatin. + 1: intensity augmentation only +` 2: spatial augmentation only + 3: Both intensity and spatial augmentation (default). + """ + def __init__(self, params): + self.output_size = params['Crop4VolumeFusion_output_size'.lower()] + self.scale_lower = params.get('Crop4VolumeFusion_rescale_lower_bound'.lower(), [0.7, 0.7, 0.7]) + self.scale_upper = params.get('Crop4VolumeFusion_rescale_upper_bound'.lower(), [1.5, 1.5, 1.5]) + self.aug_mode = params.get('Crop4VolumeFusion_augentation_mode'.lower(), 3) + self.task = params['Task'.lower()] + assert isinstance(self.output_size, (list, tuple)) + + def __call__(self, sample): + image = sample['image'] + channel, input_size = image.shape[0], image.shape[1:] + input_dim = len(input_size) + assert channel == 1 + assert(input_dim == len(self.output_size)) + + if(self.aug_mode == 0 or self.aug_mode == 1): + self.scale_lower = [1.0, 1.0, 1.0] + self.scale_upper = [1.0, 1.0, 1.0] + patch_1 = random_resized_crop(image, self.output_size, self.scale_lower, self.scale_upper) + patch_2 = random_resized_crop(image, self.output_size, self.scale_lower, self.scale_upper) + if(self.aug_mode > 1): + patch_1 = random_flip(patch_1) + patch_2 = random_flip(patch_2) + if(self.aug_mode == 1 or self.aug_mode == 3): + p0, p1 = random.uniform(0.1, 2.0), random.uniform(98, 99.9) + patch_1 = adaptive_contrast_adjust(patch_1, p0, p1) + patch_1 = gamma_correction(patch_1, 0.7, 1.5) + + p0, p1 = random.uniform(0.1, 2.0), random.uniform(98, 99.9) + patch_2 = adaptive_contrast_adjust(patch_2, p0, p1) + patch_2 = gamma_correction(patch_2, 0.7, 1.5) + + if(random.random() < 0.25): + patch_1 = 1.0 - patch_1 + patch_2 = 1.0 - patch_2 + + sample['image'] = patch_1, patch_2 + return sample + +class VolumeFusion(AbstractTransform): + """ + Randomly crop an volume into two views with augmentation. This is used for + self-supervised pretraining in Vox2vec. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `DualViewCrop_output_size`: (list/tuple) Desired output size [D, H, W]. + The output channel is the same as the input channel. + :param `DualViewCrop_scale_lower_bound`: (list/tuple) Lower bound of the range of scale + for each dimension. e.g. (1.0, 0.5, 0.5). + param `DualViewCrop_scale_upper_bound`: (list/tuple) Upper bound of the range of scale + for each dimension. e.g. (1.0, 2.0, 2.0). + :param `DualViewCrop_inverse`: (optional, bool) Is inverse transform needed for inference. + Default is `False`. Currently, the inverse transform is not supported, and + this transform is assumed to be used only during training stage. + """ + def __init__(self, params): + self.cls_num = params.get('VolumeFusion_cls_num'.lower(), 5) + self.ratio = params.get('VolumeFusion_foreground_ratio'.lower(), 0.7) + self.size_min = params.get('VolumeFusion_patchsize_min'.lower(), [8, 8, 8]) + self.size_max = params.get('VolumeFusion_patchsize_max'.lower(), [32, 32, 32]) + self.task = params['Task'.lower()] + + def __call__(self, sample): + K = self.cls_num - 1 + image1, image2 = sample['image'] + C, D, H, W = image1.shape + db = random.randint(self.size_min[0], self.size_max[0]) + hb = random.randint(self.size_min[1], self.size_max[1]) + wb = random.randint(self.size_min[2], self.size_max[2]) + d_offset = random.randint(0, D % db) + h_offset = random.randint(0, H % hb) + w_offset = random.randint(0, W % wb) + d_n = D // db + h_n = H // hb + w_n = W // wb + Nblock = d_n * h_n * w_n + Nfg = int(d_n * h_n * w_n * self.ratio) + list_fg = [1] * Nfg + [0] * (Nblock - Nfg) + random.shuffle(list_fg) + mask = np.zeros([1, D, H, W], np.uint8) + for d in range(d_n): + for h in range(h_n): + for w in range(w_n): + d0, h0, w0 = d*db + d_offset, h*hb + h_offset, w*wb + w_offset + d1, h1, w1 = d0 + db, h0 + hb, w0 + wb + idx = d*h_n*w_n + h*w_n + w + if(list_fg[idx]> 0): + cls_k = random.randint(1, K) + mask[:, d0:d1, h0:h1, w0:w1] = cls_k + alpha = mask * 1.0 / K + x_fuse = alpha*image1 + (1.0 - alpha)*image2 + sample['image'] = x_fuse + sample['label'] = mask + return sample + +class VolumeFusionShuffle(AbstractTransform): + """ + Randomly crop an volume into two views with augmentation. This is used for + self-supervised pretraining in Vox2vec. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `DualViewCrop_output_size`: (list/tuple) Desired output size [D, H, W]. + The output channel is the same as the input channel. + :param `DualViewCrop_scale_lower_bound`: (list/tuple) Lower bound of the range of scale + for each dimension. e.g. (1.0, 0.5, 0.5). + param `DualViewCrop_scale_upper_bound`: (list/tuple) Upper bound of the range of scale + for each dimension. e.g. (1.0, 2.0, 2.0). + :param `DualViewCrop_inverse`: (optional, bool) Is inverse transform needed for inference. + Default is `False`. Currently, the inverse transform is not supported, and + this transform is assumed to be used only during training stage. + """ + def __init__(self, params): + self.cls_num = params.get('VolumeFusionShuffle_cls_num'.lower(), 5) + self.ratio = params.get('VolumeFusionShuffle_foreground_ratio'.lower(), 0.7) + self.size_min = params.get('VolumeFusionShuffle_patchsize_min'.lower(), [8, 8, 8]) + self.size_max = params.get('VolumeFusionShuffle_patchsize_max'.lower(), [32, 32, 32]) + self.task = params['Task'.lower()] + + def __call__(self, sample): + K = self.cls_num - 1 + image1, image2 = sample['image'] + C, D, H, W = image1.shape + x_fuse = image2 * 1.0 + mask = np.zeros([1, D, H, W], np.uint8) + db = random.randint(self.size_min[0], self.size_max[0]) + hb = random.randint(self.size_min[1], self.size_max[1]) + wb = random.randint(self.size_min[2], self.size_max[2]) + d_offset = random.randint(0, D % db) + h_offset = random.randint(0, H % hb) + w_offset = random.randint(0, W % wb) + d_n = D // db + h_n = H // hb + w_n = W // wb + coord_list_source = [] + for di in range(d_n): + for hi in range(h_n): + for wi in range(w_n): + coord_list_source.append([di, hi, wi]) + coord_list_target = copy.deepcopy(coord_list_source) + random.shuffle(coord_list_source) + random.shuffle(coord_list_target) + for i in range(int(len(coord_list_source)*self.ratio)): + ds_l = d_offset + db * coord_list_source[i][0] + hs_l = h_offset + hb * coord_list_source[i][1] + ws_l = w_offset + wb * coord_list_source[i][2] + dt_l = d_offset + db * coord_list_target[i][0] + ht_l = h_offset + hb * coord_list_target[i][1] + wt_l = w_offset + wb * coord_list_target[i][2] + s_crop = image1[:, ds_l:ds_l+db, hs_l:hs_l+hb, ws_l:ws_l+wb] + t_crop = image2[:, dt_l:dt_l+db, ht_l:ht_l+hb, wt_l:wt_l+wb] + fg_m = random.randint(1, K) + fg_w = fg_m / (K + 0.0) + x_fuse[:, dt_l:dt_l+db, ht_l:ht_l+hb, wt_l:wt_l+wb] = t_crop * (1.0 - fg_w) + s_crop * fg_w + mask[0, dt_l:dt_l+db, ht_l:ht_l+hb, wt_l:wt_l+wb] = \ + np.ones([1, db, hb, wb]) * fg_m + sample['image'] = x_fuse + sample['label'] = mask + return sample + diff --git a/pymic/transform/crop4voco.py b/pymic/transform/crop4voco.py new file mode 100644 index 0000000..6c52ca7 --- /dev/null +++ b/pymic/transform/crop4voco.py @@ -0,0 +1,107 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import json +import math +import random +import numpy as np +from scipy import ndimage +from pymic import TaskType +from pymic.transform.abstract_transform import AbstractTransform +from pymic.transform.crop import CenterCrop +from pymic.transform.intensity import * +from pymic.util.image_process import * + +def get_position_label(roi=96, num_crops=4): + half = roi // 2 + max_roi = roi * num_crops + center_x, center_y = np.random.randint(low=half, high=max_roi - half), \ + np.random.randint(low=half, high=max_roi - half) + + x_min, x_max = center_x - half, center_x + half + y_min, y_max = center_y - half, center_y + half + + total_area = roi * roi + labels = [] + for j in range(num_crops): + for i in range(num_crops): + crop_x_min, crop_x_max = i * roi, (i + 1) * roi + crop_y_min, crop_y_max = j * roi, (j + 1) * roi + + dx = min(crop_x_max, x_max) - max(crop_x_min, x_min) + dy = min(crop_y_max, y_max) - max(crop_y_min, y_min) + if dx <= 0 or dy <= 0: + area = 0 + else: + area = (dx * dy) / total_area + labels.append(area) + + labels = np.asarray(labels).reshape(1, num_crops * num_crops) + return x_min, y_min, labels + +class Crop4VoCo(CenterCrop): + """ + Randomly crop an volume into two views with augmentation. This is used for + self-supervised pretraining such as DeSD. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `DualViewCrop_output_size`: (list/tuple) Desired output size [D, H, W]. + The output channel is the same as the input channel. + :param `DualViewCrop_scale_lower_bound`: (list/tuple) Lower bound of the range of scale + for each dimension. e.g. (1.0, 0.5, 0.5). + param `DualViewCrop_scale_upper_bound`: (list/tuple) Upper bound of the range of scale + for each dimension. e.g. (1.0, 2.0, 2.0). + :param `DualViewCrop_inverse`: (optional, bool) Is inverse transform needed for inference. + Default is `False`. Currently, the inverse transform is not supported, and + this transform is assumed to be used only during training stage. + """ + def __init__(self, params): + roi_size = params.get('Crop4VoCo_roi_size'.lower(), 64) + if isinstance(roi_size, int): + self.roi_size = [roi_size] * 3 + else: + self.roi_size = roi_size + self.roi_num = params.get('Crop4VoCo_roi_num'.lower(), 2) + self.base_num = params.get('Crop4VoCo_base_num'.lower(), 4) + + self.inverse = params.get('Crop4VoCo_inverse'.lower(), False) + self.task = params['Task'.lower()] + + def __call__(self, sample): + image = sample['image'] + channel, input_size = image.shape[0], image.shape[1:] + input_dim = len(input_size) + # print(input_size, self.roi_size) + assert(input_size[0] == self.roi_size[0]) + assert(input_size[1] == self.roi_size[1] * self.base_num) + assert(input_size[2] == self.roi_size[2] * self.base_num) + + base_num, roi_num, roi_size = self.base_num, self.roi_num, self.roi_size + base_crops, roi_crops, roi_labels = [], [], [] + crop_size = [channel] + list(roi_size) + for j in range(base_num): + for i in range(base_num): + crop_min = [0, 0, roi_size[1]*j, roi_size[2]*i] + crop_max = [crop_min[d] + crop_size[d] for d in range(4)] + crop_out = crop_ND_volume_with_bounding_box(image, crop_min, crop_max) + base_crops.append(crop_out) + + for i in range(roi_num): + x_min, y_min, label = get_position_label(self.roi_size[2], base_num) + # print('label', label) + crop_min = [0, 0, y_min, x_min] + crop_max = [crop_min[d] + crop_size[d] for d in range(4)] + crop_out = crop_ND_volume_with_bounding_box(image, crop_min, crop_max) + roi_crops.append(crop_out) + roi_labels.append(label) + roi_labels = np.concatenate(roi_labels, 0).reshape(roi_num, base_num * base_num) + + base_crops = np.stack(base_crops, 0) + roi_crops = np.stack(roi_crops, 0) + sample['image'] = base_crops, roi_crops, roi_labels + return sample + + \ No newline at end of file diff --git a/pymic/transform/crop4vox2vec.py b/pymic/transform/crop4vox2vec.py new file mode 100644 index 0000000..6fdcf83 --- /dev/null +++ b/pymic/transform/crop4vox2vec.py @@ -0,0 +1,160 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch + +import json +import math +import random +import numpy as np +from imops import crop_to_box +from typing import * +from scipy import ndimage +from pymic import TaskType +from pymic.transform.abstract_transform import AbstractTransform +from pymic.transform.crop import CenterCrop +from pymic.util.image_process import * +from pymic.transform.intensity import * + +def normalize_axis_list(axis, ndim): + return list(np.core.numeric.normalize_axis_tuple(axis, ndim)) + +def scale_hu(image_hu: np.ndarray, window_hu: Tuple[float, float]) -> np.ndarray: + min_hu, max_hu = window_hu + assert min_hu < max_hu + return np.clip((image_hu - min_hu) / (max_hu - min_hu), 0, 1) + +# def gaussian_filter( +# x: np.ndarray, +# sigma: Union[float, Sequence[float]], +# axis: Union[int, Sequence[int]] +# ) -> np.ndarray: +# axis = normalize_axis_list(axis, x.ndim) +# sigma = np.broadcast_to(sigma, len(axis)) +# for sgm, ax in zip(sigma, axis): +# x = ndimage.gaussian_filter1d(x, sgm, ax) +# return x + +# def gaussian_sharpen( +# x: np.ndarray, +# sigma_1: Union[float, Sequence[float]], +# sigma_2: Union[float, Sequence[float]], +# alpha: float, +# axis: Union[int, Sequence[int]] +# ) -> np.ndarray: +# """ See https://docs.monai.io/en/stable/transforms.html#gaussiansharpen """ +# blurred = gaussian_filter(x, sigma_1, axis) +# return blurred + alpha * (blurred - gaussian_filter(blurred, sigma_2, axis)) + +def sample_box(image_size, patch_size, anchor_voxel=None): + image_size = np.array(image_size, ndmin=1) + patch_size = np.array(patch_size, ndmin=1) + + if not np.all(image_size >= patch_size): + raise ValueError(f'Can\'t sample patch of size {patch_size} from image of size {image_size}') + + min_start = 0 + max_start = image_size - patch_size + if anchor_voxel is not None: + anchor_voxel = np.array(anchor_voxel, ndmin=1) + min_start = np.maximum(min_start, anchor_voxel - patch_size + 1) + max_start = np.minimum(max_start, anchor_voxel) + start = np.random.randint(min_start, max_start + 1) + return np.array([start, start + patch_size]) + +def sample_views( + image: np.ndarray, + min_overlap: Tuple[int, int, int], + patch_size: Tuple[int, int, int], + max_num_voxels: int, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """ For 3D volumes, the image shape should be [C, D, H, W]. + """ + img_size = image.shape[1:] + overlap = [random.randint(min_overlap[i], patch_size[i]) for i in range(3)] + union_size = [2*patch_size[i] - overlap[i] for i in range(3)] + anchor_max = [img_size[i] - union_size[i] for i in range(3)] + crop_min_1 = [random.randint(0, anchor_max[i]) for i in range(3)] + crop_min_2 = [crop_min_1[i] + patch_size[i] - overlap[i] for i in range(3)] + patch_1 = sample_view(image, crop_min_1, patch_size) + patch_2 = sample_view(image, crop_min_2, patch_size) + + coords = [range(crop_min_2[i], crop_min_2[i] + overlap[i]) for i in range(3)] + coords = np.asarray(np.meshgrid(coords[0], coords[1], coords[2])) + coords = coords.reshape(3, -1).transpose() + roi_voxels_1 = coords - crop_min_1 + roi_voxels_2 = coords - crop_min_2 + + indices = range(coords.shape[0]) + if len(indices) > max_num_voxels: + indices = np.random.choice(indices, max_num_voxels, replace=False) + + return patch_1, patch_2, roi_voxels_1[indices], roi_voxels_2[indices] + + +def sample_view(image, crop_min, patch_size): + """ For 3D volumes, the image shape should be [C, D, H, W]. + """ + assert image.ndim == 4 + C = image.shape[0] + crop_max = [crop_min[i] + patch_size[i] for i in range(3)] + out = crop_ND_volume_with_bounding_box(image, [0] + crop_min, [C] + crop_max) + + # intensity augmentations + for c in range(C): + if(random.random() < 0.8): + out[c] = gaussian_noise(out[c], 0.05, 0.1) + if(random.random() < 0.5): + out[c] = gaussian_blur(out[c], 0.5, 1.5) + else: + alpha = random.uniform(0.0, 2.0) + out[c] = gaussian_sharpen(out[c], 0.5, 2.0, alpha) + if(random.random() < 0.8): + out[c] = gamma_correction(out[c], 0.5, 2.0) + if(random.random() < 0.8): + out[c] = window_level_augment(out[c]) + return out + +class Crop4Vox2Vec(CenterCrop): + """ + Randomly crop an volume into two views with augmentation. This is used for + self-supervised pretraining in Vox2vec. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `DualViewCrop_output_size`: (list/tuple) Desired output size [D, H, W]. + The output channel is the same as the input channel. + :param `DualViewCrop_scale_lower_bound`: (list/tuple) Lower bound of the range of scale + for each dimension. e.g. (1.0, 0.5, 0.5). + param `DualViewCrop_scale_upper_bound`: (list/tuple) Upper bound of the range of scale + for each dimension. e.g. (1.0, 2.0, 2.0). + :param `DualViewCrop_inverse`: (optional, bool) Is inverse transform needed for inference. + Default is `False`. Currently, the inverse transform is not supported, and + this transform is assumed to be used only during training stage. + """ + def __init__(self, params): + self.output_size = params['Crop4Vox2Vec_output_size'.lower()] + self.min_overlap = params.get('Crop4Vox2Vec_min_overlap'.lower(), [8, 12, 12]) + self.max_voxel = params.get('Crop4Vox2Vec_max_voxel'.lower(), 1024) + self.inverse = params.get('Crop4Vox2Vec_inverse'.lower(), False) + self.task = params['Task'.lower()] + assert isinstance(self.output_size, (list, tuple)) + + def __call__(self, sample): + image = sample['image'] + channel, input_size = image.shape[0], image.shape[1:] + input_dim = len(input_size) + assert channel == 1 + assert(input_dim == len(self.output_size)) + invalid_size = [input_size[i] < self.output_size[i]*2 - self.min_overlap[i] for i in range(3)] + if True in invalid_size: + raise ValueError("The overlap requirement {0:} is too weak for the given patch size \ + {1:} and input size {2:}".format( self.min_overlap, self.output_size,input_size)) + + patches_1, patches_2, voxels_1, voxels_2 = sample_views(image, + self.min_overlap, self.output_size, self.max_voxel) + sample['image'] = patches_1, patches_2, voxels_1, voxels_2 + return sample + + diff --git a/pymic/transform/extract_channel.py b/pymic/transform/extract_channel.py new file mode 100644 index 0000000..c4974be --- /dev/null +++ b/pymic/transform/extract_channel.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import json +import math +import random +import numpy as np +from scipy import ndimage +from pymic import TaskType +from pymic.transform.abstract_transform import AbstractTransform +from pymic.util.image_process import * + + +class ExtractChannel(AbstractTransform): + """ Random flip the image. The shape is [C, D, H, W] or [C, H, W]. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `RandomFlip_flip_depth`: (bool) + Random flip along depth axis or not, only used for 3D images. + :param `RandomFlip_flip_height`: (bool) Random flip along height axis or not. + :param `RandomFlip_flip_width`: (bool) Random flip along width axis or not. + :param `RandomFlip_inverse`: (optional, bool) Is inverse transform needed for inference. + Default is `True`. + """ + def __init__(self, params): + super(ExtractChannel, self).__init__(params) + self.channels = params['ExtractChannel_channels'.lower()] + self.inverse = params.get('ExtractChannel_inverse'.lower(), False) + + def __call__(self, sample): + image = sample['image'] + image_extract = [] + for i in self.channels: + image_extract.append(image[i]) + sample['image'] = np.asarray(image_extract) + return sample diff --git a/pymic/transform/flip.py b/pymic/transform/flip.py index ca0915e..6ffd535 100644 --- a/pymic/transform/flip.py +++ b/pymic/transform/flip.py @@ -6,7 +6,7 @@ import math import random import numpy as np -from scipy import ndimage +from pymic import TaskType from pymic.transform.abstract_transform import AbstractTransform from pymic.util.image_process import * @@ -33,8 +33,7 @@ def __init__(self, params): def __call__(self, sample): image = sample['image'] - input_shape = image.shape - input_dim = len(input_shape) - 1 + input_dim = image.ndim flip_axis = [] if(self.flip_width): if(random.random() > 0.5): @@ -52,9 +51,11 @@ def __call__(self, sample): # current pytorch does not support negative strides image_t = np.flip(image, flip_axis).copy() sample['image'] = image_t - if('label' in sample and self.task == 'segmentation'): + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): sample['label'] = np.flip(sample['label'] , flip_axis).copy() - if('pixel_weight' in sample and self.task == 'segmentation'): + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): sample['pixel_weight'] = np.flip(sample['pixel_weight'] , flip_axis).copy() return sample diff --git a/pymic/transform/intensity.py b/pymic/transform/intensity.py index 3b5ee9d..f05b95b 100644 --- a/pymic/transform/intensity.py +++ b/pymic/transform/intensity.py @@ -1,10 +1,13 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division import copy +import itertools import json import math import random import numpy as np +from scipy import ndimage +from skimage import exposure from pymic.transform.abstract_transform import AbstractTransform from pymic.util.image_process import * try: # SciPy >= 0.19 @@ -37,13 +40,83 @@ def bezier_curve(points, nTimes=1000): t = np.linspace(0.0, 1.0, nTimes) - polynomial_array = np.array([ bernstein_poly(i, nPoints-1, t) for i in range(0, nPoints) ]) + polynomial_array = np.array([ bernstein_poly(i, nPoints-1, t) for i in range(0, nPoints)]) xvals = np.dot(xPoints, polynomial_array) yvals = np.dot(yPoints, polynomial_array) return xvals, yvals + +class IntensityClip(AbstractTransform): + """ + Clip the intensity for input image + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `IntensityClip_channels`: (list) A list of int for specifying the channels. + :param `IntensityClip_lower`: (list) The lower bound for clip in each channel. + :param `IntensityClip_upper`: (list) The upper bound for clip in each channel. + :param `IntensityClip_inverse`: (optional, bool) + Is inverse transform needed for inference. Default is `False`. + """ + def __init__(self, params): + super(IntensityClip, self).__init__(params) + self.channels = params['IntensityClip_channels'.lower()] + self.lower = params.get('IntensityClip_lower'.lower(), None) + self.upper = params.get('IntensityClip_upper'.lower(), None) + self.perct = params.get('IntensityClip_percentile_mode'.lower(), False) + self.inverse = params.get('IntensityClip_inverse'.lower(), False) + + def __call__(self, sample): + image = sample['image'] + lower = self.lower if self.lower is not None else [None] * len(self.channels) + upper = self.upper if self.upper is not None else [None] * len(self.channels) + for chn in self.channels: + lower_c, upper_c = lower[chn], upper[chn] + if(lower_c is None): + lower_c = np.percentile(image[chn], 0.05) + elif(self.perct): + lower_c = np.percentile(image[chn], lower_c) + if(upper_c is None): + upper_c = np.percentile(image[chn], 99.95) + elif(self.perct): + upper_c = np.percentile(image[chn], upper_c) + image[chn] = np.clip(image[chn], lower_c, upper_c) + sample['image'] = image + return sample + +class HistEqual(AbstractTransform): + """ + Histogram equalization. Note that the output will be in the range of [0, 1]. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `HistEqual_channels`: (list) A list of int for specifying the channels. + :param `HistEqual_bin`: (int) The number of bins. + :param `HistEqual_inverse`: (optional, bool) + Is inverse transform needed for inference. Default is `False`. + """ + def __init__(self, params): + super(HistEqual, self).__init__(params) + self.channels = params.get('HistEqual_channels'.lower(), None) + # self.min = params.get('HistEqual_min'.lower(), None) + # self.max = params.get('HistEqual_max'.lower(), None) + self.bin = params.get('HistEqual_bin'.lower(), 2000) + self.inverse = params.get('HistEqual_inverse'.lower(), False) + + def __call__(self, sample): + image = sample['image'] + C = image.shape[0] + chns = range(C) if self.channels is None else self.channels + for i in range(len(chns)): + c = chns[i] + image[c] = exposure.equalize_hist(image[c],nbins= self.bin) + sample['image'] = image + return sample + class GammaCorrection(AbstractTransform): """ Apply random gamma correction to given channels. @@ -61,28 +134,76 @@ class GammaCorrection(AbstractTransform): """ def __init__(self, params): super(GammaCorrection, self).__init__(params) - self.channels = params['GammaCorrection_channels'.lower()] - self.gamma_min = params['GammaCorrection_gamma_min'.lower()] - self.gamma_max = params['GammaCorrection_gamma_max'.lower()] + self.channels = params.get('GammaCorrection_channels'.lower(), None) + self.gamma_min = params.get('GammaCorrection_gamma_min'.lower(), 0.7) + self.gamma_max = params.get('GammaCorrection_gamma_max'.lower(), 1.5) + self.flip_prob = params.get('GammaCorrection_intensity_flip_probability'.lower(), 0.0) self.prob = params.get('GammaCorrection_probability'.lower(), 0.5) self.inverse = params.get('GammaCorrection_inverse'.lower(), False) def __call__(self, sample): - if(np.random.uniform() > self.prob): - return sample image= sample['image'] + if(self.channels is None): + self.channels = range(image.shape[0]) for chn in self.channels: + if(np.random.uniform() > self.prob): + continue gamma_c = random.random() * (self.gamma_max - self.gamma_min) + self.gamma_min img_c = image[chn] v_min = img_c.min() v_max = img_c.max() - img_c = (img_c - v_min)/(v_max - v_min) - img_c = np.power(img_c, gamma_c)*(v_max - v_min) + v_min + if(v_min < v_max): + img_c = (img_c - v_min)/(v_max - v_min) + if(np.random.uniform() < self.flip_prob): + img_c = 1.0 - img_c + img_c = np.power(img_c, gamma_c)*(v_max - v_min) + v_min image[chn] = img_c sample['image'] = image return sample +def gaussian_noise(image, std_min, std_max,): + """ + The input has a shape of [C, D, H, W] or [D, H, W]. + In the former case, volume-level noise will be added. + In the latter case, slice-level noise will ba added. + """ + v_min = image.min() + v_max = image.max() + std = random.random() * (std_max - std_min) + std_min + noise = np.random.normal(0, std, image.shape) + out = image + noise + out = np.clip(out, v_min, v_max) + return out + +def gaussian_blur(image, sigma_min, sigma_max): + sigma = random.random() * (sigma_max - sigma_min) + sigma_min + out = ndimage.gaussian_filter(image, sigma, order = 0) + return out + +def gaussian_sharpen(image, sigma_min, sigma_max, alpha = 10.0): + blurred = gaussian_blur(image, sigma_min, sigma_max) + out = image + (image - blurred) * alpha + return out + +def window_level_augment(image, offset = 0.1): + v_min = image.min() + v_max = image.max() + margin = (v_max - v_min) * offset + v0 = random.uniform(v_min - margin, v_min + margin) + v1 = random.uniform(v_max - margin, v_max + margin) + out = np.clip((image - v0) / (v1 - v0), 0, 1) + return out + +def gamma_correction(image, gamma_min, gamma_max): + v_min = image.min() + v_max = image.max() + if(v_min < v_max): + image = (image - v_min)/(v_max - v_min) + gamma = random.random() * (gamma_max - gamma_min) + gamma_min + image = np.power(image, gamma)*(v_max - v_min) + v_min + return image + class GaussianNoise(AbstractTransform): """ Add Gaussian Noise to given channels. @@ -100,21 +221,65 @@ class GaussianNoise(AbstractTransform): """ def __init__(self, params): super(GaussianNoise, self).__init__(params) - self.channels = params['GaussianNoise_channels'.lower()] - self.mean = params['GaussianNoise_mean'.lower()] - self.std = params['GaussianNoise_std'.lower()] + self.channels = params.get('GaussianNoise_channels'.lower(), None) + self.std_min = params.get('GaussianNoise_std_min'.lower(), 0.02) + self.std_max = params.get('GaussianNoise_std_max'.lower(), 0.1) self.prob = params.get('GaussianNoise_probability'.lower(), 0.5) self.inverse = params.get('GaussianNoise_inverse'.lower(), False) def __call__(self, sample): - if(np.random.uniform() > self.prob): - return sample - image= sample['image'] + image = sample['image'] + if(self.channels is None): + self.channels = range(image.shape[0]) for chn in self.channels: - img_c = image[chn] - noise = np.random.normal(self.mean, self.std, img_c.shape) - image[chn] = img_c + noise + if(np.random.uniform() < self.prob): + image[chn] = gaussian_noise(image[chn], self.std_min, self.std_max) + sample['image'] = image + return sample + +def adaptive_contrast_adjust(image, p0=0.1, p1=99.9): + v_min = image.min() + v_max = image.max() + v0 = np.percentile(image, p0) + v1 = np.percentile(image, p1) + mask_l = image < v0 + mask_m = (image >= v0) * (image <= v1) + mask_u = image > v1 + image[mask_l] = (image[mask_l] - v_min) * 0.1 / (v0 - v_min) + image[mask_m] = (image[mask_m] - v0) / (v1 - v0)*0.8 + 0.1 + image[mask_u] = 0.9 + 0.1 * (image[mask_u] - v1) / (v_max - v1) + return image + +class AdaptiveContrastAdjust(AbstractTransform): + """ + Add Gaussian Noise to given channels. + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `GaussianNoise_channels`: (list) A list of int for specifying the channels. + :param `GaussianNoise_mean`: (float) The mean value of noise. + :param `GaussianNoise_std`: (float) The std of noise. + :param `GaussianNoise_probability`: (optional, float) + The probability of applying GaussianNoise. Default is 0.5. + :param `GaussianNoise_inverse`: (optional, bool) + Is inverse transform needed for inference. Default is `False`. + """ + def __init__(self, params): + super(AdaptiveContrastAdjust, self).__init__(params) + self.channels = params.get('AdaptiveContrastAdjust_channels'.lower(), None) + self.p0 = params.get('AdaptiveContrastAdjust_percent_lower'.lower(), 2) + self.p1 = params.get('AdaptiveContrastAdjust_percent_upper'.lower(), 98) + self.prob = params.get('AdaptiveContrastAdjust_probability'.lower(), 0.5) + self.inverse = params.get('AdaptiveContrastAdjust_inverse'.lower(), False) + + def __call__(self, sample): + image = sample['image'] * 1.0 + if(self.channels is None): + self.channels = range(image.shape[0]) + for chn in self.channels: + if(np.random.uniform() < self.prob): + image[chn] = adaptive_contrast_adjust(image[chn], self.p0, self.p1) sample['image'] = image return sample @@ -136,21 +301,55 @@ def __call__(self, sample): class NonLinearTransform(AbstractTransform): def __init__(self, params): super(NonLinearTransform, self).__init__(params) - self.inverse = params.get('NonLinearTransform_inverse'.lower(), False) + self.channels = params.get('NonLinearTransform_channels'.lower(), None) self.prob = params.get('NonLinearTransform_probability'.lower(), 0.5) + self.inverse = params.get('NonLinearTransform_inverse'.lower(), False) + self.block_range = params.get('NonLinearTransform_block_range'.lower(), None) + self.block_size = params.get('NonLinearTransform_block_size'.lower(), [4, 8, 8]) + - def __call__(self, sample): - if(random.random() > self.prob): - return sample - - image= sample['image'] + def apply_nonlinear_transform(self, img): + """ + the input img should be normlized to [0, 1]""" points = [[0, 0], [random.random(), random.random()], [random.random(), random.random()], [1, 1]] - xvals, yvals = bezier_curve(points, nTimes=100000) - if random.random() < 0.5: # Half change to get flip + xvals, yvals = bezier_curve(points, nTimes=10000) + if random.random() < 0.5: # Half chance to get flip xvals = np.sort(xvals) else: xvals, yvals = np.sort(xvals), np.sort(yvals) - image = np.interp(image, xvals, yvals) + + img = np.interp(img, xvals, yvals) + return img + + def __call__(self, sample): + if(random.random() > self.prob): + return sample + + image = sample['image'] + img_shape = image.shape + img_dim = len(img_shape) - 1 + channels = self.channels if self.channels is not None else range(image.shape[0]) + for chn in channels: + # normalize the image intensity to [0, 1] before the non-linear tranform + img_c = image[chn] + v_min, v_max = img_c.min(), img_c.max() + if(v_min < v_max): + img_c = (img_c - v_min)/(v_max - v_min) + if(self.block_range is None): # apply non-linear transform to the entire image + img_c = self.apply_nonlinear_transform(img_c) + else: # non-linear transform to random blocks + img_c_sr = copy.deepcopy(img_c) + for n in range(self.block_range[0], self.block_range[1]): + coord_min = [random.randint(0, img_shape[1+i] - self.block_size[i]) \ + for i in range(img_dim)] + window = img_c_sr[coord_min[0]:coord_min[0] + self.block_size[0], + coord_min[1]:coord_min[1] + self.block_size[1], + coord_min[2]:coord_min[2] + self.block_size[2]] + img_c[coord_min[0]:coord_min[0] + self.block_size[0], + coord_min[1]:coord_min[1] + self.block_size[1], + coord_min[2]:coord_min[2] + self.block_size[2]] = \ + self.apply_nonlinear_transform(window) + image[chn] = img_c * (v_max - v_min) + v_min sample['image'] = image return sample @@ -162,9 +361,8 @@ def __init__(self, params): super(LocalShuffling, self).__init__(params) self.inverse = params.get('LocalShuffling_inverse'.lower(), False) self.prob = params.get('LocalShuffling_probability'.lower(), 0.5) - self.block_range = params.get('LocalShuffling_block_range'.lower(), (5000, 10000)) - self.block_size_min = params.get('LocalShuffling_block_size_min'.lower(), None) - self.block_size_max = params.get('LocalShuffling_block_size_max'.lower(), None) + self.block_range = params.get('LocalShuffling_block_range'.lower(), [40, 80]) + self.block_size = params.get('LocalShuffling_block_size'.lower(), [4, 8, 8]) def __call__(self, sample): if(random.random() > self.prob): @@ -175,49 +373,33 @@ def __call__(self, sample): img_dim = len(img_shape) - 1 assert(img_dim == 2 or img_dim == 3) img_out = copy.deepcopy(image) - if(self.block_size_min is None): - block_size_min = [2] * img_dim - elif(isinstance(self.block_size_min, int)): - block_size_min = [self.block_size_min] * img_dim - else: - assert(len(self.block_size_min) == img_dim) - block_size_min = self.block_size_min - - if(self.block_size_max is None): - block_size_max = [img_shape[1+i]//10 for i in range(img_dim)] - elif(isinstance(self.block_size_min, int)): - block_size_max = [self.block_size_max] * img_dim - else: - assert(len(self.block_size_max) == img_dim) - block_size_max = self.block_size_max + block_num = random.randint(self.block_range[0], self.block_range[1]) for n in range(block_num): - block_size = [random.randint(block_size_min[i], block_size_max[i]) \ - for i in range(img_dim)] - coord_min = [random.randint(0, img_shape[1+i] - block_size[i]) \ + coord_min = [random.randint(0, img_shape[1+i] - self.block_size[i]) \ for i in range(img_dim)] if(img_dim == 2): - window = image[:, coord_min[0]:coord_min[0] + block_size[0], - coord_min[1]:coord_min[1] + block_size[1]] - n_pixels = block_size[0] * block_size[1] + window = image[:, coord_min[0]:coord_min[0] + self.block_size[0], + coord_min[1]:coord_min[1] + self.block_size[1]] + n_pixels = self.block_size[0] * self.block_size[1] else: - window = image[:, coord_min[0]:coord_min[0] + block_size[0], - coord_min[1]:coord_min[1] + block_size[1], - coord_min[2]:coord_min[2] + block_size[2]] - n_pixels = block_size[0] * block_size[1] * block_size[2] + window = image[:, coord_min[0]:coord_min[0] + self.block_size[0], + coord_min[1]:coord_min[1] + self.block_size[1], + coord_min[2]:coord_min[2] + self.block_size[2]] + n_pixels = self.block_size[0] * self.block_size[1] * self.block_size[2] window = np.reshape(window, [-1, n_pixels]) np.random.shuffle(np.transpose(window)) window = np.transpose(window) if(img_dim == 2): - window = np.reshape(window, [-1, block_size[0], block_size[1]]) - img_out[:, coord_min[0]:coord_min[0] + block_size[0], - coord_min[1]:coord_min[1] + block_size[1]] = window + window = np.reshape(window, [-1, self.block_size[0], self.block_size[1]]) + img_out[:, coord_min[0]:coord_min[0] + self.block_size[0], + coord_min[1]:coord_min[1] + self.block_size[1]] = window else: - window = np.reshape(window, [-1, block_size[0], block_size[1], block_size[2]]) - img_out[:, coord_min[0]:coord_min[0] + block_size[0], - coord_min[1]:coord_min[1] + block_size[1], - coord_min[2]:coord_min[2] + block_size[2]] = window + window = np.reshape(window, [-1, self.block_size[0], self.block_size[1], self.block_size[2]]) + img_out[:, coord_min[0]:coord_min[0] + self.block_size[0], + coord_min[1]:coord_min[1] + self.block_size[1], + coord_min[2]:coord_min[2] + self.block_size[2]] = window sample['image'] = img_out return sample @@ -229,10 +411,9 @@ def __init__(self, params): super(InPainting, self).__init__(params) self.inverse = params.get('InPainting_inverse'.lower(), False) self.prob = params.get('InPainting_probability'.lower(), 0.5) - self.block_range = params.get('InPainting_block_range'.lower(), (1, 6)) - self.block_size_min = params.get('InPainting_block_size_min'.lower(), None) - self.block_size_max = params.get('InPainting_block_size_max'.lower(), None) - + self.block_range = params.get('InPainting_block_range'.lower(), (20, 40)) + self.block_size = params.get('InPainting_block_size'.lower(), [4, 8, 8]) + def __call__(self, sample): if(random.random() > self.prob): return sample @@ -242,38 +423,21 @@ def __call__(self, sample): img_dim = len(img_shape) - 1 assert(img_dim == 2 or img_dim == 3) - if(self.block_size_min is None): - block_size_min = [img_shape[1+i]//6 for i in range(img_dim)] - elif(isinstance(self.block_size_min, int)): - block_size_min = [self.block_size_min] * img_dim - else: - assert(len(self.block_size_min) == img_dim) - block_size_min = self.block_size_min - - if(self.block_size_max is None): - block_size_max = [img_shape[1+i]//3 for i in range(img_dim)] - elif(isinstance(self.block_size_min, int)): - block_size_max = [self.block_size_max] * img_dim - else: - assert(len(self.block_size_max) == img_dim) - block_size_max = self.block_size_max block_num = random.randint(self.block_range[0], self.block_range[1]) - for n in range(block_num): - block_size = [random.randint(block_size_min[i], block_size_max[i]) \ - for i in range(img_dim)] - coord_min = [random.randint(3, img_shape[1+i] - block_size[i] - 3) \ + for n in range(block_num): + coord_min = [random.randint(3, img_shape[1+i] - self.block_size[i] - 3) \ for i in range(img_dim)] if(img_dim == 2): - random_block = np.random.rand(img_shape[0], block_size[0], block_size[1]) - image[:, coord_min[0]:coord_min[0] + block_size[0], - coord_min[1]:coord_min[1] + block_size[1]] = random_block + random_block = np.random.rand(img_shape[0], self.block_size[0], self.block_size[1]) * 2 -1 + image[:, coord_min[0]:coord_min[0] + self.block_size[0], + coord_min[1]:coord_min[1] + self.block_size[1]] = random_block else: - random_block = np.random.rand(img_shape[0], block_size[0], - block_size[1], block_size[2]) - image[:, coord_min[0]:coord_min[0] + block_size[0], - coord_min[1]:coord_min[1] + block_size[1], - coord_min[2]:coord_min[2] + block_size[2]] = random_block + random_block = np.random.rand(img_shape[0], self.block_size[0], + self.block_size[1], self.block_size[2]) * 2 -1 + image[:, coord_min[0]:coord_min[0] + self.block_size[0], + coord_min[1]:coord_min[1] + self.block_size[1], + coord_min[2]:coord_min[2] + self.block_size[2]] = random_block sample['image'] = image return sample @@ -285,9 +449,8 @@ def __init__(self, params): super(OutPainting, self).__init__(params) self.inverse = params.get('OutPainting_inverse'.lower(), False) self.prob = params.get('OutPainting_probability'.lower(), 0.5) - self.block_range = params.get('OutPainting_block_range'.lower(), (1, 6)) - self.block_size_min = params.get('OutPainting_block_size_min'.lower(), None) - self.block_size_max = params.get('OutPainting_block_size_max'.lower(), None) + self.block_range = params.get('OutPainting_block_range'.lower(), (2, 8)) + self.block_size = params.get('OutPainting_block_size'.lower(), None) def __call__(self, sample): if(random.random() > self.prob): @@ -297,28 +460,18 @@ def __call__(self, sample): img_shape = image.shape img_dim = len(img_shape) - 1 assert(img_dim == 2 or img_dim == 3) - img_out = np.random.rand(*img_shape) + img_out = np.random.rand(*img_shape) * 2 -1 - if(self.block_size_min is None): - block_size_min = [img_shape[1+i] - 4 * img_shape[1+i]//7 for i in range(img_dim)] - elif(isinstance(self.block_size_min, int)): - block_size_min = [self.block_size_min] * img_dim + if(self.block_size is None): + margin = [16, 32, 32] + block_size = [img_shape[1+i] - margin[i] for i in range(img_dim)] else: - assert(len(self.block_size_min) == img_dim) - block_size_min = self.block_size_min + assert(len(self.block_size) == img_dim) + block_size = self.block_size - if(self.block_size_max is None): - block_size_max = [img_shape[1+i] - 3 * img_shape[1+i]//7 for i in range(img_dim)] - elif(isinstance(self.block_size_min, int)): - block_size_max = [self.block_size_max] * img_dim - else: - assert(len(self.block_size_max) == img_dim) - block_size_max = self.block_size_max block_num = random.randint(self.block_range[0], self.block_range[1]) for n in range(block_num): - block_size = [random.randint(block_size_min[i], block_size_max[i]) \ - for i in range(img_dim)] coord_min = [random.randint(3, img_shape[1+i] - block_size[i] - 3) \ for i in range(img_dim)] if(img_dim == 2): @@ -345,8 +498,8 @@ def __init__(self, params): self.inverse = params.get('InOutPainting_inverse'.lower(), False) self.prob = params.get('InOutPainting_probability'.lower(), 0.5) self.in_prob = params.get('InPainting_probability'.lower(), 0.5) - params['InPainting_probability'] = 1.0 - params['outPainting_probability'] = 1.0 + params['InPainting_probability'.lower()] = 1.0 + params['OutPainting_probability'.lower()] = 1.0 self.inpaint = InPainting(params) self.outpaint = OutPainting(params) @@ -357,4 +510,76 @@ def __call__(self, sample): sample = self.inpaint(sample) else: sample = self.outpaint(sample) + return sample + +class PatchSwaping(AbstractTransform): + """ + Apply patch swaping for context restoration in self-supervised learning. + Reference: Liang Chen et al., Self-supervised learning for medical image analysis + using image context restoration, Medical Image Analysis, 2019. + """ + def __init__(self, params): + super(PatchSwaping, self).__init__(params) + self.block_range = params.get('PatchSwaping_block_range'.lower(), (10, 20)) + self.block_size = params.get('PatchSwaping_block_size'.lower(), [8, 16, 16]) + self.inverse = params.get('PatchSwaping_inverse'.lower(), False) + + def __call__(self, sample): + image= sample['image'] + img_shape = image.shape + img_dim = len(img_shape) - 1 + assert(img_dim == 2 or img_dim == 3) + img_out = copy.deepcopy(image) + + block_num = random.randint(self.block_range[0], self.block_range[1]) + for t in range(block_num): + pos_a0 = [random.randint(0, img_shape[-3+i] - self.block_size[i]) for i in range(img_dim)] + pos_b0 = [random.randint(0, img_shape[-3+i] - self.block_size[i]) for i in range(img_dim)] + pos_a1 = [pos_a0[i] + self.block_size[i] for i in range(img_dim)] + pos_b1 = [pos_b0[i] + self.block_size[i] for i in range(img_dim)] + img_out[:, pos_a0[0]:pos_a1[0], pos_a0[1]:pos_a1[1], pos_a0[2]:pos_a1[2]] = \ + image[:, pos_b0[0]:pos_b1[0], pos_b0[1]:pos_b1[1], pos_b0[2]:pos_b1[2]] + img_out[:, pos_b0[0]:pos_b1[0], pos_b0[1]:pos_b1[1], pos_b0[2]:pos_b1[2]] = \ + image[:, pos_a0[0]:pos_a1[0], pos_a0[1]:pos_a1[1], pos_a0[2]:pos_a1[2]] + + sample['image'] = img_out + sample['label'] = image + return sample + +class MaskedImageModeling(AbstractTransform): + """ + Apply masking for context restoration in self-supervised learning. + Reference: Zekai Chen et al., Masked Image Modeling Advances 3D Medical Image Analysis, + WACV, 2023 . + """ + def __init__(self, params): + super(MaskedImageModeling, self).__init__(params) + self.ratio = params.get('MaskedImageModeling_ratio'.lower(), 0.45) + self.block_size = params.get('MaskedImageModeling_block_size'.lower(), [8, 16, 16]) + self.inverse = params.get('MaskedImageModeling_inverse'.lower(), False) + + def __call__(self, sample): + image= sample['image'] + C, D, H, W = image.shape + img_out = copy.deepcopy(image) + + block = np.zeros([C] + list(self.block_size)) + for d in range(0, D, self.block_size[0]): + d1 = d + self.block_size[0] + if d1 > D: + continue + for h in range(0, H, self.block_size[1]): + h1 = h + self.block_size[1] + if h1 > H: + continue + for w in range(0, W, self.block_size[2]): + w1 = w + self.block_size[2] + if w1 > W: + continue + r = random.random() + if ( r < self.ratio): + img_out[:, d:d1, h:h1, w:w1] = block + + sample['image'] = img_out + sample['label'] = image return sample \ No newline at end of file diff --git a/pymic/transform/label_convert.py b/pymic/transform/label_convert.py index 0dcae37..afbbaaf 100644 --- a/pymic/transform/label_convert.py +++ b/pymic/transform/label_convert.py @@ -7,6 +7,7 @@ import random import numpy as np from scipy import ndimage +from pymic import TaskType from pymic.transform.abstract_transform import AbstractTransform from pymic.util.image_process import * @@ -80,19 +81,42 @@ def __init__(self, params): self.inverse = params.get('LabelToProbability_inverse'.lower(), False) def __call__(self, sample): - if(self.task == 'segmentation'): + if(self.task == TaskType.SEGMENTATION): label = sample['label'][0] # sample['label'] is (1, h, w) label_prob = np.zeros((self.class_num, *label.shape), dtype = np.float32) for i in range(self.class_num): label_prob[i] = label == i*np.ones_like(label) sample['label_prob'] = label_prob - elif(self.task == 'classification'): + elif(self.task == TaskType.CLASSIFICATION_ONE_HOT): label_idx = sample['label'] label_prob = np.zeros((self.class_num,), np.float32) label_prob[label_idx] = 1.0 sample['label_prob'] = label_prob + elif(self.task == TaskType.CLASSIFICATION_COEXIST): + sample['label_prob'] = sample['label'] return sample +class LabelSmooth(AbstractTransform): + """ + Apply label smoothing to one-hot labels. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `LabelSmooth_alpha`: (float) Alpha value for label smoothing. + :param `LabelSmooth_inverse`: (optional, bool) + Is inverse transform needed for inference. Default is `False`. + """ + def __init__(self, params): + super(LabelSmooth, self).__init__(params) + self.alpha = params['LabelSmooth_alpha'.lower()] + self.inverse = params.get('LabelSmooth_inverse'.lower(), False) + + def __call__(self, sample): + label_prob = sample['label_prob'] + K = list(label_prob.shape)[1] + sample['label_prob'] = label_prob * (1.0 - self.alpha) + self.alpha / K + return sample class PartialLabelToProbability(AbstractTransform): """ @@ -130,5 +154,59 @@ def __call__(self, sample): return sample +class SelfReconstructionLabel(AbstractTransform): + """ + Used for self-supervised learning with image reconstruction tasks. + """ + def __init__(self, params): + """ + class_num (int): the class number in the label map + """ + super(SelfReconstructionLabel, self).__init__(params) + self.inverse = params.get('SelfReconstructionLabel_inverse'.lower(), False) + + def __call__(self, sample): + image = sample['image'] + label = image * 1.0 + sample['label'] = label + return sample + +class MaskedImageModelingLabel(AbstractTransform): + """ + Used for self-supervised learning with image reconstruction tasks. + Only reconstruct the masked region in the input. + The input images is masked in local patches. + """ + def __init__(self, params): + """ + class_num (int): the class number in the label map + """ + super(MaskedImageModelingLabel, self).__init__(params) + self.patch_size = params.get('MaskedImageModelingLabel_patch_size'.lower(), [16, 16, 16]) + self.masking_ratio = params.get('MaskedImageModelingLabel_ratio'.lower(), 0.15) + self.inverse = params.get('MaskedImageModelingLabel_inverse'.lower(), False) + + def __call__(self, sample): + image = sample['image'] + C, D, H, W = image.shape + patch_size = self.patch_size + mask = np.ones([D, H, W], np.float32) + grid_size = [math.ceil((image.shape[i+1] + 0.0) / patch_size[i]) for i in range(3)] + for d in range(grid_size[0]): + d0 = d*patch_size[0] + for h in range(grid_size[1]): + h0 = h*patch_size[1] + for w in range(grid_size[2]): + w0 = w*patch_size[2] + if(random.random() > self.masking_ratio): + continue + d1 = min(d0 + patch_size[0], D) + h1 = min(h0 + patch_size[1], H) + w1 = min(w0 + patch_size[2], W) + mask[d0:d1, h0:h1, w0:w1] = np.zeros([d1 - d0, h1 - h0, w1 - w0]) + sample['pixel_weight'] = 1 - mask + sample['image'] = image * mask + sample['label'] = image + return sample diff --git a/pymic/transform/mix.py b/pymic/transform/mix.py new file mode 100644 index 0000000..6efed6a --- /dev/null +++ b/pymic/transform/mix.py @@ -0,0 +1,153 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import copy +import json +import math +import random +import numpy as np +from pymic.transform.abstract_transform import AbstractTransform +from pymic.util.image_process import * +try: # SciPy >= 0.19 + from scipy.special import comb +except ImportError: + from scipy.misc import comb + + +class CopyPaste(AbstractTransform): + """ + In-painting of an input image, used for self-supervised learning + """ + def __init__(self, params): + super(CopyPaste, self).__init__(params) + self.inverse = params.get('CopyPaste_inverse'.lower(), False) + self.block_range = params.get('CopyPaste_block_range'.lower(), (1, 6)) + self.block_size_min = params.get('CopyPaste_block_size_min'.lower(), None) + self.block_size_max = params.get('CopyPaste_block_size_max'.lower(), None) + + def __call__(self, sample): + image= sample['image'] + img_shape = image.shape + img_dim = len(img_shape) - 1 + assert(img_dim == 2 or img_dim == 3) + + if(self.block_size_min is None): + block_size_min = [img_shape[1+i]//6 for i in range(img_dim)] + elif(isinstance(self.block_size_min, int)): + block_size_min = [self.block_size_min] * img_dim + else: + assert(len(self.block_size_min) == img_dim) + block_size_min = self.block_size_min + + if(self.block_size_max is None): + block_size_max = [img_shape[1+i]//3 for i in range(img_dim)] + elif(isinstance(self.block_size_min, int)): + block_size_max = [self.block_size_max] * img_dim + else: + assert(len(self.block_size_max) == img_dim) + block_size_max = self.block_size_max + block_num = random.randint(self.block_range[0], self.block_range[1]) + + for n in range(block_num): + block_size = [random.randint(block_size_min[i], block_size_max[i]) \ + for i in range(img_dim)] + coord_min = [random.randint(3, img_shape[1+i] - block_size[i] - 3) \ + for i in range(img_dim)] + if(img_dim == 2): + random_block = np.random.rand(img_shape[0], block_size[0], block_size[1]) + image[:, coord_min[0]:coord_min[0] + block_size[0], + coord_min[1]:coord_min[1] + block_size[1]] = random_block + else: + random_block = np.random.rand(img_shape[0], block_size[0], + block_size[1], block_size[2]) + image[:, coord_min[0]:coord_min[0] + block_size[0], + coord_min[1]:coord_min[1] + block_size[1], + coord_min[2]:coord_min[2] + block_size[2]] = random_block + sample['image'] = image + return sample + +class PatchMix(AbstractTransform): + """ + In-painting of an input image, used for self-supervised learning + """ + def __init__(self, params): + super(PatchMix, self).__init__(params) + self.inverse = params.get('PatchMix_inverse'.lower(), False) + self.threshold = params.get('PatchMix_threshold'.lower(), 0) + self.crop_size = params.get('PatchMix_crop_size'.lower(), [64, 128, 128]) + self.fg_cls_num = params.get('PatchMix_cls_num'.lower(), [4, 40]) + self.patch_num_range= params.get('PatchMix_patch_range'.lower(), [4, 40]) + self.patch_size_min = params.get('PatchMix_patch_size_min'.lower(), [4, 4, 4]) + self.patch_size_max = params.get('PatchMix_patch_size_max'.lower(), [20, 40, 40]) + + def __call__(self, sample): + x0 = self._random_crop_and_flip(sample) + x1 = self._random_crop_and_flip(sample) + C, D, H, W = x0.shape + # generate mask + fg_mask = np.zeros_like(x0, np.uint8) + patch_num = random.randint(self.patch_num_range[0], self.patch_num_range[1]) + for patch in range(patch_num): + d = random.randint(self.patch_size_min[0], self.patch_size_max[0]) + h = random.randint(self.patch_size_min[1], self.patch_size_max[1]) + w = random.randint(self.patch_size_min[2], self.patch_size_max[2]) + d_c = random.randint(0, D) + h_c = random.randint(0, H) + w_c = random.randint(0, W) + d0, d1 = max(0, d_c - d // 2), min(D, d_c + d // 2) + h0, h1 = max(0, h_c - h // 2), min(H, h_c + h // 2) + w0, w1 = max(0, w_c - w // 2), min(W, w_c + w // 2) + temp_m = np.ones([C, d1-d0, h1-h0, w1-w0]) * random.randint(1, self.fg_cls_num) + fg_mask[:, d0:d1, h0:h1, w0:w1] = temp_m + fg_w = fg_mask * 1.0 / self.fg_cls_num + x_fuse = fg_w*x0 + (1.0 - fg_w)*x1 # x1 is used as background + + sample['image'] = x_fuse + sample['label'] = fg_mask + return sample + + def _random_crop_and_flip(self, sample): + image = sample['image'] + input_dim = len(image.shape) - 1 + assert(input_dim == 3) + C, D, H, W = image.shape + + half_size = [x // 2 for x in self.crop_size] + dc = random.randint(half_size[0], D - half_size[0]) + image2d = image[0, dc, :, :] + mask2d = np.zeros_like(image2d) + mask2d[half_size[1]:H+1-half_size[1], half_size[2]:W+1-half_size[2]] = \ + np.ones([H-self.crop_size[1]+1, W-self.crop_size[2]+1]) + if('label' in sample): + temp_mask = sample['label'][0, dc, :, :] > 0 + mask2d = temp_mask * mask2d + elif(self.threshold is not None): + temp_mask = image2d > self.threshold + se = np.ones([3,3]) + temp_mask = ndimage.binary_opening(temp_mask, se, iterations = 2) + temp_mask = get_largest_k_components(temp_mask, 1) + mask2d = temp_mask * mask2d + + indices = np.where(mask2d) + n = random.randint(0, len(indices[0])-1) + center = [indices[i][n] for i in range(2)] + crop_min = [dc - half_size[0], center[0]-half_size[1], center[1] - half_size[2]] + crop_max = [crop_min[i] + self.crop_size[i] for i in range(input_dim)] + crop_min = [0] + crop_min + crop_max = [C] + crop_max + x = crop_ND_volume_with_bounding_box(image, crop_min, crop_max) + + flip_axis = [] + if(random.random() > 0.5): + flip_axis.append(-1) + if(random.random() > 0.5): + flip_axis.append(-2) + if(random.random() > 0.5): + flip_axis.append(-3) + if(len(flip_axis) > 0): + x = np.flip(x, flip_axis).copy() + + if(x.shape[1] == 63): + print("crop shape == 63", x.shape) + print(sample['names']) + print(image.shape, crop_min, crop_max) + return x \ No newline at end of file diff --git a/pymic/transform/normalize.py b/pymic/transform/normalize.py index 4e493dd..35c5dc4 100644 --- a/pymic/transform/normalize.py +++ b/pymic/transform/normalize.py @@ -24,51 +24,57 @@ class NormalizeWithMeanStd(AbstractTransform): :param `NormalizeWithMeanStd_std`: (list/tuple or None) The std values along each specified channel. If None, the std values are calculated automatically. - :param `NormalizeWithMeanStd_ignore_non_positive`: (optional, bool) - Only used when mean and std are not given. Default is False. - If True, calculate mean and std in the positive region for normalization, - and set non-positive region to random. If False, calculate - the mean and std values in the entire image region. + :param `NormalizeWithMeanStd_mask_threshold`: (optional, float) + Only used when mean and std are not given. Default is 1.0. + Calculate mean and std in the mask region where the intensity is higher than the mask. + :param `NormalizeWithMeanStd_set_background_to_random`: (optional, bool) + Set background region to random or not, and only applicable when + `NormalizeWithMeanStd_mask_threshold` is not None. Default is True. :param `NormalizeWithMeanStd_inverse`: (optional, bool) Is inverse transform needed for inference. Default is `False`. """ def __init__(self, params): super(NormalizeWithMeanStd, self).__init__(params) - self.chns = params['NormalizeWithMeanStd_channels'.lower()] + self.chns = params.get('NormalizeWithMeanStd_channels'.lower(), None) self.mean = params.get('NormalizeWithMeanStd_mean'.lower(), None) self.std = params.get('NormalizeWithMeanStd_std'.lower(), None) - self.ingore_np = params.get('NormalizeWithMeanStd_ignore_non_positive'.lower(), False) - self.inverse = params.get('NormalizeWithMeanStd_inverse'.lower(), False) + self.mask_thrd = params.get('NormalizeWithMeanStd_mask_threshold'.lower(), None) + self.bg_random = params.get('NormalizeWithMeanStd_set_background_to_random'.lower(), True) + self.inverse = params.get('NormalizeWithMeanStd_inverse'.lower(), False) def __call__(self, sample): image= sample['image'] - chns = self.chns if self.chns is not None else range(image.shape[0]) + if(self.chns is None): + self.chns = range(image.shape[0]) if(self.mean is None): - self.mean = [None] * len(chns) - self.std = [None] * len(chns) + self.mean = [None] * len(self.chns) + self.std = [None] * len(self.chns) - for i in range(len(chns)): - chn = chns[i] + for i in range(len(self.chns)): + chn = self.chns[i] chn_mean, chn_std = self.mean[i], self.std[i] if(chn_mean is None): - if(self.ingore_np): - pixels = image[chn][image[chn] > 0] - chn_mean, chn_std = pixels.mean(), pixels.std() + if(self.mask_thrd is not None): + pixels = image[chn][image[chn] > self.mask_thrd] + if(len(pixels) > 0): + chn_mean, chn_std = pixels.mean(), pixels.std() + 1e-5 + else: + chn_mean, chn_std = 0.0, 1.0 else: - chn_mean, chn_std = image[chn].mean(), image[chn].std() + chn_mean, chn_std = image[chn].mean(), image[chn].std() + 1e-5 chn_norm = (image[chn] - chn_mean)/chn_std - if(self.ingore_np): + if(self.mask_thrd is not None and self.bg_random): chn_random = np.random.normal(0, 1, size = chn_norm.shape) - chn_norm[image[chn] <= 0] = chn_random[image[chn] <= 0] + chn_norm[image[chn] <= self.mask_thrd] = chn_random[image[chn] <=self.mask_thrd] image[chn] = chn_norm sample['image'] = image return sample class NormalizeWithMinMax(AbstractTransform): - """Nomralize the image to [0, 1]. The shape should be [C, D, H, W] or [C, H, W]. + """Nomralize the image to [-1, 1]. The shape should be [C, D, H, W] or [C, H, W]. The arguments should be written in the `params` dictionary, and it has the following fields: @@ -106,13 +112,13 @@ def __call__(self, sample): img_chn[img_chn < v0] = v0 img_chn[img_chn > v1] = v1 - img_chn = (img_chn - v0) / (v1 - v0) + img_chn = 2.0* (img_chn - v0) / (v1 - v0) -1.0 image[chn] = img_chn sample['image'] = image return sample class NormalizeWithPercentiles(AbstractTransform): - """Nomralize the image to [0, 1] with percentiles for given channels. + """Nomralize the image to [-1, 1] with percentiles for given channels. The shape should be [C, D, H, W] or [C, H, W]. The arguments should be written in the `params` dictionary, and it has the @@ -125,14 +131,17 @@ class NormalizeWithPercentiles(AbstractTransform): The min percentile, which must be between 0 and 100 inclusive. :param `NormalizeWithPercentiles_percentile_upper`: (float) The max percentile, which must be between 0 and 100 inclusive. + :param `NormalizeWithPercentiles_output_mode`: (int) 0: the output is in the range [0,1] + Otherwise the output is in the range of [-1, 1] :param `NormalizeWithMinMax_inverse`: (optional, bool) Is inverse transform needed for inference. Default is `False`. """ def __init__(self, params): super(NormalizeWithPercentiles, self).__init__(params) - self.chns = params['NormalizeWithPercentiles_channels'.lower()] - self.percent_lower = params['NormalizeWithPercentiles_percentile_lower'.lower()] - self.percent_upper = params['NormalizeWithPercentiles_percentile_upper'.lower()] + self.chns = params.get('NormalizeWithPercentiles_channels'.lower(), None) + self.percent_lower = params.get('NormalizeWithPercentiles_percentile_lower'.lower(), 0.1) + self.percent_upper = params.get('NormalizeWithPercentiles_percentile_upper'.lower(), 99.9) + self.out_mode = params.get('NormalizeWithPercentiles_output_mode'.lower(), 0) self.inverse = params.get('NormalizeWithPercentiles_inverse'.lower(), False) def __call__(self, sample): @@ -146,7 +155,13 @@ def __call__(self, sample): img_chn[img_chn < v0] = v0 img_chn[img_chn > v1] = v1 - img_chn = (img_chn - v0) / (v1 - v0) + if(self.out_mode == 0): + img_chn = (img_chn - v0) / (v1 - v0) + img_chn = np.clip(img_chn, 0, 1) + else: + img_chn = 2.0* (img_chn - v0) / (v1 - v0) -1.0 + img_chn = np.clip(img_chn, -1, 1) + image[chn] = img_chn sample['image'] = image return sample \ No newline at end of file diff --git a/pymic/transform/pad.py b/pymic/transform/pad.py index 0ec196c..509643d 100644 --- a/pymic/transform/pad.py +++ b/pymic/transform/pad.py @@ -6,7 +6,7 @@ import math import random import numpy as np -from scipy import ndimage +from pymic import TaskType from pymic.transform.abstract_transform import AbstractTransform from pymic.util.image_process import * @@ -21,7 +21,7 @@ class Pad(AbstractTransform): following fields: :param `Pad_output_size`: (list/tuple) The output size along each spatial axis. - :param `Pad_ceil_mode`: (optional, bool) If true (by default), the real output size will + :param `Pad_ceil_mode`: (optional, bool) If true, the real output size will be the minimal integer multiples of output_size higher than the input size. For example, the input image has a shape of [3, 100, 100], `Pad_output_size` = [32, 32], and the real output size will be [3, 128, 128] if `Pad_ceil_mode` = True. @@ -38,6 +38,11 @@ def __call__(self, sample): image = sample['image'] input_shape = image.shape input_dim = len(input_shape) - 1 + + if(input_dim == 3): + if(len(self.output_size) == 2): + # for 3D images, igore the z-axis + self.output_size = [input_shape[1]] + list(self.output_size) assert(len(self.output_size) == input_dim) if(self.ceil_mode): multiple = [int(math.ceil(float(input_shape[1+i])/self.output_size[i]))\ @@ -59,11 +64,13 @@ def __call__(self, sample): sample['image'] = image_t - if('label' in sample and self.task == 'segmentation'): + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): label = sample['label'] label = np.pad(label, pad, 'reflect') if(max(margin) > 0) else label sample['label'] = label - if('pixel_weight' in sample and self.task == 'segmentation'): + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): weight = sample['pixel_weight'] weight = np.pad(weight, pad, 'reflect') if(max(margin) > 0) else weight sample['pixel_weight'] = weight diff --git a/pymic/transform/rescale.py b/pymic/transform/rescale.py index 2a671fd..ba519c7 100644 --- a/pymic/transform/rescale.py +++ b/pymic/transform/rescale.py @@ -5,6 +5,7 @@ import random import numpy as np from scipy import ndimage +from pymic import TaskType from pymic.transform.abstract_transform import AbstractTransform from pymic.util.image_process import * @@ -16,9 +17,9 @@ class Rescale(AbstractTransform): following fields: :param `Rescale_output_size`: (list/tuple or int) The output size along each spatial axis, - such as [D, H, W] or [H, W]. If D is None, the input image is only reslcaled in 2D. - If int, the smallest axis is matched to output_size keeping aspect ratio the same - as the input. + such as [D, H, W] or [H, W]. For 3D images, if D is None, or the lenght of tuple/list is 2, + the input image is only reslcaled in 2D. If int, the smallest axis is matched to output_size + keeping aspect ratio the same as the input. :param `Rescale_inverse`: (optional, bool) Is inverse transform needed for inference. Default is `True`. """ @@ -37,6 +38,8 @@ def __call__(self, sample): output_size = self.output_size if(output_size[0] is None): output_size[0] = input_shape[1] + if(input_dim == 3 and len(self.output_size) == 2): + output_size = [input_shape[1]] + list(output_size) assert(len(output_size) == input_dim) else: min_edge = min(input_shape[1:]) @@ -48,11 +51,13 @@ def __call__(self, sample): sample['image'] = image_t sample['Rescale_origin_shape'] = json.dumps(input_shape) - if('label' in sample and self.task == 'segmentation'): + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): label = sample['label'] label = ndimage.interpolation.zoom(label, scale, order = 0) sample['label'] = label - if('pixel_weight' in sample and self.task == 'segmentation'): + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): weight = sample['pixel_weight'] weight = ndimage.interpolation.zoom(weight, scale, order = 1) sample['pixel_weight'] = weight @@ -67,12 +72,22 @@ def inverse_transform_for_prediction(self, sample): origin_shape = json.loads(sample['Rescale_origin_shape']) origin_dim = len(origin_shape) - 1 predict = sample['predict'] - input_shape = predict.shape - scale = [(origin_shape[1:][i] + 0.0)/input_shape[2:][i] for \ - i in range(origin_dim)] - scale = [1.0, 1.0] + scale - output_predict = ndimage.interpolation.zoom(predict, scale, order = 1) + if(isinstance(predict, tuple) or isinstance(predict, list)): + output_predict = [] + for predict_i in predict: + input_shape = predict_i.shape + scale = [(origin_shape[1:][i] + 0.0)/input_shape[2:][i] for \ + i in range(origin_dim)] + scale = [1.0, 1.0] + scale + output_predict_i = ndimage.interpolation.zoom(predict_i, scale, order = 1) + output_predict.append(output_predict_i) + else: + input_shape = predict.shape + scale = [(origin_shape[1:][i] + 0.0)/input_shape[2:][i] for \ + i in range(origin_dim)] + scale = [1.0, 1.0] + scale + output_predict = ndimage.interpolation.zoom(predict, scale, order = 1) sample['predict'] = output_predict return sample @@ -97,27 +112,26 @@ def __init__(self, params): self.ratio0 = params["RandomRescale_lower_bound".lower()] self.ratio1 = params["RandomRescale_upper_bound".lower()] self.prob = params.get('RandomRescale_probability'.lower(), 0.5) - self.inverse = params.get("RandomRescale_inverse".lower(), True) + self.inverse = params.get("RandomRescale_inverse".lower(), False) assert isinstance(self.ratio0, (float, list, tuple)) assert isinstance(self.ratio1, (float, list, tuple)) def __call__(self, sample): - # if(random.random() > self.prob): - # print("rescale not started") - # sample['RandomRescale_triggered'] = False - # return sample - # else: - # print("rescale started") - # sample['RandomRescale_triggered'] = True + image = sample['image'] input_shape = image.shape input_dim = len(input_shape) - 1 - + assert(input_dim == len(self.ratio0) and input_dim == len(self.ratio1)) + if isinstance(self.ratio0, (list, tuple)): - for i in range(len(self.ratio0)): + for i in range(input_dim): + if(self.ratio0[i] is None): + self.ratio0[i] = 1.0 + if(self.ratio1[i] is None): + self.ratio1[i] = 1.0 assert(self.ratio0[i] <= self.ratio1[i]) scale = [self.ratio0[i] + random.random()*(self.ratio1[i] - self.ratio0[i]) \ - for i in range(len(self.ratio0))] + for i in range(input_dim)] else: scale = self.ratio0 + random.random()*(self.ratio1 - self.ratio0) scale = [scale] * input_dim @@ -126,11 +140,13 @@ def __call__(self, sample): sample['image'] = image_t sample['RandomRescale_Param'] = json.dumps(input_shape) - if('label' in sample and self.task == 'segmentation'): + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): label = sample['label'] label = ndimage.interpolation.zoom(label, scale, order = 0) sample['label'] = label - if('pixel_weight' in sample and self.task == 'segmentation'): + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): weight = sample['pixel_weight'] weight = ndimage.interpolation.zoom(weight, scale, order = 1) sample['pixel_weight'] = weight @@ -138,8 +154,6 @@ def __call__(self, sample): return sample def inverse_transform_for_prediction(self, sample): - if(not sample['RandomRescale_triggered']): - return sample if(isinstance(sample['RandomRescale_Param'], list) or \ isinstance(sample['RandomRescale_Param'], tuple)): origin_shape = json.loads(sample['RandomRescale_Param'][0]) @@ -152,6 +166,76 @@ def inverse_transform_for_prediction(self, sample): i in range(origin_dim)] scale = [1.0, 1.0] + scale + output_predict = ndimage.interpolation.zoom(predict, scale, order = 1) + sample['predict'] = output_predict + return sample + + +class Resample(Rescale): + """Resample the image to a given spatial resolution. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `Resample_output_spacing`: (list/tuple or int) The output spacing along each spatial axis, + such as [Ds, Hs, Ws] or [Hs, Ws]. If Ds is None, the input image is only reslcaled in 2D. + :param `Resample_ignore_zspacing_range`: (list/tuple) The range of zspacing that would be ingored. + :param `Resample_inverse`: (optional, bool) + Is inverse transform needed for inference. Default is `True`. + """ + def __init__(self, params): + super(Rescale, self).__init__(params) + self.output_spacing = params["Resample_output_spacing".lower()] + self.ignore_zspacing= params.get("Resample_ignore_zspacing_range".lower(), None) + self.inverse = params.get("Resample_inverse".lower(), True) + + def __call__(self, sample): + image = sample['image'] + input_shape = image.shape + + input_dim = len(input_shape) - 1 + spacing = sample['spacing'] + out_spacing = [item for item in self.output_spacing] + for i in range(input_dim): + out_spacing[i] = spacing[i] if out_spacing[i] is None else out_spacing[i] + if(self.ignore_zspacing is not None): + if(spacing[0] > self.ignore_zspacing[0] and spacing[0] < self.ignore_zspacing[1]): + out_spacing[0] = spacing[0] + scale = [spacing[i] / out_spacing[i] for i in range(input_dim)] + scale = [1.0] + scale + + image_t = ndimage.interpolation.zoom(image, scale, order = 1) + + sample['image'] = image_t + sample['spacing'] = out_spacing + sample['Resample_origin_shape'] = json.dumps(input_shape) + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + label = sample['label'] + label = ndimage.interpolation.zoom(label, scale, order = 0) + sample['label'] = label + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + weight = sample['pixel_weight'] + weight = ndimage.interpolation.zoom(weight, scale, order = 1) + sample['pixel_weight'] = weight + + return sample + + def inverse_transform_for_prediction(self, sample): + if(isinstance(sample['Resample_origin_shape'], list) or \ + isinstance(sample['Resample_origin_shape'], tuple)): + origin_shape = json.loads(sample['Resample_origin_shape'][0]) + else: + origin_shape = json.loads(sample['Resample_origin_shape']) + + origin_dim = len(origin_shape) - 1 + predict = sample['predict'] + input_shape = predict.shape + scale = [(origin_shape[1:][i] + 0.0)/input_shape[2:][i] for \ + i in range(origin_dim)] + scale = [1.0, 1.0] + scale + output_predict = ndimage.interpolation.zoom(predict, scale, order = 1) sample['predict'] = output_predict return sample \ No newline at end of file diff --git a/pymic/transform/rotate.py b/pymic/transform/rotate.py index 2aa06d4..5de77d7 100644 --- a/pymic/transform/rotate.py +++ b/pymic/transform/rotate.py @@ -5,6 +5,7 @@ import random import numpy as np from scipy import ndimage +from pymic import TaskType from pymic.transform.abstract_transform import AbstractTransform from pymic.util.image_process import * @@ -33,8 +34,8 @@ class RandomRotate(AbstractTransform): def __init__(self, params): super(RandomRotate, self).__init__(params) self.angle_range_d = params['RandomRotate_angle_range_d'.lower()] - self.angle_range_h = params['RandomRotate_angle_range_h'.lower()] - self.angle_range_w = params['RandomRotate_angle_range_w'.lower()] + self.angle_range_h = params.get('RandomRotate_angle_range_h'.lower(), None) + self.angle_range_w = params.get('RandomRotate_angle_range_w'.lower(), None) self.prob = params.get('RandomRotate_probability'.lower(), 0.5) self.inverse = params.get('RandomRotate_inverse'.lower(), True) @@ -51,11 +52,6 @@ def __apply_transformation(self, image, transform_param_list, order = 1): return image def __call__(self, sample): - # if(random.random() > self.prob): - # sample['RandomRotate_triggered'] = False - # return sample - # else: - # sample['RandomRotate_triggered'] = True image = sample['image'] input_shape = image.shape input_dim = len(input_shape) - 1 @@ -78,10 +74,12 @@ def __call__(self, sample): sample['RandomRotate_Param'] = json.dumps(transform_param_list) image_t = self.__apply_transformation(image, transform_param_list, 1) sample['image'] = image_t - if('label' in sample and self.task == 'segmentation'): + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): sample['label'] = self.__apply_transformation(sample['label'] , transform_param_list, 0) - if('pixel_weight' in sample and self.task == 'segmentation'): + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): sample['pixel_weight'] = self.__apply_transformation(sample['pixel_weight'] , transform_param_list, 1) return sample @@ -99,4 +97,49 @@ def inverse_transform_for_prediction(self, sample): transform_param_list[i][0] = - transform_param_list[i][0] sample['predict'] = self.__apply_transformation(sample['predict'] , transform_param_list, 1) + return sample + +class RandomRot90(AbstractTransform): + """ + Random rotate an image in x-y plane with angles in [90, 180, 270]. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `RandomRot90_probability`: (optional, float) + The probability of applying RandomRot90. Default is 0.75. + :param `RandomRot90_inverse`: (optional, bool) + Is inverse transform needed for inference. Default is `True`. + """ + def __init__(self, params): + super(RandomRot90, self).__init__(params) + self.prob = params.get('RandomRot90_probability'.lower(), 0.75) + self.inverse = params.get('RandomRot90_inverse'.lower(), True) + + def __call__(self, sample): + if(random.random() > self.prob): + sample['RandomRot90_triggered'] = False + sample['RandomRot90_Param'] = 0 + return sample + else: + sample['RandomRot90_triggered'] = True + image = sample['image'] + rote_k = random.randint(1, 3) + sample['RandomRot90_Param'] = rote_k + image_t = np.rot90(image, rote_k, (-2, -1)) + sample['image'] = image_t + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + sample['label'] = np.rot90(sample['label'], rote_k, (-2, -1)) + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + sample['pixel_weight'] = np.rot90(sample['pixel_weight'], rote_k, (-2, -1)) + return sample + + def inverse_transform_for_prediction(self, sample): + if(not sample['RandomRot90_triggered']): + return sample + rote_k = sample['RandomRot90_Param'] + rote_i = 4 - rote_k + sample['predict'] = np.rot90(sample['predict'], rote_i, (-2, -1)) return sample \ No newline at end of file diff --git a/pymic/transform/trans_dict.py b/pymic/transform/trans_dict.py index bc72c93..e4bfe24 100644 --- a/pymic/transform/trans_dict.py +++ b/pymic/transform/trans_dict.py @@ -14,6 +14,8 @@ 'LabelConvert': LabelConvert, 'LabelConvertNonzero': LabelConvertNonzero, 'LabelToProbability': LabelToProbability, + 'IntensityClip': IntensityClip, + 'NonLinearTransform': NonLinearTransform, 'NormalizeWithMeanStd': NormalizeWithMeanStd, 'NormalizeWithMinMax': NormalizeWithMinMax, 'NormalizeWithPercentiles': NormalizeWithPercentiles, @@ -23,12 +25,15 @@ 'RandomRescale': RandomRescale, 'RandomFlip': RandomFlip, 'RandomRotate': RandomRotate, + 'RandomRot90': RandomRot90, 'ReduceLabelDim': ReduceLabelDim, 'Rescale': Rescale, + 'SelfSuperviseLabel': SelfSuperviseLabel, 'Pad': Pad. """ from __future__ import print_function, division +from pymic.transform.affine import * from pymic.transform.intensity import * from pymic.transform.flip import * from pymic.transform.pad import * @@ -38,35 +43,58 @@ from pymic.transform.threshold import * from pymic.transform.normalize import * from pymic.transform.crop import * +from pymic.transform.crop4dino import Crop4Dino +from pymic.transform.crop4voco import Crop4VoCo +from pymic.transform.crop4vox2vec import Crop4Vox2Vec +from pymic.transform.crop4vf import Crop4VolumeFusion, VolumeFusion, VolumeFusionShuffle from pymic.transform.label_convert import * TransformDict = { + 'Affine': Affine, + 'AdaptiveContrastAdjust': AdaptiveContrastAdjust, 'ChannelWiseThreshold': ChannelWiseThreshold, 'ChannelWiseThresholdWithNormalize': ChannelWiseThresholdWithNormalize, 'CropWithBoundingBox': CropWithBoundingBox, + 'CropWithForeground': CropWithForeground, + 'CropHumanRegion': CropHumanRegion, 'CenterCrop': CenterCrop, + 'Crop4Dino': Crop4Dino, + 'Crop4VoCo': Crop4VoCo, + 'Crop4Vox2Vec': Crop4Vox2Vec, + 'Crop4VolumeFusion': Crop4VolumeFusion, 'GrayscaleToRGB': GrayscaleToRGB, 'GammaCorrection': GammaCorrection, 'GaussianNoise': GaussianNoise, + 'HistEqual': HistEqual, 'InPainting': InPainting, 'InOutPainting': InOutPainting, 'LabelConvert': LabelConvert, 'LabelConvertNonzero': LabelConvertNonzero, 'LabelToProbability': LabelToProbability, 'LocalShuffling': LocalShuffling, + 'IntensityClip': IntensityClip, + 'MaskedImageModeling': MaskedImageModeling, 'NonLinearTransform': NonLinearTransform, 'NormalizeWithMeanStd': NormalizeWithMeanStd, 'NormalizeWithMinMax': NormalizeWithMinMax, 'NormalizeWithPercentiles': NormalizeWithPercentiles, 'PartialLabelToProbability':PartialLabelToProbability, 'RandomCrop': RandomCrop, + 'RandomSlice': RandomSlice, 'RandomResizedCrop': RandomResizedCrop, 'RandomRescale': RandomRescale, 'RandomTranspose': RandomTranspose, 'RandomFlip': RandomFlip, 'RandomRotate': RandomRotate, + 'RandomRot90': RandomRot90, 'ReduceLabelDim': ReduceLabelDim, 'Rescale': Rescale, + 'Resample': Resample, + 'SelfReconstructionLabel': SelfReconstructionLabel, + 'MaskedImageModelingLabel': MaskedImageModelingLabel, 'OutPainting': OutPainting, 'Pad': Pad, + 'PatchSwaping':PatchSwaping, + 'VolumeFusion': VolumeFusion, + 'VolumeFusionShuffle': VolumeFusionShuffle } diff --git a/pymic/transform/transpose.py b/pymic/transform/transpose.py index 9c73bda..6f5d5fa 100644 --- a/pymic/transform/transpose.py +++ b/pymic/transform/transpose.py @@ -4,6 +4,7 @@ import json import random import numpy as np +from pymic import TaskType from pymic.transform.abstract_transform import AbstractTransform @@ -39,11 +40,12 @@ def __call__(self, sample): if(transpose_axis is not None): image_t = np.transpose(image, transpose_axis) sample['image'] = image_t - if('label' in sample and self.task == 'segmentation'): + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): sample['label'] = np.transpose(sample['label'] , transpose_axis) - if('pixel_weight' in sample and self.task == 'segmentation'): - sample['pixel_weight'] = np.transpose(sample['pixel_weight'] , transpose_axis) - + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + sample['pixel_weight'] = np.transpose(sample['pixel_weight'] , transpose_axis) return sample def inverse_transform_for_prediction(self, sample): diff --git a/pymic/util/evaluation_cls.py b/pymic/util/evaluation_cls.py index af11a17..686a811 100644 --- a/pymic/util/evaluation_cls.py +++ b/pymic/util/evaluation_cls.py @@ -3,7 +3,7 @@ Evaluation module for classification tasks. """ from __future__ import absolute_import, print_function - +import argparse import os import csv import sys @@ -75,15 +75,15 @@ def binary_evaluation(config): The arguments are given in the `config` dictionary. It should have the following fields: - :param metric_list: (list) A list of evaluation metrics. + :param metric: (list) A list of evaluation metrics. The supported metrics are {`accuracy`, `recall`, `sensitivity`, `specificity`, `precision`, `auc`}. - :param ground_truth_csv: (str) The csv file for ground truth. - :param predict_prob_csv: (str) The csv file for prediction probability. + :param gt_csv: (str) The csv file for ground truth. + :param pred_prob_csv: (str) The csv file for prediction probability. """ - metric_list = config['metric_list'] - gt_csv = config['ground_truth_csv'] - prob_csv= config['predict_prob_csv'] + metric_list = config['metric'] + gt_csv = config['gt_csv'] + prob_csv= config['pred_prob_csv'] gt_items = pd.read_csv(gt_csv) prob_items = pd.read_csv(prob_csv) assert(len(gt_items) == len(prob_items)) @@ -111,15 +111,15 @@ def nexcl_evaluation(config): The arguments are given in the `config` dictionary. It should have the following fields: - :param metric_list: (list) A list of evaluation metrics. + :param metric: (list) A list of evaluation metrics. The supported metrics are {`accuracy`, `recall`, `sensitivity`, `specificity`, `precision`, `auc`}. - :param ground_truth_csv: (str) The csv file for ground truth. - :param predict_prob_csv: (str) The csv file for prediction probability. + :param gt_csv: (str) The csv file for ground truth. + :param pred_prob_csv: (str) The csv file for prediction probability. """ - metric_list = config['metric_list'] - gt_csv = config['ground_truth_csv'] - prob_csv = config['predict_prob_csv'] + metric_list = config['metric'] + gt_csv = config['gt_csv'] + prob_csv = config['pred_prob_csv'] gt_items = pd.read_csv(gt_csv) prob_items= pd.read_csv(prob_csv) assert(len(gt_items) == len(prob_items)) @@ -163,25 +163,35 @@ def main(): .. code-block:: none - pymic_evaluate_cls config.cfg + pymic_evaluate_cls -cfg config.cfg The configuration file should have an `evaluation` section with the following fields: :param task_type: (str) `cls` or `cls_nexcl`. - :param metric_list: (list) A list of evaluation metrics. + :param metric: (list) A list of evaluation metrics. The supported metrics are {`accuracy`, `recall`, `sensitivity`, `specificity`, `precision`, `auc`}. - :param ground_truth_csv: (str) The csv file for ground truth. - :param predict_prob_csv: (str) The csv file for prediction probability. + :param gt_csv: (str) The csv file for ground truth. + :param pred_prob_csv: (str) The csv file for prediction probability. """ - if(len(sys.argv) < 2): - print('Number of arguments should be 2. e.g.') - print(' pymic_evaluate_cls config.cfg') - exit() - config_file = str(sys.argv[1]) - assert(os.path.isfile(config_file)) - config = parse_config(config_file)['evaluation'] + parser = argparse.ArgumentParser() + parser.add_argument("--cfg", help="configuration file for evaluation", + required=False, default=None) + parser.add_argument("--metric", help="evaluation metrics, e.g., accuracy, or [accuracy, auc]", + required=False, default=None) + parser.add_argument("--gt_csv", help="csv file for ground truth", + required=False, default=None) + parser.add_argument("--pred_prob_csv", help="csv file for probability prediction", + required=False, default=None) + args = parser.parse_args() + print(args) + if(args.cfg is not None): + config = parse_config(args)['evaluation'] + + # config_file = str(sys.argv[1]) + # assert(os.path.isfile(config_file)) + # config = parse_config(config_file)['evaluation'] task_type = config.get('task_type', "cls") if(task_type == "cls"): # default exclusive classification binary_evaluation(config) diff --git a/pymic/util/evaluation_seg.py b/pymic/util/evaluation_seg.py index ba04a73..5099d3b 100644 --- a/pymic/util/evaluation_seg.py +++ b/pymic/util/evaluation_seg.py @@ -3,15 +3,19 @@ Evaluation module for segmenation tasks. """ from __future__ import absolute_import, print_function +import argparse import csv import os import sys import pandas as pd import numpy as np +from os.path import join from scipy import ndimage from pymic.io.image_read_write import * from pymic.util.image_process import * -from pymic.util.parse_config import parse_config +from pymic.util.general import is_image_name +from pymic.util.parse_config import parse_config, parse_value_from_string + def binary_dice(s, g, resize = False): @@ -104,22 +108,30 @@ def binary_hd95(s, g, spacing = None): """ s_edge = get_edge_points(s) g_edge = get_edge_points(g) - image_dim = len(s.shape) - assert(image_dim == len(g.shape)) - if(spacing == None): - spacing = [1.0] * image_dim + ns = s_edge.sum() + ng = g_edge.sum() + if(ns + ng == 0): + hd95 = 0.0 + elif(ns * ng == 0): + hd95 = 100.0 else: - assert(image_dim == len(spacing)) - s_dis = ndimage.distance_transform_edt(1-s_edge, sampling = spacing) - g_dis = ndimage.distance_transform_edt(1-g_edge, sampling = spacing) - - dist_list1 = s_dis[g_edge > 0] - dist_list1 = sorted(dist_list1) - dist1 = dist_list1[int(len(dist_list1)*0.95)] - dist_list2 = g_dis[s_edge > 0] - dist_list2 = sorted(dist_list2) - dist2 = dist_list2[int(len(dist_list2)*0.95)] - return max(dist1, dist2) + image_dim = len(s.shape) + assert(image_dim == len(g.shape)) + if(spacing == None): + spacing = [1.0] * image_dim + else: + assert(image_dim == len(spacing)) + s_dis = ndimage.distance_transform_edt(1-s_edge, sampling = spacing) + g_dis = ndimage.distance_transform_edt(1-g_edge, sampling = spacing) + + dist_list1 = s_dis[g_edge > 0] + dist_list1 = sorted(dist_list1) + dist1 = dist_list1[int(len(dist_list1)*0.95)] + dist_list2 = g_dis[s_edge > 0] + dist_list2 = sorted(dist_list2) + dist2 = dist_list2[int(len(dist_list2)*0.95)] + hd95 = max(dist1, dist2) + return hd95 def binary_assd(s, g, spacing = None): @@ -146,9 +158,14 @@ def binary_assd(s, g, spacing = None): ns = s_edge.sum() ng = g_edge.sum() - s_dis_g_edge = s_dis * g_edge - g_dis_s_edge = g_dis * s_edge - assd = (s_dis_g_edge.sum() + g_dis_s_edge.sum()) / (ns + ng) + if(ns + ng == 0): + assd = 0.0 + elif(ns*ng == 0): + assd = 20.0 + else: + s_dis_g_edge = s_dis * g_edge + g_dis_s_edge = g_dis * s_edge + assd = (s_dis_g_edge.sum() + g_dis_s_edge.sum()) / (ns + ng) return assd # relative volume error evaluation @@ -195,8 +212,10 @@ def get_binary_evaluation_score(s_volume, g_volume, spacing, metric): score = binary_iou(s_volume,g_volume) elif(metric_lower == 'assd'): score = binary_assd(s_volume, g_volume, spacing) + score = min(score, 20) # to reject outliers elif(metric_lower == "hd95"): score = binary_hd95(s_volume, g_volume, spacing) + score = min(score, 50) # to reject outliers elif(metric_lower == "rve"): score = binary_relative_volume_error(s_volume, g_volume) elif(metric_lower == "volume"): @@ -252,72 +271,59 @@ def evaluation(config): :param label_fuse: (option, bool) If true, fuse the labels in the `label_list` as the foreground, and other labels as the background. Default is False. :param organ_name: (str) The name of the organ for segmentation. - :param ground_truth_folder_root: (str) The root dir of ground truth images. - :param segmentation_folder_root: (str or list) The root dir of segmentation images. + :param ground_truth_folder: (str) The root dir of ground truth images. + :param segmentation_folder: (str or list) The root dir of segmentation images. When a list is given, each list element should be the root dir of the results of one method. :param evaluation_image_pair: (str) The csv file that provide the segmentation images and the corresponding ground truth images. - :param ground_truth_label_convert_source: (optional, list) The list of source - labels for label conversion in the ground truth. - :param ground_truth_label_convert_target: (optional, list) The list of target - labels for label conversion in the ground truth. - :param segmentation_label_convert_source: (optional, list) The list of source - labels for label conversion in the segmentation. - :param segmentation_label_convert_target: (optional, list) The list of target - labels for label conversion in the segmentation. """ metric_list = config['metric_list'] - label_list = config['label_list'] - label_fuse = config.get('label_fuse', False) - organ_name = config['organ_name'] - gt_root = config['ground_truth_folder_root'] - seg_root = config['segmentation_folder_root'] - if(not(isinstance(seg_root, tuple) or isinstance(seg_root, list))): - seg_root = [seg_root] - image_pair_csv = config['evaluation_image_pair'] - ground_truth_label_convert_source = config.get('ground_truth_label_convert_source', None) - ground_truth_label_convert_target = config.get('ground_truth_label_convert_target', None) - segmentation_label_convert_source = config.get('segmentation_label_convert_source', None) - segmentation_label_convert_target = config.get('segmentation_label_convert_target', None) - - image_items = pd.read_csv(image_pair_csv) - item_num = len(image_items) - - for seg_root_n in seg_root: # for each segmentation method + if(not isinstance(metric_list, list)): + metric_list = [metric_list] + label_list = config.get('label_list', None) + if(label_list is None): + label_list = range(1, config["class_number"]) + elif(not isinstance(label_list, list)): + label_list = [label_list] + label_fuse = config.get('label_fuse', False) + output_name = config.get('output_name', None) + gt_dir = config['ground_truth_folder'] + seg_dirs = config['segmentation_folder'] + image_pair_csv = config.get('evaluation_image_pair', None) + + if(not isinstance(seg_dirs, (tuple, list))): + seg_dirs = [seg_dirs] + if(image_pair_csv is not None): + image_pair = pd.read_csv(image_pair_csv) + gt_names, seg_names = image_pair.iloc[:, 0], image_pair.iloc[:, 1] + else: + seg_names = sorted(os.listdir(seg_dirs[0])) + seg_names = [item for item in seg_names if is_image_name(item)] + gt_names = seg_names + + for seg_dir in seg_dirs: for metric in metric_list: + print(metric) score_all_data = [] name_score_list= [] - for i in range(item_num): - gt_name = image_items.iloc[i, 0] - seg_name = image_items.iloc[i, 1] - # seg_name = seg_name.replace(".nii.gz", "_pred.nii.gz") - gt_full_name = gt_root + '/' + gt_name - seg_full_name = seg_root_n + '/' + seg_name - + for i in range(len(gt_names)): + gt_full_name = join(gt_dir, gt_names[i]) + seg_full_name = join(seg_dir, seg_names[i]) s_dict = load_image_as_nd_array(seg_full_name) g_dict = load_image_as_nd_array(gt_full_name) s_volume = s_dict["data_array"]; s_spacing = s_dict["spacing"] g_volume = g_dict["data_array"]; g_spacing = g_dict["spacing"] # for dim in range(len(s_spacing)): # assert(s_spacing[dim] == g_spacing[dim]) - if((ground_truth_label_convert_source is not None) and \ - ground_truth_label_convert_target is not None): - g_volume = convert_label(g_volume, ground_truth_label_convert_source, \ - ground_truth_label_convert_target) - - if((segmentation_label_convert_source is not None) and \ - segmentation_label_convert_target is not None): - s_volume = convert_label(s_volume, segmentation_label_convert_source, \ - segmentation_label_convert_target) score_vector = get_multi_class_evaluation_score(s_volume, g_volume, label_list, label_fuse, s_spacing, metric ) if(len(label_list) > 1): score_vector.append(np.asarray(score_vector).mean()) score_all_data.append(score_vector) - name_score_list.append([seg_name] + score_vector) - print(seg_name, score_vector) + name_score_list.append([seg_names[i]] + score_vector) + print(seg_names[i], score_vector) score_all_data = np.asarray(score_all_data) score_mean = score_all_data.mean(axis = 0) score_std = score_all_data.std(axis = 0) @@ -325,8 +331,11 @@ def evaluation(config): name_score_list.append(['std'] + list(score_std)) # save the result as csv - score_csv = "{0:}/{1:}_{2:}_all.csv".format(seg_root_n, organ_name, metric) - with open(score_csv, mode='w') as csv_file: + if(output_name is None): + metric_output_name = "{0:}/eval_{1:}.csv".format(seg_dir, metric) + else: + metric_output_name = output_name + with open(metric_output_name, mode='w') as csv_file: csv_writer = csv.writer(csv_file, delimiter=',', quotechar='"',quoting=csv.QUOTE_MINIMAL) head = ['image'] + ["class_{0:}".format(i) for i in label_list] @@ -342,23 +351,56 @@ def evaluation(config): def main(): """ Main function for evaluation of segmentation results. - A configuration file is needed for runing. e.g., + You can use a configuration file for runing. e.g., .. code-block:: none - pymic_evaluate_cls config.cfg + pymic_evaluate_seg -cfg config.cfg The configuration file should have an `evaluation` section. See :mod:`pymic.util.evaluation_seg.evaluation` for details of the configuration required. + + In addition, you can also provide a list of args in the command if -cfg is not used. For example: + + .. code-block:: none + + pymic_evaluate_seg -metric dice -cls_index 255 -gt_dir ground_truth_dir -seg_dir segmentation_dir + """ - if(len(sys.argv) < 2): - print('Number of arguments should be 2. e.g.') - print(' pymic_evaluate_seg config.cfg') - exit() - config_file = str(sys.argv[1]) - assert(os.path.isfile(config_file)) - config = parse_config(config_file)['evaluation'] + parser = argparse.ArgumentParser() + parser.add_argument("--cfg", help="configuration file for evaluation", + required=False, default=None) + parser.add_argument("--metric", help="evaluation metrics, e.g., dice, or [dice, assd]", + required=False, default=None) + parser.add_argument("--cls_num", help="number of classes", + required=False, default=None) + parser.add_argument("--cls_index", help="The class index for evaluation, e.g., 255, [1, 2]", + required=False, default=None) + parser.add_argument("--gt_dir", help="path of folder for ground truth", + required=False, default=None) + parser.add_argument("--seg_dir", help="path of folder for segmentation", + required=False, default=None) + parser.add_argument("--name_pair", help="the .csv file for name mapping in case" + " the names of one case are different in the gt_dir " + " and seg_dir", + required=False, default=None) + parser.add_argument("--out", help="the output .csv file name", + required=False, default=None) + args = parser.parse_args() + print(args) + if(args.cfg is not None): + config = parse_config(args)['evaluation'] + else: + config = {} + config['metric_list'] = parse_value_from_string(args.metric) + config['label_list'] = None if args.cls_index is None else parse_value_from_string(args.cls_index) + config['class_number']= None if args.cls_num is None else parse_value_from_string(args.cls_num) + config['ground_truth_folder'] = args.gt_dir + config['segmentation_folder'] = args.seg_dir + config['evaluation_image_pair'] = args.name_pair + config['output_name'] = args.out + print(config) evaluation(config) - + if __name__ == '__main__': main() diff --git a/pymic/util/general.py b/pymic/util/general.py index 075d2e1..cf52746 100644 --- a/pymic/util/general.py +++ b/pymic/util/general.py @@ -26,7 +26,15 @@ def tensor_shape_match(a,b): return False return True - +def is_image_name(x): + valid_names = ["jpg", "jpeg", "png", "bmp", "nii.gz", + "tif", "nii", "nii.gz", "mha"] + valid = False + for item in valid_names: + if(x.endswith(item)): + valid = True + break + return valid def get_one_hot_seg(label, class_num): """ diff --git a/pymic/util/image_process.py b/pymic/util/image_process.py index 9a484e8..c31f28f 100644 --- a/pymic/util/image_process.py +++ b/pymic/util/image_process.py @@ -1,9 +1,13 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, print_function +import csv +import random +import pandas as pd import numpy as np import SimpleITK as sitk from scipy import ndimage +from pymic.io.image_read_write import load_image_as_nd_array def get_ND_bounding_box(volume, margin = None): """ @@ -34,6 +38,18 @@ def get_ND_bounding_box(volume, margin = None): bb_max[i] = min(bb_max[i] + margin[i], input_shape[i]) return bb_min, bb_max +def get_human_region_from_ct(image, threshold_i = -600, threshold_z = 0.6): + input_shape = image.shape + mask = np.asarray(image > threshold_i) + mask2d = np.mean(mask, axis = 0) > threshold_z + se = np.ones([3,3]) + mask2d = ndimage.binary_opening(mask2d, se, iterations = 2) + mask2d = get_largest_k_components(mask2d, 1) + bbmin, bbmax = get_ND_bounding_box(mask2d, margin = [0, 0]) + bb_min = [0] + bbmin + bb_max = list(input_shape[:1]) + bbmax + return bb_min, bb_max + def crop_ND_volume_with_bounding_box(volume, bb_min, bb_max): """ Extract a subregion form an ND image. @@ -57,7 +73,7 @@ def crop_ND_volume_with_bounding_box(volume, bb_min, bb_max): output = volume[bb_min[0]:bb_max[0], bb_min[1]:bb_max[1], bb_min[2]:bb_max[2], bb_min[3]:bb_max[3], bb_min[4]:bb_max[4]] else: raise ValueError("the dimension number shoud be 2 to 5") - return output + return output * 1 def set_ND_volume_roi_with_bounding_box_range(volume, bb_min, bb_max, sub_volume, addition = True): """ @@ -96,7 +112,7 @@ def set_ND_volume_roi_with_bounding_box_range(volume, bb_min, bb_max, sub_volume raise ValueError("array dimension should be 2 to 5") return out -def crop_and_pad_ND_array_to_desired_shape(image, out_shape, pad_mod): +def crop_and_pad_ND_array_to_desired_shape(image, out_shape, pad_mod='reflect'): """ Crop and pad an image to a given shape. @@ -136,6 +152,120 @@ def crop_and_pad_ND_array_to_desired_shape(image, out_shape, pad_mod): return image_pad +def random_crop_ND_volume(volume, out_shape): + """ + randomly crop a volume with to a given shape. + + :param volume: The input ND array. + :param out_shape: (list) The desired output shape. + """ + in_shape = volume.shape + dim = len(in_shape) + + # pad the image first if the input size is smaller than the output size + pad_shape = [max(out_shape[i], in_shape[i]) for i in range(dim)] + mgnp = [pad_shape[i] - in_shape[i] for i in range(dim)] + if(max(mgnp) == 0): + image_pad = volume + else: + ml = [int(mgnp[i]/2) for i in range(dim)] + mr = [mgnp[i] - ml[i] for i in range(dim)] + pad = [(ml[i], mr[i]) for i in range(dim)] + pad = tuple(pad) + image_pad = np.pad(volume, pad, 'reflect') + + bb_min = [random.randint(0, pad_shape[i] - out_shape[i]) for i in range(dim)] + bb_max = [bb_min[i] + out_shape[i] for i in range(dim)] + crop_volume = crop_ND_volume_with_bounding_box(image_pad, bb_min, bb_max) + return crop_volume + +def get_random_box_from_mask(mask, out_shape, mode = 0): + """ + get a bounding box of a subvolume according to a mask + + mode == 0: The output bounding box should be a sub region of the mask region + mode == 1: The center point of the output bounding box can be ahy where of the mask region + """ + dim = len(out_shape) + left_margin = [int(out_shape[i]/2) for i in range(dim)] + right_margin = [out_shape[i] - left_margin[i] for i in range(dim)] + + if(mode == 0): + bb_mask_min, bb_mask_max = get_ND_bounding_box(mask) + bb_valid_min, bb_valid_max = [], [] + for i in range(dim): + mask_size = bb_mask_max[i] - bb_mask_min[i] + if(mask_size > out_shape[i]): + valid_left = bb_mask_min[i] + left_margin[i] + valid_right = bb_mask_max[i] - right_margin[i] + else: + valid_left = (bb_mask_max[i] - bb_mask_min[i]) // 2 + valid_right = valid_left + 1 + bb_valid_min.append(valid_left) + bb_valid_max.append(valid_right) + + valid_region_shape = [bb_valid_max[i] - bb_valid_min[i] for i in range(dim)] + valid_mask = np.zeros_like(mask) + valid_mask = set_ND_volume_roi_with_bounding_box_range(valid_mask, + bb_valid_min, bb_valid_max, np.ones(valid_region_shape, np.bool), addition = True) + valid_mask = valid_mask * mask + else: + valid_mask = mask + + indices = np.where(valid_mask) + voxel_num = len(indices[0]) + j = random.randint(0, voxel_num - 1) + bb_c = [int(indices[i][j]) for i in range(dim)] + bb_min = [max(0, bb_c[i] - left_margin[i]) for i in range(dim)] + mask_shape = np.shape(mask) + bb_min = [min(bb_min[i], mask_shape[i] - out_shape[i]) for i in range(dim)] + bb_max = [bb_min[i] + out_shape[i] for i in range(dim)] + + return bb_min, bb_max + +def random_crop_ND_volume_with_mask(volume, out_shape, mask): + """ + randomly crop a volume with to a given shape. + + :param volume: The input ND array. + :param out_shape: (list) The desired output shape. + :param mask: A binary ND array. Default is None. If not None, + the center of the cropped region should be limited to the mask region. + """ + in_shape = volume.shape + dim = len(in_shape) + # pad the image first if the input size is smaller than the output size + pad_shape = [max(out_shape[i], in_shape[i]) for i in range(dim)] + mgnp = [pad_shape[i] - in_shape[i] for i in range(dim)] + if(max(mgnp) == 0): + image_pad, mask_pad = volume, mask + else: + ml = [int(mgnp[i]/2) for i in range(dim)] + mr = [mgnp[i] - ml[i] for i in range(dim)] + pad = [(ml[i], mr[i]) for i in range(dim)] + pad = tuple(pad) + image_pad = np.pad(volume, pad, 'reflect') + mask_pad = np.pad(mask, pad, 'constant') + + bb_min, bb_max = get_random_box_from_mask(mask_pad, out_shape) + # left_margin = [int(out_shape[i]/2) for i in range(dim)] + # right_margin= [pad_shape[i] - (out_shape[i] - left_margin[i]) + 1 for i in range(dim)] + + # valid_center_shape = [right_margin[i] - left_margin[i] for i in range(dim)] + # valid_mask = np.zeros(pad_shape) + # valid_mask = set_ND_volume_roi_with_bounding_box_range(valid_mask, + # left_margin, right_margin, np.ones(valid_center_shape)) + # valid_mask = valid_mask * mask_pad + + # indexes = np.where(valid_mask) + # voxel_num = len(indexes[0]) + # j = random.randint(0, voxel_num) + # bb_c = [indexes[i][j] for i in range(dim)] + # bb_min = [bb_c[i] - left_margin[i] for i in range(dim)] + # bb_max = [bb_min[i] + out_shape[i] for i in range(dim)] + crop_volume = crop_ND_volume_with_bounding_box(image_pad, bb_min, bb_max) + return crop_volume + def get_largest_k_components(image, k = 1): """ Get the largest K components from 2D or 3D binary image. @@ -143,7 +273,8 @@ def get_largest_k_components(image, k = 1): :param image: The input ND array for binary segmentation. :param k: (int) The value of k. - :return: An output array with only the largest K components of the input. + :return: An output array (k == 1) or a list of ND array (k>1) + with only the largest K components of the input. """ dim = len(image.shape) if(image.sum() == 0 ): @@ -156,11 +287,12 @@ def get_largest_k_components(image, k = 1): sizes = ndimage.sum(image, labeled_array, range(1, numpatches + 1)) sizes_sort = sorted(sizes, reverse = True) kmin = min(k, numpatches) - output = np.zeros_like(image) + output = [] for i in range(kmin): labeli = np.where(sizes == sizes_sort[i])[0] + 1 - output = output + np.asarray(labeled_array == labeli, np.uint8) - return output + output_i = np.asarray(labeled_array == labeli, np.uint8) + output.append(output_i) + return output[0] if k == 1 else output def get_euclidean_distance(image, dim = 3, spacing = [1.0, 1.0, 1.0]): """ @@ -200,14 +332,14 @@ def convert_label(label, source_list, target_list): :param target_list: A list of target labels, e.g. [0, 1, 2, 3] """ assert(len(source_list) == len(target_list)) - label_converted = np.zeros_like(label) + label_converted = label * 1 for i in range(len(source_list)): - label_temp = np.asarray(label == source_list[i], label.dtype) - label_temp = label_temp * target_list[i] - label_converted = label_converted + label_temp + label_s = np.asarray(label == source_list[i], label.dtype) + label_t = label_s * target_list[i] + label_converted[label_s > 0] = label_t[label_s > 0] return label_converted -def resample_sitk_image_to_given_spacing(image, spacing, order): +def resample_sitk_image_to_given_spacing(image, spacing, order = 3): """ Resample an sitk image objct to a given spacing. @@ -226,3 +358,86 @@ def resample_sitk_image_to_given_spacing(image, spacing, order): out_img.SetSpacing(spacing) out_img.SetDirection(image.GetDirection()) return out_img + +def get_image_info(img_names, output_csv = None): + spacing_list, shape_list = [], [] + for img_name in img_names: + img_obj = sitk.ReadImage(img_name) + img_arr = sitk.GetArrayFromImage(img_obj) + spacing = img_obj.GetSpacing() + shape = img_arr.shape + spacing_list.append(spacing) + shape_list.append(shape) + print(img_name, spacing, shape) + spacings = np.asarray(spacing_list) + shapes = np.asarray(shape_list) + spacing_min = spacings.min(axis = 0) + spacing_max = spacings.max(axis = 0) + spacing_median = np.percentile(spacings, 50, axis = 0) + print("spacing min", spacing_min) + print("spacing max", spacing_max) + print("spacing median", spacing_median) + + shape_min = shapes.min(axis = 0) + shape_max = shapes.max(axis = 0) + shape_median = np.percentile(shapes, 50, axis = 0) + print("shape min", shape_min) + print("shape max", shape_max) + print("shape median", shape_median) + + if(output_csv is not None): + img_names_short = [item.split("/")[-1] for item in img_names] + img_names_short.extend(["spacing min", "spacing max", "spacing median", + "shape min", "shape max", "shape median"]) + spacing_list.extend([spacing_min, spacing_max, spacing_median, + shape_min, shape_max, shape_median]) + shape_list.extend(['']* 6) + out_dict = {"img_name": img_names_short, + "spacing": spacing_list, + "shape": shape_list} + df = pd.DataFrame.from_dict(out_dict) + df.to_csv(output_csv, index=False) + +def get_average_mean_std(data_dir, data_csv): + df = pd.read_csv(data_csv) + mean_list, std_list = [], [] + for i in range(len(df)): + img_name = data_dir + "/" + df.iloc[i, 0] + lab_name = data_dir + "/" + df.iloc[i, 1] + img = load_image_as_nd_array(img_name)["data_array"][0] + lab = load_image_as_nd_array(lab_name)["data_array"][0] + voxels = img[lab>0] + mean = voxels.mean() + std = voxels.std() + mean_list.append(mean) + std_list.append(std) + print(img_name, mean, std) + mean = np.asarray(mean_list).mean() + std = np.asarray(std_list).mean() + print("mean and std value", mean, std) + +def get_label_info(data_dir, label_csv, class_num): + df = pd.read_csv(label_csv) + size_list = [] + # mean_list, std_list = [], [] + num_no_tumor = 0 + for i in range(len(df)): + lab_name = data_dir + "/" + df.iloc[i, 1] + lab = load_image_as_nd_array(lab_name)["data_array"][0] + size_per_class = [] + for c in range(1, class_num): + labc = lab == c + size_per_class.append(np.sum(labc)) + if(np.sum(labc) == 0): + num_no_tumor = num_no_tumor + 1 + size_list.append(size_per_class) + print(lab_name, size_per_class) + size = np.asarray(size_list) + size_min = size.min(axis = 0) + size_max = size.max(axis = 0) + size_mean = size.mean(axis = 0) + + print("size min", size_min) + print("size max", size_max) + print("size mean", size_mean) + print("case number without tumor", num_no_tumor) \ No newline at end of file diff --git a/pymic/util/parse_config.py b/pymic/util/parse_config.py index 232762f..3be02cb 100644 --- a/pymic/util/parse_config.py +++ b/pymic/util/parse_config.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, print_function +from pymic import TaskDict import configparser import logging @@ -83,33 +84,98 @@ def parse_value_from_string(val_str): val = val_str return val -def parse_config(filename): +def parse_config(args): config = configparser.ConfigParser() - config.read(filename) + config.read(args.cfg) output = {} for section in config.sections(): output[section] = {} for key in config[section]: val_str = str(config[section][key]) + if hasattr(args, key): + args_key = getattr(args, key) + if(args_key is not None): + val_str = args_key + print(section, key, val_str) if(len(val_str)>0): val = parse_value_from_string(val_str) output[section][key] = val else: val = None - print(section, key, val) + + for key in ["train_dir", "train_csv", "valid_csv", "test_dir", "test_csv"]: + if key in args and getattr(args, key) is not None: + output["dataset"][key] = parse_value_from_string(getattr(args, key)) + for key in ["ckpt_dir", "iter_max", "gpus"]: + if key in args and getattr(args, key) is not None: + output["training"][key] = parse_value_from_string(getattr(args, key)) + for key in ["output_dir", "ckpt_mode", "ckpt_name"]: + if key in args and getattr(args, key) is not None: + output["testing"][key] = parse_value_from_string(getattr(args, key)) return output def synchronize_config(config): data_cfg = config['dataset'] - net_cfg = config['network'] + data_cfg["task_type"] = TaskDict[data_cfg["task_type"]] + if('network' in config): + net_cfg = config['network'] # data_cfg["modal_num"] = net_cfg["in_chns"] - data_cfg["LabelToProbability_class_num".lower()] = net_cfg["class_num"] - if "PartialLabelToProbability" in data_cfg['train_transform']: + data_cfg["LabelToProbability_class_num".lower()] = net_cfg["class_num"] + transform = [] + if('transform' in data_cfg and data_cfg['transform'] is not None): + transform.extend(data_cfg['transform']) + if('train_transform' in data_cfg and data_cfg['train_transform'] is not None): + transform.extend(data_cfg['train_transform']) + if('valid_transform' in data_cfg and data_cfg['valid_transform'] is not None): + transform.extend(data_cfg['valid_transform']) + if('test_transform' in data_cfg and data_cfg['test_transform'] is not None): + transform.extend(data_cfg['test_transform']) + if ( "PartialLabelToProbability" in transform and 'network' in config): data_cfg["PartialLabelToProbability_class_num".lower()] = net_cfg["class_num"] + patch_size = data_cfg.get('patch_size', None) + if(patch_size is not None): + if('Pad' in transform and 'Pad_output_size'.lower() not in data_cfg): + data_cfg['Pad_output_size'.lower()] = patch_size + if('CenterCrop' in transform and 'CenterCrop_output_size'.lower() not in data_cfg): + data_cfg['CenterCrop_output_size'.lower()] = patch_size + if('RandomCrop' in transform and 'RandomCrop_output_size'.lower() not in data_cfg): + data_cfg['RandomCrop_output_size'.lower()] = patch_size + if('RandomResizedCrop' in transform and \ + 'RandomResizedCrop_output_size'.lower() not in data_cfg): + data_cfg['RandomResizedCrop_output_size'.lower()] = patch_size + if('testing' in config): + test_cfg = config['testing'] + sliding_window_enable = test_cfg.get("sliding_window_enable", False) + if(sliding_window_enable): + sliding_window_size = test_cfg.get("sliding_window_size", None) + if(sliding_window_size is None): + test_cfg["sliding_window_size"] = patch_size + sliding_window_stride = test_cfg.get("sliding_window_stride", None) + if(sliding_window_stride is None): + test_cfg["sliding_window_stride"] = [item // 2 for item in patch_size] + config['testing'] = test_cfg config['dataset'] = data_cfg - config['network'] = net_cfg + # config['network'] = net_cfg return config +def wrtie_config(config, output_name): + logging.info("The running configuations are: ") + with open(output_name, 'w') as f: + for section in config: + if(isinstance(config[section], dict)): + line = '[' + section + ']' + f.write('\n' + line + '\n') + logging.info(line) + for key in config[section]: + value = config[section][key] + line = "{0:} = {1:}".format(key, value) + f.write(line + '\n') + logging.info(line) + else: + line = "{0:} = {1:}".format(section, config[section]) + f.write(line + "\n") + logging.info(line) + def logging_config(config): for section in config: if(isinstance(config[section], dict)): diff --git a/pymic/util/preprocess.py b/pymic/util/preprocess.py deleted file mode 100644 index 5f20372..0000000 --- a/pymic/util/preprocess.py +++ /dev/null @@ -1,62 +0,0 @@ -import os -import numpy as np -import SimpleITK as sitk -from pymic.io.image_read_write import load_image_as_nd_array -from pymic.transform.trans_dict import TransformDict -from pymic.util.parse_config import parse_config - -def get_transform_list(trans_config_file): - """ - Create a list of transforms given a configuration file. - """ - config = parse_config(trans_config_file) - transform_list = [] - - transform_param = config['dataset'] - transform_param['task'] = 'segmentation' - transform_names = config['dataset']['transform'] - for name in transform_names: - print(name) - if(name not in TransformDict): - raise(ValueError("Undefined transform {0:}".format(name))) - one_transform = TransformDict[name](transform_param) - transform_list.append(one_transform) - return transform_list - -def preprocess_with_transform(transforms, img_in_name, img_out_name, - lab_in_name = None, lab_out_name = None): - """ - Using a list of data transforms for preprocessing, - such as image normalization, cropping, etc. - TODO: support multip-modality preprocessing. - - :param transforms: (list) A list of transform objects. - :param img_in_name: (str) Input file name. - :param img_out_name: (str) Output file name. - :param lab_in_name: (optional, str) If None, load the image's - corresponding label for preprocessing as well. - :param lab_out_name: (optional, str) The output label name. - """ - image_dict = load_image_as_nd_array(img_in_name) - sample = {'image': np.asarray(image_dict['data_array'], np.float32), - 'origin':image_dict['origin'], - 'spacing': image_dict['spacing'], - 'direction':image_dict['direction']} - if(lab_in_name is not None): - label_dict = load_image_as_nd_array(lab_in_name) - sample['label'] = label_dict['data_array'] - for transform in transforms: - sample = transform(sample) - - out_img = sitk.GetImageFromArray(sample['image'][0]) - out_img.SetSpacing(sample['spacing']) - out_img.SetOrigin(sample['origin']) - out_img.SetDirection(sample['direction']) - sitk.WriteImage(out_img, img_out_name) - if(lab_in_name is not None and lab_out_name is not None): - out_lab = sitk.GetImageFromArray(sample['label'][0]) - out_lab.CopyInformation(out_img) - sitk.WriteImage(out_lab, lab_out_name) - - - diff --git a/requirements.txt b/requirements.txt index cac47f3..e70fc83 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,14 @@ h5py matplotlib>=3.1.2 -numpy>=1.17.4 -pandas>=0.25.3 -scikit-image>=0.16.2 -scikit-learn>=0.22 -scipy>=1.3.3 -SimpleITK>=2.0.0 +numpy>=1.23.5 +pandas>=1.5.2 +scikit-image>=0.19.3 +scikit-learn>=1.2.0 +scipy>=1.10.0 +SimpleITK>=2.0.2 tensorboard tensorboardX -torch>=1.1.12 -torchvision>=0.13.0 +torch>=1.13.1 +torchvision>=0.14.1 +causal-conv1d>=1.5.0 +mamba-ssm>=2.2.4 diff --git a/setup.py b/setup.py index 36daf9a..879ee6c 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setuptools.setup( name = 'PYMIC', - version = "0.4.0", + version = "0.5.4", author ='PyMIC Consortium', author_email = 'wguotai@gmail.com', description = description, @@ -41,6 +41,7 @@ python_requires = '>=3.6', entry_points = { 'console_scripts': [ + 'pymic_preprocess = pymic.net_run.preprocess:main', 'pymic_train = pymic.net_run.train:main', 'pymic_test = pymic.net_run.predict:main', 'pymic_eval_cls = pymic.util.evaluation_cls:main',