|
778 | 778 | "cell_type": "markdown", |
779 | 779 | "metadata": {}, |
780 | 780 | "source": [ |
781 | | - "## Train a Sortformer Diarizer Model" |
| 781 | + "## Train a Sortformer Diarizer Model\n", |
| 782 | + "\n", |
| 783 | + "### Training an offline Sortformer model" |
782 | 784 | ] |
783 | 785 | }, |
784 | 786 | { |
|
823 | 825 | "source": [ |
824 | 826 | "curr_dir = os.getcwd() + \"/\"\n", |
825 | 827 | "config.model.train_ds.manifest_filepath = f'{curr_dir}simulated_train/sortformer_train.json'\n", |
826 | | - "config.model.test_ds.manifest_filepath = f'{curr_dir}simulated_valid/sortformer_valid.json'\n", |
| 828 | + "# config.model.test_ds.manifest_filepath = f'{curr_dir}simulated_valid/sortformer_valid.json'\n", |
827 | 829 | "config.model.validation_ds.manifest_filepath = f'{curr_dir}simulated_valid/sortformer_valid.json'\n", |
828 | 830 | "config.trainer.strategy = \"ddp_notebook\"\n", |
829 | 831 | "config.batch_size = 3\n", |
|
858 | 860 | "sortformer_model.maybe_init_from_pretrained_checkpoint(config)\n", |
859 | 861 | "trainer.fit(sortformer_model)" |
860 | 862 | ] |
| 863 | + }, |
| 864 | + { |
| 865 | + "cell_type": "markdown", |
| 866 | + "metadata": {}, |
| 867 | + "source": [ |
| 868 | + "### Training a streaming Sortformer model\n", |
| 869 | + "\n", |
| 870 | + "If you want to train a streaming version of Sortformer, you can download the following YAML file." |
| 871 | + ] |
| 872 | + }, |
| 873 | + { |
| 874 | + "cell_type": "code", |
| 875 | + "execution_count": null, |
| 876 | + "metadata": {}, |
| 877 | + "outputs": [], |
| 878 | + "source": [ |
| 879 | + "\n", |
| 880 | + "!wget -P conf https://raw.githubusercontent.com/NVIDIA/NeMo/{BRANCH}/examples/speaker_tasks/diarization/conf/neural_diarizer/streaming_sortformer_diarizer_4spk-v2.yaml\n", |
| 881 | + "MODEL_CONFIG = os.path.join(NEMO_ROOT,'conf/streaming_sortformer_diarizer_4spk-v2.yaml')\n", |
| 882 | + "config = OmegaConf.load(MODEL_CONFIG)\n", |
| 883 | + "\n", |
| 884 | + "curr_dir = os.getcwd() + \"/\"\n", |
| 885 | + "config.model.train_ds.manifest_filepath = f'{curr_dir}simulated_train/sortformer_train.json'\n", |
| 886 | + "config.model.test_ds.manifest_filepath = f'{curr_dir}simulated_valid/sortformer_valid.json'\n", |
| 887 | + "config.model.validation_ds.manifest_filepath = f'{curr_dir}simulated_valid/sortformer_valid.json'\n", |
| 888 | + "config.trainer.strategy = \"ddp_notebook\"\n", |
| 889 | + "config.batch_size = 3\n", |
| 890 | + "\n", |
| 891 | + "config.trainer.devices=1\n", |
| 892 | + "config.accelerator=\"gpu\"\n", |
| 893 | + "print(os.getcwd())\n", |
| 894 | + "\n", |
| 895 | + "print(\"config.model.train_ds.manifest_filepath \", config.model.train_ds.manifest_filepath )" |
| 896 | + ] |
| 897 | + }, |
| 898 | + { |
| 899 | + "cell_type": "markdown", |
| 900 | + "metadata": {}, |
| 901 | + "source": [ |
| 902 | + "Initiate a streaming Sortformer diarization training session using the given configurations." |
| 903 | + ] |
| 904 | + }, |
| 905 | + { |
| 906 | + "cell_type": "code", |
| 907 | + "execution_count": null, |
| 908 | + "metadata": {}, |
| 909 | + "outputs": [], |
| 910 | + "source": [ |
| 911 | + "trainer = pl.Trainer(devices=1, accelerator='gpu', max_epochs=50,\n", |
| 912 | + " enable_checkpointing=False, logger=False,\n", |
| 913 | + " log_every_n_steps=5, check_val_every_n_epoch=10)\n", |
| 914 | + "\n", |
| 915 | + "exp_manager(trainer, config.get(\"exp_manager\", None))\n", |
| 916 | + "streaming_sortformer_model = SortformerEncLabelModel(cfg=config.model, trainer=trainer)\n", |
| 917 | + "streaming_sortformer_model.maybe_init_from_pretrained_checkpoint(config)\n", |
| 918 | + "trainer.fit(streaming_sortformer_model)" |
| 919 | + ] |
861 | 920 | } |
862 | 921 | ], |
863 | 922 | "metadata": { |
864 | 923 | "kernelspec": { |
865 | | - "display_name": "Python 3 (ipykernel)", |
| 924 | + "display_name": "nv082124", |
866 | 925 | "language": "python", |
867 | 926 | "name": "python3" |
868 | 927 | }, |
|
0 commit comments