diff --git a/libs/langchain/langchain/chains/base.py b/libs/langchain/langchain/chains/base.py index 1bb6d1636b2..7f8aec6d21d 100644 --- a/libs/langchain/langchain/chains/base.py +++ b/libs/langchain/langchain/chains/base.py @@ -146,24 +146,28 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC): self.metadata, ) new_arg_supported = inspect.signature(self._call).parameters.get("run_manager") + run_manager = callback_manager.on_chain_start( dumpd(self), inputs, name=run_name, ) try: + self._validate_inputs(inputs) outputs = ( self._call(inputs, run_manager=run_manager) if new_arg_supported else self._call(inputs) ) + + final_outputs: Dict[str, Any] = self.prep_outputs( + inputs, outputs, return_only_outputs + ) except BaseException as e: run_manager.on_chain_error(e) raise e run_manager.on_chain_end(outputs) - final_outputs: Dict[str, Any] = self.prep_outputs( - inputs, outputs, return_only_outputs - ) + if include_run_info: final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id) return final_outputs @@ -199,18 +203,20 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC): name=run_name, ) try: + self._validate_inputs(inputs) outputs = ( await self._acall(inputs, run_manager=run_manager) if new_arg_supported else await self._acall(inputs) ) + final_outputs: Dict[str, Any] = self.prep_outputs( + inputs, outputs, return_only_outputs + ) except BaseException as e: await run_manager.on_chain_error(e) raise e await run_manager.on_chain_end(outputs) - final_outputs: Dict[str, Any] = self.prep_outputs( - inputs, outputs, return_only_outputs - ) + if include_run_info: final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id) return final_outputs @@ -259,6 +265,20 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC): def _validate_inputs(self, inputs: Dict[str, Any]) -> None: """Check that all inputs are present.""" + if not isinstance(inputs, dict): + _input_keys = set(self.input_keys) + if self.memory is not None: + # If there are multiple input keys, but some get set by memory so that + # only one is not set, we can still figure out which key it is. + _input_keys = _input_keys.difference(self.memory.memory_variables) + if len(_input_keys) != 1: + raise ValueError( + f"A single string input was passed in, but this chain expects " + f"multiple inputs ({_input_keys}). When a chain expects " + f"multiple inputs, please call it by passing in a dictionary, " + "eg `chain({'foo': 1, 'bar': 2})`" + ) + missing_keys = set(self.input_keys).difference(inputs) if missing_keys: raise ValueError(f"Missing some input keys: {missing_keys}") @@ -461,18 +481,10 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC): # If there are multiple input keys, but some get set by memory so that # only one is not set, we can still figure out which key it is. _input_keys = _input_keys.difference(self.memory.memory_variables) - if len(_input_keys) != 1: - raise ValueError( - f"A single string input was passed in, but this chain expects " - f"multiple inputs ({_input_keys}). When a chain expects " - f"multiple inputs, please call it by passing in a dictionary, " - "eg `chain({'foo': 1, 'bar': 2})`" - ) inputs = {list(_input_keys)[0]: inputs} if self.memory is not None: external_context = self.memory.load_memory_variables(inputs) inputs = dict(inputs, **external_context) - self._validate_inputs(inputs) return inputs @property diff --git a/libs/langchain/tests/unit_tests/chains/test_base.py b/libs/langchain/tests/unit_tests/chains/test_base.py index 2c410e2337d..c96f1d945b1 100644 --- a/libs/langchain/tests/unit_tests/chains/test_base.py +++ b/libs/langchain/tests/unit_tests/chains/test_base.py @@ -162,3 +162,35 @@ def test_run_with_callback() -> None: assert handler.starts == 1 assert handler.ends == 1 assert handler.errors == 0 + + +def test_run_with_callback_and_input_error() -> None: + """Test callback manager catches run validation input error.""" + handler = FakeCallbackHandler() + chain = FakeChain( + the_input_keys=["foo", "bar"], + callbacks=[handler], + ) + + with pytest.raises(ValueError): + chain({"bar": "foo"}) + + assert handler.starts == 1 + assert handler.ends == 0 + assert handler.errors == 1 + + +def test_run_with_callback_and_output_error() -> None: + """Test callback manager catches run validation output error.""" + handler = FakeCallbackHandler() + chain = FakeChain( + the_output_keys=["foo", "bar"], + callbacks=[handler], + ) + + with pytest.raises(ValueError): + chain("foo") + + assert handler.starts == 1 + assert handler.ends == 0 + assert handler.errors == 1