Merge pull request #1 from NINEMINEsigma/EP-Asynchrony

EP Asynchrony
This commit is contained in:
ninemine
2025-07-25 10:54:01 +08:00
committed by GitHub
5 changed files with 464 additions and 39 deletions

View File

@@ -0,0 +1,353 @@
from .Config import *
from .Reflection import *
from collections import defaultdict
import asyncio
import threading
from typing import Optional
from pydantic import BaseModel
from abc import ABC, abstractmethod
class AsyncContextDetector:
"""异步上下文检测工具类"""
@staticmethod
def is_in_async_context() -> bool:
"""检查是否在异步上下文中运行"""
try:
asyncio.current_task()
return True
except RuntimeError:
return False
@staticmethod
def get_current_loop() -> Optional[asyncio.AbstractEventLoop]:
"""获取当前事件循环如果没有则返回None"""
try:
return asyncio.get_running_loop()
except RuntimeError:
return None
@staticmethod
def ensure_async_context_safe(operation_name: str) -> None:
"""确保在异步上下文中执行是安全的"""
if AsyncContextDetector.is_in_async_context():
raise RuntimeError(
f"Cannot perform '{operation_name}' from within an async context. "
f"Use await or async methods instead."
)
class AsyncFieldAccessor:
"""异步字段访问器,封装字段访问逻辑"""
def __init__(
self,
async_fields: Dict[str, 'AsynchronyExpression'],
origin_fields: Dict[str, FieldInfo]
) -> None:
self._async_fields = async_fields
self._origin_fields = origin_fields
async def get_field_value_async(self, field_name: str):
"""异步获取字段值"""
if field_name not in self._origin_fields:
raise AttributeError(f"No async field '{field_name}' found")
return await self._async_fields[field_name].get_value()
def get_field_value_sync(self, field_name: str):
"""同步获取字段值(仅在非异步上下文中使用)"""
AsyncContextDetector.ensure_async_context_safe(f"sync access to field '{field_name}'")
if field_name not in self._origin_fields:
raise AttributeError(f"No async field '{field_name}' found")
async_expr = self._async_fields[field_name]
if not async_expr.is_initialize and async_expr.timeout > 0:
# 需要等待但在同步上下文中使用run_async
return run_async(async_expr.get_value())
elif not async_expr.is_initialize:
raise RuntimeError(f"Field '{field_name}' is not initialized and has no timeout")
else:
return run_async(async_expr.get_value())
def is_field_initialized(self, field_name: str) -> bool:
"""检查字段是否已初始化"""
if field_name not in self._origin_fields:
raise AttributeError(f"No async field '{field_name}' found")
return self._async_fields[field_name].is_initialize
def set_field_value(self, field_name: str, value: Any) -> None:
"""设置字段值"""
if field_name not in self._origin_fields:
raise AttributeError(f"No async field '{field_name}' found")
self._async_fields[field_name].set_value(value)
class AsynchronyUninitialized:
"""表示未初始化状态的单例类"""
__instance__ = None
_lock = threading.Lock()
def __new__(cls, *args, **kwargs):
if cls.__instance__ is None:
with cls._lock:
if cls.__instance__ is None:
cls.__instance__ = super().__new__(cls)
return cls.__instance__
def __repr__(self):
return "uninitialized"
def __str__(self):
return "None"
class AsynchronyExpression:
def __init__(
self,
field: FieldInfo,
value: Any = AsynchronyUninitialized(),
*,
time_wait: float = 0.1,
timeout: float = 0,
callback: Optional[Action] = None,
):
'''
参数:
field: 字段
value: 初始化, 默认为AsynchronyUninitialized, 即无初始化
time_wait: 等待时间, 默认为0.1秒
timeout: 超时时间, 默认为0秒
callback: 回调函数, 默认为None, 当状态为无初始化时get_value会调用callback
'''
self.field = field
self._value = value
self.callback = callback
self.is_initialize = not isinstance(value, AsynchronyUninitialized)
self.time_wait = time_wait
self.timeout = timeout
def get_value_sync(self):
if self.is_initialize:
return self._value
elif self.callback is not None:
self.callback()
if self.is_initialize:
return self._value
else:
raise RuntimeError(f"Field {self.field.FieldName} is not initialized")
async def get_value(self):
"""异步获取字段值,改进的超时机制"""
if self.is_initialize:
return self._value
elif self.callback is not None:
self.callback()
if self.timeout > 0:
try:
# 使用 asyncio.wait_for 提供更精确的超时控制
async def wait_for_initialization():
while not self.is_initialize:
await asyncio.sleep(self.time_wait)
return self._value
return await asyncio.wait_for(wait_for_initialization(), timeout=self.timeout)
except asyncio.TimeoutError:
raise TimeoutError(f"Timeout waiting for uninitialized field {self.field.FieldName}")
else:
# 无超时,一直等待
while not self.is_initialize:
await asyncio.sleep(self.time_wait)
return self._value
def set_value(self, value: Any) -> None:
"""设置字段值"""
if isinstance(value, AsynchronyUninitialized):
self.set_uninitialized()
elif self.field.Verify(type(value)):
self._value = value
self.is_initialize = True
else:
raise ValueError(f"Value {value} is not valid for field {self.field.FieldName}")
def SetUninitialized(self) -> None:
"""设置为未初始化状态(保持兼容性的旧方法名)"""
self.set_uninitialized()
def set_uninitialized(self) -> None:
"""设置为未初始化状态"""
if self.is_initialize:
del self._value
self._value = AsynchronyUninitialized()
self.is_initialize = False
class Asynchronous(ABC):
__Asynchronous_Origin_Fields__: Dict[Type, Dict[str, FieldInfo]] = defaultdict(dict)
_fields_lock = threading.Lock()
def _GetAsynchronousOriginFields(self) -> Dict[str, FieldInfo]:
return Asynchronous.__Asynchronous_Origin_Fields__[type(self)]
def __init__(self, **kwargs: Dict[str, dict]):
super().__init__()
self.__Asynchronous_Fields__: Dict[str, AsynchronyExpression] = {}
# 使用线程锁保护类变量访问
with Asynchronous._fields_lock:
origin_fields = self._GetAsynchronousOriginFields()
for field_info in TypeManager.GetInstance().CreateOrGetRefTypeFromType(type(self)).GetAllFields():
if field_info.FieldName == "__Asynchronous_Origin_Fields__":
continue
origin_fields[field_info.FieldName] = field_info
self.__Asynchronous_Fields__[field_info.FieldName] = AsynchronyExpression(
field_info, **kwargs.get(field_info.FieldName, {})
)
# 创建字段访问器以提升性能
self._field_accessor = AsyncFieldAccessor(self.__Asynchronous_Fields__, origin_fields)
def __getattribute__(self, name: str) -> Any:
# 快速路径:非异步字段直接返回
if name in ("__Asynchronous_Fields__", "_GetAsynchronousOriginFields", "_field_accessor"):
return super().__getattribute__(name)
# 一次性获取所需属性,避免重复调用
try:
field_accessor:AsyncFieldAccessor = super().__getattribute__("_field_accessor")
origin_fields:Dict[str, FieldInfo] = super().__getattribute__("_GetAsynchronousOriginFields")()
except AttributeError:
# 对象可能尚未完全初始化
return super().__getattribute__(name)
if name in origin_fields:
# 这是一个异步字段
if AsyncContextDetector.is_in_async_context():
# 在异步上下文中,提供友好的错误提示
async_fields:Dict[str, AsynchronyExpression] = super().__getattribute__("__Asynchronous_Fields__")
async_expr = async_fields[name]
if not async_expr.is_initialize:
timeout_info = f" with {async_expr.timeout}s timeout" if async_expr.timeout > 0 else ""
raise RuntimeError(
f"Field '{name}' is not initialized{timeout_info}. "
)
else:
# 字段已初始化,直接返回值
return async_expr.get_value_sync()
else:
# 在同步上下文中,使用字段访问器
try:
return field_accessor.get_field_value_sync(name)
except RuntimeError as e:
if "Cannot perform" in str(e):
# 重新包装错误信息,提供更友好的提示
raise RuntimeError(
f"Cannot access async field '{name}' from sync context when it requires initialization. "
f"Use async context or ensure field is pre-initialized."
) from e
else:
raise
return super().__getattribute__(name)
def __setattr__(self, name: str, value: Any) -> None:
if name in ("__Asynchronous_Fields__", "_GetAsynchronousOriginFields", "_field_accessor"):
super().__setattr__(name, value)
elif hasattr(self, '_field_accessor'):
# 对象已初始化,使用字段访问器
try:
field_accessor = super().__getattribute__("_field_accessor")
field_accessor.set_field_value(name, value)
return
except AttributeError:
# 不是异步字段
pass
super().__setattr__(name, value)
def __delattr__(self, name: str) -> None:
if name in ("__Asynchronous_Fields__", "_GetAsynchronousOriginFields", "_field_accessor"):
super().__delattr__(name)
elif hasattr(self, '_field_accessor'):
# 对象已初始化,使用字段访问器
try:
field_accessor = super().__getattribute__("_field_accessor")
origin_fields = super().__getattribute__("_GetAsynchronousOriginFields")()
if name in origin_fields:
async_fields = super().__getattribute__("__Asynchronous_Fields__")
async_fields[name].set_uninitialized()
return
except AttributeError:
# 不是异步字段
pass
super().__delattr__(name)
def is_field_initialized(self, field_name: str) -> bool:
"""检查字段是否已初始化"""
return self._field_accessor.is_field_initialized(field_name)
def run_until_complete(coro: Coroutine) -> Any:
"""Gets an existing event loop to run the coroutine.
If there is no existing event loop, creates a new one.
"""
try:
# Check if there's an existing event loop
loop = asyncio.get_event_loop()
# If we're here, there's an existing loop but it's not running
return loop.run_until_complete(coro)
except RuntimeError:
# If we can't get the event loop, we're likely in a different thread, or its already running
try:
return asyncio.run(coro)
except RuntimeError:
raise RuntimeError(
"Detected nested async. Please use nest_asyncio.apply() to allow nested event loops."
"Or, use async entry methods like `aquery()`, `aretriever`, `achat`, etc."
)
def run_async_coroutine(coro: Coroutine) -> Any:
try:
# Check if there's an existing event loop
loop = asyncio.get_event_loop()
# If we're here, there's an existing loop but it's not running
return loop.create_task(coro)
except RuntimeError:
# If we can't get the event loop, we're likely in a different thread, or its already running
try:
return asyncio.run(coro)
except RuntimeError:
raise RuntimeError(
"Detected nested async. Please use nest_asyncio.apply() to allow nested event loops."
"Or, use async entry methods like `aquery()`, `aretriever`, `achat`, etc."
)
def run_async(coro: Coroutine):
"""安全地运行异步协程,避免事件循环死锁"""
# 使用统一的异步上下文检测
AsyncContextDetector.ensure_async_context_safe("run_async")
# 尝试获取当前事件循环
current_loop = AsyncContextDetector.get_current_loop()
if current_loop is not None and not current_loop.is_running():
# 有事件循环但未运行,直接使用
return current_loop.run_until_complete(coro)
elif current_loop is None:
# 没有事件循环,创建新的
try:
return asyncio.run(coro)
except RuntimeError as e:
raise RuntimeError(
"Failed to run async coroutine. "
"Please ensure proper async environment or use nest_asyncio.apply() for nested loops."
) from e
else:
# 事件循环正在运行这种情况应该被AsyncContextDetector捕获
raise RuntimeError(
"Unexpected state: running event loop detected but context check passed. "
"This should not happen."
)

View File

@@ -565,7 +565,7 @@ class ValueInfo(BaseInfo):
return ValueInfo(metaType, **kwargs)
else:
return ValueInfo(type_, **kwargs)
elif isinstance(metaType, Self):#metaType is Self:
elif metaType is Self:
if SelfType is None:
raise ReflectionException("SelfType is required when metaType is <Self>")
return ValueInfo.Create(SelfType, **kwargs)
@@ -1198,7 +1198,7 @@ class RefType(ValueInfo):
# 确保正确地实现所有GetBase*方法
@functools.lru_cache(maxsize=128)
def GetBaseFields(self, flag:RefTypeFlag=RefTypeFlag.Default) -> List[FieldInfo]:
def _GetBaseFields(self, flag:RefTypeFlag=RefTypeFlag.Default) -> List[FieldInfo]:
if self._BaseTypes is None:
self._InitBaseTypesIfNeeded()
result = []
@@ -1206,8 +1206,11 @@ class RefType(ValueInfo):
result.extend(baseType.GetFields(flag))
return result
def GetBaseFields(self, flag:RefTypeFlag=RefTypeFlag.Default) -> List[FieldInfo]:
return self._GetBaseFields(flag)
@functools.lru_cache(maxsize=128)
def GetAllBaseFields(self) -> List[FieldInfo]:
def _GetAllBaseFields(self) -> List[FieldInfo]:
if self._BaseTypes is None:
self._InitBaseTypesIfNeeded()
result = []
@@ -1215,9 +1218,12 @@ class RefType(ValueInfo):
result.extend(baseType.GetAllFields())
return result
def GetAllBaseFields(self) -> List[FieldInfo]:
return self._GetAllBaseFields()
# 修改所有的GetBase*方法
@functools.lru_cache(maxsize=128)
def GetBaseMethods(self, flag:RefTypeFlag=RefTypeFlag.Default) -> List[MethodInfo]:
def _GetBaseMethods(self, flag:RefTypeFlag=RefTypeFlag.Default) -> List[MethodInfo]:
if self._BaseTypes is None:
self._InitBaseTypesIfNeeded()
result = []
@@ -1225,8 +1231,11 @@ class RefType(ValueInfo):
result.extend(baseType.GetMethods(flag))
return result
def GetBaseMethods(self, flag:RefTypeFlag=RefTypeFlag.Default) -> List[MethodInfo]:
return self._GetBaseMethods(flag)
@functools.lru_cache(maxsize=128)
def GetAllBaseMethods(self) -> List[MethodInfo]:
def _GetAllBaseMethods(self) -> List[MethodInfo]:
if self._BaseTypes is None:
self._InitBaseTypesIfNeeded()
result = []
@@ -1234,8 +1243,11 @@ class RefType(ValueInfo):
result.extend(baseType.GetAllMethods())
return result
def GetAllBaseMethods(self) -> List[MethodInfo]:
return self._GetAllBaseMethods()
@functools.lru_cache(maxsize=128)
def GetBaseMembers(self, flag:RefTypeFlag=RefTypeFlag.Default) -> List[MemberInfo]:
def _GetBaseMembers(self, flag:RefTypeFlag=RefTypeFlag.Default) -> List[MemberInfo]:
if self._BaseTypes is None:
self._InitBaseTypesIfNeeded()
result = []
@@ -1243,8 +1255,11 @@ class RefType(ValueInfo):
result.extend(baseType.GetMembers(flag))
return result
def GetBaseMembers(self, flag:RefTypeFlag=RefTypeFlag.Default) -> List[MemberInfo]:
return self._GetBaseMembers(flag)
@functools.lru_cache(maxsize=128)
def GetAllBaseMembers(self) -> List[MemberInfo]:
def _GetAllBaseMembers(self) -> List[MemberInfo]:
if self._BaseTypes is None:
self._InitBaseTypesIfNeeded()
result = []
@@ -1252,6 +1267,9 @@ class RefType(ValueInfo):
result.extend(baseType.GetAllMembers())
return result
def GetAllBaseMembers(self) -> List[MemberInfo]:
return self._GetAllBaseMembers()
def GetFields(self, flag:RefTypeFlag=RefTypeFlag.Default) -> List[FieldInfo]:
self._ensure_initialized()
if flag == RefTypeFlag.Default:

View File

@@ -1,20 +0,0 @@
{
"easy": {
"__type": "__main__.test_log, Global",
"value": {
"__type": "__main__.test_log, Global",
"model_computed_fields": {
"__type": "typing.Any, Global"
},
"model_extra": null,
"model_fields": {
"__type": "typing.Any, Global"
},
"model_fields_set": {
"__type": "typing.Any, Global"
},
"test_field": 1,
"test_field_2": "test"
}
}
}

View File

@@ -2,18 +2,95 @@ import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from Convention.Runtime.Config import *
from Convention.Runtime.EasySave import *
from Convention.Runtime.Asynchrony import *
class Test:
test_field:int = 10
class_test_field:int = 20
class Test(Asynchronous):
a:int
b:int
c:int
def __init__(self):
self.test_field:int = 0
super().__init__(c={"timeout":2},b={"timeout":10})
self.a = 1
def run():
print(Test.__annotations__)
async def seta(obj:Test, value:int, delay:float = 1) -> None:
await asyncio.sleep(delay)
obj.a = value
async def setb(obj:Test, value:int, delay:float = 1) -> None:
await asyncio.sleep(delay)
obj.b = value
async def setc(obj:Test, value:int, delay:float = 1) -> None:
await asyncio.sleep(delay)
obj.c = value
async def run():
print("=== 测试优化后的异步字段系统 ===")
test = Test()
# 测试字段状态检查
assert test.is_field_initialized('a')
assert not test.is_field_initialized('b')
assert not test.is_field_initialized('c')
print("\n=== 测试1未设置值的情况应该超时===")
try:
print(f"失败: {test.a,test.b,test.c}")
raise RuntimeError("测试1应该超时")
except Exception as e:
print(f"成功: {e}")
print("\n=== 测试2在超时前设置字段b的值 ===")
# 创建新的测试实例
test2 = Test()
assert not test2.is_field_initialized('b')
# 启动并发任务设置b的值和获取b的值
# 并发执行设置b值延迟0.5秒和获取b值
await asyncio.gather(
setb(test2, 42, delay=0.5), # 0.5秒后设置b=42
return_exceptions=True
)
assert test2.b == 42
assert test2.is_field_initialized('b')
assert test2.b == 42
test3 = Test()
test3.b = 100
assert test3.is_field_initialized('b')
assert test3.b == 100
print("\n=== 测试3测试字段c短超时应该仍然超时===")
try:
print(f"失败: {test.c}")
except TimeoutError as timeout_e:
print(f"成功: {timeout_e}")
def test_sync_access():
"""测试同步访问(在非异步上下文中)"""
print("\n=== 测试同步访问 ===")
test = Test()
# 测试已初始化字段的同步访问
try:
print(f"成功: a = {test.a}")
except Exception as e:
raise
# 测试未初始化字段的同步访问(应该有更友好的错误提示)
try:
print(f"失败: c = {test.c}")
raise RuntimeError("字段c此时不应该能够被访问")
except Exception as e:
print(f"成功: {e}")
if __name__ == "__main__":
run()
# 测试同步访问
test_sync_access()
# 测试异步访问
print("\n=== 开始异步测试 ===")
run_until_complete(run())

View File

@@ -1,3 +0,0 @@
import math
import r
print(re.findall(r"\d+[.\d]?", "xxxxx$19.99"))