-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathapp.py
More file actions
420 lines (358 loc) · 14.3 KB
/
app.py
File metadata and controls
420 lines (358 loc) · 14.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
import os
import time
import asyncio
from datetime import datetime
from typing import List, Optional
from dotenv import load_dotenv
# Function to ensure .env file exists
def ensure_env_file_exists():
"""Create a .env file from defaults and OS environment variables"""
if not os.path.exists(".env") and os.path.exists(".env.example"):
try:
# 1. Create default env dictionary from .env.example
default_env = {}
with open(".env.example", "r") as example_file:
for line in example_file:
line = line.strip()
if line and not line.startswith("#") and "=" in line:
key = line.split("=")[0].strip()
default_env[key] = line.split("=", 1)[1].strip()
# 2. Override defaults with Docker environment variables if they exist
final_env = default_env.copy()
for key in default_env:
if key in os.environ:
final_env[key] = os.environ[key]
# 3. Write dictionary to .env file in env format
with open(".env", "w") as env_file:
for key, value in final_env.items():
env_file.write(f"{key}={value}\n")
print("✅ Created default .env file from .env.example and environment variables.")
except Exception as e:
print(f"⚠️ Error creating default .env file: {e}")
# Ensure .env file exists before loading environment variables
ensure_env_file_exists()
# Load environment variables from .env file
load_dotenv(override=True)
from fastapi import FastAPI, Request, Form, HTTPException, Depends
from fastapi.responses import HTMLResponse, FileResponse, JSONResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel
import json
from tts_engine import generate_speech_from_api, generate_speech_streaming, AVAILABLE_VOICES, DEFAULT_VOICE, VOICE_TO_LANGUAGE, AVAILABLE_LANGUAGES
# Create FastAPI app
app = FastAPI(
title="Orpheus-FASTAPI",
description="High-performance Text-to-Speech server using Orpheus-FASTAPI",
version="1.0.0"
)
# We'll use FastAPI's built-in startup complete mechanism
# The log message "INFO: Application startup complete." indicates
# that the application is ready
# Ensure directories exist
os.makedirs("outputs", exist_ok=True)
os.makedirs("static", exist_ok=True)
# Mount directories for serving files
app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs")
app.mount("/static", StaticFiles(directory="static"), name="static")
# Setup templates
templates = Jinja2Templates(directory="templates")
# API models
class SpeechRequest(BaseModel):
input: str
model: str = "orpheus"
voice: str = DEFAULT_VOICE
response_format: str = "wav"
speed: float = 1.0
class APIResponse(BaseModel):
status: str
voice: str
output_file: str
generation_time: float
# OpenAI-compatible API endpoint
@app.post("/v1/audio/speech")
async def create_speech_api(request: SpeechRequest):
"""
Generate speech from text using the Orpheus TTS model.
Compatible with OpenAI's /v1/audio/speech endpoint.
For longer texts (>1000 characters), batched generation is used
to improve reliability and avoid truncation issues.
"""
if not request.input:
raise HTTPException(status_code=400, detail="Missing input text")
# Generate unique filename
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_path = f"outputs/{request.voice}_{timestamp}.wav"
# Check if we should use batched generation
use_batching = len(request.input) > 1000
if use_batching:
print(f"Using batched generation for long text ({len(request.input)} characters)")
# Generate speech with automatic batching for long texts
start = time.time()
generate_speech_from_api(
prompt=request.input,
voice=request.voice,
output_file=output_path,
use_batching=use_batching,
max_batch_chars=1000 # Process in ~1000 character chunks (roughly 1 paragraph)
)
end = time.time()
generation_time = round(end - start, 2)
# Return audio file
return FileResponse(
path=output_path,
media_type="audio/wav",
filename=f"{request.voice}_{timestamp}.wav"
)
# Streaming endpoint for real-time audio chunking
@app.post("/v1/audio/speech/stream")
async def stream_speech_api(request: SpeechRequest):
"""
Real-time streaming TTS endpoint that yields PCM16 mono 24kHz chunks.
Compatible with HTTP chunked transfer consumers.
"""
if not request.input:
raise HTTPException(status_code=400, detail="Missing input text")
async def audio_iter():
async for chunk in generate_speech_streaming(
prompt=request.input,
voice=request.voice,
temperature=1.0 if request.model == "orpheus" else 0.6,
top_p=0.9,
output_format="wav"
):
yield chunk
# WAV stream with proper headers for browser compatibility
headers = {
"Transfer-Encoding": "chunked",
"Cache-Control": "no-cache",
}
return StreamingResponse(audio_iter(), media_type="audio/wav", headers=headers)
@app.get("/v1/audio/voices")
async def list_voices():
"""Return list of available voices"""
if not AVAILABLE_VOICES or len(AVAILABLE_VOICES) == 0:
raise HTTPException(status_code=404, detail="No voices available")
return JSONResponse(
content={
"status": "ok",
"voices": AVAILABLE_VOICES
}
)
# Legacy API endpoint for compatibility
@app.post("/speak")
async def speak(request: Request):
"""Legacy endpoint for compatibility with existing clients"""
data = await request.json()
text = data.get("text", "")
voice = data.get("voice", DEFAULT_VOICE)
if not text:
return JSONResponse(
status_code=400,
content={"error": "Missing 'text'"}
)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_path = f"outputs/{voice}_{timestamp}.wav"
# Check if we should use batched generation for longer texts
use_batching = len(text) > 1000
if use_batching:
print(f"Using batched generation for long text ({len(text)} characters)")
# Generate speech with batching for longer texts
start = time.time()
generate_speech_from_api(
prompt=text,
voice=voice,
output_file=output_path,
use_batching=use_batching,
max_batch_chars=1000
)
end = time.time()
generation_time = round(end - start, 2)
return JSONResponse(content={
"status": "ok",
"voice": voice,
"output_file": output_path,
"generation_time": generation_time
})
# Web UI routes
@app.get("/", response_class=HTMLResponse)
async def root(request: Request):
"""Redirect to web UI"""
return templates.TemplateResponse(
"tts.html",
{
"request": request,
"voices": AVAILABLE_VOICES,
"VOICE_TO_LANGUAGE": VOICE_TO_LANGUAGE,
"AVAILABLE_LANGUAGES": AVAILABLE_LANGUAGES
}
)
@app.get("/web/", response_class=HTMLResponse)
async def web_ui(request: Request):
"""Main web UI for TTS generation"""
# Get current config for the Web UI
config = get_current_config()
return templates.TemplateResponse(
"tts.html",
{
"request": request,
"voices": AVAILABLE_VOICES,
"config": config,
"VOICE_TO_LANGUAGE": VOICE_TO_LANGUAGE,
"AVAILABLE_LANGUAGES": AVAILABLE_LANGUAGES
}
)
@app.get("/get_config")
async def get_config():
"""Get current configuration from .env file or defaults"""
config = get_current_config()
return JSONResponse(content=config)
@app.post("/save_config")
async def save_config(request: Request):
"""Save configuration to .env file"""
data = await request.json()
# Convert values to proper types
for key, value in data.items():
if key in ["ORPHEUS_MAX_TOKENS", "ORPHEUS_API_TIMEOUT", "ORPHEUS_PORT", "ORPHEUS_SAMPLE_RATE"]:
try:
data[key] = str(int(value))
except (ValueError, TypeError):
pass
elif key in ["ORPHEUS_TEMPERATURE", "ORPHEUS_TOP_P"]: # Removed ORPHEUS_REPETITION_PENALTY since it's hardcoded now
try:
data[key] = str(float(value))
except (ValueError, TypeError):
pass
# Write configuration to .env file
with open(".env", "w") as f:
for key, value in data.items():
f.write(f"{key}={value}\n")
return JSONResponse(content={"status": "ok", "message": "Configuration saved successfully. Restart server to apply changes."})
@app.post("/restart_server")
async def restart_server():
"""Restart the server by touching a file that triggers Uvicorn's reload"""
import threading
def touch_restart_file():
# Wait a moment to let the response get back to the client
time.sleep(0.5)
# Create or update restart.flag file to trigger reload
restart_file = "restart.flag"
with open(restart_file, "w") as f:
f.write(str(time.time()))
print("🔄 Restart flag created, server will reload momentarily...")
# Start the touch operation in a separate thread
threading.Thread(target=touch_restart_file, daemon=True).start()
# Return success response
return JSONResponse(content={"status": "ok", "message": "Server is restarting. Please wait a moment..."})
def get_current_config():
"""Read current configuration from .env.example and .env files"""
# Default config from .env.example
default_config = {}
if os.path.exists(".env.example"):
with open(".env.example", "r") as f:
for line in f:
line = line.strip()
if line and not line.startswith("#") and "=" in line:
key, value = line.split("=", 1)
default_config[key] = value
# Current config from .env
current_config = {}
if os.path.exists(".env"):
with open(".env", "r") as f:
for line in f:
line = line.strip()
if line and not line.startswith("#") and "=" in line:
key, value = line.split("=", 1)
current_config[key] = value
# Merge configs, with current taking precedence
config = {**default_config, **current_config}
# Add current environment variables
for key in config:
env_value = os.environ.get(key)
if env_value is not None:
config[key] = env_value
return config
@app.post("/web/", response_class=HTMLResponse)
async def generate_from_web(
request: Request,
text: str = Form(...),
voice: str = Form(DEFAULT_VOICE)
):
"""Handle form submission from web UI"""
if not text:
return templates.TemplateResponse(
"tts.html",
{
"request": request,
"error": "Please enter some text.",
"voices": AVAILABLE_VOICES,
"VOICE_TO_LANGUAGE": VOICE_TO_LANGUAGE,
"AVAILABLE_LANGUAGES": AVAILABLE_LANGUAGES
}
)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_path = f"outputs/{voice}_{timestamp}.wav"
# Check if we should use batched generation for longer texts
use_batching = len(text) > 1000
if use_batching:
print(f"Using batched generation for long text from web form ({len(text)} characters)")
# Generate speech with batching for longer texts
start = time.time()
generate_speech_from_api(
prompt=text,
voice=voice,
output_file=output_path,
use_batching=use_batching,
max_batch_chars=1000
)
end = time.time()
generation_time = round(end - start, 2)
return templates.TemplateResponse(
"tts.html",
{
"request": request,
"success": True,
"text": text,
"voice": voice,
"output_file": output_path,
"generation_time": generation_time,
"voices": AVAILABLE_VOICES,
"VOICE_TO_LANGUAGE": VOICE_TO_LANGUAGE,
"AVAILABLE_LANGUAGES": AVAILABLE_LANGUAGES
}
)
if __name__ == "__main__":
import uvicorn
# Check for required settings
required_settings = ["ORPHEUS_HOST", "ORPHEUS_PORT"]
missing_settings = [s for s in required_settings if s not in os.environ]
if missing_settings:
print(f"⚠️ Missing environment variable(s): {', '.join(missing_settings)}")
print(" Using fallback values for server startup.")
# Get host and port from environment variables with better error handling
try:
host = os.environ.get("ORPHEUS_HOST")
if not host:
print("⚠️ ORPHEUS_HOST not set, using 0.0.0.0 as fallback")
host = "0.0.0.0"
except Exception:
print("⚠️ Error reading ORPHEUS_HOST, using 0.0.0.0 as fallback")
host = "0.0.0.0"
try:
port = int(os.environ.get("ORPHEUS_PORT", "5005"))
except (ValueError, TypeError):
print("⚠️ Invalid ORPHEUS_PORT value, using 5005 as fallback")
port = 5005
print(f"🔥 Starting Orpheus-FASTAPI Server on {host}:{port}")
print(f"💬 Web UI available at http://{host if host != '0.0.0.0' else 'localhost'}:{port}")
print(f"📖 API docs available at http://{host if host != '0.0.0.0' else 'localhost'}:{port}/docs")
# Read current API_URL for user information
api_url = os.environ.get("ORPHEUS_API_URL")
if not api_url:
print("⚠️ ORPHEUS_API_URL not set. Please configure in .env file before generating speech.")
else:
print(f"🔗 Using LLM inference server at: {api_url}")
# Include restart.flag in the reload_dirs to monitor it for changes
extra_files = ["restart.flag"] if os.path.exists("restart.flag") else []
# Start with reload enabled to allow automatic restart when restart.flag changes
uvicorn.run("app:app", host=host, port=port, reload=True, reload_dirs=["."], reload_includes=["*.py", "*.html", "restart.flag"])