diff --git a/libs/core/langchain_core/prompts/base.py b/libs/core/langchain_core/prompts/base.py index 1cff41200f7..2e1878ed27e 100644 --- a/libs/core/langchain_core/prompts/base.py +++ b/libs/core/langchain_core/prompts/base.py @@ -73,20 +73,19 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC): ) def _format_prompt_with_error_handling(self, inner_input: Dict) -> PromptValue: - try: - input_dict = {key: inner_input[key] for key in self.input_variables} - except TypeError as e: + if not isinstance(inner_input, dict): raise TypeError( f"Expected mapping type as input to {self.__class__.__name__}. " f"Received {type(inner_input)}." - ) from e - except KeyError as e: + ) + missing = set(self.input_variables).difference(inner_input) + if missing: raise KeyError( - f"Input to {self.__class__.__name__} is missing variable {e}. " + f"Input to {self.__class__.__name__} is missing variables {missing}. " f" Expected: {self.input_variables}" f" Received: {list(inner_input.keys())}" - ) from e - return self.format_prompt(**input_dict) + ) + return self.format_prompt(**inner_input) def invoke( self, input: Dict, config: Optional[RunnableConfig] = None @@ -100,7 +99,7 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC): @abstractmethod def format_prompt(self, **kwargs: Any) -> PromptValue: - """Create Chat Messages.""" + """Create Prompt Value.""" @root_validator() def validate_variable_names(cls, values: Dict) -> Dict: diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index e2978d383bf..94ce5e6040d 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -133,7 +133,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate): Returns: List of input variable names. """ - return [self.variable_name] + return [self.variable_name] if not self.optional else [] MessagePromptTemplateT = TypeVar( @@ -611,12 +611,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate): elif isinstance( message_template, (BaseMessagePromptTemplate, BaseChatPromptTemplate) ): - rel_params = { - k: v - for k, v in kwargs.items() - if k in message_template.input_variables - } - message = message_template.format_messages(**rel_params) + message = message_template.format_messages(**kwargs) result.extend(message) else: raise ValueError(f"Unexpected input: {message_template}") diff --git a/libs/core/langchain_core/tracers/log_stream.py b/libs/core/langchain_core/tracers/log_stream.py index 6d2b7e43a06..2d4a512edb1 100644 --- a/libs/core/langchain_core/tracers/log_stream.py +++ b/libs/core/langchain_core/tracers/log_stream.py @@ -43,6 +43,8 @@ class LogEntry(TypedDict): streamed_output_str: List[str] """List of LLM tokens streamed by this run, if applicable.""" + streamed_output: List[Any] + """List of output chunks streamed by this run, if available.""" final_output: Optional[Any] """Final output of this run. Only available after the run has finished successfully.""" @@ -242,6 +244,7 @@ class LogStreamCallbackHandler(BaseTracer): tags=run.tags or [], metadata=(run.extra or {}).get("metadata", {}), start_time=run.start_time.isoformat(timespec="milliseconds"), + streamed_output=[], streamed_output_str=[], final_output=None, end_time=None, @@ -298,6 +301,13 @@ class LogStreamCallbackHandler(BaseTracer): "op": "add", "path": f"/logs/{index}/streamed_output_str/-", "value": token, - } + }, + { + "op": "add", + "path": f"/logs/{index}/streamed_output/-", + "value": chunk.message + if isinstance(chunk, ChatGenerationChunk) + else token, + }, ) ) diff --git a/libs/core/langchain_core/utils/formatting.py b/libs/core/langchain_core/utils/formatting.py index 3b3b597b083..83fe2aba927 100644 --- a/libs/core/langchain_core/utils/formatting.py +++ b/libs/core/langchain_core/utils/formatting.py @@ -1,22 +1,11 @@ """Utilities for formatting strings.""" from string import Formatter -from typing import Any, List, Mapping, Sequence, Union +from typing import Any, List, Mapping, Sequence class StrictFormatter(Formatter): """A subclass of formatter that checks for extra keys.""" - def check_unused_args( - self, - used_args: Sequence[Union[int, str]], - args: Sequence, - kwargs: Mapping[str, Any], - ) -> None: - """Check to see if extra parameters are passed.""" - extra = set(kwargs).difference(used_args) - if extra: - raise KeyError(extra) - def vformat( self, format_string: str, args: Sequence, kwargs: Mapping[str, Any] ) -> str: diff --git a/libs/core/tests/unit_tests/prompts/test_few_shot.py b/libs/core/tests/unit_tests/prompts/test_few_shot.py index 1a94b1c6ca2..7b4bc2378e2 100644 --- a/libs/core/tests/unit_tests/prompts/test_few_shot.py +++ b/libs/core/tests/unit_tests/prompts/test_few_shot.py @@ -96,26 +96,6 @@ def test_prompt_missing_input_variables() -> None: ).input_variables == ["foo"] -def test_prompt_extra_input_variables() -> None: - """Test error is raised when there are too many input variables.""" - template = "This is a {foo} test." - input_variables = ["foo", "bar"] - with pytest.raises(ValueError): - FewShotPromptTemplate( - input_variables=input_variables, - suffix=template, - examples=[], - example_prompt=EXAMPLE_PROMPT, - validate_template=True, - ) - assert FewShotPromptTemplate( - input_variables=input_variables, - suffix=template, - examples=[], - example_prompt=EXAMPLE_PROMPT, - ).input_variables == ["foo"] - - def test_few_shot_functionality() -> None: """Test that few shot works with examples.""" prefix = "This is a test about {content}." diff --git a/libs/core/tests/unit_tests/prompts/test_prompt.py b/libs/core/tests/unit_tests/prompts/test_prompt.py index 1b63b8859e1..ec01e426f78 100644 --- a/libs/core/tests/unit_tests/prompts/test_prompt.py +++ b/libs/core/tests/unit_tests/prompts/test_prompt.py @@ -53,19 +53,6 @@ def test_prompt_empty_input_variable() -> None: PromptTemplate(input_variables=[""], template="{}", validate_template=True) -def test_prompt_extra_input_variables() -> None: - """Test error is raised when there are too many input variables.""" - template = "This is a {foo} test." - input_variables = ["foo", "bar"] - with pytest.raises(ValueError): - PromptTemplate( - input_variables=input_variables, template=template, validate_template=True - ) - assert PromptTemplate( - input_variables=input_variables, template=template - ).input_variables == ["foo"] - - def test_prompt_wrong_input_variables() -> None: """Test error is raised when name of input variable is wrong.""" template = "This is a {foo} test." diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 4723be01159..4a3adc05cef 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -2054,6 +2054,7 @@ async def test_prompt_with_llm( "metadata": {}, "name": "ChatPromptTemplate", "start_time": "2023-01-01T00:00:00.000", + "streamed_output": [], "streamed_output_str": [], "tags": ["seq:step:1"], "type": "prompt", @@ -2087,6 +2088,7 @@ async def test_prompt_with_llm( "metadata": {}, "name": "FakeListLLM", "start_time": "2023-01-01T00:00:00.000", + "streamed_output": [], "streamed_output_str": [], "tags": ["seq:step:2"], "type": "llm", diff --git a/libs/langchain/tests/unit_tests/test_formatting.py b/libs/langchain/tests/unit_tests/test_formatting.py index 096fd13d306..1da6210fdda 100644 --- a/libs/langchain/tests/unit_tests/test_formatting.py +++ b/libs/langchain/tests/unit_tests/test_formatting.py @@ -18,8 +18,9 @@ def test_does_not_allow_args() -> None: formatter.format(template, "good") -def test_does_not_allow_extra_kwargs() -> None: - """Test formatting does not allow extra keyword arguments.""" +def test_allows_extra_kwargs() -> None: + """Test formatting allows extra keyword arguments.""" template = "This is a {foo} test." - with pytest.raises(KeyError): - formatter.format(template, foo="good", bar="oops") + output = formatter.format(template, foo="good", bar="oops") + expected_output = "This is a good test." + assert output == expected_output