Files
TheVirtualOne/tts_server.py

695 lines
28 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Liu Yue)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import argparse
import logging
import io
import torch
import torchaudio
import librosa
import numpy as np
import time
from urllib.parse import quote
from typing import Optional, List, Dict, Callable
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/third_party/Matcha-TTS'.format(ROOT_DIR))
# 预先下载 wetext 模型,避免在初始化时下载(离线环境支持)
def _preload_wetext_model():
"""预先下载 wetext 模型到本地缓存,避免在初始化时下载
此函数会在导入 CosyVoice2 之前检查并下载 wetext 模型。
如果模型已存在于缓存中,则不会重新下载,支持离线环境运行。
"""
try:
# 检查是否已安装 wetext
import wetext
except ImportError:
# 如果没有安装 wetext直接返回会使用 ttsfrd
return
try:
# 检查 ModelScope 缓存目录中是否已有 wetext 模型
cache_root = os.path.expanduser('~/.cache/modelscope/hub')
wetext_cache_dir = os.path.join(cache_root, 'pengzhendong', 'wetext')
# 检查必要的文件是否存在
required_files = [
os.path.join(wetext_cache_dir, 'zh', 'tn', 'tagger.fst'),
os.path.join(wetext_cache_dir, 'zh', 'tn', 'verbalizer.fst'),
os.path.join(wetext_cache_dir, 'en', 'tn', 'tagger.fst'),
os.path.join(wetext_cache_dir, 'en', 'tn', 'verbalizer.fst'),
]
# 如果所有文件都存在,说明模型已下载,直接返回
if all(os.path.exists(f) for f in required_files):
logging.info(f'wetext 模型已存在于缓存: {wetext_cache_dir}')
return
# 如果模型不存在,尝试下载(仅在联网时)
# 注意snapshot_download 在模型已存在时会直接返回路径,不会重新下载
logging.info('正在下载 wetext 模型(如果已存在则使用缓存)...')
from modelscope import snapshot_download
downloaded_dir = snapshot_download("pengzhendong/wetext")
logging.info(f'wetext 模型已就绪: {downloaded_dir}')
except Exception as e:
# 如果下载失败(可能是离线环境),记录警告但继续运行
# 后续初始化时会再次尝试下载或报错
logging.warning(f'无法预先下载 wetext 模型: {e},将在初始化时尝试下载')
# 在导入 CosyVoice2 之前预先加载 wetext 模型
_preload_wetext_model()
from cosyvoice.cli.cosyvoice import CosyVoice2
from cosyvoice.utils.file_utils import load_wav
from cosyvoice.utils.common import set_all_random_seed
logging.getLogger('matplotlib').setLevel(logging.WARNING)
app = FastAPI(title="CosyVoice API Server", version="1.0.0")
# 设置跨域支持
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 全局变量
cosyvoice = None
prompt_sr = 16000
max_val = 0.8
spk2info: Dict[str, Dict[str, torch.Tensor]] = {}
spk2info_path: Optional[str] = None
def postprocess(speech, top_db=60, hop_length=220, win_length=440):
"""后处理音频,去除静音并归一化"""
speech, _ = librosa.effects.trim(
speech, top_db=top_db,
frame_length=win_length,
hop_length=hop_length
)
if speech.abs().max() > max_val:
speech = speech / speech.abs().max() * max_val
speech = torch.concat([speech, torch.zeros(1, int(cosyvoice.sample_rate * 0.2))], dim=1)
return speech
def generate_wav_stream(model_output, sample_rate, stream_mode=False):
"""生成WAV格式的音频流"""
if stream_mode:
# 流式模式逐个chunk返回
for i in model_output:
audio_data = i['tts_speech'].numpy().flatten()
# 转换为int16格式
audio_int16 = (audio_data * (2 ** 15)).astype(np.int16)
# 创建WAV字节流
buffer = io.BytesIO()
torchaudio.save(buffer, torch.from_numpy(audio_int16).unsqueeze(0), sample_rate, format='wav')
yield buffer.getvalue()
else:
# 非流式模式收集所有chunk后一次性返回
audio_chunks = []
for i in model_output:
audio_chunks.append(i['tts_speech'].numpy().flatten())
if len(audio_chunks) > 0:
# 拼接所有chunk
audio_data = np.concatenate(audio_chunks)
# 转换为int16格式
audio_int16 = (audio_data * (2 ** 15)).astype(np.int16)
# 创建WAV字节流
buffer = io.BytesIO()
torchaudio.save(buffer, torch.from_numpy(audio_int16).unsqueeze(0), sample_rate, format='wav')
yield buffer.getvalue()
def build_content_disposition(filename: str) -> str:
"""构造兼容多语言文件名的 Content-Disposition 头"""
safe_ascii = filename.encode('ascii', errors='ignore').decode('ascii') or 'download.wav'
quoted = quote(filename)
return f'attachment; filename="{safe_ascii}"; filename*=UTF-8\'\'{quoted}'
def load_speaker_info():
"""加载已保存的音色缓存"""
global spk2info
if cosyvoice is None or spk2info_path is None:
spk2info = {}
return
if os.path.exists(spk2info_path):
logging.info(f"加载音色缓存: {spk2info_path}")
spk2info = torch.load(spk2info_path, map_location=cosyvoice.frontend.device)
logging.info(f"✓ 已加载 {len(spk2info)} 个音色特征")
else:
spk2info = {}
logging.info("未找到音色缓存文件,初始化为空")
def save_speaker_info():
"""将音色缓存保存到磁盘"""
if spk2info_path is None:
raise RuntimeError("spk2info_path 未初始化,无法保存音色信息")
torch.save(spk2info, spk2info_path)
logging.info(f"音色缓存已保存到 {spk2info_path}")
def save_uploaded_file(upload_file: UploadFile) -> str:
"""保存上传的音频文件到临时路径"""
import tempfile
suffix = os.path.splitext(upload_file.filename)[1]
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
content = upload_file.file.read()
tmp.write(content)
return tmp.name
def extract_speaker_features(speaker_id: str, prompt_wav_path: str, prompt_text: str, force_regenerate: bool = False) -> bool:
"""提取并缓存音色的 embedding / speech_feat / speech_token"""
global spk2info
if cosyvoice is None:
raise RuntimeError("模型未加载,无法提取音色特征")
if speaker_id in spk2info and not force_regenerate:
logging.info(f"音色 {speaker_id} 已存在,跳过提取")
return True
try:
logging.info(f"开始提取音色 {speaker_id} 特征")
prompt_speech_16k = load_wav(prompt_wav_path, 16000)
embedding = cosyvoice.frontend._extract_spk_embedding(prompt_speech_16k)
resample_op = torchaudio.transforms.Resample(orig_freq=16000, new_freq=cosyvoice.sample_rate)
prompt_speech_resample = resample_op(prompt_speech_16k)
speech_feat, _ = cosyvoice.frontend._extract_speech_feat(prompt_speech_resample)
speech_token, _ = cosyvoice.frontend._extract_speech_token(prompt_speech_16k)
spk2info[speaker_id] = {
"embedding": embedding,
"speech_feat": speech_feat,
"speech_token": speech_token,
"prompt_text": prompt_text,
}
logging.info(f"音色 {speaker_id} 特征提取完成")
return True
except Exception as e:
logging.error(f"提取音色特征失败: {e}")
return False
def tts_with_cached_features(
tts_text: str,
speaker_id: str,
prompt_text: Optional[str] = "",
stream: bool = False,
speed: float = 1.0,
text_frontend: bool = True,
progress_callback: Optional[Callable[[int, int], None]] = None,
):
"""使用缓存特征进行快速 SFT 合成"""
if cosyvoice is None:
raise RuntimeError("模型未加载")
if speaker_id not in spk2info:
raise ValueError(f"音色 {speaker_id} 不存在或未缓存")
speaker_info = spk2info[speaker_id]
if not prompt_text and "prompt_text" in speaker_info:
prompt_text = speaker_info["prompt_text"]
segments = list(cosyvoice.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend))
if not segments:
segments = [tts_text]
total_segments = len(segments)
for idx, chunk_text in enumerate(segments, start=1):
text_token, text_token_len = cosyvoice.frontend._extract_text_token(chunk_text)
prompt_text_token, prompt_text_token_len = cosyvoice.frontend._extract_text_token(prompt_text)
speech_token_len = torch.tensor(
[speaker_info["speech_token"].shape[1]],
dtype=torch.int32,
).to(cosyvoice.frontend.device)
speech_feat_len = torch.tensor(
[speaker_info["speech_feat"].shape[1]],
dtype=torch.int32,
).to(cosyvoice.frontend.device)
model_input = {
"text": text_token,
"text_len": text_token_len,
"prompt_text": prompt_text_token,
"prompt_text_len": prompt_text_token_len,
"llm_prompt_speech_token": speaker_info["speech_token"],
"llm_prompt_speech_token_len": speech_token_len,
"flow_prompt_speech_token": speaker_info["speech_token"],
"flow_prompt_speech_token_len": speech_token_len,
"prompt_speech_feat": speaker_info["speech_feat"],
"prompt_speech_feat_len": speech_feat_len,
"llm_embedding": speaker_info["embedding"],
"flow_embedding": speaker_info["embedding"],
}
if progress_callback:
progress_callback(idx, total_segments)
for output in cosyvoice.model.tts(**model_input, stream=stream, speed=speed):
yield output
@app.get("/")
async def root():
"""根路径返回API信息"""
return {
"name": "CosyVoice API Server",
"version": "1.0.0",
"endpoints": {
"/inference": "统一推理接口,支持所有模式",
"/list_speakers": "获取可用的预训练音色列表",
"/health": "健康检查"
}
}
@app.get("/health")
async def health():
"""健康检查"""
return {"status": "healthy", "model_loaded": cosyvoice is not None}
@app.get("/list_speakers")
async def list_speakers():
"""获取可用的预训练音色列表"""
if cosyvoice is None:
raise HTTPException(status_code=503, detail="模型未加载")
speakers = cosyvoice.list_available_spks()
return {"speakers": speakers if speakers else []}
@app.get("/api/speakers", response_model=List[str])
async def list_cached_speakers():
"""获取已缓存特征的音色ID"""
if cosyvoice is None:
raise HTTPException(status_code=500, detail="模型未加载")
logging.info(f"当前缓存音色数量: {len(spk2info)}")
return list(spk2info.keys())
@app.get("/api/speakers/info")
async def get_cached_speakers_info():
"""获取缓存音色的详细信息"""
if cosyvoice is None:
raise HTTPException(status_code=500, detail="模型未加载")
speakers_info = []
for speaker_id, info in spk2info.items():
speakers_info.append({
"speaker_id": speaker_id,
"has_embedding": "embedding" in info,
"has_speech_feat": "speech_feat" in info,
"has_speech_token": "speech_token" in info,
"embedding_shape": str(info["embedding"].shape) if "embedding" in info else None,
"speech_feat_shape": str(info["speech_feat"].shape) if "speech_feat" in info else None,
"speech_token_shape": str(info["speech_token"].shape) if "speech_token" in info else None,
})
return {"total_count": len(speakers_info), "speakers": speakers_info}
@app.post("/api/speakers/add")
async def add_speaker(
speaker_id: str = Form(..., description="音色ID"),
prompt_text: str = Form(..., description="参考文本"),
prompt_wav: UploadFile = File(..., description="参考音频文件"),
force_regenerate: bool = Form(False, description="是否强制重新生成"),
):
"""上传音频并提取缓存音色特征"""
if cosyvoice is None:
raise HTTPException(status_code=500, detail="模型未加载")
if speaker_id in spk2info and not force_regenerate:
raise HTTPException(
status_code=400,
detail=f"音色 {speaker_id} 已存在,可设置 force_regenerate=true 重新生成",
)
temp_path = None
try:
temp_path = save_uploaded_file(prompt_wav)
success = extract_speaker_features(
speaker_id=speaker_id,
prompt_wav_path=temp_path,
prompt_text=prompt_text,
force_regenerate=force_regenerate,
)
if not success:
raise HTTPException(status_code=500, detail="特征提取失败")
return {
"success": True,
"speaker_id": speaker_id,
"message": f"音色 {speaker_id} 特征已缓存,如需持久化请调用 /api/speakers/save",
"cached_features": ["embedding", "speech_feat", "speech_token"],
}
except HTTPException:
raise
except Exception as e:
logging.error(f"添加音色失败: {e}")
raise HTTPException(status_code=500, detail=f"添加音色失败: {e}")
finally:
if temp_path and os.path.exists(temp_path):
os.remove(temp_path)
@app.post("/api/speakers/save")
async def save_cached_speakers():
"""将缓存音色保存到磁盘"""
if cosyvoice is None:
raise HTTPException(status_code=500, detail="模型未加载")
try:
save_speaker_info()
return {
"success": True,
"file_path": spk2info_path,
"total_speakers": len(spk2info),
"speakers": list(spk2info.keys()),
}
except Exception as e:
logging.error(f"保存音色缓存失败: {e}")
raise HTTPException(status_code=500, detail=f"保存失败: {e}")
@app.delete("/api/speakers/{speaker_id}")
async def delete_cached_speaker(speaker_id: str):
"""删除指定缓存音色"""
if cosyvoice is None:
raise HTTPException(status_code=500, detail="模型未加载")
if speaker_id not in spk2info:
raise HTTPException(status_code=404, detail=f"音色 {speaker_id} 不存在")
del spk2info[speaker_id]
logging.info(f"音色 {speaker_id} 已从缓存移除")
return {
"success": True,
"message": f"音色 {speaker_id} 已删除,如需同步磁盘请调用 /api/speakers/save",
}
@app.post("/api/speakers/regenerate/{speaker_id}")
async def regenerate_cached_speaker(
speaker_id: str,
prompt_text: str = Form(...),
prompt_wav: UploadFile = File(...),
):
"""重新生成缓存音色"""
if cosyvoice is None:
raise HTTPException(status_code=500, detail="模型未加载")
if speaker_id not in spk2info:
raise HTTPException(status_code=404, detail=f"音色 {speaker_id} 不存在")
temp_path = None
try:
temp_path = save_uploaded_file(prompt_wav)
success = extract_speaker_features(
speaker_id=speaker_id,
prompt_wav_path=temp_path,
prompt_text=prompt_text,
force_regenerate=True,
)
if not success:
raise HTTPException(status_code=500, detail="特征重新生成失败")
return {"success": True, "message": f"音色 {speaker_id} 特征已更新"}
except HTTPException:
raise
except Exception as e:
logging.error(f"重新生成音色失败: {e}")
raise HTTPException(status_code=500, detail=f"重新生成失败: {e}")
finally:
if temp_path and os.path.exists(temp_path):
os.remove(temp_path)
@app.post("/api/synthesis/sft")
async def synthesis_sft_cached(
text: str = Form(..., description="要合成的文本"),
speaker_id: str = Form(..., description="缓存音色ID"),
prompt_text: str = Form("", description="可选提示文本,默认使用缓存中保存的文本"),
stream: bool = Form(False, description="是否流式返回"),
speed: float = Form(1.0, description="语速调节"),
):
"""使用缓存音色特征进行高速 SFT 合成"""
if cosyvoice is None:
raise HTTPException(status_code=500, detail="模型未加载")
if not text.strip():
raise HTTPException(status_code=400, detail="文本不能为空")
if speaker_id not in spk2info:
raise HTTPException(status_code=404, detail=f"音色 {speaker_id} 不存在")
start_time = time.perf_counter()
progress_total = {"total": 0}
def log_progress(current: int, total: int):
progress_total["total"] = total
logging.info(f"SFT缓存合成进度 - 音色:{speaker_id} {current}/{total}")
try:
logging.info(f"SFT缓存合成开始 - 音色:{speaker_id}, 文本长度:{len(text)}")
model_output = tts_with_cached_features(
tts_text=text,
speaker_id=speaker_id,
prompt_text=prompt_text,
stream=stream,
speed=speed,
progress_callback=log_progress,
)
headers = {
"Content-Disposition": build_content_disposition(f"synthesis_{speaker_id}.wav"),
"X-Sample-Rate": str(cosyvoice.sample_rate),
}
if stream:
headers["X-Progress-Total"] = str(progress_total["total"])
def timed_stream():
try:
for chunk in generate_wav_stream(model_output, cosyvoice.sample_rate, stream_mode=True):
yield chunk
finally:
elapsed = time.perf_counter() - start_time
logging.info(f"SFT缓存合成完成(流式) - 音色:{speaker_id}, 耗时:{elapsed:.3f}s")
return StreamingResponse(timed_stream(), media_type="audio/wav", headers=headers)
audio_bytes = b"".join(
generate_wav_stream(model_output, cosyvoice.sample_rate, stream_mode=False)
)
elapsed = time.perf_counter() - start_time
headers["X-Processing-Time"] = f"{elapsed:.3f}"
logging.info(f"SFT缓存合成完成 - 音色:{speaker_id}, 耗时:{elapsed:.3f}s")
return StreamingResponse(
iter([audio_bytes]),
media_type="audio/wav",
headers=headers,
)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
logging.error(f"SFT 缓存合成失败: {e}")
raise HTTPException(status_code=500, detail=f"合成失败: {e}")
@app.post("/inference")
async def inference(
tts_text: str = Form(..., description="需要合成的文本"),
mode: str = Form(..., description="推理模式: 'sft'(预训练音色), 'zero_shot'(3s极速复刻), 'cross_lingual'(跨语种复刻), 'instruct'(自然语言控制)"),
spk_id: Optional[str] = Form(None, description="预训练音色IDsft和instruct模式需要"),
prompt_text: Optional[str] = Form(None, description="prompt文本zero_shot模式需要"),
prompt_wav: Optional[UploadFile] = File(None, description="prompt音频文件zero_shot和cross_lingual模式需要"),
instruct_text: Optional[str] = Form(None, description="instruct文本instruct模式需要"),
seed: Optional[int] = Form(0, description="随机种子0表示随机"),
stream: bool = Form(False, description="是否流式推理"),
speed: float = Form(1.0, description="速度调节(0.5-2.0,仅支持非流式推理)", ge=0.5, le=2.0)
):
"""
统一推理接口,支持所有推理模式
模式说明:
- sft: 预训练音色模式需要spk_id
- zero_shot: 3s极速复刻模式需要prompt_text和prompt_wav
- cross_lingual: 跨语种复刻模式需要prompt_wav
- instruct: 自然语言控制模式需要spk_id和instruct_text
"""
if cosyvoice is None:
raise HTTPException(status_code=503, detail="模型未加载")
# 参数验证
if mode == 'sft':
if not spk_id:
raise HTTPException(status_code=400, detail="sft模式需要提供spk_id")
available_spks = cosyvoice.list_available_spks()
if not available_spks or spk_id not in available_spks:
raise HTTPException(status_code=400, detail=f"无效的spk_id: {spk_id},可用音色: {available_spks}")
elif mode == 'zero_shot':
if not prompt_text:
raise HTTPException(status_code=400, detail="zero_shot模式需要提供prompt_text")
if not prompt_wav:
raise HTTPException(status_code=400, detail="zero_shot模式需要提供prompt_wav")
elif mode == 'cross_lingual':
if not prompt_wav:
raise HTTPException(status_code=400, detail="cross_lingual模式需要提供prompt_wav")
if cosyvoice.instruct is True:
raise HTTPException(status_code=400, detail="当前模型不支持cross_lingual模式请使用非Instruct模型")
elif mode == 'instruct':
if not spk_id:
raise HTTPException(status_code=400, detail="instruct模式需要提供spk_id")
if not instruct_text:
raise HTTPException(status_code=400, detail="instruct模式需要提供instruct_text")
if cosyvoice.instruct is False:
raise HTTPException(status_code=400, detail="当前模型不支持instruct模式请使用CosyVoice-300M-Instruct模型")
available_spks = cosyvoice.list_available_spks()
if not available_spks or spk_id not in available_spks:
raise HTTPException(status_code=400, detail=f"无效的spk_id: {spk_id},可用音色: {available_spks}")
else:
raise HTTPException(status_code=400, detail=f"无效的模式: {mode},支持的模式: sft, zero_shot, cross_lingual, instruct")
# 设置随机种子
if seed > 0:
set_all_random_seed(seed)
# 流式模式下速度必须为1.0
if stream and speed != 1.0:
raise HTTPException(status_code=400, detail="流式推理模式下速度必须为1.0")
try:
# 执行推理
if mode == 'sft':
logging.info('get sft inference request')
model_output = cosyvoice.inference_sft(tts_text, spk_id, stream=stream, speed=speed)
elif mode == 'zero_shot':
logging.info('get zero_shot inference request')
# 保存上传的文件到临时位置
temp_file = io.BytesIO(await prompt_wav.read())
try:
prompt_speech_16k = postprocess(load_wav(temp_file, prompt_sr))
except AssertionError as e:
raise HTTPException(status_code=400, detail=str(e))
model_output = cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=stream, speed=speed)
elif mode == 'cross_lingual':
logging.info('get cross_lingual inference request')
# 保存上传的文件到临时位置
temp_file = io.BytesIO(await prompt_wav.read())
try:
prompt_speech_16k = postprocess(load_wav(temp_file, prompt_sr))
except AssertionError as e:
raise HTTPException(status_code=400, detail=str(e))
model_output = cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream, speed=speed)
else: # instruct
logging.info('get instruct inference request')
model_output = cosyvoice.inference_instruct(tts_text, spk_id, instruct_text, stream=stream, speed=speed)
# 返回音频流
return StreamingResponse(
generate_wav_stream(model_output, cosyvoice.sample_rate, stream_mode=stream),
media_type="audio/wav",
headers={
"Content-Disposition": "attachment; filename=output.wav",
"X-Sample-Rate": str(cosyvoice.sample_rate)
}
)
except Exception as e:
logging.error(f"推理错误: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"推理失败: {str(e)}")
@app.post("/inference_sft")
async def inference_sft(
tts_text: str = Form(...),
spk_id: str = Form(...),
seed: Optional[int] = Form(0),
stream: bool = Form(False),
speed: float = Form(1.0, ge=0.5, le=2.0)
):
"""预训练音色推理接口"""
return await inference(tts_text=tts_text, mode='sft', spk_id=spk_id, seed=seed, stream=stream, speed=speed)
@app.post("/inference_zero_shot")
async def inference_zero_shot(
tts_text: str = Form(...),
prompt_text: str = Form(...),
prompt_wav: UploadFile = File(...),
seed: Optional[int] = Form(0),
stream: bool = Form(False),
speed: float = Form(1.0, ge=0.5, le=2.0)
):
"""3s极速复刻推理接口"""
return await inference(tts_text=tts_text, mode='zero_shot', prompt_text=prompt_text,
prompt_wav=prompt_wav, seed=seed, stream=stream, speed=speed)
@app.post("/inference_cross_lingual")
async def inference_cross_lingual(
tts_text: str = Form(...),
prompt_wav: UploadFile = File(...),
seed: Optional[int] = Form(0),
stream: bool = Form(False),
speed: float = Form(1.0, ge=0.5, le=2.0)
):
"""跨语种复刻推理接口"""
return await inference(tts_text=tts_text, mode='cross_lingual', prompt_wav=prompt_wav,
seed=seed, stream=stream, speed=speed)
@app.post("/inference_instruct")
async def inference_instruct(
tts_text: str = Form(...),
spk_id: str = Form(...),
instruct_text: str = Form(...),
seed: Optional[int] = Form(0),
stream: bool = Form(False),
speed: float = Form(1.0, ge=0.5, le=2.0)
):
"""自然语言控制推理接口"""
return await inference(tts_text=tts_text, mode='instruct', spk_id=spk_id,
instruct_text=instruct_text, seed=seed, stream=stream, speed=speed)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='CosyVoice API Server')
parser.add_argument('--port', type=int, default=8000, help='服务器端口')
parser.add_argument('--host', type=str, default='0.0.0.0', help='服务器地址')
parser.add_argument('--model_dir', type=str, default='pretrain/CosyVoice2-0.5B',
help='模型路径本地路径或modelscope repo id')
args = parser.parse_args()
# 加载模型(仅支持 CosyVoice2
logging.info(f"正在加载 CosyVoice2 模型: {args.model_dir}")
try:
cosyvoice = CosyVoice2(args.model_dir, load_jit=False, load_trt=False, fp16=False)
logging.info("成功加载 CosyVoice2 模型")
except Exception as e:
logging.error(f"模型加载失败: {e}")
raise RuntimeError(f'无法加载 CosyVoice2 模型: {e}')
spk2info_path = os.path.join(args.model_dir, 'spk2info.pt')
load_speaker_info()
if spk2info:
logging.info(f"已加载 {len(spk2info)} 个缓存音色: {list(spk2info.keys())[:5]}{'...' if len(spk2info) > 5 else ''}")
else:
logging.info("当前无缓存音色,可通过 /api/speakers/add 添加")
prompt_sr = 16000
logging.info(f"API服务器启动在 http://{args.host}:{args.port}")
logging.info(f"可用音色: {cosyvoice.list_available_spks()}")
uvicorn.run(app, host=args.host, port=args.port)