mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 17:38:36 +00:00
Implement better reprs for Runnables (#11175)
``` ChatPromptTemplate(messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], template='You are a nice assistant.')), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['question'], template='{question}'))]) | RunnableLambda(lambda x: x) | { chat: FakeListChatModel(responses=["i'm a chatbot"]), llm: FakeListLLM(responses=["i'm a textbot"]) } ``` <!-- Thank you for contributing to LangChain! Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes (if applicable), - **Dependencies:** any dependencies required for this change, - **Tag maintainer:** for a quicker response, tag the relevant maintainer (see below), - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/extras` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
This commit is contained in:
commit
61b5942adf
@ -77,6 +77,13 @@ class Serializable(BaseModel, ABC):
|
||||
class Config:
|
||||
extra = "ignore"
|
||||
|
||||
def __repr_args__(self) -> Any:
|
||||
return [
|
||||
(k, v)
|
||||
for k, v in super().__repr_args__()
|
||||
if (k not in self.__fields__ or self.__fields__[k].get_default() != v)
|
||||
]
|
||||
|
||||
_lc_kwargs = PrivateAttr(default_factory=dict)
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
|
@ -59,6 +59,8 @@ from langchain.schema.runnable.utils import (
|
||||
accepts_run_manager,
|
||||
gather_with_concurrency,
|
||||
get_function_first_arg_dict_keys,
|
||||
get_lambda_source,
|
||||
indent_lines_after_first,
|
||||
)
|
||||
from langchain.utils.aiter import atee, py_anext
|
||||
from langchain.utils.iter import safetee
|
||||
@ -1298,6 +1300,12 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
def output_schema(self) -> Type[BaseModel]:
|
||||
return self.last.output_schema
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "\n| ".join(
|
||||
repr(s) if i == 0 else indent_lines_after_first(repr(s), "| ")
|
||||
for i, s in enumerate(self.steps)
|
||||
)
|
||||
|
||||
def __or__(
|
||||
self,
|
||||
other: Union[
|
||||
@ -1819,6 +1827,13 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
**{k: (v.OutputType, None) for k, v in self.steps.items()},
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
map_for_repr = ",\n ".join(
|
||||
f"{k}: {indent_lines_after_first(repr(v), ' ' + k + ': ')}"
|
||||
for k, v in self.steps.items()
|
||||
)
|
||||
return "{\n " + map_for_repr + "\n}"
|
||||
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Dict[str, Any]:
|
||||
@ -2134,7 +2149,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
return False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "RunnableLambda(...)"
|
||||
return f"RunnableLambda({get_lambda_source(self.func) or '...'})"
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
|
@ -87,6 +87,17 @@ class IsFunctionArgDict(ast.NodeVisitor):
|
||||
IsLocalDict(input_arg_name, self.keys).visit(node)
|
||||
|
||||
|
||||
class GetLambdaSource(ast.NodeVisitor):
|
||||
def __init__(self) -> None:
|
||||
self.source: Optional[str] = None
|
||||
self.count = 0
|
||||
|
||||
def visit_Lambda(self, node: ast.Lambda) -> Any:
|
||||
self.count += 1
|
||||
if hasattr(ast, "unparse"):
|
||||
self.source = ast.unparse(node)
|
||||
|
||||
|
||||
def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]:
|
||||
try:
|
||||
code = inspect.getsource(func)
|
||||
@ -94,5 +105,40 @@ def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]:
|
||||
visitor = IsFunctionArgDict()
|
||||
visitor.visit(tree)
|
||||
return list(visitor.keys) if visitor.keys else None
|
||||
except (TypeError, OSError):
|
||||
except (SyntaxError, TypeError, OSError):
|
||||
return None
|
||||
|
||||
|
||||
def get_lambda_source(func: Callable) -> Optional[str]:
|
||||
"""Get the source code of a lambda function.
|
||||
|
||||
Args:
|
||||
func: a callable that can be a lambda function
|
||||
|
||||
Returns:
|
||||
str: the source code of the lambda function
|
||||
"""
|
||||
try:
|
||||
code = inspect.getsource(func)
|
||||
tree = ast.parse(textwrap.dedent(code))
|
||||
visitor = GetLambdaSource()
|
||||
visitor.visit(tree)
|
||||
return visitor.source if visitor.count == 1 else None
|
||||
except (SyntaxError, TypeError, OSError):
|
||||
return None
|
||||
|
||||
|
||||
def indent_lines_after_first(text: str, prefix: str) -> str:
|
||||
"""Indent all lines of text after the first line.
|
||||
|
||||
Args:
|
||||
text: The text to indent
|
||||
prefix: Used to determine the number of spaces to indent
|
||||
|
||||
Returns:
|
||||
str: The indented text
|
||||
"""
|
||||
n_spaces = len(prefix)
|
||||
spaces = " " * n_spaces
|
||||
lines = text.splitlines()
|
||||
return "\n".join([lines[0]] + [spaces + line for line in lines[1:]])
|
||||
|
File diff suppressed because one or more lines are too long
@ -867,6 +867,7 @@ async def test_prompt_with_chat_model(
|
||||
|
||||
chain = prompt | chat
|
||||
|
||||
assert repr(chain) == snapshot
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
assert chain.first == prompt
|
||||
assert chain.middle == []
|
||||
@ -1276,6 +1277,7 @@ def test_combining_sequences(
|
||||
assert chain.first == prompt
|
||||
assert chain.middle == [chat]
|
||||
assert chain.last == parser
|
||||
if sys.version_info >= (3, 9):
|
||||
assert dumps(chain, pretty=True) == snapshot
|
||||
|
||||
prompt2 = (
|
||||
@ -1294,6 +1296,7 @@ def test_combining_sequences(
|
||||
assert chain2.first == input_formatter
|
||||
assert chain2.middle == [prompt2, chat2]
|
||||
assert chain2.last == parser2
|
||||
if sys.version_info >= (3, 9):
|
||||
assert dumps(chain2, pretty=True) == snapshot
|
||||
|
||||
combined_chain = chain | chain2
|
||||
@ -1307,6 +1310,7 @@ def test_combining_sequences(
|
||||
chat2,
|
||||
]
|
||||
assert combined_chain.last == parser2
|
||||
if sys.version_info >= (3, 9):
|
||||
assert dumps(combined_chain, pretty=True) == snapshot
|
||||
|
||||
# Test invoke
|
||||
@ -1315,6 +1319,7 @@ def test_combining_sequences(
|
||||
{"question": "What is your name?"}, dict(callbacks=[tracer])
|
||||
) == ["baz", "qux"]
|
||||
|
||||
if sys.version_info >= (3, 9):
|
||||
assert tracer.runs == snapshot
|
||||
|
||||
|
||||
@ -1350,6 +1355,7 @@ Question:
|
||||
| parser
|
||||
)
|
||||
|
||||
assert repr(chain) == snapshot
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
assert isinstance(chain.first, RunnableMap)
|
||||
assert chain.middle == [prompt, chat]
|
||||
@ -1375,7 +1381,7 @@ Question:
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(
|
||||
content="""Context:
|
||||
[Document(page_content='foo', metadata={}), Document(page_content='bar', metadata={})]
|
||||
[Document(page_content='foo'), Document(page_content='bar')]
|
||||
|
||||
Question:
|
||||
What is your name?"""
|
||||
@ -1413,6 +1419,7 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
|
||||
}
|
||||
)
|
||||
|
||||
assert repr(chain) == snapshot
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
assert chain.first == prompt
|
||||
assert chain.middle == [RunnableLambda(passthrough)]
|
||||
@ -2098,6 +2105,7 @@ async def test_llm_with_fallbacks(
|
||||
assert await runnable.ainvoke("hello") == "bar"
|
||||
assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3
|
||||
assert list(await runnable.ainvoke("hello")) == list("bar")
|
||||
if sys.version_info >= (3, 9):
|
||||
assert dumps(runnable, pretty=True) == snapshot
|
||||
|
||||
|
||||
@ -2196,6 +2204,7 @@ def test_retrying(mocker: MockerFixture) -> None:
|
||||
with pytest.raises(RuntimeError):
|
||||
runnable.with_retry(
|
||||
stop_after_attempt=2,
|
||||
wait_exponential_jitter=False,
|
||||
retry_if_exception_type=(ValueError,),
|
||||
).invoke(2)
|
||||
|
||||
@ -2205,6 +2214,7 @@ def test_retrying(mocker: MockerFixture) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
runnable.with_retry(
|
||||
stop_after_attempt=2,
|
||||
wait_exponential_jitter=False,
|
||||
retry_if_exception_type=(ValueError,),
|
||||
).batch([1, 2, 0])
|
||||
|
||||
@ -2214,6 +2224,7 @@ def test_retrying(mocker: MockerFixture) -> None:
|
||||
|
||||
output = runnable.with_retry(
|
||||
stop_after_attempt=2,
|
||||
wait_exponential_jitter=False,
|
||||
retry_if_exception_type=(ValueError,),
|
||||
).batch([1, 2, 0], return_exceptions=True)
|
||||
|
||||
@ -2248,6 +2259,7 @@ async def test_async_retrying(mocker: MockerFixture) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
await runnable.with_retry(
|
||||
stop_after_attempt=2,
|
||||
wait_exponential_jitter=False,
|
||||
retry_if_exception_type=(ValueError, KeyError),
|
||||
).ainvoke(1)
|
||||
|
||||
@ -2257,6 +2269,7 @@ async def test_async_retrying(mocker: MockerFixture) -> None:
|
||||
with pytest.raises(RuntimeError):
|
||||
await runnable.with_retry(
|
||||
stop_after_attempt=2,
|
||||
wait_exponential_jitter=False,
|
||||
retry_if_exception_type=(ValueError,),
|
||||
).ainvoke(2)
|
||||
|
||||
@ -2266,6 +2279,7 @@ async def test_async_retrying(mocker: MockerFixture) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
await runnable.with_retry(
|
||||
stop_after_attempt=2,
|
||||
wait_exponential_jitter=False,
|
||||
retry_if_exception_type=(ValueError,),
|
||||
).abatch([1, 2, 0])
|
||||
|
||||
@ -2275,6 +2289,7 @@ async def test_async_retrying(mocker: MockerFixture) -> None:
|
||||
|
||||
output = await runnable.with_retry(
|
||||
stop_after_attempt=2,
|
||||
wait_exponential_jitter=False,
|
||||
retry_if_exception_type=(ValueError,),
|
||||
).abatch([1, 2, 0], return_exceptions=True)
|
||||
|
||||
@ -2729,3 +2744,38 @@ async def test_runnable_branch_abatch() -> None:
|
||||
)
|
||||
|
||||
assert await branch.abatch([1, 10, 0]) == [2, 100, -1]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||
)
|
||||
def test_representation_of_runnables() -> None:
|
||||
"""Test representation of runnables."""
|
||||
runnable = RunnableLambda(lambda x: x * 2)
|
||||
assert repr(runnable) == "RunnableLambda(lambda x: x * 2)"
|
||||
|
||||
def f(x: int) -> int:
|
||||
"""Return 2."""
|
||||
return 2
|
||||
|
||||
assert repr(RunnableLambda(func=f)) == "RunnableLambda(...)"
|
||||
|
||||
async def af(x: int) -> int:
|
||||
"""Return 2."""
|
||||
return 2
|
||||
|
||||
assert repr(RunnableLambda(func=f, afunc=af)) == "RunnableLambda(...)"
|
||||
|
||||
assert repr(
|
||||
RunnableLambda(lambda x: x + 2)
|
||||
| {
|
||||
"a": RunnableLambda(lambda x: x * 2),
|
||||
"b": RunnableLambda(lambda x: x * 3),
|
||||
}
|
||||
) == (
|
||||
"RunnableLambda(...)\n"
|
||||
"| {\n"
|
||||
" a: RunnableLambda(...),\n"
|
||||
" b: RunnableLambda(...)\n"
|
||||
" }"
|
||||
), "repr where code string contains multiple lambdas gives up"
|
||||
|
@ -0,0 +1,39 @@
|
||||
import sys
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.schema.runnable.utils import (
|
||||
get_lambda_source,
|
||||
indent_lines_after_first,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"func, expected_source",
|
||||
[
|
||||
(lambda x: x * 2, "lambda x: x * 2"),
|
||||
(lambda a, b: a + b, "lambda a, b: a + b"),
|
||||
(lambda x: x if x > 0 else 0, "lambda x: x if x > 0 else 0"),
|
||||
],
|
||||
)
|
||||
def test_get_lambda_source(func: Callable, expected_source: str) -> None:
|
||||
"""Test get_lambda_source function"""
|
||||
source = get_lambda_source(func)
|
||||
assert source == expected_source
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text,prefix,expected_output",
|
||||
[
|
||||
("line 1\nline 2\nline 3", "1", "line 1\n line 2\n line 3"),
|
||||
("line 1\nline 2\nline 3", "ax", "line 1\n line 2\n line 3"),
|
||||
],
|
||||
)
|
||||
def test_indent_lines_after_first(text: str, prefix: str, expected_output: str) -> None:
|
||||
"""Test indent_lines_after_first function"""
|
||||
indented_text = indent_lines_after_first(text, prefix)
|
||||
assert indented_text == expected_output
|
Loading…
Reference in New Issue
Block a user