mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-08 14:05:16 +00:00
Implement better reprs for Runnables
This commit is contained in:
parent
cfa2203c62
commit
5c1f462bb9
@ -68,6 +68,13 @@ class Serializable(BaseModel, ABC):
|
||||
class Config:
|
||||
extra = "ignore"
|
||||
|
||||
def __repr_args__(self) -> Any:
|
||||
return [
|
||||
(k, v)
|
||||
for k, v in super().__repr_args__()
|
||||
if (k not in self.__fields__ or self.__fields__[k].get_default() != v)
|
||||
]
|
||||
|
||||
_lc_kwargs = PrivateAttr(default_factory=dict)
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
|
@ -59,6 +59,8 @@ from langchain.schema.runnable.utils import (
|
||||
accepts_run_manager,
|
||||
gather_with_concurrency,
|
||||
get_function_first_arg_dict_keys,
|
||||
get_lambda_source,
|
||||
indent_lines_after_first,
|
||||
)
|
||||
from langchain.utils.aiter import atee, py_anext
|
||||
from langchain.utils.iter import safetee
|
||||
@ -1298,6 +1300,12 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
def output_schema(self) -> Type[BaseModel]:
|
||||
return self.last.output_schema
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "\n| ".join(
|
||||
repr(s) if i == 0 else indent_lines_after_first(repr(s), "| ")
|
||||
for i, s in enumerate(self.steps)
|
||||
)
|
||||
|
||||
def __or__(
|
||||
self,
|
||||
other: Union[
|
||||
@ -1819,6 +1827,13 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
**{k: (v.OutputType, None) for k, v in self.steps.items()},
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
map_for_repr = ",\n ".join(
|
||||
f"{k}: {indent_lines_after_first(repr(v), ' ' + k + ': ')}"
|
||||
for k, v in self.steps.items()
|
||||
)
|
||||
return "{\n " + map_for_repr + "\n}"
|
||||
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Dict[str, Any]:
|
||||
@ -2134,7 +2149,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
return False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "RunnableLambda(...)"
|
||||
return f"RunnableLambda({get_lambda_source(self.func) or '...'})"
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
|
@ -87,6 +87,14 @@ class IsFunctionArgDict(ast.NodeVisitor):
|
||||
IsLocalDict(input_arg_name, self.keys).visit(node)
|
||||
|
||||
|
||||
class GetLambdaSource(ast.NodeVisitor):
|
||||
def __init__(self) -> None:
|
||||
self.source: Optional[str] = None
|
||||
|
||||
def visit_Lambda(self, node: ast.Lambda) -> Any:
|
||||
self.source = ast.unparse(node)
|
||||
|
||||
|
||||
def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]:
|
||||
try:
|
||||
code = inspect.getsource(func)
|
||||
@ -94,5 +102,23 @@ def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]:
|
||||
visitor = IsFunctionArgDict()
|
||||
visitor.visit(tree)
|
||||
return list(visitor.keys) if visitor.keys else None
|
||||
except (TypeError, OSError):
|
||||
except (SyntaxError, TypeError, OSError):
|
||||
return None
|
||||
|
||||
|
||||
def get_lambda_source(func: Callable) -> Optional[str]:
|
||||
try:
|
||||
code = inspect.getsource(func)
|
||||
tree = ast.parse(textwrap.dedent(code))
|
||||
visitor = GetLambdaSource()
|
||||
visitor.visit(tree)
|
||||
return visitor.source
|
||||
except (SyntaxError, TypeError, OSError):
|
||||
return None
|
||||
|
||||
|
||||
def indent_lines_after_first(text: str, prefix: str) -> str:
|
||||
n_spaces = len(prefix)
|
||||
spaces = " " * n_spaces
|
||||
lines = text.splitlines()
|
||||
return "\n".join([lines[0]] + [spaces + line for line in lines[1:]])
|
||||
|
File diff suppressed because one or more lines are too long
@ -867,6 +867,7 @@ async def test_prompt_with_chat_model(
|
||||
|
||||
chain = prompt | chat
|
||||
|
||||
assert repr(chain) == snapshot
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
assert chain.first == prompt
|
||||
assert chain.middle == []
|
||||
@ -1350,6 +1351,7 @@ Question:
|
||||
| parser
|
||||
)
|
||||
|
||||
assert repr(chain) == snapshot
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
assert isinstance(chain.first, RunnableMap)
|
||||
assert chain.middle == [prompt, chat]
|
||||
@ -1375,7 +1377,7 @@ Question:
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(
|
||||
content="""Context:
|
||||
[Document(page_content='foo', metadata={}), Document(page_content='bar', metadata={})]
|
||||
[Document(page_content='foo'), Document(page_content='bar')]
|
||||
|
||||
Question:
|
||||
What is your name?"""
|
||||
@ -1413,6 +1415,7 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
|
||||
}
|
||||
)
|
||||
|
||||
assert repr(chain) == snapshot
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
assert chain.first == prompt
|
||||
assert chain.middle == [RunnableLambda(passthrough)]
|
||||
@ -2196,6 +2199,7 @@ def test_retrying(mocker: MockerFixture) -> None:
|
||||
with pytest.raises(RuntimeError):
|
||||
runnable.with_retry(
|
||||
stop_after_attempt=2,
|
||||
wait_exponential_jitter=False,
|
||||
retry_if_exception_type=(ValueError,),
|
||||
).invoke(2)
|
||||
|
||||
@ -2205,6 +2209,7 @@ def test_retrying(mocker: MockerFixture) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
runnable.with_retry(
|
||||
stop_after_attempt=2,
|
||||
wait_exponential_jitter=False,
|
||||
retry_if_exception_type=(ValueError,),
|
||||
).batch([1, 2, 0])
|
||||
|
||||
@ -2214,6 +2219,7 @@ def test_retrying(mocker: MockerFixture) -> None:
|
||||
|
||||
output = runnable.with_retry(
|
||||
stop_after_attempt=2,
|
||||
wait_exponential_jitter=False,
|
||||
retry_if_exception_type=(ValueError,),
|
||||
).batch([1, 2, 0], return_exceptions=True)
|
||||
|
||||
@ -2248,6 +2254,7 @@ async def test_async_retrying(mocker: MockerFixture) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
await runnable.with_retry(
|
||||
stop_after_attempt=2,
|
||||
wait_exponential_jitter=False,
|
||||
retry_if_exception_type=(ValueError, KeyError),
|
||||
).ainvoke(1)
|
||||
|
||||
@ -2257,6 +2264,7 @@ async def test_async_retrying(mocker: MockerFixture) -> None:
|
||||
with pytest.raises(RuntimeError):
|
||||
await runnable.with_retry(
|
||||
stop_after_attempt=2,
|
||||
wait_exponential_jitter=False,
|
||||
retry_if_exception_type=(ValueError,),
|
||||
).ainvoke(2)
|
||||
|
||||
@ -2266,6 +2274,7 @@ async def test_async_retrying(mocker: MockerFixture) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
await runnable.with_retry(
|
||||
stop_after_attempt=2,
|
||||
wait_exponential_jitter=False,
|
||||
retry_if_exception_type=(ValueError,),
|
||||
).abatch([1, 2, 0])
|
||||
|
||||
@ -2275,6 +2284,7 @@ async def test_async_retrying(mocker: MockerFixture) -> None:
|
||||
|
||||
output = await runnable.with_retry(
|
||||
stop_after_attempt=2,
|
||||
wait_exponential_jitter=False,
|
||||
retry_if_exception_type=(ValueError,),
|
||||
).abatch([1, 2, 0], return_exceptions=True)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user