mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-03 15:55:44 +00:00
x
This commit is contained in:
@@ -6,8 +6,8 @@ from typing import List, Sequence
|
||||
from langchain.automaton.chat_agent import ChatAgent
|
||||
from langchain.automaton.typedefs import (
|
||||
AgentFinish,
|
||||
FunctionCall,
|
||||
FunctionResult,
|
||||
FunctionCallRequest,
|
||||
FunctionCallResponse,
|
||||
MessageLike,
|
||||
)
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
@@ -34,7 +34,7 @@ class OpenAIFunctionsParser(BaseGenerationOutputParser):
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error parsing result: {result} {repr(e)}") from e
|
||||
|
||||
return FunctionCall(
|
||||
return FunctionCallRequest(
|
||||
name=function_request["name"],
|
||||
named_arguments=function_request["arguments"],
|
||||
)
|
||||
@@ -46,7 +46,7 @@ def prompt_generator(input_messages: Sequence[MessageLike]) -> List[BaseMessage]
|
||||
for message in input_messages:
|
||||
if isinstance(message, BaseMessage):
|
||||
messages.append(message)
|
||||
elif isinstance(message, FunctionResult):
|
||||
elif isinstance(message, FunctionCallResponse):
|
||||
messages.append(
|
||||
FunctionMessage(name=message.name, content=json.dumps(message.result))
|
||||
)
|
||||
|
||||
@@ -16,8 +16,8 @@ from langchain.automaton.typedefs import (
|
||||
AdHocMessage,
|
||||
Agent,
|
||||
AgentFinish,
|
||||
FunctionCall,
|
||||
FunctionResult,
|
||||
FunctionCallRequest,
|
||||
FunctionCallResponse,
|
||||
MessageLike,
|
||||
)
|
||||
from langchain.prompts import SystemMessagePromptTemplate
|
||||
@@ -116,7 +116,7 @@ class ActionParser:
|
||||
data=f"Invalid action blob {action_blob}, action_input must be a dict",
|
||||
)
|
||||
|
||||
return FunctionCall(
|
||||
return FunctionCallRequest(
|
||||
name=data["action"], named_arguments=named_arguments or {}
|
||||
)
|
||||
else:
|
||||
@@ -155,13 +155,13 @@ class ThinkActPromptGenerator(PromptValue):
|
||||
finalized.append(component)
|
||||
continue
|
||||
|
||||
if isinstance(message, FunctionResult):
|
||||
if isinstance(message, FunctionCallResponse):
|
||||
component = f"Observation: {message.result}"
|
||||
elif isinstance(message, HumanMessage):
|
||||
component = f"Question: {message.content.strip()}"
|
||||
elif isinstance(message, (AIMessage, SystemMessage)):
|
||||
component = message.content.strip()
|
||||
elif isinstance(message, FunctionCall):
|
||||
elif isinstance(message, FunctionCallRequest):
|
||||
# This is an internal message, and should not be returned to the user.
|
||||
continue
|
||||
elif isinstance(message, AgentFinish):
|
||||
@@ -177,7 +177,7 @@ class ThinkActPromptGenerator(PromptValue):
|
||||
for message in self.messages:
|
||||
if isinstance(message, BaseMessage):
|
||||
messages.append(message)
|
||||
elif isinstance(message, FunctionResult):
|
||||
elif isinstance(message, FunctionCallResponse):
|
||||
messages.append(
|
||||
SystemMessage(content=f"Observation: `{message.result}`")
|
||||
)
|
||||
@@ -222,7 +222,7 @@ class ThinkActAgent(Agent):
|
||||
break
|
||||
|
||||
if all_messages and isinstance(
|
||||
all_messages[-1], (FunctionResult, HumanMessage)
|
||||
all_messages[-1], (FunctionCallResponse, HumanMessage)
|
||||
):
|
||||
all_messages.append(AdHocMessage(type="prime", data="Thought:"))
|
||||
|
||||
|
||||
@@ -8,8 +8,8 @@ from langchain.automaton.chat_agent import ChatAgent
|
||||
from langchain.automaton.tool_utils import generate_tool_info
|
||||
from langchain.automaton.typedefs import (
|
||||
AgentFinish,
|
||||
FunctionCall,
|
||||
FunctionResult,
|
||||
FunctionCallRequest,
|
||||
FunctionCallResponse,
|
||||
MessageLike,
|
||||
)
|
||||
from langchain.prompts import SystemMessagePromptTemplate
|
||||
@@ -66,7 +66,7 @@ def _decode(text: Union[BaseMessage, str]) -> MessageLike:
|
||||
name = data["action"]
|
||||
if name == "Final Answer": # Special cased "tool" for final answer
|
||||
return AgentFinish(result=data["action_input"])
|
||||
return FunctionCall(
|
||||
return FunctionCallRequest(
|
||||
name=data["action"], named_arguments=data["action_input"] or {}
|
||||
)
|
||||
else:
|
||||
@@ -79,7 +79,7 @@ def generate_prompt(current_messages: Sequence[MessageLike]) -> List[BaseMessage
|
||||
for message in current_messages:
|
||||
if isinstance(message, BaseMessage):
|
||||
messages.append(message)
|
||||
elif isinstance(message, FunctionResult):
|
||||
elif isinstance(message, FunctionCallResponse):
|
||||
messages.append(
|
||||
HumanMessage(
|
||||
content=f"Observation: {message.result}",
|
||||
|
||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
import abc
|
||||
from typing import Mapping, Any, Callable, List, Sequence, Optional
|
||||
|
||||
from langchain.automaton.typedefs import MessageLike, FunctionResult
|
||||
from langchain.automaton.typedefs import MessageLike, FunctionCallResponse
|
||||
from langchain.schema import BaseMessage, PromptValue, HumanMessage
|
||||
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from langchain.automaton.typedefs import (
|
||||
AgentFinish,
|
||||
MessageLike,
|
||||
RetrievalRequest,
|
||||
RetrievalResult,
|
||||
RetrievalResponse,
|
||||
)
|
||||
from langchain.schema import PromptValue, BaseRetriever
|
||||
from langchain.schema.language_model import (
|
||||
@@ -41,7 +41,7 @@ def prompt_generator(input_messages: Sequence[MessageLike]) -> List[BaseMessage]
|
||||
for message in input_messages:
|
||||
if isinstance(message, BaseMessage):
|
||||
messages.append(message)
|
||||
elif isinstance(message, RetrievalResult):
|
||||
elif isinstance(message, RetrievalResponse):
|
||||
prompt = ""
|
||||
|
||||
if message.results:
|
||||
|
||||
@@ -5,10 +5,10 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union
|
||||
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
from langchain.automaton.typedefs import (
|
||||
FunctionCall,
|
||||
FunctionResult,
|
||||
FunctionCallRequest,
|
||||
FunctionCallResponse,
|
||||
MessageLike,
|
||||
RetrievalResult,
|
||||
RetrievalResponse,
|
||||
RetrievalRequest,
|
||||
)
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
@@ -111,7 +111,7 @@ def _to_retriever_input(message: MessageLike) -> str:
|
||||
|
||||
def create_tool_invoker(
|
||||
tools: Sequence[BaseTool],
|
||||
) -> Runnable[MessageLike, Optional[FunctionResult]]:
|
||||
) -> Runnable[MessageLike, Optional[FunctionCallResponse]]:
|
||||
"""See if possible to re-write with router
|
||||
|
||||
TODO:
|
||||
@@ -125,10 +125,10 @@ def create_tool_invoker(
|
||||
function_call: MessageLike,
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
) -> Optional[FunctionResult]:
|
||||
) -> Optional[FunctionCallResponse]:
|
||||
"""A function that can invoke a tool using .run"""
|
||||
if not isinstance(
|
||||
function_call, FunctionCall
|
||||
function_call, FunctionCallRequest
|
||||
): # TODO(Hack): Workaround lack of conditional apply
|
||||
return None
|
||||
try:
|
||||
@@ -145,16 +145,16 @@ def create_tool_invoker(
|
||||
result = None
|
||||
error = repr(e) + repr(function_call.named_arguments)
|
||||
|
||||
return FunctionResult(name=function_call.name, result=result, error=error)
|
||||
return FunctionCallResponse(name=function_call.name, result=result, error=error)
|
||||
|
||||
async def afunc(
|
||||
function_call: MessageLike,
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
) -> Optional[FunctionResult]:
|
||||
) -> Optional[FunctionCallResponse]:
|
||||
"""A function that can invoke a tool using .run"""
|
||||
if not isinstance(
|
||||
function_call, FunctionCall
|
||||
function_call, FunctionCallRequest
|
||||
): # TODO(Hack): Workaround lack of conditional apply
|
||||
return None
|
||||
try:
|
||||
@@ -171,7 +171,7 @@ def create_tool_invoker(
|
||||
result = None
|
||||
error = repr(e) + repr(function_call.named_arguments)
|
||||
|
||||
return FunctionResult(name=function_call.name, result=result, error=error)
|
||||
return FunctionCallResponse(name=function_call.name, result=result, error=error)
|
||||
|
||||
return RunnableLambda(func=func, afunc=afunc)
|
||||
|
||||
@@ -240,16 +240,16 @@ def create_llm_program(
|
||||
|
||||
def create_retriever(
|
||||
base_retriever: BaseRetriever,
|
||||
) -> Runnable[RetrievalRequest, RetrievalResult]:
|
||||
) -> Runnable[RetrievalRequest, RetrievalResponse]:
|
||||
"""Create a runnable retriever that uses messages."""
|
||||
|
||||
def _from_retrieval_request(request: RetrievalRequest) -> str:
|
||||
"""Convert a message to a list of documents."""
|
||||
return request.query
|
||||
|
||||
def _to_retrieval_result(docs: List[Document]) -> RetrievalResult:
|
||||
def _to_retrieval_result(docs: List[Document]) -> RetrievalResponse:
|
||||
"""Convert a list of documents to a message."""
|
||||
return RetrievalResult(results=docs)
|
||||
return RetrievalResponse(results=docs)
|
||||
|
||||
return (
|
||||
RunnableLambda(_from_retrieval_request)
|
||||
|
||||
@@ -11,8 +11,8 @@ from langchain.automaton.tests.utils import (
|
||||
)
|
||||
from langchain.automaton.typedefs import (
|
||||
AgentFinish,
|
||||
FunctionCall,
|
||||
FunctionResult,
|
||||
FunctionCallRequest,
|
||||
FunctionCallResponse,
|
||||
MessageLog,
|
||||
)
|
||||
from langchain.schema.messages import (
|
||||
@@ -74,11 +74,11 @@ def test_openai_agent(tools: List[Tool]) -> None:
|
||||
}
|
||||
},
|
||||
),
|
||||
FunctionCall(
|
||||
FunctionCallRequest(
|
||||
name="get_time",
|
||||
arguments={},
|
||||
),
|
||||
FunctionResult(
|
||||
FunctionCallResponse(
|
||||
name="get_time",
|
||||
result="9 PM",
|
||||
error=None,
|
||||
|
||||
@@ -13,7 +13,7 @@ from langchain.automaton.runnables import (
|
||||
from langchain.automaton.tests.utils import (
|
||||
FakeChatModel,
|
||||
)
|
||||
from langchain.automaton.typedefs import FunctionCall, FunctionResult, MessageLike
|
||||
from langchain.automaton.typedefs import FunctionCallRequest, FunctionCallResponse, MessageLike
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from langchain.schema.runnable import RunnableLambda
|
||||
@@ -127,11 +127,11 @@ def test_llm_program_with_parser(fake_llm: BaseLanguageModel) -> None:
|
||||
[AIMessage(content="Hello"), AIMessage(content="Goodbye")],
|
||||
),
|
||||
(
|
||||
RunnableLambda(lambda msg: FunctionCall(name="get_time")),
|
||||
RunnableLambda(lambda msg: FunctionCallRequest(name="get_time")),
|
||||
[
|
||||
AIMessage(content="Hello"),
|
||||
FunctionCall(name="get_time"),
|
||||
FunctionResult(result="9 PM", name="get_time"),
|
||||
FunctionCallRequest(name="get_time"),
|
||||
FunctionCallResponse(result="9 PM", name="get_time"),
|
||||
],
|
||||
),
|
||||
],
|
||||
|
||||
@@ -13,7 +13,9 @@ class InternalMessage(Serializable):
|
||||
return True
|
||||
|
||||
|
||||
class FunctionCall(InternalMessage): # TODO(Eugene): Rename as FunctionCallRequest
|
||||
class FunctionCallRequest(
|
||||
InternalMessage
|
||||
): # TODO(Eugene): Rename as FunctionCallRequest
|
||||
"""A request for a function invocation.
|
||||
|
||||
This message can be used to request a function invocation
|
||||
@@ -28,19 +30,24 @@ class FunctionCall(InternalMessage): # TODO(Eugene): Rename as FunctionCallRequ
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
"""Return a string representation of the object."""
|
||||
return f"FunctionCall(name={self.name}, named_arguments={self.named_arguments})"
|
||||
|
||||
|
||||
class FunctionResult(InternalMessage): # Rename as FunctionCallResult
|
||||
class FunctionCallResponse(InternalMessage): # Rename as FunctionCallResult
|
||||
"""A result of a function invocation."""
|
||||
|
||||
name: str
|
||||
result: Any
|
||||
error: Optional[str] = None
|
||||
|
||||
def __str__(self):
|
||||
return f"FunctionResult(name={self.name}, result={self.result}, error={self.error})"
|
||||
def __str__(self) -> str:
|
||||
"""Return a string representation of the object."""
|
||||
return (
|
||||
f"FunctionResult(name={self.name}, result={self.result}, "
|
||||
f"error={self.error})"
|
||||
)
|
||||
|
||||
|
||||
class RetrievalRequest(InternalMessage):
|
||||
@@ -49,16 +56,18 @@ class RetrievalRequest(InternalMessage):
|
||||
query: str
|
||||
"""The query to use for the retrieval."""
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
"""Return a string representation of the object."""
|
||||
return f"RetrievalRequest(query={self.query})"
|
||||
|
||||
|
||||
class RetrievalResult(InternalMessage):
|
||||
class RetrievalResponse(InternalMessage):
|
||||
"""A result of a retrieval."""
|
||||
|
||||
results: Sequence[Document]
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
"""Return a string representation of the object."""
|
||||
return f"RetrievalResults(results={self.results})"
|
||||
|
||||
|
||||
@@ -68,7 +77,8 @@ class AdHocMessage(InternalMessage):
|
||||
type: str
|
||||
data: Any # Make sure this is serializable
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
"""Return a string representation of the object."""
|
||||
return f"AdHocMessage(type={self.type}, data={self.data})"
|
||||
|
||||
|
||||
@@ -77,7 +87,8 @@ class AgentFinish(InternalMessage):
|
||||
|
||||
result: Any
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
"""Return a string representation of the object."""
|
||||
return f"AgentFinish(result={self.result})"
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user