可用版本v1.0
This commit is contained in:
355
cli.py
Normal file
355
cli.py
Normal file
@@ -0,0 +1,355 @@
|
||||
|
||||
from Convention.Runtime.GlobalConfig import *
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from llama_index.llms.ollama import Ollama
|
||||
from llama_index.core.chat_engine import SimpleChatEngine
|
||||
from llama_index.core.chat_engine.types import StreamingAgentChatResponse
|
||||
from llama_index.core import Settings
|
||||
from Convention.Runtime.File import ToolFile
|
||||
import requests
|
||||
import subprocess
|
||||
import wave
|
||||
import io
|
||||
import numpy as np
|
||||
import pyaudio
|
||||
|
||||
chat_start_id = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
|
||||
config = ProjectConfig()
|
||||
OLLAMA_URL = config.FindItem("ollama_url", "http://localhost:11434")
|
||||
OLLAMA_MODEL = config.FindItem("ollama_model", "gemma3:4b")
|
||||
RESPONSE_TIMEOUT = config.FindItem("response_timeout", 60)
|
||||
TEMPERATURE = config.FindItem("temperature", 1.3)
|
||||
MAX_CONTENT_LENGTH = config.FindItem("max_content_length", None)
|
||||
SYSTEM_PROMPT_PATH = config.FindItem("system_prompt_path", None)
|
||||
AUTO_SPEAK_WAIT_SECOND = config.FindItem("auto_speak_wait_second", 15.0)
|
||||
TTS_SERVER_URL = config.FindItem("tts_server_url", "http://localhost:43400")
|
||||
TTS_PROMPT_TEXT = config.FindItem("tts_prompt_text", None)
|
||||
TTS_PROMPT_WAV_PATH = config.FindItem("tts_prompt_wav_path", None)
|
||||
TTS_SPEAKER_ID = config.FindItem("tts_speaker_id", "tts_speaker")
|
||||
SEED = config.FindItem("seed", 0)
|
||||
STREAM_ENABLE = config.FindItem("stream_enable", False)
|
||||
VERBOSE = config.FindItem("verbose", False)
|
||||
if VERBOSE:
|
||||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"OLLAMA_URL: {OLLAMA_URL}")
|
||||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"OLLAMA_MODEL: {OLLAMA_MODEL}")
|
||||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"RESPONSE_TIMEOUT: {RESPONSE_TIMEOUT}")
|
||||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"TEMPERATURE: {TEMPERATURE}")
|
||||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"MAX_CONTENT_LENGTH: {MAX_CONTENT_LENGTH}")
|
||||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"SYSTEM_PROMPT_PATH: {SYSTEM_PROMPT_PATH}")
|
||||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"AUTO_SPEAK_WAIT_SECOND: {AUTO_SPEAK_WAIT_SECOND}")
|
||||
temp_dir = config.GetFile("temp")|chat_start_id|None
|
||||
|
||||
ollama_llm_config = {
|
||||
"model": OLLAMA_MODEL,
|
||||
"base_url": OLLAMA_URL,
|
||||
"request_timeout": RESPONSE_TIMEOUT,
|
||||
"temperature": TEMPERATURE,
|
||||
}
|
||||
chat_engine_config = {}
|
||||
|
||||
if MAX_CONTENT_LENGTH is not None:
|
||||
ollama_llm_config["max_content_length"] = MAX_CONTENT_LENGTH
|
||||
if VERBOSE:
|
||||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"MAX_CONTENT_LENGTH: {MAX_CONTENT_LENGTH}")
|
||||
|
||||
if SYSTEM_PROMPT_PATH is not None:
|
||||
system_prompt = ToolFile(SYSTEM_PROMPT_PATH).LoadAsText()
|
||||
chat_engine_config["system_prompt"] = system_prompt
|
||||
if VERBOSE:
|
||||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"system_prompt: {system_prompt}")
|
||||
|
||||
config.SaveProperties()
|
||||
|
||||
def save_vocal_data(data:bytes) -> ToolFile:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
filename = f"{timestamp}.wav"
|
||||
file = temp_dir|filename
|
||||
file.MustExistsPath()
|
||||
file.SaveAsBinary(data)
|
||||
return file
|
||||
|
||||
audio_play_queue: asyncio.Queue[bytes] = asyncio.Queue()
|
||||
_pyaudio_instance: pyaudio.PyAudio | None = None
|
||||
_pyaudio_stream: pyaudio.Stream | None = None
|
||||
_current_sample_rate: int | None = None
|
||||
|
||||
|
||||
# TOCHECK
|
||||
def parse_wav_chunk(wav_bytes: bytes) -> tuple[np.ndarray, int]:
|
||||
"""
|
||||
解析WAV数据,返回音频数组和采样率
|
||||
"""
|
||||
wav_file = wave.open(io.BytesIO(wav_bytes))
|
||||
sample_rate = wav_file.getframerate()
|
||||
n_channels = wav_file.getnchannels()
|
||||
sample_width = wav_file.getsampwidth()
|
||||
n_frames = wav_file.getnframes()
|
||||
audio_bytes = wav_file.readframes(n_frames)
|
||||
wav_file.close()
|
||||
|
||||
if sample_width == 2:
|
||||
audio_data = np.frombuffer(audio_bytes, dtype=np.int16)
|
||||
elif sample_width == 4:
|
||||
audio_data_32 = np.frombuffer(audio_bytes, dtype=np.int32)
|
||||
max_val = np.abs(audio_data_32).max()
|
||||
if max_val > 0:
|
||||
audio_data = (audio_data_32 / max_val * 32767).astype(np.int16)
|
||||
else:
|
||||
audio_data = np.zeros(len(audio_data_32), dtype=np.int16)
|
||||
else:
|
||||
raise ValueError(f"Unsupported sample width: {sample_width}")
|
||||
|
||||
if n_channels == 2:
|
||||
audio_data = audio_data.reshape(-1, 2).mean(axis=1).astype(np.int16)
|
||||
|
||||
audio_data = np.clip(audio_data, -32768, 32767).astype(np.int16)
|
||||
return audio_data, sample_rate
|
||||
|
||||
|
||||
# TOCHECK
|
||||
def play_audio_chunk(audio_data: np.ndarray, sample_rate: int) -> None:
|
||||
"""
|
||||
使用PyAudio播放音频数组
|
||||
"""
|
||||
global _pyaudio_instance, _pyaudio_stream, _current_sample_rate
|
||||
|
||||
if _pyaudio_instance is None or _current_sample_rate != sample_rate:
|
||||
if _pyaudio_stream is not None:
|
||||
try:
|
||||
_pyaudio_stream.stop_stream()
|
||||
_pyaudio_stream.close()
|
||||
except Exception:
|
||||
pass
|
||||
if _pyaudio_instance is not None:
|
||||
try:
|
||||
_pyaudio_instance.terminate()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
frames_per_buffer = max(int(sample_rate * 0.02), 256)
|
||||
_pyaudio_instance = pyaudio.PyAudio()
|
||||
_pyaudio_stream = _pyaudio_instance.open(
|
||||
format=pyaudio.paInt16,
|
||||
channels=1,
|
||||
rate=sample_rate,
|
||||
output=True,
|
||||
frames_per_buffer=frames_per_buffer
|
||||
)
|
||||
_current_sample_rate = sample_rate
|
||||
|
||||
if audio_data.dtype != np.int16:
|
||||
max_val = np.abs(audio_data).max()
|
||||
if max_val > 0:
|
||||
audio_data = (audio_data / max_val * 32767).astype(np.int16)
|
||||
else:
|
||||
audio_data = np.zeros_like(audio_data, dtype=np.int16)
|
||||
|
||||
chunk_size = 4096
|
||||
audio_bytes = audio_data.tobytes()
|
||||
for i in range(0, len(audio_bytes), chunk_size):
|
||||
chunk = audio_bytes[i:i + chunk_size]
|
||||
if _pyaudio_stream is not None:
|
||||
_pyaudio_stream.write(chunk)
|
||||
|
||||
# TOCHECK
|
||||
def cleanup_audio() -> None:
|
||||
"""
|
||||
释放PyAudio资源
|
||||
"""
|
||||
global _pyaudio_instance, _pyaudio_stream, _current_sample_rate
|
||||
if _pyaudio_stream is not None:
|
||||
try:
|
||||
_pyaudio_stream.stop_stream()
|
||||
_pyaudio_stream.close()
|
||||
except Exception:
|
||||
pass
|
||||
_pyaudio_stream = None
|
||||
if _pyaudio_instance is not None:
|
||||
try:
|
||||
_pyaudio_instance.terminate()
|
||||
except Exception:
|
||||
pass
|
||||
_pyaudio_instance = None
|
||||
_current_sample_rate = None
|
||||
|
||||
# TOCHECK
|
||||
def play_audio_sync(audio_data: bytes) -> None:
|
||||
if not audio_data:
|
||||
return
|
||||
audio_array, sample_rate = parse_wav_chunk(audio_data)
|
||||
play_audio_chunk(audio_array, sample_rate)
|
||||
|
||||
# TOCHECK
|
||||
async def audio_player_worker():
|
||||
"""
|
||||
音频播放后台任务,确保音频按顺序播放
|
||||
"""
|
||||
while True:
|
||||
audio_data = await audio_play_queue.get()
|
||||
if audio_data is None:
|
||||
audio_play_queue.task_done()
|
||||
break
|
||||
try:
|
||||
await asyncio.to_thread(play_audio_sync, audio_data)
|
||||
await asyncio.to_thread(save_vocal_data, audio_data)
|
||||
except Exception as exc:
|
||||
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"音频播放失败: {exc}")
|
||||
finally:
|
||||
audio_play_queue.task_done()
|
||||
|
||||
# CHANGE TOCHECK
|
||||
async def play_vocal(text:str) -> None:
|
||||
if len(text) == 0 or not text:
|
||||
return
|
||||
tts_server_url = f"{TTS_SERVER_URL}/api/synthesis/sft"
|
||||
# 准备请求数据
|
||||
header = {
|
||||
"accept": "application/json"
|
||||
}
|
||||
data = {
|
||||
'text': text,
|
||||
'speaker_id': TTS_SPEAKER_ID,
|
||||
'stream': STREAM_ENABLE,
|
||||
}
|
||||
# 发送POST请求
|
||||
if STREAM_ENABLE:
|
||||
response = requests.post(tts_server_url, data=data, stream=True, timeout=600, headers=header)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"语音合成失败: {response.status_code} - {response.text}")
|
||||
wav_buffer = bytearray()
|
||||
for chunk in response.iter_content(chunk_size=1024 * 256):
|
||||
if not chunk:
|
||||
continue
|
||||
wav_buffer.extend(chunk)
|
||||
while len(wav_buffer) > 12:
|
||||
if wav_buffer[:4] != b'RIFF':
|
||||
riff_pos = wav_buffer.find(b'RIFF', 1)
|
||||
if riff_pos == -1:
|
||||
wav_buffer.clear()
|
||||
break
|
||||
wav_buffer = wav_buffer[riff_pos:]
|
||||
if len(wav_buffer) < 8:
|
||||
break
|
||||
file_size = int.from_bytes(wav_buffer[4:8], byteorder='little')
|
||||
expected_size = file_size + 8
|
||||
if len(wav_buffer) < expected_size:
|
||||
break
|
||||
complete_wav = bytes(wav_buffer[:expected_size])
|
||||
del wav_buffer[:expected_size]
|
||||
await audio_play_queue.put(complete_wav)
|
||||
if wav_buffer:
|
||||
leftover = bytes(wav_buffer)
|
||||
try:
|
||||
parse_wav_chunk(leftover)
|
||||
await audio_play_queue.put(leftover)
|
||||
except Exception:
|
||||
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, "剩余音频数据解析失败,已丢弃")
|
||||
else:
|
||||
response = requests.post(tts_server_url, data=data, timeout=600, headers=header)
|
||||
if response.status_code == 200:
|
||||
await audio_play_queue.put(response.content)
|
||||
else:
|
||||
raise Exception(f"语音合成失败: {response.status_code} - {response.text}")
|
||||
|
||||
async def ainput(wait_seconds:float) -> str:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def get_input():
|
||||
try:
|
||||
return input("\n你: ")
|
||||
except EOFError:
|
||||
return ""
|
||||
|
||||
input_task = loop.run_in_executor(None, get_input)
|
||||
while wait_seconds > 0:
|
||||
if input_task.done():
|
||||
return input_task.result()
|
||||
await asyncio.sleep(0.5)
|
||||
wait_seconds -= 0.5
|
||||
return ""
|
||||
|
||||
|
||||
async def achat(engine:SimpleChatEngine,message:str) -> None:
|
||||
user_message = message if message not in [None, ""] else "(没有人说话, 请延续发言或是寻找新的话题)"
|
||||
streaming_response: StreamingAgentChatResponse = await engine.astream_chat(user_message)
|
||||
buffer_response = ""
|
||||
|
||||
end_symbol = ['。', '?', '!']
|
||||
|
||||
# 实时输出流式文本
|
||||
async for chunk in streaming_response.async_response_gen():
|
||||
await asyncio.sleep(0.01)
|
||||
print(chunk, end='', flush=True)
|
||||
for ch in chunk:
|
||||
buffer_response += ch
|
||||
if len(buffer_response) > 20:
|
||||
if ch in end_symbol:
|
||||
await play_vocal(buffer_response.strip())
|
||||
buffer_response = ""
|
||||
buffer_response = buffer_response.strip()
|
||||
if len(buffer_response) > 0:
|
||||
await play_vocal(buffer_response)
|
||||
|
||||
|
||||
def add_speaker() -> None:
|
||||
url = f"{TTS_SERVER_URL}/api/speakers/add"
|
||||
headers = {
|
||||
"accept": "application/json"
|
||||
}
|
||||
data = {
|
||||
"speaker_id": TTS_SPEAKER_ID,
|
||||
"prompt_text": TTS_PROMPT_TEXT,
|
||||
"force_regenerate": True
|
||||
}
|
||||
with open(TTS_PROMPT_WAV_PATH, 'rb') as f:
|
||||
extension = ToolFile(TTS_PROMPT_WAV_PATH).GetExtension().lower()
|
||||
files = {
|
||||
'prompt_wav': (f'prompt.{extension}', f, f'audio/{extension}')
|
||||
}
|
||||
response = requests.post(url, data=data, files=files, headers=headers, timeout=600)
|
||||
if response.status_code == 200:
|
||||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"音色可用: {response.text}")
|
||||
else:
|
||||
raise Exception(f"添加音色失败: {response.status_code} - {response.text}")
|
||||
|
||||
|
||||
async def event_loop(engine:SimpleChatEngine) -> None:
|
||||
add_speaker()
|
||||
audio_player_task = asyncio.create_task(audio_player_worker())
|
||||
message = input("请开始对话: ")
|
||||
wait_second = AUTO_SPEAK_WAIT_SECOND
|
||||
try:
|
||||
while message != "quit" and message != "exit":
|
||||
PrintColorful(ConsoleFrontColor.GREEN, "AI: ", is_reset=False, end='')
|
||||
await achat(engine, message)
|
||||
PrintColorful(ConsoleFrontColor.RESET,"")
|
||||
message = await ainput(wait_second)
|
||||
if not message:
|
||||
wait_second = max(wait_second*1.5, 3600)
|
||||
else:
|
||||
wait_second = AUTO_SPEAK_WAIT_SECOND
|
||||
finally:
|
||||
await audio_play_queue.join()
|
||||
await audio_play_queue.put(None)
|
||||
await audio_player_task
|
||||
cleanup_audio()
|
||||
|
||||
|
||||
async def main():
|
||||
# Initialize
|
||||
try:
|
||||
ollama_llm = Ollama(**ollama_llm_config)
|
||||
Settings.llm = ollama_llm
|
||||
chat_engine = SimpleChatEngine.from_defaults(**chat_engine_config)
|
||||
await event_loop(chat_engine)
|
||||
except Exception as e:
|
||||
config.Log("Error", f"Error: {e}")
|
||||
return
|
||||
finally:
|
||||
cleanup_audio()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user