This commit is contained in:
Eugene Yurtsev
2023-08-31 12:21:47 -04:00
parent 8e8c368019
commit 5a68ba4174
2 changed files with 33 additions and 12 deletions

View File

@@ -129,7 +129,7 @@ class ThinkActPromptGenerator(PromptValue):
"""The string variant of the prompt."""
finalized = []
messages = self.to_messages()
for message in messages:
for idx, message in enumerate(messages):
if isinstance(message, FunctionResult):
component = f"Observation: {message.result}"
elif isinstance(
@@ -140,16 +140,14 @@ class ThinkActPromptGenerator(PromptValue):
elif isinstance(message, HumanMessage):
component = f"Question: {message.content}"
elif isinstance(message, (AIMessage, SystemMessage)):
component = message.content
if idx > 0 and isinstance(messages[idx], (HumanMessage, FunctionCall)):
# Priming the LLM with the word" Thought"
component = f"Thought: {message.content}"
else:
component = message.content
else:
raise NotImplementedError()
finalized.append(component)
if messages and isinstance(messages[-1], (HumanMessage, FunctionCall)):
# Prime the AI to think
finalized.append("Thought:")
return "\n".join(finalized)
def to_messages(self) -> List[BaseMessage]:

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import dataclasses
from typing import Any, Optional, Sequence, List, Mapping
from typing import Any, Optional, Sequence, List, Mapping, overload, Union
from langchain.schema import (
BaseMessage,
@@ -26,18 +26,41 @@ class AgentFinish:
result: Any
MessageLike = BaseMessage | FunctionCall | FunctionResult | AgentFinish
MessageLike = Union[BaseMessage, FunctionCall, FunctionResult, AgentFinish]
class MessageLog:
"""A generalized message log for message like items."""
def __init__(self, messages: Sequence[MessageLike]):
def __init__(self, messages: Sequence[MessageLike]) -> None:
"""Initialize the message log."""
self.messages = list(messages)
def add_messages(self, messages: Sequence[MessageLike]):
def add_messages(self, messages: Sequence[MessageLike]) -> None:
"""Add messages to the message log."""
self.messages.extend(messages)
@overload
def __getitem__(self, index: int) -> MessageLike:
...
@overload
def __getitem__(self, index: slice) -> MessageLog:
...
def __getitem__(self, index: Union[int, slice]) -> Union[MessageLike, MessageLog]:
"""Use to index into the chat template."""
if isinstance(index, slice):
start, stop, step = index.indices(len(self.messages))
messages = self.messages[start:stop:step]
return MessageLog(messages=messages)
else:
return self.messages[index]
def __len__(self) -> int:
"""Get the length of the chat template."""
return len(self.messages)
class MessageLogPromptValue(PromptValue):
"""Base abstract class for inputs to any language model.