695 lines
28 KiB
Python
695 lines
28 KiB
Python
|
|
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Liu Yue)
|
|||
|
|
#
|
|||
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|||
|
|
# you may not use this file except in compliance with the License.
|
|||
|
|
# You may obtain a copy of the License at
|
|||
|
|
#
|
|||
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|||
|
|
#
|
|||
|
|
# Unless required by applicable law or agreed to in writing, software
|
|||
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|||
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
|
|
# See the License for the specific language governing permissions and
|
|||
|
|
# limitations under the License.
|
|||
|
|
import os
|
|||
|
|
import sys
|
|||
|
|
import argparse
|
|||
|
|
import logging
|
|||
|
|
import io
|
|||
|
|
import torch
|
|||
|
|
import torchaudio
|
|||
|
|
import librosa
|
|||
|
|
import numpy as np
|
|||
|
|
import time
|
|||
|
|
from urllib.parse import quote
|
|||
|
|
from typing import Optional, List, Dict, Callable
|
|||
|
|
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
|||
|
|
from fastapi.responses import StreamingResponse
|
|||
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|||
|
|
import uvicorn
|
|||
|
|
|
|||
|
|
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
|||
|
|
sys.path.append('{}/third_party/Matcha-TTS'.format(ROOT_DIR))
|
|||
|
|
|
|||
|
|
# 预先下载 wetext 模型,避免在初始化时下载(离线环境支持)
|
|||
|
|
def _preload_wetext_model():
|
|||
|
|
"""预先下载 wetext 模型到本地缓存,避免在初始化时下载
|
|||
|
|
|
|||
|
|
此函数会在导入 CosyVoice2 之前检查并下载 wetext 模型。
|
|||
|
|
如果模型已存在于缓存中,则不会重新下载,支持离线环境运行。
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
# 检查是否已安装 wetext
|
|||
|
|
import wetext
|
|||
|
|
except ImportError:
|
|||
|
|
# 如果没有安装 wetext,直接返回(会使用 ttsfrd)
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 检查 ModelScope 缓存目录中是否已有 wetext 模型
|
|||
|
|
cache_root = os.path.expanduser('~/.cache/modelscope/hub')
|
|||
|
|
wetext_cache_dir = os.path.join(cache_root, 'pengzhendong', 'wetext')
|
|||
|
|
|
|||
|
|
# 检查必要的文件是否存在
|
|||
|
|
required_files = [
|
|||
|
|
os.path.join(wetext_cache_dir, 'zh', 'tn', 'tagger.fst'),
|
|||
|
|
os.path.join(wetext_cache_dir, 'zh', 'tn', 'verbalizer.fst'),
|
|||
|
|
os.path.join(wetext_cache_dir, 'en', 'tn', 'tagger.fst'),
|
|||
|
|
os.path.join(wetext_cache_dir, 'en', 'tn', 'verbalizer.fst'),
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
# 如果所有文件都存在,说明模型已下载,直接返回
|
|||
|
|
if all(os.path.exists(f) for f in required_files):
|
|||
|
|
logging.info(f'wetext 模型已存在于缓存: {wetext_cache_dir}')
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# 如果模型不存在,尝试下载(仅在联网时)
|
|||
|
|
# 注意:snapshot_download 在模型已存在时会直接返回路径,不会重新下载
|
|||
|
|
logging.info('正在下载 wetext 模型(如果已存在则使用缓存)...')
|
|||
|
|
from modelscope import snapshot_download
|
|||
|
|
downloaded_dir = snapshot_download("pengzhendong/wetext")
|
|||
|
|
logging.info(f'wetext 模型已就绪: {downloaded_dir}')
|
|||
|
|
except Exception as e:
|
|||
|
|
# 如果下载失败(可能是离线环境),记录警告但继续运行
|
|||
|
|
# 后续初始化时会再次尝试下载或报错
|
|||
|
|
logging.warning(f'无法预先下载 wetext 模型: {e},将在初始化时尝试下载')
|
|||
|
|
|
|||
|
|
# 在导入 CosyVoice2 之前预先加载 wetext 模型
|
|||
|
|
_preload_wetext_model()
|
|||
|
|
|
|||
|
|
from cosyvoice.cli.cosyvoice import CosyVoice2
|
|||
|
|
from cosyvoice.utils.file_utils import load_wav
|
|||
|
|
from cosyvoice.utils.common import set_all_random_seed
|
|||
|
|
|
|||
|
|
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
|||
|
|
|
|||
|
|
app = FastAPI(title="CosyVoice API Server", version="1.0.0")
|
|||
|
|
|
|||
|
|
# 设置跨域支持
|
|||
|
|
app.add_middleware(
|
|||
|
|
CORSMiddleware,
|
|||
|
|
allow_origins=["*"],
|
|||
|
|
allow_credentials=True,
|
|||
|
|
allow_methods=["*"],
|
|||
|
|
allow_headers=["*"],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 全局变量
|
|||
|
|
cosyvoice = None
|
|||
|
|
prompt_sr = 16000
|
|||
|
|
max_val = 0.8
|
|||
|
|
spk2info: Dict[str, Dict[str, torch.Tensor]] = {}
|
|||
|
|
spk2info_path: Optional[str] = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def postprocess(speech, top_db=60, hop_length=220, win_length=440):
|
|||
|
|
"""后处理音频,去除静音并归一化"""
|
|||
|
|
speech, _ = librosa.effects.trim(
|
|||
|
|
speech, top_db=top_db,
|
|||
|
|
frame_length=win_length,
|
|||
|
|
hop_length=hop_length
|
|||
|
|
)
|
|||
|
|
if speech.abs().max() > max_val:
|
|||
|
|
speech = speech / speech.abs().max() * max_val
|
|||
|
|
speech = torch.concat([speech, torch.zeros(1, int(cosyvoice.sample_rate * 0.2))], dim=1)
|
|||
|
|
return speech
|
|||
|
|
|
|||
|
|
|
|||
|
|
def generate_wav_stream(model_output, sample_rate, stream_mode=False):
|
|||
|
|
"""生成WAV格式的音频流"""
|
|||
|
|
if stream_mode:
|
|||
|
|
# 流式模式:逐个chunk返回
|
|||
|
|
for i in model_output:
|
|||
|
|
audio_data = i['tts_speech'].numpy().flatten()
|
|||
|
|
# 转换为int16格式
|
|||
|
|
audio_int16 = (audio_data * (2 ** 15)).astype(np.int16)
|
|||
|
|
# 创建WAV字节流
|
|||
|
|
buffer = io.BytesIO()
|
|||
|
|
torchaudio.save(buffer, torch.from_numpy(audio_int16).unsqueeze(0), sample_rate, format='wav')
|
|||
|
|
yield buffer.getvalue()
|
|||
|
|
else:
|
|||
|
|
# 非流式模式:收集所有chunk后一次性返回
|
|||
|
|
audio_chunks = []
|
|||
|
|
for i in model_output:
|
|||
|
|
audio_chunks.append(i['tts_speech'].numpy().flatten())
|
|||
|
|
|
|||
|
|
if len(audio_chunks) > 0:
|
|||
|
|
# 拼接所有chunk
|
|||
|
|
audio_data = np.concatenate(audio_chunks)
|
|||
|
|
# 转换为int16格式
|
|||
|
|
audio_int16 = (audio_data * (2 ** 15)).astype(np.int16)
|
|||
|
|
# 创建WAV字节流
|
|||
|
|
buffer = io.BytesIO()
|
|||
|
|
torchaudio.save(buffer, torch.from_numpy(audio_int16).unsqueeze(0), sample_rate, format='wav')
|
|||
|
|
yield buffer.getvalue()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def build_content_disposition(filename: str) -> str:
|
|||
|
|
"""构造兼容多语言文件名的 Content-Disposition 头"""
|
|||
|
|
safe_ascii = filename.encode('ascii', errors='ignore').decode('ascii') or 'download.wav'
|
|||
|
|
quoted = quote(filename)
|
|||
|
|
return f'attachment; filename="{safe_ascii}"; filename*=UTF-8\'\'{quoted}'
|
|||
|
|
|
|||
|
|
|
|||
|
|
def load_speaker_info():
|
|||
|
|
"""加载已保存的音色缓存"""
|
|||
|
|
global spk2info
|
|||
|
|
if cosyvoice is None or spk2info_path is None:
|
|||
|
|
spk2info = {}
|
|||
|
|
return
|
|||
|
|
if os.path.exists(spk2info_path):
|
|||
|
|
logging.info(f"加载音色缓存: {spk2info_path}")
|
|||
|
|
spk2info = torch.load(spk2info_path, map_location=cosyvoice.frontend.device)
|
|||
|
|
logging.info(f"✓ 已加载 {len(spk2info)} 个音色特征")
|
|||
|
|
else:
|
|||
|
|
spk2info = {}
|
|||
|
|
logging.info("未找到音色缓存文件,初始化为空")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def save_speaker_info():
|
|||
|
|
"""将音色缓存保存到磁盘"""
|
|||
|
|
if spk2info_path is None:
|
|||
|
|
raise RuntimeError("spk2info_path 未初始化,无法保存音色信息")
|
|||
|
|
torch.save(spk2info, spk2info_path)
|
|||
|
|
logging.info(f"音色缓存已保存到 {spk2info_path}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def save_uploaded_file(upload_file: UploadFile) -> str:
|
|||
|
|
"""保存上传的音频文件到临时路径"""
|
|||
|
|
import tempfile
|
|||
|
|
suffix = os.path.splitext(upload_file.filename)[1]
|
|||
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
|
|||
|
|
content = upload_file.file.read()
|
|||
|
|
tmp.write(content)
|
|||
|
|
return tmp.name
|
|||
|
|
|
|||
|
|
|
|||
|
|
def extract_speaker_features(speaker_id: str, prompt_wav_path: str, prompt_text: str, force_regenerate: bool = False) -> bool:
|
|||
|
|
"""提取并缓存音色的 embedding / speech_feat / speech_token"""
|
|||
|
|
global spk2info
|
|||
|
|
if cosyvoice is None:
|
|||
|
|
raise RuntimeError("模型未加载,无法提取音色特征")
|
|||
|
|
if speaker_id in spk2info and not force_regenerate:
|
|||
|
|
logging.info(f"音色 {speaker_id} 已存在,跳过提取")
|
|||
|
|
return True
|
|||
|
|
try:
|
|||
|
|
logging.info(f"开始提取音色 {speaker_id} 特征")
|
|||
|
|
prompt_speech_16k = load_wav(prompt_wav_path, 16000)
|
|||
|
|
embedding = cosyvoice.frontend._extract_spk_embedding(prompt_speech_16k)
|
|||
|
|
resample_op = torchaudio.transforms.Resample(orig_freq=16000, new_freq=cosyvoice.sample_rate)
|
|||
|
|
prompt_speech_resample = resample_op(prompt_speech_16k)
|
|||
|
|
speech_feat, _ = cosyvoice.frontend._extract_speech_feat(prompt_speech_resample)
|
|||
|
|
speech_token, _ = cosyvoice.frontend._extract_speech_token(prompt_speech_16k)
|
|||
|
|
spk2info[speaker_id] = {
|
|||
|
|
"embedding": embedding,
|
|||
|
|
"speech_feat": speech_feat,
|
|||
|
|
"speech_token": speech_token,
|
|||
|
|
"prompt_text": prompt_text,
|
|||
|
|
}
|
|||
|
|
logging.info(f"音色 {speaker_id} 特征提取完成")
|
|||
|
|
return True
|
|||
|
|
except Exception as e:
|
|||
|
|
logging.error(f"提取音色特征失败: {e}")
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
|
|||
|
|
def tts_with_cached_features(
|
|||
|
|
tts_text: str,
|
|||
|
|
speaker_id: str,
|
|||
|
|
prompt_text: Optional[str] = "",
|
|||
|
|
stream: bool = False,
|
|||
|
|
speed: float = 1.0,
|
|||
|
|
text_frontend: bool = True,
|
|||
|
|
progress_callback: Optional[Callable[[int, int], None]] = None,
|
|||
|
|
):
|
|||
|
|
"""使用缓存特征进行快速 SFT 合成"""
|
|||
|
|
if cosyvoice is None:
|
|||
|
|
raise RuntimeError("模型未加载")
|
|||
|
|
if speaker_id not in spk2info:
|
|||
|
|
raise ValueError(f"音色 {speaker_id} 不存在或未缓存")
|
|||
|
|
speaker_info = spk2info[speaker_id]
|
|||
|
|
if not prompt_text and "prompt_text" in speaker_info:
|
|||
|
|
prompt_text = speaker_info["prompt_text"]
|
|||
|
|
segments = list(cosyvoice.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend))
|
|||
|
|
if not segments:
|
|||
|
|
segments = [tts_text]
|
|||
|
|
total_segments = len(segments)
|
|||
|
|
for idx, chunk_text in enumerate(segments, start=1):
|
|||
|
|
text_token, text_token_len = cosyvoice.frontend._extract_text_token(chunk_text)
|
|||
|
|
prompt_text_token, prompt_text_token_len = cosyvoice.frontend._extract_text_token(prompt_text)
|
|||
|
|
speech_token_len = torch.tensor(
|
|||
|
|
[speaker_info["speech_token"].shape[1]],
|
|||
|
|
dtype=torch.int32,
|
|||
|
|
).to(cosyvoice.frontend.device)
|
|||
|
|
speech_feat_len = torch.tensor(
|
|||
|
|
[speaker_info["speech_feat"].shape[1]],
|
|||
|
|
dtype=torch.int32,
|
|||
|
|
).to(cosyvoice.frontend.device)
|
|||
|
|
model_input = {
|
|||
|
|
"text": text_token,
|
|||
|
|
"text_len": text_token_len,
|
|||
|
|
"prompt_text": prompt_text_token,
|
|||
|
|
"prompt_text_len": prompt_text_token_len,
|
|||
|
|
"llm_prompt_speech_token": speaker_info["speech_token"],
|
|||
|
|
"llm_prompt_speech_token_len": speech_token_len,
|
|||
|
|
"flow_prompt_speech_token": speaker_info["speech_token"],
|
|||
|
|
"flow_prompt_speech_token_len": speech_token_len,
|
|||
|
|
"prompt_speech_feat": speaker_info["speech_feat"],
|
|||
|
|
"prompt_speech_feat_len": speech_feat_len,
|
|||
|
|
"llm_embedding": speaker_info["embedding"],
|
|||
|
|
"flow_embedding": speaker_info["embedding"],
|
|||
|
|
}
|
|||
|
|
if progress_callback:
|
|||
|
|
progress_callback(idx, total_segments)
|
|||
|
|
for output in cosyvoice.model.tts(**model_input, stream=stream, speed=speed):
|
|||
|
|
yield output
|
|||
|
|
|
|||
|
|
@app.get("/")
|
|||
|
|
async def root():
|
|||
|
|
"""根路径,返回API信息"""
|
|||
|
|
return {
|
|||
|
|
"name": "CosyVoice API Server",
|
|||
|
|
"version": "1.0.0",
|
|||
|
|
"endpoints": {
|
|||
|
|
"/inference": "统一推理接口,支持所有模式",
|
|||
|
|
"/list_speakers": "获取可用的预训练音色列表",
|
|||
|
|
"/health": "健康检查"
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.get("/health")
|
|||
|
|
async def health():
|
|||
|
|
"""健康检查"""
|
|||
|
|
return {"status": "healthy", "model_loaded": cosyvoice is not None}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.get("/list_speakers")
|
|||
|
|
async def list_speakers():
|
|||
|
|
"""获取可用的预训练音色列表"""
|
|||
|
|
if cosyvoice is None:
|
|||
|
|
raise HTTPException(status_code=503, detail="模型未加载")
|
|||
|
|
speakers = cosyvoice.list_available_spks()
|
|||
|
|
return {"speakers": speakers if speakers else []}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.get("/api/speakers", response_model=List[str])
|
|||
|
|
async def list_cached_speakers():
|
|||
|
|
"""获取已缓存特征的音色ID"""
|
|||
|
|
if cosyvoice is None:
|
|||
|
|
raise HTTPException(status_code=500, detail="模型未加载")
|
|||
|
|
logging.info(f"当前缓存音色数量: {len(spk2info)}")
|
|||
|
|
return list(spk2info.keys())
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.get("/api/speakers/info")
|
|||
|
|
async def get_cached_speakers_info():
|
|||
|
|
"""获取缓存音色的详细信息"""
|
|||
|
|
if cosyvoice is None:
|
|||
|
|
raise HTTPException(status_code=500, detail="模型未加载")
|
|||
|
|
speakers_info = []
|
|||
|
|
for speaker_id, info in spk2info.items():
|
|||
|
|
speakers_info.append({
|
|||
|
|
"speaker_id": speaker_id,
|
|||
|
|
"has_embedding": "embedding" in info,
|
|||
|
|
"has_speech_feat": "speech_feat" in info,
|
|||
|
|
"has_speech_token": "speech_token" in info,
|
|||
|
|
"embedding_shape": str(info["embedding"].shape) if "embedding" in info else None,
|
|||
|
|
"speech_feat_shape": str(info["speech_feat"].shape) if "speech_feat" in info else None,
|
|||
|
|
"speech_token_shape": str(info["speech_token"].shape) if "speech_token" in info else None,
|
|||
|
|
})
|
|||
|
|
return {"total_count": len(speakers_info), "speakers": speakers_info}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.post("/api/speakers/add")
|
|||
|
|
async def add_speaker(
|
|||
|
|
speaker_id: str = Form(..., description="音色ID"),
|
|||
|
|
prompt_text: str = Form(..., description="参考文本"),
|
|||
|
|
prompt_wav: UploadFile = File(..., description="参考音频文件"),
|
|||
|
|
force_regenerate: bool = Form(False, description="是否强制重新生成"),
|
|||
|
|
):
|
|||
|
|
"""上传音频并提取缓存音色特征"""
|
|||
|
|
if cosyvoice is None:
|
|||
|
|
raise HTTPException(status_code=500, detail="模型未加载")
|
|||
|
|
if speaker_id in spk2info and not force_regenerate:
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=400,
|
|||
|
|
detail=f"音色 {speaker_id} 已存在,可设置 force_regenerate=true 重新生成",
|
|||
|
|
)
|
|||
|
|
temp_path = None
|
|||
|
|
try:
|
|||
|
|
temp_path = save_uploaded_file(prompt_wav)
|
|||
|
|
success = extract_speaker_features(
|
|||
|
|
speaker_id=speaker_id,
|
|||
|
|
prompt_wav_path=temp_path,
|
|||
|
|
prompt_text=prompt_text,
|
|||
|
|
force_regenerate=force_regenerate,
|
|||
|
|
)
|
|||
|
|
if not success:
|
|||
|
|
raise HTTPException(status_code=500, detail="特征提取失败")
|
|||
|
|
return {
|
|||
|
|
"success": True,
|
|||
|
|
"speaker_id": speaker_id,
|
|||
|
|
"message": f"音色 {speaker_id} 特征已缓存,如需持久化请调用 /api/speakers/save",
|
|||
|
|
"cached_features": ["embedding", "speech_feat", "speech_token"],
|
|||
|
|
}
|
|||
|
|
except HTTPException:
|
|||
|
|
raise
|
|||
|
|
except Exception as e:
|
|||
|
|
logging.error(f"添加音色失败: {e}")
|
|||
|
|
raise HTTPException(status_code=500, detail=f"添加音色失败: {e}")
|
|||
|
|
finally:
|
|||
|
|
if temp_path and os.path.exists(temp_path):
|
|||
|
|
os.remove(temp_path)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.post("/api/speakers/save")
|
|||
|
|
async def save_cached_speakers():
|
|||
|
|
"""将缓存音色保存到磁盘"""
|
|||
|
|
if cosyvoice is None:
|
|||
|
|
raise HTTPException(status_code=500, detail="模型未加载")
|
|||
|
|
try:
|
|||
|
|
save_speaker_info()
|
|||
|
|
return {
|
|||
|
|
"success": True,
|
|||
|
|
"file_path": spk2info_path,
|
|||
|
|
"total_speakers": len(spk2info),
|
|||
|
|
"speakers": list(spk2info.keys()),
|
|||
|
|
}
|
|||
|
|
except Exception as e:
|
|||
|
|
logging.error(f"保存音色缓存失败: {e}")
|
|||
|
|
raise HTTPException(status_code=500, detail=f"保存失败: {e}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.delete("/api/speakers/{speaker_id}")
|
|||
|
|
async def delete_cached_speaker(speaker_id: str):
|
|||
|
|
"""删除指定缓存音色"""
|
|||
|
|
if cosyvoice is None:
|
|||
|
|
raise HTTPException(status_code=500, detail="模型未加载")
|
|||
|
|
if speaker_id not in spk2info:
|
|||
|
|
raise HTTPException(status_code=404, detail=f"音色 {speaker_id} 不存在")
|
|||
|
|
del spk2info[speaker_id]
|
|||
|
|
logging.info(f"音色 {speaker_id} 已从缓存移除")
|
|||
|
|
return {
|
|||
|
|
"success": True,
|
|||
|
|
"message": f"音色 {speaker_id} 已删除,如需同步磁盘请调用 /api/speakers/save",
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.post("/api/speakers/regenerate/{speaker_id}")
|
|||
|
|
async def regenerate_cached_speaker(
|
|||
|
|
speaker_id: str,
|
|||
|
|
prompt_text: str = Form(...),
|
|||
|
|
prompt_wav: UploadFile = File(...),
|
|||
|
|
):
|
|||
|
|
"""重新生成缓存音色"""
|
|||
|
|
if cosyvoice is None:
|
|||
|
|
raise HTTPException(status_code=500, detail="模型未加载")
|
|||
|
|
if speaker_id not in spk2info:
|
|||
|
|
raise HTTPException(status_code=404, detail=f"音色 {speaker_id} 不存在")
|
|||
|
|
temp_path = None
|
|||
|
|
try:
|
|||
|
|
temp_path = save_uploaded_file(prompt_wav)
|
|||
|
|
success = extract_speaker_features(
|
|||
|
|
speaker_id=speaker_id,
|
|||
|
|
prompt_wav_path=temp_path,
|
|||
|
|
prompt_text=prompt_text,
|
|||
|
|
force_regenerate=True,
|
|||
|
|
)
|
|||
|
|
if not success:
|
|||
|
|
raise HTTPException(status_code=500, detail="特征重新生成失败")
|
|||
|
|
return {"success": True, "message": f"音色 {speaker_id} 特征已更新"}
|
|||
|
|
except HTTPException:
|
|||
|
|
raise
|
|||
|
|
except Exception as e:
|
|||
|
|
logging.error(f"重新生成音色失败: {e}")
|
|||
|
|
raise HTTPException(status_code=500, detail=f"重新生成失败: {e}")
|
|||
|
|
finally:
|
|||
|
|
if temp_path and os.path.exists(temp_path):
|
|||
|
|
os.remove(temp_path)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.post("/api/synthesis/sft")
|
|||
|
|
async def synthesis_sft_cached(
|
|||
|
|
text: str = Form(..., description="要合成的文本"),
|
|||
|
|
speaker_id: str = Form(..., description="缓存音色ID"),
|
|||
|
|
prompt_text: str = Form("", description="可选提示文本,默认使用缓存中保存的文本"),
|
|||
|
|
stream: bool = Form(False, description="是否流式返回"),
|
|||
|
|
speed: float = Form(1.0, description="语速调节"),
|
|||
|
|
):
|
|||
|
|
"""使用缓存音色特征进行高速 SFT 合成"""
|
|||
|
|
if cosyvoice is None:
|
|||
|
|
raise HTTPException(status_code=500, detail="模型未加载")
|
|||
|
|
if not text.strip():
|
|||
|
|
raise HTTPException(status_code=400, detail="文本不能为空")
|
|||
|
|
if speaker_id not in spk2info:
|
|||
|
|
raise HTTPException(status_code=404, detail=f"音色 {speaker_id} 不存在")
|
|||
|
|
start_time = time.perf_counter()
|
|||
|
|
progress_total = {"total": 0}
|
|||
|
|
|
|||
|
|
def log_progress(current: int, total: int):
|
|||
|
|
progress_total["total"] = total
|
|||
|
|
logging.info(f"SFT缓存合成进度 - 音色:{speaker_id} {current}/{total}")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
logging.info(f"SFT缓存合成开始 - 音色:{speaker_id}, 文本长度:{len(text)}")
|
|||
|
|
model_output = tts_with_cached_features(
|
|||
|
|
tts_text=text,
|
|||
|
|
speaker_id=speaker_id,
|
|||
|
|
prompt_text=prompt_text,
|
|||
|
|
stream=stream,
|
|||
|
|
speed=speed,
|
|||
|
|
progress_callback=log_progress,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
headers = {
|
|||
|
|
"Content-Disposition": build_content_disposition(f"synthesis_{speaker_id}.wav"),
|
|||
|
|
"X-Sample-Rate": str(cosyvoice.sample_rate),
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if stream:
|
|||
|
|
headers["X-Progress-Total"] = str(progress_total["total"])
|
|||
|
|
|
|||
|
|
def timed_stream():
|
|||
|
|
try:
|
|||
|
|
for chunk in generate_wav_stream(model_output, cosyvoice.sample_rate, stream_mode=True):
|
|||
|
|
yield chunk
|
|||
|
|
finally:
|
|||
|
|
elapsed = time.perf_counter() - start_time
|
|||
|
|
logging.info(f"SFT缓存合成完成(流式) - 音色:{speaker_id}, 耗时:{elapsed:.3f}s")
|
|||
|
|
|
|||
|
|
return StreamingResponse(timed_stream(), media_type="audio/wav", headers=headers)
|
|||
|
|
|
|||
|
|
audio_bytes = b"".join(
|
|||
|
|
generate_wav_stream(model_output, cosyvoice.sample_rate, stream_mode=False)
|
|||
|
|
)
|
|||
|
|
elapsed = time.perf_counter() - start_time
|
|||
|
|
headers["X-Processing-Time"] = f"{elapsed:.3f}"
|
|||
|
|
logging.info(f"SFT缓存合成完成 - 音色:{speaker_id}, 耗时:{elapsed:.3f}s")
|
|||
|
|
return StreamingResponse(
|
|||
|
|
iter([audio_bytes]),
|
|||
|
|
media_type="audio/wav",
|
|||
|
|
headers=headers,
|
|||
|
|
)
|
|||
|
|
except ValueError as e:
|
|||
|
|
raise HTTPException(status_code=404, detail=str(e))
|
|||
|
|
except Exception as e:
|
|||
|
|
logging.error(f"SFT 缓存合成失败: {e}")
|
|||
|
|
raise HTTPException(status_code=500, detail=f"合成失败: {e}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.post("/inference")
|
|||
|
|
async def inference(
|
|||
|
|
tts_text: str = Form(..., description="需要合成的文本"),
|
|||
|
|
mode: str = Form(..., description="推理模式: 'sft'(预训练音色), 'zero_shot'(3s极速复刻), 'cross_lingual'(跨语种复刻), 'instruct'(自然语言控制)"),
|
|||
|
|
spk_id: Optional[str] = Form(None, description="预训练音色ID(sft和instruct模式需要)"),
|
|||
|
|
prompt_text: Optional[str] = Form(None, description="prompt文本(zero_shot模式需要)"),
|
|||
|
|
prompt_wav: Optional[UploadFile] = File(None, description="prompt音频文件(zero_shot和cross_lingual模式需要)"),
|
|||
|
|
instruct_text: Optional[str] = Form(None, description="instruct文本(instruct模式需要)"),
|
|||
|
|
seed: Optional[int] = Form(0, description="随机种子,0表示随机"),
|
|||
|
|
stream: bool = Form(False, description="是否流式推理"),
|
|||
|
|
speed: float = Form(1.0, description="速度调节(0.5-2.0,仅支持非流式推理)", ge=0.5, le=2.0)
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
统一推理接口,支持所有推理模式
|
|||
|
|
|
|||
|
|
模式说明:
|
|||
|
|
- sft: 预训练音色模式,需要spk_id
|
|||
|
|
- zero_shot: 3s极速复刻模式,需要prompt_text和prompt_wav
|
|||
|
|
- cross_lingual: 跨语种复刻模式,需要prompt_wav
|
|||
|
|
- instruct: 自然语言控制模式,需要spk_id和instruct_text
|
|||
|
|
"""
|
|||
|
|
if cosyvoice is None:
|
|||
|
|
raise HTTPException(status_code=503, detail="模型未加载")
|
|||
|
|
|
|||
|
|
# 参数验证
|
|||
|
|
if mode == 'sft':
|
|||
|
|
if not spk_id:
|
|||
|
|
raise HTTPException(status_code=400, detail="sft模式需要提供spk_id")
|
|||
|
|
available_spks = cosyvoice.list_available_spks()
|
|||
|
|
if not available_spks or spk_id not in available_spks:
|
|||
|
|
raise HTTPException(status_code=400, detail=f"无效的spk_id: {spk_id},可用音色: {available_spks}")
|
|||
|
|
|
|||
|
|
elif mode == 'zero_shot':
|
|||
|
|
if not prompt_text:
|
|||
|
|
raise HTTPException(status_code=400, detail="zero_shot模式需要提供prompt_text")
|
|||
|
|
if not prompt_wav:
|
|||
|
|
raise HTTPException(status_code=400, detail="zero_shot模式需要提供prompt_wav")
|
|||
|
|
|
|||
|
|
elif mode == 'cross_lingual':
|
|||
|
|
if not prompt_wav:
|
|||
|
|
raise HTTPException(status_code=400, detail="cross_lingual模式需要提供prompt_wav")
|
|||
|
|
if cosyvoice.instruct is True:
|
|||
|
|
raise HTTPException(status_code=400, detail="当前模型不支持cross_lingual模式,请使用非Instruct模型")
|
|||
|
|
|
|||
|
|
elif mode == 'instruct':
|
|||
|
|
if not spk_id:
|
|||
|
|
raise HTTPException(status_code=400, detail="instruct模式需要提供spk_id")
|
|||
|
|
if not instruct_text:
|
|||
|
|
raise HTTPException(status_code=400, detail="instruct模式需要提供instruct_text")
|
|||
|
|
if cosyvoice.instruct is False:
|
|||
|
|
raise HTTPException(status_code=400, detail="当前模型不支持instruct模式,请使用CosyVoice-300M-Instruct模型")
|
|||
|
|
available_spks = cosyvoice.list_available_spks()
|
|||
|
|
if not available_spks or spk_id not in available_spks:
|
|||
|
|
raise HTTPException(status_code=400, detail=f"无效的spk_id: {spk_id},可用音色: {available_spks}")
|
|||
|
|
|
|||
|
|
else:
|
|||
|
|
raise HTTPException(status_code=400, detail=f"无效的模式: {mode},支持的模式: sft, zero_shot, cross_lingual, instruct")
|
|||
|
|
|
|||
|
|
# 设置随机种子
|
|||
|
|
if seed > 0:
|
|||
|
|
set_all_random_seed(seed)
|
|||
|
|
|
|||
|
|
# 流式模式下速度必须为1.0
|
|||
|
|
if stream and speed != 1.0:
|
|||
|
|
raise HTTPException(status_code=400, detail="流式推理模式下速度必须为1.0")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 执行推理
|
|||
|
|
if mode == 'sft':
|
|||
|
|
logging.info('get sft inference request')
|
|||
|
|
model_output = cosyvoice.inference_sft(tts_text, spk_id, stream=stream, speed=speed)
|
|||
|
|
|
|||
|
|
elif mode == 'zero_shot':
|
|||
|
|
logging.info('get zero_shot inference request')
|
|||
|
|
# 保存上传的文件到临时位置
|
|||
|
|
temp_file = io.BytesIO(await prompt_wav.read())
|
|||
|
|
try:
|
|||
|
|
prompt_speech_16k = postprocess(load_wav(temp_file, prompt_sr))
|
|||
|
|
except AssertionError as e:
|
|||
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|||
|
|
model_output = cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=stream, speed=speed)
|
|||
|
|
|
|||
|
|
elif mode == 'cross_lingual':
|
|||
|
|
logging.info('get cross_lingual inference request')
|
|||
|
|
# 保存上传的文件到临时位置
|
|||
|
|
temp_file = io.BytesIO(await prompt_wav.read())
|
|||
|
|
try:
|
|||
|
|
prompt_speech_16k = postprocess(load_wav(temp_file, prompt_sr))
|
|||
|
|
except AssertionError as e:
|
|||
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|||
|
|
model_output = cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream, speed=speed)
|
|||
|
|
|
|||
|
|
else: # instruct
|
|||
|
|
logging.info('get instruct inference request')
|
|||
|
|
model_output = cosyvoice.inference_instruct(tts_text, spk_id, instruct_text, stream=stream, speed=speed)
|
|||
|
|
|
|||
|
|
# 返回音频流
|
|||
|
|
return StreamingResponse(
|
|||
|
|
generate_wav_stream(model_output, cosyvoice.sample_rate, stream_mode=stream),
|
|||
|
|
media_type="audio/wav",
|
|||
|
|
headers={
|
|||
|
|
"Content-Disposition": "attachment; filename=output.wav",
|
|||
|
|
"X-Sample-Rate": str(cosyvoice.sample_rate)
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logging.error(f"推理错误: {str(e)}", exc_info=True)
|
|||
|
|
raise HTTPException(status_code=500, detail=f"推理失败: {str(e)}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.post("/inference_sft")
|
|||
|
|
async def inference_sft(
|
|||
|
|
tts_text: str = Form(...),
|
|||
|
|
spk_id: str = Form(...),
|
|||
|
|
seed: Optional[int] = Form(0),
|
|||
|
|
stream: bool = Form(False),
|
|||
|
|
speed: float = Form(1.0, ge=0.5, le=2.0)
|
|||
|
|
):
|
|||
|
|
"""预训练音色推理接口"""
|
|||
|
|
return await inference(tts_text=tts_text, mode='sft', spk_id=spk_id, seed=seed, stream=stream, speed=speed)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.post("/inference_zero_shot")
|
|||
|
|
async def inference_zero_shot(
|
|||
|
|
tts_text: str = Form(...),
|
|||
|
|
prompt_text: str = Form(...),
|
|||
|
|
prompt_wav: UploadFile = File(...),
|
|||
|
|
seed: Optional[int] = Form(0),
|
|||
|
|
stream: bool = Form(False),
|
|||
|
|
speed: float = Form(1.0, ge=0.5, le=2.0)
|
|||
|
|
):
|
|||
|
|
"""3s极速复刻推理接口"""
|
|||
|
|
return await inference(tts_text=tts_text, mode='zero_shot', prompt_text=prompt_text,
|
|||
|
|
prompt_wav=prompt_wav, seed=seed, stream=stream, speed=speed)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.post("/inference_cross_lingual")
|
|||
|
|
async def inference_cross_lingual(
|
|||
|
|
tts_text: str = Form(...),
|
|||
|
|
prompt_wav: UploadFile = File(...),
|
|||
|
|
seed: Optional[int] = Form(0),
|
|||
|
|
stream: bool = Form(False),
|
|||
|
|
speed: float = Form(1.0, ge=0.5, le=2.0)
|
|||
|
|
):
|
|||
|
|
"""跨语种复刻推理接口"""
|
|||
|
|
return await inference(tts_text=tts_text, mode='cross_lingual', prompt_wav=prompt_wav,
|
|||
|
|
seed=seed, stream=stream, speed=speed)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.post("/inference_instruct")
|
|||
|
|
async def inference_instruct(
|
|||
|
|
tts_text: str = Form(...),
|
|||
|
|
spk_id: str = Form(...),
|
|||
|
|
instruct_text: str = Form(...),
|
|||
|
|
seed: Optional[int] = Form(0),
|
|||
|
|
stream: bool = Form(False),
|
|||
|
|
speed: float = Form(1.0, ge=0.5, le=2.0)
|
|||
|
|
):
|
|||
|
|
"""自然语言控制推理接口"""
|
|||
|
|
return await inference(tts_text=tts_text, mode='instruct', spk_id=spk_id,
|
|||
|
|
instruct_text=instruct_text, seed=seed, stream=stream, speed=speed)
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
parser = argparse.ArgumentParser(description='CosyVoice API Server')
|
|||
|
|
parser.add_argument('--port', type=int, default=8000, help='服务器端口')
|
|||
|
|
parser.add_argument('--host', type=str, default='0.0.0.0', help='服务器地址')
|
|||
|
|
parser.add_argument('--model_dir', type=str, default='pretrain/CosyVoice2-0.5B',
|
|||
|
|
help='模型路径(本地路径或modelscope repo id)')
|
|||
|
|
args = parser.parse_args()
|
|||
|
|
|
|||
|
|
# 加载模型(仅支持 CosyVoice2)
|
|||
|
|
logging.info(f"正在加载 CosyVoice2 模型: {args.model_dir}")
|
|||
|
|
try:
|
|||
|
|
cosyvoice = CosyVoice2(args.model_dir, load_jit=False, load_trt=False, fp16=False)
|
|||
|
|
logging.info("成功加载 CosyVoice2 模型")
|
|||
|
|
except Exception as e:
|
|||
|
|
logging.error(f"模型加载失败: {e}")
|
|||
|
|
raise RuntimeError(f'无法加载 CosyVoice2 模型: {e}')
|
|||
|
|
|
|||
|
|
spk2info_path = os.path.join(args.model_dir, 'spk2info.pt')
|
|||
|
|
load_speaker_info()
|
|||
|
|
if spk2info:
|
|||
|
|
logging.info(f"已加载 {len(spk2info)} 个缓存音色: {list(spk2info.keys())[:5]}{'...' if len(spk2info) > 5 else ''}")
|
|||
|
|
else:
|
|||
|
|
logging.info("当前无缓存音色,可通过 /api/speakers/add 添加")
|
|||
|
|
|
|||
|
|
prompt_sr = 16000
|
|||
|
|
logging.info(f"API服务器启动在 http://{args.host}:{args.port}")
|
|||
|
|
logging.info(f"可用音色: {cosyvoice.list_available_spks()}")
|
|||
|
|
|
|||
|
|
uvicorn.run(app, host=args.host, port=args.port)
|
|||
|
|
|