956 lines
32 KiB
Python
956 lines
32 KiB
Python
from Convention.Runtime.GlobalConfig import *
|
||
import asyncio
|
||
from datetime import datetime
|
||
from llama_index.llms.ollama import Ollama
|
||
from llama_index.core.chat_engine import SimpleChatEngine
|
||
from llama_index.core.chat_engine.types import StreamingAgentChatResponse
|
||
from llama_index.core import Settings
|
||
from Convention.Runtime.File import ToolFile
|
||
import requests
|
||
import wave
|
||
import io
|
||
import base64
|
||
import json
|
||
import time
|
||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
|
||
from fastapi.responses import HTMLResponse, FileResponse
|
||
from fastapi.staticfiles import StaticFiles
|
||
from contextlib import asynccontextmanager
|
||
from typing import Optional, Set
|
||
from pydantic import BaseModel
|
||
|
||
chat_start_id = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
|
||
# 全局变量
|
||
config: Optional[ProjectConfig] = None
|
||
chat_engine: Optional[SimpleChatEngine] = None
|
||
connected_clients: Set[WebSocket] = set()
|
||
temp_dir: Optional[ToolFile] = None
|
||
last_message_time: float = 0.0
|
||
auto_speak_task: Optional[asyncio.Task] = None
|
||
is_processing: bool = False # 标记是否正在处理消息
|
||
current_wait_second: float = 15.0 # 当前等待间隔(动态调整)
|
||
last_user_message_time: float = 0.0 # 最后用户消息时间(用于检测用户回应)
|
||
|
||
# 配置变量
|
||
OLLAMA_URL: str = "http://localhost:11434"
|
||
OLLAMA_MODEL: str = "gemma3:4b"
|
||
RESPONSE_TIMEOUT: int = 60
|
||
TEMPERATURE: float = 1.3
|
||
MAX_CONTENT_LENGTH: Optional[int] = None
|
||
SYSTEM_PROMPT_PATH: Optional[str] = None
|
||
TTS_SERVER_URL: str = "http://localhost:43400"
|
||
TTS_PROMPT_TEXT: Optional[str] = None
|
||
TTS_PROMPT_WAV_PATH: Optional[str] = None
|
||
TTS_SPEAKER_ID: str = "tts_speaker"
|
||
STREAM_ENABLE: bool = False
|
||
VERBOSE: bool = False
|
||
AUTO_SPEAK_WAIT_SECOND: float = 15.0
|
||
|
||
def initialize_config():
|
||
"""初始化配置"""
|
||
global config, temp_dir, OLLAMA_URL, OLLAMA_MODEL, RESPONSE_TIMEOUT, TEMPERATURE
|
||
global MAX_CONTENT_LENGTH, SYSTEM_PROMPT_PATH, TTS_SERVER_URL, TTS_PROMPT_TEXT
|
||
global TTS_PROMPT_WAV_PATH, TTS_SPEAKER_ID, STREAM_ENABLE, VERBOSE, AUTO_SPEAK_WAIT_SECOND
|
||
|
||
config = ProjectConfig()
|
||
OLLAMA_URL = config.FindItem("ollama_url", "http://localhost:11434")
|
||
OLLAMA_MODEL = config.FindItem("ollama_model", "gemma3:4b")
|
||
RESPONSE_TIMEOUT = config.FindItem("response_timeout", 60)
|
||
TEMPERATURE = config.FindItem("temperature", 1.3)
|
||
MAX_CONTENT_LENGTH = config.FindItem("max_content_length", None)
|
||
SYSTEM_PROMPT_PATH = config.FindItem("system_prompt_path", None)
|
||
TTS_SERVER_URL = config.FindItem("tts_server_url", "http://localhost:43400")
|
||
TTS_PROMPT_TEXT = config.FindItem("tts_prompt_text", None)
|
||
TTS_PROMPT_WAV_PATH = config.FindItem("tts_prompt_wav_path", None)
|
||
TTS_SPEAKER_ID = config.FindItem("tts_speaker_id", "tts_speaker")
|
||
STREAM_ENABLE = config.FindItem("stream_enable", False)
|
||
VERBOSE = config.FindItem("verbose", False)
|
||
AUTO_SPEAK_WAIT_SECOND = config.FindItem("auto_speak_wait_second", 15.0)
|
||
|
||
temp_dir = config.GetFile("temp") | chat_start_id | None
|
||
|
||
if VERBOSE:
|
||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"OLLAMA_URL: {OLLAMA_URL}")
|
||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"OLLAMA_MODEL: {OLLAMA_MODEL}")
|
||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"RESPONSE_TIMEOUT: {RESPONSE_TIMEOUT}")
|
||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"TEMPERATURE: {TEMPERATURE}")
|
||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"AUTO_SPEAK_WAIT_SECOND: {AUTO_SPEAK_WAIT_SECOND}")
|
||
|
||
config.SaveProperties()
|
||
|
||
def initialize_chat_engine():
|
||
"""初始化聊天引擎"""
|
||
global chat_engine
|
||
|
||
ollama_llm_config = {
|
||
"model": OLLAMA_MODEL,
|
||
"base_url": OLLAMA_URL,
|
||
"request_timeout": RESPONSE_TIMEOUT,
|
||
"temperature": TEMPERATURE,
|
||
}
|
||
|
||
chat_engine_config = {}
|
||
|
||
if MAX_CONTENT_LENGTH is not None:
|
||
ollama_llm_config["max_content_length"] = MAX_CONTENT_LENGTH
|
||
|
||
if SYSTEM_PROMPT_PATH is not None:
|
||
system_prompt = ToolFile(SYSTEM_PROMPT_PATH).LoadAsText()
|
||
chat_engine_config["system_prompt"] = system_prompt
|
||
if VERBOSE:
|
||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"system_prompt loaded")
|
||
|
||
ollama_llm = Ollama(**ollama_llm_config)
|
||
Settings.llm = ollama_llm
|
||
chat_engine = SimpleChatEngine.from_defaults(**chat_engine_config)
|
||
|
||
if VERBOSE:
|
||
PrintColorful(ConsoleFrontColor.LIGHTGREEN_EX, "Chat engine initialized")
|
||
|
||
def save_vocal_data(data: bytes) -> ToolFile:
|
||
"""保存音频数据"""
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||
filename = f"{timestamp}.wav"
|
||
file = temp_dir | filename
|
||
file.MustExistsPath()
|
||
file.SaveAsBinary(data)
|
||
return file
|
||
|
||
async def generate_tts_audio(text: str) -> Optional[bytes]:
|
||
"""生成TTS音频,返回音频字节数据"""
|
||
if len(text) == 0 or not text:
|
||
return None
|
||
|
||
if TTS_PROMPT_WAV_PATH is None:
|
||
return None
|
||
|
||
tts_server_url = f"{TTS_SERVER_URL}/api/synthesis/sft"
|
||
header = {
|
||
"accept": "application/json"
|
||
}
|
||
data = {
|
||
'text': text,
|
||
'speaker_id': TTS_SPEAKER_ID,
|
||
'stream': STREAM_ENABLE,
|
||
}
|
||
|
||
def _generate_sync():
|
||
"""同步生成TTS音频"""
|
||
try:
|
||
if STREAM_ENABLE:
|
||
response = requests.post(tts_server_url, data=data, stream=True, timeout=600, headers=header)
|
||
if response.status_code != 200:
|
||
if VERBOSE:
|
||
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"TTS失败: {response.status_code} - {response.text}")
|
||
return None
|
||
|
||
wav_buffer = bytearray()
|
||
for chunk in response.iter_content(chunk_size=1024 * 256):
|
||
if not chunk:
|
||
continue
|
||
wav_buffer.extend(chunk)
|
||
|
||
if wav_buffer:
|
||
complete_wav = bytes(wav_buffer)
|
||
save_vocal_data(complete_wav)
|
||
return complete_wav
|
||
return None
|
||
else:
|
||
response = requests.post(tts_server_url, data=data, timeout=600, headers=header)
|
||
if response.status_code == 200:
|
||
audio_data = response.content
|
||
save_vocal_data(audio_data)
|
||
return audio_data
|
||
else:
|
||
if VERBOSE:
|
||
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"TTS失败: {response.status_code} - {response.text}")
|
||
return None
|
||
except Exception as e:
|
||
if VERBOSE:
|
||
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"TTS异常: {e}")
|
||
return None
|
||
|
||
# 在线程池中执行同步操作
|
||
return await asyncio.to_thread(_generate_sync)
|
||
|
||
def add_speaker() -> None:
|
||
"""添加TTS音色"""
|
||
if TTS_PROMPT_WAV_PATH is None or TTS_PROMPT_TEXT is None:
|
||
if VERBOSE:
|
||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, "TTS音色配置不完整,跳过初始化")
|
||
return
|
||
|
||
url = f"{TTS_SERVER_URL}/api/speakers/add"
|
||
headers = {
|
||
"accept": "application/json"
|
||
}
|
||
data = {
|
||
"speaker_id": TTS_SPEAKER_ID,
|
||
"prompt_text": TTS_PROMPT_TEXT,
|
||
"force_regenerate": True
|
||
}
|
||
|
||
try:
|
||
with open(TTS_PROMPT_WAV_PATH, 'rb') as f:
|
||
extension = ToolFile(TTS_PROMPT_WAV_PATH).GetExtension().lower()
|
||
files = {
|
||
'prompt_wav': (f'prompt.{extension}', f, f'audio/{extension}')
|
||
}
|
||
response = requests.post(url, data=data, files=files, headers=headers, timeout=600)
|
||
if response.status_code == 200:
|
||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"音色可用: {response.text}")
|
||
else:
|
||
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"添加音色失败: {response.status_code} - {response.text}")
|
||
except Exception as e:
|
||
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"添加音色异常: {e}")
|
||
|
||
async def auto_speak_task_func():
|
||
"""自动发言后台任务"""
|
||
global chat_engine, last_message_time, config, connected_clients, is_processing, current_wait_second, last_user_message_time
|
||
|
||
# 初始化最后消息时间和等待间隔
|
||
last_message_time = time.time()
|
||
last_user_message_time = time.time()
|
||
current_wait_second = AUTO_SPEAK_WAIT_SECOND
|
||
|
||
while True:
|
||
# 使用动态等待间隔
|
||
await asyncio.sleep(current_wait_second)
|
||
|
||
if chat_engine is None or config is None:
|
||
continue
|
||
|
||
current_time = time.time()
|
||
# 检查是否有客户端连接、超过等待时间且当前没有正在处理的消息
|
||
# 使用 last_user_message_time 来判断是否超过等待时间(只考虑用户消息,不包括自动发言)
|
||
if (connected_clients and
|
||
not is_processing and
|
||
(current_time - last_user_message_time) >= current_wait_second):
|
||
try:
|
||
# 触发自动发言
|
||
auto_message = "(没有人说话, 请延续发言或是寻找新的话题)"
|
||
|
||
# 向所有连接的客户端发送自动发言
|
||
for websocket in list(connected_clients):
|
||
try:
|
||
await handle_chat_stream(websocket, auto_message)
|
||
except Exception as e:
|
||
if VERBOSE:
|
||
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"自动发言发送错误: {e}")
|
||
|
||
# 等待一段时间让用户有机会回应(等待当前间隔的10%,但至少1秒)
|
||
check_wait_time = max(current_wait_second * 0.1, 1.0)
|
||
await asyncio.sleep(check_wait_time)
|
||
|
||
# 检查是否有用户回应:比较 last_user_message_time 是否在自动发言后被更新
|
||
time_after_check = time.time()
|
||
if (time_after_check - last_user_message_time) > check_wait_time + 2.0:
|
||
# 没有用户回应,逐渐增加等待间隔
|
||
current_wait_second = min(current_wait_second * 1.5, 3600.0)
|
||
if VERBOSE:
|
||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"无用户回应,等待间隔调整为: {current_wait_second}秒")
|
||
# 如果有用户回应,等待间隔会在 handle_chat_stream 中重置
|
||
|
||
except Exception as e:
|
||
if VERBOSE:
|
||
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"自动发言任务错误: {e}")
|
||
|
||
# FastAPI应用
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
"""应用生命周期管理"""
|
||
global auto_speak_task, last_message_time
|
||
|
||
# 启动时初始化
|
||
initialize_config()
|
||
initialize_chat_engine()
|
||
add_speaker()
|
||
last_message_time = time.time()
|
||
last_user_message_time = time.time()
|
||
|
||
# 启动自动发言任务
|
||
auto_speak_task = asyncio.create_task(auto_speak_task_func())
|
||
|
||
yield
|
||
|
||
# 关闭时清理
|
||
if auto_speak_task and not auto_speak_task.done():
|
||
auto_speak_task.cancel()
|
||
try:
|
||
await auto_speak_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
|
||
if config:
|
||
config.SaveProperties()
|
||
|
||
app = FastAPI(lifespan=lifespan)
|
||
|
||
# WebSocket连接管理
|
||
async def connect_client(websocket: WebSocket):
|
||
"""添加客户端连接"""
|
||
await websocket.accept()
|
||
connected_clients.add(websocket)
|
||
if VERBOSE:
|
||
PrintColorful(ConsoleFrontColor.LIGHTGREEN_EX, f"客户端已连接,当前连接数: {len(connected_clients)}")
|
||
|
||
async def disconnect_client(websocket: WebSocket):
|
||
"""移除客户端连接"""
|
||
connected_clients.discard(websocket)
|
||
if VERBOSE:
|
||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"客户端已断开,当前连接数: {len(connected_clients)}")
|
||
|
||
# API模型
|
||
class ChatRequest(BaseModel):
|
||
message: str
|
||
|
||
async def safe_send_json(websocket: WebSocket, data: dict) -> bool:
|
||
"""安全地发送 JSON 消息,检查连接状态"""
|
||
if websocket not in connected_clients:
|
||
return False
|
||
try:
|
||
await websocket.send_json(data)
|
||
return True
|
||
except (WebSocketDisconnect, RuntimeError, ConnectionError):
|
||
if websocket in connected_clients:
|
||
await disconnect_client(websocket)
|
||
return False
|
||
except Exception as e:
|
||
if VERBOSE:
|
||
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"发送消息错误: {e}")
|
||
return False
|
||
|
||
async def generate_and_send_audio(websocket: WebSocket, text: str):
|
||
"""生成音频并发送到客户端"""
|
||
try:
|
||
# 检查 WebSocket 是否仍然连接
|
||
if websocket not in connected_clients:
|
||
return
|
||
|
||
audio_data = await generate_tts_audio(text)
|
||
|
||
# 再次检查连接状态(可能在生成音频期间断开)
|
||
if websocket not in connected_clients:
|
||
return
|
||
|
||
if audio_data:
|
||
audio_base64 = base64.b64encode(audio_data).decode('utf-8')
|
||
await safe_send_json(websocket, {
|
||
"type": "audio",
|
||
"audio": audio_base64
|
||
})
|
||
except (WebSocketDisconnect, RuntimeError, ConnectionError) as e:
|
||
# 连接已关闭,忽略错误
|
||
if websocket in connected_clients:
|
||
await disconnect_client(websocket)
|
||
except Exception as e:
|
||
if VERBOSE:
|
||
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"音频生成错误: {e}")
|
||
|
||
# WebSocket端点
|
||
@app.websocket("/ws")
|
||
async def websocket_endpoint(websocket: WebSocket):
|
||
"""WebSocket端点,处理流式聊天"""
|
||
await connect_client(websocket)
|
||
try:
|
||
while True:
|
||
# 接收客户端消息
|
||
data = await websocket.receive_text()
|
||
message_data = json.loads(data)
|
||
|
||
if message_data.get("type") == "chat":
|
||
message = message_data.get("message", "")
|
||
if not message:
|
||
await safe_send_json(websocket, {
|
||
"type": "error",
|
||
"message": "消息不能为空"
|
||
})
|
||
continue
|
||
|
||
# 处理聊天请求
|
||
await handle_chat_stream(websocket, message)
|
||
elif message_data.get("type") == "ping":
|
||
await safe_send_json(websocket, {"type": "pong"})
|
||
|
||
except WebSocketDisconnect:
|
||
await disconnect_client(websocket)
|
||
except Exception as e:
|
||
if VERBOSE:
|
||
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"WebSocket错误: {e}")
|
||
await disconnect_client(websocket)
|
||
|
||
async def handle_chat_stream(websocket: WebSocket, message: str):
|
||
"""处理流式聊天"""
|
||
global chat_engine, last_message_time, is_processing, current_wait_second, last_user_message_time
|
||
|
||
if chat_engine is None:
|
||
await safe_send_json(websocket, {
|
||
"type": "error",
|
||
"message": "聊天引擎未初始化"
|
||
})
|
||
return
|
||
|
||
# 更新最后消息时间和处理状态
|
||
last_message_time = time.time()
|
||
is_processing = True
|
||
|
||
# 如果是用户消息(不是自动发言),更新用户消息时间并重置等待间隔
|
||
if message != "(没有人说话, 请延续发言或是寻找新的话题)":
|
||
last_user_message_time = time.time()
|
||
current_wait_second = AUTO_SPEAK_WAIT_SECOND
|
||
if VERBOSE:
|
||
PrintColorful(ConsoleFrontColor.LIGHTGREEN_EX, f"用户发送消息,等待间隔重置为: {current_wait_second}秒")
|
||
|
||
user_message = message if message not in [None, ""] else "(没有人说话, 请延续发言或是寻找新的话题)"
|
||
|
||
try:
|
||
streaming_response: StreamingAgentChatResponse = await chat_engine.astream_chat(user_message)
|
||
buffer_response = ""
|
||
end_symbol = ['。', '?', '!']
|
||
audio_text_buffer = ""
|
||
|
||
# 发送开始消息
|
||
if not await safe_send_json(websocket, {
|
||
"type": "start",
|
||
"message": ""
|
||
}):
|
||
return
|
||
|
||
# 流式输出
|
||
async for chunk in streaming_response.async_response_gen():
|
||
# 检查连接状态
|
||
if websocket not in connected_clients:
|
||
break
|
||
|
||
await asyncio.sleep(0.01)
|
||
buffer_response += chunk
|
||
|
||
# 发送文本块
|
||
if not await safe_send_json(websocket, {
|
||
"type": "chunk",
|
||
"message": chunk
|
||
}):
|
||
break
|
||
|
||
# 检查是否需要生成音频
|
||
for ch in chunk:
|
||
audio_text_buffer += ch
|
||
if len(audio_text_buffer) > 20:
|
||
if ch in end_symbol:
|
||
text_to_speak = audio_text_buffer.strip()
|
||
if text_to_speak:
|
||
# 异步生成音频(不阻塞流式输出)
|
||
asyncio.create_task(generate_and_send_audio(websocket, text_to_speak))
|
||
audio_text_buffer = ""
|
||
|
||
# 检查连接状态
|
||
if websocket not in connected_clients:
|
||
return
|
||
|
||
# 处理剩余文本
|
||
if buffer_response.strip():
|
||
# 发送完成消息
|
||
await safe_send_json(websocket, {
|
||
"type": "complete",
|
||
"message": buffer_response.strip()
|
||
})
|
||
|
||
# 生成剩余音频
|
||
if audio_text_buffer.strip():
|
||
asyncio.create_task(generate_and_send_audio(websocket, audio_text_buffer.strip()))
|
||
else:
|
||
await safe_send_json(websocket, {
|
||
"type": "complete",
|
||
"message": ""
|
||
})
|
||
|
||
except (WebSocketDisconnect, RuntimeError, ConnectionError):
|
||
if websocket in connected_clients:
|
||
await disconnect_client(websocket)
|
||
except Exception as e:
|
||
if VERBOSE:
|
||
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"聊天处理错误: {e}")
|
||
await safe_send_json(websocket, {
|
||
"type": "error",
|
||
"message": f"处理请求时发生错误: {str(e)}"
|
||
})
|
||
finally:
|
||
# 重置处理状态
|
||
is_processing = False
|
||
|
||
# 静态文件服务
|
||
@app.get("/", response_class=HTMLResponse)
|
||
async def read_root():
|
||
"""返回前端页面"""
|
||
return HTMLResponse(content=get_html_content())
|
||
|
||
def get_html_content() -> str:
|
||
"""生成HTML内容"""
|
||
return """<!DOCTYPE html>
|
||
<html lang="zh-CN">
|
||
<head>
|
||
<meta charset="UTF-8">
|
||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||
<title>VirtualChat - AI对话</title>
|
||
<style>
|
||
* {
|
||
margin: 0;
|
||
padding: 0;
|
||
box-sizing: border-box;
|
||
}
|
||
|
||
body {
|
||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif;
|
||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||
min-height: 100vh;
|
||
display: flex;
|
||
justify-content: center;
|
||
align-items: center;
|
||
padding: 20px;
|
||
}
|
||
|
||
.container {
|
||
width: 100%;
|
||
max-width: 800px;
|
||
height: 90vh;
|
||
background: rgba(255, 255, 255, 0.95);
|
||
border-radius: 24px;
|
||
box-shadow: 0 20px 60px rgba(0, 0, 0, 0.3);
|
||
display: flex;
|
||
flex-direction: column;
|
||
overflow: hidden;
|
||
backdrop-filter: blur(10px);
|
||
}
|
||
|
||
.header {
|
||
padding: 24px;
|
||
border-bottom: 1px solid rgba(0, 0, 0, 0.1);
|
||
background: rgba(255, 255, 255, 0.8);
|
||
backdrop-filter: blur(10px);
|
||
}
|
||
|
||
.header h1 {
|
||
font-size: 24px;
|
||
font-weight: 600;
|
||
color: #1d1d1f;
|
||
text-align: center;
|
||
}
|
||
|
||
.messages {
|
||
flex: 1;
|
||
overflow-y: auto;
|
||
padding: 24px;
|
||
display: flex;
|
||
flex-direction: column;
|
||
gap: 16px;
|
||
}
|
||
|
||
.messages::-webkit-scrollbar {
|
||
width: 6px;
|
||
}
|
||
|
||
.messages::-webkit-scrollbar-track {
|
||
background: transparent;
|
||
}
|
||
|
||
.messages::-webkit-scrollbar-thumb {
|
||
background: rgba(0, 0, 0, 0.2);
|
||
border-radius: 3px;
|
||
}
|
||
|
||
.message {
|
||
display: flex;
|
||
gap: 12px;
|
||
animation: fadeIn 0.3s ease-in;
|
||
}
|
||
|
||
@keyframes fadeIn {
|
||
from {
|
||
opacity: 0;
|
||
transform: translateY(10px);
|
||
}
|
||
to {
|
||
opacity: 1;
|
||
transform: translateY(0);
|
||
}
|
||
}
|
||
|
||
.message.user {
|
||
justify-content: flex-end;
|
||
}
|
||
|
||
.message-content {
|
||
max-width: 70%;
|
||
padding: 14px 18px;
|
||
border-radius: 20px;
|
||
word-wrap: break-word;
|
||
line-height: 1.5;
|
||
font-size: 15px;
|
||
}
|
||
|
||
.message.user .message-content {
|
||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||
color: white;
|
||
border-bottom-right-radius: 4px;
|
||
}
|
||
|
||
.message.ai .message-content {
|
||
background: #f5f5f7;
|
||
color: #1d1d1f;
|
||
border-bottom-left-radius: 4px;
|
||
}
|
||
|
||
.input-area {
|
||
padding: 24px;
|
||
border-top: 1px solid rgba(0, 0, 0, 0.1);
|
||
background: rgba(255, 255, 255, 0.8);
|
||
backdrop-filter: blur(10px);
|
||
}
|
||
|
||
.input-container {
|
||
display: flex;
|
||
gap: 12px;
|
||
align-items: flex-end;
|
||
}
|
||
|
||
.input-wrapper {
|
||
flex: 1;
|
||
position: relative;
|
||
}
|
||
|
||
#messageInput {
|
||
width: 100%;
|
||
padding: 14px 18px;
|
||
border: 2px solid rgba(0, 0, 0, 0.1);
|
||
border-radius: 24px;
|
||
font-size: 15px;
|
||
font-family: inherit;
|
||
outline: none;
|
||
transition: all 0.3s ease;
|
||
background: white;
|
||
resize: none;
|
||
min-height: 50px;
|
||
max-height: 120px;
|
||
}
|
||
|
||
#messageInput:focus {
|
||
border-color: #667eea;
|
||
box-shadow: 0 0 0 4px rgba(102, 126, 234, 0.1);
|
||
}
|
||
|
||
.send-button {
|
||
padding: 14px 28px;
|
||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||
color: white;
|
||
border: none;
|
||
border-radius: 24px;
|
||
font-size: 15px;
|
||
font-weight: 600;
|
||
cursor: pointer;
|
||
transition: all 0.3s ease;
|
||
white-space: nowrap;
|
||
}
|
||
|
||
.send-button:hover:not(:disabled) {
|
||
transform: translateY(-2px);
|
||
box-shadow: 0 8px 20px rgba(102, 126, 234, 0.4);
|
||
}
|
||
|
||
.send-button:active:not(:disabled) {
|
||
transform: translateY(0);
|
||
}
|
||
|
||
.send-button:disabled {
|
||
opacity: 0.5;
|
||
cursor: not-allowed;
|
||
}
|
||
|
||
.status {
|
||
text-align: center;
|
||
padding: 12px;
|
||
color: #86868b;
|
||
font-size: 13px;
|
||
}
|
||
|
||
.status.connected {
|
||
color: #30d158;
|
||
}
|
||
|
||
.status.disconnected {
|
||
color: #ff453a;
|
||
}
|
||
|
||
.typing-indicator {
|
||
display: none;
|
||
padding: 14px 18px;
|
||
background: #f5f5f7;
|
||
border-radius: 20px;
|
||
border-bottom-left-radius: 4px;
|
||
max-width: 70px;
|
||
}
|
||
|
||
.typing-indicator.active {
|
||
display: block;
|
||
}
|
||
|
||
.typing-dots {
|
||
display: flex;
|
||
gap: 4px;
|
||
}
|
||
|
||
.typing-dot {
|
||
width: 8px;
|
||
height: 8px;
|
||
background: #86868b;
|
||
border-radius: 50%;
|
||
animation: typing 1.4s infinite;
|
||
}
|
||
|
||
.typing-dot:nth-child(2) {
|
||
animation-delay: 0.2s;
|
||
}
|
||
|
||
.typing-dot:nth-child(3) {
|
||
animation-delay: 0.4s;
|
||
}
|
||
|
||
@keyframes typing {
|
||
0%, 60%, 100% {
|
||
transform: translateY(0);
|
||
}
|
||
30% {
|
||
transform: translateY(-10px);
|
||
}
|
||
}
|
||
</style>
|
||
</head>
|
||
<body>
|
||
<div class="container">
|
||
<div class="header">
|
||
<h1>VirtualChat</h1>
|
||
</div>
|
||
<div class="messages" id="messages"></div>
|
||
<div class="input-area">
|
||
<div class="input-container">
|
||
<div class="input-wrapper">
|
||
<textarea id="messageInput" placeholder="输入消息..." rows="1"></textarea>
|
||
</div>
|
||
<button class="send-button" id="sendButton">发送</button>
|
||
</div>
|
||
<div class="status" id="status">连接中...</div>
|
||
</div>
|
||
</div>
|
||
|
||
<script>
|
||
let ws = null;
|
||
let currentAiMessage = null;
|
||
let audioQueue = [];
|
||
let isPlayingAudio = false;
|
||
|
||
const messagesDiv = document.getElementById('messages');
|
||
const messageInput = document.getElementById('messageInput');
|
||
const sendButton = document.getElementById('sendButton');
|
||
const statusDiv = document.getElementById('status');
|
||
|
||
function connect() {
|
||
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
||
const wsUrl = `${protocol}//${window.location.host}/ws`;
|
||
ws = new WebSocket(wsUrl);
|
||
|
||
ws.onopen = () => {
|
||
statusDiv.textContent = '已连接';
|
||
statusDiv.className = 'status connected';
|
||
sendButton.disabled = false;
|
||
};
|
||
|
||
ws.onclose = () => {
|
||
statusDiv.textContent = '已断开';
|
||
statusDiv.className = 'status disconnected';
|
||
sendButton.disabled = true;
|
||
setTimeout(connect, 3000);
|
||
};
|
||
|
||
ws.onerror = (error) => {
|
||
console.error('WebSocket error:', error);
|
||
statusDiv.textContent = '连接错误';
|
||
statusDiv.className = 'status disconnected';
|
||
};
|
||
|
||
ws.onmessage = (event) => {
|
||
const data = JSON.parse(event.data);
|
||
handleMessage(data);
|
||
};
|
||
}
|
||
|
||
function handleMessage(data) {
|
||
if (data.type === 'start') {
|
||
currentAiMessage = createMessage('ai', '');
|
||
showTypingIndicator();
|
||
} else if (data.type === 'chunk') {
|
||
hideTypingIndicator();
|
||
if (currentAiMessage) {
|
||
currentAiMessage.textContent += data.message;
|
||
scrollToBottom();
|
||
}
|
||
} else if (data.type === 'complete') {
|
||
hideTypingIndicator();
|
||
if (currentAiMessage && data.message) {
|
||
currentAiMessage.textContent = data.message;
|
||
}
|
||
currentAiMessage = null;
|
||
scrollToBottom();
|
||
} else if (data.type === 'audio') {
|
||
playAudio(data.audio);
|
||
} else if (data.type === 'error') {
|
||
hideTypingIndicator();
|
||
addMessage('ai', '错误: ' + data.message);
|
||
} else if (data.type === 'auto_speak') {
|
||
// 自动发言消息(与普通消息处理相同)
|
||
if (data.message) {
|
||
addMessage('ai', data.message);
|
||
}
|
||
if (data.audio) {
|
||
playAudio(data.audio);
|
||
}
|
||
} else if (data.type === 'pong') {
|
||
// 心跳响应
|
||
}
|
||
}
|
||
|
||
function createMessage(role, text) {
|
||
const messageDiv = document.createElement('div');
|
||
messageDiv.className = `message ${role}`;
|
||
|
||
const contentDiv = document.createElement('div');
|
||
contentDiv.className = 'message-content';
|
||
contentDiv.textContent = text;
|
||
|
||
messageDiv.appendChild(contentDiv);
|
||
messagesDiv.appendChild(messageDiv);
|
||
scrollToBottom();
|
||
|
||
return contentDiv;
|
||
}
|
||
|
||
function addMessage(role, text) {
|
||
createMessage(role, text);
|
||
}
|
||
|
||
function showTypingIndicator() {
|
||
let indicator = document.getElementById('typingIndicator');
|
||
if (!indicator) {
|
||
indicator = document.createElement('div');
|
||
indicator.id = 'typingIndicator';
|
||
indicator.className = 'message ai';
|
||
indicator.innerHTML = `
|
||
<div class="typing-indicator active">
|
||
<div class="typing-dots">
|
||
<div class="typing-dot"></div>
|
||
<div class="typing-dot"></div>
|
||
<div class="typing-dot"></div>
|
||
</div>
|
||
</div>
|
||
`;
|
||
messagesDiv.appendChild(indicator);
|
||
} else {
|
||
indicator.querySelector('.typing-indicator').classList.add('active');
|
||
}
|
||
scrollToBottom();
|
||
}
|
||
|
||
function hideTypingIndicator() {
|
||
const indicator = document.getElementById('typingIndicator');
|
||
if (indicator) {
|
||
indicator.querySelector('.typing-indicator').classList.remove('active');
|
||
}
|
||
}
|
||
|
||
function scrollToBottom() {
|
||
messagesDiv.scrollTop = messagesDiv.scrollHeight;
|
||
}
|
||
|
||
function playAudio(audioBase64) {
|
||
audioQueue.push(audioBase64);
|
||
processAudioQueue();
|
||
}
|
||
|
||
function processAudioQueue() {
|
||
if (isPlayingAudio || audioQueue.length === 0) {
|
||
return;
|
||
}
|
||
|
||
isPlayingAudio = true;
|
||
const audioBase64 = audioQueue.shift();
|
||
const audio = new Audio('data:audio/wav;base64,' + audioBase64);
|
||
|
||
audio.onended = () => {
|
||
isPlayingAudio = false;
|
||
processAudioQueue();
|
||
};
|
||
|
||
audio.onerror = () => {
|
||
isPlayingAudio = false;
|
||
processAudioQueue();
|
||
};
|
||
|
||
audio.play().catch(err => {
|
||
console.error('Audio play error:', err);
|
||
isPlayingAudio = false;
|
||
processAudioQueue();
|
||
});
|
||
}
|
||
|
||
function sendMessage() {
|
||
const message = messageInput.value.trim();
|
||
if (!message || !ws || ws.readyState !== WebSocket.OPEN) {
|
||
return;
|
||
}
|
||
|
||
addMessage('user', message);
|
||
messageInput.value = '';
|
||
messageInput.style.height = 'auto';
|
||
|
||
ws.send(JSON.stringify({
|
||
type: 'chat',
|
||
message: message
|
||
}));
|
||
}
|
||
|
||
sendButton.addEventListener('click', sendMessage);
|
||
|
||
messageInput.addEventListener('keydown', (e) => {
|
||
if (e.key === 'Enter' && !e.shiftKey) {
|
||
e.preventDefault();
|
||
sendMessage();
|
||
}
|
||
});
|
||
|
||
messageInput.addEventListener('input', () => {
|
||
messageInput.style.height = 'auto';
|
||
messageInput.style.height = messageInput.scrollHeight + 'px';
|
||
});
|
||
|
||
connect();
|
||
</script>
|
||
</body>
|
||
</html>"""
|
||
|
||
# 健康检查
|
||
@app.get("/health")
|
||
def health_check():
|
||
"""健康检查端点"""
|
||
return {
|
||
"status": "ok",
|
||
"chat_engine_ready": chat_engine is not None,
|
||
"connected_clients": len(connected_clients)
|
||
}
|
||
|
||
# 主程序入口
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
server_host = config.FindItem("server_host", "0.0.0.0") if config else "0.0.0.0"
|
||
server_port = config.FindItem("server_port", 11451) if config else 11451
|
||
uvicorn.run(app, host=server_host, port=server_port)
|
||
|