mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-07 13:40:46 +00:00
[core] prompt changes (#15186)
change it to pass all variables through all the way in invoke
This commit is contained in:
parent
ccf9c8e0be
commit
4ad77f777e
@ -73,20 +73,19 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _format_prompt_with_error_handling(self, inner_input: Dict) -> PromptValue:
|
def _format_prompt_with_error_handling(self, inner_input: Dict) -> PromptValue:
|
||||||
try:
|
if not isinstance(inner_input, dict):
|
||||||
input_dict = {key: inner_input[key] for key in self.input_variables}
|
|
||||||
except TypeError as e:
|
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Expected mapping type as input to {self.__class__.__name__}. "
|
f"Expected mapping type as input to {self.__class__.__name__}. "
|
||||||
f"Received {type(inner_input)}."
|
f"Received {type(inner_input)}."
|
||||||
) from e
|
)
|
||||||
except KeyError as e:
|
missing = set(self.input_variables).difference(inner_input)
|
||||||
|
if missing:
|
||||||
raise KeyError(
|
raise KeyError(
|
||||||
f"Input to {self.__class__.__name__} is missing variable {e}. "
|
f"Input to {self.__class__.__name__} is missing variables {missing}. "
|
||||||
f" Expected: {self.input_variables}"
|
f" Expected: {self.input_variables}"
|
||||||
f" Received: {list(inner_input.keys())}"
|
f" Received: {list(inner_input.keys())}"
|
||||||
) from e
|
)
|
||||||
return self.format_prompt(**input_dict)
|
return self.format_prompt(**inner_input)
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self, input: Dict, config: Optional[RunnableConfig] = None
|
self, input: Dict, config: Optional[RunnableConfig] = None
|
||||||
@ -100,7 +99,7 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
||||||
"""Create Chat Messages."""
|
"""Create Prompt Value."""
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_variable_names(cls, values: Dict) -> Dict:
|
def validate_variable_names(cls, values: Dict) -> Dict:
|
||||||
|
@ -133,7 +133,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
|
|||||||
Returns:
|
Returns:
|
||||||
List of input variable names.
|
List of input variable names.
|
||||||
"""
|
"""
|
||||||
return [self.variable_name]
|
return [self.variable_name] if not self.optional else []
|
||||||
|
|
||||||
|
|
||||||
MessagePromptTemplateT = TypeVar(
|
MessagePromptTemplateT = TypeVar(
|
||||||
@ -611,12 +611,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
|||||||
elif isinstance(
|
elif isinstance(
|
||||||
message_template, (BaseMessagePromptTemplate, BaseChatPromptTemplate)
|
message_template, (BaseMessagePromptTemplate, BaseChatPromptTemplate)
|
||||||
):
|
):
|
||||||
rel_params = {
|
message = message_template.format_messages(**kwargs)
|
||||||
k: v
|
|
||||||
for k, v in kwargs.items()
|
|
||||||
if k in message_template.input_variables
|
|
||||||
}
|
|
||||||
message = message_template.format_messages(**rel_params)
|
|
||||||
result.extend(message)
|
result.extend(message)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected input: {message_template}")
|
raise ValueError(f"Unexpected input: {message_template}")
|
||||||
|
@ -43,6 +43,8 @@ class LogEntry(TypedDict):
|
|||||||
|
|
||||||
streamed_output_str: List[str]
|
streamed_output_str: List[str]
|
||||||
"""List of LLM tokens streamed by this run, if applicable."""
|
"""List of LLM tokens streamed by this run, if applicable."""
|
||||||
|
streamed_output: List[Any]
|
||||||
|
"""List of output chunks streamed by this run, if available."""
|
||||||
final_output: Optional[Any]
|
final_output: Optional[Any]
|
||||||
"""Final output of this run.
|
"""Final output of this run.
|
||||||
Only available after the run has finished successfully."""
|
Only available after the run has finished successfully."""
|
||||||
@ -242,6 +244,7 @@ class LogStreamCallbackHandler(BaseTracer):
|
|||||||
tags=run.tags or [],
|
tags=run.tags or [],
|
||||||
metadata=(run.extra or {}).get("metadata", {}),
|
metadata=(run.extra or {}).get("metadata", {}),
|
||||||
start_time=run.start_time.isoformat(timespec="milliseconds"),
|
start_time=run.start_time.isoformat(timespec="milliseconds"),
|
||||||
|
streamed_output=[],
|
||||||
streamed_output_str=[],
|
streamed_output_str=[],
|
||||||
final_output=None,
|
final_output=None,
|
||||||
end_time=None,
|
end_time=None,
|
||||||
@ -298,6 +301,13 @@ class LogStreamCallbackHandler(BaseTracer):
|
|||||||
"op": "add",
|
"op": "add",
|
||||||
"path": f"/logs/{index}/streamed_output_str/-",
|
"path": f"/logs/{index}/streamed_output_str/-",
|
||||||
"value": token,
|
"value": token,
|
||||||
}
|
},
|
||||||
|
{
|
||||||
|
"op": "add",
|
||||||
|
"path": f"/logs/{index}/streamed_output/-",
|
||||||
|
"value": chunk.message
|
||||||
|
if isinstance(chunk, ChatGenerationChunk)
|
||||||
|
else token,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -1,22 +1,11 @@
|
|||||||
"""Utilities for formatting strings."""
|
"""Utilities for formatting strings."""
|
||||||
from string import Formatter
|
from string import Formatter
|
||||||
from typing import Any, List, Mapping, Sequence, Union
|
from typing import Any, List, Mapping, Sequence
|
||||||
|
|
||||||
|
|
||||||
class StrictFormatter(Formatter):
|
class StrictFormatter(Formatter):
|
||||||
"""A subclass of formatter that checks for extra keys."""
|
"""A subclass of formatter that checks for extra keys."""
|
||||||
|
|
||||||
def check_unused_args(
|
|
||||||
self,
|
|
||||||
used_args: Sequence[Union[int, str]],
|
|
||||||
args: Sequence,
|
|
||||||
kwargs: Mapping[str, Any],
|
|
||||||
) -> None:
|
|
||||||
"""Check to see if extra parameters are passed."""
|
|
||||||
extra = set(kwargs).difference(used_args)
|
|
||||||
if extra:
|
|
||||||
raise KeyError(extra)
|
|
||||||
|
|
||||||
def vformat(
|
def vformat(
|
||||||
self, format_string: str, args: Sequence, kwargs: Mapping[str, Any]
|
self, format_string: str, args: Sequence, kwargs: Mapping[str, Any]
|
||||||
) -> str:
|
) -> str:
|
||||||
|
@ -96,26 +96,6 @@ def test_prompt_missing_input_variables() -> None:
|
|||||||
).input_variables == ["foo"]
|
).input_variables == ["foo"]
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_extra_input_variables() -> None:
|
|
||||||
"""Test error is raised when there are too many input variables."""
|
|
||||||
template = "This is a {foo} test."
|
|
||||||
input_variables = ["foo", "bar"]
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
FewShotPromptTemplate(
|
|
||||||
input_variables=input_variables,
|
|
||||||
suffix=template,
|
|
||||||
examples=[],
|
|
||||||
example_prompt=EXAMPLE_PROMPT,
|
|
||||||
validate_template=True,
|
|
||||||
)
|
|
||||||
assert FewShotPromptTemplate(
|
|
||||||
input_variables=input_variables,
|
|
||||||
suffix=template,
|
|
||||||
examples=[],
|
|
||||||
example_prompt=EXAMPLE_PROMPT,
|
|
||||||
).input_variables == ["foo"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_few_shot_functionality() -> None:
|
def test_few_shot_functionality() -> None:
|
||||||
"""Test that few shot works with examples."""
|
"""Test that few shot works with examples."""
|
||||||
prefix = "This is a test about {content}."
|
prefix = "This is a test about {content}."
|
||||||
|
@ -53,19 +53,6 @@ def test_prompt_empty_input_variable() -> None:
|
|||||||
PromptTemplate(input_variables=[""], template="{}", validate_template=True)
|
PromptTemplate(input_variables=[""], template="{}", validate_template=True)
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_extra_input_variables() -> None:
|
|
||||||
"""Test error is raised when there are too many input variables."""
|
|
||||||
template = "This is a {foo} test."
|
|
||||||
input_variables = ["foo", "bar"]
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
PromptTemplate(
|
|
||||||
input_variables=input_variables, template=template, validate_template=True
|
|
||||||
)
|
|
||||||
assert PromptTemplate(
|
|
||||||
input_variables=input_variables, template=template
|
|
||||||
).input_variables == ["foo"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_wrong_input_variables() -> None:
|
def test_prompt_wrong_input_variables() -> None:
|
||||||
"""Test error is raised when name of input variable is wrong."""
|
"""Test error is raised when name of input variable is wrong."""
|
||||||
template = "This is a {foo} test."
|
template = "This is a {foo} test."
|
||||||
|
@ -2054,6 +2054,7 @@ async def test_prompt_with_llm(
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"name": "ChatPromptTemplate",
|
"name": "ChatPromptTemplate",
|
||||||
"start_time": "2023-01-01T00:00:00.000",
|
"start_time": "2023-01-01T00:00:00.000",
|
||||||
|
"streamed_output": [],
|
||||||
"streamed_output_str": [],
|
"streamed_output_str": [],
|
||||||
"tags": ["seq:step:1"],
|
"tags": ["seq:step:1"],
|
||||||
"type": "prompt",
|
"type": "prompt",
|
||||||
@ -2087,6 +2088,7 @@ async def test_prompt_with_llm(
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"name": "FakeListLLM",
|
"name": "FakeListLLM",
|
||||||
"start_time": "2023-01-01T00:00:00.000",
|
"start_time": "2023-01-01T00:00:00.000",
|
||||||
|
"streamed_output": [],
|
||||||
"streamed_output_str": [],
|
"streamed_output_str": [],
|
||||||
"tags": ["seq:step:2"],
|
"tags": ["seq:step:2"],
|
||||||
"type": "llm",
|
"type": "llm",
|
||||||
|
@ -18,8 +18,9 @@ def test_does_not_allow_args() -> None:
|
|||||||
formatter.format(template, "good")
|
formatter.format(template, "good")
|
||||||
|
|
||||||
|
|
||||||
def test_does_not_allow_extra_kwargs() -> None:
|
def test_allows_extra_kwargs() -> None:
|
||||||
"""Test formatting does not allow extra keyword arguments."""
|
"""Test formatting allows extra keyword arguments."""
|
||||||
template = "This is a {foo} test."
|
template = "This is a {foo} test."
|
||||||
with pytest.raises(KeyError):
|
output = formatter.format(template, foo="good", bar="oops")
|
||||||
formatter.format(template, foo="good", bar="oops")
|
expected_output = "This is a good test."
|
||||||
|
assert output == expected_output
|
||||||
|
Loading…
Reference in New Issue
Block a user