mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-03 15:55:44 +00:00
Compare commits
21 Commits
cc/oai_ima
...
eugene/aut
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
15d5c49076 | ||
|
|
65660535bc | ||
|
|
6c41dd82f0 | ||
|
|
78d788c28c | ||
|
|
e9deeab37f | ||
|
|
4d595eec5b | ||
|
|
047b001336 | ||
|
|
840e936c7c | ||
|
|
6cc6b490be | ||
|
|
f45d1ed4f5 | ||
|
|
4ffc417858 | ||
|
|
9e74a70859 | ||
|
|
0997f2c0f1 | ||
|
|
437b545426 | ||
|
|
9b9d07572b | ||
|
|
6a90c6c2c8 | ||
|
|
8371187689 | ||
|
|
4309c17ffa | ||
|
|
183a9d4e66 | ||
|
|
c1b444e1e7 | ||
|
|
5f117384c0 |
0
libs/langchain/langchain/automaton/__init__.py
Normal file
0
libs/langchain/langchain/automaton/__init__.py
Normal file
39
libs/langchain/langchain/automaton/automaton.py
Normal file
39
libs/langchain/langchain/automaton/automaton.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""Module containing an automaton definition."""
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from typing import (
|
||||
TypedDict,
|
||||
Mapping,
|
||||
Any,
|
||||
Protocol,
|
||||
)
|
||||
|
||||
from langchain.automaton.typedefs import Memory
|
||||
|
||||
|
||||
class ExecutedState(TypedDict):
|
||||
"""The response of an action taking LLM."""
|
||||
|
||||
id: str # the ID of the state that was just executed
|
||||
data: Mapping[str, Any]
|
||||
|
||||
|
||||
class State(Protocol):
|
||||
"""Automaton state protocol."""
|
||||
|
||||
def execute(self, memory: Memory) -> ExecutedState:
|
||||
"""Execute the state, returning the result."""
|
||||
...
|
||||
|
||||
|
||||
class Automaton:
|
||||
@abc.abstractmethod
|
||||
def get_start_state(self, *args: Any, **kwargs: Any) -> State:
|
||||
"""Get the start state."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_next_state(self, executed_state: ExecutedState) -> State:
|
||||
"""Get the next state."""
|
||||
raise NotImplementedError()
|
||||
71
libs/langchain/langchain/automaton/chat_automaton.py
Normal file
71
libs/langchain/langchain/automaton/chat_automaton.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Sequence
|
||||
|
||||
from langchain.automaton.automaton import ExecutedState, State, Automaton
|
||||
from langchain.automaton.typedefs import (
|
||||
MessageType,
|
||||
infer_message_type,
|
||||
PromptGenerator,
|
||||
)
|
||||
from langchain.automaton.well_known_states import LLMProgram, UserInputState
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
|
||||
class ChatAutomaton(Automaton):
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
prompt_generator: PromptGenerator,
|
||||
) -> None:
|
||||
"""Initialize the chat automaton."""
|
||||
super().__init__()
|
||||
self.llm = llm
|
||||
self.tools = tools
|
||||
# TODO: Fix mutability of chat template, potentially add factory method
|
||||
self.prompt_generator = prompt_generator
|
||||
self.llm_program_state = LLMProgram(
|
||||
llm=self.llm,
|
||||
tools=self.tools,
|
||||
prompt_generator=self.prompt_generator,
|
||||
)
|
||||
|
||||
def get_start_state(self, *args: Any, **kwargs: Any) -> State:
|
||||
"""Get the start state."""
|
||||
return self.llm_program_state
|
||||
|
||||
def get_next_state(
|
||||
self, executed_state: ExecutedState # Add memory for transition functions?
|
||||
) -> State:
|
||||
"""Get the next state."""
|
||||
previous_state_id = executed_state["id"]
|
||||
data = executed_state["data"]
|
||||
|
||||
if previous_state_id == "user_input":
|
||||
return self.llm_program_state
|
||||
elif previous_state_id == "llm_program":
|
||||
message_type = infer_message_type(data["message"])
|
||||
if message_type == MessageType.AI:
|
||||
return UserInputState()
|
||||
elif message_type == MessageType.FUNCTION:
|
||||
return self.llm_program_state
|
||||
else:
|
||||
raise AssertionError(f"Unknown message type: {message_type}")
|
||||
else:
|
||||
raise ValueError(f"Unknown state ID: {previous_state_id}")
|
||||
|
||||
|
||||
# This is transition matrix syntax
|
||||
# transition_matrix = {
|
||||
# ("user_input", "*"): LLMProgram,
|
||||
# ("llm_program", MessageType.AI): UserInputState,
|
||||
# ("llm_program", MessageType.AI_SELF): LLMProgram,
|
||||
# (
|
||||
# "llm_program",
|
||||
# MessageType.AI_INVOKE,
|
||||
# ): FuncInvocationState, # But must add function message
|
||||
# ("func_invocation", MessageType.FUNCTION): LLMProgram,
|
||||
# }
|
||||
#
|
||||
38
libs/langchain/langchain/automaton/executor.py
Normal file
38
libs/langchain/langchain/automaton/executor.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import (
|
||||
List,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
from langchain.automaton.automaton import ExecutedState, State, Automaton
|
||||
from langchain.automaton.typedefs import Memory
|
||||
|
||||
|
||||
# Need to make into runnable
|
||||
# This is a for looping runnable... :)
|
||||
class Executor:
|
||||
def __init__(
|
||||
self, automaton: Automaton, memory: Memory, max_iterations: int
|
||||
) -> None:
|
||||
"""Initialize the executor."""
|
||||
self.automaton = automaton
|
||||
self.max_iterations = max_iterations
|
||||
self.memory = memory
|
||||
|
||||
def run(self) -> Tuple[State, List[ExecutedState]]:
|
||||
"""Run the automaton.
|
||||
|
||||
Returns:
|
||||
The final state and result of executed states.
|
||||
"""
|
||||
state = self.automaton.get_start_state()
|
||||
executed_states = []
|
||||
|
||||
for _ in range(self.max_iterations):
|
||||
executed_state = state.execute(memory=self.memory)
|
||||
executed_states.append(executed_state)
|
||||
# Should the transition function get memory?
|
||||
state = self.automaton.get_next_state(executed_state)
|
||||
|
||||
return state, executed_states
|
||||
136
libs/langchain/langchain/automaton/open_ai_functions.py
Normal file
136
libs/langchain/langchain/automaton/open_ai_functions.py
Normal file
@@ -0,0 +1,136 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from operator import itemgetter
|
||||
from typing import Any, Callable, List, Mapping, TypedDict, Union, Sequence, Optional
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
||||
from langchain.schema import BaseMessage, AIMessage
|
||||
from langchain.schema.output import ChatGeneration
|
||||
from langchain.schema.runnable import RouterRunnable, Runnable, RunnableBinding
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.convert_to_openai import format_tool_to_openai_function
|
||||
import json
|
||||
|
||||
|
||||
class OpenAIFunction(TypedDict):
|
||||
"""A function to call on the OpenAI API."""
|
||||
|
||||
name: str
|
||||
"""The name of the function."""
|
||||
description: str
|
||||
"""The description of the function."""
|
||||
parameters: dict
|
||||
"""The parameters to the function."""
|
||||
|
||||
|
||||
class OpenAIFunctionsRouter(RunnableBinding[ChatGeneration, Any]):
|
||||
"""A runnable that routes to the selected function."""
|
||||
|
||||
functions: List[OpenAIFunction]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
functions: List[OpenAIFunction],
|
||||
runnables: Mapping[
|
||||
str,
|
||||
Union[
|
||||
Runnable[dict, Any],
|
||||
Callable[[dict], Any],
|
||||
],
|
||||
],
|
||||
):
|
||||
assert len(functions) == len(runnables)
|
||||
assert all(func["name"] in runnables for func in functions)
|
||||
router = (
|
||||
JsonOutputFunctionsParser(args_only=False)
|
||||
| {"key": itemgetter("name"), "input": itemgetter("arguments")}
|
||||
| RouterRunnable(runnables)
|
||||
)
|
||||
super().__init__(bound=router, kwargs={}, functions=functions)
|
||||
|
||||
|
||||
class FunctionCall(TypedDict):
|
||||
name: str
|
||||
"""The name of the function."""
|
||||
arguments: dict
|
||||
"""The arguments to the function."""
|
||||
result: Any # Need to denote not invoked yet as well
|
||||
"""The result of the function call"""
|
||||
|
||||
|
||||
class ActingResult(TypedDict):
|
||||
"""The result of an action."""
|
||||
|
||||
message: BaseMessage
|
||||
"""The message that was passed to the action."""
|
||||
function_call: Optional[FunctionCall]
|
||||
|
||||
|
||||
def create_action_taking_llm(
|
||||
llm: BaseLanguageModel,
|
||||
*,
|
||||
tools: Sequence[BaseTool] = (),
|
||||
stop: Sequence[str] | None = None,
|
||||
invoke_function: bool = True,
|
||||
) -> Runnable:
|
||||
"""A chain that can create an action.
|
||||
|
||||
Args:
|
||||
llm: The language model to use.
|
||||
tools: The tools to use.
|
||||
stop: The stop tokens to use.
|
||||
invoke_function: Whether to invoke the function.
|
||||
|
||||
Returns:
|
||||
a segment of a runnable that take an action.
|
||||
"""
|
||||
|
||||
openai_funcs = [format_tool_to_openai_function(tool_) for tool_ in tools]
|
||||
|
||||
def _interpret_message(message: BaseMessage) -> ActingResult:
|
||||
"""Interpret a message."""
|
||||
if (
|
||||
isinstance(message, AIMessage)
|
||||
and "function_call" in message.additional_kwargs
|
||||
):
|
||||
if invoke_function:
|
||||
result = invoke_from_function.invoke( # TODO: fixme using invoke
|
||||
message
|
||||
)
|
||||
else:
|
||||
result = None
|
||||
return {
|
||||
"message": message,
|
||||
"function_call": {
|
||||
"name": message.additional_kwargs["function_call"]["name"],
|
||||
"arguments": json.loads(
|
||||
message.additional_kwargs["function_call"]["arguments"]
|
||||
),
|
||||
"result": result,
|
||||
# Check this works.
|
||||
# "result": message.additional_kwargs["function_call"]
|
||||
# | invoke_from_function,
|
||||
},
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"message": message,
|
||||
"function_call": None,
|
||||
}
|
||||
|
||||
invoke_from_function = OpenAIFunctionsRouter(
|
||||
functions=openai_funcs,
|
||||
runnables={
|
||||
openai_func["name"]: tool_
|
||||
for openai_func, tool_ in zip(openai_funcs, tools)
|
||||
},
|
||||
)
|
||||
|
||||
if stop:
|
||||
_llm = llm.bind(stop=stop)
|
||||
else:
|
||||
_llm = llm
|
||||
|
||||
chain = _llm.bind(functions=openai_funcs) | _interpret_message
|
||||
return chain
|
||||
99
libs/langchain/langchain/automaton/tests/test_automaton.py
Normal file
99
libs/langchain/langchain/automaton/tests/test_automaton.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import (
|
||||
List,
|
||||
Sequence,
|
||||
)
|
||||
|
||||
from langchain.automaton.chat_automaton import ChatAutomaton
|
||||
from langchain.automaton.executor import Executor
|
||||
from langchain.automaton.tests.utils import (
|
||||
FakeChatOpenAI,
|
||||
construct_func_invocation_message,
|
||||
)
|
||||
from langchain.automaton.typedefs import Memory
|
||||
from langchain.schema import PromptValue
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
SystemMessage,
|
||||
FunctionMessage,
|
||||
)
|
||||
from langchain.tools.base import tool as tool_maker
|
||||
|
||||
|
||||
class MessageBasedPromptValue(PromptValue):
|
||||
"""Prompt Value populated from messages."""
|
||||
|
||||
messages: List[BaseMessage]
|
||||
|
||||
@classmethod
|
||||
def from_messages(cls, messages: Sequence[BaseMessage]) -> MessageBasedPromptValue:
|
||||
return cls(messages=messages)
|
||||
|
||||
def to_messages(self) -> List[BaseMessage]:
|
||||
return self.messages
|
||||
|
||||
def to_string(self) -> str:
|
||||
return " ".join([message.content for message in self.messages])
|
||||
|
||||
|
||||
def prompt_generator(memory: Memory) -> PromptValue:
|
||||
"""Generate a prompt."""
|
||||
if not memory.messages:
|
||||
raise AssertionError("Memory is empty")
|
||||
return MessageBasedPromptValue.from_messages(messages=memory.messages)
|
||||
|
||||
|
||||
def test_automaton() -> None:
|
||||
"""Run the automaton."""
|
||||
|
||||
@tool_maker
|
||||
def get_time() -> str:
|
||||
"""Get time."""
|
||||
return "9 PM"
|
||||
|
||||
@tool_maker
|
||||
def get_location() -> str:
|
||||
"""Get location."""
|
||||
return "the park"
|
||||
|
||||
tools = [get_time, get_location]
|
||||
llm = FakeChatOpenAI(
|
||||
message_iter=iter(
|
||||
[
|
||||
construct_func_invocation_message(get_time, {}),
|
||||
AIMessage(
|
||||
content="The time is 9 PM.",
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
# TODO(FIX MUTABILITY)
|
||||
|
||||
memory = Memory(
|
||||
messages=[
|
||||
SystemMessage(
|
||||
content=(
|
||||
"Hello! I'm a chatbot that can help you write a letter. "
|
||||
"What would you like to do?"
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
chat_automaton = ChatAutomaton(
|
||||
llm=llm, tools=tools, prompt_generator=prompt_generator
|
||||
)
|
||||
executor = Executor(chat_automaton, memory, max_iterations=1)
|
||||
state, executed_states = executor.run()
|
||||
assert executed_states == [
|
||||
{
|
||||
"data": {
|
||||
"message": FunctionMessage(
|
||||
content="9 PM", additional_kwargs={}, name="get_time"
|
||||
)
|
||||
},
|
||||
"id": "llm_program",
|
||||
}
|
||||
]
|
||||
73
libs/langchain/langchain/automaton/tests/test_functions.py
Normal file
73
libs/langchain/langchain/automaton/tests/test_functions.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain.automaton.open_ai_functions import OpenAIFunctionsRouter
|
||||
from langchain.automaton.tests.utils import FakeChatOpenAI
|
||||
from langchain.schema import AIMessage
|
||||
from langchain.schema.runnable import RunnableLambda
|
||||
|
||||
|
||||
def test_openai_functions_router() -> None:
|
||||
"""Test the OpenAIFunctionsRouter."""
|
||||
|
||||
def revise(notes: str) -> str:
|
||||
"""Revises the draft."""
|
||||
return f"Revised draft: {notes}!"
|
||||
|
||||
def accept(draft: str) -> str:
|
||||
"""Accepts the draft."""
|
||||
return f"Accepted draft: {draft}!"
|
||||
|
||||
router = OpenAIFunctionsRouter(
|
||||
functions=[
|
||||
{
|
||||
"name": "revise",
|
||||
"description": "Sends the draft for revision.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"notes": {
|
||||
"type": "string",
|
||||
"description": "The editor's notes to guide the revision.",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "accept",
|
||||
"description": "Accepts the draft.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"draft": {
|
||||
"type": "string",
|
||||
"description": "The draft to accept.",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
runnables={
|
||||
"revise": RunnableLambda(lambda x: revise(x["revise"])),
|
||||
"accept": RunnableLambda(lambda x: accept(x["draft"])),
|
||||
},
|
||||
)
|
||||
|
||||
model = FakeChatOpenAI(
|
||||
message_iter=iter(
|
||||
[
|
||||
AIMessage(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {
|
||||
"name": "accept",
|
||||
"arguments": '{\n "draft": "turtles"\n}',
|
||||
}
|
||||
},
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
chain = model.bind(functions=router.functions) | router
|
||||
|
||||
assert chain.invoke("Something about turtles?") == "Accepted draft: turtles!"
|
||||
46
libs/langchain/langchain/automaton/tests/utils.py
Normal file
46
libs/langchain/langchain/automaton/tests/utils.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Iterator, List, Any, Mapping
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import BaseMessage, ChatResult, ChatGeneration, AIMessage
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
|
||||
class FakeChatOpenAI(BaseChatModel):
|
||||
"""A fake chat model that returns a pre-defined response."""
|
||||
|
||||
message_iter: Iterator[BaseMessage]
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""The type of the model."""
|
||||
return "fake-openai-chat-model"
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: List[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Generate a response to the given messages."""
|
||||
message = next(self.message_iter)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
|
||||
def construct_func_invocation_message(
|
||||
tool: BaseTool, args: Mapping[str, Any]
|
||||
) -> AIMessage:
|
||||
"""Construct a function invocation message."""
|
||||
return AIMessage(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {
|
||||
"name": tool.name,
|
||||
"arguments": json.dumps(args),
|
||||
}
|
||||
},
|
||||
)
|
||||
61
libs/langchain/langchain/automaton/typedefs.py
Normal file
61
libs/langchain/langchain/automaton/typedefs.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import enum
|
||||
from typing import Callable
|
||||
|
||||
from langchain.schema import (
|
||||
BaseMessage,
|
||||
FunctionMessage,
|
||||
AIMessage,
|
||||
SystemMessage,
|
||||
HumanMessage,
|
||||
BaseChatMessageHistory,
|
||||
PromptValue,
|
||||
)
|
||||
|
||||
|
||||
class MessageType(enum.Enum):
|
||||
"""The type of message."""
|
||||
|
||||
SYSTEM = enum.auto()
|
||||
USER = enum.auto()
|
||||
FUNCTION = enum.auto()
|
||||
AI = enum.auto()
|
||||
AI_INVOKE = enum.auto()
|
||||
|
||||
|
||||
def infer_message_type(message: BaseMessage) -> MessageType:
|
||||
"""Infer the message type."""
|
||||
if isinstance(message, FunctionMessage):
|
||||
return MessageType.FUNCTION
|
||||
elif isinstance(message, AIMessage):
|
||||
if message.additional_kwargs:
|
||||
return MessageType.AI_INVOKE
|
||||
else:
|
||||
return MessageType.AI
|
||||
elif isinstance(message, SystemMessage):
|
||||
return MessageType.SYSTEM
|
||||
elif isinstance(message, HumanMessage):
|
||||
return MessageType.USER
|
||||
else:
|
||||
raise ValueError(f"Unknown message type: {type(message)}")
|
||||
|
||||
|
||||
class Memory(BaseChatMessageHistory):
|
||||
"""A memory for the automaton."""
|
||||
|
||||
def __init__(self, messages):
|
||||
self.messages = messages
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Add a message to the memory."""
|
||||
self.messages.append(message)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear the memory."""
|
||||
self.messages = []
|
||||
|
||||
|
||||
# Interface that takes memory and returns a prompt value
|
||||
PromptGenerator = Callable[[Memory], PromptValue]
|
||||
79
libs/langchain/langchain/automaton/well_known_states.py
Normal file
79
libs/langchain/langchain/automaton/well_known_states.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from typing import Sequence
|
||||
|
||||
from langchain.automaton.automaton import State, ExecutedState
|
||||
from langchain.automaton.open_ai_functions import create_action_taking_llm
|
||||
from langchain.automaton.typedefs import (
|
||||
Memory,
|
||||
PromptGenerator,
|
||||
infer_message_type,
|
||||
MessageType,
|
||||
)
|
||||
from langchain.schema import HumanMessage, FunctionMessage, AIMessage
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LLMProgram(State):
|
||||
"""A state that executes an LLM program."""
|
||||
|
||||
llm: BaseLanguageModel
|
||||
tools: Sequence[BaseTool]
|
||||
prompt_generator: PromptGenerator
|
||||
|
||||
def execute(self, memory: Memory) -> ExecutedState:
|
||||
"""Execute LLM program."""
|
||||
action_taking_llm = create_action_taking_llm(self.llm, tools=self.tools)
|
||||
prompt_value = self.prompt_generator(memory)
|
||||
result = action_taking_llm.invoke(prompt_value)
|
||||
# Memory is mutable
|
||||
message = result["message"]
|
||||
if not isinstance(message, AIMessage):
|
||||
raise AssertionError(
|
||||
f"LLM program should return an AI message. Got a {type(message)}."
|
||||
)
|
||||
memory.add_message(message)
|
||||
|
||||
if infer_message_type(message) == MessageType.AI_INVOKE:
|
||||
function_call = result["function_call"]
|
||||
function_message = FunctionMessage(
|
||||
name=function_call["name"],
|
||||
content=function_call["result"],
|
||||
)
|
||||
memory.add_message(function_message)
|
||||
routing_message = function_message
|
||||
else:
|
||||
routing_message = message
|
||||
|
||||
# What information should the state return in this case.
|
||||
# Does it matter, folks can use it or not...
|
||||
return {
|
||||
"id": "llm_program",
|
||||
"data": {
|
||||
"message": routing_message, # Last message
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class UserInputState(State):
|
||||
"""A state that prompts the user for input from stdin.
|
||||
|
||||
This is primarily useful for interactive development.
|
||||
"""
|
||||
|
||||
def execute(self, memory: Memory) -> ExecutedState:
|
||||
"""Execute user input state."""
|
||||
user_input = input("Enter your input: ")
|
||||
message = HumanMessage(content=user_input)
|
||||
memory.add_message(message)
|
||||
|
||||
return {
|
||||
"id": "user_input",
|
||||
"data": {
|
||||
"message": message,
|
||||
},
|
||||
}
|
||||
@@ -25,8 +25,8 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
|
||||
)
|
||||
message = generation.message
|
||||
try:
|
||||
func_call = message.additional_kwargs["function_call"]
|
||||
except ValueError as exc:
|
||||
func_call = message.additional_kwargs["function_call"].copy()
|
||||
except KeyError as exc:
|
||||
raise OutputParserException(f"Could not parse function call: {exc}")
|
||||
|
||||
if self.args_only:
|
||||
@@ -38,11 +38,16 @@ class JsonOutputFunctionsParser(OutputFunctionsParser):
|
||||
"""Parse an output as the Json object."""
|
||||
|
||||
def parse_result(self, result: List[Generation]) -> Any:
|
||||
func = super().parse_result(result)
|
||||
function_call_info = super().parse_result(result)
|
||||
if self.args_only:
|
||||
return json.loads(func)
|
||||
func["arguments"] = json.loads(func["arguments"])
|
||||
return func
|
||||
try:
|
||||
return json.loads(function_call_info)
|
||||
except (json.JSONDecodeError, TypeError) as exc:
|
||||
raise OutputParserException(
|
||||
f"Could not parse function call data: {exc}"
|
||||
)
|
||||
function_call_info["arguments"] = json.loads(function_call_info["arguments"])
|
||||
return function_call_info
|
||||
|
||||
|
||||
class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
|
||||
|
||||
@@ -316,6 +316,16 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
|
||||
"""Format kwargs into a list of messages."""
|
||||
|
||||
|
||||
MessageLike = Union[
|
||||
BaseMessagePromptTemplate,
|
||||
BaseChatPromptTemplate,
|
||||
BaseMessage,
|
||||
Tuple[str, str],
|
||||
Tuple[Type, str],
|
||||
str,
|
||||
]
|
||||
|
||||
|
||||
class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
||||
"""A prompt template for chat models.
|
||||
|
||||
@@ -363,6 +373,9 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
||||
other, (BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate)
|
||||
):
|
||||
return ChatPromptTemplate(messages=self.messages + [other])
|
||||
elif isinstance(other, (list, tuple)):
|
||||
_other = ChatPromptTemplate.from_messages(other)
|
||||
return ChatPromptTemplate(messages=self.messages + _other.messages)
|
||||
elif isinstance(other, str):
|
||||
prompt = HumanMessagePromptTemplate.from_template(other)
|
||||
return ChatPromptTemplate(messages=self.messages + [prompt])
|
||||
@@ -453,17 +466,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
||||
|
||||
@classmethod
|
||||
def from_messages(
|
||||
cls,
|
||||
messages: Sequence[
|
||||
Union[
|
||||
BaseMessagePromptTemplate,
|
||||
BaseChatPromptTemplate,
|
||||
BaseMessage,
|
||||
Tuple[str, str],
|
||||
Tuple[Type, str],
|
||||
str,
|
||||
]
|
||||
],
|
||||
cls, messages: Sequence[MessageLike],
|
||||
) -> ChatPromptTemplate:
|
||||
"""Create a chat prompt template from a variety of message formats.
|
||||
|
||||
@@ -589,6 +592,18 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
||||
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
|
||||
return type(self)(**prompt_dict)
|
||||
|
||||
def append(self, message: MessageLike) -> None:
|
||||
"""Append message to the end of the chat template.
|
||||
|
||||
Args:
|
||||
message: representation of a message to append.
|
||||
"""
|
||||
self.messages.append(_convert_to_message(message))
|
||||
|
||||
def extend(self, messages: Sequence[MessageLike]) -> None:
|
||||
"""Extend the chat template with a sequence of messages."""
|
||||
self.messages.extend([_convert_to_message(message) for message in messages])
|
||||
|
||||
@property
|
||||
def _prompt_type(self) -> str:
|
||||
"""Name of prompt type."""
|
||||
@@ -632,14 +647,7 @@ def _create_template_from_message_type(
|
||||
|
||||
|
||||
def _convert_to_message(
|
||||
message: Union[
|
||||
BaseMessagePromptTemplate,
|
||||
BaseChatPromptTemplate,
|
||||
BaseMessage,
|
||||
Tuple[str, str],
|
||||
Tuple[Type, str],
|
||||
str,
|
||||
]
|
||||
message: MessageLike
|
||||
) -> Union[BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate]:
|
||||
"""Instantiate a message from a variety of message formats.
|
||||
|
||||
|
||||
@@ -685,7 +685,7 @@ def tool(
|
||||
return_direct: bool = False,
|
||||
args_schema: Optional[Type[BaseModel]] = None,
|
||||
infer_schema: bool = True,
|
||||
) -> Callable:
|
||||
) -> Callable[[Callable], BaseTool]:
|
||||
"""Make tools out of functions, can be used with or without arguments.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.output_parsers.openai_functions import (
|
||||
JsonOutputFunctionsParser,
|
||||
)
|
||||
from langchain.schema import BaseMessage, ChatGeneration, OutputParserException
|
||||
from langchain.schema.messages import AIMessage, HumanMessage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ai_message() -> AIMessage:
|
||||
"""Return a simple AIMessage."""
|
||||
content = "This is a test message"
|
||||
|
||||
args = json.dumps(
|
||||
{
|
||||
"arg1": "value1",
|
||||
}
|
||||
)
|
||||
|
||||
function_call = {"name": "function_name", "arguments": args}
|
||||
additional_kwargs = {"function_call": function_call}
|
||||
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
||||
|
||||
|
||||
def test_json_output_function_parser(ai_message: AIMessage) -> None:
|
||||
"""Test that the JsonOutputFunctionsParser with full output."""
|
||||
chat_generation = ChatGeneration(message=ai_message)
|
||||
|
||||
# Full output
|
||||
parser = JsonOutputFunctionsParser(args_only=False)
|
||||
result = parser.parse_result([chat_generation])
|
||||
assert result == {"arguments": {"arg1": "value1"}, "name": "function_name"}
|
||||
|
||||
# Args only
|
||||
parser = JsonOutputFunctionsParser(args_only=True)
|
||||
result = parser.parse_result([chat_generation])
|
||||
assert result == {"arg1": "value1"}
|
||||
|
||||
# Verify that the original message is not modified
|
||||
assert ai_message.additional_kwargs == {
|
||||
"function_call": {"name": "function_name", "arguments": '{"arg1": "value1"}'}
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"bad_message",
|
||||
[
|
||||
# Human message has no function call
|
||||
HumanMessage(content="This is a test message"),
|
||||
# AIMessage has no function call information.
|
||||
AIMessage(content="This is a test message", additional_kwargs={}),
|
||||
# Bad function call information (arguments should be a string)
|
||||
AIMessage(
|
||||
content="This is a test message",
|
||||
additional_kwargs={
|
||||
"function_call": {"name": "function_name", "arguments": {}}
|
||||
},
|
||||
),
|
||||
# Bad function call information (arguments should be proper json)
|
||||
AIMessage(
|
||||
content="This is a test message",
|
||||
additional_kwargs={
|
||||
"function_call": {"name": "function_name", "arguments": "noqweqwe"}
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_exceptions_raised_while_parsing(bad_message: BaseMessage) -> None:
|
||||
"""Test exceptions raised correctly while using JSON parser."""
|
||||
chat_generation = ChatGeneration(message=bad_message)
|
||||
|
||||
with pytest.raises(OutputParserException):
|
||||
JsonOutputFunctionsParser().parse_result([chat_generation])
|
||||
Reference in New Issue
Block a user