mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 08:33:49 +00:00
Harrison/add back in experimental (#8128)
This commit is contained in:
parent
8b08687fc4
commit
86946a47a8
16
MIGRATE.md
16
MIGRATE.md
@ -1,6 +1,6 @@
|
||||
# Migrating to `langchain.experimental`
|
||||
# Migrating to `langchain._experimental`
|
||||
|
||||
We are moving any experimental components of langchain, or components with vulnerability issues, into `langchain_experimental`.
|
||||
We are moving any experimental components of LangChain, or components with vulnerability issues, into `langchain_experimental`.
|
||||
This guide covers how to migrate.
|
||||
|
||||
## Installation
|
||||
@ -9,10 +9,20 @@ Previously:
|
||||
|
||||
`pip install -U langchain`
|
||||
|
||||
Now:
|
||||
Now (only if you want to access things in experimental):
|
||||
|
||||
`pip install -U langchain langchain_experimental`
|
||||
|
||||
## Things in `langchain.experimental`
|
||||
|
||||
Previously:
|
||||
|
||||
`from langchain.experimental import ...`
|
||||
|
||||
Now:
|
||||
|
||||
`from langchain_experimental import ...`
|
||||
|
||||
## PALChain
|
||||
|
||||
Previously:
|
||||
|
19
libs/langchain/langchain/experimental/__init__.py
Normal file
19
libs/langchain/langchain/experimental/__init__.py
Normal file
@ -0,0 +1,19 @@
|
||||
from langchain.experimental.autonomous_agents.autogpt.agent import AutoGPT
|
||||
from langchain.experimental.autonomous_agents.baby_agi.baby_agi import BabyAGI
|
||||
from langchain.experimental.generative_agents.generative_agent import GenerativeAgent
|
||||
from langchain.experimental.generative_agents.memory import GenerativeAgentMemory
|
||||
from langchain.experimental.plan_and_execute import (
|
||||
PlanAndExecute,
|
||||
load_agent_executor,
|
||||
load_chat_planner,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BabyAGI",
|
||||
"AutoGPT",
|
||||
"GenerativeAgent",
|
||||
"GenerativeAgentMemory",
|
||||
"PlanAndExecute",
|
||||
"load_agent_executor",
|
||||
"load_chat_planner",
|
||||
]
|
@ -0,0 +1,4 @@
|
||||
from langchain.experimental.autonomous_agents.autogpt.agent import AutoGPT
|
||||
from langchain.experimental.autonomous_agents.baby_agi.baby_agi import BabyAGI
|
||||
|
||||
__all__ = ["BabyAGI", "AutoGPT"]
|
@ -0,0 +1,143 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.experimental.autonomous_agents.autogpt.output_parser import (
|
||||
AutoGPTOutputParser,
|
||||
BaseAutoGPTOutputParser,
|
||||
)
|
||||
from langchain.experimental.autonomous_agents.autogpt.prompt import AutoGPTPrompt
|
||||
from langchain.experimental.autonomous_agents.autogpt.prompt_generator import (
|
||||
FINISH_NAME,
|
||||
)
|
||||
from langchain.memory import ChatMessageHistory
|
||||
from langchain.schema import (
|
||||
BaseChatMessageHistory,
|
||||
Document,
|
||||
)
|
||||
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.human.tool import HumanInputRun
|
||||
from langchain.vectorstores.base import VectorStoreRetriever
|
||||
|
||||
|
||||
class AutoGPT:
|
||||
"""Agent class for interacting with Auto-GPT."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ai_name: str,
|
||||
memory: VectorStoreRetriever,
|
||||
chain: LLMChain,
|
||||
output_parser: BaseAutoGPTOutputParser,
|
||||
tools: List[BaseTool],
|
||||
feedback_tool: Optional[HumanInputRun] = None,
|
||||
chat_history_memory: Optional[BaseChatMessageHistory] = None,
|
||||
):
|
||||
self.ai_name = ai_name
|
||||
self.memory = memory
|
||||
self.next_action_count = 0
|
||||
self.chain = chain
|
||||
self.output_parser = output_parser
|
||||
self.tools = tools
|
||||
self.feedback_tool = feedback_tool
|
||||
self.chat_history_memory = chat_history_memory or ChatMessageHistory()
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
ai_name: str,
|
||||
ai_role: str,
|
||||
memory: VectorStoreRetriever,
|
||||
tools: List[BaseTool],
|
||||
llm: BaseChatModel,
|
||||
human_in_the_loop: bool = False,
|
||||
output_parser: Optional[BaseAutoGPTOutputParser] = None,
|
||||
chat_history_memory: Optional[BaseChatMessageHistory] = None,
|
||||
) -> AutoGPT:
|
||||
prompt = AutoGPTPrompt(
|
||||
ai_name=ai_name,
|
||||
ai_role=ai_role,
|
||||
tools=tools,
|
||||
input_variables=["memory", "messages", "goals", "user_input"],
|
||||
token_counter=llm.get_num_tokens,
|
||||
)
|
||||
human_feedback_tool = HumanInputRun() if human_in_the_loop else None
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return cls(
|
||||
ai_name,
|
||||
memory,
|
||||
chain,
|
||||
output_parser or AutoGPTOutputParser(),
|
||||
tools,
|
||||
feedback_tool=human_feedback_tool,
|
||||
chat_history_memory=chat_history_memory,
|
||||
)
|
||||
|
||||
def run(self, goals: List[str]) -> str:
|
||||
user_input = (
|
||||
"Determine which next command to use, "
|
||||
"and respond using the format specified above:"
|
||||
)
|
||||
# Interaction Loop
|
||||
loop_count = 0
|
||||
while True:
|
||||
# Discontinue if continuous limit is reached
|
||||
loop_count += 1
|
||||
|
||||
# Send message to AI, get response
|
||||
assistant_reply = self.chain.run(
|
||||
goals=goals,
|
||||
messages=self.chat_history_memory.messages,
|
||||
memory=self.memory,
|
||||
user_input=user_input,
|
||||
)
|
||||
|
||||
# Print Assistant thoughts
|
||||
print(assistant_reply)
|
||||
self.chat_history_memory.add_message(HumanMessage(content=user_input))
|
||||
self.chat_history_memory.add_message(AIMessage(content=assistant_reply))
|
||||
|
||||
# Get command name and arguments
|
||||
action = self.output_parser.parse(assistant_reply)
|
||||
tools = {t.name: t for t in self.tools}
|
||||
if action.name == FINISH_NAME:
|
||||
return action.args["response"]
|
||||
if action.name in tools:
|
||||
tool = tools[action.name]
|
||||
try:
|
||||
observation = tool.run(action.args)
|
||||
except ValidationError as e:
|
||||
observation = (
|
||||
f"Validation Error in args: {str(e)}, args: {action.args}"
|
||||
)
|
||||
except Exception as e:
|
||||
observation = (
|
||||
f"Error: {str(e)}, {type(e).__name__}, args: {action.args}"
|
||||
)
|
||||
result = f"Command {tool.name} returned: {observation}"
|
||||
elif action.name == "ERROR":
|
||||
result = f"Error: {action.args}. "
|
||||
else:
|
||||
result = (
|
||||
f"Unknown command '{action.name}'. "
|
||||
f"Please refer to the 'COMMANDS' list for available "
|
||||
f"commands and only respond in the specified JSON format."
|
||||
)
|
||||
|
||||
memory_to_add = (
|
||||
f"Assistant Reply: {assistant_reply} " f"\nResult: {result} "
|
||||
)
|
||||
if self.feedback_tool is not None:
|
||||
feedback = f"\n{self.feedback_tool.run('Input: ')}"
|
||||
if feedback in {"q", "stop"}:
|
||||
print("EXITING")
|
||||
return "EXITING"
|
||||
memory_to_add += feedback
|
||||
|
||||
self.memory.add_documents([Document(page_content=memory_to_add)])
|
||||
self.chat_history_memory.add_message(SystemMessage(content=result))
|
@ -0,0 +1,30 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.memory.chat_memory import BaseChatMemory, get_prompt_input_key
|
||||
from langchain.vectorstores.base import VectorStoreRetriever
|
||||
|
||||
|
||||
class AutoGPTMemory(BaseChatMemory):
|
||||
retriever: VectorStoreRetriever = Field(exclude=True)
|
||||
"""VectorStoreRetriever object to connect to."""
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
return ["chat_history", "relevant_context"]
|
||||
|
||||
def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str:
|
||||
"""Get the input key for the prompt."""
|
||||
if self.input_key is None:
|
||||
return get_prompt_input_key(inputs, self.memory_variables)
|
||||
return self.input_key
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
input_key = self._get_prompt_input_key(inputs)
|
||||
query = inputs[input_key]
|
||||
docs = self.retriever.get_relevant_documents(query)
|
||||
return {
|
||||
"chat_history": self.chat_memory.messages[-10:],
|
||||
"relevant_context": docs,
|
||||
}
|
@ -0,0 +1,60 @@
|
||||
import json
|
||||
import re
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, NamedTuple
|
||||
|
||||
from langchain.schema import BaseOutputParser
|
||||
|
||||
|
||||
class AutoGPTAction(NamedTuple):
|
||||
name: str
|
||||
args: Dict
|
||||
|
||||
|
||||
class BaseAutoGPTOutputParser(BaseOutputParser):
|
||||
@abstractmethod
|
||||
def parse(self, text: str) -> AutoGPTAction:
|
||||
"""Return AutoGPTAction"""
|
||||
|
||||
|
||||
def preprocess_json_input(input_str: str) -> str:
|
||||
"""Preprocesses a string to be parsed as json.
|
||||
|
||||
Replace single backslashes with double backslashes,
|
||||
while leaving already escaped ones intact.
|
||||
|
||||
Args:
|
||||
input_str: String to be preprocessed
|
||||
|
||||
Returns:
|
||||
Preprocessed string
|
||||
"""
|
||||
corrected_str = re.sub(
|
||||
r'(?<!\\)\\(?!["\\/bfnrt]|u[0-9a-fA-F]{4})', r"\\\\", input_str
|
||||
)
|
||||
return corrected_str
|
||||
|
||||
|
||||
class AutoGPTOutputParser(BaseAutoGPTOutputParser):
|
||||
def parse(self, text: str) -> AutoGPTAction:
|
||||
try:
|
||||
parsed = json.loads(text, strict=False)
|
||||
except json.JSONDecodeError:
|
||||
preprocessed_text = preprocess_json_input(text)
|
||||
try:
|
||||
parsed = json.loads(preprocessed_text, strict=False)
|
||||
except Exception:
|
||||
return AutoGPTAction(
|
||||
name="ERROR",
|
||||
args={"error": f"Could not parse invalid json: {text}"},
|
||||
)
|
||||
try:
|
||||
return AutoGPTAction(
|
||||
name=parsed["command"]["name"],
|
||||
args=parsed["command"]["args"],
|
||||
)
|
||||
except (KeyError, TypeError):
|
||||
# If the command is null or incomplete, return an erroneous tool
|
||||
return AutoGPTAction(
|
||||
name="ERROR", args={"error": f"Incomplete command args: {parsed}"}
|
||||
)
|
@ -0,0 +1,78 @@
|
||||
import time
|
||||
from typing import Any, Callable, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.experimental.autonomous_agents.autogpt.prompt_generator import get_prompt
|
||||
from langchain.prompts.chat import (
|
||||
BaseChatPromptTemplate,
|
||||
)
|
||||
from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.vectorstores.base import VectorStoreRetriever
|
||||
|
||||
|
||||
class AutoGPTPrompt(BaseChatPromptTemplate, BaseModel):
|
||||
ai_name: str
|
||||
ai_role: str
|
||||
tools: List[BaseTool]
|
||||
token_counter: Callable[[str], int]
|
||||
send_token_limit: int = 4196
|
||||
|
||||
def construct_full_prompt(self, goals: List[str]) -> str:
|
||||
prompt_start = (
|
||||
"Your decisions must always be made independently "
|
||||
"without seeking user assistance.\n"
|
||||
"Play to your strengths as an LLM and pursue simple "
|
||||
"strategies with no legal complications.\n"
|
||||
"If you have completed all your tasks, make sure to "
|
||||
'use the "finish" command.'
|
||||
)
|
||||
# Construct full prompt
|
||||
full_prompt = (
|
||||
f"You are {self.ai_name}, {self.ai_role}\n{prompt_start}\n\nGOALS:\n\n"
|
||||
)
|
||||
for i, goal in enumerate(goals):
|
||||
full_prompt += f"{i+1}. {goal}\n"
|
||||
|
||||
full_prompt += f"\n\n{get_prompt(self.tools)}"
|
||||
return full_prompt
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
base_prompt = SystemMessage(content=self.construct_full_prompt(kwargs["goals"]))
|
||||
time_prompt = SystemMessage(
|
||||
content=f"The current time and date is {time.strftime('%c')}"
|
||||
)
|
||||
used_tokens = self.token_counter(base_prompt.content) + self.token_counter(
|
||||
time_prompt.content
|
||||
)
|
||||
memory: VectorStoreRetriever = kwargs["memory"]
|
||||
previous_messages = kwargs["messages"]
|
||||
relevant_docs = memory.get_relevant_documents(str(previous_messages[-10:]))
|
||||
relevant_memory = [d.page_content for d in relevant_docs]
|
||||
relevant_memory_tokens = sum(
|
||||
[self.token_counter(doc) for doc in relevant_memory]
|
||||
)
|
||||
while used_tokens + relevant_memory_tokens > 2500:
|
||||
relevant_memory = relevant_memory[:-1]
|
||||
relevant_memory_tokens = sum(
|
||||
[self.token_counter(doc) for doc in relevant_memory]
|
||||
)
|
||||
content_format = (
|
||||
f"This reminds you of these events "
|
||||
f"from your past:\n{relevant_memory}\n\n"
|
||||
)
|
||||
memory_message = SystemMessage(content=content_format)
|
||||
used_tokens += self.token_counter(memory_message.content)
|
||||
historical_messages: List[BaseMessage] = []
|
||||
for message in previous_messages[-10:][::-1]:
|
||||
message_tokens = self.token_counter(message.content)
|
||||
if used_tokens + message_tokens > self.send_token_limit - 1000:
|
||||
break
|
||||
historical_messages = [message] + historical_messages
|
||||
used_tokens += message_tokens
|
||||
input_message = HumanMessage(content=kwargs["user_input"])
|
||||
messages: List[BaseMessage] = [base_prompt, time_prompt, memory_message]
|
||||
messages += historical_messages
|
||||
messages.append(input_message)
|
||||
return messages
|
@ -0,0 +1,186 @@
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
FINISH_NAME = "finish"
|
||||
|
||||
|
||||
class PromptGenerator:
|
||||
"""A class for generating custom prompt strings.
|
||||
|
||||
Does this based on constraints, commands, resources, and performance evaluations.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the PromptGenerator object.
|
||||
|
||||
Starts with empty lists of constraints, commands, resources,
|
||||
and performance evaluations.
|
||||
"""
|
||||
self.constraints: List[str] = []
|
||||
self.commands: List[BaseTool] = []
|
||||
self.resources: List[str] = []
|
||||
self.performance_evaluation: List[str] = []
|
||||
self.response_format = {
|
||||
"thoughts": {
|
||||
"text": "thought",
|
||||
"reasoning": "reasoning",
|
||||
"plan": "- short bulleted\n- list that conveys\n- long-term plan",
|
||||
"criticism": "constructive self-criticism",
|
||||
"speak": "thoughts summary to say to user",
|
||||
},
|
||||
"command": {"name": "command name", "args": {"arg name": "value"}},
|
||||
}
|
||||
|
||||
def add_constraint(self, constraint: str) -> None:
|
||||
"""
|
||||
Add a constraint to the constraints list.
|
||||
|
||||
Args:
|
||||
constraint (str): The constraint to be added.
|
||||
"""
|
||||
self.constraints.append(constraint)
|
||||
|
||||
def add_tool(self, tool: BaseTool) -> None:
|
||||
self.commands.append(tool)
|
||||
|
||||
def _generate_command_string(self, tool: BaseTool) -> str:
|
||||
output = f"{tool.name}: {tool.description}"
|
||||
output += f", args json schema: {json.dumps(tool.args)}"
|
||||
return output
|
||||
|
||||
def add_resource(self, resource: str) -> None:
|
||||
"""
|
||||
Add a resource to the resources list.
|
||||
|
||||
Args:
|
||||
resource (str): The resource to be added.
|
||||
"""
|
||||
self.resources.append(resource)
|
||||
|
||||
def add_performance_evaluation(self, evaluation: str) -> None:
|
||||
"""
|
||||
Add a performance evaluation item to the performance_evaluation list.
|
||||
|
||||
Args:
|
||||
evaluation (str): The evaluation item to be added.
|
||||
"""
|
||||
self.performance_evaluation.append(evaluation)
|
||||
|
||||
def _generate_numbered_list(self, items: list, item_type: str = "list") -> str:
|
||||
"""
|
||||
Generate a numbered list from given items based on the item_type.
|
||||
|
||||
Args:
|
||||
items (list): A list of items to be numbered.
|
||||
item_type (str, optional): The type of items in the list.
|
||||
Defaults to 'list'.
|
||||
|
||||
Returns:
|
||||
str: The formatted numbered list.
|
||||
"""
|
||||
if item_type == "command":
|
||||
command_strings = [
|
||||
f"{i + 1}. {self._generate_command_string(item)}"
|
||||
for i, item in enumerate(items)
|
||||
]
|
||||
finish_description = (
|
||||
"use this to signal that you have finished all your objectives"
|
||||
)
|
||||
finish_args = (
|
||||
'"response": "final response to let '
|
||||
'people know you have finished your objectives"'
|
||||
)
|
||||
finish_string = (
|
||||
f"{len(items) + 1}. {FINISH_NAME}: "
|
||||
f"{finish_description}, args: {finish_args}"
|
||||
)
|
||||
return "\n".join(command_strings + [finish_string])
|
||||
else:
|
||||
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items))
|
||||
|
||||
def generate_prompt_string(self) -> str:
|
||||
"""Generate a prompt string.
|
||||
|
||||
Returns:
|
||||
str: The generated prompt string.
|
||||
"""
|
||||
formatted_response_format = json.dumps(self.response_format, indent=4)
|
||||
prompt_string = (
|
||||
f"Constraints:\n{self._generate_numbered_list(self.constraints)}\n\n"
|
||||
f"Commands:\n"
|
||||
f"{self._generate_numbered_list(self.commands, item_type='command')}\n\n"
|
||||
f"Resources:\n{self._generate_numbered_list(self.resources)}\n\n"
|
||||
f"Performance Evaluation:\n"
|
||||
f"{self._generate_numbered_list(self.performance_evaluation)}\n\n"
|
||||
f"You should only respond in JSON format as described below "
|
||||
f"\nResponse Format: \n{formatted_response_format} "
|
||||
f"\nEnsure the response can be parsed by Python json.loads"
|
||||
)
|
||||
|
||||
return prompt_string
|
||||
|
||||
|
||||
def get_prompt(tools: List[BaseTool]) -> str:
|
||||
"""This function generates a prompt string.
|
||||
|
||||
It includes various constraints, commands, resources, and performance evaluations.
|
||||
|
||||
Returns:
|
||||
str: The generated prompt string.
|
||||
"""
|
||||
|
||||
# Initialize the PromptGenerator object
|
||||
prompt_generator = PromptGenerator()
|
||||
|
||||
# Add constraints to the PromptGenerator object
|
||||
prompt_generator.add_constraint(
|
||||
"~4000 word limit for short term memory. "
|
||||
"Your short term memory is short, "
|
||||
"so immediately save important information to files."
|
||||
)
|
||||
prompt_generator.add_constraint(
|
||||
"If you are unsure how you previously did something "
|
||||
"or want to recall past events, "
|
||||
"thinking about similar events will help you remember."
|
||||
)
|
||||
prompt_generator.add_constraint("No user assistance")
|
||||
prompt_generator.add_constraint(
|
||||
'Exclusively use the commands listed in double quotes e.g. "command name"'
|
||||
)
|
||||
|
||||
# Add commands to the PromptGenerator object
|
||||
for tool in tools:
|
||||
prompt_generator.add_tool(tool)
|
||||
|
||||
# Add resources to the PromptGenerator object
|
||||
prompt_generator.add_resource(
|
||||
"Internet access for searches and information gathering."
|
||||
)
|
||||
prompt_generator.add_resource("Long Term memory management.")
|
||||
prompt_generator.add_resource(
|
||||
"GPT-3.5 powered Agents for delegation of simple tasks."
|
||||
)
|
||||
prompt_generator.add_resource("File output.")
|
||||
|
||||
# Add performance evaluations to the PromptGenerator object
|
||||
prompt_generator.add_performance_evaluation(
|
||||
"Continuously review and analyze your actions "
|
||||
"to ensure you are performing to the best of your abilities."
|
||||
)
|
||||
prompt_generator.add_performance_evaluation(
|
||||
"Constructively self-criticize your big-picture behavior constantly."
|
||||
)
|
||||
prompt_generator.add_performance_evaluation(
|
||||
"Reflect on past decisions and strategies to refine your approach."
|
||||
)
|
||||
prompt_generator.add_performance_evaluation(
|
||||
"Every command has a cost, so be smart and efficient. "
|
||||
"Aim to complete tasks in the least number of steps."
|
||||
)
|
||||
|
||||
# Generate the prompt string
|
||||
prompt_string = prompt_generator.generate_prompt_string()
|
||||
|
||||
return prompt_string
|
@ -0,0 +1,17 @@
|
||||
from langchain.experimental.autonomous_agents.baby_agi.baby_agi import BabyAGI
|
||||
from langchain.experimental.autonomous_agents.baby_agi.task_creation import (
|
||||
TaskCreationChain,
|
||||
)
|
||||
from langchain.experimental.autonomous_agents.baby_agi.task_execution import (
|
||||
TaskExecutionChain,
|
||||
)
|
||||
from langchain.experimental.autonomous_agents.baby_agi.task_prioritization import (
|
||||
TaskPrioritizationChain,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BabyAGI",
|
||||
"TaskPrioritizationChain",
|
||||
"TaskExecutionChain",
|
||||
"TaskCreationChain",
|
||||
]
|
@ -0,0 +1,203 @@
|
||||
"""BabyAGI agent."""
|
||||
from collections import deque
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.experimental.autonomous_agents.baby_agi.task_creation import (
|
||||
TaskCreationChain,
|
||||
)
|
||||
from langchain.experimental.autonomous_agents.baby_agi.task_execution import (
|
||||
TaskExecutionChain,
|
||||
)
|
||||
from langchain.experimental.autonomous_agents.baby_agi.task_prioritization import (
|
||||
TaskPrioritizationChain,
|
||||
)
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
||||
class BabyAGI(Chain, BaseModel):
|
||||
"""Controller model for the BabyAGI agent."""
|
||||
|
||||
task_list: deque = Field(default_factory=deque)
|
||||
task_creation_chain: Chain = Field(...)
|
||||
task_prioritization_chain: Chain = Field(...)
|
||||
execution_chain: Chain = Field(...)
|
||||
task_id_counter: int = Field(1)
|
||||
vectorstore: VectorStore = Field(init=False)
|
||||
max_iterations: Optional[int] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def add_task(self, task: Dict) -> None:
|
||||
self.task_list.append(task)
|
||||
|
||||
def print_task_list(self) -> None:
|
||||
print("\033[95m\033[1m" + "\n*****TASK LIST*****\n" + "\033[0m\033[0m")
|
||||
for t in self.task_list:
|
||||
print(str(t["task_id"]) + ": " + t["task_name"])
|
||||
|
||||
def print_next_task(self, task: Dict) -> None:
|
||||
print("\033[92m\033[1m" + "\n*****NEXT TASK*****\n" + "\033[0m\033[0m")
|
||||
print(str(task["task_id"]) + ": " + task["task_name"])
|
||||
|
||||
def print_task_result(self, result: str) -> None:
|
||||
print("\033[93m\033[1m" + "\n*****TASK RESULT*****\n" + "\033[0m\033[0m")
|
||||
print(result)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
return ["objective"]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
return []
|
||||
|
||||
def get_next_task(
|
||||
self, result: str, task_description: str, objective: str, **kwargs: Any
|
||||
) -> List[Dict]:
|
||||
"""Get the next task."""
|
||||
task_names = [t["task_name"] for t in self.task_list]
|
||||
|
||||
incomplete_tasks = ", ".join(task_names)
|
||||
response = self.task_creation_chain.run(
|
||||
result=result,
|
||||
task_description=task_description,
|
||||
incomplete_tasks=incomplete_tasks,
|
||||
objective=objective,
|
||||
**kwargs,
|
||||
)
|
||||
new_tasks = response.split("\n")
|
||||
return [
|
||||
{"task_name": task_name} for task_name in new_tasks if task_name.strip()
|
||||
]
|
||||
|
||||
def prioritize_tasks(
|
||||
self, this_task_id: int, objective: str, **kwargs: Any
|
||||
) -> List[Dict]:
|
||||
"""Prioritize tasks."""
|
||||
task_names = [t["task_name"] for t in list(self.task_list)]
|
||||
next_task_id = int(this_task_id) + 1
|
||||
response = self.task_prioritization_chain.run(
|
||||
task_names=", ".join(task_names),
|
||||
next_task_id=str(next_task_id),
|
||||
objective=objective,
|
||||
**kwargs,
|
||||
)
|
||||
new_tasks = response.split("\n")
|
||||
prioritized_task_list = []
|
||||
for task_string in new_tasks:
|
||||
if not task_string.strip():
|
||||
continue
|
||||
task_parts = task_string.strip().split(".", 1)
|
||||
if len(task_parts) == 2:
|
||||
task_id = task_parts[0].strip()
|
||||
task_name = task_parts[1].strip()
|
||||
prioritized_task_list.append(
|
||||
{"task_id": task_id, "task_name": task_name}
|
||||
)
|
||||
return prioritized_task_list
|
||||
|
||||
def _get_top_tasks(self, query: str, k: int) -> List[str]:
|
||||
"""Get the top k tasks based on the query."""
|
||||
results = self.vectorstore.similarity_search(query, k=k)
|
||||
if not results:
|
||||
return []
|
||||
return [str(item.metadata["task"]) for item in results]
|
||||
|
||||
def execute_task(self, objective: str, task: str, k: int = 5, **kwargs: Any) -> str:
|
||||
"""Execute a task."""
|
||||
context = self._get_top_tasks(query=objective, k=k)
|
||||
return self.execution_chain.run(
|
||||
objective=objective, context="\n".join(context), task=task, **kwargs
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the agent."""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
objective = inputs["objective"]
|
||||
first_task = inputs.get("first_task", "Make a todo list")
|
||||
self.add_task({"task_id": 1, "task_name": first_task})
|
||||
num_iters = 0
|
||||
while True:
|
||||
if self.task_list:
|
||||
self.print_task_list()
|
||||
|
||||
# Step 1: Pull the first task
|
||||
task = self.task_list.popleft()
|
||||
self.print_next_task(task)
|
||||
|
||||
# Step 2: Execute the task
|
||||
result = self.execute_task(
|
||||
objective, task["task_name"], callbacks=_run_manager.get_child()
|
||||
)
|
||||
this_task_id = int(task["task_id"])
|
||||
self.print_task_result(result)
|
||||
|
||||
# Step 3: Store the result in Pinecone
|
||||
result_id = f"result_{task['task_id']}"
|
||||
self.vectorstore.add_texts(
|
||||
texts=[result],
|
||||
metadatas=[{"task": task["task_name"]}],
|
||||
ids=[result_id],
|
||||
)
|
||||
|
||||
# Step 4: Create new tasks and reprioritize task list
|
||||
new_tasks = self.get_next_task(
|
||||
result,
|
||||
task["task_name"],
|
||||
objective,
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
for new_task in new_tasks:
|
||||
self.task_id_counter += 1
|
||||
new_task.update({"task_id": self.task_id_counter})
|
||||
self.add_task(new_task)
|
||||
self.task_list = deque(
|
||||
self.prioritize_tasks(
|
||||
this_task_id, objective, callbacks=_run_manager.get_child()
|
||||
)
|
||||
)
|
||||
num_iters += 1
|
||||
if self.max_iterations is not None and num_iters == self.max_iterations:
|
||||
print(
|
||||
"\033[91m\033[1m" + "\n*****TASK ENDING*****\n" + "\033[0m\033[0m"
|
||||
)
|
||||
break
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
vectorstore: VectorStore,
|
||||
verbose: bool = False,
|
||||
task_execution_chain: Optional[Chain] = None,
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> "BabyAGI":
|
||||
"""Initialize the BabyAGI Controller."""
|
||||
task_creation_chain = TaskCreationChain.from_llm(llm, verbose=verbose)
|
||||
task_prioritization_chain = TaskPrioritizationChain.from_llm(
|
||||
llm, verbose=verbose
|
||||
)
|
||||
if task_execution_chain is None:
|
||||
execution_chain: Chain = TaskExecutionChain.from_llm(llm, verbose=verbose)
|
||||
else:
|
||||
execution_chain = task_execution_chain
|
||||
return cls(
|
||||
task_creation_chain=task_creation_chain,
|
||||
task_prioritization_chain=task_prioritization_chain,
|
||||
execution_chain=execution_chain,
|
||||
vectorstore=vectorstore,
|
||||
**kwargs,
|
||||
)
|
@ -0,0 +1,30 @@
|
||||
from langchain import LLMChain, PromptTemplate
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
|
||||
class TaskCreationChain(LLMChain):
|
||||
"""Chain to generates tasks."""
|
||||
|
||||
@classmethod
|
||||
def from_llm(cls, llm: BaseLanguageModel, verbose: bool = True) -> LLMChain:
|
||||
"""Get the response parser."""
|
||||
task_creation_template = (
|
||||
"You are an task creation AI that uses the result of an execution agent"
|
||||
" to create new tasks with the following objective: {objective},"
|
||||
" The last completed task has the result: {result}."
|
||||
" This result was based on this task description: {task_description}."
|
||||
" These are incomplete tasks: {incomplete_tasks}."
|
||||
" Based on the result, create new tasks to be completed"
|
||||
" by the AI system that do not overlap with incomplete tasks."
|
||||
" Return the tasks as an array."
|
||||
)
|
||||
prompt = PromptTemplate(
|
||||
template=task_creation_template,
|
||||
input_variables=[
|
||||
"result",
|
||||
"task_description",
|
||||
"incomplete_tasks",
|
||||
"objective",
|
||||
],
|
||||
)
|
||||
return cls(prompt=prompt, llm=llm, verbose=verbose)
|
@ -0,0 +1,21 @@
|
||||
from langchain import LLMChain, PromptTemplate
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
|
||||
class TaskExecutionChain(LLMChain):
|
||||
"""Chain to execute tasks."""
|
||||
|
||||
@classmethod
|
||||
def from_llm(cls, llm: BaseLanguageModel, verbose: bool = True) -> LLMChain:
|
||||
"""Get the response parser."""
|
||||
execution_template = (
|
||||
"You are an AI who performs one task based on the following objective: "
|
||||
"{objective}."
|
||||
"Take into account these previously completed tasks: {context}."
|
||||
" Your task: {task}. Response:"
|
||||
)
|
||||
prompt = PromptTemplate(
|
||||
template=execution_template,
|
||||
input_variables=["objective", "context", "task"],
|
||||
)
|
||||
return cls(prompt=prompt, llm=llm, verbose=verbose)
|
@ -0,0 +1,24 @@
|
||||
from langchain import LLMChain, PromptTemplate
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
|
||||
class TaskPrioritizationChain(LLMChain):
|
||||
"""Chain to prioritize tasks."""
|
||||
|
||||
@classmethod
|
||||
def from_llm(cls, llm: BaseLanguageModel, verbose: bool = True) -> LLMChain:
|
||||
"""Get the response parser."""
|
||||
task_prioritization_template = (
|
||||
"You are a task prioritization AI tasked with cleaning the formatting of "
|
||||
"and reprioritizing the following tasks: {task_names}."
|
||||
" Consider the ultimate objective of your team: {objective}."
|
||||
" Do not remove any tasks. Return the result as a numbered list, like:"
|
||||
" #. First task"
|
||||
" #. Second task"
|
||||
" Start the task list with number {next_task_id}."
|
||||
)
|
||||
prompt = PromptTemplate(
|
||||
template=task_prioritization_template,
|
||||
input_variables=["task_names", "next_task_id", "objective"],
|
||||
)
|
||||
return cls(prompt=prompt, llm=llm, verbose=verbose)
|
4
libs/langchain/langchain/experimental/cpal/README.md
Normal file
4
libs/langchain/langchain/experimental/cpal/README.md
Normal file
@ -0,0 +1,4 @@
|
||||
# Causal program-aided language (CPAL) chain
|
||||
|
||||
|
||||
see https://github.com/hwchase17/langchain/pull/6255
|
271
libs/langchain/langchain/experimental/cpal/base.py
Normal file
271
libs/langchain/langchain/experimental/cpal/base.py
Normal file
@ -0,0 +1,271 @@
|
||||
"""
|
||||
CPAL Chain and its subchains
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, ClassVar, Dict, List, Optional, Type
|
||||
|
||||
import pydantic
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.experimental.cpal.constants import Constant
|
||||
from langchain.experimental.cpal.models import (
|
||||
CausalModel,
|
||||
InterventionModel,
|
||||
NarrativeModel,
|
||||
QueryModel,
|
||||
StoryModel,
|
||||
)
|
||||
from langchain.experimental.cpal.templates.univariate.causal import (
|
||||
template as causal_template,
|
||||
)
|
||||
from langchain.experimental.cpal.templates.univariate.intervention import (
|
||||
template as intervention_template,
|
||||
)
|
||||
from langchain.experimental.cpal.templates.univariate.narrative import (
|
||||
template as narrative_template,
|
||||
)
|
||||
from langchain.experimental.cpal.templates.univariate.query import (
|
||||
template as query_template,
|
||||
)
|
||||
from langchain.output_parsers import PydanticOutputParser
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
|
||||
class _BaseStoryElementChain(Chain):
|
||||
chain: LLMChain
|
||||
input_key: str = Constant.narrative_input.value #: :meta private:
|
||||
output_key: str = Constant.chain_answer.value #: :meta private:
|
||||
pydantic_model: ClassVar[
|
||||
Optional[Type[pydantic.BaseModel]]
|
||||
] = None #: :meta private:
|
||||
template: ClassVar[Optional[str]] = None #: :meta private:
|
||||
|
||||
@classmethod
|
||||
def parser(cls) -> PydanticOutputParser:
|
||||
"""Parse LLM output into a pydantic object."""
|
||||
if cls.pydantic_model is None:
|
||||
raise NotImplementedError(
|
||||
f"pydantic_model not implemented for {cls.__name__}"
|
||||
)
|
||||
return PydanticOutputParser(pydantic_object=cls.pydantic_model)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the input keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the output keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
_output_keys = [self.output_key]
|
||||
return _output_keys
|
||||
|
||||
@classmethod
|
||||
def from_univariate_prompt(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return cls(
|
||||
chain=LLMChain(
|
||||
llm=llm,
|
||||
prompt=PromptTemplate(
|
||||
input_variables=[Constant.narrative_input.value],
|
||||
template=kwargs.get("template", cls.template),
|
||||
partial_variables={
|
||||
"format_instructions": cls.parser().get_format_instructions()
|
||||
},
|
||||
),
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
completion = self.chain.run(inputs[self.input_key])
|
||||
pydantic_data = self.__class__.parser().parse(completion)
|
||||
return {
|
||||
Constant.chain_data.value: pydantic_data,
|
||||
Constant.chain_answer.value: None,
|
||||
}
|
||||
|
||||
|
||||
class NarrativeChain(_BaseStoryElementChain):
|
||||
"""Decompose the narrative into its story elements
|
||||
|
||||
- causal model
|
||||
- query
|
||||
- intervention
|
||||
"""
|
||||
|
||||
pydantic_model: ClassVar[Type[pydantic.BaseModel]] = NarrativeModel
|
||||
template: ClassVar[str] = narrative_template
|
||||
|
||||
|
||||
class CausalChain(_BaseStoryElementChain):
|
||||
"""Translate the causal narrative into a stack of operations."""
|
||||
|
||||
pydantic_model: ClassVar[Type[pydantic.BaseModel]] = CausalModel
|
||||
template: ClassVar[str] = causal_template
|
||||
|
||||
|
||||
class InterventionChain(_BaseStoryElementChain):
|
||||
"""Set the hypothetical conditions for the causal model."""
|
||||
|
||||
pydantic_model: ClassVar[Type[pydantic.BaseModel]] = InterventionModel
|
||||
template: ClassVar[str] = intervention_template
|
||||
|
||||
|
||||
class QueryChain(_BaseStoryElementChain):
|
||||
"""Query the outcome table using SQL."""
|
||||
|
||||
pydantic_model: ClassVar[Type[pydantic.BaseModel]] = QueryModel
|
||||
template: ClassVar[str] = query_template # TODO: incl. table schema
|
||||
|
||||
|
||||
class CPALChain(_BaseStoryElementChain):
|
||||
llm: BaseLanguageModel
|
||||
narrative_chain: Optional[NarrativeChain] = None
|
||||
causal_chain: Optional[CausalChain] = None
|
||||
intervention_chain: Optional[InterventionChain] = None
|
||||
query_chain: Optional[QueryChain] = None
|
||||
_story: StoryModel = pydantic.PrivateAttr(default=None) # TODO: change name ?
|
||||
|
||||
@classmethod
|
||||
def from_univariate_prompt(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
**kwargs: Any,
|
||||
) -> CPALChain:
|
||||
"""instantiation depends on component chains"""
|
||||
return cls(
|
||||
llm=llm,
|
||||
chain=LLMChain(
|
||||
llm=llm,
|
||||
prompt=PromptTemplate(
|
||||
input_variables=["question", "query_result"],
|
||||
template=(
|
||||
"Summarize this answer '{query_result}' to this "
|
||||
"question '{question}'? "
|
||||
),
|
||||
),
|
||||
),
|
||||
narrative_chain=NarrativeChain.from_univariate_prompt(llm=llm),
|
||||
causal_chain=CausalChain.from_univariate_prompt(llm=llm),
|
||||
intervention_chain=InterventionChain.from_univariate_prompt(llm=llm),
|
||||
query_chain=QueryChain.from_univariate_prompt(llm=llm),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
# instantiate component chains
|
||||
if self.narrative_chain is None:
|
||||
self.narrative_chain = NarrativeChain.from_univariate_prompt(llm=self.llm)
|
||||
if self.causal_chain is None:
|
||||
self.causal_chain = CausalChain.from_univariate_prompt(llm=self.llm)
|
||||
if self.intervention_chain is None:
|
||||
self.intervention_chain = InterventionChain.from_univariate_prompt(
|
||||
llm=self.llm
|
||||
)
|
||||
if self.query_chain is None:
|
||||
self.query_chain = QueryChain.from_univariate_prompt(llm=self.llm)
|
||||
|
||||
# decompose narrative into three causal story elements
|
||||
narrative = self.narrative_chain(inputs[Constant.narrative_input.value])[
|
||||
Constant.chain_data.value
|
||||
]
|
||||
|
||||
story = StoryModel(
|
||||
causal_operations=self.causal_chain(narrative.story_plot)[
|
||||
Constant.chain_data.value
|
||||
],
|
||||
intervention=self.intervention_chain(narrative.story_hypothetical)[
|
||||
Constant.chain_data.value
|
||||
],
|
||||
query=self.query_chain(narrative.story_outcome_question)[
|
||||
Constant.chain_data.value
|
||||
],
|
||||
)
|
||||
self._story = story
|
||||
|
||||
def pretty_print_str(title: str, d: str) -> str:
|
||||
return title + "\n" + d
|
||||
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
_run_manager.on_text(
|
||||
pretty_print_str("story outcome data", story._outcome_table.to_string()),
|
||||
color="green",
|
||||
end="\n\n",
|
||||
verbose=self.verbose,
|
||||
)
|
||||
|
||||
def pretty_print_dict(title: str, d: dict) -> str:
|
||||
return title + "\n" + json.dumps(d, indent=4)
|
||||
|
||||
_run_manager.on_text(
|
||||
pretty_print_dict("query data", story.query.dict()),
|
||||
color="blue",
|
||||
end="\n\n",
|
||||
verbose=self.verbose,
|
||||
)
|
||||
if story.query._result_table.empty:
|
||||
# prevent piping bad data into subsequent chains
|
||||
raise ValueError(
|
||||
(
|
||||
"unanswerable, query and outcome are incoherent\n"
|
||||
"\n"
|
||||
"outcome:\n"
|
||||
f"{story._outcome_table}\n"
|
||||
"query:\n"
|
||||
f"{story.query.dict()}"
|
||||
)
|
||||
)
|
||||
else:
|
||||
query_result = float(story.query._result_table.values[0][-1])
|
||||
if False:
|
||||
"""TODO: add this back in when demanded by composable chains"""
|
||||
reporting_chain = self.chain
|
||||
human_report = reporting_chain.run(
|
||||
question=story.query.question, query_result=query_result
|
||||
)
|
||||
query_result = {
|
||||
"query_result": query_result,
|
||||
"human_report": human_report,
|
||||
}
|
||||
output = {
|
||||
Constant.chain_data.value: story,
|
||||
self.output_key: query_result,
|
||||
**kwargs,
|
||||
}
|
||||
return output
|
||||
|
||||
def draw(self, **kwargs: Any) -> None:
|
||||
"""
|
||||
CPAL chain can draw its resulting DAG.
|
||||
|
||||
Usage in a jupyter notebook:
|
||||
|
||||
>>> from IPython.display import SVG
|
||||
>>> cpal_chain.draw(path="graph.svg")
|
||||
>>> SVG('graph.svg')
|
||||
"""
|
||||
self._story._networkx_wrapper.draw_graphviz(**kwargs)
|
7
libs/langchain/langchain/experimental/cpal/constants.py
Normal file
7
libs/langchain/langchain/experimental/cpal/constants.py
Normal file
@ -0,0 +1,7 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class Constant(Enum):
|
||||
narrative_input = "narrative_input"
|
||||
chain_answer = "chain_answer" # natural language answer
|
||||
chain_data = "chain_data" # pydantic instance
|
245
libs/langchain/langchain/experimental/cpal/models.py
Normal file
245
libs/langchain/langchain/experimental/cpal/models.py
Normal file
@ -0,0 +1,245 @@
|
||||
from __future__ import annotations # allows pydantic model to reference itself
|
||||
|
||||
import re
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import duckdb
|
||||
import pandas as pd
|
||||
from pydantic import BaseModel, Field, PrivateAttr, root_validator, validator
|
||||
|
||||
from langchain.experimental.cpal.constants import Constant
|
||||
from langchain.graphs.networkx_graph import NetworkxEntityGraph
|
||||
|
||||
|
||||
class NarrativeModel(BaseModel):
|
||||
"""
|
||||
Represent the narrative input as three story elements.
|
||||
"""
|
||||
|
||||
story_outcome_question: str
|
||||
story_hypothetical: str
|
||||
story_plot: str # causal stack of operations
|
||||
|
||||
@validator("*", pre=True)
|
||||
def empty_str_to_none(cls, v: str) -> Union[str, None]:
|
||||
"""Empty strings are not allowed"""
|
||||
if v == "":
|
||||
return None
|
||||
return v
|
||||
|
||||
|
||||
class EntityModel(BaseModel):
|
||||
name: str = Field(description="entity name")
|
||||
code: str = Field(description="entity actions")
|
||||
value: float = Field(description="entity initial value")
|
||||
depends_on: list[str] = Field(default=[], description="ancestor entities")
|
||||
|
||||
# TODO: generalize to multivariate math
|
||||
# TODO: acyclic graph
|
||||
|
||||
class Config:
|
||||
validate_assignment = True
|
||||
|
||||
@validator("name")
|
||||
def lower_case_name(cls, v: str) -> str:
|
||||
v = v.lower()
|
||||
return v
|
||||
|
||||
|
||||
class CausalModel(BaseModel):
|
||||
attribute: str = Field(description="name of the attribute to be calculated")
|
||||
entities: list[EntityModel] = Field(description="entities in the story")
|
||||
|
||||
# TODO: root validate each `entity.depends_on` using system's entity names
|
||||
|
||||
|
||||
class EntitySettingModel(BaseModel):
|
||||
"""
|
||||
Initial conditions for an entity
|
||||
|
||||
{"name": "bud", "attribute": "pet_count", "value": 12}
|
||||
"""
|
||||
|
||||
name: str = Field(description="name of the entity")
|
||||
attribute: str = Field(description="name of the attribute to be calculated")
|
||||
value: float = Field(description="entity's attribute value (calculated)")
|
||||
|
||||
@validator("name")
|
||||
def lower_case_transform(cls, v: str) -> str:
|
||||
v = v.lower()
|
||||
return v
|
||||
|
||||
|
||||
class SystemSettingModel(BaseModel):
|
||||
"""
|
||||
Initial global conditions for the system.
|
||||
|
||||
{"parameter": "interest_rate", "value": .05}
|
||||
"""
|
||||
|
||||
parameter: str
|
||||
value: float
|
||||
|
||||
|
||||
class InterventionModel(BaseModel):
|
||||
"""
|
||||
aka initial conditions
|
||||
|
||||
>>> intervention.dict()
|
||||
{
|
||||
entity_settings: [
|
||||
{"name": "bud", "attribute": "pet_count", "value": 12},
|
||||
{"name": "pat", "attribute": "pet_count", "value": 0},
|
||||
],
|
||||
system_settings: None,
|
||||
}
|
||||
"""
|
||||
|
||||
entity_settings: list[EntitySettingModel]
|
||||
system_settings: Optional[list[SystemSettingModel]] = None
|
||||
|
||||
@validator("system_settings")
|
||||
def lower_case_name(cls, v: str) -> Union[str, None]:
|
||||
if v is not None:
|
||||
raise NotImplementedError("system_setting is not implemented yet")
|
||||
return v
|
||||
|
||||
|
||||
class QueryModel(BaseModel):
|
||||
"""translate a question about the story outcome into a programmatic expression"""
|
||||
|
||||
question: str = Field(alias=Constant.narrative_input.value) # input
|
||||
expression: str # output, part of llm completion
|
||||
llm_error_msg: str # output, part of llm completion
|
||||
_result_table: str = PrivateAttr() # result of the executed query
|
||||
|
||||
|
||||
class ResultModel(BaseModel):
|
||||
question: str = Field(alias=Constant.narrative_input.value) # input
|
||||
_result_table: str = PrivateAttr() # result of the executed query
|
||||
|
||||
|
||||
class StoryModel(BaseModel):
|
||||
causal_operations: Any = Field(required=True)
|
||||
intervention: Any = Field(required=True)
|
||||
query: Any = Field(required=True)
|
||||
_outcome_table: pd.DataFrame = PrivateAttr(default=None)
|
||||
_networkx_wrapper: Any = PrivateAttr(default=None)
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
self._compute()
|
||||
|
||||
# TODO: when langchain adopts pydantic.v2 replace w/ `__post_init__`
|
||||
# misses hints github.com/pydantic/pydantic/issues/1729#issuecomment-1300576214
|
||||
|
||||
@root_validator
|
||||
def check_intervention_is_valid(cls, values: dict) -> dict:
|
||||
valid_names = [e.name for e in values["causal_operations"].entities]
|
||||
for setting in values["intervention"].entity_settings:
|
||||
if setting.name not in valid_names:
|
||||
error_msg = f"""
|
||||
Hypothetical question has an invalid entity name.
|
||||
`{setting.name}` not in `{valid_names}`
|
||||
"""
|
||||
raise ValueError(error_msg)
|
||||
return values
|
||||
|
||||
def _block_back_door_paths(self) -> None:
|
||||
# stop intervention entities from depending on others
|
||||
intervention_entities = [
|
||||
entity_setting.name for entity_setting in self.intervention.entity_settings
|
||||
]
|
||||
for entity in self.causal_operations.entities:
|
||||
if entity.name in intervention_entities:
|
||||
entity.depends_on = []
|
||||
entity.code = "pass"
|
||||
|
||||
def _set_initial_conditions(self) -> None:
|
||||
for entity_setting in self.intervention.entity_settings:
|
||||
for entity in self.causal_operations.entities:
|
||||
if entity.name == entity_setting.name:
|
||||
entity.value = entity_setting.value
|
||||
|
||||
def _make_graph(self) -> None:
|
||||
self._networkx_wrapper = NetworkxEntityGraph()
|
||||
for entity in self.causal_operations.entities:
|
||||
for parent_name in entity.depends_on:
|
||||
self._networkx_wrapper._graph.add_edge(
|
||||
parent_name, entity.name, relation=entity.code
|
||||
)
|
||||
|
||||
# TODO: is it correct to drop entities with no impact on the outcome (?)
|
||||
self.causal_operations.entities = [
|
||||
entity
|
||||
for entity in self.causal_operations.entities
|
||||
if entity.name in self._networkx_wrapper.get_topological_sort()
|
||||
]
|
||||
|
||||
def _sort_entities(self) -> None:
|
||||
# order the sequence of causal actions
|
||||
sorted_nodes = self._networkx_wrapper.get_topological_sort()
|
||||
self.causal_operations.entities.sort(key=lambda x: sorted_nodes.index(x.name))
|
||||
|
||||
def _forward_propagate(self) -> None:
|
||||
entity_scope = {
|
||||
entity.name: entity for entity in self.causal_operations.entities
|
||||
}
|
||||
for entity in self.causal_operations.entities:
|
||||
if entity.code == "pass":
|
||||
continue
|
||||
else:
|
||||
# gist.github.com/dean0x7d/df5ce97e4a1a05be4d56d1378726ff92
|
||||
exec(entity.code, globals(), entity_scope)
|
||||
row_values = [entity.dict() for entity in entity_scope.values()]
|
||||
self._outcome_table = pd.DataFrame(row_values)
|
||||
|
||||
def _run_query(self) -> None:
|
||||
def humanize_sql_error_msg(error: str) -> str:
|
||||
pattern = r"column\s+(.*?)\s+not found"
|
||||
col_match = re.search(pattern, error)
|
||||
if col_match:
|
||||
return (
|
||||
"SQL error: "
|
||||
+ col_match.group(1)
|
||||
+ " is not an attribute in your story!"
|
||||
)
|
||||
else:
|
||||
return str(error)
|
||||
|
||||
if self.query.llm_error_msg == "":
|
||||
try:
|
||||
df = self._outcome_table # noqa
|
||||
query_result = duckdb.sql(self.query.expression).df()
|
||||
self.query._result_table = query_result
|
||||
except duckdb.BinderException as e:
|
||||
self.query._result_table = humanize_sql_error_msg(str(e))
|
||||
except Exception as e:
|
||||
self.query._result_table = str(e)
|
||||
else:
|
||||
msg = "LLM maybe failed to translate question to SQL query."
|
||||
raise ValueError(
|
||||
{
|
||||
"question": self.query.question,
|
||||
"llm_error_msg": self.query.llm_error_msg,
|
||||
"msg": msg,
|
||||
}
|
||||
)
|
||||
|
||||
def _compute(self) -> Any:
|
||||
self._block_back_door_paths()
|
||||
self._set_initial_conditions()
|
||||
self._make_graph()
|
||||
self._sort_entities()
|
||||
self._forward_propagate()
|
||||
self._run_query()
|
||||
|
||||
def print_debug_report(self) -> None:
|
||||
report = {
|
||||
"outcome": self._outcome_table,
|
||||
"query": self.query.dict(),
|
||||
"result": self.query._result_table,
|
||||
}
|
||||
from pprint import pprint
|
||||
|
||||
pprint(report)
|
@ -0,0 +1,113 @@
|
||||
# flake8: noqa E501
|
||||
|
||||
# fmt: off
|
||||
template = (
|
||||
"""
|
||||
Transform the math story plot into a JSON object. Don't guess at any of the parts.
|
||||
|
||||
{format_instructions}
|
||||
|
||||
|
||||
|
||||
Story: Boris has seven times the number of pets as Marcia. Jan has three times the number of pets as Marcia. Marcia has two more pets than Cindy.
|
||||
|
||||
|
||||
|
||||
# JSON:
|
||||
|
||||
|
||||
|
||||
{{
|
||||
"attribute": "pet_count",
|
||||
"entities": [
|
||||
{{
|
||||
"name": "cindy",
|
||||
"value": 0,
|
||||
"depends_on": [],
|
||||
"code": "pass"
|
||||
}},
|
||||
{{
|
||||
"name": "marcia",
|
||||
"value": 0,
|
||||
"depends_on": ["cindy"],
|
||||
"code": "marcia.value = cindy.value + 2"
|
||||
}},
|
||||
{{
|
||||
"name": "boris",
|
||||
"value": 0,
|
||||
"depends_on": ["marcia"],
|
||||
"code": "boris.value = marcia.value * 7"
|
||||
}},
|
||||
{{
|
||||
"name": "jan",
|
||||
"value": 0,
|
||||
"depends_on": ["marcia"],
|
||||
"code": "jan.value = marcia.value * 3"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
|
||||
|
||||
|
||||
Story: Boris gives 20 percent of his money to Marcia. Marcia gives 10
|
||||
percent of her money to Cindy. Cindy gives 5 percent of her money to Jan.
|
||||
|
||||
|
||||
|
||||
|
||||
# JSON:
|
||||
|
||||
|
||||
|
||||
{{
|
||||
"attribute": "money",
|
||||
"entities": [
|
||||
{{
|
||||
"name": "boris",
|
||||
"value": 0,
|
||||
"depends_on": [],
|
||||
"code": "pass"
|
||||
}},
|
||||
{{
|
||||
"name": "marcia",
|
||||
"value": 0,
|
||||
"depends_on": ["boris"],
|
||||
"code": "
|
||||
marcia.value = boris.value * 0.2
|
||||
boris.value = boris.value * 0.8
|
||||
"
|
||||
}},
|
||||
{{
|
||||
"name": "cindy",
|
||||
"value": 0,
|
||||
"depends_on": ["marcia"],
|
||||
"code": "
|
||||
cindy.value = marcia.value * 0.1
|
||||
marcia.value = marcia.value * 0.9
|
||||
"
|
||||
}},
|
||||
{{
|
||||
"name": "jan",
|
||||
"value": 0,
|
||||
"depends_on": ["cindy"],
|
||||
"code": "
|
||||
jan.value = cindy.value * 0.05
|
||||
cindy.value = cindy.value * 0.9
|
||||
"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
|
||||
|
||||
|
||||
Story: {narrative_input}
|
||||
|
||||
|
||||
|
||||
# JSON:
|
||||
""".strip()
|
||||
+ "\n"
|
||||
)
|
||||
# fmt: on
|
@ -0,0 +1,59 @@
|
||||
# flake8: noqa E501
|
||||
|
||||
# fmt: off
|
||||
template = (
|
||||
"""
|
||||
Transform the hypothetical whatif statement into JSON. Don't guess at any of the parts. Write NONE if you are unsure.
|
||||
|
||||
{format_instructions}
|
||||
|
||||
|
||||
|
||||
statement: if cindy's pet count was 4
|
||||
|
||||
|
||||
|
||||
|
||||
# JSON:
|
||||
|
||||
|
||||
|
||||
{{
|
||||
"entity_settings" : [
|
||||
{{ "name": "cindy", "attribute": "pet_count", "value": "4" }}
|
||||
]
|
||||
}}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
statement: Let's say boris has ten dollars and Bill has 20 dollars.
|
||||
|
||||
|
||||
|
||||
|
||||
# JSON:
|
||||
|
||||
|
||||
{{
|
||||
"entity_settings" : [
|
||||
{{ "name": "boris", "attribute": "dollars", "value": "10" }},
|
||||
{{ "name": "bill", "attribute": "dollars", "value": "20" }}
|
||||
]
|
||||
}}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Statement: {narrative_input}
|
||||
|
||||
|
||||
|
||||
|
||||
# JSON:
|
||||
""".strip()
|
||||
+ "\n\n\n"
|
||||
)
|
||||
# fmt: on
|
@ -0,0 +1,79 @@
|
||||
# flake8: noqa E501
|
||||
|
||||
|
||||
# fmt: off
|
||||
template = (
|
||||
"""
|
||||
Split the given text into three parts: the question, the story_hypothetical, and the logic. Don't guess at any of the parts. Write NONE if you are unsure.
|
||||
|
||||
{format_instructions}
|
||||
|
||||
|
||||
|
||||
Q: Boris has seven times the number of pets as Marcia. Jan has three times the number of pets as Marcia. Marcia has two more pets than Cindy. If Cindy has four pets, how many total pets do the three have?
|
||||
|
||||
|
||||
|
||||
# JSON
|
||||
|
||||
|
||||
|
||||
{{
|
||||
"story_outcome_question": "how many total pets do the three have?",
|
||||
"story_hypothetical": "If Cindy has four pets",
|
||||
"story_plot": "Boris has seven times the number of pets as Marcia. Jan has three times the number of pets as Marcia. Marcia has two more pets than Cindy."
|
||||
}}
|
||||
|
||||
|
||||
|
||||
Q: boris gives ten percent of his money to marcia. marcia gives ten
|
||||
percent of her money to andy. If boris has 100 dollars, how much money
|
||||
will andy have?
|
||||
|
||||
|
||||
|
||||
# JSON
|
||||
|
||||
|
||||
|
||||
{{
|
||||
"story_outcome_question": "how much money will andy have?",
|
||||
"story_hypothetical": "If boris has 100 dollars"
|
||||
"story_plot": "boris gives ten percent of his money to marcia. marcia gives ten percent of her money to andy."
|
||||
}}
|
||||
|
||||
|
||||
|
||||
|
||||
Q: boris gives ten percent of his candy to marcia. marcia gives ten
|
||||
percent of her candy to andy. If boris has 100 pounds of candy and marcia has
|
||||
200 pounds of candy, then how many pounds of candy will andy have?
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# JSON
|
||||
|
||||
|
||||
|
||||
|
||||
{{
|
||||
"story_outcome_question": "how many pounds of candy will andy have?",
|
||||
"story_hypothetical": "If boris has 100 pounds of candy and marcia has 200 pounds of candy"
|
||||
"story_plot": "boris gives ten percent of his candy to marcia. marcia gives ten percent of her candy to andy."
|
||||
}}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Q: {narrative_input}
|
||||
|
||||
|
||||
|
||||
# JSON
|
||||
""".strip()
|
||||
+ "\n\n\n"
|
||||
)
|
||||
# fmt: on
|
@ -0,0 +1,270 @@
|
||||
# flake8: noqa E501
|
||||
|
||||
|
||||
# fmt: off
|
||||
template = (
|
||||
"""
|
||||
Transform the narrative_input into an SQL expression. If you are
|
||||
unsure, then do not guess, instead add a llm_error_msg that explains why you are unsure.
|
||||
|
||||
|
||||
{format_instructions}
|
||||
|
||||
|
||||
narrative_input: how much money will boris have?
|
||||
|
||||
|
||||
# JSON:
|
||||
|
||||
{{
|
||||
"narrative_input": "how much money will boris have?",
|
||||
"llm_error_msg": "",
|
||||
"expression": "SELECT name, value FROM df WHERE name = 'boris'"
|
||||
}}
|
||||
|
||||
|
||||
|
||||
narrative_input: How much money does ted have?
|
||||
|
||||
|
||||
|
||||
# JSON:
|
||||
|
||||
{{
|
||||
"narrative_input": "How much money does ted have?",
|
||||
"llm_error_msg": "",
|
||||
"expression": "SELECT name, value FROM df WHERE name = 'ted'"
|
||||
}}
|
||||
|
||||
|
||||
|
||||
narrative_input: what is the sum of pet count for all the people?
|
||||
|
||||
|
||||
|
||||
# JSON:
|
||||
|
||||
{{
|
||||
"narrative_input": "what is the sum of pet count for all the people?",
|
||||
"llm_error_msg": "",
|
||||
"expression": "SELECT SUM(value) FROM df"
|
||||
}}
|
||||
|
||||
|
||||
|
||||
|
||||
narrative_input: what's the average of the pet counts for all the people?
|
||||
|
||||
|
||||
|
||||
# JSON:
|
||||
|
||||
{{
|
||||
"narrative_input": "what's the average of the pet counts for all the people?",
|
||||
"llm_error_msg": "",
|
||||
"expression": "SELECT AVG(value) FROM df"
|
||||
}}
|
||||
|
||||
|
||||
|
||||
|
||||
narrative_input: what's the maximum of the pet counts for all the people?
|
||||
|
||||
|
||||
|
||||
# JSON:
|
||||
|
||||
{{
|
||||
"narrative_input": "what's the maximum of the pet counts for all the people?",
|
||||
"llm_error_msg": "",
|
||||
"expression": "SELECT MAX(value) FROM df"
|
||||
}}
|
||||
|
||||
|
||||
|
||||
|
||||
narrative_input: what's the minimum of the pet counts for all the people?
|
||||
|
||||
|
||||
|
||||
# JSON:
|
||||
|
||||
{{
|
||||
"narrative_input": "what's the minimum of the pet counts for all the people?",
|
||||
"llm_error_msg": "",
|
||||
"expression": "SELECT MIN(value) FROM df"
|
||||
}}
|
||||
|
||||
|
||||
|
||||
|
||||
narrative_input: what's the number of people with pet counts greater than 10?
|
||||
|
||||
|
||||
|
||||
# JSON:
|
||||
|
||||
{{
|
||||
"narrative_input": "what's the number of people with pet counts greater than 10?",
|
||||
"llm_error_msg": "",
|
||||
"expression": "SELECT COUNT(*) FROM df WHERE value > 10"
|
||||
}}
|
||||
|
||||
|
||||
|
||||
|
||||
narrative_input: what's the pet count for boris?
|
||||
|
||||
|
||||
|
||||
# JSON:
|
||||
|
||||
{{
|
||||
"narrative_input": "what's the pet count for boris?",
|
||||
"llm_error_msg": "",
|
||||
"expression": "SELECT name, value FROM df WHERE name = 'boris'"
|
||||
}}
|
||||
|
||||
|
||||
|
||||
|
||||
narrative_input: what's the pet count for cindy and marcia?
|
||||
|
||||
|
||||
|
||||
# JSON:
|
||||
|
||||
{{
|
||||
"narrative_input": "what's the pet count for cindy and marcia?",
|
||||
"llm_error_msg": "",
|
||||
"expression": "SELECT name, value FROM df WHERE name IN ('cindy', 'marcia')"
|
||||
}}
|
||||
|
||||
|
||||
|
||||
|
||||
narrative_input: what's the total pet count for cindy and marcia?
|
||||
|
||||
|
||||
|
||||
# JSON:
|
||||
|
||||
{{
|
||||
"narrative_input": "what's the total pet count for cindy and marcia?",
|
||||
"llm_error_msg": "",
|
||||
"expression": "SELECT SUM(value) FROM df WHERE name IN ('cindy', 'marcia')"
|
||||
}}
|
||||
|
||||
|
||||
|
||||
|
||||
narrative_input: what's the total pet count for TED?
|
||||
|
||||
|
||||
|
||||
# JSON:
|
||||
|
||||
{{
|
||||
"narrative_input": "what's the total pet count for TED?",
|
||||
"llm_error_msg": "",
|
||||
"expression": "SELECT SUM(value) FROM df WHERE name = 'TED'"
|
||||
}}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
narrative_input: what's the total dollar count for TED and cindy?
|
||||
|
||||
|
||||
|
||||
# JSON:
|
||||
|
||||
{{
|
||||
"narrative_input": "what's the total dollar count for TED and cindy?",
|
||||
"llm_error_msg": "",
|
||||
"expression": "SELECT SUM(value) FROM df WHERE name IN ('TED', 'cindy')"
|
||||
}}
|
||||
|
||||
|
||||
|
||||
|
||||
narrative_input: what's the total pet count for TED and cindy?
|
||||
|
||||
|
||||
|
||||
|
||||
# JSON:
|
||||
|
||||
{{
|
||||
"narrative_input": "what's the total pet count for TED and cindy?",
|
||||
"llm_error_msg": "",
|
||||
"expression": "SELECT SUM(value) FROM df WHERE name IN ('TED', 'cindy')"
|
||||
}}
|
||||
|
||||
|
||||
|
||||
|
||||
narrative_input: what's the best for TED and cindy?
|
||||
|
||||
|
||||
|
||||
|
||||
# JSON:
|
||||
|
||||
{{
|
||||
"narrative_input": "what's the best for TED and cindy?",
|
||||
"llm_error_msg": "ambiguous narrative_input, not sure what 'best' means",
|
||||
"expression": ""
|
||||
}}
|
||||
|
||||
|
||||
|
||||
|
||||
narrative_input: what's the value?
|
||||
|
||||
|
||||
|
||||
|
||||
# JSON:
|
||||
|
||||
{{
|
||||
"narrative_input": "what's the value?",
|
||||
"llm_error_msg": "ambiguous narrative_input, not sure what entity is being asked about",
|
||||
"expression": ""
|
||||
}}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
narrative_input: how many total pets do the three have?
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# JSON:
|
||||
|
||||
{{
|
||||
"narrative_input": "how many total pets do the three have?",
|
||||
"llm_error_msg": "",
|
||||
"expression": "SELECT SUM(value) FROM df"
|
||||
}}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
narrative_input: {narrative_input}
|
||||
|
||||
|
||||
|
||||
|
||||
# JSON:
|
||||
""".strip()
|
||||
+ "\n\n\n"
|
||||
)
|
||||
# fmt: on
|
@ -0,0 +1,5 @@
|
||||
"""Generative Agents primitives."""
|
||||
from langchain.experimental.generative_agents.generative_agent import GenerativeAgent
|
||||
from langchain.experimental.generative_agents.memory import GenerativeAgentMemory
|
||||
|
||||
__all__ = ["GenerativeAgent", "GenerativeAgentMemory"]
|
@ -0,0 +1,252 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain import LLMChain
|
||||
from langchain.experimental.generative_agents.memory import GenerativeAgentMemory
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
|
||||
class GenerativeAgent(BaseModel):
|
||||
"""A character with memory and innate characteristics."""
|
||||
|
||||
name: str
|
||||
"""The character's name."""
|
||||
|
||||
age: Optional[int] = None
|
||||
"""The optional age of the character."""
|
||||
traits: str = "N/A"
|
||||
"""Permanent traits to ascribe to the character."""
|
||||
status: str
|
||||
"""The traits of the character you wish not to change."""
|
||||
memory: GenerativeAgentMemory
|
||||
"""The memory object that combines relevance, recency, and 'importance'."""
|
||||
llm: BaseLanguageModel
|
||||
"""The underlying language model."""
|
||||
verbose: bool = False
|
||||
summary: str = "" #: :meta private:
|
||||
"""Stateful self-summary generated via reflection on the character's memory."""
|
||||
|
||||
summary_refresh_seconds: int = 3600 #: :meta private:
|
||||
"""How frequently to re-generate the summary."""
|
||||
|
||||
last_refreshed: datetime = Field(default_factory=datetime.now) # : :meta private:
|
||||
"""The last time the character's summary was regenerated."""
|
||||
|
||||
daily_summaries: List[str] = Field(default_factory=list) # : :meta private:
|
||||
"""Summary of the events in the plan that the agent took."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
# LLM-related methods
|
||||
@staticmethod
|
||||
def _parse_list(text: str) -> List[str]:
|
||||
"""Parse a newline-separated string into a list of strings."""
|
||||
lines = re.split(r"\n", text.strip())
|
||||
return [re.sub(r"^\s*\d+\.\s*", "", line).strip() for line in lines]
|
||||
|
||||
def chain(self, prompt: PromptTemplate) -> LLMChain:
|
||||
return LLMChain(
|
||||
llm=self.llm, prompt=prompt, verbose=self.verbose, memory=self.memory
|
||||
)
|
||||
|
||||
def _get_entity_from_observation(self, observation: str) -> str:
|
||||
prompt = PromptTemplate.from_template(
|
||||
"What is the observed entity in the following observation? {observation}"
|
||||
+ "\nEntity="
|
||||
)
|
||||
return self.chain(prompt).run(observation=observation).strip()
|
||||
|
||||
def _get_entity_action(self, observation: str, entity_name: str) -> str:
|
||||
prompt = PromptTemplate.from_template(
|
||||
"What is the {entity} doing in the following observation? {observation}"
|
||||
+ "\nThe {entity} is"
|
||||
)
|
||||
return (
|
||||
self.chain(prompt).run(entity=entity_name, observation=observation).strip()
|
||||
)
|
||||
|
||||
def summarize_related_memories(self, observation: str) -> str:
|
||||
"""Summarize memories that are most relevant to an observation."""
|
||||
prompt = PromptTemplate.from_template(
|
||||
"""
|
||||
{q1}?
|
||||
Context from memory:
|
||||
{relevant_memories}
|
||||
Relevant context:
|
||||
"""
|
||||
)
|
||||
entity_name = self._get_entity_from_observation(observation)
|
||||
entity_action = self._get_entity_action(observation, entity_name)
|
||||
q1 = f"What is the relationship between {self.name} and {entity_name}"
|
||||
q2 = f"{entity_name} is {entity_action}"
|
||||
return self.chain(prompt=prompt).run(q1=q1, queries=[q1, q2]).strip()
|
||||
|
||||
def _generate_reaction(
|
||||
self, observation: str, suffix: str, now: Optional[datetime] = None
|
||||
) -> str:
|
||||
"""React to a given observation or dialogue act."""
|
||||
prompt = PromptTemplate.from_template(
|
||||
"{agent_summary_description}"
|
||||
+ "\nIt is {current_time}."
|
||||
+ "\n{agent_name}'s status: {agent_status}"
|
||||
+ "\nSummary of relevant context from {agent_name}'s memory:"
|
||||
+ "\n{relevant_memories}"
|
||||
+ "\nMost recent observations: {most_recent_memories}"
|
||||
+ "\nObservation: {observation}"
|
||||
+ "\n\n"
|
||||
+ suffix
|
||||
)
|
||||
agent_summary_description = self.get_summary(now=now)
|
||||
relevant_memories_str = self.summarize_related_memories(observation)
|
||||
current_time_str = (
|
||||
datetime.now().strftime("%B %d, %Y, %I:%M %p")
|
||||
if now is None
|
||||
else now.strftime("%B %d, %Y, %I:%M %p")
|
||||
)
|
||||
kwargs: Dict[str, Any] = dict(
|
||||
agent_summary_description=agent_summary_description,
|
||||
current_time=current_time_str,
|
||||
relevant_memories=relevant_memories_str,
|
||||
agent_name=self.name,
|
||||
observation=observation,
|
||||
agent_status=self.status,
|
||||
)
|
||||
consumed_tokens = self.llm.get_num_tokens(
|
||||
prompt.format(most_recent_memories="", **kwargs)
|
||||
)
|
||||
kwargs[self.memory.most_recent_memories_token_key] = consumed_tokens
|
||||
return self.chain(prompt=prompt).run(**kwargs).strip()
|
||||
|
||||
def _clean_response(self, text: str) -> str:
|
||||
return re.sub(f"^{self.name} ", "", text.strip()).strip()
|
||||
|
||||
def generate_reaction(
|
||||
self, observation: str, now: Optional[datetime] = None
|
||||
) -> Tuple[bool, str]:
|
||||
"""React to a given observation."""
|
||||
call_to_action_template = (
|
||||
"Should {agent_name} react to the observation, and if so,"
|
||||
+ " what would be an appropriate reaction? Respond in one line."
|
||||
+ ' If the action is to engage in dialogue, write:\nSAY: "what to say"'
|
||||
+ "\notherwise, write:\nREACT: {agent_name}'s reaction (if anything)."
|
||||
+ "\nEither do nothing, react, or say something but not both.\n\n"
|
||||
)
|
||||
full_result = self._generate_reaction(
|
||||
observation, call_to_action_template, now=now
|
||||
)
|
||||
result = full_result.strip().split("\n")[0]
|
||||
# AAA
|
||||
self.memory.save_context(
|
||||
{},
|
||||
{
|
||||
self.memory.add_memory_key: f"{self.name} observed "
|
||||
f"{observation} and reacted by {result}",
|
||||
self.memory.now_key: now,
|
||||
},
|
||||
)
|
||||
if "REACT:" in result:
|
||||
reaction = self._clean_response(result.split("REACT:")[-1])
|
||||
return False, f"{self.name} {reaction}"
|
||||
if "SAY:" in result:
|
||||
said_value = self._clean_response(result.split("SAY:")[-1])
|
||||
return True, f"{self.name} said {said_value}"
|
||||
else:
|
||||
return False, result
|
||||
|
||||
def generate_dialogue_response(
|
||||
self, observation: str, now: Optional[datetime] = None
|
||||
) -> Tuple[bool, str]:
|
||||
"""React to a given observation."""
|
||||
call_to_action_template = (
|
||||
"What would {agent_name} say? To end the conversation, write:"
|
||||
' GOODBYE: "what to say". Otherwise to continue the conversation,'
|
||||
' write: SAY: "what to say next"\n\n'
|
||||
)
|
||||
full_result = self._generate_reaction(
|
||||
observation, call_to_action_template, now=now
|
||||
)
|
||||
result = full_result.strip().split("\n")[0]
|
||||
if "GOODBYE:" in result:
|
||||
farewell = self._clean_response(result.split("GOODBYE:")[-1])
|
||||
self.memory.save_context(
|
||||
{},
|
||||
{
|
||||
self.memory.add_memory_key: f"{self.name} observed "
|
||||
f"{observation} and said {farewell}",
|
||||
self.memory.now_key: now,
|
||||
},
|
||||
)
|
||||
return False, f"{self.name} said {farewell}"
|
||||
if "SAY:" in result:
|
||||
response_text = self._clean_response(result.split("SAY:")[-1])
|
||||
self.memory.save_context(
|
||||
{},
|
||||
{
|
||||
self.memory.add_memory_key: f"{self.name} observed "
|
||||
f"{observation} and said {response_text}",
|
||||
self.memory.now_key: now,
|
||||
},
|
||||
)
|
||||
return True, f"{self.name} said {response_text}"
|
||||
else:
|
||||
return False, result
|
||||
|
||||
######################################################
|
||||
# Agent stateful' summary methods. #
|
||||
# Each dialog or response prompt includes a header #
|
||||
# summarizing the agent's self-description. This is #
|
||||
# updated periodically through probing its memories #
|
||||
######################################################
|
||||
def _compute_agent_summary(self) -> str:
|
||||
""""""
|
||||
prompt = PromptTemplate.from_template(
|
||||
"How would you summarize {name}'s core characteristics given the"
|
||||
+ " following statements:\n"
|
||||
+ "{relevant_memories}"
|
||||
+ "Do not embellish."
|
||||
+ "\n\nSummary: "
|
||||
)
|
||||
# The agent seeks to think about their core characteristics.
|
||||
return (
|
||||
self.chain(prompt)
|
||||
.run(name=self.name, queries=[f"{self.name}'s core characteristics"])
|
||||
.strip()
|
||||
)
|
||||
|
||||
def get_summary(
|
||||
self, force_refresh: bool = False, now: Optional[datetime] = None
|
||||
) -> str:
|
||||
"""Return a descriptive summary of the agent."""
|
||||
current_time = datetime.now() if now is None else now
|
||||
since_refresh = (current_time - self.last_refreshed).seconds
|
||||
if (
|
||||
not self.summary
|
||||
or since_refresh >= self.summary_refresh_seconds
|
||||
or force_refresh
|
||||
):
|
||||
self.summary = self._compute_agent_summary()
|
||||
self.last_refreshed = current_time
|
||||
age = self.age if self.age is not None else "N/A"
|
||||
return (
|
||||
f"Name: {self.name} (age: {age})"
|
||||
+ f"\nInnate traits: {self.traits}"
|
||||
+ f"\n{self.summary}"
|
||||
)
|
||||
|
||||
def get_full_header(
|
||||
self, force_refresh: bool = False, now: Optional[datetime] = None
|
||||
) -> str:
|
||||
"""Return a full header of the agent's status, summary, and current time."""
|
||||
now = datetime.now() if now is None else now
|
||||
summary = self.get_summary(force_refresh=force_refresh, now=now)
|
||||
current_time_str = now.strftime("%B %d, %Y, %I:%M %p")
|
||||
return (
|
||||
f"{summary}\nIt is {current_time_str}.\n{self.name}'s status: {self.status}"
|
||||
)
|
@ -0,0 +1,299 @@
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain import LLMChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.retrievers import TimeWeightedVectorStoreRetriever
|
||||
from langchain.schema import BaseMemory, Document
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.utils import mock_now
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GenerativeAgentMemory(BaseMemory):
|
||||
llm: BaseLanguageModel
|
||||
"""The core language model."""
|
||||
|
||||
memory_retriever: TimeWeightedVectorStoreRetriever
|
||||
"""The retriever to fetch related memories."""
|
||||
verbose: bool = False
|
||||
|
||||
reflection_threshold: Optional[float] = None
|
||||
"""When aggregate_importance exceeds reflection_threshold, stop to reflect."""
|
||||
|
||||
current_plan: List[str] = []
|
||||
"""The current plan of the agent."""
|
||||
|
||||
# A weight of 0.15 makes this less important than it
|
||||
# would be otherwise, relative to salience and time
|
||||
importance_weight: float = 0.15
|
||||
"""How much weight to assign the memory importance."""
|
||||
|
||||
aggregate_importance: float = 0.0 # : :meta private:
|
||||
"""Track the sum of the 'importance' of recent memories.
|
||||
|
||||
Triggers reflection when it reaches reflection_threshold."""
|
||||
|
||||
max_tokens_limit: int = 1200 # : :meta private:
|
||||
# input keys
|
||||
queries_key: str = "queries"
|
||||
most_recent_memories_token_key: str = "recent_memories_token"
|
||||
add_memory_key: str = "add_memory"
|
||||
# output keys
|
||||
relevant_memories_key: str = "relevant_memories"
|
||||
relevant_memories_simple_key: str = "relevant_memories_simple"
|
||||
most_recent_memories_key: str = "most_recent_memories"
|
||||
now_key: str = "now"
|
||||
reflecting: bool = False
|
||||
|
||||
def chain(self, prompt: PromptTemplate) -> LLMChain:
|
||||
return LLMChain(llm=self.llm, prompt=prompt, verbose=self.verbose)
|
||||
|
||||
@staticmethod
|
||||
def _parse_list(text: str) -> List[str]:
|
||||
"""Parse a newline-separated string into a list of strings."""
|
||||
lines = re.split(r"\n", text.strip())
|
||||
lines = [line for line in lines if line.strip()] # remove empty lines
|
||||
return [re.sub(r"^\s*\d+\.\s*", "", line).strip() for line in lines]
|
||||
|
||||
def _get_topics_of_reflection(self, last_k: int = 50) -> List[str]:
|
||||
"""Return the 3 most salient high-level questions about recent observations."""
|
||||
prompt = PromptTemplate.from_template(
|
||||
"{observations}\n\n"
|
||||
"Given only the information above, what are the 3 most salient "
|
||||
"high-level questions we can answer about the subjects in the statements?\n"
|
||||
"Provide each question on a new line."
|
||||
)
|
||||
observations = self.memory_retriever.memory_stream[-last_k:]
|
||||
observation_str = "\n".join(
|
||||
[self._format_memory_detail(o) for o in observations]
|
||||
)
|
||||
result = self.chain(prompt).run(observations=observation_str)
|
||||
return self._parse_list(result)
|
||||
|
||||
def _get_insights_on_topic(
|
||||
self, topic: str, now: Optional[datetime] = None
|
||||
) -> List[str]:
|
||||
"""Generate 'insights' on a topic of reflection, based on pertinent memories."""
|
||||
prompt = PromptTemplate.from_template(
|
||||
"Statements relevant to: '{topic}'\n"
|
||||
"---\n"
|
||||
"{related_statements}\n"
|
||||
"---\n"
|
||||
"What 5 high-level novel insights can you infer from the above statements "
|
||||
"that are relevant for answering the following question?\n"
|
||||
"Do not include any insights that are not relevant to the question.\n"
|
||||
"Do not repeat any insights that have already been made.\n\n"
|
||||
"Question: {topic}\n\n"
|
||||
"(example format: insight (because of 1, 5, 3))\n"
|
||||
)
|
||||
|
||||
related_memories = self.fetch_memories(topic, now=now)
|
||||
related_statements = "\n".join(
|
||||
[
|
||||
self._format_memory_detail(memory, prefix=f"{i+1}. ")
|
||||
for i, memory in enumerate(related_memories)
|
||||
]
|
||||
)
|
||||
result = self.chain(prompt).run(
|
||||
topic=topic, related_statements=related_statements
|
||||
)
|
||||
# TODO: Parse the connections between memories and insights
|
||||
return self._parse_list(result)
|
||||
|
||||
def pause_to_reflect(self, now: Optional[datetime] = None) -> List[str]:
|
||||
"""Reflect on recent observations and generate 'insights'."""
|
||||
if self.verbose:
|
||||
logger.info("Character is reflecting")
|
||||
new_insights = []
|
||||
topics = self._get_topics_of_reflection()
|
||||
for topic in topics:
|
||||
insights = self._get_insights_on_topic(topic, now=now)
|
||||
for insight in insights:
|
||||
self.add_memory(insight, now=now)
|
||||
new_insights.extend(insights)
|
||||
return new_insights
|
||||
|
||||
def _score_memory_importance(self, memory_content: str) -> float:
|
||||
"""Score the absolute importance of the given memory."""
|
||||
prompt = PromptTemplate.from_template(
|
||||
"On the scale of 1 to 10, where 1 is purely mundane"
|
||||
+ " (e.g., brushing teeth, making bed) and 10 is"
|
||||
+ " extremely poignant (e.g., a break up, college"
|
||||
+ " acceptance), rate the likely poignancy of the"
|
||||
+ " following piece of memory. Respond with a single integer."
|
||||
+ "\nMemory: {memory_content}"
|
||||
+ "\nRating: "
|
||||
)
|
||||
score = self.chain(prompt).run(memory_content=memory_content).strip()
|
||||
if self.verbose:
|
||||
logger.info(f"Importance score: {score}")
|
||||
match = re.search(r"^\D*(\d+)", score)
|
||||
if match:
|
||||
return (float(match.group(1)) / 10) * self.importance_weight
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
def _score_memories_importance(self, memory_content: str) -> List[float]:
|
||||
"""Score the absolute importance of the given memory."""
|
||||
prompt = PromptTemplate.from_template(
|
||||
"On the scale of 1 to 10, where 1 is purely mundane"
|
||||
+ " (e.g., brushing teeth, making bed) and 10 is"
|
||||
+ " extremely poignant (e.g., a break up, college"
|
||||
+ " acceptance), rate the likely poignancy of the"
|
||||
+ " following piece of memory. Always answer with only a list of numbers."
|
||||
+ " If just given one memory still respond in a list."
|
||||
+ " Memories are separated by semi colans (;)"
|
||||
+ "\Memories: {memory_content}"
|
||||
+ "\nRating: "
|
||||
)
|
||||
scores = self.chain(prompt).run(memory_content=memory_content).strip()
|
||||
|
||||
if self.verbose:
|
||||
logger.info(f"Importance scores: {scores}")
|
||||
|
||||
# Split into list of strings and convert to floats
|
||||
scores_list = [float(x) for x in scores.split(";")]
|
||||
|
||||
return scores_list
|
||||
|
||||
def add_memories(
|
||||
self, memory_content: str, now: Optional[datetime] = None
|
||||
) -> List[str]:
|
||||
"""Add an observations or memories to the agent's memory."""
|
||||
importance_scores = self._score_memories_importance(memory_content)
|
||||
|
||||
self.aggregate_importance += max(importance_scores)
|
||||
memory_list = memory_content.split(";")
|
||||
documents = []
|
||||
|
||||
for i in range(len(memory_list)):
|
||||
documents.append(
|
||||
Document(
|
||||
page_content=memory_list[i],
|
||||
metadata={"importance": importance_scores[i]},
|
||||
)
|
||||
)
|
||||
|
||||
result = self.memory_retriever.add_documents(documents, current_time=now)
|
||||
|
||||
# After an agent has processed a certain amount of memories (as measured by
|
||||
# aggregate importance), it is time to reflect on recent events to add
|
||||
# more synthesized memories to the agent's memory stream.
|
||||
if (
|
||||
self.reflection_threshold is not None
|
||||
and self.aggregate_importance > self.reflection_threshold
|
||||
and not self.reflecting
|
||||
):
|
||||
self.reflecting = True
|
||||
self.pause_to_reflect(now=now)
|
||||
# Hack to clear the importance from reflection
|
||||
self.aggregate_importance = 0.0
|
||||
self.reflecting = False
|
||||
return result
|
||||
|
||||
def add_memory(
|
||||
self, memory_content: str, now: Optional[datetime] = None
|
||||
) -> List[str]:
|
||||
"""Add an observation or memory to the agent's memory."""
|
||||
importance_score = self._score_memory_importance(memory_content)
|
||||
self.aggregate_importance += importance_score
|
||||
document = Document(
|
||||
page_content=memory_content, metadata={"importance": importance_score}
|
||||
)
|
||||
result = self.memory_retriever.add_documents([document], current_time=now)
|
||||
|
||||
# After an agent has processed a certain amount of memories (as measured by
|
||||
# aggregate importance), it is time to reflect on recent events to add
|
||||
# more synthesized memories to the agent's memory stream.
|
||||
if (
|
||||
self.reflection_threshold is not None
|
||||
and self.aggregate_importance > self.reflection_threshold
|
||||
and not self.reflecting
|
||||
):
|
||||
self.reflecting = True
|
||||
self.pause_to_reflect(now=now)
|
||||
# Hack to clear the importance from reflection
|
||||
self.aggregate_importance = 0.0
|
||||
self.reflecting = False
|
||||
return result
|
||||
|
||||
def fetch_memories(
|
||||
self, observation: str, now: Optional[datetime] = None
|
||||
) -> List[Document]:
|
||||
"""Fetch related memories."""
|
||||
if now is not None:
|
||||
with mock_now(now):
|
||||
return self.memory_retriever.get_relevant_documents(observation)
|
||||
else:
|
||||
return self.memory_retriever.get_relevant_documents(observation)
|
||||
|
||||
def format_memories_detail(self, relevant_memories: List[Document]) -> str:
|
||||
content = []
|
||||
for mem in relevant_memories:
|
||||
content.append(self._format_memory_detail(mem, prefix="- "))
|
||||
return "\n".join([f"{mem}" for mem in content])
|
||||
|
||||
def _format_memory_detail(self, memory: Document, prefix: str = "") -> str:
|
||||
created_time = memory.metadata["created_at"].strftime("%B %d, %Y, %I:%M %p")
|
||||
return f"{prefix}[{created_time}] {memory.page_content.strip()}"
|
||||
|
||||
def format_memories_simple(self, relevant_memories: List[Document]) -> str:
|
||||
return "; ".join([f"{mem.page_content}" for mem in relevant_memories])
|
||||
|
||||
def _get_memories_until_limit(self, consumed_tokens: int) -> str:
|
||||
"""Reduce the number of tokens in the documents."""
|
||||
result = []
|
||||
for doc in self.memory_retriever.memory_stream[::-1]:
|
||||
if consumed_tokens >= self.max_tokens_limit:
|
||||
break
|
||||
consumed_tokens += self.llm.get_num_tokens(doc.page_content)
|
||||
if consumed_tokens < self.max_tokens_limit:
|
||||
result.append(doc)
|
||||
return self.format_memories_simple(result)
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""Input keys this memory class will load dynamically."""
|
||||
return []
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""Return key-value pairs given the text input to the chain."""
|
||||
queries = inputs.get(self.queries_key)
|
||||
now = inputs.get(self.now_key)
|
||||
if queries is not None:
|
||||
relevant_memories = [
|
||||
mem for query in queries for mem in self.fetch_memories(query, now=now)
|
||||
]
|
||||
return {
|
||||
self.relevant_memories_key: self.format_memories_detail(
|
||||
relevant_memories
|
||||
),
|
||||
self.relevant_memories_simple_key: self.format_memories_simple(
|
||||
relevant_memories
|
||||
),
|
||||
}
|
||||
|
||||
most_recent_memories_token = inputs.get(self.most_recent_memories_token_key)
|
||||
if most_recent_memories_token is not None:
|
||||
return {
|
||||
self.most_recent_memories_key: self._get_memories_until_limit(
|
||||
most_recent_memories_token
|
||||
)
|
||||
}
|
||||
return {}
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
|
||||
"""Save the context of this model run to memory."""
|
||||
# TODO: fix the save memory key
|
||||
mem = outputs.get(self.add_memory_key)
|
||||
now = outputs.get(self.now_key)
|
||||
if mem:
|
||||
self.add_memory(mem, now=now)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
# TODO
|
6
libs/langchain/langchain/experimental/llms/__init__.py
Normal file
6
libs/langchain/langchain/experimental/llms/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
"""Experimental LLM wrappers."""
|
||||
|
||||
from langchain.experimental.llms.jsonformer_decoder import JsonFormer
|
||||
from langchain.experimental.llms.rellm_decoder import RELLM
|
||||
|
||||
__all__ = ["RELLM", "JsonFormer"]
|
@ -0,0 +1,61 @@
|
||||
"""Experimental implementation of jsonformer wrapped LLM."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, cast
|
||||
|
||||
from pydantic import Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import jsonformer
|
||||
|
||||
|
||||
def import_jsonformer() -> jsonformer:
|
||||
"""Lazily import jsonformer."""
|
||||
try:
|
||||
import jsonformer
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import jsonformer python package. "
|
||||
"Please install it with `pip install jsonformer`."
|
||||
)
|
||||
return jsonformer
|
||||
|
||||
|
||||
class JsonFormer(HuggingFacePipeline):
|
||||
json_schema: dict = Field(..., description="The JSON Schema to complete.")
|
||||
max_new_tokens: int = Field(
|
||||
default=200, description="Maximum number of new tokens to generate."
|
||||
)
|
||||
debug: bool = Field(default=False, description="Debug mode.")
|
||||
|
||||
@root_validator
|
||||
def check_jsonformer_installation(cls, values: dict) -> dict:
|
||||
import_jsonformer()
|
||||
return values
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
jsonformer = import_jsonformer()
|
||||
from transformers import Text2TextGenerationPipeline
|
||||
|
||||
pipeline = cast(Text2TextGenerationPipeline, self.pipeline)
|
||||
|
||||
model = jsonformer.Jsonformer(
|
||||
model=pipeline.model,
|
||||
tokenizer=pipeline.tokenizer,
|
||||
json_schema=self.json_schema,
|
||||
prompt=prompt,
|
||||
max_number_tokens=self.max_new_tokens,
|
||||
debug=self.debug,
|
||||
)
|
||||
text = model()
|
||||
return json.dumps(text)
|
68
libs/langchain/langchain/experimental/llms/rellm_decoder.py
Normal file
68
libs/langchain/langchain/experimental/llms/rellm_decoder.py
Normal file
@ -0,0 +1,68 @@
|
||||
"""Experimental implementation of RELLM wrapped LLM."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, cast
|
||||
|
||||
from pydantic import Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import rellm
|
||||
from regex import Pattern as RegexPattern
|
||||
else:
|
||||
try:
|
||||
from regex import Pattern as RegexPattern
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def import_rellm() -> rellm:
|
||||
"""Lazily import rellm."""
|
||||
try:
|
||||
import rellm
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import rellm python package. "
|
||||
"Please install it with `pip install rellm`."
|
||||
)
|
||||
return rellm
|
||||
|
||||
|
||||
class RELLM(HuggingFacePipeline):
|
||||
regex: RegexPattern = Field(..., description="The structured format to complete.")
|
||||
max_new_tokens: int = Field(
|
||||
default=200, description="Maximum number of new tokens to generate."
|
||||
)
|
||||
|
||||
@root_validator
|
||||
def check_rellm_installation(cls, values: dict) -> dict:
|
||||
import_rellm()
|
||||
return values
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
rellm = import_rellm()
|
||||
from transformers import Text2TextGenerationPipeline
|
||||
|
||||
pipeline = cast(Text2TextGenerationPipeline, self.pipeline)
|
||||
|
||||
text = rellm.complete_re(
|
||||
prompt,
|
||||
self.regex,
|
||||
tokenizer=pipeline.tokenizer,
|
||||
model=pipeline.model,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
)
|
||||
if stop is not None:
|
||||
# This is a bit hacky, but I can't figure out a better way to enforce
|
||||
# stop tokens when making calls to huggingface_hub.
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
return text
|
@ -0,0 +1,9 @@
|
||||
from langchain.experimental.plan_and_execute.agent_executor import PlanAndExecute
|
||||
from langchain.experimental.plan_and_execute.executors.agent_executor import (
|
||||
load_agent_executor,
|
||||
)
|
||||
from langchain.experimental.plan_and_execute.planners.chat_planner import (
|
||||
load_chat_planner,
|
||||
)
|
||||
|
||||
__all__ = ["PlanAndExecute", "load_agent_executor", "load_chat_planner"]
|
@ -0,0 +1,60 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.experimental.plan_and_execute.executors.base import BaseExecutor
|
||||
from langchain.experimental.plan_and_execute.planners.base import BasePlanner
|
||||
from langchain.experimental.plan_and_execute.schema import (
|
||||
BaseStepContainer,
|
||||
ListStepContainer,
|
||||
)
|
||||
|
||||
|
||||
class PlanAndExecute(Chain):
|
||||
planner: BasePlanner
|
||||
executor: BaseExecutor
|
||||
step_container: BaseStepContainer = Field(default_factory=ListStepContainer)
|
||||
input_key: str = "input"
|
||||
output_key: str = "output"
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
return [self.output_key]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
plan = self.planner.plan(
|
||||
inputs,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_text(str(plan), verbose=self.verbose)
|
||||
for step in plan.steps:
|
||||
_new_inputs = {
|
||||
"previous_steps": self.step_container,
|
||||
"current_step": step,
|
||||
"objective": inputs[self.input_key],
|
||||
}
|
||||
new_inputs = {**_new_inputs, **inputs}
|
||||
response = self.executor.step(
|
||||
new_inputs,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_text(
|
||||
f"*****\n\nStep: {step.value}", verbose=self.verbose
|
||||
)
|
||||
run_manager.on_text(
|
||||
f"\n\nResponse: {response.response}", verbose=self.verbose
|
||||
)
|
||||
self.step_container.add_step(step, response)
|
||||
return {self.output_key: self.step_container.get_final_response()}
|
@ -0,0 +1,54 @@
|
||||
from typing import List
|
||||
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.agents.structured_chat.base import StructuredChatAgent
|
||||
from langchain.experimental.plan_and_execute.executors.base import ChainExecutor
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
HUMAN_MESSAGE_TEMPLATE = """Previous steps: {previous_steps}
|
||||
|
||||
Current objective: {current_step}
|
||||
|
||||
{agent_scratchpad}"""
|
||||
|
||||
TASK_PREFIX = """{objective}
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def load_agent_executor(
|
||||
llm: BaseLanguageModel,
|
||||
tools: List[BaseTool],
|
||||
verbose: bool = False,
|
||||
include_task_in_prompt: bool = False,
|
||||
) -> ChainExecutor:
|
||||
"""
|
||||
Load an agent executor.
|
||||
|
||||
Args:
|
||||
llm: BaseLanguageModel
|
||||
tools: List[BaseTool]
|
||||
verbose: bool. Defaults to False.
|
||||
include_task_in_prompt: bool. Defaults to False.
|
||||
|
||||
Returns:
|
||||
ChainExecutor
|
||||
"""
|
||||
input_variables = ["previous_steps", "current_step", "agent_scratchpad"]
|
||||
template = HUMAN_MESSAGE_TEMPLATE
|
||||
|
||||
if include_task_in_prompt:
|
||||
input_variables.append("objective")
|
||||
template = TASK_PREFIX + template
|
||||
|
||||
agent = StructuredChatAgent.from_llm_and_tools(
|
||||
llm,
|
||||
tools,
|
||||
human_message_template=template,
|
||||
input_variables=input_variables,
|
||||
)
|
||||
agent_executor = AgentExecutor.from_agent_and_tools(
|
||||
agent=agent, tools=tools, verbose=verbose
|
||||
)
|
||||
return ChainExecutor(chain=agent_executor)
|
@ -0,0 +1,40 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.experimental.plan_and_execute.schema import StepResponse
|
||||
|
||||
|
||||
class BaseExecutor(BaseModel):
|
||||
@abstractmethod
|
||||
def step(
|
||||
self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> StepResponse:
|
||||
"""Take step."""
|
||||
|
||||
@abstractmethod
|
||||
async def astep(
|
||||
self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> StepResponse:
|
||||
"""Take step."""
|
||||
|
||||
|
||||
class ChainExecutor(BaseExecutor):
|
||||
chain: Chain
|
||||
|
||||
def step(
|
||||
self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> StepResponse:
|
||||
"""Take step."""
|
||||
response = self.chain.run(**inputs, callbacks=callbacks)
|
||||
return StepResponse(response=response)
|
||||
|
||||
async def astep(
|
||||
self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> StepResponse:
|
||||
"""Take step."""
|
||||
response = await self.chain.arun(**inputs, callbacks=callbacks)
|
||||
return StepResponse(response=response)
|
@ -0,0 +1,40 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.experimental.plan_and_execute.schema import Plan, PlanOutputParser
|
||||
|
||||
|
||||
class BasePlanner(BaseModel):
|
||||
@abstractmethod
|
||||
def plan(self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any) -> Plan:
|
||||
"""Given input, decide what to do."""
|
||||
|
||||
@abstractmethod
|
||||
async def aplan(
|
||||
self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Plan:
|
||||
"""Given input, decide what to do."""
|
||||
|
||||
|
||||
class LLMPlanner(BasePlanner):
|
||||
llm_chain: LLMChain
|
||||
output_parser: PlanOutputParser
|
||||
stop: Optional[List] = None
|
||||
|
||||
def plan(self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any) -> Plan:
|
||||
"""Given input, decide what to do."""
|
||||
llm_response = self.llm_chain.run(**inputs, stop=self.stop, callbacks=callbacks)
|
||||
return self.output_parser.parse(llm_response)
|
||||
|
||||
async def aplan(
|
||||
self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Plan:
|
||||
"""Given input, decide what to do."""
|
||||
llm_response = await self.llm_chain.arun(
|
||||
**inputs, stop=self.stop, callbacks=callbacks
|
||||
)
|
||||
return self.output_parser.parse(llm_response)
|
@ -0,0 +1,55 @@
|
||||
import re
|
||||
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.experimental.plan_and_execute.planners.base import LLMPlanner
|
||||
from langchain.experimental.plan_and_execute.schema import (
|
||||
Plan,
|
||||
PlanOutputParser,
|
||||
Step,
|
||||
)
|
||||
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.messages import SystemMessage
|
||||
|
||||
SYSTEM_PROMPT = (
|
||||
"Let's first understand the problem and devise a plan to solve the problem."
|
||||
" Please output the plan starting with the header 'Plan:' "
|
||||
"and then followed by a numbered list of steps. "
|
||||
"Please make the plan the minimum number of steps required "
|
||||
"to accurately complete the task. If the task is a question, "
|
||||
"the final step should almost always be 'Given the above steps taken, "
|
||||
"please respond to the users original question'. "
|
||||
"At the end of your plan, say '<END_OF_PLAN>'"
|
||||
)
|
||||
|
||||
|
||||
class PlanningOutputParser(PlanOutputParser):
|
||||
def parse(self, text: str) -> Plan:
|
||||
steps = [Step(value=v) for v in re.split("\n\s*\d+\. ", text)[1:]]
|
||||
return Plan(steps=steps)
|
||||
|
||||
|
||||
def load_chat_planner(
|
||||
llm: BaseLanguageModel, system_prompt: str = SYSTEM_PROMPT
|
||||
) -> LLMPlanner:
|
||||
"""
|
||||
Load a chat planner.
|
||||
Args:
|
||||
llm: Language model.
|
||||
system_prompt: System prompt.
|
||||
|
||||
Returns:
|
||||
LLMPlanner
|
||||
"""
|
||||
prompt_template = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
SystemMessage(content=system_prompt),
|
||||
HumanMessagePromptTemplate.from_template("{input}"),
|
||||
]
|
||||
)
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt_template)
|
||||
return LLMPlanner(
|
||||
llm_chain=llm_chain,
|
||||
output_parser=PlanningOutputParser(),
|
||||
stop=["<END_OF_PLAN>"],
|
||||
)
|
@ -0,0 +1,47 @@
|
||||
from abc import abstractmethod
|
||||
from typing import List, Tuple
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.schema import BaseOutputParser
|
||||
|
||||
|
||||
class Step(BaseModel):
|
||||
value: str
|
||||
|
||||
|
||||
class Plan(BaseModel):
|
||||
steps: List[Step]
|
||||
|
||||
|
||||
class StepResponse(BaseModel):
|
||||
response: str
|
||||
|
||||
|
||||
class BaseStepContainer(BaseModel):
|
||||
@abstractmethod
|
||||
def add_step(self, step: Step, step_response: StepResponse) -> None:
|
||||
"""Add step and step response to the container."""
|
||||
|
||||
@abstractmethod
|
||||
def get_final_response(self) -> str:
|
||||
"""Return the final response based on steps taken."""
|
||||
|
||||
|
||||
class ListStepContainer(BaseStepContainer):
|
||||
steps: List[Tuple[Step, StepResponse]] = Field(default_factory=list)
|
||||
|
||||
def add_step(self, step: Step, step_response: StepResponse) -> None:
|
||||
self.steps.append((step, step_response))
|
||||
|
||||
def get_steps(self) -> List[Tuple[Step, StepResponse]]:
|
||||
return self.steps
|
||||
|
||||
def get_final_response(self) -> str:
|
||||
return self.steps[-1][1].response
|
||||
|
||||
|
||||
class PlanOutputParser(BaseOutputParser):
|
||||
@abstractmethod
|
||||
def parse(self, text: str) -> Plan:
|
||||
"""Parse into a plan."""
|
Loading…
Reference in New Issue
Block a user