356 lines
12 KiB
Python
356 lines
12 KiB
Python
|
|
|
|||
|
|
from Convention.Runtime.GlobalConfig import *
|
|||
|
|
import asyncio
|
|||
|
|
from datetime import datetime
|
|||
|
|
from llama_index.llms.ollama import Ollama
|
|||
|
|
from llama_index.core.chat_engine import SimpleChatEngine
|
|||
|
|
from llama_index.core.chat_engine.types import StreamingAgentChatResponse
|
|||
|
|
from llama_index.core import Settings
|
|||
|
|
from Convention.Runtime.File import ToolFile
|
|||
|
|
import requests
|
|||
|
|
import subprocess
|
|||
|
|
import wave
|
|||
|
|
import io
|
|||
|
|
import numpy as np
|
|||
|
|
import pyaudio
|
|||
|
|
|
|||
|
|
chat_start_id = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|||
|
|
|
|||
|
|
|
|||
|
|
config = ProjectConfig()
|
|||
|
|
OLLAMA_URL = config.FindItem("ollama_url", "http://localhost:11434")
|
|||
|
|
OLLAMA_MODEL = config.FindItem("ollama_model", "gemma3:4b")
|
|||
|
|
RESPONSE_TIMEOUT = config.FindItem("response_timeout", 60)
|
|||
|
|
TEMPERATURE = config.FindItem("temperature", 1.3)
|
|||
|
|
MAX_CONTENT_LENGTH = config.FindItem("max_content_length", None)
|
|||
|
|
SYSTEM_PROMPT_PATH = config.FindItem("system_prompt_path", None)
|
|||
|
|
AUTO_SPEAK_WAIT_SECOND = config.FindItem("auto_speak_wait_second", 15.0)
|
|||
|
|
TTS_SERVER_URL = config.FindItem("tts_server_url", "http://localhost:43400")
|
|||
|
|
TTS_PROMPT_TEXT = config.FindItem("tts_prompt_text", None)
|
|||
|
|
TTS_PROMPT_WAV_PATH = config.FindItem("tts_prompt_wav_path", None)
|
|||
|
|
TTS_SPEAKER_ID = config.FindItem("tts_speaker_id", "tts_speaker")
|
|||
|
|
SEED = config.FindItem("seed", 0)
|
|||
|
|
STREAM_ENABLE = config.FindItem("stream_enable", False)
|
|||
|
|
VERBOSE = config.FindItem("verbose", False)
|
|||
|
|
if VERBOSE:
|
|||
|
|
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"OLLAMA_URL: {OLLAMA_URL}")
|
|||
|
|
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"OLLAMA_MODEL: {OLLAMA_MODEL}")
|
|||
|
|
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"RESPONSE_TIMEOUT: {RESPONSE_TIMEOUT}")
|
|||
|
|
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"TEMPERATURE: {TEMPERATURE}")
|
|||
|
|
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"MAX_CONTENT_LENGTH: {MAX_CONTENT_LENGTH}")
|
|||
|
|
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"SYSTEM_PROMPT_PATH: {SYSTEM_PROMPT_PATH}")
|
|||
|
|
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"AUTO_SPEAK_WAIT_SECOND: {AUTO_SPEAK_WAIT_SECOND}")
|
|||
|
|
temp_dir = config.GetFile("temp")|chat_start_id|None
|
|||
|
|
|
|||
|
|
ollama_llm_config = {
|
|||
|
|
"model": OLLAMA_MODEL,
|
|||
|
|
"base_url": OLLAMA_URL,
|
|||
|
|
"request_timeout": RESPONSE_TIMEOUT,
|
|||
|
|
"temperature": TEMPERATURE,
|
|||
|
|
}
|
|||
|
|
chat_engine_config = {}
|
|||
|
|
|
|||
|
|
if MAX_CONTENT_LENGTH is not None:
|
|||
|
|
ollama_llm_config["max_content_length"] = MAX_CONTENT_LENGTH
|
|||
|
|
if VERBOSE:
|
|||
|
|
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"MAX_CONTENT_LENGTH: {MAX_CONTENT_LENGTH}")
|
|||
|
|
|
|||
|
|
if SYSTEM_PROMPT_PATH is not None:
|
|||
|
|
system_prompt = ToolFile(SYSTEM_PROMPT_PATH).LoadAsText()
|
|||
|
|
chat_engine_config["system_prompt"] = system_prompt
|
|||
|
|
if VERBOSE:
|
|||
|
|
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"system_prompt: {system_prompt}")
|
|||
|
|
|
|||
|
|
config.SaveProperties()
|
|||
|
|
|
|||
|
|
def save_vocal_data(data:bytes) -> ToolFile:
|
|||
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
|||
|
|
filename = f"{timestamp}.wav"
|
|||
|
|
file = temp_dir|filename
|
|||
|
|
file.MustExistsPath()
|
|||
|
|
file.SaveAsBinary(data)
|
|||
|
|
return file
|
|||
|
|
|
|||
|
|
audio_play_queue: asyncio.Queue[bytes] = asyncio.Queue()
|
|||
|
|
_pyaudio_instance: pyaudio.PyAudio | None = None
|
|||
|
|
_pyaudio_stream: pyaudio.Stream | None = None
|
|||
|
|
_current_sample_rate: int | None = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
# TOCHECK
|
|||
|
|
def parse_wav_chunk(wav_bytes: bytes) -> tuple[np.ndarray, int]:
|
|||
|
|
"""
|
|||
|
|
解析WAV数据,返回音频数组和采样率
|
|||
|
|
"""
|
|||
|
|
wav_file = wave.open(io.BytesIO(wav_bytes))
|
|||
|
|
sample_rate = wav_file.getframerate()
|
|||
|
|
n_channels = wav_file.getnchannels()
|
|||
|
|
sample_width = wav_file.getsampwidth()
|
|||
|
|
n_frames = wav_file.getnframes()
|
|||
|
|
audio_bytes = wav_file.readframes(n_frames)
|
|||
|
|
wav_file.close()
|
|||
|
|
|
|||
|
|
if sample_width == 2:
|
|||
|
|
audio_data = np.frombuffer(audio_bytes, dtype=np.int16)
|
|||
|
|
elif sample_width == 4:
|
|||
|
|
audio_data_32 = np.frombuffer(audio_bytes, dtype=np.int32)
|
|||
|
|
max_val = np.abs(audio_data_32).max()
|
|||
|
|
if max_val > 0:
|
|||
|
|
audio_data = (audio_data_32 / max_val * 32767).astype(np.int16)
|
|||
|
|
else:
|
|||
|
|
audio_data = np.zeros(len(audio_data_32), dtype=np.int16)
|
|||
|
|
else:
|
|||
|
|
raise ValueError(f"Unsupported sample width: {sample_width}")
|
|||
|
|
|
|||
|
|
if n_channels == 2:
|
|||
|
|
audio_data = audio_data.reshape(-1, 2).mean(axis=1).astype(np.int16)
|
|||
|
|
|
|||
|
|
audio_data = np.clip(audio_data, -32768, 32767).astype(np.int16)
|
|||
|
|
return audio_data, sample_rate
|
|||
|
|
|
|||
|
|
|
|||
|
|
# TOCHECK
|
|||
|
|
def play_audio_chunk(audio_data: np.ndarray, sample_rate: int) -> None:
|
|||
|
|
"""
|
|||
|
|
使用PyAudio播放音频数组
|
|||
|
|
"""
|
|||
|
|
global _pyaudio_instance, _pyaudio_stream, _current_sample_rate
|
|||
|
|
|
|||
|
|
if _pyaudio_instance is None or _current_sample_rate != sample_rate:
|
|||
|
|
if _pyaudio_stream is not None:
|
|||
|
|
try:
|
|||
|
|
_pyaudio_stream.stop_stream()
|
|||
|
|
_pyaudio_stream.close()
|
|||
|
|
except Exception:
|
|||
|
|
pass
|
|||
|
|
if _pyaudio_instance is not None:
|
|||
|
|
try:
|
|||
|
|
_pyaudio_instance.terminate()
|
|||
|
|
except Exception:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
frames_per_buffer = max(int(sample_rate * 0.02), 256)
|
|||
|
|
_pyaudio_instance = pyaudio.PyAudio()
|
|||
|
|
_pyaudio_stream = _pyaudio_instance.open(
|
|||
|
|
format=pyaudio.paInt16,
|
|||
|
|
channels=1,
|
|||
|
|
rate=sample_rate,
|
|||
|
|
output=True,
|
|||
|
|
frames_per_buffer=frames_per_buffer
|
|||
|
|
)
|
|||
|
|
_current_sample_rate = sample_rate
|
|||
|
|
|
|||
|
|
if audio_data.dtype != np.int16:
|
|||
|
|
max_val = np.abs(audio_data).max()
|
|||
|
|
if max_val > 0:
|
|||
|
|
audio_data = (audio_data / max_val * 32767).astype(np.int16)
|
|||
|
|
else:
|
|||
|
|
audio_data = np.zeros_like(audio_data, dtype=np.int16)
|
|||
|
|
|
|||
|
|
chunk_size = 4096
|
|||
|
|
audio_bytes = audio_data.tobytes()
|
|||
|
|
for i in range(0, len(audio_bytes), chunk_size):
|
|||
|
|
chunk = audio_bytes[i:i + chunk_size]
|
|||
|
|
if _pyaudio_stream is not None:
|
|||
|
|
_pyaudio_stream.write(chunk)
|
|||
|
|
|
|||
|
|
# TOCHECK
|
|||
|
|
def cleanup_audio() -> None:
|
|||
|
|
"""
|
|||
|
|
释放PyAudio资源
|
|||
|
|
"""
|
|||
|
|
global _pyaudio_instance, _pyaudio_stream, _current_sample_rate
|
|||
|
|
if _pyaudio_stream is not None:
|
|||
|
|
try:
|
|||
|
|
_pyaudio_stream.stop_stream()
|
|||
|
|
_pyaudio_stream.close()
|
|||
|
|
except Exception:
|
|||
|
|
pass
|
|||
|
|
_pyaudio_stream = None
|
|||
|
|
if _pyaudio_instance is not None:
|
|||
|
|
try:
|
|||
|
|
_pyaudio_instance.terminate()
|
|||
|
|
except Exception:
|
|||
|
|
pass
|
|||
|
|
_pyaudio_instance = None
|
|||
|
|
_current_sample_rate = None
|
|||
|
|
|
|||
|
|
# TOCHECK
|
|||
|
|
def play_audio_sync(audio_data: bytes) -> None:
|
|||
|
|
if not audio_data:
|
|||
|
|
return
|
|||
|
|
audio_array, sample_rate = parse_wav_chunk(audio_data)
|
|||
|
|
play_audio_chunk(audio_array, sample_rate)
|
|||
|
|
|
|||
|
|
# TOCHECK
|
|||
|
|
async def audio_player_worker():
|
|||
|
|
"""
|
|||
|
|
音频播放后台任务,确保音频按顺序播放
|
|||
|
|
"""
|
|||
|
|
while True:
|
|||
|
|
audio_data = await audio_play_queue.get()
|
|||
|
|
if audio_data is None:
|
|||
|
|
audio_play_queue.task_done()
|
|||
|
|
break
|
|||
|
|
try:
|
|||
|
|
await asyncio.to_thread(play_audio_sync, audio_data)
|
|||
|
|
await asyncio.to_thread(save_vocal_data, audio_data)
|
|||
|
|
except Exception as exc:
|
|||
|
|
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"音频播放失败: {exc}")
|
|||
|
|
finally:
|
|||
|
|
audio_play_queue.task_done()
|
|||
|
|
|
|||
|
|
# CHANGE TOCHECK
|
|||
|
|
async def play_vocal(text:str) -> None:
|
|||
|
|
if len(text) == 0 or not text:
|
|||
|
|
return
|
|||
|
|
tts_server_url = f"{TTS_SERVER_URL}/api/synthesis/sft"
|
|||
|
|
# 准备请求数据
|
|||
|
|
header = {
|
|||
|
|
"accept": "application/json"
|
|||
|
|
}
|
|||
|
|
data = {
|
|||
|
|
'text': text,
|
|||
|
|
'speaker_id': TTS_SPEAKER_ID,
|
|||
|
|
'stream': STREAM_ENABLE,
|
|||
|
|
}
|
|||
|
|
# 发送POST请求
|
|||
|
|
if STREAM_ENABLE:
|
|||
|
|
response = requests.post(tts_server_url, data=data, stream=True, timeout=600, headers=header)
|
|||
|
|
if response.status_code != 200:
|
|||
|
|
raise Exception(f"语音合成失败: {response.status_code} - {response.text}")
|
|||
|
|
wav_buffer = bytearray()
|
|||
|
|
for chunk in response.iter_content(chunk_size=1024 * 256):
|
|||
|
|
if not chunk:
|
|||
|
|
continue
|
|||
|
|
wav_buffer.extend(chunk)
|
|||
|
|
while len(wav_buffer) > 12:
|
|||
|
|
if wav_buffer[:4] != b'RIFF':
|
|||
|
|
riff_pos = wav_buffer.find(b'RIFF', 1)
|
|||
|
|
if riff_pos == -1:
|
|||
|
|
wav_buffer.clear()
|
|||
|
|
break
|
|||
|
|
wav_buffer = wav_buffer[riff_pos:]
|
|||
|
|
if len(wav_buffer) < 8:
|
|||
|
|
break
|
|||
|
|
file_size = int.from_bytes(wav_buffer[4:8], byteorder='little')
|
|||
|
|
expected_size = file_size + 8
|
|||
|
|
if len(wav_buffer) < expected_size:
|
|||
|
|
break
|
|||
|
|
complete_wav = bytes(wav_buffer[:expected_size])
|
|||
|
|
del wav_buffer[:expected_size]
|
|||
|
|
await audio_play_queue.put(complete_wav)
|
|||
|
|
if wav_buffer:
|
|||
|
|
leftover = bytes(wav_buffer)
|
|||
|
|
try:
|
|||
|
|
parse_wav_chunk(leftover)
|
|||
|
|
await audio_play_queue.put(leftover)
|
|||
|
|
except Exception:
|
|||
|
|
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, "剩余音频数据解析失败,已丢弃")
|
|||
|
|
else:
|
|||
|
|
response = requests.post(tts_server_url, data=data, timeout=600, headers=header)
|
|||
|
|
if response.status_code == 200:
|
|||
|
|
await audio_play_queue.put(response.content)
|
|||
|
|
else:
|
|||
|
|
raise Exception(f"语音合成失败: {response.status_code} - {response.text}")
|
|||
|
|
|
|||
|
|
async def ainput(wait_seconds:float) -> str:
|
|||
|
|
loop = asyncio.get_event_loop()
|
|||
|
|
|
|||
|
|
def get_input():
|
|||
|
|
try:
|
|||
|
|
return input("\n你: ")
|
|||
|
|
except EOFError:
|
|||
|
|
return ""
|
|||
|
|
|
|||
|
|
input_task = loop.run_in_executor(None, get_input)
|
|||
|
|
while wait_seconds > 0:
|
|||
|
|
if input_task.done():
|
|||
|
|
return input_task.result()
|
|||
|
|
await asyncio.sleep(0.5)
|
|||
|
|
wait_seconds -= 0.5
|
|||
|
|
return ""
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def achat(engine:SimpleChatEngine,message:str) -> None:
|
|||
|
|
user_message = message if message not in [None, ""] else "(没有人说话, 请延续发言或是寻找新的话题)"
|
|||
|
|
streaming_response: StreamingAgentChatResponse = await engine.astream_chat(user_message)
|
|||
|
|
buffer_response = ""
|
|||
|
|
|
|||
|
|
end_symbol = ['。', '?', '!']
|
|||
|
|
|
|||
|
|
# 实时输出流式文本
|
|||
|
|
async for chunk in streaming_response.async_response_gen():
|
|||
|
|
await asyncio.sleep(0.01)
|
|||
|
|
print(chunk, end='', flush=True)
|
|||
|
|
for ch in chunk:
|
|||
|
|
buffer_response += ch
|
|||
|
|
if len(buffer_response) > 20:
|
|||
|
|
if ch in end_symbol:
|
|||
|
|
await play_vocal(buffer_response.strip())
|
|||
|
|
buffer_response = ""
|
|||
|
|
buffer_response = buffer_response.strip()
|
|||
|
|
if len(buffer_response) > 0:
|
|||
|
|
await play_vocal(buffer_response)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def add_speaker() -> None:
|
|||
|
|
url = f"{TTS_SERVER_URL}/api/speakers/add"
|
|||
|
|
headers = {
|
|||
|
|
"accept": "application/json"
|
|||
|
|
}
|
|||
|
|
data = {
|
|||
|
|
"speaker_id": TTS_SPEAKER_ID,
|
|||
|
|
"prompt_text": TTS_PROMPT_TEXT,
|
|||
|
|
"force_regenerate": True
|
|||
|
|
}
|
|||
|
|
with open(TTS_PROMPT_WAV_PATH, 'rb') as f:
|
|||
|
|
extension = ToolFile(TTS_PROMPT_WAV_PATH).GetExtension().lower()
|
|||
|
|
files = {
|
|||
|
|
'prompt_wav': (f'prompt.{extension}', f, f'audio/{extension}')
|
|||
|
|
}
|
|||
|
|
response = requests.post(url, data=data, files=files, headers=headers, timeout=600)
|
|||
|
|
if response.status_code == 200:
|
|||
|
|
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"音色可用: {response.text}")
|
|||
|
|
else:
|
|||
|
|
raise Exception(f"添加音色失败: {response.status_code} - {response.text}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def event_loop(engine:SimpleChatEngine) -> None:
|
|||
|
|
add_speaker()
|
|||
|
|
audio_player_task = asyncio.create_task(audio_player_worker())
|
|||
|
|
message = input("请开始对话: ")
|
|||
|
|
wait_second = AUTO_SPEAK_WAIT_SECOND
|
|||
|
|
try:
|
|||
|
|
while message != "quit" and message != "exit":
|
|||
|
|
PrintColorful(ConsoleFrontColor.GREEN, "AI: ", is_reset=False, end='')
|
|||
|
|
await achat(engine, message)
|
|||
|
|
PrintColorful(ConsoleFrontColor.RESET,"")
|
|||
|
|
message = await ainput(wait_second)
|
|||
|
|
if not message:
|
|||
|
|
wait_second = max(wait_second*1.5, 3600)
|
|||
|
|
else:
|
|||
|
|
wait_second = AUTO_SPEAK_WAIT_SECOND
|
|||
|
|
finally:
|
|||
|
|
await audio_play_queue.join()
|
|||
|
|
await audio_play_queue.put(None)
|
|||
|
|
await audio_player_task
|
|||
|
|
cleanup_audio()
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def main():
|
|||
|
|
# Initialize
|
|||
|
|
try:
|
|||
|
|
ollama_llm = Ollama(**ollama_llm_config)
|
|||
|
|
Settings.llm = ollama_llm
|
|||
|
|
chat_engine = SimpleChatEngine.from_defaults(**chat_engine_config)
|
|||
|
|
await event_loop(chat_engine)
|
|||
|
|
except Exception as e:
|
|||
|
|
config.Log("Error", f"Error: {e}")
|
|||
|
|
return
|
|||
|
|
finally:
|
|||
|
|
cleanup_audio()
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
asyncio.run(main())
|