from __future__ import annotations
import asyncio
from concurrent.futures import Executor, Future, ThreadPoolExecutor
from contextlib import contextmanager
from contextvars import ContextVar, copy_context
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
Generator,
Iterable,
Iterator,
List,
Optional,
TypeVar,
Union,
cast,
)
from typing_extensions import ParamSpec, TypedDict
from langchain_core.runnables.utils import (
Input,
Output,
accepts_config,
accepts_run_manager,
)
if TYPE_CHECKING:
from langchain_core.callbacks.base import BaseCallbackManager, Callbacks
from langchain_core.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForChainRun,
CallbackManager,
CallbackManagerForChainRun,
)
else:
# Pydantic validates through typed dicts, but
# the callbacks need forward refs updated
Callbacks = Optional[Union[List, Any]]
[docs]class EmptyDict(TypedDict, total=False):
"""Empty dict type."""
pass
[docs]class RunnableConfig(TypedDict, total=False):
"""Configuration for a Runnable."""
tags: List[str]
"""
Tags for this call and any sub-calls (eg. a Chain calling an LLM).
You can use these to filter calls.
"""
metadata: Dict[str, Any]
"""
Metadata for this call and any sub-calls (eg. a Chain calling an LLM).
Keys should be strings, values should be JSON-serializable.
"""
callbacks: Callbacks
"""
Callbacks for this call and any sub-calls (eg. a Chain calling an LLM).
Tags are passed to all callbacks, metadata is passed to handle*Start callbacks.
"""
run_name: str
"""
Name for the tracer run for this call. Defaults to the name of the class.
"""
max_concurrency: Optional[int]
"""
Maximum number of parallel calls to make. If not provided, defaults to
ThreadPoolExecutor's default.
"""
recursion_limit: int
"""
Maximum number of times a call can recurse. If not provided, defaults to 25.
"""
configurable: Dict[str, Any]
"""
Runtime values for attributes previously made configurable on this Runnable,
or sub-Runnables, through .configurable_fields() or .configurable_alternatives().
Check .output_schema() for a description of the attributes that have been made
configurable.
"""
var_child_runnable_config = ContextVar(
"child_runnable_config", default=RunnableConfig()
)
[docs]def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
"""Ensure that a config is a dict with all keys present.
Args:
config (Optional[RunnableConfig], optional): The config to ensure.
Defaults to None.
Returns:
RunnableConfig: The ensured config.
"""
empty = RunnableConfig(
tags=[],
metadata={},
callbacks=None,
recursion_limit=25,
)
if var_config := var_child_runnable_config.get():
empty.update(
cast(RunnableConfig, {k: v for k, v in var_config.items() if v is not None})
)
if config is not None:
empty.update(
cast(RunnableConfig, {k: v for k, v in config.items() if v is not None})
)
return empty
[docs]def get_config_list(
config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int
) -> List[RunnableConfig]:
"""Get a list of configs from a single config or a list of configs.
It is useful for subclasses overriding batch() or abatch().
Args:
config (Optional[Union[RunnableConfig, List[RunnableConfig]]]):
The config or list of configs.
length (int): The length of the list.
Returns:
List[RunnableConfig]: The list of configs.
Raises:
ValueError: If the length of the list is not equal to the length of the inputs.
"""
if length < 0:
raise ValueError(f"length must be >= 0, but got {length}")
if isinstance(config, list) and len(config) != length:
raise ValueError(
f"config must be a list of the same length as inputs, "
f"but got {len(config)} configs for {length} inputs"
)
return (
list(map(ensure_config, config))
if isinstance(config, list)
else [ensure_config(config) for _ in range(length)]
)
[docs]def patch_config(
config: Optional[RunnableConfig],
*,
callbacks: Optional[BaseCallbackManager] = None,
recursion_limit: Optional[int] = None,
max_concurrency: Optional[int] = None,
run_name: Optional[str] = None,
configurable: Optional[Dict[str, Any]] = None,
) -> RunnableConfig:
"""Patch a config with new values.
Args:
config (Optional[RunnableConfig]): The config to patch.
copy_locals (bool, optional): Whether to copy locals. Defaults to False.
callbacks (Optional[BaseCallbackManager], optional): The callbacks to set.
Defaults to None.
recursion_limit (Optional[int], optional): The recursion limit to set.
Defaults to None.
max_concurrency (Optional[int], optional): The max concurrency to set.
Defaults to None.
run_name (Optional[str], optional): The run name to set. Defaults to None.
configurable (Optional[Dict[str, Any]], optional): The configurable to set.
Defaults to None.
Returns:
RunnableConfig: The patched config.
"""
config = ensure_config(config)
if callbacks is not None:
# If we're replacing callbacks, we need to unset run_name
# As that should apply only to the same run as the original callbacks
config["callbacks"] = callbacks
if "run_name" in config:
del config["run_name"]
if recursion_limit is not None:
config["recursion_limit"] = recursion_limit
if max_concurrency is not None:
config["max_concurrency"] = max_concurrency
if run_name is not None:
config["run_name"] = run_name
if configurable is not None:
config["configurable"] = {**config.get("configurable", {}), **configurable}
return config
[docs]def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
"""Merge multiple configs into one.
Args:
*configs (Optional[RunnableConfig]): The configs to merge.
Returns:
RunnableConfig: The merged config.
"""
base: RunnableConfig = {}
# Even though the keys aren't literals, this is correct
# because both dicts are the same type
for config in (c for c in configs if c is not None):
for key in config:
if key == "metadata":
base[key] = { # type: ignore
**base.get(key, {}), # type: ignore
**(config.get(key) or {}), # type: ignore
}
elif key == "tags":
base[key] = list( # type: ignore
set(base.get(key, []) + (config.get(key) or [])), # type: ignore
)
elif key == "configurable":
base[key] = { # type: ignore
**base.get(key, {}), # type: ignore
**(config.get(key) or {}), # type: ignore
}
elif key == "callbacks":
base_callbacks = base.get("callbacks")
these_callbacks = config["callbacks"]
# callbacks can be either None, list[handler] or manager
# so merging two callbacks values has 6 cases
if isinstance(these_callbacks, list):
if base_callbacks is None:
base["callbacks"] = these_callbacks
elif isinstance(base_callbacks, list):
base["callbacks"] = base_callbacks + these_callbacks
else:
# base_callbacks is a manager
mngr = base_callbacks.copy()
for callback in these_callbacks:
mngr.add_handler(callback, inherit=True)
base["callbacks"] = mngr
elif these_callbacks is not None:
# these_callbacks is a manager
if base_callbacks is None:
base["callbacks"] = these_callbacks
elif isinstance(base_callbacks, list):
mngr = these_callbacks.copy()
for callback in base_callbacks:
mngr.add_handler(callback, inherit=True)
base["callbacks"] = mngr
else:
# base_callbacks is also a manager
base["callbacks"] = base_callbacks.__class__(
parent_run_id=base_callbacks.parent_run_id
or these_callbacks.parent_run_id,
handlers=base_callbacks.handlers + these_callbacks.handlers,
inheritable_handlers=base_callbacks.inheritable_handlers
+ these_callbacks.inheritable_handlers,
tags=list(set(base_callbacks.tags + these_callbacks.tags)),
inheritable_tags=list(
set(
base_callbacks.inheritable_tags
+ these_callbacks.inheritable_tags
)
),
metadata={
**base_callbacks.metadata,
**these_callbacks.metadata,
},
)
else:
base[key] = config[key] or base.get(key) # type: ignore
return base
[docs]def call_func_with_variable_args(
func: Union[
Callable[[Input], Output],
Callable[[Input, RunnableConfig], Output],
Callable[[Input, CallbackManagerForChainRun], Output],
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output],
],
input: Input,
config: RunnableConfig,
run_manager: Optional[CallbackManagerForChainRun] = None,
**kwargs: Any,
) -> Output:
"""Call function that may optionally accept a run_manager and/or config.
Args:
func (Union[Callable[[Input], Output],
Callable[[Input, CallbackManagerForChainRun], Output],
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output]]):
The function to call.
input (Input): The input to the function.
run_manager (CallbackManagerForChainRun): The run manager to
pass to the function.
config (RunnableConfig): The config to pass to the function.
**kwargs (Any): The keyword arguments to pass to the function.
Returns:
Output: The output of the function.
"""
if accepts_config(func):
if run_manager is not None:
kwargs["config"] = patch_config(config, callbacks=run_manager.get_child())
else:
kwargs["config"] = config
if run_manager is not None and accepts_run_manager(func):
kwargs["run_manager"] = run_manager
return func(input, **kwargs) # type: ignore[call-arg]
[docs]def acall_func_with_variable_args(
func: Union[
Callable[[Input], Awaitable[Output]],
Callable[[Input, RunnableConfig], Awaitable[Output]],
Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]],
Callable[
[Input, AsyncCallbackManagerForChainRun, RunnableConfig],
Awaitable[Output],
],
],
input: Input,
config: RunnableConfig,
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
**kwargs: Any,
) -> Awaitable[Output]:
"""Call function that may optionally accept a run_manager and/or config.
Args:
func (Union[Callable[[Input], Awaitable[Output]], Callable[[Input,
AsyncCallbackManagerForChainRun], Awaitable[Output]], Callable[[Input,
AsyncCallbackManagerForChainRun, RunnableConfig], Awaitable[Output]]]):
The function to call.
input (Input): The input to the function.
run_manager (AsyncCallbackManagerForChainRun): The run manager
to pass to the function.
config (RunnableConfig): The config to pass to the function.
**kwargs (Any): The keyword arguments to pass to the function.
Returns:
Output: The output of the function.
"""
if accepts_config(func):
if run_manager is not None:
kwargs["config"] = patch_config(config, callbacks=run_manager.get_child())
else:
kwargs["config"] = config
if run_manager is not None and accepts_run_manager(func):
kwargs["run_manager"] = run_manager
return func(input, **kwargs) # type: ignore[call-arg]
[docs]def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager:
"""Get a callback manager for a config.
Args:
config (RunnableConfig): The config.
Returns:
CallbackManager: The callback manager.
"""
from langchain_core.callbacks.manager import CallbackManager
return CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
[docs]def get_async_callback_manager_for_config(
config: RunnableConfig,
) -> AsyncCallbackManager:
"""Get an async callback manager for a config.
Args:
config (RunnableConfig): The config.
Returns:
AsyncCallbackManager: The async callback manager.
"""
from langchain_core.callbacks.manager import AsyncCallbackManager
return AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
P = ParamSpec("P")
T = TypeVar("T")
[docs]class ContextThreadPoolExecutor(ThreadPoolExecutor):
"""ThreadPoolExecutor that copies the context to the child thread."""
[docs] def submit( # type: ignore[override]
self,
func: Callable[P, T],
*args: P.args,
**kwargs: P.kwargs,
) -> Future[T]:
"""Submit a function to the executor.
Args:
func (Callable[..., T]): The function to submit.
*args (Any): The positional arguments to the function.
**kwargs (Any): The keyword arguments to the function.
Returns:
Future[T]: The future for the function.
"""
return super().submit(
cast(Callable[..., T], partial(copy_context().run, func, *args, **kwargs))
)
[docs] def map(
self,
fn: Callable[..., T],
*iterables: Iterable[Any],
timeout: float | None = None,
chunksize: int = 1,
) -> Iterator[T]:
contexts = [copy_context() for _ in range(len(iterables[0]))] # type: ignore[arg-type]
def _wrapped_fn(*args: Any) -> T:
return contexts.pop().run(fn, *args)
return super().map(
_wrapped_fn,
*iterables,
timeout=timeout,
chunksize=chunksize,
)
[docs]@contextmanager
def get_executor_for_config(
config: Optional[RunnableConfig],
) -> Generator[Executor, None, None]:
"""Get an executor for a config.
Args:
config (RunnableConfig): The config.
Yields:
Generator[Executor, None, None]: The executor.
"""
config = config or {}
with ContextThreadPoolExecutor(
max_workers=config.get("max_concurrency")
) as executor:
yield executor
[docs]async def run_in_executor(
executor_or_config: Optional[Union[Executor, RunnableConfig]],
func: Callable[P, T],
*args: P.args,
**kwargs: P.kwargs,
) -> T:
"""Run a function in an executor.
Args:
executor (Executor): The executor.
func (Callable[P, Output]): The function.
*args (Any): The positional arguments to the function.
**kwargs (Any): The keyword arguments to the function.
Returns:
Output: The output of the function.
"""
if executor_or_config is None or isinstance(executor_or_config, dict):
# Use default executor with context copied from current context
return await asyncio.get_running_loop().run_in_executor(
None,
cast(Callable[..., T], partial(copy_context().run, func, *args, **kwargs)),
)
return await asyncio.get_running_loop().run_in_executor(
executor_or_config, partial(func, **kwargs), *args
)