This commit is contained in:
Eugene Yurtsev
2023-09-08 16:40:27 -04:00
parent 1cd9fb6444
commit 5601d60df0
9 changed files with 60 additions and 49 deletions

View File

@@ -6,8 +6,8 @@ from typing import List, Sequence
from langchain.automaton.chat_agent import ChatAgent
from langchain.automaton.typedefs import (
AgentFinish,
FunctionCall,
FunctionResult,
FunctionCallRequest,
FunctionCallResponse,
MessageLike,
)
from langchain.chat_models.openai import ChatOpenAI
@@ -34,7 +34,7 @@ class OpenAIFunctionsParser(BaseGenerationOutputParser):
except Exception as e:
raise RuntimeError(f"Error parsing result: {result} {repr(e)}") from e
return FunctionCall(
return FunctionCallRequest(
name=function_request["name"],
named_arguments=function_request["arguments"],
)
@@ -46,7 +46,7 @@ def prompt_generator(input_messages: Sequence[MessageLike]) -> List[BaseMessage]
for message in input_messages:
if isinstance(message, BaseMessage):
messages.append(message)
elif isinstance(message, FunctionResult):
elif isinstance(message, FunctionCallResponse):
messages.append(
FunctionMessage(name=message.name, content=json.dumps(message.result))
)

View File

@@ -16,8 +16,8 @@ from langchain.automaton.typedefs import (
AdHocMessage,
Agent,
AgentFinish,
FunctionCall,
FunctionResult,
FunctionCallRequest,
FunctionCallResponse,
MessageLike,
)
from langchain.prompts import SystemMessagePromptTemplate
@@ -116,7 +116,7 @@ class ActionParser:
data=f"Invalid action blob {action_blob}, action_input must be a dict",
)
return FunctionCall(
return FunctionCallRequest(
name=data["action"], named_arguments=named_arguments or {}
)
else:
@@ -155,13 +155,13 @@ class ThinkActPromptGenerator(PromptValue):
finalized.append(component)
continue
if isinstance(message, FunctionResult):
if isinstance(message, FunctionCallResponse):
component = f"Observation: {message.result}"
elif isinstance(message, HumanMessage):
component = f"Question: {message.content.strip()}"
elif isinstance(message, (AIMessage, SystemMessage)):
component = message.content.strip()
elif isinstance(message, FunctionCall):
elif isinstance(message, FunctionCallRequest):
# This is an internal message, and should not be returned to the user.
continue
elif isinstance(message, AgentFinish):
@@ -177,7 +177,7 @@ class ThinkActPromptGenerator(PromptValue):
for message in self.messages:
if isinstance(message, BaseMessage):
messages.append(message)
elif isinstance(message, FunctionResult):
elif isinstance(message, FunctionCallResponse):
messages.append(
SystemMessage(content=f"Observation: `{message.result}`")
)
@@ -222,7 +222,7 @@ class ThinkActAgent(Agent):
break
if all_messages and isinstance(
all_messages[-1], (FunctionResult, HumanMessage)
all_messages[-1], (FunctionCallResponse, HumanMessage)
):
all_messages.append(AdHocMessage(type="prime", data="Thought:"))

View File

@@ -8,8 +8,8 @@ from langchain.automaton.chat_agent import ChatAgent
from langchain.automaton.tool_utils import generate_tool_info
from langchain.automaton.typedefs import (
AgentFinish,
FunctionCall,
FunctionResult,
FunctionCallRequest,
FunctionCallResponse,
MessageLike,
)
from langchain.prompts import SystemMessagePromptTemplate
@@ -66,7 +66,7 @@ def _decode(text: Union[BaseMessage, str]) -> MessageLike:
name = data["action"]
if name == "Final Answer": # Special cased "tool" for final answer
return AgentFinish(result=data["action_input"])
return FunctionCall(
return FunctionCallRequest(
name=data["action"], named_arguments=data["action_input"] or {}
)
else:
@@ -79,7 +79,7 @@ def generate_prompt(current_messages: Sequence[MessageLike]) -> List[BaseMessage
for message in current_messages:
if isinstance(message, BaseMessage):
messages.append(message)
elif isinstance(message, FunctionResult):
elif isinstance(message, FunctionCallResponse):
messages.append(
HumanMessage(
content=f"Observation: {message.result}",

View File

@@ -4,7 +4,7 @@ from __future__ import annotations
import abc
from typing import Mapping, Any, Callable, List, Sequence, Optional
from langchain.automaton.typedefs import MessageLike, FunctionResult
from langchain.automaton.typedefs import MessageLike, FunctionCallResponse
from langchain.schema import BaseMessage, PromptValue, HumanMessage

View File

@@ -17,7 +17,7 @@ from langchain.automaton.typedefs import (
AgentFinish,
MessageLike,
RetrievalRequest,
RetrievalResult,
RetrievalResponse,
)
from langchain.schema import PromptValue, BaseRetriever
from langchain.schema.language_model import (
@@ -41,7 +41,7 @@ def prompt_generator(input_messages: Sequence[MessageLike]) -> List[BaseMessage]
for message in input_messages:
if isinstance(message, BaseMessage):
messages.append(message)
elif isinstance(message, RetrievalResult):
elif isinstance(message, RetrievalResponse):
prompt = ""
if message.results:

View File

@@ -5,10 +5,10 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union
from langchain.schema.retriever import BaseRetriever
from langchain.automaton.typedefs import (
FunctionCall,
FunctionResult,
FunctionCallRequest,
FunctionCallResponse,
MessageLike,
RetrievalResult,
RetrievalResponse,
RetrievalRequest,
)
from langchain.callbacks.manager import CallbackManagerForChainRun
@@ -111,7 +111,7 @@ def _to_retriever_input(message: MessageLike) -> str:
def create_tool_invoker(
tools: Sequence[BaseTool],
) -> Runnable[MessageLike, Optional[FunctionResult]]:
) -> Runnable[MessageLike, Optional[FunctionCallResponse]]:
"""See if possible to re-write with router
TODO:
@@ -125,10 +125,10 @@ def create_tool_invoker(
function_call: MessageLike,
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
) -> Optional[FunctionResult]:
) -> Optional[FunctionCallResponse]:
"""A function that can invoke a tool using .run"""
if not isinstance(
function_call, FunctionCall
function_call, FunctionCallRequest
): # TODO(Hack): Workaround lack of conditional apply
return None
try:
@@ -145,16 +145,16 @@ def create_tool_invoker(
result = None
error = repr(e) + repr(function_call.named_arguments)
return FunctionResult(name=function_call.name, result=result, error=error)
return FunctionCallResponse(name=function_call.name, result=result, error=error)
async def afunc(
function_call: MessageLike,
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
) -> Optional[FunctionResult]:
) -> Optional[FunctionCallResponse]:
"""A function that can invoke a tool using .run"""
if not isinstance(
function_call, FunctionCall
function_call, FunctionCallRequest
): # TODO(Hack): Workaround lack of conditional apply
return None
try:
@@ -171,7 +171,7 @@ def create_tool_invoker(
result = None
error = repr(e) + repr(function_call.named_arguments)
return FunctionResult(name=function_call.name, result=result, error=error)
return FunctionCallResponse(name=function_call.name, result=result, error=error)
return RunnableLambda(func=func, afunc=afunc)
@@ -240,16 +240,16 @@ def create_llm_program(
def create_retriever(
base_retriever: BaseRetriever,
) -> Runnable[RetrievalRequest, RetrievalResult]:
) -> Runnable[RetrievalRequest, RetrievalResponse]:
"""Create a runnable retriever that uses messages."""
def _from_retrieval_request(request: RetrievalRequest) -> str:
"""Convert a message to a list of documents."""
return request.query
def _to_retrieval_result(docs: List[Document]) -> RetrievalResult:
def _to_retrieval_result(docs: List[Document]) -> RetrievalResponse:
"""Convert a list of documents to a message."""
return RetrievalResult(results=docs)
return RetrievalResponse(results=docs)
return (
RunnableLambda(_from_retrieval_request)

View File

@@ -11,8 +11,8 @@ from langchain.automaton.tests.utils import (
)
from langchain.automaton.typedefs import (
AgentFinish,
FunctionCall,
FunctionResult,
FunctionCallRequest,
FunctionCallResponse,
MessageLog,
)
from langchain.schema.messages import (
@@ -74,11 +74,11 @@ def test_openai_agent(tools: List[Tool]) -> None:
}
},
),
FunctionCall(
FunctionCallRequest(
name="get_time",
arguments={},
),
FunctionResult(
FunctionCallResponse(
name="get_time",
result="9 PM",
error=None,

View File

@@ -13,7 +13,7 @@ from langchain.automaton.runnables import (
from langchain.automaton.tests.utils import (
FakeChatModel,
)
from langchain.automaton.typedefs import FunctionCall, FunctionResult, MessageLike
from langchain.automaton.typedefs import FunctionCallRequest, FunctionCallResponse, MessageLike
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
from langchain.schema.runnable import RunnableLambda
@@ -127,11 +127,11 @@ def test_llm_program_with_parser(fake_llm: BaseLanguageModel) -> None:
[AIMessage(content="Hello"), AIMessage(content="Goodbye")],
),
(
RunnableLambda(lambda msg: FunctionCall(name="get_time")),
RunnableLambda(lambda msg: FunctionCallRequest(name="get_time")),
[
AIMessage(content="Hello"),
FunctionCall(name="get_time"),
FunctionResult(result="9 PM", name="get_time"),
FunctionCallRequest(name="get_time"),
FunctionCallResponse(result="9 PM", name="get_time"),
],
),
],

View File

@@ -13,7 +13,9 @@ class InternalMessage(Serializable):
return True
class FunctionCall(InternalMessage): # TODO(Eugene): Rename as FunctionCallRequest
class FunctionCallRequest(
InternalMessage
): # TODO(Eugene): Rename as FunctionCallRequest
"""A request for a function invocation.
This message can be used to request a function invocation
@@ -28,19 +30,24 @@ class FunctionCall(InternalMessage): # TODO(Eugene): Rename as FunctionCallRequ
class Config:
extra = "forbid"
def __str__(self):
def __str__(self) -> str:
"""Return a string representation of the object."""
return f"FunctionCall(name={self.name}, named_arguments={self.named_arguments})"
class FunctionResult(InternalMessage): # Rename as FunctionCallResult
class FunctionCallResponse(InternalMessage): # Rename as FunctionCallResult
"""A result of a function invocation."""
name: str
result: Any
error: Optional[str] = None
def __str__(self):
return f"FunctionResult(name={self.name}, result={self.result}, error={self.error})"
def __str__(self) -> str:
"""Return a string representation of the object."""
return (
f"FunctionResult(name={self.name}, result={self.result}, "
f"error={self.error})"
)
class RetrievalRequest(InternalMessage):
@@ -49,16 +56,18 @@ class RetrievalRequest(InternalMessage):
query: str
"""The query to use for the retrieval."""
def __str__(self):
def __str__(self) -> str:
"""Return a string representation of the object."""
return f"RetrievalRequest(query={self.query})"
class RetrievalResult(InternalMessage):
class RetrievalResponse(InternalMessage):
"""A result of a retrieval."""
results: Sequence[Document]
def __str__(self):
def __str__(self) -> str:
"""Return a string representation of the object."""
return f"RetrievalResults(results={self.results})"
@@ -68,7 +77,8 @@ class AdHocMessage(InternalMessage):
type: str
data: Any # Make sure this is serializable
def __str__(self):
def __str__(self) -> str:
"""Return a string representation of the object."""
return f"AdHocMessage(type={self.type}, data={self.data})"
@@ -77,7 +87,8 @@ class AgentFinish(InternalMessage):
result: Any
def __str__(self):
def __str__(self) -> str:
"""Return a string representation of the object."""
return f"AgentFinish(result={self.result})"