Here are a bunch of tricks / idioms I learnt mostly from working with pydantic-ai’s source code
use repr=False and compare=False for private fields
@dataclasses.dataclass
class AgentRunResult(Generic[OutputDataT]):
"""The final result of an agent run."""
output: OutputDataT
"""The output data from the agent run."""
_output_tool_name: str | None = dataclasses.field(
repr=False, compare=False, default=None
)
_state: _agent_graph.GraphAgentState = dataclasses.field(
repr=False, compare=False, default_factory=_agent_graph.GraphAgentState
)
_new_message_index: int = dataclasses.field(
repr=False, compare=False, default=0
)
_traceparent_value: str | None = dataclasses.field(
repr=False, compare=False, default=None
)
test for potentially not installed dependency
@contextmanager
def try_import() -> Iterator[Callable[[], bool]]:
import_success = False
def check_import() -> bool:
return import_success
try:
yield check_import
except ImportError:
pass
else:
import_success = True
with try_import() as imports_successful:
import logfire
from logfire.testing import CaptureLogfire
from pydantic_evals.otel._context_subtree import (
context_subtree,
)
from pydantic_evals.otel.span_tree import SpanQuery, SpanTree
pytestmark = [pytest.mark.skipif(not imports_successful(), reason='pydantic-evals not installed'), pytest.mark.anyio]
UserError
class UserError(RuntimeError):
"""Error caused by a usage mistake by the application developer — You!"""
message: str
"""Description of the mistake."""
def __init__(self, message: str):
self.message = message
super().__init__(message)
option type
T = TypeVar('T')
@dataclass
class Some(Generic[T]):
"""Analogous to Rust's `Option::Some` type."""
value: T
Option: TypeAlias = Some[T] | None
"""Analogous to Rust's `Option` type, usage: `Option[Thing]` is equivalent to `Some[Thing] | None`."""
singleton set (differ from None)
class Unset:
"""A singleton to represent an unset value."""
pass
UNSET = Unset()
def is_set(t_or_unset: T | Unset) -> TypeGuard[T]:
return t_or_unset is not UNSET
peakable async stream
class PeekableAsyncStream(Generic[T]):
"""Wraps an async iterable of type T and allows peeking at the *next* item without consuming it.
We only buffer one item at a time (the next item). Once that item is yielded, it is discarded.
This is a single-pass stream.
"""
def __init__(self, source: AsyncIterable[T]):
self._source = source
self._source_iter: AsyncIterator[T] | None = None
self._buffer: T | Unset = UNSET
self._exhausted = False
async def peek(self) -> T | Unset:
"""Returns the next item that would be yielded without consuming it.
Returns None if the stream is exhausted.
"""
if self._exhausted:
return UNSET
# If we already have a buffered item, just return it.
if not isinstance(self._buffer, Unset):
return self._buffer
# Otherwise, we need to fetch the next item from the underlying iterator.
if self._source_iter is None:
self._source_iter = aiter(self._source)
try:
self._buffer = await anext(self._source_iter)
except StopAsyncIteration:
self._exhausted = True
return UNSET
return self._buffer
async def is_exhausted(self) -> bool:
"""Returns True if the stream is exhausted, False otherwise."""
return isinstance(await self.peek(), Unset)
def __aiter__(self) -> AsyncIterator[T]:
# For a single-pass iteration, we can return self as the iterator.
return self
async def __anext__(self) -> T:
"""Yields the buffered item if present, otherwise fetches the next item from the underlying source.
Raises StopAsyncIteration if the stream is exhausted.
"""
if self._exhausted:
raise StopAsyncIteration
# If we have a buffered item, yield it.
if not isinstance(self._buffer, Unset):
item = self._buffer
self._buffer = UNSET
return item
# Otherwise, fetch the next item from the source.
if self._source_iter is None:
self._source_iter = aiter(self._source)
try:
return await anext(self._source_iter)
except StopAsyncIteration:
self._exhausted = True
raise
check if a callable is async
AwaitableCallable = Callable[..., Awaitable[T]]
@overload
def is_async_callable(obj: AwaitableCallable[T]) -> TypeIs[AwaitableCallable[T]]: ...
@overload
def is_async_callable(obj: Any) -> TypeIs[AwaitableCallable[Any]]: ...
def is_async_callable(obj: Any) -> Any:
"""Correctly check if a callable is async.
This function was copied from Starlette:
https://github.com/encode/starlette/blob/78da9b9e218ab289117df7d62aee200ed4c59617/starlette/_utils.py#L36-L40
"""
while isinstance(obj, functools.partial):
obj = obj.func
return inspect.iscoroutinefunction(obj) or (callable(obj) and inspect.iscoroutinefunction(obj.__call__)) # type: ignore
extract the arguments of a Union type
def _unwrap_annotated(tp: Any) -> Any:
origin = get_origin(tp)
while typing_objects.is_annotated(origin):
tp = tp.__origin__
origin = get_origin(tp)
return tp
def get_union_args(tp: Any) -> tuple[Any, ...]:
"""Extract the arguments of a Union type if `tp` is a union, otherwise return an empty tuple."""
if typing_objects.is_typealiastype(tp):
tp = tp.__value__
tp = _unwrap_annotated(tp)
origin = get_origin(tp)
if is_union_origin(origin):
return tuple(_unwrap_annotated(arg) for arg in get_args(tp))
else:
return ()
get variable name
def infer_obj_name(obj: Any, *, depth: int) -> str | None:
"""Infer the variable name of an object from the calling frame's scope.
This function examines the call stack to find what variable name was used
for the given object in the calling scope. This is useful for automatic
naming of objects based on their variable names.
Args:
obj: The object whose variable name to infer.
depth: Number of stack frames to traverse upward from the current frame.
Returns:
The inferred variable name if found, None otherwise.
Example:
Usage should generally look like `infer_name(self, depth=2)` or similar.
"""
target_frame = inspect.currentframe()
if target_frame is None:
return None # pragma: no cover
for _ in range(depth):
target_frame = target_frame.f_back
if target_frame is None:
return None
for name, item in target_frame.f_locals.items():
if item is obj:
return name
if target_frame.f_locals != target_frame.f_globals: # pragma: no branch
# if we couldn't find the agent in locals and globals are a different dict, try globals
for name, item in target_frame.f_globals.items():
if item is obj:
return name
return None
__test__ = False
- The
__test__attribute is a special attribute recognized by pytest (Python’s testing framework). - When
__test__ = Falseis set on a class, it tells pytest to skip that class during test discovery and collection. This prevents pytest from treating it as a test class, even though the class name starts with “Test” which would normally signal to pytest that it contains tests.
class TestEnv:
__test__ = False
...
TestEnv
class TestEnv:
__test__ = False
def __init__(self):
self.envars: dict[str, str | None] = {}
def set(self, name: str, value: str) -> None:
self.envars[name] = os.getenv(name)
os.environ[name] = value
def remove(self, name: str) -> None:
self.envars[name] = os.environ.pop(name, None)
def reset(self) -> None:
for name, value in self.envars.items():
if value is None:
os.environ.pop(name, None)
else:
os.environ[name] = value
- Usage
def test_cerebras_provider_need_api_key(env: TestEnv) -> None:
env.remove('CEREBRAS_API_KEY')
with pytest.raises(
UserError,
match=re.escape(
'Set the `CEREBRAS_API_KEY` environment variable or pass it via `CerebrasProvider(api_key=...)` '
'to use the Cerebras provider.'
),
):
CerebrasProvider()
def test_bedrock_provider_model_profile(env: TestEnv, mocker: MockerFixture):
env.set('AWS_DEFAULT_REGION', 'us-east-1')
provider = BedrockProvider()
from __future__ import annotations as _annotations- mark it as private for
__all__? for doc generation? for IDE autocomplete hint?
- mark it as private for