1- import onnxruntime
1+ import asyncio
2+ import platform
3+ import sys
4+ import threading
5+ from typing import (
6+ AsyncGenerator ,
7+ Generator ,
8+ Iterator ,
9+ Literal ,
10+ NotRequired ,
11+ TypedDict ,
12+ cast ,
13+ )
14+
215import numpy as np
16+ import onnxruntime
317from huggingface_hub import hf_hub_download
4- import sys
5- import platform
6- from typing import Generator , Iterator , cast , TypedDict , AsyncGenerator , NotRequired , Literal
718from numpy .typing import NDArray
8- import asyncio
9- import threading
1019
1120
1221class TTSOptions (TypedDict ):
@@ -22,42 +31,56 @@ class TTSOptions(TypedDict):
2231 """Minimum probability for top-p sampling. Default: 0.05"""
2332 pre_buffer_size : NotRequired [float ]
2433 """Seconds of audio to generate before yielding the first chunk. Smoother audio streaming at the cost of higher time to wait for the first chunk."""
25- voice_id : NotRequired [Literal ["tara" , "leah" , "jess" , "leo" , "dan" , "mia" , "zac" , "zoe" ]]
34+ voice_id : NotRequired [
35+ Literal ["tara" , "leah" , "jess" , "leo" , "dan" , "mia" , "zac" , "zoe" ]
36+ ]
2637 """The voice to use for the TTS. Default: "tara"."""
2738
39+
2840CUSTOM_TOKEN_PREFIX = "<custom_token_"
2941
42+
3043class OrpheusCpp :
3144 def __init__ (self , verbose : bool = True ):
3245 import importlib .util
46+
3347 if importlib .util .find_spec ("llama_cpp" ) is None :
3448 if sys .platform == "darwin" :
3549 # Check if macOS 11.0+ on arm64 (Apple Silicon)
3650 is_arm64 = platform .machine () == "arm64"
3751 version = platform .mac_ver ()[0 ].split ("." )
3852 is_macos_11_plus = len (version ) >= 2 and int (version [0 ]) >= 11
3953 is_macos_10_less = len (version ) >= 2 and int (version [0 ]) < 11
40-
54+
4155 if is_arm64 and is_macos_11_plus :
4256 extra_index_url = "--extra-index-url https://abetlen.github.io/llama-cpp-python/whl/metal"
4357 elif is_macos_10_less :
44- raise ImportError ("llama_cpp does not have pre-built wheels for macOS 10.x "
45- "Follow install instructions at https://github.com/abetlen/llama-cpp-python" )
58+ raise ImportError (
59+ "llama_cpp does not have pre-built wheels for macOS 10.x "
60+ "Follow install instructions at https://github.com/abetlen/llama-cpp-python"
61+ )
4662 else :
4763 extra_index_url = "--extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu"
4864 else :
4965 extra_index_url = "--extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu"
5066
51- raise ImportError (f"llama_cpp is not installed. Please install it using `pip install llama-cpp-python { extra_index_url } `." )
67+ raise ImportError (
68+ f"llama_cpp is not installed. Please install it using `pip install llama-cpp-python { extra_index_url } `."
69+ )
5270
53- model_file = hf_hub_download (repo_id = "isaiahbjork/orpheus-3b-0.1-ft-Q4_K_M-GGUF" ,
54- filename = "orpheus-3b-0.1-ft-q4_k_m.gguf" )
71+ model_file = hf_hub_download (
72+ repo_id = "isaiahbjork/orpheus-3b-0.1-ft-Q4_K_M-GGUF" ,
73+ filename = "orpheus-3b-0.1-ft-q4_k_m.gguf" ,
74+ )
5575 from llama_cpp import Llama
76+
5677 self ._llm = Llama (model_path = model_file , n_ctx = 0 , verbose = verbose )
5778
5879 repo_id = "onnx-community/snac_24khz-ONNX"
5980 snac_model_file = "decoder_model.onnx"
60- snac_model_path = hf_hub_download (repo_id , subfolder = "onnx" , filename = snac_model_file )
81+ snac_model_path = hf_hub_download (
82+ repo_id , subfolder = "onnx" , filename = snac_model_file
83+ )
6184
6285 # Load SNAC model with optimizations
6386 self ._snac_session = onnxruntime .InferenceSession (
@@ -67,16 +90,16 @@ def __init__(self, verbose: bool = True):
6790
6891 def _token_to_id (self , token_text : str , index : int ) -> int | None :
6992 token_string = token_text .strip ()
70-
93+
7194 # Find the last token in the string
7295 last_token_start = token_string .rfind (CUSTOM_TOKEN_PREFIX )
73-
96+
7497 if last_token_start == - 1 :
7598 return None
76-
99+
77100 # Extract the last token
78101 last_token = token_string [last_token_start :]
79-
102+
80103 # Process the last token
81104 if last_token .startswith (CUSTOM_TOKEN_PREFIX ) and last_token .endswith (">" ):
82105 try :
@@ -87,8 +110,10 @@ def _token_to_id(self, token_text: str, index: int) -> int | None:
87110 return None
88111 else :
89112 return None
90-
91- def _decode (self , token_gen : Generator [str , None , None ]) -> Generator [np .ndarray , None , None ]:
113+
114+ def _decode (
115+ self , token_gen : Generator [str , None , None ]
116+ ) -> Generator [np .ndarray , None , None ]:
92117 """Asynchronous token decoder that converts token stream to audio stream."""
93118 buffer = []
94119 count = 0
@@ -97,79 +122,90 @@ def _decode(self, token_gen: Generator[str, None, None]) -> Generator[np.ndarray
97122 if token is not None and token > 0 :
98123 buffer .append (token )
99124 count += 1
100-
125+
101126 # Convert to audio when we have enough tokens
102127 if count % 7 == 0 and count > 27 :
103128 buffer_to_proc = buffer [- 28 :]
104129 audio_samples = self ._convert_to_audio (buffer_to_proc )
105130 if audio_samples is not None :
106131 yield audio_samples
107-
132+
108133 def _convert_to_audio (self , multiframe : list [int ]) -> np .ndarray | None :
109134 if len (multiframe ) < 28 : # Ensure we have enough tokens
110135 return None
111-
136+
112137 num_frames = len (multiframe ) // 7
113- frame = multiframe [:num_frames * 7 ]
114-
138+ frame = multiframe [: num_frames * 7 ]
139+
115140 # Initialize empty numpy arrays instead of torch tensors
116141 codes_0 = np .array ([], dtype = np .int32 )
117142 codes_1 = np .array ([], dtype = np .int32 )
118143 codes_2 = np .array ([], dtype = np .int32 )
119-
144+
120145 for j in range (num_frames ):
121- i = 7 * j
146+ i = 7 * j
122147 # Append values to numpy arrays
123148 codes_0 = np .append (codes_0 , frame [i ])
124-
125- codes_1 = np .append (codes_1 , [frame [i + 1 ], frame [i + 4 ]])
126-
127- codes_2 = np .append (codes_2 , [frame [i + 2 ], frame [i + 3 ], frame [i + 5 ], frame [i + 6 ]])
128-
149+
150+ codes_1 = np .append (codes_1 , [frame [i + 1 ], frame [i + 4 ]])
151+
152+ codes_2 = np .append (
153+ codes_2 , [frame [i + 2 ], frame [i + 3 ], frame [i + 5 ], frame [i + 6 ]]
154+ )
155+
129156 # Reshape arrays to match the expected input format (add batch dimension)
130157 codes_0 = np .expand_dims (codes_0 , axis = 0 )
131158 codes_1 = np .expand_dims (codes_1 , axis = 0 )
132159 codes_2 = np .expand_dims (codes_2 , axis = 0 )
133-
160+
134161 # Check that all tokens are between 0 and 4096
135- if (np .any (codes_0 < 0 ) or np .any (codes_0 > 4096 ) or
136- np .any (codes_1 < 0 ) or np .any (codes_1 > 4096 ) or
137- np .any (codes_2 < 0 ) or np .any (codes_2 > 4096 )):
162+ if (
163+ np .any (codes_0 < 0 )
164+ or np .any (codes_0 > 4096 )
165+ or np .any (codes_1 < 0 )
166+ or np .any (codes_1 > 4096 )
167+ or np .any (codes_2 < 0 )
168+ or np .any (codes_2 > 4096 )
169+ ):
138170 return None
139-
171+
140172 # Create input dictionary for ONNX session
141173
142174 snac_input_names = [x .name for x in self ._snac_session .get_inputs ()]
143175
144176 input_dict = dict (zip (snac_input_names , [codes_0 , codes_1 , codes_2 ]))
145-
177+
146178 # Run inference
147179 audio_hat = self ._snac_session .run (None , input_dict )[0 ]
148-
180+
149181 # Process output
150182 audio_np = audio_hat [:, :, 2048 :4096 ]
151183 audio_int16 = (audio_np * 32767 ).astype (np .int16 )
152184 audio_bytes = audio_int16 .tobytes ()
153185 return audio_bytes
154186
155- def tts (self , text : str , options : TTSOptions | None = None ) -> tuple [int , NDArray [np .int16 ]]:
187+ def tts (
188+ self , text : str , options : TTSOptions | None = None
189+ ) -> tuple [int , NDArray [np .int16 ]]:
156190 buffer = []
157191 for _ , array in self .stream_tts_sync (text , options ):
158192 buffer .append (array )
159193 return (24_000 , np .concatenate (buffer , axis = 1 ))
160-
194+
161195 async def stream_tts (
162196 self , text : str , options : TTSOptions | None = None
163197 ) -> AsyncGenerator [tuple [int , NDArray [np .float32 ]], None ]:
164-
165198 queue = asyncio .Queue ()
166199 finished = asyncio .Event ()
200+
167201 def strem_to_queue (text , options , queue , finished ):
168202 for chunk in self .stream_tts_sync (text , options ):
169203 queue .put_nowait (chunk )
170204 finished .set ()
171-
172- thread = threading .Thread (target = strem_to_queue , args = (text , options , queue , finished ))
205+
206+ thread = threading .Thread (
207+ target = strem_to_queue , args = (text , options , queue , finished )
208+ )
173209 thread .start ()
174210 while not finished .is_set ():
175211 try :
@@ -180,18 +216,25 @@ def strem_to_queue(text, options, queue, finished):
180216 chunk = queue .get_nowait ()
181217 yield chunk
182218
183- def _token_gen (self , text : str , options : TTSOptions | None = None ) -> Generator [str , None , None ]:
219+ def _token_gen (
220+ self , text : str , options : TTSOptions | None = None
221+ ) -> Generator [str , None , None ]:
184222 from llama_cpp import CreateCompletionStreamResponse
223+
185224 options = options or TTSOptions ()
186225 voice_id = options .get ("voice_id" , "tara" )
187226 text = f"<|audio|>{ voice_id } : { text } <|eot_id|><custom_token_4>"
188- token_gen = self ._llm (text , max_tokens = options .get ("max_tokens" , 2_048 ), stream = True ,
189- temperature = options .get ("temperature" , 0.8 ),
190- top_p = options .get ("top_p" , 0.95 ),
191- top_k = options .get ("top_k" , 40 ),
192- min_p = options .get ("min_p" , 0.05 ))
227+ token_gen = self ._llm (
228+ text ,
229+ max_tokens = options .get ("max_tokens" , 2_048 ),
230+ stream = True ,
231+ temperature = options .get ("temperature" , 0.8 ),
232+ top_p = options .get ("top_p" , 0.95 ),
233+ top_k = options .get ("top_k" , 40 ),
234+ min_p = options .get ("min_p" , 0.05 ),
235+ )
193236 for token in cast (Iterator [CreateCompletionStreamResponse ], token_gen ):
194- yield token [' choices' ][0 ][' text' ]
237+ yield token [" choices" ][0 ][" text" ]
195238
196239 def stream_tts_sync (
197240 self , text : str , options : TTSOptions | None = None
@@ -212,4 +255,3 @@ def stream_tts_sync(
212255 yield (24_000 , audio_array )
213256 if not started_playback :
214257 yield (24_000 , pre_buffer )
215-
0 commit comments