From 4d0f24fd0c50ae12ba28badab2452ad098e06ecb Mon Sep 17 00:00:00 2001 From: ninemine <1371605831@qq.com> Date: Thu, 24 Jul 2025 11:45:44 +0800 Subject: [PATCH] =?UTF-8?q?EP=20Asynchrony=20=E9=9D=9E=E5=9B=A0=E6=9E=9C?= =?UTF-8?q?=E5=9E=8B=E5=BC=82=E6=AD=A5=E6=9E=B6=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Convention/Runtime/Asynchrony.py | 346 +++++++++++++++++++++++++++++++ Convention/Runtime/Reflection.py | 32 ++- [Test]/test.json | 20 -- [Test]/test.py | 136 +++++++++++- [Test]/test0.py | 3 - 5 files changed, 498 insertions(+), 39 deletions(-) create mode 100644 Convention/Runtime/Asynchrony.py delete mode 100644 [Test]/test.json delete mode 100644 [Test]/test0.py diff --git a/Convention/Runtime/Asynchrony.py b/Convention/Runtime/Asynchrony.py new file mode 100644 index 0000000..4e08082 --- /dev/null +++ b/Convention/Runtime/Asynchrony.py @@ -0,0 +1,346 @@ +from .Config import * +from .Reflection import * +from collections import defaultdict +import asyncio +import threading +from typing import Optional +from pydantic import BaseModel +from abc import ABC, abstractmethod + +class AsyncContextDetector: + """异步上下文检测工具类""" + + @staticmethod + def is_in_async_context() -> bool: + """检查是否在异步上下文中运行""" + try: + asyncio.current_task() + return True + except RuntimeError: + return False + + @staticmethod + def get_current_loop() -> Optional[asyncio.AbstractEventLoop]: + """获取当前事件循环,如果没有则返回None""" + try: + return asyncio.get_running_loop() + except RuntimeError: + return None + + @staticmethod + def ensure_async_context_safe(operation_name: str) -> None: + """确保在异步上下文中执行是安全的""" + if AsyncContextDetector.is_in_async_context(): + raise RuntimeError( + f"Cannot perform '{operation_name}' from within an async context. " + f"Use await or async methods instead." + ) + +class AsyncFieldAccessor: + """异步字段访问器,封装字段访问逻辑""" + + def __init__(self, async_fields: Dict[str, 'AsynchronyExpression'], origin_fields: Dict[str, FieldInfo]): + self._async_fields = async_fields + self._origin_fields = origin_fields + + async def get_field_value_async(self, field_name: str) -> Any: + """异步获取字段值""" + if field_name not in self._origin_fields: + raise AttributeError(f"No async field '{field_name}' found") + return await self._async_fields[field_name].get_value() + + def get_field_value_sync(self, field_name: str) -> Any: + """同步获取字段值(仅在非异步上下文中使用)""" + AsyncContextDetector.ensure_async_context_safe(f"sync access to field '{field_name}'") + + if field_name not in self._origin_fields: + raise AttributeError(f"No async field '{field_name}' found") + + async_expr = self._async_fields[field_name] + if not async_expr.is_initialize and async_expr.timeout > 0: + # 需要等待但在同步上下文中,使用run_async + return run_async(async_expr.get_value()) + elif not async_expr.is_initialize: + raise RuntimeError(f"Field '{field_name}' is not initialized and has no timeout") + else: + return async_expr.value + + def is_field_initialized(self, field_name: str) -> bool: + """检查字段是否已初始化""" + if field_name not in self._origin_fields: + raise AttributeError(f"No async field '{field_name}' found") + return self._async_fields[field_name].is_initialize + + def set_field_value(self, field_name: str, value: Any) -> None: + """设置字段值""" + if field_name not in self._origin_fields: + raise AttributeError(f"No async field '{field_name}' found") + self._async_fields[field_name].set_value(value) + +class AsynchronyUninitialized: + """表示未初始化状态的单例类""" + __instance__ = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + if cls.__instance__ is None: + with cls._lock: + if cls.__instance__ is None: + cls.__instance__ = super().__new__(cls) + return cls.__instance__ + +class AsynchronyExpression: + def __init__( + self, + field: FieldInfo, + value: Any = None, + *, + time_wait: float = 0.1, + timeout: float = 0 + ): + self.field = field + self.value = value + self.is_initialize = False + self.time_wait = time_wait + self.timeout = timeout + + async def GetValue(self) -> Any: + """获取字段值(保持兼容性的旧方法名)""" + return await self.get_value() + + async def get_value(self) -> Any: + """异步获取字段值,改进的超时机制""" + if self.is_initialize: + return self.value + + if self.timeout > 0: + try: + # 使用 asyncio.wait_for 提供更精确的超时控制 + async def wait_for_initialization(): + while not self.is_initialize: + await asyncio.sleep(self.time_wait) + return self.value + + return await asyncio.wait_for(wait_for_initialization(), timeout=self.timeout) + except asyncio.TimeoutError: + raise TimeoutError(f"Timeout waiting for uninitialized field {self.field.FieldName}") + else: + # 无超时,一直等待 + while not self.is_initialize: + await asyncio.sleep(self.time_wait) + return self.value + + def SetValue(self, value: Any) -> None: + """设置字段值(保持兼容性的旧方法名)""" + self.set_value(value) + + def set_value(self, value: Any) -> None: + """设置字段值""" + if isinstance(value, AsynchronyUninitialized): + self.set_uninitialized() + elif self.field.Verify(type(value)): + self.value = value + self.is_initialize = True + else: + raise ValueError(f"Value {value} is not valid for field {self.field.FieldName}") + + def SetUninitialized(self) -> None: + """设置为未初始化状态(保持兼容性的旧方法名)""" + self.set_uninitialized() + + def set_uninitialized(self) -> None: + """设置为未初始化状态""" + if self.is_initialize: + if hasattr(self, 'value'): + del self.value + self.value = None + self.is_initialize = False + +class Asynchronous(ABC): + __Asynchronous_Origin_Fields__: Dict[Type, Dict[str, FieldInfo]] = defaultdict(dict) + _fields_lock = threading.Lock() + + def _GetAsynchronousOriginFields(self) -> Dict[str, FieldInfo]: + return Asynchronous.__Asynchronous_Origin_Fields__[type(self)] + + def __init__(self, **kwargs: Dict[str, dict]): + super().__init__() + self.__Asynchronous_Fields__: Dict[str, AsynchronyExpression] = {} + + # 使用线程锁保护类变量访问 + with Asynchronous._fields_lock: + origin_fields = self._GetAsynchronousOriginFields() + for field_info in TypeManager.GetInstance().CreateOrGetRefTypeFromType(type(self)).GetAllFields(): + if field_info.FieldName == "__Asynchronous_Origin_Fields__": + continue + origin_fields[field_info.FieldName] = field_info + self.__Asynchronous_Fields__[field_info.FieldName] = AsynchronyExpression( + field_info, **kwargs.get(field_info.FieldName, {}) + ) + + # 创建字段访问器以提升性能 + self._field_accessor = AsyncFieldAccessor(self.__Asynchronous_Fields__, origin_fields) + + def __getattribute__(self, name: str) -> Any: + # 快速路径:非异步字段直接返回 + if name in ("__Asynchronous_Fields__", "_GetAsynchronousOriginFields", "_field_accessor"): + return super().__getattribute__(name) + + # 一次性获取所需属性,避免重复调用 + try: + field_accessor = super().__getattribute__("_field_accessor") + origin_fields = super().__getattribute__("_GetAsynchronousOriginFields")() + except AttributeError: + # 对象可能尚未完全初始化 + return super().__getattribute__(name) + + if name in origin_fields: + # 这是一个异步字段 + if AsyncContextDetector.is_in_async_context(): + # 在异步上下文中,提供友好的错误提示 + async_fields = super().__getattribute__("__Asynchronous_Fields__") + async_expr = async_fields[name] + + if not async_expr.is_initialize: + timeout_info = f" with {async_expr.timeout}s timeout" if async_expr.timeout > 0 else "" + raise RuntimeError( + f"Field '{name}' is not initialized{timeout_info}. " + f"In async context, use 'await obj.get_field_async(\"{name}\")' instead." + ) + else: + # 字段已初始化,直接返回值 + return async_expr.value + else: + # 在同步上下文中,使用字段访问器 + try: + return field_accessor.get_field_value_sync(name) + except RuntimeError as e: + if "Cannot perform" in str(e): + # 重新包装错误信息,提供更友好的提示 + raise RuntimeError( + f"Cannot access async field '{name}' from sync context when it requires initialization. " + f"Use async context or ensure field is pre-initialized." + ) from e + else: + raise + + return super().__getattribute__(name) + + def __setattr__(self, name: str, value: Any) -> None: + if name in ("__Asynchronous_Fields__", "_GetAsynchronousOriginFields", "_field_accessor"): + super().__setattr__(name, value) + elif hasattr(self, '_field_accessor'): + # 对象已初始化,使用字段访问器 + try: + field_accessor = super().__getattribute__("_field_accessor") + field_accessor.set_field_value(name, value) + return + except AttributeError: + # 不是异步字段 + pass + + super().__setattr__(name, value) + + def __delattr__(self, name: str) -> None: + if name in ("__Asynchronous_Fields__", "_GetAsynchronousOriginFields", "_field_accessor"): + super().__delattr__(name) + elif hasattr(self, '_field_accessor'): + # 对象已初始化,使用字段访问器 + try: + field_accessor = super().__getattribute__("_field_accessor") + origin_fields = super().__getattribute__("_GetAsynchronousOriginFields")() + if name in origin_fields: + async_fields = super().__getattribute__("__Asynchronous_Fields__") + async_fields[name].set_uninitialized() + return + except AttributeError: + # 不是异步字段 + pass + + super().__delattr__(name) + + async def get_field_async(self, field_name: str) -> Any: + """异步获取字段值,适用于异步上下文""" + return await self._field_accessor.get_field_value_async(field_name) + + def is_field_initialized(self, field_name: str) -> bool: + """检查字段是否已初始化""" + return self._field_accessor.is_field_initialized(field_name) + + def set_field_value(self, field_name: str, value: Any) -> None: + """设置字段值""" + self._field_accessor.set_field_value(field_name, value) + + def get_field_timeout(self, field_name: str) -> float: + """获取字段的超时设置""" + origin_fields = self._GetAsynchronousOriginFields() + if field_name not in origin_fields: + raise AttributeError(f"'{type(self).__name__}' object has no async field '{field_name}'") + return self.__Asynchronous_Fields__[field_name].timeout + +def run_until_complete(coro: Coroutine) -> Any: + """Gets an existing event loop to run the coroutine. + + If there is no existing event loop, creates a new one. + """ + try: + # Check if there's an existing event loop + loop = asyncio.get_event_loop() + + # If we're here, there's an existing loop but it's not running + return loop.run_until_complete(coro) + + except RuntimeError: + # If we can't get the event loop, we're likely in a different thread, or its already running + try: + return asyncio.run(coro) + except RuntimeError: + raise RuntimeError( + "Detected nested async. Please use nest_asyncio.apply() to allow nested event loops." + "Or, use async entry methods like `aquery()`, `aretriever`, `achat`, etc." + ) + +def run_async_coroutine(coro: Coroutine) -> Any: + try: + # Check if there's an existing event loop + loop = asyncio.get_event_loop() + + # If we're here, there's an existing loop but it's not running + return loop.create_task(coro) + + except RuntimeError: + # If we can't get the event loop, we're likely in a different thread, or its already running + try: + return asyncio.run(coro) + except RuntimeError: + raise RuntimeError( + "Detected nested async. Please use nest_asyncio.apply() to allow nested event loops." + "Or, use async entry methods like `aquery()`, `aretriever`, `achat`, etc." + ) + +def run_async(coro: Coroutine): + """安全地运行异步协程,避免事件循环死锁""" + # 使用统一的异步上下文检测 + AsyncContextDetector.ensure_async_context_safe("run_async") + + # 尝试获取当前事件循环 + current_loop = AsyncContextDetector.get_current_loop() + + if current_loop is not None and not current_loop.is_running(): + # 有事件循环但未运行,直接使用 + return current_loop.run_until_complete(coro) + elif current_loop is None: + # 没有事件循环,创建新的 + try: + return asyncio.run(coro) + except RuntimeError as e: + raise RuntimeError( + "Failed to run async coroutine. " + "Please ensure proper async environment or use nest_asyncio.apply() for nested loops." + ) from e + else: + # 事件循环正在运行,这种情况应该被AsyncContextDetector捕获 + raise RuntimeError( + "Unexpected state: running event loop detected but context check passed. " + "This should not happen." + ) diff --git a/Convention/Runtime/Reflection.py b/Convention/Runtime/Reflection.py index e7c2b8b..967c656 100644 --- a/Convention/Runtime/Reflection.py +++ b/Convention/Runtime/Reflection.py @@ -565,7 +565,7 @@ class ValueInfo(BaseInfo): return ValueInfo(metaType, **kwargs) else: return ValueInfo(type_, **kwargs) - elif isinstance(metaType, Self):#metaType is Self: + elif metaType is Self: if SelfType is None: raise ReflectionException("SelfType is required when metaType is ") return ValueInfo.Create(SelfType, **kwargs) @@ -1198,7 +1198,7 @@ class RefType(ValueInfo): # 确保正确地实现所有GetBase*方法 @functools.lru_cache(maxsize=128) - def GetBaseFields(self, flag:RefTypeFlag=RefTypeFlag.Default) -> List[FieldInfo]: + def _GetBaseFields(self, flag:RefTypeFlag=RefTypeFlag.Default) -> List[FieldInfo]: if self._BaseTypes is None: self._InitBaseTypesIfNeeded() result = [] @@ -1206,8 +1206,11 @@ class RefType(ValueInfo): result.extend(baseType.GetFields(flag)) return result + def GetBaseFields(self, flag:RefTypeFlag=RefTypeFlag.Default) -> List[FieldInfo]: + return self._GetBaseFields(flag) + @functools.lru_cache(maxsize=128) - def GetAllBaseFields(self) -> List[FieldInfo]: + def _GetAllBaseFields(self) -> List[FieldInfo]: if self._BaseTypes is None: self._InitBaseTypesIfNeeded() result = [] @@ -1215,9 +1218,12 @@ class RefType(ValueInfo): result.extend(baseType.GetAllFields()) return result + def GetAllBaseFields(self) -> List[FieldInfo]: + return self._GetAllBaseFields() + # 修改所有的GetBase*方法 @functools.lru_cache(maxsize=128) - def GetBaseMethods(self, flag:RefTypeFlag=RefTypeFlag.Default) -> List[MethodInfo]: + def _GetBaseMethods(self, flag:RefTypeFlag=RefTypeFlag.Default) -> List[MethodInfo]: if self._BaseTypes is None: self._InitBaseTypesIfNeeded() result = [] @@ -1225,8 +1231,11 @@ class RefType(ValueInfo): result.extend(baseType.GetMethods(flag)) return result + def GetBaseMethods(self, flag:RefTypeFlag=RefTypeFlag.Default) -> List[MethodInfo]: + return self._GetBaseMethods(flag) + @functools.lru_cache(maxsize=128) - def GetAllBaseMethods(self) -> List[MethodInfo]: + def _GetAllBaseMethods(self) -> List[MethodInfo]: if self._BaseTypes is None: self._InitBaseTypesIfNeeded() result = [] @@ -1234,8 +1243,11 @@ class RefType(ValueInfo): result.extend(baseType.GetAllMethods()) return result + def GetAllBaseMethods(self) -> List[MethodInfo]: + return self._GetAllBaseMethods() + @functools.lru_cache(maxsize=128) - def GetBaseMembers(self, flag:RefTypeFlag=RefTypeFlag.Default) -> List[MemberInfo]: + def _GetBaseMembers(self, flag:RefTypeFlag=RefTypeFlag.Default) -> List[MemberInfo]: if self._BaseTypes is None: self._InitBaseTypesIfNeeded() result = [] @@ -1243,8 +1255,11 @@ class RefType(ValueInfo): result.extend(baseType.GetMembers(flag)) return result + def GetBaseMembers(self, flag:RefTypeFlag=RefTypeFlag.Default) -> List[MemberInfo]: + return self._GetBaseMembers(flag) + @functools.lru_cache(maxsize=128) - def GetAllBaseMembers(self) -> List[MemberInfo]: + def _GetAllBaseMembers(self) -> List[MemberInfo]: if self._BaseTypes is None: self._InitBaseTypesIfNeeded() result = [] @@ -1252,6 +1267,9 @@ class RefType(ValueInfo): result.extend(baseType.GetAllMembers()) return result + def GetAllBaseMembers(self) -> List[MemberInfo]: + return self._GetAllBaseMembers() + def GetFields(self, flag:RefTypeFlag=RefTypeFlag.Default) -> List[FieldInfo]: self._ensure_initialized() if flag == RefTypeFlag.Default: diff --git a/[Test]/test.json b/[Test]/test.json deleted file mode 100644 index 997f24b..0000000 --- a/[Test]/test.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "easy": { - "__type": "__main__.test_log, Global", - "value": { - "__type": "__main__.test_log, Global", - "model_computed_fields": { - "__type": "typing.Any, Global" - }, - "model_extra": null, - "model_fields": { - "__type": "typing.Any, Global" - }, - "model_fields_set": { - "__type": "typing.Any, Global" - }, - "test_field": 1, - "test_field_2": "test" - } - } -} \ No newline at end of file diff --git a/[Test]/test.py b/[Test]/test.py index 401610f..378de61 100644 --- a/[Test]/test.py +++ b/[Test]/test.py @@ -2,18 +2,136 @@ import sys import os sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from Convention.Runtime.Config import * -from Convention.Runtime.EasySave import * +from Convention.Runtime.Asynchrony import * -class Test: - test_field:int = 10 - class_test_field:int = 20 +class Test(Asynchronous): + a:int + b:int + c:int def __init__(self): - self.test_field:int = 0 + super().__init__(c={"timeout":2},b={"timeout":10}) + self.a = 1 -def run(): - print(Test.__annotations__) +async def geta(obj:Test) -> int: + # 字段a已在__init__中初始化,可以直接访问 + if obj.is_field_initialized('a'): + print(f"geta:{obj.a}") + return obj.a + else: + # 使用异步方法获取 + value = await obj.get_field_async('a') + print(f"geta:{value}") + return value + +async def getb(obj:Test) -> int: + # 字段b有超时设置但未初始化,使用异步方法获取 + try: + value = await obj.get_field_async('b') + print(f"getb:{value}") + return value + except TimeoutError as e: + print(f"getb timeout: {e}") + raise + +async def getc(obj:Test) -> int: + # 字段c有超时设置但未初始化,使用异步方法获取 + try: + value = await obj.get_field_async('c') + print(f"getc:{value}") + return value + except TimeoutError as e: + print(f"getc timeout: {e}") + raise + +async def setb(obj:Test, value:int, delay:float = 1) -> None: + await asyncio.sleep(delay) + obj.b = value + +async def run(): + print("=== 测试优化后的异步字段系统 ===") + test = Test() + + # 测试字段状态检查 + print(f"字段初始化状态 - a: {test.is_field_initialized('a')}, b: {test.is_field_initialized('b')}, c: {test.is_field_initialized('c')}") + + # 测试超时设置查询 + print(f"字段超时设置 - b: {test.get_field_timeout('b')}s, c: {test.get_field_timeout('c')}s") + + print("\n=== 测试1:未设置值的情况(应该超时)===") + try: + print("开始并发获取字段值...") + r = await asyncio.gather(geta(test), getb(test), getc(test)) + print(f"结果: {r}") + except Exception as e: + print(f"捕获到异常: {e}") + + print("\n=== 测试2:在超时前设置字段b的值 ===") + # 创建新的测试实例 + test2 = Test() + print(f"设置前字段b初始化状态: {test2.is_field_initialized('b')}") + + try: + # 启动并发任务:设置b的值和获取b的值 + print("启动并发任务:0.5秒后设置b=42,同时尝试获取b的值(超时10秒)") + + # 并发执行:设置b值(延迟0.5秒)和获取b值 + results = await asyncio.gather( + setb(test2, 42, delay=0.5), # 0.5秒后设置b=42 + getb(test2), # 尝试获取b值(会等待直到被设置) + return_exceptions=True + ) + + print(f"并发任务结果: {results}") + print(f"设置后字段b初始化状态: {test2.is_field_initialized('b')}") + + # 再次访问b,应该能立即获取到值 + print("再次访问字段b(应该立即返回):") + b_value = await test2.get_field_async('b') + print(f"字段b的值: {b_value}") + + except Exception as e: + print(f"测试2出现异常: {e}") + + print("\n=== 测试3:使用同步方式设置,异步方式获取 ===") + test3 = Test() + print("使用同步方式设置字段b = 100") + test3.b = 100 + print(f"设置后字段b初始化状态: {test3.is_field_initialized('b')}") + + # 异步获取值 + b_sync_set_value = await test3.get_field_async('b') + print(f"异步获取同步设置的值: {b_sync_set_value}") + + print("\n=== 测试4:测试字段c(短超时,应该仍然超时)===") + try: + print("尝试单独访问字段c(2秒超时)...") + c_value = await test.get_field_async('c') + print(f"字段c的值: {c_value}") + except TimeoutError as timeout_e: + print(f"字段c访问超时(预期): {timeout_e}") + +def test_sync_access(): + """测试同步访问(在非异步上下文中)""" + print("\n=== 测试同步访问 ===") + test = Test() + + # 测试已初始化字段的同步访问 + try: + print(f"同步访问字段a: {test.a}") + except Exception as e: + print(f"同步访问字段a失败: {e}") + + # 测试未初始化字段的同步访问(应该有更友好的错误提示) + try: + print(f"同步访问字段c: {test.c}") + except Exception as e: + print(f"同步访问字段c失败 (预期): {e}") if __name__ == "__main__": - run() + # 测试同步访问 + test_sync_access() + + # 测试异步访问 + print("\n=== 开始异步测试 ===") + run_until_complete(run()) diff --git a/[Test]/test0.py b/[Test]/test0.py deleted file mode 100644 index 1de0101..0000000 --- a/[Test]/test0.py +++ /dev/null @@ -1,3 +0,0 @@ -import math -import r -print(re.findall(r"\d+[.\d]?", "xxxxx$19.99"))