langchain[patch]: Invoke chain prep_inputs and prep_outputs inside try block to catch validation errors (#16644)

- **Description:** Callback manager can't catch chain input or output
validation errors because `prepare_input` and `prepare_output` are not
part of the try/raise logic, this PR fixes that logic.
 
  - **Issue:** #15954
This commit is contained in:
Mo Latif 2024-02-13 23:23:11 -04:00 committed by GitHub
parent a8f530bc4d
commit 50b48a8e6a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 58 additions and 14 deletions

View File

@ -146,24 +146,28 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
self.metadata,
)
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
run_manager = callback_manager.on_chain_start(
dumpd(self),
inputs,
name=run_name,
)
try:
self._validate_inputs(inputs)
outputs = (
self._call(inputs, run_manager=run_manager)
if new_arg_supported
else self._call(inputs)
)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
except BaseException as e:
run_manager.on_chain_error(e)
raise e
run_manager.on_chain_end(outputs)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs
@ -199,18 +203,20 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
name=run_name,
)
try:
self._validate_inputs(inputs)
outputs = (
await self._acall(inputs, run_manager=run_manager)
if new_arg_supported
else await self._acall(inputs)
)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
except BaseException as e:
await run_manager.on_chain_error(e)
raise e
await run_manager.on_chain_end(outputs)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs
@ -259,6 +265,20 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
"""Check that all inputs are present."""
if not isinstance(inputs, dict):
_input_keys = set(self.input_keys)
if self.memory is not None:
# If there are multiple input keys, but some get set by memory so that
# only one is not set, we can still figure out which key it is.
_input_keys = _input_keys.difference(self.memory.memory_variables)
if len(_input_keys) != 1:
raise ValueError(
f"A single string input was passed in, but this chain expects "
f"multiple inputs ({_input_keys}). When a chain expects "
f"multiple inputs, please call it by passing in a dictionary, "
"eg `chain({'foo': 1, 'bar': 2})`"
)
missing_keys = set(self.input_keys).difference(inputs)
if missing_keys:
raise ValueError(f"Missing some input keys: {missing_keys}")
@ -461,18 +481,10 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
# If there are multiple input keys, but some get set by memory so that
# only one is not set, we can still figure out which key it is.
_input_keys = _input_keys.difference(self.memory.memory_variables)
if len(_input_keys) != 1:
raise ValueError(
f"A single string input was passed in, but this chain expects "
f"multiple inputs ({_input_keys}). When a chain expects "
f"multiple inputs, please call it by passing in a dictionary, "
"eg `chain({'foo': 1, 'bar': 2})`"
)
inputs = {list(_input_keys)[0]: inputs}
if self.memory is not None:
external_context = self.memory.load_memory_variables(inputs)
inputs = dict(inputs, **external_context)
self._validate_inputs(inputs)
return inputs
@property

View File

@ -162,3 +162,35 @@ def test_run_with_callback() -> None:
assert handler.starts == 1
assert handler.ends == 1
assert handler.errors == 0
def test_run_with_callback_and_input_error() -> None:
"""Test callback manager catches run validation input error."""
handler = FakeCallbackHandler()
chain = FakeChain(
the_input_keys=["foo", "bar"],
callbacks=[handler],
)
with pytest.raises(ValueError):
chain({"bar": "foo"})
assert handler.starts == 1
assert handler.ends == 0
assert handler.errors == 1
def test_run_with_callback_and_output_error() -> None:
"""Test callback manager catches run validation output error."""
handler = FakeCallbackHandler()
chain = FakeChain(
the_output_keys=["foo", "bar"],
callbacks=[handler],
)
with pytest.raises(ValueError):
chain("foo")
assert handler.starts == 1
assert handler.ends == 0
assert handler.errors == 1