Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 7fcdeee

Browse files
Format
1 parent 13c923c commit 7fcdeee

File tree

3 files changed

+101
-54
lines changed

3 files changed

+101
-54
lines changed

src/orpheus_cpp/__main__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
)
2222
from fastrtc.utils import create_message
2323
from huggingface_hub import InferenceClient
24+
2425
from orpheus_cpp.model import OrpheusCpp
26+
2527
async_client = httpx.AsyncClient()
2628

2729
client = InferenceClient(model="meta-llama/Llama-3.2-3B-Instruct")
@@ -46,6 +48,7 @@ def generate_message():
4648
msg = msg.replace('"', "")
4749
return msg
4850

51+
4952
model = OrpheusCpp()
5053

5154

@@ -68,7 +71,9 @@ async def receive(self, frame: tuple[int, npt.NDArray[np.int16]]) -> None:
6871
all_audio = np.array([], dtype=np.int16)
6972
started_playback = False
7073

71-
async for (sample_rate, chunk) in model.stream_tts(msg, options={"voice_id": voice_id}):
74+
async for sample_rate, chunk in model.stream_tts(
75+
msg, options={"voice_id": voice_id}
76+
):
7277
all_audio = np.concatenate([all_audio, chunk.squeeze()])
7378
if not started_playback:
7479
started_playback = True

src/orpheus_cpp/model.py

Lines changed: 94 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
1-
import onnxruntime
1+
import asyncio
2+
import platform
3+
import sys
4+
import threading
5+
from typing import (
6+
AsyncGenerator,
7+
Generator,
8+
Iterator,
9+
Literal,
10+
NotRequired,
11+
TypedDict,
12+
cast,
13+
)
14+
215
import numpy as np
16+
import onnxruntime
317
from huggingface_hub import hf_hub_download
4-
import sys
5-
import platform
6-
from typing import Generator, Iterator, cast, TypedDict, AsyncGenerator, NotRequired, Literal
718
from numpy.typing import NDArray
8-
import asyncio
9-
import threading
1019

1120

1221
class TTSOptions(TypedDict):
@@ -22,42 +31,56 @@ class TTSOptions(TypedDict):
2231
"""Minimum probability for top-p sampling. Default: 0.05"""
2332
pre_buffer_size: NotRequired[float]
2433
"""Seconds of audio to generate before yielding the first chunk. Smoother audio streaming at the cost of higher time to wait for the first chunk."""
25-
voice_id: NotRequired[Literal["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe"]]
34+
voice_id: NotRequired[
35+
Literal["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe"]
36+
]
2637
"""The voice to use for the TTS. Default: "tara"."""
2738

39+
2840
CUSTOM_TOKEN_PREFIX = "<custom_token_"
2941

42+
3043
class OrpheusCpp:
3144
def __init__(self, verbose: bool = True):
3245
import importlib.util
46+
3347
if importlib.util.find_spec("llama_cpp") is None:
3448
if sys.platform == "darwin":
3549
# Check if macOS 11.0+ on arm64 (Apple Silicon)
3650
is_arm64 = platform.machine() == "arm64"
3751
version = platform.mac_ver()[0].split(".")
3852
is_macos_11_plus = len(version) >= 2 and int(version[0]) >= 11
3953
is_macos_10_less = len(version) >= 2 and int(version[0]) < 11
40-
54+
4155
if is_arm64 and is_macos_11_plus:
4256
extra_index_url = "--extra-index-url https://abetlen.github.io/llama-cpp-python/whl/metal"
4357
elif is_macos_10_less:
44-
raise ImportError("llama_cpp does not have pre-built wheels for macOS 10.x "
45-
"Follow install instructions at https://github.com/abetlen/llama-cpp-python")
58+
raise ImportError(
59+
"llama_cpp does not have pre-built wheels for macOS 10.x "
60+
"Follow install instructions at https://github.com/abetlen/llama-cpp-python"
61+
)
4662
else:
4763
extra_index_url = "--extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu"
4864
else:
4965
extra_index_url = "--extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu"
5066

51-
raise ImportError(f"llama_cpp is not installed. Please install it using `pip install llama-cpp-python {extra_index_url}`.")
67+
raise ImportError(
68+
f"llama_cpp is not installed. Please install it using `pip install llama-cpp-python {extra_index_url}`."
69+
)
5270

53-
model_file = hf_hub_download(repo_id="isaiahbjork/orpheus-3b-0.1-ft-Q4_K_M-GGUF",
54-
filename="orpheus-3b-0.1-ft-q4_k_m.gguf")
71+
model_file = hf_hub_download(
72+
repo_id="isaiahbjork/orpheus-3b-0.1-ft-Q4_K_M-GGUF",
73+
filename="orpheus-3b-0.1-ft-q4_k_m.gguf",
74+
)
5575
from llama_cpp import Llama
76+
5677
self._llm = Llama(model_path=model_file, n_ctx=0, verbose=verbose)
5778

5879
repo_id = "onnx-community/snac_24khz-ONNX"
5980
snac_model_file = "decoder_model.onnx"
60-
snac_model_path = hf_hub_download(repo_id, subfolder="onnx", filename=snac_model_file)
81+
snac_model_path = hf_hub_download(
82+
repo_id, subfolder="onnx", filename=snac_model_file
83+
)
6184

6285
# Load SNAC model with optimizations
6386
self._snac_session = onnxruntime.InferenceSession(
@@ -67,16 +90,16 @@ def __init__(self, verbose: bool = True):
6790

6891
def _token_to_id(self, token_text: str, index: int) -> int | None:
6992
token_string = token_text.strip()
70-
93+
7194
# Find the last token in the string
7295
last_token_start = token_string.rfind(CUSTOM_TOKEN_PREFIX)
73-
96+
7497
if last_token_start == -1:
7598
return None
76-
99+
77100
# Extract the last token
78101
last_token = token_string[last_token_start:]
79-
102+
80103
# Process the last token
81104
if last_token.startswith(CUSTOM_TOKEN_PREFIX) and last_token.endswith(">"):
82105
try:
@@ -87,8 +110,10 @@ def _token_to_id(self, token_text: str, index: int) -> int | None:
87110
return None
88111
else:
89112
return None
90-
91-
def _decode(self, token_gen: Generator[str, None, None]) -> Generator[np.ndarray, None, None]:
113+
114+
def _decode(
115+
self, token_gen: Generator[str, None, None]
116+
) -> Generator[np.ndarray, None, None]:
92117
"""Asynchronous token decoder that converts token stream to audio stream."""
93118
buffer = []
94119
count = 0
@@ -97,79 +122,90 @@ def _decode(self, token_gen: Generator[str, None, None]) -> Generator[np.ndarray
97122
if token is not None and token > 0:
98123
buffer.append(token)
99124
count += 1
100-
125+
101126
# Convert to audio when we have enough tokens
102127
if count % 7 == 0 and count > 27:
103128
buffer_to_proc = buffer[-28:]
104129
audio_samples = self._convert_to_audio(buffer_to_proc)
105130
if audio_samples is not None:
106131
yield audio_samples
107-
132+
108133
def _convert_to_audio(self, multiframe: list[int]) -> np.ndarray | None:
109134
if len(multiframe) < 28: # Ensure we have enough tokens
110135
return None
111-
136+
112137
num_frames = len(multiframe) // 7
113-
frame = multiframe[:num_frames*7]
114-
138+
frame = multiframe[: num_frames * 7]
139+
115140
# Initialize empty numpy arrays instead of torch tensors
116141
codes_0 = np.array([], dtype=np.int32)
117142
codes_1 = np.array([], dtype=np.int32)
118143
codes_2 = np.array([], dtype=np.int32)
119-
144+
120145
for j in range(num_frames):
121-
i = 7*j
146+
i = 7 * j
122147
# Append values to numpy arrays
123148
codes_0 = np.append(codes_0, frame[i])
124-
125-
codes_1 = np.append(codes_1, [frame[i+1], frame[i+4]])
126-
127-
codes_2 = np.append(codes_2, [frame[i+2], frame[i+3], frame[i+5], frame[i+6]])
128-
149+
150+
codes_1 = np.append(codes_1, [frame[i + 1], frame[i + 4]])
151+
152+
codes_2 = np.append(
153+
codes_2, [frame[i + 2], frame[i + 3], frame[i + 5], frame[i + 6]]
154+
)
155+
129156
# Reshape arrays to match the expected input format (add batch dimension)
130157
codes_0 = np.expand_dims(codes_0, axis=0)
131158
codes_1 = np.expand_dims(codes_1, axis=0)
132159
codes_2 = np.expand_dims(codes_2, axis=0)
133-
160+
134161
# Check that all tokens are between 0 and 4096
135-
if (np.any(codes_0 < 0) or np.any(codes_0 > 4096) or
136-
np.any(codes_1 < 0) or np.any(codes_1 > 4096) or
137-
np.any(codes_2 < 0) or np.any(codes_2 > 4096)):
162+
if (
163+
np.any(codes_0 < 0)
164+
or np.any(codes_0 > 4096)
165+
or np.any(codes_1 < 0)
166+
or np.any(codes_1 > 4096)
167+
or np.any(codes_2 < 0)
168+
or np.any(codes_2 > 4096)
169+
):
138170
return None
139-
171+
140172
# Create input dictionary for ONNX session
141173

142174
snac_input_names = [x.name for x in self._snac_session.get_inputs()]
143175

144176
input_dict = dict(zip(snac_input_names, [codes_0, codes_1, codes_2]))
145-
177+
146178
# Run inference
147179
audio_hat = self._snac_session.run(None, input_dict)[0]
148-
180+
149181
# Process output
150182
audio_np = audio_hat[:, :, 2048:4096]
151183
audio_int16 = (audio_np * 32767).astype(np.int16)
152184
audio_bytes = audio_int16.tobytes()
153185
return audio_bytes
154186

155-
def tts(self, text: str, options: TTSOptions | None = None) -> tuple[int, NDArray[np.int16]]:
187+
def tts(
188+
self, text: str, options: TTSOptions | None = None
189+
) -> tuple[int, NDArray[np.int16]]:
156190
buffer = []
157191
for _, array in self.stream_tts_sync(text, options):
158192
buffer.append(array)
159193
return (24_000, np.concatenate(buffer, axis=1))
160-
194+
161195
async def stream_tts(
162196
self, text: str, options: TTSOptions | None = None
163197
) -> AsyncGenerator[tuple[int, NDArray[np.float32]], None]:
164-
165198
queue = asyncio.Queue()
166199
finished = asyncio.Event()
200+
167201
def strem_to_queue(text, options, queue, finished):
168202
for chunk in self.stream_tts_sync(text, options):
169203
queue.put_nowait(chunk)
170204
finished.set()
171-
172-
thread = threading.Thread(target=strem_to_queue, args=(text, options, queue, finished))
205+
206+
thread = threading.Thread(
207+
target=strem_to_queue, args=(text, options, queue, finished)
208+
)
173209
thread.start()
174210
while not finished.is_set():
175211
try:
@@ -180,18 +216,25 @@ def strem_to_queue(text, options, queue, finished):
180216
chunk = queue.get_nowait()
181217
yield chunk
182218

183-
def _token_gen(self, text: str, options: TTSOptions | None = None) -> Generator[str, None, None]:
219+
def _token_gen(
220+
self, text: str, options: TTSOptions | None = None
221+
) -> Generator[str, None, None]:
184222
from llama_cpp import CreateCompletionStreamResponse
223+
185224
options = options or TTSOptions()
186225
voice_id = options.get("voice_id", "tara")
187226
text = f"<|audio|>{voice_id}: {text}<|eot_id|><custom_token_4>"
188-
token_gen = self._llm(text, max_tokens=options.get("max_tokens", 2_048), stream=True,
189-
temperature=options.get("temperature", 0.8),
190-
top_p=options.get("top_p", 0.95),
191-
top_k=options.get("top_k", 40),
192-
min_p=options.get("min_p", 0.05))
227+
token_gen = self._llm(
228+
text,
229+
max_tokens=options.get("max_tokens", 2_048),
230+
stream=True,
231+
temperature=options.get("temperature", 0.8),
232+
top_p=options.get("top_p", 0.95),
233+
top_k=options.get("top_k", 40),
234+
min_p=options.get("min_p", 0.05),
235+
)
193236
for token in cast(Iterator[CreateCompletionStreamResponse], token_gen):
194-
yield token['choices'][0]['text']
237+
yield token["choices"][0]["text"]
195238

196239
def stream_tts_sync(
197240
self, text: str, options: TTSOptions | None = None
@@ -212,4 +255,3 @@ def stream_tts_sync(
212255
yield (24_000, audio_array)
213256
if not started_playback:
214257
yield (24_000, pre_buffer)
215-

tests/test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from orpheus_cpp import OrpheusCpp
21
from scipy.io.wavfile import write
32

3+
from orpheus_cpp import OrpheusCpp
44

55
orpheus = OrpheusCpp()
66

0 commit comments

Comments
 (0)