EP Asynchrony

This commit is contained in:
2025-07-25 10:48:07 +08:00
parent 4d0f24fd0c
commit 2bb6f924df
2 changed files with 88 additions and 122 deletions

View File

@@ -39,17 +39,21 @@ class AsyncContextDetector:
class AsyncFieldAccessor:
"""异步字段访问器,封装字段访问逻辑"""
def __init__(self, async_fields: Dict[str, 'AsynchronyExpression'], origin_fields: Dict[str, FieldInfo]):
def __init__(
self,
async_fields: Dict[str, 'AsynchronyExpression'],
origin_fields: Dict[str, FieldInfo]
) -> None:
self._async_fields = async_fields
self._origin_fields = origin_fields
async def get_field_value_async(self, field_name: str) -> Any:
async def get_field_value_async(self, field_name: str):
"""异步获取字段值"""
if field_name not in self._origin_fields:
raise AttributeError(f"No async field '{field_name}' found")
return await self._async_fields[field_name].get_value()
def get_field_value_sync(self, field_name: str) -> Any:
def get_field_value_sync(self, field_name: str):
"""同步获取字段值(仅在非异步上下文中使用)"""
AsyncContextDetector.ensure_async_context_safe(f"sync access to field '{field_name}'")
@@ -63,7 +67,7 @@ class AsyncFieldAccessor:
elif not async_expr.is_initialize:
raise RuntimeError(f"Field '{field_name}' is not initialized and has no timeout")
else:
return async_expr.value
return run_async(async_expr.get_value())
def is_field_initialized(self, field_name: str) -> bool:
"""检查字段是否已初始化"""
@@ -89,29 +93,53 @@ class AsynchronyUninitialized:
cls.__instance__ = super().__new__(cls)
return cls.__instance__
def __repr__(self):
return "uninitialized"
def __str__(self):
return "None"
class AsynchronyExpression:
def __init__(
self,
field: FieldInfo,
value: Any = None,
field: FieldInfo,
value: Any = AsynchronyUninitialized(),
*,
time_wait: float = 0.1,
timeout: float = 0
time_wait: float = 0.1,
timeout: float = 0,
callback: Optional[Action] = None,
):
'''
参数:
field: 字段
value: 初始化, 默认为AsynchronyUninitialized, 即无初始化
time_wait: 等待时间, 默认为0.1秒
timeout: 超时时间, 默认为0秒
callback: 回调函数, 默认为None, 当状态为无初始化时get_value会调用callback
'''
self.field = field
self.value = value
self.is_initialize = False
self._value = value
self.callback = callback
self.is_initialize = not isinstance(value, AsynchronyUninitialized)
self.time_wait = time_wait
self.timeout = timeout
async def GetValue(self) -> Any:
"""获取字段值(保持兼容性的旧方法名)"""
return await self.get_value()
async def get_value(self) -> Any:
def get_value_sync(self):
if self.is_initialize:
return self._value
elif self.callback is not None:
self.callback()
if self.is_initialize:
return self._value
else:
raise RuntimeError(f"Field {self.field.FieldName} is not initialized")
async def get_value(self):
"""异步获取字段值,改进的超时机制"""
if self.is_initialize:
return self.value
return self._value
elif self.callback is not None:
self.callback()
if self.timeout > 0:
try:
@@ -119,7 +147,7 @@ class AsynchronyExpression:
async def wait_for_initialization():
while not self.is_initialize:
await asyncio.sleep(self.time_wait)
return self.value
return self._value
return await asyncio.wait_for(wait_for_initialization(), timeout=self.timeout)
except asyncio.TimeoutError:
@@ -128,18 +156,14 @@ class AsynchronyExpression:
# 无超时,一直等待
while not self.is_initialize:
await asyncio.sleep(self.time_wait)
return self.value
def SetValue(self, value: Any) -> None:
"""设置字段值(保持兼容性的旧方法名)"""
self.set_value(value)
return self._value
def set_value(self, value: Any) -> None:
"""设置字段值"""
if isinstance(value, AsynchronyUninitialized):
self.set_uninitialized()
elif self.field.Verify(type(value)):
self.value = value
self._value = value
self.is_initialize = True
else:
raise ValueError(f"Value {value} is not valid for field {self.field.FieldName}")
@@ -151,9 +175,8 @@ class AsynchronyExpression:
def set_uninitialized(self) -> None:
"""设置为未初始化状态"""
if self.is_initialize:
if hasattr(self, 'value'):
del self.value
self.value = None
del self._value
self._value = AsynchronyUninitialized()
self.is_initialize = False
class Asynchronous(ABC):
@@ -188,8 +211,8 @@ class Asynchronous(ABC):
# 一次性获取所需属性,避免重复调用
try:
field_accessor = super().__getattribute__("_field_accessor")
origin_fields = super().__getattribute__("_GetAsynchronousOriginFields")()
field_accessor:AsyncFieldAccessor = super().__getattribute__("_field_accessor")
origin_fields:Dict[str, FieldInfo] = super().__getattribute__("_GetAsynchronousOriginFields")()
except AttributeError:
# 对象可能尚未完全初始化
return super().__getattribute__(name)
@@ -198,18 +221,17 @@ class Asynchronous(ABC):
# 这是一个异步字段
if AsyncContextDetector.is_in_async_context():
# 在异步上下文中,提供友好的错误提示
async_fields = super().__getattribute__("__Asynchronous_Fields__")
async_fields:Dict[str, AsynchronyExpression] = super().__getattribute__("__Asynchronous_Fields__")
async_expr = async_fields[name]
if not async_expr.is_initialize:
timeout_info = f" with {async_expr.timeout}s timeout" if async_expr.timeout > 0 else ""
raise RuntimeError(
f"Field '{name}' is not initialized{timeout_info}. "
f"In async context, use 'await obj.get_field_async(\"{name}\")' instead."
)
else:
# 字段已初始化,直接返回值
return async_expr.value
return async_expr.get_value_sync()
else:
# 在同步上下文中,使用字段访问器
try:
@@ -259,24 +281,9 @@ class Asynchronous(ABC):
super().__delattr__(name)
async def get_field_async(self, field_name: str) -> Any:
"""异步获取字段值,适用于异步上下文"""
return await self._field_accessor.get_field_value_async(field_name)
def is_field_initialized(self, field_name: str) -> bool:
"""检查字段是否已初始化"""
return self._field_accessor.is_field_initialized(field_name)
def set_field_value(self, field_name: str, value: Any) -> None:
"""设置字段值"""
self._field_accessor.set_field_value(field_name, value)
def get_field_timeout(self, field_name: str) -> float:
"""获取字段的超时设置"""
origin_fields = self._GetAsynchronousOriginFields()
if field_name not in origin_fields:
raise AttributeError(f"'{type(self).__name__}' object has no async field '{field_name}'")
return self.__Asynchronous_Fields__[field_name].timeout
def run_until_complete(coro: Coroutine) -> Any:
"""Gets an existing event loop to run the coroutine.