Compare commits

..

1 Commits

Author SHA1 Message Date
Harrison Chase
844151605c WIP logging to disk 2022-11-27 15:20:44 -08:00
14 changed files with 172 additions and 247 deletions

View File

@@ -21,7 +21,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 1,
"id": "ac561cc4",
"metadata": {},
"outputs": [],
@@ -32,7 +32,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 2,
"id": "07e96d99",
"metadata": {},
"outputs": [],
@@ -43,16 +43,16 @@
"db = SQLDatabase.from_uri(\"sqlite:///../../../notebooks/Chinook.db\")\n",
"db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)\n",
"tools = [\n",
"# Tool(\n",
"# name = \"Search\",\n",
"# func=search.run,\n",
"# description=\"useful for when you need to answer questions about current events\"\n",
"# ),\n",
"# Tool(\n",
"# name=\"Calculator\",\n",
"# func=llm_math_chain.run,\n",
"# description=\"useful for when you need to answer questions about math\"\n",
"# ),\n",
" Tool(\n",
" name = \"Search\",\n",
" func=search.run,\n",
" description=\"useful for when you need to answer questions about current events\"\n",
" ),\n",
" Tool(\n",
" name=\"Calculator\",\n",
" func=llm_math_chain.run,\n",
" description=\"useful for when you need to answer questions about math\"\n",
" ),\n",
" Tool(\n",
" name=\"FooBar DB\",\n",
" func=db_chain.run,\n",
@@ -63,7 +63,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 3,
"id": "a069c4b6",
"metadata": {},
"outputs": [],
@@ -71,153 +71,6 @@
"mrkl = initialize_agent(tools, llm, agent=\"zero-shot-react-description\", verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "356a1396",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"What was our total number of sales and total revenue in 2013?\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to find out how many sales and how much revenue we had in 2013\n",
"Action: FooBar DB\n",
"Action Input: sales 2013\u001b[0m\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"sales 2013\n",
"SQLQuery:\u001b[32;1m\u001b[1;3m SELECT SUM(Total) FROM Invoice WHERE InvoiceDate BETWEEN '2013-01-01' AND '2013-12-31'\u001b[0m\n",
"SQLResult: \u001b[33;1m\u001b[1;3m[(450.58000000000027,)]\u001b[0m\n",
"Answer:\u001b[32;1m\u001b[1;3m 450.58\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"Observation: \u001b[36;1m\u001b[1;3m 450.58\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to find out how many sales and how much revenue we had in 2013\n",
"Action: FooBar DB\n",
"Action Input: revenue 2013\u001b[0m\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"revenue 2013\n",
"SQLQuery:\u001b[32;1m\u001b[1;3m SELECT SUM(Total) FROM Invoice WHERE strftime('%Y', InvoiceDate) = '2013'\u001b[0m\n",
"SQLResult: \u001b[33;1m\u001b[1;3m[(450.58000000000027,)]\u001b[0m\n",
"Answer:\u001b[32;1m\u001b[1;3m 450.58\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"Observation: \u001b[36;1m\u001b[1;3m 450.58\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: We had 450.58 sales and 450.58 revenue in 2013\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'We had 450.58 sales and 450.58 revenue in 2013'"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mrkl.run(\"What was our total number of sales and total revenue in 2013?\")"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "4e9c9b23",
"metadata": {},
"outputs": [],
"source": [
"from langchain.agents import ZeroShotAgent\n",
"from langchain import PromptTemplate\n",
"from langchain import LLMChain\n",
"TEMPLATE = \"\"\"You are a data engineer answering questions using a SQL database.\n",
"\n",
"Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. If a previous query produced an error, do NOT try it again.\n",
"\n",
"Only use the following tables:\n",
"\n",
"{table_info}\n",
"\n",
"Use the following format:\n",
"\n",
"Question: the input question you must answer\n",
"Thought: you should always think about what to do\n",
"Action: SQLDB\n",
"Action Input: the query to run against the SQL database\n",
"Observation: the result of the action\n",
"... (this Thought/Action/Action Input/Observation can repeat N times)\n",
"Thought: I now know the final answer\n",
"Final Answer: the final answer to the original input question\n",
"\n",
"Begin!\n",
"\n",
"Question: {{input}}\"\"\".format(**{\n",
" \"dialect\": db.dialect,\n",
" \"table_info\": db.table_info})\n",
"prompt = PromptTemplate(template=TEMPLATE, input_variables=[\"input\"])"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "a6fb1b4f",
"metadata": {},
"outputs": [],
"source": [
"llm_chain = LLMChain(llm=llm, prompt=prompt)\n",
"tools = [\n",
" Tool(\"SQLDB\", db.run, \"foo\")\n",
"]\n",
"agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools, verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "64578802",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"What was our total number of sales and total revenue in 2013?\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to find the total number of sales and total revenue\n",
"Action: SQLDB\n",
"Action Input: SELECT SUM(Total) AS \"Total Sales\", SUM(Total) AS \"Total Revenue\" FROM Invoice WHERE InvoiceDate LIKE '2013%'\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3m[(450.58000000000027, 450.58000000000027)]\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: 450.58\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'450.58'"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent.run(\"What was our total number of sales and total revenue in 2013?\")"
]
},
{
"cell_type": "code",
"execution_count": 4,

View File

@@ -7,9 +7,11 @@ from pydantic import BaseModel
from langchain.agents.tools import Tool
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.input import ChainedInput, get_color_mapping
from langchain.input import ChainedInput
from langchain.printing import get_color_mapping
from langchain.llms.base import LLM
from langchain.prompts.base import BasePromptTemplate
from langchain.logger import CONTEXT_KEY
class Action(NamedTuple):
@@ -116,7 +118,7 @@ class Agent(Chain, BaseModel, ABC):
tool, tool_input = parsed_output
return Action(tool, tool_input, full_output)
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""Run text through and get agent response."""
text = inputs[self.input_key]
# Construct a mapping of tool name to tool for easy lookup
@@ -128,7 +130,7 @@ class Agent(Chain, BaseModel, ABC):
# prompts the LLM to take an action.
starter_string = text + self.starter_string + self.llm_prefix
# We use the ChainedInput class to iteratively add to the input over time.
chained_input = ChainedInput(starter_string, verbose=self.verbose)
chained_input = ChainedInput(starter_string, inputs[CONTEXT_KEY], logger=self.logger)
# We construct a mapping from each tool to a color, used for logging.
color_mapping = get_color_mapping(
[tool.name for tool in self.tools], excluded_colors=["green"]

View File

@@ -1,22 +0,0 @@
TEMPLATE = """You are a data engineer answering questions using a SQL database.
Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Only use the following tables:
{table_info}
Use the following format:
Question: the input question you must answer
Thought: you should always think about what to do
Action: SQLDB
Action Input: the query to run against the SQL database
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question
Begin!
Question: {input}"""

View File

@@ -1,9 +1,12 @@
"""Base interface that all chains should implement."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from langchain.logger import PrintLogger
import uuid
from pydantic import BaseModel, Extra
from pydantic import BaseModel, Extra, root_validator
from langchain.logger import Logger, CONTEXT_KEY
class Memory(BaseModel, ABC):
"""Base interface for memory in chains."""
@@ -35,6 +38,7 @@ class Chain(BaseModel, ABC):
verbose: bool = False
"""Whether to print out response text."""
logger: Optional[Logger] = None
@property
@abstractmethod
@@ -46,6 +50,19 @@ class Chain(BaseModel, ABC):
def output_keys(self) -> List[str]:
"""Output keys this chain expects."""
@root_validator()
def add_logger(cls, values: Dict) -> Dict:
"""Add a printing logger if verbose=True and none provided."""
if values["verbose"] and values["logger"] is None:
values["logger"] = PrintLogger()
return values
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
def _validate_inputs(self, inputs: Dict[str, str]) -> None:
"""Check that all inputs are present."""
missing_keys = set(self.input_keys).difference(inputs)
@@ -76,16 +93,22 @@ class Chain(BaseModel, ABC):
chain will be returned. Defaults to False.
"""
if CONTEXT_KEY not in inputs:
inputs[CONTEXT_KEY] = {}
if "id" not in inputs[CONTEXT_KEY]:
inputs[CONTEXT_KEY]["id"] = str(uuid.uuid4())
if self.memory is not None:
external_context = self.memory.load_memory_variables(inputs)
inputs = dict(inputs, **external_context)
self._validate_inputs(inputs)
if self.verbose:
print("\n\n\033[1m> Entering new chain...\033[0m")
if self.logger:
self.logger.log_start_of_chain(inputs)
outputs = self._call(inputs)
if self.verbose:
print("\n\033[1m> Finished chain.\033[0m")
self._validate_outputs(outputs)
outputs[CONTEXT_KEY] = inputs[CONTEXT_KEY]
if self.logger:
self.logger.log_end_of_chain(outputs)
if self.memory is not None:
self.memory.save_context(inputs, outputs)
if return_only_outputs:

View File

@@ -4,9 +4,10 @@ from typing import Any, Dict, List
from pydantic import BaseModel, Extra
from langchain.chains.base import Chain
from langchain.input import print_text
from langchain.printing import print_text
from langchain.llms.base import LLM
from langchain.prompts.base import BasePromptTemplate
from langchain.logger import CONTEXT_KEY
class LLMChain(Chain, BaseModel):
@@ -54,9 +55,9 @@ class LLMChain(Chain, BaseModel):
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
prompt = self.prompt.format(**selected_inputs)
if self.verbose:
print("Prompt after formatting:")
print_text(prompt, color="green", end="\n")
if self.logger:
title="Prompt after formatting:"
self.logger.log(prompt, inputs[CONTEXT_KEY],title=title, color="green", end="\n")
kwargs = {}
if "stop" in inputs:
kwargs["stop"] = inputs["stop"]

View File

@@ -1,5 +1,5 @@
"""Chain that interprets a prompt and executes python code to do math."""
from typing import Dict, List
from typing import Dict, List, Any
from pydantic import BaseModel, Extra
@@ -9,6 +9,7 @@ from langchain.chains.llm_math.prompt import PROMPT
from langchain.chains.python import PythonChain
from langchain.input import ChainedInput
from langchain.llms.base import LLM
from langchain.logger import CONTEXT_KEY
class LLMMathChain(Chain, BaseModel):
@@ -48,10 +49,10 @@ class LLMMathChain(Chain, BaseModel):
"""
return [self.output_key]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
llm_executor = LLMChain(prompt=PROMPT, llm=self.llm)
python_executor = PythonChain()
chained_input = ChainedInput(inputs[self.input_key], verbose=self.verbose)
chained_input = ChainedInput(inputs[self.input_key], inputs[CONTEXT_KEY], logger=self.logger)
t = llm_executor.predict(question=chained_input.input, stop=["```output"])
chained_input.add(t, color="green")
t = t.strip()

View File

@@ -5,7 +5,6 @@ from typing import Dict, List
from pydantic import BaseModel, Extra, root_validator
from langchain.chains.base import Chain
from langchain.input import get_color_mapping, print_text
class SequentialChain(Chain, BaseModel):
@@ -127,11 +126,8 @@ class SimpleSequentialChain(Chain, BaseModel):
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
_input = inputs[self.input_key]
color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))])
for i, chain in enumerate(self.chains):
_input = chain.run(_input)
if self.strip_outputs:
_input = _input.strip()
if self.verbose:
print_text(_input, color=color_mapping[str(i)], end="\n")
return {self.output_key: _input}

View File

@@ -1,5 +1,5 @@
"""Chain for interacting with SQL Database."""
from typing import Dict, List
from typing import Dict, List, Any
from pydantic import BaseModel, Extra
@@ -9,6 +9,7 @@ from langchain.chains.sql_database.prompt import PROMPT
from langchain.input import ChainedInput
from langchain.llms.base import LLM
from langchain.sql_database import SQLDatabase
from langchain.logger import CONTEXT_KEY
class SQLDatabaseChain(Chain, BaseModel):
@@ -51,10 +52,10 @@ class SQLDatabaseChain(Chain, BaseModel):
"""
return [self.output_key]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
llm_chain = LLMChain(llm=self.llm, prompt=PROMPT)
chained_input = ChainedInput(
inputs[self.input_key] + "\nSQLQuery:", verbose=self.verbose
inputs[self.input_key] + "\nSQLQuery:", inputs[CONTEXT_KEY], logger=self.logger
)
llm_inputs = {
"input": chained_input.input,
@@ -64,10 +65,7 @@ class SQLDatabaseChain(Chain, BaseModel):
}
sql_cmd = llm_chain.predict(**llm_inputs)
chained_input.add(sql_cmd, color="green")
try:
result = self.database.run(sql_cmd)
except Exception as e:
result = str(e)
result = self.database.run(sql_cmd)
chained_input.add("\nSQLResult: ")
chained_input.add(result, color="yellow")
chained_input.add("\nAnswer:")

View File

@@ -1,48 +1,24 @@
"""Handle chained inputs."""
from typing import Dict, List, Optional
from typing import Optional
_TEXT_COLOR_MAPPING = {
"blue": "36;1",
"yellow": "33;1",
"pink": "38;5;200",
"green": "32;1",
}
def get_color_mapping(
items: List[str], excluded_colors: Optional[List] = None
) -> Dict[str, str]:
"""Get mapping for items to a support color."""
colors = list(_TEXT_COLOR_MAPPING.keys())
if excluded_colors is not None:
colors = [c for c in colors if c not in excluded_colors]
color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)}
return color_mapping
def print_text(text: str, color: Optional[str] = None, end: str = "") -> None:
"""Print text with highlighting and no end characters."""
if color is None:
print(text, end=end)
else:
color_str = _TEXT_COLOR_MAPPING[color]
print(f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m", end=end)
from langchain.logger import Logger
class ChainedInput:
"""Class for working with input that is the result of chains."""
def __init__(self, text: str, verbose: bool = False):
def __init__(self, text: str, context: dict, logger: Optional[Logger] = None):
"""Initialize with verbose flag and initial text."""
self._verbose = verbose
if self._verbose:
print_text(text, None)
self._logger = logger
if self._logger:
self._logger.log(text, context)
self._input = text
self._context = context
def add(self, text: str, color: Optional[str] = None) -> None:
"""Add text to input, print if in verbose mode."""
if self._verbose:
print_text(text, color)
if self._logger:
self._logger.log(text, self._context, color=color)
self._input += text
@property

70
langchain/logger.py Normal file
View File

@@ -0,0 +1,70 @@
from abc import ABC, abstractmethod
from typing import Optional, Any
from langchain.printing import print_text
from pathlib import Path
CONTEXT_KEY = "__context__"
class Logger(ABC):
@abstractmethod
def log_start_of_chain(self, inputs):
""""""
@abstractmethod
def log_end_of_chain(self, outputs):
""""""
@abstractmethod
def log(self, text: str, context: dict, **kwargs):
""""""
class PrintLogger(Logger):
def log_start_of_chain(self, inputs):
""""""
print("\n\n\033[1m> Entering new chain...\033[0m")
def log_end_of_chain(self, outputs):
""""""
print("\n\033[1m> Finished chain.\033[0m")
def log(self, text: str, context: dict, title: Optional[str ] =None ,**kwargs:Any):
""""""
if title is not None:
print(title)
print_text(text, **kwargs)
import json
class JSONLogger(Logger):
def __init__(self, log_dir):
self.log_dir = Path(log_dir)
self.log_dir.mkdir(exist_ok=True)
def log_start_of_chain(self, inputs):
""""""
fname = self.log_dir / f"{inputs[CONTEXT_KEY]['id']}.json"
if not fname.exists():
with open(fname, 'w') as f:
json.dump([], f)
def log_end_of_chain(self, outputs):
""""""
fname = self.log_dir / f"{outputs[CONTEXT_KEY]['id']}.json"
with open(fname) as f:
logs = json.load(f)
logs.append(outputs)
with open(fname, 'w') as f:
json.dump(logs, f)
def log(self, text: str, context: dict, title: Optional[str ] =None ,**kwargs:Any):
""""""
fname = self.log_dir / f"{context['id']}.json"
with open(fname) as f:
logs = json.load(f)
logs.append({"text": text, "title": title})
with open(fname, 'w') as f:
json.dump(logs, f)

View File

@@ -4,7 +4,7 @@ from typing import List, Optional, Sequence, Union
from langchain.agents.agent import Agent
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.input import get_color_mapping, print_text
from langchain.printing import print_text, get_color_mapping
from langchain.llms.base import LLM
from langchain.prompts.prompt import PromptTemplate

29
langchain/printing.py Normal file
View File

@@ -0,0 +1,29 @@
from typing import Optional, List, Dict
def print_text(text: str, color: Optional[str] = None, end: str = "") -> None:
"""Print text with highlighting and no end characters."""
if color is None:
print(text, end=end)
else:
color_str = _TEXT_COLOR_MAPPING[color]
print(f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m", end=end)
def get_color_mapping(
items: List[str], excluded_colors: Optional[List] = None
) -> Dict[str, str]:
"""Get mapping for items to a support color."""
colors = list(_TEXT_COLOR_MAPPING.keys())
if excluded_colors is not None:
colors = [c for c in colors if c not in excluded_colors]
color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)}
return color_mapping
_TEXT_COLOR_MAPPING = {
"blue": "36;1",
"yellow": "33;1",
"pink": "38;5;200",
"green": "32;1",
}

View File

@@ -67,8 +67,5 @@ class SQLDatabase:
def run(self, command: str) -> str:
"""Execute a SQL command and return a string of the results."""
try:
result = self._engine.execute(command).fetchall()
except Exception as e:
result = e
result = self._engine.execute(command).fetchall()
return str(result)

View File

@@ -3,7 +3,8 @@
import sys
from io import StringIO
from langchain.input import ChainedInput, get_color_mapping
from langchain.input import ChainedInput
from langchain.printing import get_color_mapping
def test_chained_input_not_verbose() -> None: