mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-06 05:08:20 +00:00
[core] prompt changes (#15186)
change it to pass all variables through all the way in invoke
This commit is contained in:
parent
ccf9c8e0be
commit
4ad77f777e
@ -73,20 +73,19 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
|
||||
)
|
||||
|
||||
def _format_prompt_with_error_handling(self, inner_input: Dict) -> PromptValue:
|
||||
try:
|
||||
input_dict = {key: inner_input[key] for key in self.input_variables}
|
||||
except TypeError as e:
|
||||
if not isinstance(inner_input, dict):
|
||||
raise TypeError(
|
||||
f"Expected mapping type as input to {self.__class__.__name__}. "
|
||||
f"Received {type(inner_input)}."
|
||||
) from e
|
||||
except KeyError as e:
|
||||
)
|
||||
missing = set(self.input_variables).difference(inner_input)
|
||||
if missing:
|
||||
raise KeyError(
|
||||
f"Input to {self.__class__.__name__} is missing variable {e}. "
|
||||
f"Input to {self.__class__.__name__} is missing variables {missing}. "
|
||||
f" Expected: {self.input_variables}"
|
||||
f" Received: {list(inner_input.keys())}"
|
||||
) from e
|
||||
return self.format_prompt(**input_dict)
|
||||
)
|
||||
return self.format_prompt(**inner_input)
|
||||
|
||||
def invoke(
|
||||
self, input: Dict, config: Optional[RunnableConfig] = None
|
||||
@ -100,7 +99,7 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
|
||||
|
||||
@abstractmethod
|
||||
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
||||
"""Create Chat Messages."""
|
||||
"""Create Prompt Value."""
|
||||
|
||||
@root_validator()
|
||||
def validate_variable_names(cls, values: Dict) -> Dict:
|
||||
|
@ -133,7 +133,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
|
||||
Returns:
|
||||
List of input variable names.
|
||||
"""
|
||||
return [self.variable_name]
|
||||
return [self.variable_name] if not self.optional else []
|
||||
|
||||
|
||||
MessagePromptTemplateT = TypeVar(
|
||||
@ -611,12 +611,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
elif isinstance(
|
||||
message_template, (BaseMessagePromptTemplate, BaseChatPromptTemplate)
|
||||
):
|
||||
rel_params = {
|
||||
k: v
|
||||
for k, v in kwargs.items()
|
||||
if k in message_template.input_variables
|
||||
}
|
||||
message = message_template.format_messages(**rel_params)
|
||||
message = message_template.format_messages(**kwargs)
|
||||
result.extend(message)
|
||||
else:
|
||||
raise ValueError(f"Unexpected input: {message_template}")
|
||||
|
@ -43,6 +43,8 @@ class LogEntry(TypedDict):
|
||||
|
||||
streamed_output_str: List[str]
|
||||
"""List of LLM tokens streamed by this run, if applicable."""
|
||||
streamed_output: List[Any]
|
||||
"""List of output chunks streamed by this run, if available."""
|
||||
final_output: Optional[Any]
|
||||
"""Final output of this run.
|
||||
Only available after the run has finished successfully."""
|
||||
@ -242,6 +244,7 @@ class LogStreamCallbackHandler(BaseTracer):
|
||||
tags=run.tags or [],
|
||||
metadata=(run.extra or {}).get("metadata", {}),
|
||||
start_time=run.start_time.isoformat(timespec="milliseconds"),
|
||||
streamed_output=[],
|
||||
streamed_output_str=[],
|
||||
final_output=None,
|
||||
end_time=None,
|
||||
@ -298,6 +301,13 @@ class LogStreamCallbackHandler(BaseTracer):
|
||||
"op": "add",
|
||||
"path": f"/logs/{index}/streamed_output_str/-",
|
||||
"value": token,
|
||||
}
|
||||
},
|
||||
{
|
||||
"op": "add",
|
||||
"path": f"/logs/{index}/streamed_output/-",
|
||||
"value": chunk.message
|
||||
if isinstance(chunk, ChatGenerationChunk)
|
||||
else token,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
@ -1,22 +1,11 @@
|
||||
"""Utilities for formatting strings."""
|
||||
from string import Formatter
|
||||
from typing import Any, List, Mapping, Sequence, Union
|
||||
from typing import Any, List, Mapping, Sequence
|
||||
|
||||
|
||||
class StrictFormatter(Formatter):
|
||||
"""A subclass of formatter that checks for extra keys."""
|
||||
|
||||
def check_unused_args(
|
||||
self,
|
||||
used_args: Sequence[Union[int, str]],
|
||||
args: Sequence,
|
||||
kwargs: Mapping[str, Any],
|
||||
) -> None:
|
||||
"""Check to see if extra parameters are passed."""
|
||||
extra = set(kwargs).difference(used_args)
|
||||
if extra:
|
||||
raise KeyError(extra)
|
||||
|
||||
def vformat(
|
||||
self, format_string: str, args: Sequence, kwargs: Mapping[str, Any]
|
||||
) -> str:
|
||||
|
@ -96,26 +96,6 @@ def test_prompt_missing_input_variables() -> None:
|
||||
).input_variables == ["foo"]
|
||||
|
||||
|
||||
def test_prompt_extra_input_variables() -> None:
|
||||
"""Test error is raised when there are too many input variables."""
|
||||
template = "This is a {foo} test."
|
||||
input_variables = ["foo", "bar"]
|
||||
with pytest.raises(ValueError):
|
||||
FewShotPromptTemplate(
|
||||
input_variables=input_variables,
|
||||
suffix=template,
|
||||
examples=[],
|
||||
example_prompt=EXAMPLE_PROMPT,
|
||||
validate_template=True,
|
||||
)
|
||||
assert FewShotPromptTemplate(
|
||||
input_variables=input_variables,
|
||||
suffix=template,
|
||||
examples=[],
|
||||
example_prompt=EXAMPLE_PROMPT,
|
||||
).input_variables == ["foo"]
|
||||
|
||||
|
||||
def test_few_shot_functionality() -> None:
|
||||
"""Test that few shot works with examples."""
|
||||
prefix = "This is a test about {content}."
|
||||
|
@ -53,19 +53,6 @@ def test_prompt_empty_input_variable() -> None:
|
||||
PromptTemplate(input_variables=[""], template="{}", validate_template=True)
|
||||
|
||||
|
||||
def test_prompt_extra_input_variables() -> None:
|
||||
"""Test error is raised when there are too many input variables."""
|
||||
template = "This is a {foo} test."
|
||||
input_variables = ["foo", "bar"]
|
||||
with pytest.raises(ValueError):
|
||||
PromptTemplate(
|
||||
input_variables=input_variables, template=template, validate_template=True
|
||||
)
|
||||
assert PromptTemplate(
|
||||
input_variables=input_variables, template=template
|
||||
).input_variables == ["foo"]
|
||||
|
||||
|
||||
def test_prompt_wrong_input_variables() -> None:
|
||||
"""Test error is raised when name of input variable is wrong."""
|
||||
template = "This is a {foo} test."
|
||||
|
@ -2054,6 +2054,7 @@ async def test_prompt_with_llm(
|
||||
"metadata": {},
|
||||
"name": "ChatPromptTemplate",
|
||||
"start_time": "2023-01-01T00:00:00.000",
|
||||
"streamed_output": [],
|
||||
"streamed_output_str": [],
|
||||
"tags": ["seq:step:1"],
|
||||
"type": "prompt",
|
||||
@ -2087,6 +2088,7 @@ async def test_prompt_with_llm(
|
||||
"metadata": {},
|
||||
"name": "FakeListLLM",
|
||||
"start_time": "2023-01-01T00:00:00.000",
|
||||
"streamed_output": [],
|
||||
"streamed_output_str": [],
|
||||
"tags": ["seq:step:2"],
|
||||
"type": "llm",
|
||||
|
@ -18,8 +18,9 @@ def test_does_not_allow_args() -> None:
|
||||
formatter.format(template, "good")
|
||||
|
||||
|
||||
def test_does_not_allow_extra_kwargs() -> None:
|
||||
"""Test formatting does not allow extra keyword arguments."""
|
||||
def test_allows_extra_kwargs() -> None:
|
||||
"""Test formatting allows extra keyword arguments."""
|
||||
template = "This is a {foo} test."
|
||||
with pytest.raises(KeyError):
|
||||
formatter.format(template, foo="good", bar="oops")
|
||||
output = formatter.format(template, foo="good", bar="oops")
|
||||
expected_output = "This is a good test."
|
||||
assert output == expected_output
|
||||
|
Loading…
Reference in New Issue
Block a user