EP Asynchrony
This commit is contained in:
@@ -39,17 +39,21 @@ class AsyncContextDetector:
|
||||
class AsyncFieldAccessor:
|
||||
"""异步字段访问器,封装字段访问逻辑"""
|
||||
|
||||
def __init__(self, async_fields: Dict[str, 'AsynchronyExpression'], origin_fields: Dict[str, FieldInfo]):
|
||||
def __init__(
|
||||
self,
|
||||
async_fields: Dict[str, 'AsynchronyExpression'],
|
||||
origin_fields: Dict[str, FieldInfo]
|
||||
) -> None:
|
||||
self._async_fields = async_fields
|
||||
self._origin_fields = origin_fields
|
||||
|
||||
async def get_field_value_async(self, field_name: str) -> Any:
|
||||
async def get_field_value_async(self, field_name: str):
|
||||
"""异步获取字段值"""
|
||||
if field_name not in self._origin_fields:
|
||||
raise AttributeError(f"No async field '{field_name}' found")
|
||||
return await self._async_fields[field_name].get_value()
|
||||
|
||||
def get_field_value_sync(self, field_name: str) -> Any:
|
||||
def get_field_value_sync(self, field_name: str):
|
||||
"""同步获取字段值(仅在非异步上下文中使用)"""
|
||||
AsyncContextDetector.ensure_async_context_safe(f"sync access to field '{field_name}'")
|
||||
|
||||
@@ -63,7 +67,7 @@ class AsyncFieldAccessor:
|
||||
elif not async_expr.is_initialize:
|
||||
raise RuntimeError(f"Field '{field_name}' is not initialized and has no timeout")
|
||||
else:
|
||||
return async_expr.value
|
||||
return run_async(async_expr.get_value())
|
||||
|
||||
def is_field_initialized(self, field_name: str) -> bool:
|
||||
"""检查字段是否已初始化"""
|
||||
@@ -89,29 +93,53 @@ class AsynchronyUninitialized:
|
||||
cls.__instance__ = super().__new__(cls)
|
||||
return cls.__instance__
|
||||
|
||||
def __repr__(self):
|
||||
return "uninitialized"
|
||||
|
||||
def __str__(self):
|
||||
return "None"
|
||||
|
||||
class AsynchronyExpression:
|
||||
def __init__(
|
||||
self,
|
||||
field: FieldInfo,
|
||||
value: Any = None,
|
||||
field: FieldInfo,
|
||||
value: Any = AsynchronyUninitialized(),
|
||||
*,
|
||||
time_wait: float = 0.1,
|
||||
timeout: float = 0
|
||||
time_wait: float = 0.1,
|
||||
timeout: float = 0,
|
||||
callback: Optional[Action] = None,
|
||||
):
|
||||
'''
|
||||
参数:
|
||||
field: 字段
|
||||
value: 初始化, 默认为AsynchronyUninitialized, 即无初始化
|
||||
time_wait: 等待时间, 默认为0.1秒
|
||||
timeout: 超时时间, 默认为0秒
|
||||
callback: 回调函数, 默认为None, 当状态为无初始化时get_value会调用callback
|
||||
'''
|
||||
self.field = field
|
||||
self.value = value
|
||||
self.is_initialize = False
|
||||
self._value = value
|
||||
self.callback = callback
|
||||
self.is_initialize = not isinstance(value, AsynchronyUninitialized)
|
||||
self.time_wait = time_wait
|
||||
self.timeout = timeout
|
||||
|
||||
async def GetValue(self) -> Any:
|
||||
"""获取字段值(保持兼容性的旧方法名)"""
|
||||
return await self.get_value()
|
||||
|
||||
async def get_value(self) -> Any:
|
||||
def get_value_sync(self):
|
||||
if self.is_initialize:
|
||||
return self._value
|
||||
elif self.callback is not None:
|
||||
self.callback()
|
||||
if self.is_initialize:
|
||||
return self._value
|
||||
else:
|
||||
raise RuntimeError(f"Field {self.field.FieldName} is not initialized")
|
||||
|
||||
async def get_value(self):
|
||||
"""异步获取字段值,改进的超时机制"""
|
||||
if self.is_initialize:
|
||||
return self.value
|
||||
return self._value
|
||||
elif self.callback is not None:
|
||||
self.callback()
|
||||
|
||||
if self.timeout > 0:
|
||||
try:
|
||||
@@ -119,7 +147,7 @@ class AsynchronyExpression:
|
||||
async def wait_for_initialization():
|
||||
while not self.is_initialize:
|
||||
await asyncio.sleep(self.time_wait)
|
||||
return self.value
|
||||
return self._value
|
||||
|
||||
return await asyncio.wait_for(wait_for_initialization(), timeout=self.timeout)
|
||||
except asyncio.TimeoutError:
|
||||
@@ -128,18 +156,14 @@ class AsynchronyExpression:
|
||||
# 无超时,一直等待
|
||||
while not self.is_initialize:
|
||||
await asyncio.sleep(self.time_wait)
|
||||
return self.value
|
||||
|
||||
def SetValue(self, value: Any) -> None:
|
||||
"""设置字段值(保持兼容性的旧方法名)"""
|
||||
self.set_value(value)
|
||||
|
||||
return self._value
|
||||
|
||||
def set_value(self, value: Any) -> None:
|
||||
"""设置字段值"""
|
||||
if isinstance(value, AsynchronyUninitialized):
|
||||
self.set_uninitialized()
|
||||
elif self.field.Verify(type(value)):
|
||||
self.value = value
|
||||
self._value = value
|
||||
self.is_initialize = True
|
||||
else:
|
||||
raise ValueError(f"Value {value} is not valid for field {self.field.FieldName}")
|
||||
@@ -151,9 +175,8 @@ class AsynchronyExpression:
|
||||
def set_uninitialized(self) -> None:
|
||||
"""设置为未初始化状态"""
|
||||
if self.is_initialize:
|
||||
if hasattr(self, 'value'):
|
||||
del self.value
|
||||
self.value = None
|
||||
del self._value
|
||||
self._value = AsynchronyUninitialized()
|
||||
self.is_initialize = False
|
||||
|
||||
class Asynchronous(ABC):
|
||||
@@ -188,8 +211,8 @@ class Asynchronous(ABC):
|
||||
|
||||
# 一次性获取所需属性,避免重复调用
|
||||
try:
|
||||
field_accessor = super().__getattribute__("_field_accessor")
|
||||
origin_fields = super().__getattribute__("_GetAsynchronousOriginFields")()
|
||||
field_accessor:AsyncFieldAccessor = super().__getattribute__("_field_accessor")
|
||||
origin_fields:Dict[str, FieldInfo] = super().__getattribute__("_GetAsynchronousOriginFields")()
|
||||
except AttributeError:
|
||||
# 对象可能尚未完全初始化
|
||||
return super().__getattribute__(name)
|
||||
@@ -198,18 +221,17 @@ class Asynchronous(ABC):
|
||||
# 这是一个异步字段
|
||||
if AsyncContextDetector.is_in_async_context():
|
||||
# 在异步上下文中,提供友好的错误提示
|
||||
async_fields = super().__getattribute__("__Asynchronous_Fields__")
|
||||
async_fields:Dict[str, AsynchronyExpression] = super().__getattribute__("__Asynchronous_Fields__")
|
||||
async_expr = async_fields[name]
|
||||
|
||||
if not async_expr.is_initialize:
|
||||
timeout_info = f" with {async_expr.timeout}s timeout" if async_expr.timeout > 0 else ""
|
||||
raise RuntimeError(
|
||||
f"Field '{name}' is not initialized{timeout_info}. "
|
||||
f"In async context, use 'await obj.get_field_async(\"{name}\")' instead."
|
||||
)
|
||||
else:
|
||||
# 字段已初始化,直接返回值
|
||||
return async_expr.value
|
||||
return async_expr.get_value_sync()
|
||||
else:
|
||||
# 在同步上下文中,使用字段访问器
|
||||
try:
|
||||
@@ -259,24 +281,9 @@ class Asynchronous(ABC):
|
||||
|
||||
super().__delattr__(name)
|
||||
|
||||
async def get_field_async(self, field_name: str) -> Any:
|
||||
"""异步获取字段值,适用于异步上下文"""
|
||||
return await self._field_accessor.get_field_value_async(field_name)
|
||||
|
||||
def is_field_initialized(self, field_name: str) -> bool:
|
||||
"""检查字段是否已初始化"""
|
||||
return self._field_accessor.is_field_initialized(field_name)
|
||||
|
||||
def set_field_value(self, field_name: str, value: Any) -> None:
|
||||
"""设置字段值"""
|
||||
self._field_accessor.set_field_value(field_name, value)
|
||||
|
||||
def get_field_timeout(self, field_name: str) -> float:
|
||||
"""获取字段的超时设置"""
|
||||
origin_fields = self._GetAsynchronousOriginFields()
|
||||
if field_name not in origin_fields:
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no async field '{field_name}'")
|
||||
return self.__Asynchronous_Fields__[field_name].timeout
|
||||
|
||||
def run_until_complete(coro: Coroutine) -> Any:
|
||||
"""Gets an existing event loop to run the coroutine.
|
||||
|
109
[Test]/test.py
109
[Test]/test.py
@@ -13,103 +13,61 @@ class Test(Asynchronous):
|
||||
super().__init__(c={"timeout":2},b={"timeout":10})
|
||||
self.a = 1
|
||||
|
||||
async def geta(obj:Test) -> int:
|
||||
# 字段a已在__init__中初始化,可以直接访问
|
||||
if obj.is_field_initialized('a'):
|
||||
print(f"geta:{obj.a}")
|
||||
return obj.a
|
||||
else:
|
||||
# 使用异步方法获取
|
||||
value = await obj.get_field_async('a')
|
||||
print(f"geta:{value}")
|
||||
return value
|
||||
|
||||
async def getb(obj:Test) -> int:
|
||||
# 字段b有超时设置但未初始化,使用异步方法获取
|
||||
try:
|
||||
value = await obj.get_field_async('b')
|
||||
print(f"getb:{value}")
|
||||
return value
|
||||
except TimeoutError as e:
|
||||
print(f"getb timeout: {e}")
|
||||
raise
|
||||
|
||||
async def getc(obj:Test) -> int:
|
||||
# 字段c有超时设置但未初始化,使用异步方法获取
|
||||
try:
|
||||
value = await obj.get_field_async('c')
|
||||
print(f"getc:{value}")
|
||||
return value
|
||||
except TimeoutError as e:
|
||||
print(f"getc timeout: {e}")
|
||||
raise
|
||||
async def seta(obj:Test, value:int, delay:float = 1) -> None:
|
||||
await asyncio.sleep(delay)
|
||||
obj.a = value
|
||||
|
||||
async def setb(obj:Test, value:int, delay:float = 1) -> None:
|
||||
await asyncio.sleep(delay)
|
||||
obj.b = value
|
||||
|
||||
async def setc(obj:Test, value:int, delay:float = 1) -> None:
|
||||
await asyncio.sleep(delay)
|
||||
obj.c = value
|
||||
|
||||
async def run():
|
||||
print("=== 测试优化后的异步字段系统 ===")
|
||||
test = Test()
|
||||
|
||||
# 测试字段状态检查
|
||||
print(f"字段初始化状态 - a: {test.is_field_initialized('a')}, b: {test.is_field_initialized('b')}, c: {test.is_field_initialized('c')}")
|
||||
|
||||
# 测试超时设置查询
|
||||
print(f"字段超时设置 - b: {test.get_field_timeout('b')}s, c: {test.get_field_timeout('c')}s")
|
||||
assert test.is_field_initialized('a')
|
||||
assert not test.is_field_initialized('b')
|
||||
assert not test.is_field_initialized('c')
|
||||
|
||||
print("\n=== 测试1:未设置值的情况(应该超时)===")
|
||||
try:
|
||||
print("开始并发获取字段值...")
|
||||
r = await asyncio.gather(geta(test), getb(test), getc(test))
|
||||
print(f"结果: {r}")
|
||||
print(f"失败: {test.a,test.b,test.c}")
|
||||
raise RuntimeError("测试1应该超时")
|
||||
except Exception as e:
|
||||
print(f"捕获到异常: {e}")
|
||||
print(f"成功: {e}")
|
||||
|
||||
print("\n=== 测试2:在超时前设置字段b的值 ===")
|
||||
# 创建新的测试实例
|
||||
test2 = Test()
|
||||
print(f"设置前字段b初始化状态: {test2.is_field_initialized('b')}")
|
||||
assert not test2.is_field_initialized('b')
|
||||
|
||||
try:
|
||||
# 启动并发任务:设置b的值和获取b的值
|
||||
print("启动并发任务:0.5秒后设置b=42,同时尝试获取b的值(超时10秒)")
|
||||
|
||||
# 并发执行:设置b值(延迟0.5秒)和获取b值
|
||||
results = await asyncio.gather(
|
||||
setb(test2, 42, delay=0.5), # 0.5秒后设置b=42
|
||||
getb(test2), # 尝试获取b值(会等待直到被设置)
|
||||
return_exceptions=True
|
||||
)
|
||||
|
||||
print(f"并发任务结果: {results}")
|
||||
print(f"设置后字段b初始化状态: {test2.is_field_initialized('b')}")
|
||||
|
||||
# 再次访问b,应该能立即获取到值
|
||||
print("再次访问字段b(应该立即返回):")
|
||||
b_value = await test2.get_field_async('b')
|
||||
print(f"字段b的值: {b_value}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"测试2出现异常: {e}")
|
||||
# 启动并发任务:设置b的值和获取b的值
|
||||
# 并发执行:设置b值(延迟0.5秒)和获取b值
|
||||
await asyncio.gather(
|
||||
setb(test2, 42, delay=0.5), # 0.5秒后设置b=42
|
||||
return_exceptions=True
|
||||
)
|
||||
|
||||
assert test2.b == 42
|
||||
assert test2.is_field_initialized('b')
|
||||
assert test2.b == 42
|
||||
|
||||
print("\n=== 测试3:使用同步方式设置,异步方式获取 ===")
|
||||
test3 = Test()
|
||||
print("使用同步方式设置字段b = 100")
|
||||
test3.b = 100
|
||||
print(f"设置后字段b初始化状态: {test3.is_field_initialized('b')}")
|
||||
assert test3.is_field_initialized('b')
|
||||
assert test3.b == 100
|
||||
|
||||
# 异步获取值
|
||||
b_sync_set_value = await test3.get_field_async('b')
|
||||
print(f"异步获取同步设置的值: {b_sync_set_value}")
|
||||
|
||||
print("\n=== 测试4:测试字段c(短超时,应该仍然超时)===")
|
||||
print("\n=== 测试3:测试字段c(短超时,应该仍然超时)===")
|
||||
try:
|
||||
print("尝试单独访问字段c(2秒超时)...")
|
||||
c_value = await test.get_field_async('c')
|
||||
print(f"字段c的值: {c_value}")
|
||||
print(f"失败: {test.c}")
|
||||
except TimeoutError as timeout_e:
|
||||
print(f"字段c访问超时(预期): {timeout_e}")
|
||||
print(f"成功: {timeout_e}")
|
||||
|
||||
def test_sync_access():
|
||||
"""测试同步访问(在非异步上下文中)"""
|
||||
@@ -118,15 +76,16 @@ def test_sync_access():
|
||||
|
||||
# 测试已初始化字段的同步访问
|
||||
try:
|
||||
print(f"同步访问字段a: {test.a}")
|
||||
print(f"成功: a = {test.a}")
|
||||
except Exception as e:
|
||||
print(f"同步访问字段a失败: {e}")
|
||||
raise
|
||||
|
||||
# 测试未初始化字段的同步访问(应该有更友好的错误提示)
|
||||
try:
|
||||
print(f"同步访问字段c: {test.c}")
|
||||
print(f"失败: c = {test.c}")
|
||||
raise RuntimeError("字段c此时不应该能够被访问")
|
||||
except Exception as e:
|
||||
print(f"同步访问字段c失败 (预期): {e}")
|
||||
print(f"成功: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试同步访问
|
||||
|
Reference in New Issue
Block a user