From 4d0f24fd0c50ae12ba28badab2452ad098e06ecb Mon Sep 17 00:00:00 2001 From: ninemine <1371605831@qq.com> Date: Thu, 24 Jul 2025 11:45:44 +0800 Subject: [PATCH 1/2] =?UTF-8?q?EP=20Asynchrony=20=E9=9D=9E=E5=9B=A0?= =?UTF-8?q?=E6=9E=9C=E5=9E=8B=E5=BC=82=E6=AD=A5=E6=9E=B6=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Convention/Runtime/Asynchrony.py | 346 +++++++++++++++++++++++++++++++ Convention/Runtime/Reflection.py | 32 ++- [Test]/test.json | 20 -- [Test]/test.py | 136 +++++++++++- [Test]/test0.py | 3 - 5 files changed, 498 insertions(+), 39 deletions(-) create mode 100644 Convention/Runtime/Asynchrony.py delete mode 100644 [Test]/test.json delete mode 100644 [Test]/test0.py diff --git a/Convention/Runtime/Asynchrony.py b/Convention/Runtime/Asynchrony.py new file mode 100644 index 0000000..4e08082 --- /dev/null +++ b/Convention/Runtime/Asynchrony.py @@ -0,0 +1,346 @@ +from .Config import * +from .Reflection import * +from collections import defaultdict +import asyncio +import threading +from typing import Optional +from pydantic import BaseModel +from abc import ABC, abstractmethod + +class AsyncContextDetector: + """异步上下文检测工具类""" + + @staticmethod + def is_in_async_context() -> bool: + """检查是否在异步上下文中运行""" + try: + asyncio.current_task() + return True + except RuntimeError: + return False + + @staticmethod + def get_current_loop() -> Optional[asyncio.AbstractEventLoop]: + """获取当前事件循环,如果没有则返回None""" + try: + return asyncio.get_running_loop() + except RuntimeError: + return None + + @staticmethod + def ensure_async_context_safe(operation_name: str) -> None: + """确保在异步上下文中执行是安全的""" + if AsyncContextDetector.is_in_async_context(): + raise RuntimeError( + f"Cannot perform '{operation_name}' from within an async context. " + f"Use await or async methods instead." + ) + +class AsyncFieldAccessor: + """异步字段访问器,封装字段访问逻辑""" + + def __init__(self, async_fields: Dict[str, 'AsynchronyExpression'], origin_fields: Dict[str, FieldInfo]): + self._async_fields = async_fields + self._origin_fields = origin_fields + + async def get_field_value_async(self, field_name: str) -> Any: + """异步获取字段值""" + if field_name not in self._origin_fields: + raise AttributeError(f"No async field '{field_name}' found") + return await self._async_fields[field_name].get_value() + + def get_field_value_sync(self, field_name: str) -> Any: + """同步获取字段值(仅在非异步上下文中使用)""" + AsyncContextDetector.ensure_async_context_safe(f"sync access to field '{field_name}'") + + if field_name not in self._origin_fields: + raise AttributeError(f"No async field '{field_name}' found") + + async_expr = self._async_fields[field_name] + if not async_expr.is_initialize and async_expr.timeout > 0: + # 需要等待但在同步上下文中,使用run_async + return run_async(async_expr.get_value()) + elif not async_expr.is_initialize: + raise RuntimeError(f"Field '{field_name}' is not initialized and has no timeout") + else: + return async_expr.value + + def is_field_initialized(self, field_name: str) -> bool: + """检查字段是否已初始化""" + if field_name not in self._origin_fields: + raise AttributeError(f"No async field '{field_name}' found") + return self._async_fields[field_name].is_initialize + + def set_field_value(self, field_name: str, value: Any) -> None: + """设置字段值""" + if field_name not in self._origin_fields: + raise AttributeError(f"No async field '{field_name}' found") + self._async_fields[field_name].set_value(value) + +class AsynchronyUninitialized: + """表示未初始化状态的单例类""" + __instance__ = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + if cls.__instance__ is None: + with cls._lock: + if cls.__instance__ is None: + cls.__instance__ = super().__new__(cls) + return cls.__instance__ + +class AsynchronyExpression: + def __init__( + self, + field: FieldInfo, + value: Any = None, + *, + time_wait: float = 0.1, + timeout: float = 0 + ): + self.field = field + self.value = value + self.is_initialize = False + self.time_wait = time_wait + self.timeout = timeout + + async def GetValue(self) -> Any: + """获取字段值(保持兼容性的旧方法名)""" + return await self.get_value() + + async def get_value(self) -> Any: + """异步获取字段值,改进的超时机制""" + if self.is_initialize: + return self.value + + if self.timeout > 0: + try: + # 使用 asyncio.wait_for 提供更精确的超时控制 + async def wait_for_initialization(): + while not self.is_initialize: + await asyncio.sleep(self.time_wait) + return self.value + + return await asyncio.wait_for(wait_for_initialization(), timeout=self.timeout) + except asyncio.TimeoutError: + raise TimeoutError(f"Timeout waiting for uninitialized field {self.field.FieldName}") + else: + # 无超时,一直等待 + while not self.is_initialize: + await asyncio.sleep(self.time_wait) + return self.value + + def SetValue(self, value: Any) -> None: + """设置字段值(保持兼容性的旧方法名)""" + self.set_value(value) + + def set_value(self, value: Any) -> None: + """设置字段值""" + if isinstance(value, AsynchronyUninitialized): + self.set_uninitialized() + elif self.field.Verify(type(value)): + self.value = value + self.is_initialize = True + else: + raise ValueError(f"Value {value} is not valid for field {self.field.FieldName}") + + def SetUninitialized(self) -> None: + """设置为未初始化状态(保持兼容性的旧方法名)""" + self.set_uninitialized() + + def set_uninitialized(self) -> None: + """设置为未初始化状态""" + if self.is_initialize: + if hasattr(self, 'value'): + del self.value + self.value = None + self.is_initialize = False + +class Asynchronous(ABC): + __Asynchronous_Origin_Fields__: Dict[Type, Dict[str, FieldInfo]] = defaultdict(dict) + _fields_lock = threading.Lock() + + def _GetAsynchronousOriginFields(self) -> Dict[str, FieldInfo]: + return Asynchronous.__Asynchronous_Origin_Fields__[type(self)] + + def __init__(self, **kwargs: Dict[str, dict]): + super().__init__() + self.__Asynchronous_Fields__: Dict[str, AsynchronyExpression] = {} + + # 使用线程锁保护类变量访问 + with Asynchronous._fields_lock: + origin_fields = self._GetAsynchronousOriginFields() + for field_info in TypeManager.GetInstance().CreateOrGetRefTypeFromType(type(self)).GetAllFields(): + if field_info.FieldName == "__Asynchronous_Origin_Fields__": + continue + origin_fields[field_info.FieldName] = field_info + self.__Asynchronous_Fields__[field_info.FieldName] = AsynchronyExpression( + field_info, **kwargs.get(field_info.FieldName, {}) + ) + + # 创建字段访问器以提升性能 + self._field_accessor = AsyncFieldAccessor(self.__Asynchronous_Fields__, origin_fields) + + def __getattribute__(self, name: str) -> Any: + # 快速路径:非异步字段直接返回 + if name in ("__Asynchronous_Fields__", "_GetAsynchronousOriginFields", "_field_accessor"): + return super().__getattribute__(name) + + # 一次性获取所需属性,避免重复调用 + try: + field_accessor = super().__getattribute__("_field_accessor") + origin_fields = super().__getattribute__("_GetAsynchronousOriginFields")() + except AttributeError: + # 对象可能尚未完全初始化 + return super().__getattribute__(name) + + if name in origin_fields: + # 这是一个异步字段 + if AsyncContextDetector.is_in_async_context(): + # 在异步上下文中,提供友好的错误提示 + async_fields = super().__getattribute__("__Asynchronous_Fields__") + async_expr = async_fields[name] + + if not async_expr.is_initialize: + timeout_info = f" with {async_expr.timeout}s timeout" if async_expr.timeout > 0 else "" + raise RuntimeError( + f"Field '{name}' is not initialized{timeout_info}. " + f"In async context, use 'await obj.get_field_async(\"{name}\")' instead." + ) + else: + # 字段已初始化,直接返回值 + return async_expr.value + else: + # 在同步上下文中,使用字段访问器 + try: + return field_accessor.get_field_value_sync(name) + except RuntimeError as e: + if "Cannot perform" in str(e): + # 重新包装错误信息,提供更友好的提示 + raise RuntimeError( + f"Cannot access async field '{name}' from sync context when it requires initialization. " + f"Use async context or ensure field is pre-initialized." + ) from e + else: + raise + + return super().__getattribute__(name) + + def __setattr__(self, name: str, value: Any) -> None: + if name in ("__Asynchronous_Fields__", "_GetAsynchronousOriginFields", "_field_accessor"): + super().__setattr__(name, value) + elif hasattr(self, '_field_accessor'): + # 对象已初始化,使用字段访问器 + try: + field_accessor = super().__getattribute__("_field_accessor") + field_accessor.set_field_value(name, value) + return + except AttributeError: + # 不是异步字段 + pass + + super().__setattr__(name, value) + + def __delattr__(self, name: str) -> None: + if name in ("__Asynchronous_Fields__", "_GetAsynchronousOriginFields", "_field_accessor"): + super().__delattr__(name) + elif hasattr(self, '_field_accessor'): + # 对象已初始化,使用字段访问器 + try: + field_accessor = super().__getattribute__("_field_accessor") + origin_fields = super().__getattribute__("_GetAsynchronousOriginFields")() + if name in origin_fields: + async_fields = super().__getattribute__("__Asynchronous_Fields__") + async_fields[name].set_uninitialized() + return + except AttributeError: + # 不是异步字段 + pass + + super().__delattr__(name) + + async def get_field_async(self, field_name: str) -> Any: + """异步获取字段值,适用于异步上下文""" + return await self._field_accessor.get_field_value_async(field_name) + + def is_field_initialized(self, field_name: str) -> bool: + """检查字段是否已初始化""" + return self._field_accessor.is_field_initialized(field_name) + + def set_field_value(self, field_name: str, value: Any) -> None: + """设置字段值""" + self._field_accessor.set_field_value(field_name, value) + + def get_field_timeout(self, field_name: str) -> float: + """获取字段的超时设置""" + origin_fields = self._GetAsynchronousOriginFields() + if field_name not in origin_fields: + raise AttributeError(f"'{type(self).__name__}' object has no async field '{field_name}'") + return self.__Asynchronous_Fields__[field_name].timeout + +def run_until_complete(coro: Coroutine) -> Any: + """Gets an existing event loop to run the coroutine. + + If there is no existing event loop, creates a new one. + """ + try: + # Check if there's an existing event loop + loop = asyncio.get_event_loop() + + # If we're here, there's an existing loop but it's not running + return loop.run_until_complete(coro) + + except RuntimeError: + # If we can't get the event loop, we're likely in a different thread, or its already running + try: + return asyncio.run(coro) + except RuntimeError: + raise RuntimeError( + "Detected nested async. Please use nest_asyncio.apply() to allow nested event loops." + "Or, use async entry methods like `aquery()`, `aretriever`, `achat`, etc." + ) + +def run_async_coroutine(coro: Coroutine) -> Any: + try: + # Check if there's an existing event loop + loop = asyncio.get_event_loop() + + # If we're here, there's an existing loop but it's not running + return loop.create_task(coro) + + except RuntimeError: + # If we can't get the event loop, we're likely in a different thread, or its already running + try: + return asyncio.run(coro) + except RuntimeError: + raise RuntimeError( + "Detected nested async. Please use nest_asyncio.apply() to allow nested event loops." + "Or, use async entry methods like `aquery()`, `aretriever`, `achat`, etc." + ) + +def run_async(coro: Coroutine): + """安全地运行异步协程,避免事件循环死锁""" + # 使用统一的异步上下文检测 + AsyncContextDetector.ensure_async_context_safe("run_async") + + # 尝试获取当前事件循环 + current_loop = AsyncContextDetector.get_current_loop() + + if current_loop is not None and not current_loop.is_running(): + # 有事件循环但未运行,直接使用 + return current_loop.run_until_complete(coro) + elif current_loop is None: + # 没有事件循环,创建新的 + try: + return asyncio.run(coro) + except RuntimeError as e: + raise RuntimeError( + "Failed to run async coroutine. " + "Please ensure proper async environment or use nest_asyncio.apply() for nested loops." + ) from e + else: + # 事件循环正在运行,这种情况应该被AsyncContextDetector捕获 + raise RuntimeError( + "Unexpected state: running event loop detected but context check passed. " + "This should not happen." + ) diff --git a/Convention/Runtime/Reflection.py b/Convention/Runtime/Reflection.py index e7c2b8b..967c656 100644 --- a/Convention/Runtime/Reflection.py +++ b/Convention/Runtime/Reflection.py @@ -565,7 +565,7 @@ class ValueInfo(BaseInfo): return ValueInfo(metaType, **kwargs) else: return ValueInfo(type_, **kwargs) - elif isinstance(metaType, Self):#metaType is Self: + elif metaType is Self: if SelfType is None: raise ReflectionException("SelfType is required when metaType is ") return ValueInfo.Create(SelfType, **kwargs) @@ -1198,7 +1198,7 @@ class RefType(ValueInfo): # 确保正确地实现所有GetBase*方法 @functools.lru_cache(maxsize=128) - def GetBaseFields(self, flag:RefTypeFlag=RefTypeFlag.Default) -> List[FieldInfo]: + def _GetBaseFields(self, flag:RefTypeFlag=RefTypeFlag.Default) -> List[FieldInfo]: if self._BaseTypes is None: self._InitBaseTypesIfNeeded() result = [] @@ -1206,8 +1206,11 @@ class RefType(ValueInfo): result.extend(baseType.GetFields(flag)) return result + def GetBaseFields(self, flag:RefTypeFlag=RefTypeFlag.Default) -> List[FieldInfo]: + return self._GetBaseFields(flag) + @functools.lru_cache(maxsize=128) - def GetAllBaseFields(self) -> List[FieldInfo]: + def _GetAllBaseFields(self) -> List[FieldInfo]: if self._BaseTypes is None: self._InitBaseTypesIfNeeded() result = [] @@ -1215,9 +1218,12 @@ class RefType(ValueInfo): result.extend(baseType.GetAllFields()) return result + def GetAllBaseFields(self) -> List[FieldInfo]: + return self._GetAllBaseFields() + # 修改所有的GetBase*方法 @functools.lru_cache(maxsize=128) - def GetBaseMethods(self, flag:RefTypeFlag=RefTypeFlag.Default) -> List[MethodInfo]: + def _GetBaseMethods(self, flag:RefTypeFlag=RefTypeFlag.Default) -> List[MethodInfo]: if self._BaseTypes is None: self._InitBaseTypesIfNeeded() result = [] @@ -1225,8 +1231,11 @@ class RefType(ValueInfo): result.extend(baseType.GetMethods(flag)) return result + def GetBaseMethods(self, flag:RefTypeFlag=RefTypeFlag.Default) -> List[MethodInfo]: + return self._GetBaseMethods(flag) + @functools.lru_cache(maxsize=128) - def GetAllBaseMethods(self) -> List[MethodInfo]: + def _GetAllBaseMethods(self) -> List[MethodInfo]: if self._BaseTypes is None: self._InitBaseTypesIfNeeded() result = [] @@ -1234,8 +1243,11 @@ class RefType(ValueInfo): result.extend(baseType.GetAllMethods()) return result + def GetAllBaseMethods(self) -> List[MethodInfo]: + return self._GetAllBaseMethods() + @functools.lru_cache(maxsize=128) - def GetBaseMembers(self, flag:RefTypeFlag=RefTypeFlag.Default) -> List[MemberInfo]: + def _GetBaseMembers(self, flag:RefTypeFlag=RefTypeFlag.Default) -> List[MemberInfo]: if self._BaseTypes is None: self._InitBaseTypesIfNeeded() result = [] @@ -1243,8 +1255,11 @@ class RefType(ValueInfo): result.extend(baseType.GetMembers(flag)) return result + def GetBaseMembers(self, flag:RefTypeFlag=RefTypeFlag.Default) -> List[MemberInfo]: + return self._GetBaseMembers(flag) + @functools.lru_cache(maxsize=128) - def GetAllBaseMembers(self) -> List[MemberInfo]: + def _GetAllBaseMembers(self) -> List[MemberInfo]: if self._BaseTypes is None: self._InitBaseTypesIfNeeded() result = [] @@ -1252,6 +1267,9 @@ class RefType(ValueInfo): result.extend(baseType.GetAllMembers()) return result + def GetAllBaseMembers(self) -> List[MemberInfo]: + return self._GetAllBaseMembers() + def GetFields(self, flag:RefTypeFlag=RefTypeFlag.Default) -> List[FieldInfo]: self._ensure_initialized() if flag == RefTypeFlag.Default: diff --git a/[Test]/test.json b/[Test]/test.json deleted file mode 100644 index 997f24b..0000000 --- a/[Test]/test.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "easy": { - "__type": "__main__.test_log, Global", - "value": { - "__type": "__main__.test_log, Global", - "model_computed_fields": { - "__type": "typing.Any, Global" - }, - "model_extra": null, - "model_fields": { - "__type": "typing.Any, Global" - }, - "model_fields_set": { - "__type": "typing.Any, Global" - }, - "test_field": 1, - "test_field_2": "test" - } - } -} \ No newline at end of file diff --git a/[Test]/test.py b/[Test]/test.py index 401610f..378de61 100644 --- a/[Test]/test.py +++ b/[Test]/test.py @@ -2,18 +2,136 @@ import sys import os sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from Convention.Runtime.Config import * -from Convention.Runtime.EasySave import * +from Convention.Runtime.Asynchrony import * -class Test: - test_field:int = 10 - class_test_field:int = 20 +class Test(Asynchronous): + a:int + b:int + c:int def __init__(self): - self.test_field:int = 0 + super().__init__(c={"timeout":2},b={"timeout":10}) + self.a = 1 -def run(): - print(Test.__annotations__) +async def geta(obj:Test) -> int: + # 字段a已在__init__中初始化,可以直接访问 + if obj.is_field_initialized('a'): + print(f"geta:{obj.a}") + return obj.a + else: + # 使用异步方法获取 + value = await obj.get_field_async('a') + print(f"geta:{value}") + return value + +async def getb(obj:Test) -> int: + # 字段b有超时设置但未初始化,使用异步方法获取 + try: + value = await obj.get_field_async('b') + print(f"getb:{value}") + return value + except TimeoutError as e: + print(f"getb timeout: {e}") + raise + +async def getc(obj:Test) -> int: + # 字段c有超时设置但未初始化,使用异步方法获取 + try: + value = await obj.get_field_async('c') + print(f"getc:{value}") + return value + except TimeoutError as e: + print(f"getc timeout: {e}") + raise + +async def setb(obj:Test, value:int, delay:float = 1) -> None: + await asyncio.sleep(delay) + obj.b = value + +async def run(): + print("=== 测试优化后的异步字段系统 ===") + test = Test() + + # 测试字段状态检查 + print(f"字段初始化状态 - a: {test.is_field_initialized('a')}, b: {test.is_field_initialized('b')}, c: {test.is_field_initialized('c')}") + + # 测试超时设置查询 + print(f"字段超时设置 - b: {test.get_field_timeout('b')}s, c: {test.get_field_timeout('c')}s") + + print("\n=== 测试1:未设置值的情况(应该超时)===") + try: + print("开始并发获取字段值...") + r = await asyncio.gather(geta(test), getb(test), getc(test)) + print(f"结果: {r}") + except Exception as e: + print(f"捕获到异常: {e}") + + print("\n=== 测试2:在超时前设置字段b的值 ===") + # 创建新的测试实例 + test2 = Test() + print(f"设置前字段b初始化状态: {test2.is_field_initialized('b')}") + + try: + # 启动并发任务:设置b的值和获取b的值 + print("启动并发任务:0.5秒后设置b=42,同时尝试获取b的值(超时10秒)") + + # 并发执行:设置b值(延迟0.5秒)和获取b值 + results = await asyncio.gather( + setb(test2, 42, delay=0.5), # 0.5秒后设置b=42 + getb(test2), # 尝试获取b值(会等待直到被设置) + return_exceptions=True + ) + + print(f"并发任务结果: {results}") + print(f"设置后字段b初始化状态: {test2.is_field_initialized('b')}") + + # 再次访问b,应该能立即获取到值 + print("再次访问字段b(应该立即返回):") + b_value = await test2.get_field_async('b') + print(f"字段b的值: {b_value}") + + except Exception as e: + print(f"测试2出现异常: {e}") + + print("\n=== 测试3:使用同步方式设置,异步方式获取 ===") + test3 = Test() + print("使用同步方式设置字段b = 100") + test3.b = 100 + print(f"设置后字段b初始化状态: {test3.is_field_initialized('b')}") + + # 异步获取值 + b_sync_set_value = await test3.get_field_async('b') + print(f"异步获取同步设置的值: {b_sync_set_value}") + + print("\n=== 测试4:测试字段c(短超时,应该仍然超时)===") + try: + print("尝试单独访问字段c(2秒超时)...") + c_value = await test.get_field_async('c') + print(f"字段c的值: {c_value}") + except TimeoutError as timeout_e: + print(f"字段c访问超时(预期): {timeout_e}") + +def test_sync_access(): + """测试同步访问(在非异步上下文中)""" + print("\n=== 测试同步访问 ===") + test = Test() + + # 测试已初始化字段的同步访问 + try: + print(f"同步访问字段a: {test.a}") + except Exception as e: + print(f"同步访问字段a失败: {e}") + + # 测试未初始化字段的同步访问(应该有更友好的错误提示) + try: + print(f"同步访问字段c: {test.c}") + except Exception as e: + print(f"同步访问字段c失败 (预期): {e}") if __name__ == "__main__": - run() + # 测试同步访问 + test_sync_access() + + # 测试异步访问 + print("\n=== 开始异步测试 ===") + run_until_complete(run()) diff --git a/[Test]/test0.py b/[Test]/test0.py deleted file mode 100644 index 1de0101..0000000 --- a/[Test]/test0.py +++ /dev/null @@ -1,3 +0,0 @@ -import math -import r -print(re.findall(r"\d+[.\d]?", "xxxxx$19.99")) From 2bb6f924dfd92308fc123de72a64354cfa733591 Mon Sep 17 00:00:00 2001 From: ninemine <1371605831@qq.com> Date: Fri, 25 Jul 2025 10:48:07 +0800 Subject: [PATCH 2/2] EP Asynchrony --- Convention/Runtime/Asynchrony.py | 101 +++++++++++++++------------- [Test]/test.py | 109 ++++++++++--------------------- 2 files changed, 88 insertions(+), 122 deletions(-) diff --git a/Convention/Runtime/Asynchrony.py b/Convention/Runtime/Asynchrony.py index 4e08082..7d69027 100644 --- a/Convention/Runtime/Asynchrony.py +++ b/Convention/Runtime/Asynchrony.py @@ -39,17 +39,21 @@ class AsyncContextDetector: class AsyncFieldAccessor: """异步字段访问器,封装字段访问逻辑""" - def __init__(self, async_fields: Dict[str, 'AsynchronyExpression'], origin_fields: Dict[str, FieldInfo]): + def __init__( + self, + async_fields: Dict[str, 'AsynchronyExpression'], + origin_fields: Dict[str, FieldInfo] + ) -> None: self._async_fields = async_fields self._origin_fields = origin_fields - async def get_field_value_async(self, field_name: str) -> Any: + async def get_field_value_async(self, field_name: str): """异步获取字段值""" if field_name not in self._origin_fields: raise AttributeError(f"No async field '{field_name}' found") return await self._async_fields[field_name].get_value() - def get_field_value_sync(self, field_name: str) -> Any: + def get_field_value_sync(self, field_name: str): """同步获取字段值(仅在非异步上下文中使用)""" AsyncContextDetector.ensure_async_context_safe(f"sync access to field '{field_name}'") @@ -63,7 +67,7 @@ class AsyncFieldAccessor: elif not async_expr.is_initialize: raise RuntimeError(f"Field '{field_name}' is not initialized and has no timeout") else: - return async_expr.value + return run_async(async_expr.get_value()) def is_field_initialized(self, field_name: str) -> bool: """检查字段是否已初始化""" @@ -89,29 +93,53 @@ class AsynchronyUninitialized: cls.__instance__ = super().__new__(cls) return cls.__instance__ + def __repr__(self): + return "uninitialized" + + def __str__(self): + return "None" + class AsynchronyExpression: def __init__( self, - field: FieldInfo, - value: Any = None, + field: FieldInfo, + value: Any = AsynchronyUninitialized(), *, - time_wait: float = 0.1, - timeout: float = 0 + time_wait: float = 0.1, + timeout: float = 0, + callback: Optional[Action] = None, ): + ''' + 参数: + field: 字段 + value: 初始化, 默认为AsynchronyUninitialized, 即无初始化 + time_wait: 等待时间, 默认为0.1秒 + timeout: 超时时间, 默认为0秒 + callback: 回调函数, 默认为None, 当状态为无初始化时get_value会调用callback + ''' self.field = field - self.value = value - self.is_initialize = False + self._value = value + self.callback = callback + self.is_initialize = not isinstance(value, AsynchronyUninitialized) self.time_wait = time_wait self.timeout = timeout - async def GetValue(self) -> Any: - """获取字段值(保持兼容性的旧方法名)""" - return await self.get_value() - - async def get_value(self) -> Any: + def get_value_sync(self): + if self.is_initialize: + return self._value + elif self.callback is not None: + self.callback() + if self.is_initialize: + return self._value + else: + raise RuntimeError(f"Field {self.field.FieldName} is not initialized") + + async def get_value(self): """异步获取字段值,改进的超时机制""" if self.is_initialize: - return self.value + return self._value + elif self.callback is not None: + self.callback() if self.timeout > 0: try: @@ -119,7 +147,7 @@ class AsynchronyExpression: async def wait_for_initialization(): while not self.is_initialize: await asyncio.sleep(self.time_wait) - return self.value + return self._value return await asyncio.wait_for(wait_for_initialization(), timeout=self.timeout) except asyncio.TimeoutError: @@ -128,18 +156,14 @@ class AsynchronyExpression: # 无超时,一直等待 while not self.is_initialize: await asyncio.sleep(self.time_wait) - return self.value - - def SetValue(self, value: Any) -> None: - """设置字段值(保持兼容性的旧方法名)""" - self.set_value(value) - + return self._value + def set_value(self, value: Any) -> None: """设置字段值""" if isinstance(value, AsynchronyUninitialized): self.set_uninitialized() elif self.field.Verify(type(value)): - self.value = value + self._value = value self.is_initialize = True else: raise ValueError(f"Value {value} is not valid for field {self.field.FieldName}") @@ -151,9 +175,8 @@ class AsynchronyExpression: def set_uninitialized(self) -> None: """设置为未初始化状态""" if self.is_initialize: - if hasattr(self, 'value'): - del self.value - self.value = None + del self._value + self._value = AsynchronyUninitialized() self.is_initialize = False class Asynchronous(ABC): @@ -188,8 +211,8 @@ class Asynchronous(ABC): # 一次性获取所需属性,避免重复调用 try: - field_accessor = super().__getattribute__("_field_accessor") - origin_fields = super().__getattribute__("_GetAsynchronousOriginFields")() + field_accessor:AsyncFieldAccessor = super().__getattribute__("_field_accessor") + origin_fields:Dict[str, FieldInfo] = super().__getattribute__("_GetAsynchronousOriginFields")() except AttributeError: # 对象可能尚未完全初始化 return super().__getattribute__(name) @@ -198,18 +221,17 @@ class Asynchronous(ABC): # 这是一个异步字段 if AsyncContextDetector.is_in_async_context(): # 在异步上下文中,提供友好的错误提示 - async_fields = super().__getattribute__("__Asynchronous_Fields__") + async_fields:Dict[str, AsynchronyExpression] = super().__getattribute__("__Asynchronous_Fields__") async_expr = async_fields[name] if not async_expr.is_initialize: timeout_info = f" with {async_expr.timeout}s timeout" if async_expr.timeout > 0 else "" raise RuntimeError( f"Field '{name}' is not initialized{timeout_info}. " - f"In async context, use 'await obj.get_field_async(\"{name}\")' instead." ) else: # 字段已初始化,直接返回值 - return async_expr.value + return async_expr.get_value_sync() else: # 在同步上下文中,使用字段访问器 try: @@ -259,24 +281,9 @@ class Asynchronous(ABC): super().__delattr__(name) - async def get_field_async(self, field_name: str) -> Any: - """异步获取字段值,适用于异步上下文""" - return await self._field_accessor.get_field_value_async(field_name) - def is_field_initialized(self, field_name: str) -> bool: """检查字段是否已初始化""" return self._field_accessor.is_field_initialized(field_name) - - def set_field_value(self, field_name: str, value: Any) -> None: - """设置字段值""" - self._field_accessor.set_field_value(field_name, value) - - def get_field_timeout(self, field_name: str) -> float: - """获取字段的超时设置""" - origin_fields = self._GetAsynchronousOriginFields() - if field_name not in origin_fields: - raise AttributeError(f"'{type(self).__name__}' object has no async field '{field_name}'") - return self.__Asynchronous_Fields__[field_name].timeout def run_until_complete(coro: Coroutine) -> Any: """Gets an existing event loop to run the coroutine. diff --git a/[Test]/test.py b/[Test]/test.py index 378de61..4ea96ce 100644 --- a/[Test]/test.py +++ b/[Test]/test.py @@ -13,103 +13,61 @@ class Test(Asynchronous): super().__init__(c={"timeout":2},b={"timeout":10}) self.a = 1 -async def geta(obj:Test) -> int: - # 字段a已在__init__中初始化,可以直接访问 - if obj.is_field_initialized('a'): - print(f"geta:{obj.a}") - return obj.a - else: - # 使用异步方法获取 - value = await obj.get_field_async('a') - print(f"geta:{value}") - return value -async def getb(obj:Test) -> int: - # 字段b有超时设置但未初始化,使用异步方法获取 - try: - value = await obj.get_field_async('b') - print(f"getb:{value}") - return value - except TimeoutError as e: - print(f"getb timeout: {e}") - raise - -async def getc(obj:Test) -> int: - # 字段c有超时设置但未初始化,使用异步方法获取 - try: - value = await obj.get_field_async('c') - print(f"getc:{value}") - return value - except TimeoutError as e: - print(f"getc timeout: {e}") - raise +async def seta(obj:Test, value:int, delay:float = 1) -> None: + await asyncio.sleep(delay) + obj.a = value async def setb(obj:Test, value:int, delay:float = 1) -> None: await asyncio.sleep(delay) obj.b = value + +async def setc(obj:Test, value:int, delay:float = 1) -> None: + await asyncio.sleep(delay) + obj.c = value async def run(): print("=== 测试优化后的异步字段系统 ===") test = Test() # 测试字段状态检查 - print(f"字段初始化状态 - a: {test.is_field_initialized('a')}, b: {test.is_field_initialized('b')}, c: {test.is_field_initialized('c')}") - - # 测试超时设置查询 - print(f"字段超时设置 - b: {test.get_field_timeout('b')}s, c: {test.get_field_timeout('c')}s") + assert test.is_field_initialized('a') + assert not test.is_field_initialized('b') + assert not test.is_field_initialized('c') print("\n=== 测试1:未设置值的情况(应该超时)===") try: - print("开始并发获取字段值...") - r = await asyncio.gather(geta(test), getb(test), getc(test)) - print(f"结果: {r}") + print(f"失败: {test.a,test.b,test.c}") + raise RuntimeError("测试1应该超时") except Exception as e: - print(f"捕获到异常: {e}") + print(f"成功: {e}") print("\n=== 测试2:在超时前设置字段b的值 ===") # 创建新的测试实例 test2 = Test() - print(f"设置前字段b初始化状态: {test2.is_field_initialized('b')}") + assert not test2.is_field_initialized('b') - try: - # 启动并发任务:设置b的值和获取b的值 - print("启动并发任务:0.5秒后设置b=42,同时尝试获取b的值(超时10秒)") - - # 并发执行:设置b值(延迟0.5秒)和获取b值 - results = await asyncio.gather( - setb(test2, 42, delay=0.5), # 0.5秒后设置b=42 - getb(test2), # 尝试获取b值(会等待直到被设置) - return_exceptions=True - ) - - print(f"并发任务结果: {results}") - print(f"设置后字段b初始化状态: {test2.is_field_initialized('b')}") - - # 再次访问b,应该能立即获取到值 - print("再次访问字段b(应该立即返回):") - b_value = await test2.get_field_async('b') - print(f"字段b的值: {b_value}") - - except Exception as e: - print(f"测试2出现异常: {e}") + # 启动并发任务:设置b的值和获取b的值 + # 并发执行:设置b值(延迟0.5秒)和获取b值 + await asyncio.gather( + setb(test2, 42, delay=0.5), # 0.5秒后设置b=42 + return_exceptions=True + ) + + assert test2.b == 42 + assert test2.is_field_initialized('b') + assert test2.b == 42 - print("\n=== 测试3:使用同步方式设置,异步方式获取 ===") test3 = Test() - print("使用同步方式设置字段b = 100") test3.b = 100 - print(f"设置后字段b初始化状态: {test3.is_field_initialized('b')}") + assert test3.is_field_initialized('b') + assert test3.b == 100 - # 异步获取值 - b_sync_set_value = await test3.get_field_async('b') - print(f"异步获取同步设置的值: {b_sync_set_value}") - - print("\n=== 测试4:测试字段c(短超时,应该仍然超时)===") + print("\n=== 测试3:测试字段c(短超时,应该仍然超时)===") try: - print("尝试单独访问字段c(2秒超时)...") - c_value = await test.get_field_async('c') - print(f"字段c的值: {c_value}") + print(f"失败: {test.c}") except TimeoutError as timeout_e: - print(f"字段c访问超时(预期): {timeout_e}") + print(f"成功: {timeout_e}") def test_sync_access(): """测试同步访问(在非异步上下文中)""" @@ -118,15 +76,16 @@ def test_sync_access(): # 测试已初始化字段的同步访问 try: - print(f"同步访问字段a: {test.a}") + print(f"成功: a = {test.a}") except Exception as e: - print(f"同步访问字段a失败: {e}") + raise # 测试未初始化字段的同步访问(应该有更友好的错误提示) try: - print(f"同步访问字段c: {test.c}") + print(f"失败: c = {test.c}") + raise RuntimeError("字段c此时不应该能够被访问") except Exception as e: - print(f"同步访问字段c失败 (预期): {e}") + print(f"成功: {e}") if __name__ == "__main__": # 测试同步访问