langchain[patch]: Invoke chain prep_inputs and prep_outputs inside try block to catch validation errors (#16644)

- **Description:** Callback manager can't catch chain input or output
validation errors because `prepare_input` and `prepare_output` are not
part of the try/raise logic, this PR fixes that logic.
 
  - **Issue:** #15954
This commit is contained in:
Mo Latif 2024-02-13 23:23:11 -04:00 committed by GitHub
parent a8f530bc4d
commit 50b48a8e6a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 58 additions and 14 deletions

View File

@ -146,24 +146,28 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
self.metadata, self.metadata,
) )
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager") new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
dumpd(self), dumpd(self),
inputs, inputs,
name=run_name, name=run_name,
) )
try: try:
self._validate_inputs(inputs)
outputs = ( outputs = (
self._call(inputs, run_manager=run_manager) self._call(inputs, run_manager=run_manager)
if new_arg_supported if new_arg_supported
else self._call(inputs) else self._call(inputs)
) )
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
except BaseException as e: except BaseException as e:
run_manager.on_chain_error(e) run_manager.on_chain_error(e)
raise e raise e
run_manager.on_chain_end(outputs) run_manager.on_chain_end(outputs)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
if include_run_info: if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id) final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs return final_outputs
@ -199,18 +203,20 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
name=run_name, name=run_name,
) )
try: try:
self._validate_inputs(inputs)
outputs = ( outputs = (
await self._acall(inputs, run_manager=run_manager) await self._acall(inputs, run_manager=run_manager)
if new_arg_supported if new_arg_supported
else await self._acall(inputs) else await self._acall(inputs)
) )
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
except BaseException as e: except BaseException as e:
await run_manager.on_chain_error(e) await run_manager.on_chain_error(e)
raise e raise e
await run_manager.on_chain_end(outputs) await run_manager.on_chain_end(outputs)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
if include_run_info: if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id) final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs return final_outputs
@ -259,6 +265,20 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
def _validate_inputs(self, inputs: Dict[str, Any]) -> None: def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
"""Check that all inputs are present.""" """Check that all inputs are present."""
if not isinstance(inputs, dict):
_input_keys = set(self.input_keys)
if self.memory is not None:
# If there are multiple input keys, but some get set by memory so that
# only one is not set, we can still figure out which key it is.
_input_keys = _input_keys.difference(self.memory.memory_variables)
if len(_input_keys) != 1:
raise ValueError(
f"A single string input was passed in, but this chain expects "
f"multiple inputs ({_input_keys}). When a chain expects "
f"multiple inputs, please call it by passing in a dictionary, "
"eg `chain({'foo': 1, 'bar': 2})`"
)
missing_keys = set(self.input_keys).difference(inputs) missing_keys = set(self.input_keys).difference(inputs)
if missing_keys: if missing_keys:
raise ValueError(f"Missing some input keys: {missing_keys}") raise ValueError(f"Missing some input keys: {missing_keys}")
@ -461,18 +481,10 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
# If there are multiple input keys, but some get set by memory so that # If there are multiple input keys, but some get set by memory so that
# only one is not set, we can still figure out which key it is. # only one is not set, we can still figure out which key it is.
_input_keys = _input_keys.difference(self.memory.memory_variables) _input_keys = _input_keys.difference(self.memory.memory_variables)
if len(_input_keys) != 1:
raise ValueError(
f"A single string input was passed in, but this chain expects "
f"multiple inputs ({_input_keys}). When a chain expects "
f"multiple inputs, please call it by passing in a dictionary, "
"eg `chain({'foo': 1, 'bar': 2})`"
)
inputs = {list(_input_keys)[0]: inputs} inputs = {list(_input_keys)[0]: inputs}
if self.memory is not None: if self.memory is not None:
external_context = self.memory.load_memory_variables(inputs) external_context = self.memory.load_memory_variables(inputs)
inputs = dict(inputs, **external_context) inputs = dict(inputs, **external_context)
self._validate_inputs(inputs)
return inputs return inputs
@property @property

View File

@ -162,3 +162,35 @@ def test_run_with_callback() -> None:
assert handler.starts == 1 assert handler.starts == 1
assert handler.ends == 1 assert handler.ends == 1
assert handler.errors == 0 assert handler.errors == 0
def test_run_with_callback_and_input_error() -> None:
"""Test callback manager catches run validation input error."""
handler = FakeCallbackHandler()
chain = FakeChain(
the_input_keys=["foo", "bar"],
callbacks=[handler],
)
with pytest.raises(ValueError):
chain({"bar": "foo"})
assert handler.starts == 1
assert handler.ends == 0
assert handler.errors == 1
def test_run_with_callback_and_output_error() -> None:
"""Test callback manager catches run validation output error."""
handler = FakeCallbackHandler()
chain = FakeChain(
the_output_keys=["foo", "bar"],
callbacks=[handler],
)
with pytest.raises(ValueError):
chain("foo")
assert handler.starts == 1
assert handler.ends == 0
assert handler.errors == 1