-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Expand file tree
/
Copy pathtraining.html
More file actions
325 lines (274 loc) · 13.2 KB
/
training.html
File metadata and controls
325 lines (274 loc) · 13.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
<!DOCTYPE html>
<html lang="en"><head><meta charset="UTF-8"><meta name="viewport" content="width=device-width,initial-scale=1.0">
<title>FunASR Training Guide</title>
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700;800&family=JetBrains+Mono:wght@400;500&display=swap" rel="stylesheet">
<link rel="stylesheet" href="style.css">
</head><body>
<nav class="nav"><div class="container">
<a href="index.html" class="nav-logo">FunASR</a>
<div class="nav-links">
<a href="index.html">Home</a>
<a href="tutorial.html">Tutorial</a>
<a href="training.html" class="active">Training</a>
<a href="model-registration.html">Develop</a>
<a href="api.html">API</a>
</div>
<a href="https://github.com/modelscope/FunASR" class="nav-github">GitHub</a>
</div></nav>
<div class="content"><div class="container narrow">
<h1>Training & Fine-tuning</h1>
<p>Fine-tune pretrained models on your own data using FunASR's training framework.</p>
<div class="toc-grid">
<a href="#overview">Overview</a>
<a href="#data">Data Preparation</a>
<a href="#paraformer">Fine-tune Paraformer</a>
<a href="#sensevoice">Fine-tune SenseVoice</a>
<a href="#nano">Fine-tune Fun-ASR-Nano</a>
<a href="#params">Parameter Reference</a>
<a href="#multi-gpu">Multi-GPU Training</a>
<a href="#deepspeed">DeepSpeed</a>
<a href="#monitor">Monitoring</a>
<a href="#inference-after">Use Fine-tuned Model</a>
<a href="#tips">Tips & Troubleshooting</a>
</div>
<!-- Overview -->
<h2 id="overview">Overview</h2>
<p>FunASR's training framework supports:</p>
<ul>
<li><strong>Fine-tuning</strong> any pretrained model on custom domain data</li>
<li><strong>Multi-GPU</strong> training with PyTorch DDP (single/multi-node)</li>
<li><strong>DeepSpeed</strong> ZeRO Stage 1/2/3 for large model training</li>
<li><strong>Dynamic batching</strong> by token count or example count</li>
<li><strong>Checkpoint averaging</strong> for best performance</li>
<li><strong>Resume training</strong> from interruption</li>
</ul>
<p>The training entry point is <code>funasr-train-ds</code> (or <code>funasr/bin/train_ds.py</code>), launched via <code>torchrun</code> for distributed training.</p>
<!-- Data -->
<h2 id="data">Data Preparation</h2>
<h3>Standard Format (Paraformer, SenseVoice)</h3>
<p>Training data uses JSONL format — one JSON object per line:</p>
<pre>{"key": "utt001", "source": "/path/to/audio.wav", "source_len": 90, "target": "这是转写文本", "target_len": 6}
{"key": "utt002", "source": "/path/to/audio2.wav", "source_len": 150, "target": "hello world", "target_len": 2}</pre>
<table>
<tr><th>Field</th><th>Type</th><th>Description</th></tr>
<tr><td><code>key</code></td><td>str</td><td>Unique utterance ID</td></tr>
<tr><td><code>source</code></td><td>str</td><td>Audio file path (local path or URL)</td></tr>
<tr><td><code>source_len</code></td><td>int</td><td>Audio length in fbank frames (1 frame = 10ms)</td></tr>
<tr><td><code>target</code></td><td>str</td><td>Transcription text</td></tr>
<tr><td><code>target_len</code></td><td>int</td><td>Number of text tokens</td></tr>
</table>
<h4>Generate from wav.scp + text.txt</h4>
<p>If you have Kaldi-style data files, convert them:</p>
<pre># train_wav.scp (tab-separated: id \t path)
utt001 /data/audio/001.wav
utt002 /data/audio/002.wav
# train_text.txt (tab-separated: id \t text)
utt001 这是转写文本
utt002 hello world</pre>
<pre># Convert to jsonl
scp2jsonl \
++scp_file_list='["/data/list/train_wav.scp", "/data/list/train_text.txt"]' \
++data_type_list='["source", "target"]' \
++jsonl_file_out="/data/list/train.jsonl"</pre>
<h3>ChatML Format (Fun-ASR-Nano)</h3>
<p>Fun-ASR-Nano uses ChatML conversation format:</p>
<pre>{"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "语音转写:<|startofspeech|>!/path/to/audio.wav<|endofspeech|>"},
{"role": "assistant", "content": "几点了?"}
], "speech_length": 145, "text_length": 3}</pre>
<table>
<tr><th>Field</th><th>Description</th></tr>
<tr><td><code>messages[0]</code></td><td>System prompt (fixed: "You are a helpful assistant.")</td></tr>
<tr><td><code>messages[1]</code></td><td>User: prompt + audio path wrapped in <code><|startofspeech|>!...<|endofspeech|></code></td></tr>
<tr><td><code>messages[2]</code></td><td>Assistant: transcription text</td></tr>
<tr><td><code>speech_length</code></td><td>Number of fbank frames (10ms each)</td></tr>
<tr><td><code>text_length</code></td><td>Number of tokens (tokenized by Qwen3-0.6B)</td></tr>
</table>
<div class="tip"><strong>Prompt variations:</strong><br>
• Chinese: <code>语音转写:</code><br>
• English: <code>Speech transcription:</code><br>
• Cross-language: <code>语音转写成英文:</code><br>
• No ITN: <code>语音转写,不进行文本规整:</code></div>
<p>Convert from wav.scp + text.txt:</p>
<pre>python tools/scp2jsonl.py \
++scp_file=data/train_wav.scp \
++transcript_file=data/train_text.txt \
++jsonl_file=data/train_example.jsonl</pre>
<!-- Paraformer -->
<h2 id="paraformer">Fine-tune Paraformer</h2>
<pre>cd examples/industrial_data_pretraining/paraformer
bash finetune.sh</pre>
<p>Or customize the key parameters:</p>
<pre>export CUDA_VISIBLE_DEVICES="0,1"
gpu_num=2
torchrun --nproc_per_node $gpu_num \
funasr/bin/train_ds.py \
++model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
++train_data_set_list="data/train.jsonl" \
++valid_data_set_list="data/val.jsonl" \
++dataset_conf.batch_size=6000 \
++dataset_conf.batch_type="token" \
++dataset_conf.num_workers=4 \
++train_conf.max_epoch=50 \
++train_conf.validate_interval=2000 \
++train_conf.save_checkpoint_interval=2000 \
++train_conf.keep_nbest_models=20 \
++train_conf.avg_nbest_model=10 \
++optim_conf.lr=0.0002 \
++output_dir="./outputs"</pre>
<!-- SenseVoice -->
<h2 id="sensevoice">Fine-tune SenseVoice</h2>
<pre>cd examples/industrial_data_pretraining/sense_voice
bash finetune.sh</pre>
<p>Same data format as Paraformer (source/target JSONL). Key difference: SenseVoice uses its own dataset class internally.</p>
<!-- Fun-ASR-Nano -->
<h2 id="nano">Fine-tune Fun-ASR-Nano</h2>
<pre>cd examples/industrial_data_pretraining/fun_asr_nano
bash finetune.sh</pre>
<p>Key differences from Paraformer:</p>
<ul>
<li>Uses ChatML data format (see above)</li>
<li><code>++trust_remote_code=true</code> required</li>
<li>Supports <strong>selective freezing</strong>: freeze encoder/adaptor while training LLM decoder</li>
</ul>
<pre># Freeze encoder + adaptor, only train LLM (recommended for domain adaptation)
++audio_encoder_conf.freeze=true
++audio_adaptor_conf.freeze=true
++llm_conf.freeze=false
# Full fine-tune (all parameters)
++audio_encoder_conf.freeze=false
++audio_adaptor_conf.freeze=false
++llm_conf.freeze=false</pre>
<div class="note"><strong>Recommended strategy:</strong> Start with LLM-only fine-tuning (faster, less data needed). If results are insufficient, unfreeze adaptor. Only unfreeze encoder with very large datasets (>1000h).</div>
<!-- Parameters -->
<h2 id="params">Parameter Reference</h2>
<h3>Dataset Parameters</h3>
<table>
<tr><th>Parameter</th><th>Default</th><th>Description</th></tr>
<tr><td><code>dataset_conf.batch_type</code></td><td>"token"</td><td><code>"token"</code>: dynamic batch by total tokens. <code>"example"</code>: fixed batch count.</td></tr>
<tr><td><code>dataset_conf.batch_size</code></td><td>6000</td><td>Token mode: total frames per batch. Example mode: number of samples.</td></tr>
<tr><td><code>dataset_conf.sort_size</code></td><td>1024</td><td>Buffer size for length-based sorting (improves padding efficiency).</td></tr>
<tr><td><code>dataset_conf.num_workers</code></td><td>4</td><td>Data loading threads.</td></tr>
<tr><td><code>dataset_conf.data_split_num</code></td><td>1</td><td>Split data into N groups for large-scale training (reduces memory).</td></tr>
<tr><td><code>dataset_conf.max_token_length</code></td><td>—</td><td>Filter: skip samples longer than this (in frames/tokens).</td></tr>
<tr><td><code>dataset_conf.min_token_length</code></td><td>—</td><td>Filter: skip samples shorter than this.</td></tr>
</table>
<h3>Training Parameters</h3>
<table>
<tr><th>Parameter</th><th>Default</th><th>Description</th></tr>
<tr><td><code>train_conf.max_epoch</code></td><td>50</td><td>Total training epochs.</td></tr>
<tr><td><code>train_conf.log_interval</code></td><td>1</td><td>Print loss every N steps.</td></tr>
<tr><td><code>train_conf.validate_interval</code></td><td>2000</td><td>Run validation every N steps.</td></tr>
<tr><td><code>train_conf.save_checkpoint_interval</code></td><td>2000</td><td>Save model every N steps.</td></tr>
<tr><td><code>train_conf.keep_nbest_models</code></td><td>20</td><td>Keep top N models (by validation accuracy).</td></tr>
<tr><td><code>train_conf.avg_nbest_model</code></td><td>10</td><td>Average top N models for final checkpoint.</td></tr>
<tr><td><code>train_conf.resume</code></td><td>true</td><td>Resume from last checkpoint if exists.</td></tr>
<tr><td><code>train_conf.use_deepspeed</code></td><td>false</td><td>Enable DeepSpeed ZeRO optimization.</td></tr>
<tr><td><code>optim_conf.lr</code></td><td>0.0002</td><td>Learning rate.</td></tr>
</table>
<!-- Multi-GPU -->
<h2 id="multi-gpu">Multi-GPU Training</h2>
<h3>Single Machine, Multiple GPUs</h3>
<pre>export CUDA_VISIBLE_DEVICES="0,1,2,3"
gpu_num=4
torchrun --nnodes 1 --nproc_per_node $gpu_num \
funasr/bin/train_ds.py ${train_args}</pre>
<h3>Multiple Machines</h3>
<pre># Machine 1 (master, IP=192.168.1.1)
torchrun --nnodes 2 --node_rank 0 --nproc_per_node 4 \
--master_addr=192.168.1.1 --master_port=12345 \
funasr/bin/train_ds.py ${train_args}
# Machine 2
torchrun --nnodes 2 --node_rank 1 --nproc_per_node 4 \
--master_addr=192.168.1.1 --master_port=12345 \
funasr/bin/train_ds.py ${train_args}</pre>
<!-- DeepSpeed -->
<h2 id="deepspeed">DeepSpeed</h2>
<p>For large models (Fun-ASR-Nano 800M params), enable DeepSpeed ZeRO:</p>
<pre>++train_conf.use_deepspeed=true
++train_conf.deepspeed_config=./deepspeed_conf/ds_stage1.json</pre>
<p>Stage 1 config (recommended starting point):</p>
<pre>{
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 1,
"bf16": {"enabled": true},
"zero_optimization": {
"stage": 1,
"reduce_bucket_size": 5e8,
"allgather_bucket_size": 5e8
}
}</pre>
<div class="note"><strong>When to use which stage:</strong><br>
• Stage 1: Optimizer state partitioned. Good for most cases.<br>
• Stage 2: + Gradient partitioned. For larger models.<br>
• Stage 3: + Parameter partitioned. Maximum memory savings but slower.</div>
<!-- Monitor -->
<h2 id="monitor">Monitoring Training</h2>
<h4>Log file</h4>
<pre>tail -f outputs/log.txt
# Example output:
# train, rank: 0, epoch: 0/50, step: 6990, (loss_avg_rank: 0.327),
# (acc_avg_epoch: 0.795), (lr: 1.165e-04),
# GPU memory: usage: 3.8GB, peak: 18.3GB</pre>
<p>Key metrics to watch:</p>
<ul>
<li><code>loss_avg_epoch</code>: should decrease over time</li>
<li><code>acc_avg_epoch</code>: should increase (most important metric)</li>
<li><code>lr</code>: learning rate at current step</li>
<li><code>GPU memory</code>: peak should not exceed your GPU VRAM</li>
</ul>
<h4>TensorBoard</h4>
<pre>tensorboard --logdir outputs/log/tensorboard
# Open http://localhost:6006</pre>
<!-- Use Fine-tuned Model -->
<h2 id="inference-after">Use Your Fine-tuned Model</h2>
<h4>If outputs/ has configuration.json</h4>
<pre>from funasr import AutoModel
model = AutoModel(model="./outputs")
res = model.generate(input="test.wav")
print(res[0]["text"])</pre>
<h4>If no configuration.json</h4>
<pre>funasr ++model="./outputs" \
++config-path="./outputs" \
++config-name="config.yaml" \
++init_param="./outputs/model.pt" \
++input="test.wav"</pre>
<!-- Tips -->
<h2 id="tips">Tips & Troubleshooting</h2>
<h4>OOM during training</h4>
<ol>
<li>Reduce <code>dataset_conf.batch_size</code></li>
<li>Add <code>dataset_conf.max_token_length=2000</code> to filter long utterances</li>
<li>Enable DeepSpeed (partitions optimizer states)</li>
<li>Reduce <code>dataset_conf.num_workers</code></li>
</ol>
<h4>Training loss stuck / NaN gradients</h4>
<ul>
<li>Reduce learning rate (try 0.00005)</li>
<li>Check data quality — corrupted audio files cause NaN</li>
<li>For Fun-ASR-Nano: start with encoder frozen</li>
</ul>
<h4>Validation accuracy not improving</h4>
<ul>
<li>Increase training data (min ~10h for fine-tuning)</li>
<li>Check domain match — model may not generalize to very different domains</li>
<li>Try unfreezing more layers gradually</li>
</ul>
<h4>Large-scale data (>10,000 hours)</h4>
<p>Use data splitting to avoid memory issues:</p>
<pre># Split data into chunks, load 2 at a time
++dataset_conf.data_split_num=256
# data.list contains paths to split jsonl files:
# data/train.0.jsonl
# data/train.1.jsonl
# ...
++train_data_set_list="data/data.list"</pre>
<h4>Resume after crash</h4>
<p>Set <code>++train_conf.resume=true</code> (default). Training automatically restarts from the latest checkpoint in <code>output_dir</code>.</p>
</div></div>
<footer>
<p>FunASR · Tongyi Lab, Alibaba Group</p>
<p><a href="index.html">Home</a> · <a href="tutorial.html">Tutorial</a> · <a href="training.html">Training</a> · <a href="model-registration.html">Develop</a> · <a href="api.html">API</a> · <a href="https://github.com/modelscope/FunASR">GitHub</a></p>
</footer>
</body></html>