mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-05 22:53:30 +00:00
langchain[patch]: Invoke chain prep_inputs and prep_outputs inside try block to catch validation errors (#16644)
- **Description:** Callback manager can't catch chain input or output validation errors because `prepare_input` and `prepare_output` are not part of the try/raise logic, this PR fixes that logic. - **Issue:** #15954
This commit is contained in:
parent
a8f530bc4d
commit
50b48a8e6a
@ -146,24 +146,28 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
self.metadata,
|
||||
)
|
||||
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
|
||||
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
inputs,
|
||||
name=run_name,
|
||||
)
|
||||
try:
|
||||
self._validate_inputs(inputs)
|
||||
outputs = (
|
||||
self._call(inputs, run_manager=run_manager)
|
||||
if new_arg_supported
|
||||
else self._call(inputs)
|
||||
)
|
||||
|
||||
final_outputs: Dict[str, Any] = self.prep_outputs(
|
||||
inputs, outputs, return_only_outputs
|
||||
)
|
||||
except BaseException as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise e
|
||||
run_manager.on_chain_end(outputs)
|
||||
final_outputs: Dict[str, Any] = self.prep_outputs(
|
||||
inputs, outputs, return_only_outputs
|
||||
)
|
||||
|
||||
if include_run_info:
|
||||
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
||||
return final_outputs
|
||||
@ -199,18 +203,20 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
name=run_name,
|
||||
)
|
||||
try:
|
||||
self._validate_inputs(inputs)
|
||||
outputs = (
|
||||
await self._acall(inputs, run_manager=run_manager)
|
||||
if new_arg_supported
|
||||
else await self._acall(inputs)
|
||||
)
|
||||
final_outputs: Dict[str, Any] = self.prep_outputs(
|
||||
inputs, outputs, return_only_outputs
|
||||
)
|
||||
except BaseException as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise e
|
||||
await run_manager.on_chain_end(outputs)
|
||||
final_outputs: Dict[str, Any] = self.prep_outputs(
|
||||
inputs, outputs, return_only_outputs
|
||||
)
|
||||
|
||||
if include_run_info:
|
||||
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
||||
return final_outputs
|
||||
@ -259,6 +265,20 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
|
||||
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
|
||||
"""Check that all inputs are present."""
|
||||
if not isinstance(inputs, dict):
|
||||
_input_keys = set(self.input_keys)
|
||||
if self.memory is not None:
|
||||
# If there are multiple input keys, but some get set by memory so that
|
||||
# only one is not set, we can still figure out which key it is.
|
||||
_input_keys = _input_keys.difference(self.memory.memory_variables)
|
||||
if len(_input_keys) != 1:
|
||||
raise ValueError(
|
||||
f"A single string input was passed in, but this chain expects "
|
||||
f"multiple inputs ({_input_keys}). When a chain expects "
|
||||
f"multiple inputs, please call it by passing in a dictionary, "
|
||||
"eg `chain({'foo': 1, 'bar': 2})`"
|
||||
)
|
||||
|
||||
missing_keys = set(self.input_keys).difference(inputs)
|
||||
if missing_keys:
|
||||
raise ValueError(f"Missing some input keys: {missing_keys}")
|
||||
@ -461,18 +481,10 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
# If there are multiple input keys, but some get set by memory so that
|
||||
# only one is not set, we can still figure out which key it is.
|
||||
_input_keys = _input_keys.difference(self.memory.memory_variables)
|
||||
if len(_input_keys) != 1:
|
||||
raise ValueError(
|
||||
f"A single string input was passed in, but this chain expects "
|
||||
f"multiple inputs ({_input_keys}). When a chain expects "
|
||||
f"multiple inputs, please call it by passing in a dictionary, "
|
||||
"eg `chain({'foo': 1, 'bar': 2})`"
|
||||
)
|
||||
inputs = {list(_input_keys)[0]: inputs}
|
||||
if self.memory is not None:
|
||||
external_context = self.memory.load_memory_variables(inputs)
|
||||
inputs = dict(inputs, **external_context)
|
||||
self._validate_inputs(inputs)
|
||||
return inputs
|
||||
|
||||
@property
|
||||
|
@ -162,3 +162,35 @@ def test_run_with_callback() -> None:
|
||||
assert handler.starts == 1
|
||||
assert handler.ends == 1
|
||||
assert handler.errors == 0
|
||||
|
||||
|
||||
def test_run_with_callback_and_input_error() -> None:
|
||||
"""Test callback manager catches run validation input error."""
|
||||
handler = FakeCallbackHandler()
|
||||
chain = FakeChain(
|
||||
the_input_keys=["foo", "bar"],
|
||||
callbacks=[handler],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
chain({"bar": "foo"})
|
||||
|
||||
assert handler.starts == 1
|
||||
assert handler.ends == 0
|
||||
assert handler.errors == 1
|
||||
|
||||
|
||||
def test_run_with_callback_and_output_error() -> None:
|
||||
"""Test callback manager catches run validation output error."""
|
||||
handler = FakeCallbackHandler()
|
||||
chain = FakeChain(
|
||||
the_output_keys=["foo", "bar"],
|
||||
callbacks=[handler],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
chain("foo")
|
||||
|
||||
assert handler.starts == 1
|
||||
assert handler.ends == 0
|
||||
assert handler.errors == 1
|
||||
|
Loading…
Reference in New Issue
Block a user