Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 135 additions & 29 deletions libs/core/langchain_core/runnables/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@

import ast
import asyncio
import dis
import inspect
import sys
import textwrap

# Cannot move to TYPE_CHECKING as Mapping and Sequence are needed at runtime by
# RunnableConfigurableFields.
from collections.abc import Mapping, Sequence # noqa: TC003
from contextlib import suppress
from functools import lru_cache
from inspect import signature
from itertools import groupby
Expand Down Expand Up @@ -38,6 +40,7 @@
Iterable,
)
from contextvars import Context
from types import CodeType

from langchain_core.runnables.schema import StreamEvent

Expand Down Expand Up @@ -404,6 +407,63 @@ def get_lambda_source(func: Callable) -> str | None:
return visitor.source if visitor.count == 1 else name


@lru_cache(maxsize=256)
def _nonlocal_access_plan(
code: CodeType,
) -> tuple[tuple[str, ...], tuple[tuple[str, tuple[str, ...]], ...]]:
"""Compute a nonlocal access plan from bytecode.

Args:
code: Code object to scan.

Returns:
A tuple ``(plain_roots, chains)``, where:
- plain_roots: Names loaded without any attribute access.
- chains: ``(root, path)`` pairs for attribute-access chains.
"""
root_ops = {"LOAD_GLOBAL", "LOAD_DEREF", "LOAD_NAME"}
attr_ops = {"LOAD_ATTR", "LOAD_METHOD"}

plain_roots: set[str] = set()
chains: list[tuple[str, tuple[str, ...]]] = []

base: str | None = None
attrs: list[str] = []

def flush() -> None:
nonlocal base, attrs
if base is not None:
if attrs:
chains.append((base, tuple(attrs)))
else:
plain_roots.add(base)
base = None
attrs = []

for ins in dis.get_instructions(code):
op = ins.opname
if op in root_ops and isinstance(ins.argval, str):
flush()
base = ins.argval
continue
if op in attr_ops and isinstance(ins.argval, str):
if base is not None:
attrs.append(ins.argval)
continue
flush()

flush()

deduped = []
seen = set()
for c in chains:
if c not in seen:
seen.add(c)
deduped.append(c)

return tuple(sorted(plain_roots)), tuple(deduped)


@lru_cache(maxsize=256)
def get_function_nonlocals(func: Callable) -> list[Any]:
"""Get the nonlocal variables accessed by a function.
Expand All @@ -414,37 +474,83 @@ def get_function_nonlocals(func: Callable) -> list[Any]:
Returns:
The nonlocal variables accessed by the function.
"""
try:
code = inspect.getsource(func)
tree = ast.parse(textwrap.dedent(code))
visitor = FunctionNonLocals()
visitor.visit(tree)
values: list[Any] = []
closure = (
inspect.getclosurevars(func.__wrapped__)
if hasattr(func, "__wrapped__") and callable(func.__wrapped__)
else inspect.getclosurevars(func)
)
candidates = {**closure.globals, **closure.nonlocals}
for k, v in candidates.items():
if k in visitor.nonlocals:
values.append(v)
for kk in visitor.nonlocals:
if "." in kk and kk.startswith(k):
vv = v
for part in kk.split(".")[1:]:
if vv is None:
break
try:
vv = getattr(vv, part)
except AttributeError:
break
else:
values.append(vv)
except (SyntaxError, TypeError, OSError, SystemError):
terminal_methods = {
"invoke",
"ainvoke",
"batch",
"abatch",
"stream",
"astream",
"transform",
"atransform",
}

target = func
seen_wrapped = set()
while True:
w = getattr(target, "__wrapped__", None)
if not callable(w):
break
wid = id(w)
if wid in seen_wrapped:
break
seen_wrapped.add(wid)
target = w

if getattr(target, "__code__", None) is None:
target = getattr(target, "__func__", target)

code = getattr(target, "__code__", None)
if code is None:
return []

return values
nonlocals_dict = {}
freevars = code.co_freevars
closure = getattr(target, "__closure__", None)
if closure and freevars:
for name, cell in zip(freevars, closure, strict=False):
with suppress(ValueError):
nonlocals_dict[name] = cell.cell_contents

globals_dict = getattr(target, "__globals__", {})

plain_roots, chains = _nonlocal_access_plan(code)

out = []
seen_ids = set()

def add(v: Any) -> None:
vid = id(v)
if vid not in seen_ids:
seen_ids.add(vid)
out.append(v)

def resolve_root(name: str) -> Any | None:
if name in nonlocals_dict:
return nonlocals_dict[name]
return globals_dict.get(name)

for name in plain_roots:
v = resolve_root(name)
if v is not None:
add(v)

for base, attrs in chains:
if not attrs or attrs[-1] not in terminal_methods:
continue
v = resolve_root(base)
if v is None:
continue
for a in attrs:
try:
v = getattr(v, a)
except Exception:
break
else:
if v is not None:
add(v)

return out


def indent_lines_after_first(text: str, prefix: str) -> str:
Expand Down
3 changes: 0 additions & 3 deletions libs/core/tests/unit_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@ def blockbuster() -> Iterator[BlockBuster]:
bb.functions[func]
.can_block_in("langchain_core/_api/internal.py", "is_caller_internal")
.can_block_in("langchain_core/runnables/base.py", "__repr__")
.can_block_in(
"langchain_core/beta/runnables/context.py", "aconfig_with_context"
)
)

for func in ["os.stat", "io.TextIOWrapper.read"]:
Expand Down
35 changes: 34 additions & 1 deletion libs/core/tests/unit_tests/runnables/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect
from collections.abc import Callable
from typing import Any
from typing import Any, NoReturn

import pytest

Expand Down Expand Up @@ -73,3 +74,35 @@ def my_func6(value: str) -> str:
assert RunnableLambda(my_func3).deps == [agent]
assert RunnableLambda(my_func4).deps == [global_agent]
assert RunnableLambda(func).deps == [nl]


def test_deps_does_not_call_inspect_getsource() -> None:
original = inspect.getsource
error_message = "inspect.getsource was called while computing deps"

def explode(*_args: Any, **_kwargs: Any) -> NoReturn:
raise AssertionError(error_message)

inspect.getsource = explode
try:
agent: RunnableLambda[str, str] = RunnableLambda(lambda x: x)

class Box:
def __init__(self, a: RunnableLambda[str, str]) -> None:
self.agent: RunnableLambda[str, str] = a

box = Box(agent)

def my_func(x: str) -> str:
return box.agent.invoke(x)

r: RunnableLambda[str, str] = RunnableLambda(my_func)
_ = r.deps
finally:
inspect.getsource = original


def test_deps_is_cached_on_instance() -> None:
r = RunnableLambda(lambda x: x)
_ = r.deps
assert "deps" in r.__dict__