Files
TheVirtualOne/cli.py
2025-12-18 01:52:40 +08:00

356 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from Convention.Runtime.GlobalConfig import *
import asyncio
from datetime import datetime
from llama_index.llms.ollama import Ollama
from llama_index.core.chat_engine import SimpleChatEngine
from llama_index.core.chat_engine.types import StreamingAgentChatResponse
from llama_index.core import Settings
from Convention.Runtime.File import ToolFile
import requests
import subprocess
import wave
import io
import numpy as np
import pyaudio
chat_start_id = datetime.now().strftime("%Y%m%d_%H%M%S")
config = ProjectConfig()
OLLAMA_URL = config.FindItem("ollama_url", "http://localhost:11434")
OLLAMA_MODEL = config.FindItem("ollama_model", "gemma3:4b")
RESPONSE_TIMEOUT = config.FindItem("response_timeout", 60)
TEMPERATURE = config.FindItem("temperature", 1.3)
MAX_CONTENT_LENGTH = config.FindItem("max_content_length", None)
SYSTEM_PROMPT_PATH = config.FindItem("system_prompt_path", None)
AUTO_SPEAK_WAIT_SECOND = config.FindItem("auto_speak_wait_second", 15.0)
TTS_SERVER_URL = config.FindItem("tts_server_url", "http://localhost:43400")
TTS_PROMPT_TEXT = config.FindItem("tts_prompt_text", None)
TTS_PROMPT_WAV_PATH = config.FindItem("tts_prompt_wav_path", None)
TTS_SPEAKER_ID = config.FindItem("tts_speaker_id", "tts_speaker")
SEED = config.FindItem("seed", 0)
STREAM_ENABLE = config.FindItem("stream_enable", False)
VERBOSE = config.FindItem("verbose", False)
if VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"OLLAMA_URL: {OLLAMA_URL}")
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"OLLAMA_MODEL: {OLLAMA_MODEL}")
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"RESPONSE_TIMEOUT: {RESPONSE_TIMEOUT}")
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"TEMPERATURE: {TEMPERATURE}")
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"MAX_CONTENT_LENGTH: {MAX_CONTENT_LENGTH}")
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"SYSTEM_PROMPT_PATH: {SYSTEM_PROMPT_PATH}")
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"AUTO_SPEAK_WAIT_SECOND: {AUTO_SPEAK_WAIT_SECOND}")
temp_dir = config.GetFile("temp")|chat_start_id|None
ollama_llm_config = {
"model": OLLAMA_MODEL,
"base_url": OLLAMA_URL,
"request_timeout": RESPONSE_TIMEOUT,
"temperature": TEMPERATURE,
}
chat_engine_config = {}
if MAX_CONTENT_LENGTH is not None:
ollama_llm_config["max_content_length"] = MAX_CONTENT_LENGTH
if VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"MAX_CONTENT_LENGTH: {MAX_CONTENT_LENGTH}")
if SYSTEM_PROMPT_PATH is not None:
system_prompt = ToolFile(SYSTEM_PROMPT_PATH).LoadAsText()
chat_engine_config["system_prompt"] = system_prompt
if VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"system_prompt: {system_prompt}")
config.SaveProperties()
def save_vocal_data(data:bytes) -> ToolFile:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
filename = f"{timestamp}.wav"
file = temp_dir|filename
file.MustExistsPath()
file.SaveAsBinary(data)
return file
audio_play_queue: asyncio.Queue[bytes] = asyncio.Queue()
_pyaudio_instance: pyaudio.PyAudio | None = None
_pyaudio_stream: pyaudio.Stream | None = None
_current_sample_rate: int | None = None
# TOCHECK
def parse_wav_chunk(wav_bytes: bytes) -> tuple[np.ndarray, int]:
"""
解析WAV数据返回音频数组和采样率
"""
wav_file = wave.open(io.BytesIO(wav_bytes))
sample_rate = wav_file.getframerate()
n_channels = wav_file.getnchannels()
sample_width = wav_file.getsampwidth()
n_frames = wav_file.getnframes()
audio_bytes = wav_file.readframes(n_frames)
wav_file.close()
if sample_width == 2:
audio_data = np.frombuffer(audio_bytes, dtype=np.int16)
elif sample_width == 4:
audio_data_32 = np.frombuffer(audio_bytes, dtype=np.int32)
max_val = np.abs(audio_data_32).max()
if max_val > 0:
audio_data = (audio_data_32 / max_val * 32767).astype(np.int16)
else:
audio_data = np.zeros(len(audio_data_32), dtype=np.int16)
else:
raise ValueError(f"Unsupported sample width: {sample_width}")
if n_channels == 2:
audio_data = audio_data.reshape(-1, 2).mean(axis=1).astype(np.int16)
audio_data = np.clip(audio_data, -32768, 32767).astype(np.int16)
return audio_data, sample_rate
# TOCHECK
def play_audio_chunk(audio_data: np.ndarray, sample_rate: int) -> None:
"""
使用PyAudio播放音频数组
"""
global _pyaudio_instance, _pyaudio_stream, _current_sample_rate
if _pyaudio_instance is None or _current_sample_rate != sample_rate:
if _pyaudio_stream is not None:
try:
_pyaudio_stream.stop_stream()
_pyaudio_stream.close()
except Exception:
pass
if _pyaudio_instance is not None:
try:
_pyaudio_instance.terminate()
except Exception:
pass
frames_per_buffer = max(int(sample_rate * 0.02), 256)
_pyaudio_instance = pyaudio.PyAudio()
_pyaudio_stream = _pyaudio_instance.open(
format=pyaudio.paInt16,
channels=1,
rate=sample_rate,
output=True,
frames_per_buffer=frames_per_buffer
)
_current_sample_rate = sample_rate
if audio_data.dtype != np.int16:
max_val = np.abs(audio_data).max()
if max_val > 0:
audio_data = (audio_data / max_val * 32767).astype(np.int16)
else:
audio_data = np.zeros_like(audio_data, dtype=np.int16)
chunk_size = 4096
audio_bytes = audio_data.tobytes()
for i in range(0, len(audio_bytes), chunk_size):
chunk = audio_bytes[i:i + chunk_size]
if _pyaudio_stream is not None:
_pyaudio_stream.write(chunk)
# TOCHECK
def cleanup_audio() -> None:
"""
释放PyAudio资源
"""
global _pyaudio_instance, _pyaudio_stream, _current_sample_rate
if _pyaudio_stream is not None:
try:
_pyaudio_stream.stop_stream()
_pyaudio_stream.close()
except Exception:
pass
_pyaudio_stream = None
if _pyaudio_instance is not None:
try:
_pyaudio_instance.terminate()
except Exception:
pass
_pyaudio_instance = None
_current_sample_rate = None
# TOCHECK
def play_audio_sync(audio_data: bytes) -> None:
if not audio_data:
return
audio_array, sample_rate = parse_wav_chunk(audio_data)
play_audio_chunk(audio_array, sample_rate)
# TOCHECK
async def audio_player_worker():
"""
音频播放后台任务,确保音频按顺序播放
"""
while True:
audio_data = await audio_play_queue.get()
if audio_data is None:
audio_play_queue.task_done()
break
try:
await asyncio.to_thread(play_audio_sync, audio_data)
await asyncio.to_thread(save_vocal_data, audio_data)
except Exception as exc:
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"音频播放失败: {exc}")
finally:
audio_play_queue.task_done()
# CHANGE TOCHECK
async def play_vocal(text:str) -> None:
if len(text) == 0 or not text:
return
tts_server_url = f"{TTS_SERVER_URL}/api/synthesis/sft"
# 准备请求数据
header = {
"accept": "application/json"
}
data = {
'text': text,
'speaker_id': TTS_SPEAKER_ID,
'stream': STREAM_ENABLE,
}
# 发送POST请求
if STREAM_ENABLE:
response = requests.post(tts_server_url, data=data, stream=True, timeout=600, headers=header)
if response.status_code != 200:
raise Exception(f"语音合成失败: {response.status_code} - {response.text}")
wav_buffer = bytearray()
for chunk in response.iter_content(chunk_size=1024 * 256):
if not chunk:
continue
wav_buffer.extend(chunk)
while len(wav_buffer) > 12:
if wav_buffer[:4] != b'RIFF':
riff_pos = wav_buffer.find(b'RIFF', 1)
if riff_pos == -1:
wav_buffer.clear()
break
wav_buffer = wav_buffer[riff_pos:]
if len(wav_buffer) < 8:
break
file_size = int.from_bytes(wav_buffer[4:8], byteorder='little')
expected_size = file_size + 8
if len(wav_buffer) < expected_size:
break
complete_wav = bytes(wav_buffer[:expected_size])
del wav_buffer[:expected_size]
await audio_play_queue.put(complete_wav)
if wav_buffer:
leftover = bytes(wav_buffer)
try:
parse_wav_chunk(leftover)
await audio_play_queue.put(leftover)
except Exception:
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, "剩余音频数据解析失败,已丢弃")
else:
response = requests.post(tts_server_url, data=data, timeout=600, headers=header)
if response.status_code == 200:
await audio_play_queue.put(response.content)
else:
raise Exception(f"语音合成失败: {response.status_code} - {response.text}")
async def ainput(wait_seconds:float) -> str:
loop = asyncio.get_event_loop()
def get_input():
try:
return input("\n你: ")
except EOFError:
return ""
input_task = loop.run_in_executor(None, get_input)
while wait_seconds > 0:
if input_task.done():
return input_task.result()
await asyncio.sleep(0.5)
wait_seconds -= 0.5
return ""
async def achat(engine:SimpleChatEngine,message:str) -> None:
user_message = message if message not in [None, ""] else "(没有人说话, 请延续发言或是寻找新的话题)"
streaming_response: StreamingAgentChatResponse = await engine.astream_chat(user_message)
buffer_response = ""
end_symbol = ['', '', '']
# 实时输出流式文本
async for chunk in streaming_response.async_response_gen():
await asyncio.sleep(0.01)
print(chunk, end='', flush=True)
for ch in chunk:
buffer_response += ch
if len(buffer_response) > 20:
if ch in end_symbol:
await play_vocal(buffer_response.strip())
buffer_response = ""
buffer_response = buffer_response.strip()
if len(buffer_response) > 0:
await play_vocal(buffer_response)
def add_speaker() -> None:
url = f"{TTS_SERVER_URL}/api/speakers/add"
headers = {
"accept": "application/json"
}
data = {
"speaker_id": TTS_SPEAKER_ID,
"prompt_text": TTS_PROMPT_TEXT,
"force_regenerate": True
}
with open(TTS_PROMPT_WAV_PATH, 'rb') as f:
extension = ToolFile(TTS_PROMPT_WAV_PATH).GetExtension().lower()
files = {
'prompt_wav': (f'prompt.{extension}', f, f'audio/{extension}')
}
response = requests.post(url, data=data, files=files, headers=headers, timeout=600)
if response.status_code == 200:
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"音色可用: {response.text}")
else:
raise Exception(f"添加音色失败: {response.status_code} - {response.text}")
async def event_loop(engine:SimpleChatEngine) -> None:
add_speaker()
audio_player_task = asyncio.create_task(audio_player_worker())
message = input("请开始对话: ")
wait_second = AUTO_SPEAK_WAIT_SECOND
try:
while message != "quit" and message != "exit":
PrintColorful(ConsoleFrontColor.GREEN, "AI: ", is_reset=False, end='')
await achat(engine, message)
PrintColorful(ConsoleFrontColor.RESET,"")
message = await ainput(wait_second)
if not message:
wait_second = max(wait_second*1.5, 3600)
else:
wait_second = AUTO_SPEAK_WAIT_SECOND
finally:
await audio_play_queue.join()
await audio_play_queue.put(None)
await audio_player_task
cleanup_audio()
async def main():
# Initialize
try:
ollama_llm = Ollama(**ollama_llm_config)
Settings.llm = ollama_llm
chat_engine = SimpleChatEngine.from_defaults(**chat_engine_config)
await event_loop(chat_engine)
except Exception as e:
config.Log("Error", f"Error: {e}")
return
finally:
cleanup_audio()
if __name__ == "__main__":
asyncio.run(main())