mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 06:14:37 +00:00
langchain[patch]: Invoke chain prep_inputs and prep_outputs inside try block to catch validation errors (#16644)
- **Description:** Callback manager can't catch chain input or output validation errors because `prepare_input` and `prepare_output` are not part of the try/raise logic, this PR fixes that logic. - **Issue:** #15954
This commit is contained in:
parent
a8f530bc4d
commit
50b48a8e6a
@ -146,24 +146,28 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
self.metadata,
|
self.metadata,
|
||||||
)
|
)
|
||||||
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
|
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
|
||||||
|
|
||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(
|
||||||
dumpd(self),
|
dumpd(self),
|
||||||
inputs,
|
inputs,
|
||||||
name=run_name,
|
name=run_name,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
self._validate_inputs(inputs)
|
||||||
outputs = (
|
outputs = (
|
||||||
self._call(inputs, run_manager=run_manager)
|
self._call(inputs, run_manager=run_manager)
|
||||||
if new_arg_supported
|
if new_arg_supported
|
||||||
else self._call(inputs)
|
else self._call(inputs)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
final_outputs: Dict[str, Any] = self.prep_outputs(
|
||||||
|
inputs, outputs, return_only_outputs
|
||||||
|
)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
run_manager.on_chain_error(e)
|
run_manager.on_chain_error(e)
|
||||||
raise e
|
raise e
|
||||||
run_manager.on_chain_end(outputs)
|
run_manager.on_chain_end(outputs)
|
||||||
final_outputs: Dict[str, Any] = self.prep_outputs(
|
|
||||||
inputs, outputs, return_only_outputs
|
|
||||||
)
|
|
||||||
if include_run_info:
|
if include_run_info:
|
||||||
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
||||||
return final_outputs
|
return final_outputs
|
||||||
@ -199,18 +203,20 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
name=run_name,
|
name=run_name,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
self._validate_inputs(inputs)
|
||||||
outputs = (
|
outputs = (
|
||||||
await self._acall(inputs, run_manager=run_manager)
|
await self._acall(inputs, run_manager=run_manager)
|
||||||
if new_arg_supported
|
if new_arg_supported
|
||||||
else await self._acall(inputs)
|
else await self._acall(inputs)
|
||||||
)
|
)
|
||||||
|
final_outputs: Dict[str, Any] = self.prep_outputs(
|
||||||
|
inputs, outputs, return_only_outputs
|
||||||
|
)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
await run_manager.on_chain_error(e)
|
await run_manager.on_chain_error(e)
|
||||||
raise e
|
raise e
|
||||||
await run_manager.on_chain_end(outputs)
|
await run_manager.on_chain_end(outputs)
|
||||||
final_outputs: Dict[str, Any] = self.prep_outputs(
|
|
||||||
inputs, outputs, return_only_outputs
|
|
||||||
)
|
|
||||||
if include_run_info:
|
if include_run_info:
|
||||||
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
||||||
return final_outputs
|
return final_outputs
|
||||||
@ -259,6 +265,20 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
|
|
||||||
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
|
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
|
||||||
"""Check that all inputs are present."""
|
"""Check that all inputs are present."""
|
||||||
|
if not isinstance(inputs, dict):
|
||||||
|
_input_keys = set(self.input_keys)
|
||||||
|
if self.memory is not None:
|
||||||
|
# If there are multiple input keys, but some get set by memory so that
|
||||||
|
# only one is not set, we can still figure out which key it is.
|
||||||
|
_input_keys = _input_keys.difference(self.memory.memory_variables)
|
||||||
|
if len(_input_keys) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"A single string input was passed in, but this chain expects "
|
||||||
|
f"multiple inputs ({_input_keys}). When a chain expects "
|
||||||
|
f"multiple inputs, please call it by passing in a dictionary, "
|
||||||
|
"eg `chain({'foo': 1, 'bar': 2})`"
|
||||||
|
)
|
||||||
|
|
||||||
missing_keys = set(self.input_keys).difference(inputs)
|
missing_keys = set(self.input_keys).difference(inputs)
|
||||||
if missing_keys:
|
if missing_keys:
|
||||||
raise ValueError(f"Missing some input keys: {missing_keys}")
|
raise ValueError(f"Missing some input keys: {missing_keys}")
|
||||||
@ -461,18 +481,10 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
# If there are multiple input keys, but some get set by memory so that
|
# If there are multiple input keys, but some get set by memory so that
|
||||||
# only one is not set, we can still figure out which key it is.
|
# only one is not set, we can still figure out which key it is.
|
||||||
_input_keys = _input_keys.difference(self.memory.memory_variables)
|
_input_keys = _input_keys.difference(self.memory.memory_variables)
|
||||||
if len(_input_keys) != 1:
|
|
||||||
raise ValueError(
|
|
||||||
f"A single string input was passed in, but this chain expects "
|
|
||||||
f"multiple inputs ({_input_keys}). When a chain expects "
|
|
||||||
f"multiple inputs, please call it by passing in a dictionary, "
|
|
||||||
"eg `chain({'foo': 1, 'bar': 2})`"
|
|
||||||
)
|
|
||||||
inputs = {list(_input_keys)[0]: inputs}
|
inputs = {list(_input_keys)[0]: inputs}
|
||||||
if self.memory is not None:
|
if self.memory is not None:
|
||||||
external_context = self.memory.load_memory_variables(inputs)
|
external_context = self.memory.load_memory_variables(inputs)
|
||||||
inputs = dict(inputs, **external_context)
|
inputs = dict(inputs, **external_context)
|
||||||
self._validate_inputs(inputs)
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -162,3 +162,35 @@ def test_run_with_callback() -> None:
|
|||||||
assert handler.starts == 1
|
assert handler.starts == 1
|
||||||
assert handler.ends == 1
|
assert handler.ends == 1
|
||||||
assert handler.errors == 0
|
assert handler.errors == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_with_callback_and_input_error() -> None:
|
||||||
|
"""Test callback manager catches run validation input error."""
|
||||||
|
handler = FakeCallbackHandler()
|
||||||
|
chain = FakeChain(
|
||||||
|
the_input_keys=["foo", "bar"],
|
||||||
|
callbacks=[handler],
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
chain({"bar": "foo"})
|
||||||
|
|
||||||
|
assert handler.starts == 1
|
||||||
|
assert handler.ends == 0
|
||||||
|
assert handler.errors == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_with_callback_and_output_error() -> None:
|
||||||
|
"""Test callback manager catches run validation output error."""
|
||||||
|
handler = FakeCallbackHandler()
|
||||||
|
chain = FakeChain(
|
||||||
|
the_output_keys=["foo", "bar"],
|
||||||
|
callbacks=[handler],
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
chain("foo")
|
||||||
|
|
||||||
|
assert handler.starts == 1
|
||||||
|
assert handler.ends == 0
|
||||||
|
assert handler.errors == 1
|
||||||
|
Loading…
Reference in New Issue
Block a user