[core] prompt changes (#15186)

change it to pass all variables through all the way in invoke
This commit is contained in:
Harrison Chase 2023-12-26 15:52:17 -08:00 committed by GitHub
parent ccf9c8e0be
commit 4ad77f777e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 29 additions and 66 deletions

View File

@ -73,20 +73,19 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
) )
def _format_prompt_with_error_handling(self, inner_input: Dict) -> PromptValue: def _format_prompt_with_error_handling(self, inner_input: Dict) -> PromptValue:
try: if not isinstance(inner_input, dict):
input_dict = {key: inner_input[key] for key in self.input_variables}
except TypeError as e:
raise TypeError( raise TypeError(
f"Expected mapping type as input to {self.__class__.__name__}. " f"Expected mapping type as input to {self.__class__.__name__}. "
f"Received {type(inner_input)}." f"Received {type(inner_input)}."
) from e )
except KeyError as e: missing = set(self.input_variables).difference(inner_input)
if missing:
raise KeyError( raise KeyError(
f"Input to {self.__class__.__name__} is missing variable {e}. " f"Input to {self.__class__.__name__} is missing variables {missing}. "
f" Expected: {self.input_variables}" f" Expected: {self.input_variables}"
f" Received: {list(inner_input.keys())}" f" Received: {list(inner_input.keys())}"
) from e )
return self.format_prompt(**input_dict) return self.format_prompt(**inner_input)
def invoke( def invoke(
self, input: Dict, config: Optional[RunnableConfig] = None self, input: Dict, config: Optional[RunnableConfig] = None
@ -100,7 +99,7 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
@abstractmethod @abstractmethod
def format_prompt(self, **kwargs: Any) -> PromptValue: def format_prompt(self, **kwargs: Any) -> PromptValue:
"""Create Chat Messages.""" """Create Prompt Value."""
@root_validator() @root_validator()
def validate_variable_names(cls, values: Dict) -> Dict: def validate_variable_names(cls, values: Dict) -> Dict:

View File

@ -133,7 +133,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
Returns: Returns:
List of input variable names. List of input variable names.
""" """
return [self.variable_name] return [self.variable_name] if not self.optional else []
MessagePromptTemplateT = TypeVar( MessagePromptTemplateT = TypeVar(
@ -611,12 +611,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
elif isinstance( elif isinstance(
message_template, (BaseMessagePromptTemplate, BaseChatPromptTemplate) message_template, (BaseMessagePromptTemplate, BaseChatPromptTemplate)
): ):
rel_params = { message = message_template.format_messages(**kwargs)
k: v
for k, v in kwargs.items()
if k in message_template.input_variables
}
message = message_template.format_messages(**rel_params)
result.extend(message) result.extend(message)
else: else:
raise ValueError(f"Unexpected input: {message_template}") raise ValueError(f"Unexpected input: {message_template}")

View File

@ -43,6 +43,8 @@ class LogEntry(TypedDict):
streamed_output_str: List[str] streamed_output_str: List[str]
"""List of LLM tokens streamed by this run, if applicable.""" """List of LLM tokens streamed by this run, if applicable."""
streamed_output: List[Any]
"""List of output chunks streamed by this run, if available."""
final_output: Optional[Any] final_output: Optional[Any]
"""Final output of this run. """Final output of this run.
Only available after the run has finished successfully.""" Only available after the run has finished successfully."""
@ -242,6 +244,7 @@ class LogStreamCallbackHandler(BaseTracer):
tags=run.tags or [], tags=run.tags or [],
metadata=(run.extra or {}).get("metadata", {}), metadata=(run.extra or {}).get("metadata", {}),
start_time=run.start_time.isoformat(timespec="milliseconds"), start_time=run.start_time.isoformat(timespec="milliseconds"),
streamed_output=[],
streamed_output_str=[], streamed_output_str=[],
final_output=None, final_output=None,
end_time=None, end_time=None,
@ -298,6 +301,13 @@ class LogStreamCallbackHandler(BaseTracer):
"op": "add", "op": "add",
"path": f"/logs/{index}/streamed_output_str/-", "path": f"/logs/{index}/streamed_output_str/-",
"value": token, "value": token,
} },
{
"op": "add",
"path": f"/logs/{index}/streamed_output/-",
"value": chunk.message
if isinstance(chunk, ChatGenerationChunk)
else token,
},
) )
) )

View File

@ -1,22 +1,11 @@
"""Utilities for formatting strings.""" """Utilities for formatting strings."""
from string import Formatter from string import Formatter
from typing import Any, List, Mapping, Sequence, Union from typing import Any, List, Mapping, Sequence
class StrictFormatter(Formatter): class StrictFormatter(Formatter):
"""A subclass of formatter that checks for extra keys.""" """A subclass of formatter that checks for extra keys."""
def check_unused_args(
self,
used_args: Sequence[Union[int, str]],
args: Sequence,
kwargs: Mapping[str, Any],
) -> None:
"""Check to see if extra parameters are passed."""
extra = set(kwargs).difference(used_args)
if extra:
raise KeyError(extra)
def vformat( def vformat(
self, format_string: str, args: Sequence, kwargs: Mapping[str, Any] self, format_string: str, args: Sequence, kwargs: Mapping[str, Any]
) -> str: ) -> str:

View File

@ -96,26 +96,6 @@ def test_prompt_missing_input_variables() -> None:
).input_variables == ["foo"] ).input_variables == ["foo"]
def test_prompt_extra_input_variables() -> None:
"""Test error is raised when there are too many input variables."""
template = "This is a {foo} test."
input_variables = ["foo", "bar"]
with pytest.raises(ValueError):
FewShotPromptTemplate(
input_variables=input_variables,
suffix=template,
examples=[],
example_prompt=EXAMPLE_PROMPT,
validate_template=True,
)
assert FewShotPromptTemplate(
input_variables=input_variables,
suffix=template,
examples=[],
example_prompt=EXAMPLE_PROMPT,
).input_variables == ["foo"]
def test_few_shot_functionality() -> None: def test_few_shot_functionality() -> None:
"""Test that few shot works with examples.""" """Test that few shot works with examples."""
prefix = "This is a test about {content}." prefix = "This is a test about {content}."

View File

@ -53,19 +53,6 @@ def test_prompt_empty_input_variable() -> None:
PromptTemplate(input_variables=[""], template="{}", validate_template=True) PromptTemplate(input_variables=[""], template="{}", validate_template=True)
def test_prompt_extra_input_variables() -> None:
"""Test error is raised when there are too many input variables."""
template = "This is a {foo} test."
input_variables = ["foo", "bar"]
with pytest.raises(ValueError):
PromptTemplate(
input_variables=input_variables, template=template, validate_template=True
)
assert PromptTemplate(
input_variables=input_variables, template=template
).input_variables == ["foo"]
def test_prompt_wrong_input_variables() -> None: def test_prompt_wrong_input_variables() -> None:
"""Test error is raised when name of input variable is wrong.""" """Test error is raised when name of input variable is wrong."""
template = "This is a {foo} test." template = "This is a {foo} test."

View File

@ -2054,6 +2054,7 @@ async def test_prompt_with_llm(
"metadata": {}, "metadata": {},
"name": "ChatPromptTemplate", "name": "ChatPromptTemplate",
"start_time": "2023-01-01T00:00:00.000", "start_time": "2023-01-01T00:00:00.000",
"streamed_output": [],
"streamed_output_str": [], "streamed_output_str": [],
"tags": ["seq:step:1"], "tags": ["seq:step:1"],
"type": "prompt", "type": "prompt",
@ -2087,6 +2088,7 @@ async def test_prompt_with_llm(
"metadata": {}, "metadata": {},
"name": "FakeListLLM", "name": "FakeListLLM",
"start_time": "2023-01-01T00:00:00.000", "start_time": "2023-01-01T00:00:00.000",
"streamed_output": [],
"streamed_output_str": [], "streamed_output_str": [],
"tags": ["seq:step:2"], "tags": ["seq:step:2"],
"type": "llm", "type": "llm",

View File

@ -18,8 +18,9 @@ def test_does_not_allow_args() -> None:
formatter.format(template, "good") formatter.format(template, "good")
def test_does_not_allow_extra_kwargs() -> None: def test_allows_extra_kwargs() -> None:
"""Test formatting does not allow extra keyword arguments.""" """Test formatting allows extra keyword arguments."""
template = "This is a {foo} test." template = "This is a {foo} test."
with pytest.raises(KeyError): output = formatter.format(template, foo="good", bar="oops")
formatter.format(template, foo="good", bar="oops") expected_output = "This is a good test."
assert output == expected_output