diff --git a/.gitignore b/.gitignore index 2c1e1f9..fb24679 100644 --- a/.gitignore +++ b/.gitignore @@ -69,3 +69,6 @@ htmlcov/ # Distribution *.whl + +# Server outputs +output_audio/ diff --git a/requirements_modified.txt b/requirements_modified.txt new file mode 100644 index 0000000..42b8e93 --- /dev/null +++ b/requirements_modified.txt @@ -0,0 +1,37 @@ +# Basic Dependencies +transformers==4.52.3 +datasets==3.6.0 +numpy<2.0.0 +accelerate +deepspeed==0.17.3 +librosa + +# CV3 +conformer==0.3.2 +diffusers==0.29.0 +gdown==5.1.0 +hydra-core==1.3.2 +HyperPyYAML==1.2.2 +inflect==7.3.1 +lightning==2.2.4 +loguru +matplotlib==3.7.5 +modelscope +networkx==3.1 +omegaconf==2.3.0 +onnx==1.16.0 +openai-whisper +protobuf==4.25 +pyarrow==18.1.0 +pydantic==2.7.0 +pyworld==0.3.4 +rich==13.7.1 +soundfile==0.12.1 +uvicorn==0.30.0 +wetext==0.0.4 +wget==3.2 +x_transformers + +# CUDA 13.0 nightly (GB10 / sm_121) +# Install with: +# pip install --pre --index-url https://download.pytorch.org/whl/nightly/cu130 torch torchaudio torchvision diff --git a/simple_server.py b/simple_server.py new file mode 100644 index 0000000..43c8a97 --- /dev/null +++ b/simple_server.py @@ -0,0 +1,707 @@ +import os +os.environ.setdefault("TORCH_CUDA_ARCH_LIST", "12.1+PTX") +import sys +import uuid +import json +import re +import time +import torch +torch._C._jit_set_nvfuser_enabled(False) +torch._C._jit_set_texpr_fuser_enabled(False) +try: + torch.backends.cuda.enable_flash_sdp(False) + torch.backends.cuda.enable_mem_efficient_sdp(False) + torch.backends.cuda.enable_math_sdp(True) +except Exception: + pass +import torchaudio +import soundfile as sf +import librosa +from flask import Flask, request, jsonify, send_file, Response, stream_with_context +from flask_cors import CORS +from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoProcessor + +sys.path.append(os.getcwd()) +from funaudiochat.register import register_funaudiochat +register_funaudiochat() + +from utils.cosyvoice_detokenizer import get_audio_detokenizer, token2wav +from utils.constant import ( + DEFAULT_S2M_GEN_KWARGS, + DEFAULT_SP_GEN_KWARGS, + DEFAULT_S2M_PROMPT, + FUNCTION_CALLING_PROMPT, + AUDIO_TEMPLATE, +) + +app = Flask(__name__) +CORS(app) + +MODEL_DEFAULT_ID = "Fun-Audio-Chat-8B" +PERSONAPLEX_MODEL_ID = "nvidia/personaplex-7b-v1" +ASR_MODEL_SIZE = os.environ.get("ASR_MODEL_SIZE", "base") +ASR_LANGUAGE = os.environ.get("ASR_LANGUAGE", "en") +SPK_EMB_PATHS = [ + 'pretrained_models/Fun-CosyVoice3-0.5B-2512/spk_emb.pt', + 'utils/new_spk2info.pt', +] + +# Global model variables +model = None +processor = None +tts_model = None +current_model_id = None +device = None +output_dir = './output_audio' +asr_model = None + + +def resolve_model_path(model_id: str) -> str: + if os.path.isdir(model_id): + return model_id + + local_candidate = os.path.join('pretrained_models', model_id) + if os.path.isdir(local_candidate): + return local_candidate + + return model_id + + +def is_personaplex(model_id: str) -> bool: + return PERSONAPLEX_MODEL_ID in model_id.lower() or 'personaplex' in model_id.lower() + + +def load_model_if_needed(model_id: str): + global model, processor, tts_model, current_model_id, device + + if model is not None and model_id == current_model_id: + return + + print(f'Loading model: {model_id}') + force_cpu = os.environ.get('FORCE_CPU', '').strip().lower() in ('1', 'true', 'yes') + device = torch.device('cpu' if force_cpu else ('cuda' if torch.cuda.is_available() else 'cpu')) + model_path = resolve_model_path(model_id) + + if is_personaplex(model_id): + if not os.path.isdir(model_path) and model_id == PERSONAPLEX_MODEL_ID: + raise RuntimeError( + 'PersonaPlex model not found. Download it and set PERSONAPLEX_MODEL_PATH or place it ' + 'under pretrained_models/personaplex-7b-v1.' + ) + + config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) + model = AutoModelForSeq2SeqLM.from_pretrained( + model_path, + config=config, + torch_dtype=torch.bfloat16, + trust_remote_code=True, + device_map='auto' + ).eval() + + if hasattr(model, 'config') and hasattr(model.config, 'attn_implementation'): + model.config.attn_implementation = 'eager' + if hasattr(model, 'config'): + setattr(model.config, '_attn_implementation', 'eager') + if hasattr(model, 'language_model') and hasattr(model.language_model, 'config') and hasattr(model.language_model.config, 'attn_implementation'): + model.language_model.config.attn_implementation = 'eager' + if hasattr(model, 'language_model') and hasattr(model.language_model, 'config'): + setattr(model.language_model.config, '_attn_implementation', 'eager') + + if hasattr(model, 'sp_gen_kwargs'): + # Match web_demo defaults to avoid CRQ dimension mismatches. + model.sp_gen_kwargs.update(DEFAULT_SP_GEN_KWARGS) + model.sp_gen_kwargs.update({ + 'text_greedy': False, + 'disable_speech': False, + }) + + tts_model = get_audio_detokenizer() + current_model_id = model_id + os.makedirs(output_dir, exist_ok=True) + print('Models loaded successfully!') + + +def parse_tools(tools_raw: str | None): + if not tools_raw: + return None + try: + tools = json.loads(tools_raw) + except json.JSONDecodeError as exc: + raise ValueError(f"Invalid tools JSON: {exc}") from exc + if not isinstance(tools, list): + raise ValueError('Tools must be a JSON array') + return tools + + +def build_system_prompt(system_prompt: str, tools: list | None): + if not tools: + return system_prompt + + tools_definition = "\n".join([json.dumps(tool, ensure_ascii=False) for tool in tools]) + tools_prompt = FUNCTION_CALLING_PROMPT.replace("{tools_definition}", tools_definition) + if system_prompt: + return f"{system_prompt}\n\n{tools_prompt}" + return tools_prompt + + +def ensure_speech_prompt(system_prompt: str) -> str: + if DEFAULT_S2M_PROMPT and DEFAULT_S2M_PROMPT not in system_prompt: + if system_prompt.strip(): + return f"{system_prompt}\n\n{DEFAULT_S2M_PROMPT}" + return DEFAULT_S2M_PROMPT + return system_prompt + + +def strip_tool_calls(text: str) -> str: + text = re.sub(r".*?", "", text, flags=re.DOTALL) + text = re.sub(r"\[[a-zA-Z0-9_]+\s+[^\]]+\]", "", text) + text = re.sub(r"\{\s*\"(action|tool|name|function)\".*?\}", "", text, flags=re.DOTALL) + return text + + +def extract_tool_calls(text: str, tools: list | None = None): + tool_calls = [] + + matches = re.findall(r"\s*(\{.*?\})\s*", text, re.DOTALL) + for match in matches: + try: + data = json.loads(match) + except json.JSONDecodeError: + continue + name = data.get('name') or data.get('function') or 'tool' + arguments = data.get('arguments') or {} + tool_calls.append({ + 'name': name, + 'arguments': json.dumps(arguments, ensure_ascii=False) if not isinstance(arguments, str) else arguments, + }) + + bracket_matches = re.findall(r"\[([a-zA-Z0-9_]+)\s+([^\]]+)\]", text) + for name, args in bracket_matches: + parsed_args = {} + for key, value in re.findall(r'([a-zA-Z0-9_]+)\s*=\s*\"([^\"]*)\"', args): + parsed_args[key] = value + if not parsed_args: + parsed_args = {"input": args.strip()} + tool_calls.append({ + 'name': name, + 'arguments': json.dumps(parsed_args, ensure_ascii=False), + }) + + call_matches = re.findall(r"([a-zA-Z0-9_]+)\s*\(([^)]*)\)", text) + for name, args in call_matches: + args = args.strip() + if not args: + parsed_args = {} + else: + string_match = re.match(r'\"([^\"]+)\"', args) or re.match(r"\'([^\']+)\'", args) + if string_match: + parsed_args = {"input": string_match.group(1)} + else: + parsed_args = {"input": args} + tool_calls.append({ + 'name': name, + 'arguments': json.dumps(parsed_args, ensure_ascii=False), + }) + + json_matches = re.findall(r"\{.*?\}", text, flags=re.DOTALL) + for match in json_matches: + try: + data = json.loads(match) + except json.JSONDecodeError: + continue + name = data.get('name') or data.get('function') or data.get('tool') or data.get('action') + arguments = {k: v for k, v in data.items() if k not in {'name', 'function', 'tool', 'action'}} + if not name: + if tools and len(tools) == 1: + name = tools[0].get('name') or 'tool' + else: + continue + tool_calls.append({ + 'name': name, + 'arguments': json.dumps(arguments, ensure_ascii=False), + }) + + return tool_calls + + +def fallback_tool_calls(text: str, tools: list | None): + if not tools or len(tools) != 1: + return [] + tool = tools[0] or {} + name = tool.get('name') or 'tool' + tool_blob = json.dumps(tool, ensure_ascii=False).lower() + if 'weather' not in name.lower() and 'weather' not in tool_blob: + return [] + + city = extract_city_from_text(text) + + if not city: + return [] + + return [{ + 'name': name, + 'arguments': json.dumps({'city': city}, ensure_ascii=False), + }] + + +def extract_city_from_text(text: str): + match = re.search(r"weather\s*(?:in|for)\s+([A-Za-z][A-Za-z\- ]+)", text, flags=re.IGNORECASE) + if match: + return match.group(1).strip() + + match = re.search(r"\bin\s+([A-Z][A-Za-z\-]+)\b", text) + if match: + return match.group(1).strip() + + return None + + +def extract_location_from_text(text: str): + match = re.search(r"\b(?:in|for|at)\s+([A-Za-z][A-Za-z\- ]+)", text, flags=re.IGNORECASE) + if match: + return match.group(1).strip() + return None + + +def extract_math_expression(text: str): + match = re.search(r"([-+/*().\d\s]{3,})", text) + if not match: + return None + expr = match.group(1).strip() + if any(char.isdigit() for char in expr): + return expr + return None + + +def extract_translation(text: str): + match = re.search(r"translate\s+(.+?)\s+to\s+([A-Za-z \-]+)", text, flags=re.IGNORECASE) + if match: + return match.group(1).strip(), match.group(2).strip() + match = re.search(r"translate\s+(.+?)\s+into\s+([A-Za-z \-]+)", text, flags=re.IGNORECASE) + if match: + return match.group(1).strip(), match.group(2).strip() + return None, None + + +def extract_search_query(text: str): + match = re.search(r"(?:search for|look up|find)\s+(.+)", text, flags=re.IGNORECASE) + if match: + return match.group(1).strip() + return None + + +def find_tool_by_keywords(tools: list | None, keywords: list[str]): + if not tools: + return None + for tool in tools: + blob = f"{tool.get('name', '')} {tool.get('description', '')}".lower() + if any(keyword in blob for keyword in keywords): + return tool + return None + + +def infer_tool_calls_from_text(text: str, tools: list | None): + if not text or not tools: + return [] + + tool = find_tool_by_keywords(tools, ['weather']) + if tool: + city = extract_city_from_text(text) + if city: + return [{ + 'name': tool.get('name') or 'tool', + 'arguments': json.dumps({'city': city}, ensure_ascii=False), + }] + + tool = find_tool_by_keywords(tools, ['time', 'clock']) + if tool and re.search(r"\btime\b", text, flags=re.IGNORECASE): + location = extract_location_from_text(text) or 'UTC' + return [{ + 'name': tool.get('name') or 'tool', + 'arguments': json.dumps({'location': location}, ensure_ascii=False), + }] + + tool = find_tool_by_keywords(tools, ['calculate', 'calculator', 'math']) + if tool: + expr = extract_math_expression(text) + if expr: + return [{ + 'name': tool.get('name') or 'tool', + 'arguments': json.dumps({'expression': expr}, ensure_ascii=False), + }] + + tool = find_tool_by_keywords(tools, ['translate', 'translation']) + if tool: + text_value, target_language = extract_translation(text) + if text_value and target_language: + return [{ + 'name': tool.get('name') or 'tool', + 'arguments': json.dumps({'text': text_value, 'target_language': target_language}, ensure_ascii=False), + }] + + tool = find_tool_by_keywords(tools, ['summarize', 'summary']) + if tool and re.search(r"\bsummar", text, flags=re.IGNORECASE): + return [{ + 'name': tool.get('name') or 'tool', + 'arguments': json.dumps({'text': text.strip()}, ensure_ascii=False), + }] + + tool = find_tool_by_keywords(tools, ['search', 'web']) + if tool: + query = extract_search_query(text) + if query: + return [{ + 'name': tool.get('name') or 'tool', + 'arguments': json.dumps({'query': query}, ensure_ascii=False), + }] + + return [] + + +def load_asr_model(): + global asr_model + if asr_model is not None: + return asr_model + + try: + from faster_whisper import WhisperModel + except Exception as exc: + print(f'ASR unavailable: {exc}') + return None + + device_name = 'cuda' if torch.cuda.is_available() else 'cpu' + compute_type = 'float16' if torch.cuda.is_available() else 'int8' + try: + asr_model = WhisperModel(ASR_MODEL_SIZE, device=device_name, compute_type=compute_type) + return asr_model + except Exception as exc: + print(f'ASR load failed on {device_name}: {exc}') + if device_name != 'cpu': + try: + asr_model = WhisperModel(ASR_MODEL_SIZE, device='cpu', compute_type='int8') + return asr_model + except Exception as exc_cpu: + print(f'ASR load failed on cpu: {exc_cpu}') + asr_model = None + return None + + +def transcribe_audio_array(audio_array, sr: int): + model = load_asr_model() + if model is None: + return None + + if sr != 16000: + audio_array = librosa.resample(audio_array, orig_sr=sr, target_sr=16000) + sr = 16000 + + audio_array = audio_array.astype("float32") + try: + segments, _ = model.transcribe( + audio_array, + beam_size=5, + language=ASR_LANGUAGE, + ) + text = ''.join([segment.text for segment in segments]).strip() + return text or None + except Exception as exc: + print(f'ASR transcription failed: {exc}') + return None + + +def override_weather_tool_call(tool_calls: list, transcript: str, tools: list | None): + if not transcript or not tools or len(tools) != 1: + return tool_calls + tool = tools[0] or {} + name = tool.get('name') or 'tool' + tool_blob = json.dumps(tool, ensure_ascii=False).lower() + if 'weather' not in name.lower() and 'weather' not in tool_blob: + return tool_calls + + city = extract_city_from_text(transcript) + if not city: + return tool_calls + + return [{ + 'name': name, + 'arguments': json.dumps({'city': city}, ensure_ascii=False), + }] + + +def load_audio(file_storage, target_sr: int): + temp_path = os.path.join(output_dir, f'input_{uuid.uuid4()}.webm') + file_storage.save(temp_path) + audio_array, sr = librosa.load(temp_path, sr=target_sr, mono=True) + return audio_array, sr + + +def prepare_inputs(audio_array, system_prompt: str): + conversation = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": AUDIO_TEMPLATE}, + ] + text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False) + inputs = processor( + text=text, + audio=[audio_array], + return_tensors="pt", + return_token_type_ids=False, + ).to(device) + return inputs + + +def get_generate_ids(outputs): + if isinstance(outputs, tuple): + return outputs[0] + if hasattr(outputs, 'sequences'): + return outputs.sequences + return outputs + + +def decode_generation(outputs, inputs): + generate_ids = get_generate_ids(outputs) + speech_ids = None + if isinstance(outputs, tuple) and len(outputs) > 1: + speech_ids = outputs[1] + + if hasattr(inputs, 'input_ids') and inputs.input_ids is not None: + generate_ids = generate_ids[:, inputs.input_ids.size(1):] + text = processor.decode(generate_ids[0], skip_special_tokens=True) + + if speech_ids is not None and hasattr(speech_ids, 'numel') and speech_ids.numel() > 0: + audio_tokens = speech_ids[0].to(dtype=torch.long) + else: + audio_tokens = generate_ids[0][generate_ids[0] >= 32000] + return text, audio_tokens + + +def load_speaker_embedding(): + for path in SPK_EMB_PATHS: + if not os.path.exists(path): + continue + data = torch.load(path) + if isinstance(data, dict): + if '中文女' in data and 'embedding' in data['中文女']: + return data['中文女']['embedding'] + if 'embedding' in data: + return data['embedding'] + return None + + +def synthesize_audio(audio_tokens): + if audio_tokens is None or audio_tokens.numel() == 0: + return None + + # Move to CPU and drop special/pad tokens that CosyVoice doesn't model. + audio_tokens = audio_tokens.detach().to('cpu') + audio_tokens = audio_tokens[(audio_tokens >= 0) & (audio_tokens < 6561)] + if audio_tokens.numel() == 0: + return None + + spk_embedding = load_speaker_embedding() + if spk_embedding is None: + print('No speaker embedding found; skipping audio synthesis.') + return None + + try: + audio_output = token2wav( + tts_model, + audio_tokens, + spk_embedding, + ) + except Exception as exc: + print(f'TTS failed: {exc}') + return None + + output_path = os.path.join(output_dir, f'output_{uuid.uuid4()}.wav') + audio_np = audio_output.detach().cpu().numpy() + if audio_np.ndim > 1: + audio_np = audio_np[0] + sf.write(output_path, audio_np, 24000) + return output_path + + +def chunk_text(text: str, chunk_size: int = 48): + return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)] + + +def get_generation_kwargs(): + kwargs = dict(DEFAULT_S2M_GEN_KWARGS) + kwargs['use_cache'] = True + + if not kwargs.get('bad_words_ids'): + try: + kwargs['bad_words_ids'] = [[ + processor.tokenizer.convert_tokens_to_ids('<|audio_bos|>'), + processor.tokenizer.convert_tokens_to_ids('<|sil|>') + ]] + except Exception: + pass + + eos_token_id = getattr(processor.tokenizer, 'eos_token_id', None) + if eos_token_id is None and hasattr(model.config, 'text_config'): + eos_token_id = getattr(model.config.text_config, 'eos_token_id', None) + if eos_token_id is not None: + kwargs['eos_token_id'] = eos_token_id + + pad_token_id = getattr(processor.tokenizer, 'pad_token_id', None) + if pad_token_id is None and hasattr(model.config, 'text_config'): + pad_token_id = getattr(model.config.text_config, 'pad_token_id', None) + if pad_token_id is not None: + kwargs['pad_token_id'] = pad_token_id + + return kwargs + + +@app.route('/health', methods=['GET']) +def health(): + return jsonify({'status': 'ok', 'model_loaded': model is not None, 'model_id': current_model_id}) + + +@app.route('/process-audio', methods=['POST']) +def process_audio(): + try: + if 'audio' not in request.files: + return jsonify({'error': 'No audio file provided'}), 400 + + audio_file = request.files['audio'] + system_prompt = request.form.get('system_prompt', DEFAULT_S2M_PROMPT) + system_prompt = ensure_speech_prompt(system_prompt) + tools = parse_tools(request.form.get('tools')) + model_id = request.form.get('model', MODEL_DEFAULT_ID).strip() or MODEL_DEFAULT_ID + + load_model_if_needed(model_id) + + target_sr = 24000 if is_personaplex(model_id) else 16000 + audio_array, _ = load_audio(audio_file, target_sr) + system_prompt = build_system_prompt(system_prompt, tools) + + inputs = prepare_inputs(audio_array, system_prompt) + + with torch.no_grad(): + outputs = model.generate( + **inputs, + **get_generation_kwargs(), + ) + + response_text, audio_tokens = decode_generation(outputs, inputs) + tool_calls = extract_tool_calls(response_text, tools) + if not tool_calls: + tool_calls = fallback_tool_calls(response_text, tools) + if tools: + transcript = transcribe_audio_array(audio_array, target_sr) + if transcript: + tool_calls = override_weather_tool_call(tool_calls, transcript, tools) + if not tool_calls: + tool_calls = infer_tool_calls_from_text(transcript, tools) + if not tool_calls: + tool_calls = fallback_tool_calls(transcript, tools) + if tools and not tool_calls: + tool_calls = infer_tool_calls_from_text(response_text, tools) + response_text = strip_tool_calls(response_text).strip() + if tool_calls and not response_text: + response_text = "Calling tool..." + + audio_url = None + output_path = synthesize_audio(audio_tokens) + if output_path: + audio_url = f'/audio/{os.path.basename(output_path)}' + + return jsonify({ + 'text': response_text, + 'audio_url': audio_url, + 'tool_calls': tool_calls, + 'status': 'success' + }) + + except Exception as exc: + print(f'Error: {str(exc)}') + import traceback + traceback.print_exc() + return jsonify({'error': str(exc)}), 500 + + +@app.route('/process-audio-stream', methods=['POST']) +def process_audio_stream(): + try: + if 'audio' not in request.files: + return jsonify({'error': 'No audio file provided'}), 400 + + audio_file = request.files['audio'] + system_prompt = request.form.get('system_prompt', DEFAULT_S2M_PROMPT) + system_prompt = ensure_speech_prompt(system_prompt) + tools = parse_tools(request.form.get('tools')) + model_id = request.form.get('model', MODEL_DEFAULT_ID).strip() or MODEL_DEFAULT_ID + + load_model_if_needed(model_id) + + target_sr = 24000 if is_personaplex(model_id) else 16000 + audio_array, _ = load_audio(audio_file, target_sr) + system_prompt = build_system_prompt(system_prompt, tools) + + inputs = prepare_inputs(audio_array, system_prompt) + + with torch.no_grad(): + outputs = model.generate( + **inputs, + **get_generation_kwargs(), + ) + + response_text, audio_tokens = decode_generation(outputs, inputs) + tool_calls = extract_tool_calls(response_text, tools) + if not tool_calls: + tool_calls = fallback_tool_calls(response_text, tools) + if tools: + transcript = transcribe_audio_array(audio_array, target_sr) + if transcript: + tool_calls = override_weather_tool_call(tool_calls, transcript, tools) + if not tool_calls: + tool_calls = infer_tool_calls_from_text(transcript, tools) + if not tool_calls: + tool_calls = fallback_tool_calls(transcript, tools) + if tools and not tool_calls: + tool_calls = infer_tool_calls_from_text(response_text, tools) + response_text = strip_tool_calls(response_text).strip() + if tool_calls and not response_text: + response_text = "Calling tool..." + output_path = synthesize_audio(audio_tokens) + audio_url = f'/audio/{os.path.basename(output_path)}' if output_path else None + + def generate_events(): + for chunk in chunk_text(response_text): + yield f"data: {json.dumps({'delta': chunk})}\n\n" + time.sleep(0.02) + + for tool in tool_calls: + yield f"data: {json.dumps({'tool_call': tool})}\n\n" + + if audio_url: + yield f"data: {json.dumps({'audio_url': audio_url})}\n\n" + + yield "data: [DONE]\n\n" + + return Response( + stream_with_context(generate_events()), + mimetype='text/event-stream', + headers={ + 'Cache-Control': 'no-cache', + 'X-Accel-Buffering': 'no', + } + ) + + except Exception as exc: + print(f'Error: {str(exc)}') + import traceback + traceback.print_exc() + return jsonify({'error': str(exc)}), 500 + + +@app.route('/audio/', methods=['GET']) +def get_audio(filename): + return send_file(os.path.join(output_dir, filename), mimetype='audio/wav') + + +if __name__ == '__main__': + load_model_if_needed(MODEL_DEFAULT_ID) + app.run(host='0.0.0.0', port=11236, debug=False, threaded=True) diff --git a/start_server.sh b/start_server.sh new file mode 100755 index 0000000..b0fa8e7 --- /dev/null +++ b/start_server.sh @@ -0,0 +1,6 @@ +#!/bin/bash +cd ~/Fun-Audio-Chat +source venv/bin/activate +export PATH=$HOME/bin:$PATH +export PYTHONPATH=$PWD:$PYTHONPATH +python simple_server.py diff --git a/web_demo/server/README.md b/web_demo/server/README.md index 5d783ea..e46d123 100644 --- a/web_demo/server/README.md +++ b/web_demo/server/README.md @@ -16,6 +16,16 @@ From the project root directory, run: python -m web_demo.server.server ``` +### GPU Compatibility Launcher (sm_121 and newer) + +If you see NVRTC errors like `invalid value for --gpu-architecture` on newer GPUs, +use the compatibility launcher to set conservative CUDA env defaults before importing +Torch: + +```bash +FUN_AUDIOCHAT_CUDA_COMPAT=1 python -m web_demo.server.server_gpu_compat --host 0.0.0.0 --port 11235 +``` + ## Command Line Arguments | Argument | Type | Default | Description | @@ -53,4 +63,3 @@ Additional model parameters can be configured in `utils/constants.py`: ## License The present code is provided under the MIT license. - diff --git a/web_demo/server/server_gpu_compat.py b/web_demo/server/server_gpu_compat.py new file mode 100644 index 0000000..a045606 --- /dev/null +++ b/web_demo/server/server_gpu_compat.py @@ -0,0 +1,26 @@ +""" +GPU-compat launcher for FunAudioChat server. + +Use this when running on new architectures (e.g., sm_121) with older CUDA/NVRTC. +It sets conservative env defaults before importing torch to avoid NVRTC arch errors +and prefers PTX fallback where possible. + +Usage: + FUN_AUDIOCHAT_CUDA_COMPAT=1 python web_demo/server/server_gpu_compat.py --host 0.0.0.0 --port 11235 +""" +import os + +# Only apply when explicitly enabled. +if os.environ.get("FUN_AUDIOCHAT_CUDA_COMPAT") == "1": + # Prefer PTX fallback for new architectures when SASS is unavailable. + os.environ.setdefault("TORCH_CUDA_ARCH_LIST", "12.0+PTX") + # Avoid aggressive kernel fusion paths that may JIT-compile with NVRTC. + os.environ.setdefault("TORCHINDUCTOR_DISABLE", "1") + os.environ.setdefault("XFORMERS_DISABLE_TRITON", "1") + os.environ.setdefault("FLASH_ATTENTION_DISABLE", "1") + +from web_demo.server import server + + +if __name__ == "__main__": + server.main()