Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 57ef50d

Browse files
authored
End_to_End_Diarization_Training.ipynb (#14680)
Signed-off-by: taejinp <[email protected]>
1 parent 4d15f4c commit 57ef50d

1 file changed

Lines changed: 62 additions & 3 deletions

File tree

tutorials/speaker_tasks/End_to_End_Diarization_Training.ipynb

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -778,7 +778,9 @@
778778
"cell_type": "markdown",
779779
"metadata": {},
780780
"source": [
781-
"## Train a Sortformer Diarizer Model"
781+
"## Train a Sortformer Diarizer Model\n",
782+
"\n",
783+
"### Training an offline Sortformer model"
782784
]
783785
},
784786
{
@@ -823,7 +825,7 @@
823825
"source": [
824826
"curr_dir = os.getcwd() + \"/\"\n",
825827
"config.model.train_ds.manifest_filepath = f'{curr_dir}simulated_train/sortformer_train.json'\n",
826-
"config.model.test_ds.manifest_filepath = f'{curr_dir}simulated_valid/sortformer_valid.json'\n",
828+
"# config.model.test_ds.manifest_filepath = f'{curr_dir}simulated_valid/sortformer_valid.json'\n",
827829
"config.model.validation_ds.manifest_filepath = f'{curr_dir}simulated_valid/sortformer_valid.json'\n",
828830
"config.trainer.strategy = \"ddp_notebook\"\n",
829831
"config.batch_size = 3\n",
@@ -858,11 +860,68 @@
858860
"sortformer_model.maybe_init_from_pretrained_checkpoint(config)\n",
859861
"trainer.fit(sortformer_model)"
860862
]
863+
},
864+
{
865+
"cell_type": "markdown",
866+
"metadata": {},
867+
"source": [
868+
"### Training a streaming Sortformer model\n",
869+
"\n",
870+
"If you want to train a streaming version of Sortformer, you can download the following YAML file."
871+
]
872+
},
873+
{
874+
"cell_type": "code",
875+
"execution_count": null,
876+
"metadata": {},
877+
"outputs": [],
878+
"source": [
879+
"\n",
880+
"!wget -P conf https://raw.githubusercontent.com/NVIDIA/NeMo/{BRANCH}/examples/speaker_tasks/diarization/conf/neural_diarizer/streaming_sortformer_diarizer_4spk-v2.yaml\n",
881+
"MODEL_CONFIG = os.path.join(NEMO_ROOT,'conf/streaming_sortformer_diarizer_4spk-v2.yaml')\n",
882+
"config = OmegaConf.load(MODEL_CONFIG)\n",
883+
"\n",
884+
"curr_dir = os.getcwd() + \"/\"\n",
885+
"config.model.train_ds.manifest_filepath = f'{curr_dir}simulated_train/sortformer_train.json'\n",
886+
"config.model.test_ds.manifest_filepath = f'{curr_dir}simulated_valid/sortformer_valid.json'\n",
887+
"config.model.validation_ds.manifest_filepath = f'{curr_dir}simulated_valid/sortformer_valid.json'\n",
888+
"config.trainer.strategy = \"ddp_notebook\"\n",
889+
"config.batch_size = 3\n",
890+
"\n",
891+
"config.trainer.devices=1\n",
892+
"config.accelerator=\"gpu\"\n",
893+
"print(os.getcwd())\n",
894+
"\n",
895+
"print(\"config.model.train_ds.manifest_filepath \", config.model.train_ds.manifest_filepath )"
896+
]
897+
},
898+
{
899+
"cell_type": "markdown",
900+
"metadata": {},
901+
"source": [
902+
"Initiate a streaming Sortformer diarization training session using the given configurations."
903+
]
904+
},
905+
{
906+
"cell_type": "code",
907+
"execution_count": null,
908+
"metadata": {},
909+
"outputs": [],
910+
"source": [
911+
"trainer = pl.Trainer(devices=1, accelerator='gpu', max_epochs=50,\n",
912+
" enable_checkpointing=False, logger=False,\n",
913+
" log_every_n_steps=5, check_val_every_n_epoch=10)\n",
914+
"\n",
915+
"exp_manager(trainer, config.get(\"exp_manager\", None))\n",
916+
"streaming_sortformer_model = SortformerEncLabelModel(cfg=config.model, trainer=trainer)\n",
917+
"streaming_sortformer_model.maybe_init_from_pretrained_checkpoint(config)\n",
918+
"trainer.fit(streaming_sortformer_model)"
919+
]
861920
}
862921
],
863922
"metadata": {
864923
"kernelspec": {
865-
"display_name": "Python 3 (ipykernel)",
924+
"display_name": "nv082124",
866925
"language": "python",
867926
"name": "python3"
868927
},

0 commit comments

Comments
 (0)