Files
2025-07-25 10:48:07 +08:00

354 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from .Config import *
from .Reflection import *
from collections import defaultdict
import asyncio
import threading
from typing import Optional
from pydantic import BaseModel
from abc import ABC, abstractmethod
class AsyncContextDetector:
"""异步上下文检测工具类"""
@staticmethod
def is_in_async_context() -> bool:
"""检查是否在异步上下文中运行"""
try:
asyncio.current_task()
return True
except RuntimeError:
return False
@staticmethod
def get_current_loop() -> Optional[asyncio.AbstractEventLoop]:
"""获取当前事件循环如果没有则返回None"""
try:
return asyncio.get_running_loop()
except RuntimeError:
return None
@staticmethod
def ensure_async_context_safe(operation_name: str) -> None:
"""确保在异步上下文中执行是安全的"""
if AsyncContextDetector.is_in_async_context():
raise RuntimeError(
f"Cannot perform '{operation_name}' from within an async context. "
f"Use await or async methods instead."
)
class AsyncFieldAccessor:
"""异步字段访问器,封装字段访问逻辑"""
def __init__(
self,
async_fields: Dict[str, 'AsynchronyExpression'],
origin_fields: Dict[str, FieldInfo]
) -> None:
self._async_fields = async_fields
self._origin_fields = origin_fields
async def get_field_value_async(self, field_name: str):
"""异步获取字段值"""
if field_name not in self._origin_fields:
raise AttributeError(f"No async field '{field_name}' found")
return await self._async_fields[field_name].get_value()
def get_field_value_sync(self, field_name: str):
"""同步获取字段值(仅在非异步上下文中使用)"""
AsyncContextDetector.ensure_async_context_safe(f"sync access to field '{field_name}'")
if field_name not in self._origin_fields:
raise AttributeError(f"No async field '{field_name}' found")
async_expr = self._async_fields[field_name]
if not async_expr.is_initialize and async_expr.timeout > 0:
# 需要等待但在同步上下文中使用run_async
return run_async(async_expr.get_value())
elif not async_expr.is_initialize:
raise RuntimeError(f"Field '{field_name}' is not initialized and has no timeout")
else:
return run_async(async_expr.get_value())
def is_field_initialized(self, field_name: str) -> bool:
"""检查字段是否已初始化"""
if field_name not in self._origin_fields:
raise AttributeError(f"No async field '{field_name}' found")
return self._async_fields[field_name].is_initialize
def set_field_value(self, field_name: str, value: Any) -> None:
"""设置字段值"""
if field_name not in self._origin_fields:
raise AttributeError(f"No async field '{field_name}' found")
self._async_fields[field_name].set_value(value)
class AsynchronyUninitialized:
"""表示未初始化状态的单例类"""
__instance__ = None
_lock = threading.Lock()
def __new__(cls, *args, **kwargs):
if cls.__instance__ is None:
with cls._lock:
if cls.__instance__ is None:
cls.__instance__ = super().__new__(cls)
return cls.__instance__
def __repr__(self):
return "uninitialized"
def __str__(self):
return "None"
class AsynchronyExpression:
def __init__(
self,
field: FieldInfo,
value: Any = AsynchronyUninitialized(),
*,
time_wait: float = 0.1,
timeout: float = 0,
callback: Optional[Action] = None,
):
'''
参数:
field: 字段
value: 初始化, 默认为AsynchronyUninitialized, 即无初始化
time_wait: 等待时间, 默认为0.1秒
timeout: 超时时间, 默认为0秒
callback: 回调函数, 默认为None, 当状态为无初始化时get_value会调用callback
'''
self.field = field
self._value = value
self.callback = callback
self.is_initialize = not isinstance(value, AsynchronyUninitialized)
self.time_wait = time_wait
self.timeout = timeout
def get_value_sync(self):
if self.is_initialize:
return self._value
elif self.callback is not None:
self.callback()
if self.is_initialize:
return self._value
else:
raise RuntimeError(f"Field {self.field.FieldName} is not initialized")
async def get_value(self):
"""异步获取字段值,改进的超时机制"""
if self.is_initialize:
return self._value
elif self.callback is not None:
self.callback()
if self.timeout > 0:
try:
# 使用 asyncio.wait_for 提供更精确的超时控制
async def wait_for_initialization():
while not self.is_initialize:
await asyncio.sleep(self.time_wait)
return self._value
return await asyncio.wait_for(wait_for_initialization(), timeout=self.timeout)
except asyncio.TimeoutError:
raise TimeoutError(f"Timeout waiting for uninitialized field {self.field.FieldName}")
else:
# 无超时,一直等待
while not self.is_initialize:
await asyncio.sleep(self.time_wait)
return self._value
def set_value(self, value: Any) -> None:
"""设置字段值"""
if isinstance(value, AsynchronyUninitialized):
self.set_uninitialized()
elif self.field.Verify(type(value)):
self._value = value
self.is_initialize = True
else:
raise ValueError(f"Value {value} is not valid for field {self.field.FieldName}")
def SetUninitialized(self) -> None:
"""设置为未初始化状态(保持兼容性的旧方法名)"""
self.set_uninitialized()
def set_uninitialized(self) -> None:
"""设置为未初始化状态"""
if self.is_initialize:
del self._value
self._value = AsynchronyUninitialized()
self.is_initialize = False
class Asynchronous(ABC):
__Asynchronous_Origin_Fields__: Dict[Type, Dict[str, FieldInfo]] = defaultdict(dict)
_fields_lock = threading.Lock()
def _GetAsynchronousOriginFields(self) -> Dict[str, FieldInfo]:
return Asynchronous.__Asynchronous_Origin_Fields__[type(self)]
def __init__(self, **kwargs: Dict[str, dict]):
super().__init__()
self.__Asynchronous_Fields__: Dict[str, AsynchronyExpression] = {}
# 使用线程锁保护类变量访问
with Asynchronous._fields_lock:
origin_fields = self._GetAsynchronousOriginFields()
for field_info in TypeManager.GetInstance().CreateOrGetRefTypeFromType(type(self)).GetAllFields():
if field_info.FieldName == "__Asynchronous_Origin_Fields__":
continue
origin_fields[field_info.FieldName] = field_info
self.__Asynchronous_Fields__[field_info.FieldName] = AsynchronyExpression(
field_info, **kwargs.get(field_info.FieldName, {})
)
# 创建字段访问器以提升性能
self._field_accessor = AsyncFieldAccessor(self.__Asynchronous_Fields__, origin_fields)
def __getattribute__(self, name: str) -> Any:
# 快速路径:非异步字段直接返回
if name in ("__Asynchronous_Fields__", "_GetAsynchronousOriginFields", "_field_accessor"):
return super().__getattribute__(name)
# 一次性获取所需属性,避免重复调用
try:
field_accessor:AsyncFieldAccessor = super().__getattribute__("_field_accessor")
origin_fields:Dict[str, FieldInfo] = super().__getattribute__("_GetAsynchronousOriginFields")()
except AttributeError:
# 对象可能尚未完全初始化
return super().__getattribute__(name)
if name in origin_fields:
# 这是一个异步字段
if AsyncContextDetector.is_in_async_context():
# 在异步上下文中,提供友好的错误提示
async_fields:Dict[str, AsynchronyExpression] = super().__getattribute__("__Asynchronous_Fields__")
async_expr = async_fields[name]
if not async_expr.is_initialize:
timeout_info = f" with {async_expr.timeout}s timeout" if async_expr.timeout > 0 else ""
raise RuntimeError(
f"Field '{name}' is not initialized{timeout_info}. "
)
else:
# 字段已初始化,直接返回值
return async_expr.get_value_sync()
else:
# 在同步上下文中,使用字段访问器
try:
return field_accessor.get_field_value_sync(name)
except RuntimeError as e:
if "Cannot perform" in str(e):
# 重新包装错误信息,提供更友好的提示
raise RuntimeError(
f"Cannot access async field '{name}' from sync context when it requires initialization. "
f"Use async context or ensure field is pre-initialized."
) from e
else:
raise
return super().__getattribute__(name)
def __setattr__(self, name: str, value: Any) -> None:
if name in ("__Asynchronous_Fields__", "_GetAsynchronousOriginFields", "_field_accessor"):
super().__setattr__(name, value)
elif hasattr(self, '_field_accessor'):
# 对象已初始化,使用字段访问器
try:
field_accessor = super().__getattribute__("_field_accessor")
field_accessor.set_field_value(name, value)
return
except AttributeError:
# 不是异步字段
pass
super().__setattr__(name, value)
def __delattr__(self, name: str) -> None:
if name in ("__Asynchronous_Fields__", "_GetAsynchronousOriginFields", "_field_accessor"):
super().__delattr__(name)
elif hasattr(self, '_field_accessor'):
# 对象已初始化,使用字段访问器
try:
field_accessor = super().__getattribute__("_field_accessor")
origin_fields = super().__getattribute__("_GetAsynchronousOriginFields")()
if name in origin_fields:
async_fields = super().__getattribute__("__Asynchronous_Fields__")
async_fields[name].set_uninitialized()
return
except AttributeError:
# 不是异步字段
pass
super().__delattr__(name)
def is_field_initialized(self, field_name: str) -> bool:
"""检查字段是否已初始化"""
return self._field_accessor.is_field_initialized(field_name)
def run_until_complete(coro: Coroutine) -> Any:
"""Gets an existing event loop to run the coroutine.
If there is no existing event loop, creates a new one.
"""
try:
# Check if there's an existing event loop
loop = asyncio.get_event_loop()
# If we're here, there's an existing loop but it's not running
return loop.run_until_complete(coro)
except RuntimeError:
# If we can't get the event loop, we're likely in a different thread, or its already running
try:
return asyncio.run(coro)
except RuntimeError:
raise RuntimeError(
"Detected nested async. Please use nest_asyncio.apply() to allow nested event loops."
"Or, use async entry methods like `aquery()`, `aretriever`, `achat`, etc."
)
def run_async_coroutine(coro: Coroutine) -> Any:
try:
# Check if there's an existing event loop
loop = asyncio.get_event_loop()
# If we're here, there's an existing loop but it's not running
return loop.create_task(coro)
except RuntimeError:
# If we can't get the event loop, we're likely in a different thread, or its already running
try:
return asyncio.run(coro)
except RuntimeError:
raise RuntimeError(
"Detected nested async. Please use nest_asyncio.apply() to allow nested event loops."
"Or, use async entry methods like `aquery()`, `aretriever`, `achat`, etc."
)
def run_async(coro: Coroutine):
"""安全地运行异步协程,避免事件循环死锁"""
# 使用统一的异步上下文检测
AsyncContextDetector.ensure_async_context_safe("run_async")
# 尝试获取当前事件循环
current_loop = AsyncContextDetector.get_current_loop()
if current_loop is not None and not current_loop.is_running():
# 有事件循环但未运行,直接使用
return current_loop.run_until_complete(coro)
elif current_loop is None:
# 没有事件循环,创建新的
try:
return asyncio.run(coro)
except RuntimeError as e:
raise RuntimeError(
"Failed to run async coroutine. "
"Please ensure proper async environment or use nest_asyncio.apply() for nested loops."
) from e
else:
# 事件循环正在运行这种情况应该被AsyncContextDetector捕获
raise RuntimeError(
"Unexpected state: running event loop detected but context check passed. "
"This should not happen."
)