mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-19 20:04:11 +00:00
Compare commits
1 Commits
harrison/s
...
harrison/l
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
844151605c |
@@ -21,7 +21,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 1,
|
||||
"id": "ac561cc4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -32,7 +32,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 2,
|
||||
"id": "07e96d99",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -43,16 +43,16 @@
|
||||
"db = SQLDatabase.from_uri(\"sqlite:///../../../notebooks/Chinook.db\")\n",
|
||||
"db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)\n",
|
||||
"tools = [\n",
|
||||
"# Tool(\n",
|
||||
"# name = \"Search\",\n",
|
||||
"# func=search.run,\n",
|
||||
"# description=\"useful for when you need to answer questions about current events\"\n",
|
||||
"# ),\n",
|
||||
"# Tool(\n",
|
||||
"# name=\"Calculator\",\n",
|
||||
"# func=llm_math_chain.run,\n",
|
||||
"# description=\"useful for when you need to answer questions about math\"\n",
|
||||
"# ),\n",
|
||||
" Tool(\n",
|
||||
" name = \"Search\",\n",
|
||||
" func=search.run,\n",
|
||||
" description=\"useful for when you need to answer questions about current events\"\n",
|
||||
" ),\n",
|
||||
" Tool(\n",
|
||||
" name=\"Calculator\",\n",
|
||||
" func=llm_math_chain.run,\n",
|
||||
" description=\"useful for when you need to answer questions about math\"\n",
|
||||
" ),\n",
|
||||
" Tool(\n",
|
||||
" name=\"FooBar DB\",\n",
|
||||
" func=db_chain.run,\n",
|
||||
@@ -63,7 +63,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 3,
|
||||
"id": "a069c4b6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -71,153 +71,6 @@
|
||||
"mrkl = initialize_agent(tools, llm, agent=\"zero-shot-react-description\", verbose=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "356a1396",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||
"What was our total number of sales and total revenue in 2013?\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3m I need to find out how many sales and how much revenue we had in 2013\n",
|
||||
"Action: FooBar DB\n",
|
||||
"Action Input: sales 2013\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||
"sales 2013\n",
|
||||
"SQLQuery:\u001b[32;1m\u001b[1;3m SELECT SUM(Total) FROM Invoice WHERE InvoiceDate BETWEEN '2013-01-01' AND '2013-12-31'\u001b[0m\n",
|
||||
"SQLResult: \u001b[33;1m\u001b[1;3m[(450.58000000000027,)]\u001b[0m\n",
|
||||
"Answer:\u001b[32;1m\u001b[1;3m 450.58\u001b[0m\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||
"\n",
|
||||
"Observation: \u001b[36;1m\u001b[1;3m 450.58\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3m I need to find out how many sales and how much revenue we had in 2013\n",
|
||||
"Action: FooBar DB\n",
|
||||
"Action Input: revenue 2013\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||
"revenue 2013\n",
|
||||
"SQLQuery:\u001b[32;1m\u001b[1;3m SELECT SUM(Total) FROM Invoice WHERE strftime('%Y', InvoiceDate) = '2013'\u001b[0m\n",
|
||||
"SQLResult: \u001b[33;1m\u001b[1;3m[(450.58000000000027,)]\u001b[0m\n",
|
||||
"Answer:\u001b[32;1m\u001b[1;3m 450.58\u001b[0m\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||
"\n",
|
||||
"Observation: \u001b[36;1m\u001b[1;3m 450.58\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
||||
"Final Answer: We had 450.58 sales and 450.58 revenue in 2013\u001b[0m\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'We had 450.58 sales and 450.58 revenue in 2013'"
|
||||
]
|
||||
},
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"mrkl.run(\"What was our total number of sales and total revenue in 2013?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"id": "4e9c9b23",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.agents import ZeroShotAgent\n",
|
||||
"from langchain import PromptTemplate\n",
|
||||
"from langchain import LLMChain\n",
|
||||
"TEMPLATE = \"\"\"You are a data engineer answering questions using a SQL database.\n",
|
||||
"\n",
|
||||
"Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. If a previous query produced an error, do NOT try it again.\n",
|
||||
"\n",
|
||||
"Only use the following tables:\n",
|
||||
"\n",
|
||||
"{table_info}\n",
|
||||
"\n",
|
||||
"Use the following format:\n",
|
||||
"\n",
|
||||
"Question: the input question you must answer\n",
|
||||
"Thought: you should always think about what to do\n",
|
||||
"Action: SQLDB\n",
|
||||
"Action Input: the query to run against the SQL database\n",
|
||||
"Observation: the result of the action\n",
|
||||
"... (this Thought/Action/Action Input/Observation can repeat N times)\n",
|
||||
"Thought: I now know the final answer\n",
|
||||
"Final Answer: the final answer to the original input question\n",
|
||||
"\n",
|
||||
"Begin!\n",
|
||||
"\n",
|
||||
"Question: {{input}}\"\"\".format(**{\n",
|
||||
" \"dialect\": db.dialect,\n",
|
||||
" \"table_info\": db.table_info})\n",
|
||||
"prompt = PromptTemplate(template=TEMPLATE, input_variables=[\"input\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"id": "a6fb1b4f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm_chain = LLMChain(llm=llm, prompt=prompt)\n",
|
||||
"tools = [\n",
|
||||
" Tool(\"SQLDB\", db.run, \"foo\")\n",
|
||||
"]\n",
|
||||
"agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools, verbose=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"id": "64578802",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||
"What was our total number of sales and total revenue in 2013?\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3m I need to find the total number of sales and total revenue\n",
|
||||
"Action: SQLDB\n",
|
||||
"Action Input: SELECT SUM(Total) AS \"Total Sales\", SUM(Total) AS \"Total Revenue\" FROM Invoice WHERE InvoiceDate LIKE '2013%'\u001b[0m\n",
|
||||
"Observation: \u001b[36;1m\u001b[1;3m[(450.58000000000027, 450.58000000000027)]\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
||||
"Final Answer: 450.58\u001b[0m\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'450.58'"
|
||||
]
|
||||
},
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"agent.run(\"What was our total number of sales and total revenue in 2013?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
|
||||
@@ -7,9 +7,11 @@ from pydantic import BaseModel
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.input import ChainedInput, get_color_mapping
|
||||
from langchain.input import ChainedInput
|
||||
from langchain.printing import get_color_mapping
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.logger import CONTEXT_KEY
|
||||
|
||||
|
||||
class Action(NamedTuple):
|
||||
@@ -116,7 +118,7 @@ class Agent(Chain, BaseModel, ABC):
|
||||
tool, tool_input = parsed_output
|
||||
return Action(tool, tool_input, full_output)
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""Run text through and get agent response."""
|
||||
text = inputs[self.input_key]
|
||||
# Construct a mapping of tool name to tool for easy lookup
|
||||
@@ -128,7 +130,7 @@ class Agent(Chain, BaseModel, ABC):
|
||||
# prompts the LLM to take an action.
|
||||
starter_string = text + self.starter_string + self.llm_prefix
|
||||
# We use the ChainedInput class to iteratively add to the input over time.
|
||||
chained_input = ChainedInput(starter_string, verbose=self.verbose)
|
||||
chained_input = ChainedInput(starter_string, inputs[CONTEXT_KEY], logger=self.logger)
|
||||
# We construct a mapping from each tool to a color, used for logging.
|
||||
color_mapping = get_color_mapping(
|
||||
[tool.name for tool in self.tools], excluded_colors=["green"]
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
TEMPLATE = """You are a data engineer answering questions using a SQL database.
|
||||
|
||||
Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
|
||||
|
||||
Only use the following tables:
|
||||
|
||||
{table_info}
|
||||
|
||||
Use the following format:
|
||||
|
||||
Question: the input question you must answer
|
||||
Thought: you should always think about what to do
|
||||
Action: SQLDB
|
||||
Action Input: the query to run against the SQL database
|
||||
Observation: the result of the action
|
||||
... (this Thought/Action/Action Input/Observation can repeat N times)
|
||||
Thought: I now know the final answer
|
||||
Final Answer: the final answer to the original input question
|
||||
|
||||
Begin!
|
||||
|
||||
Question: {input}"""
|
||||
@@ -1,9 +1,12 @@
|
||||
"""Base interface that all chains should implement."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
from langchain.logger import PrintLogger
|
||||
import uuid
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.logger import Logger, CONTEXT_KEY
|
||||
|
||||
class Memory(BaseModel, ABC):
|
||||
"""Base interface for memory in chains."""
|
||||
@@ -35,6 +38,7 @@ class Chain(BaseModel, ABC):
|
||||
|
||||
verbose: bool = False
|
||||
"""Whether to print out response text."""
|
||||
logger: Optional[Logger] = None
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
@@ -46,6 +50,19 @@ class Chain(BaseModel, ABC):
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Output keys this chain expects."""
|
||||
|
||||
@root_validator()
|
||||
def add_logger(cls, values: Dict) -> Dict:
|
||||
"""Add a printing logger if verbose=True and none provided."""
|
||||
if values["verbose"] and values["logger"] is None:
|
||||
values["logger"] = PrintLogger()
|
||||
return values
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def _validate_inputs(self, inputs: Dict[str, str]) -> None:
|
||||
"""Check that all inputs are present."""
|
||||
missing_keys = set(self.input_keys).difference(inputs)
|
||||
@@ -76,16 +93,22 @@ class Chain(BaseModel, ABC):
|
||||
chain will be returned. Defaults to False.
|
||||
|
||||
"""
|
||||
if CONTEXT_KEY not in inputs:
|
||||
inputs[CONTEXT_KEY] = {}
|
||||
if "id" not in inputs[CONTEXT_KEY]:
|
||||
inputs[CONTEXT_KEY]["id"] = str(uuid.uuid4())
|
||||
|
||||
if self.memory is not None:
|
||||
external_context = self.memory.load_memory_variables(inputs)
|
||||
inputs = dict(inputs, **external_context)
|
||||
self._validate_inputs(inputs)
|
||||
if self.verbose:
|
||||
print("\n\n\033[1m> Entering new chain...\033[0m")
|
||||
if self.logger:
|
||||
self.logger.log_start_of_chain(inputs)
|
||||
outputs = self._call(inputs)
|
||||
if self.verbose:
|
||||
print("\n\033[1m> Finished chain.\033[0m")
|
||||
self._validate_outputs(outputs)
|
||||
outputs[CONTEXT_KEY] = inputs[CONTEXT_KEY]
|
||||
if self.logger:
|
||||
self.logger.log_end_of_chain(outputs)
|
||||
if self.memory is not None:
|
||||
self.memory.save_context(inputs, outputs)
|
||||
if return_only_outputs:
|
||||
|
||||
@@ -4,9 +4,10 @@ from typing import Any, Dict, List
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.input import print_text
|
||||
from langchain.printing import print_text
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.logger import CONTEXT_KEY
|
||||
|
||||
|
||||
class LLMChain(Chain, BaseModel):
|
||||
@@ -54,9 +55,9 @@ class LLMChain(Chain, BaseModel):
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
|
||||
prompt = self.prompt.format(**selected_inputs)
|
||||
if self.verbose:
|
||||
print("Prompt after formatting:")
|
||||
print_text(prompt, color="green", end="\n")
|
||||
if self.logger:
|
||||
title="Prompt after formatting:"
|
||||
self.logger.log(prompt, inputs[CONTEXT_KEY],title=title, color="green", end="\n")
|
||||
kwargs = {}
|
||||
if "stop" in inputs:
|
||||
kwargs["stop"] = inputs["stop"]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""Chain that interprets a prompt and executes python code to do math."""
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Any
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
@@ -9,6 +9,7 @@ from langchain.chains.llm_math.prompt import PROMPT
|
||||
from langchain.chains.python import PythonChain
|
||||
from langchain.input import ChainedInput
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.logger import CONTEXT_KEY
|
||||
|
||||
|
||||
class LLMMathChain(Chain, BaseModel):
|
||||
@@ -48,10 +49,10 @@ class LLMMathChain(Chain, BaseModel):
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
llm_executor = LLMChain(prompt=PROMPT, llm=self.llm)
|
||||
python_executor = PythonChain()
|
||||
chained_input = ChainedInput(inputs[self.input_key], verbose=self.verbose)
|
||||
chained_input = ChainedInput(inputs[self.input_key], inputs[CONTEXT_KEY], logger=self.logger)
|
||||
t = llm_executor.predict(question=chained_input.input, stop=["```output"])
|
||||
chained_input.add(t, color="green")
|
||||
t = t.strip()
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import Dict, List
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.input import get_color_mapping, print_text
|
||||
|
||||
|
||||
class SequentialChain(Chain, BaseModel):
|
||||
@@ -127,11 +126,8 @@ class SimpleSequentialChain(Chain, BaseModel):
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
_input = inputs[self.input_key]
|
||||
color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))])
|
||||
for i, chain in enumerate(self.chains):
|
||||
_input = chain.run(_input)
|
||||
if self.strip_outputs:
|
||||
_input = _input.strip()
|
||||
if self.verbose:
|
||||
print_text(_input, color=color_mapping[str(i)], end="\n")
|
||||
return {self.output_key: _input}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""Chain for interacting with SQL Database."""
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Any
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
@@ -9,6 +9,7 @@ from langchain.chains.sql_database.prompt import PROMPT
|
||||
from langchain.input import ChainedInput
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.sql_database import SQLDatabase
|
||||
from langchain.logger import CONTEXT_KEY
|
||||
|
||||
|
||||
class SQLDatabaseChain(Chain, BaseModel):
|
||||
@@ -51,10 +52,10 @@ class SQLDatabaseChain(Chain, BaseModel):
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
llm_chain = LLMChain(llm=self.llm, prompt=PROMPT)
|
||||
chained_input = ChainedInput(
|
||||
inputs[self.input_key] + "\nSQLQuery:", verbose=self.verbose
|
||||
inputs[self.input_key] + "\nSQLQuery:", inputs[CONTEXT_KEY], logger=self.logger
|
||||
)
|
||||
llm_inputs = {
|
||||
"input": chained_input.input,
|
||||
@@ -64,10 +65,7 @@ class SQLDatabaseChain(Chain, BaseModel):
|
||||
}
|
||||
sql_cmd = llm_chain.predict(**llm_inputs)
|
||||
chained_input.add(sql_cmd, color="green")
|
||||
try:
|
||||
result = self.database.run(sql_cmd)
|
||||
except Exception as e:
|
||||
result = str(e)
|
||||
result = self.database.run(sql_cmd)
|
||||
chained_input.add("\nSQLResult: ")
|
||||
chained_input.add(result, color="yellow")
|
||||
chained_input.add("\nAnswer:")
|
||||
|
||||
@@ -1,48 +1,24 @@
|
||||
"""Handle chained inputs."""
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
_TEXT_COLOR_MAPPING = {
|
||||
"blue": "36;1",
|
||||
"yellow": "33;1",
|
||||
"pink": "38;5;200",
|
||||
"green": "32;1",
|
||||
}
|
||||
|
||||
|
||||
def get_color_mapping(
|
||||
items: List[str], excluded_colors: Optional[List] = None
|
||||
) -> Dict[str, str]:
|
||||
"""Get mapping for items to a support color."""
|
||||
colors = list(_TEXT_COLOR_MAPPING.keys())
|
||||
if excluded_colors is not None:
|
||||
colors = [c for c in colors if c not in excluded_colors]
|
||||
color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)}
|
||||
return color_mapping
|
||||
|
||||
|
||||
def print_text(text: str, color: Optional[str] = None, end: str = "") -> None:
|
||||
"""Print text with highlighting and no end characters."""
|
||||
if color is None:
|
||||
print(text, end=end)
|
||||
else:
|
||||
color_str = _TEXT_COLOR_MAPPING[color]
|
||||
print(f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m", end=end)
|
||||
from langchain.logger import Logger
|
||||
|
||||
|
||||
class ChainedInput:
|
||||
"""Class for working with input that is the result of chains."""
|
||||
|
||||
def __init__(self, text: str, verbose: bool = False):
|
||||
def __init__(self, text: str, context: dict, logger: Optional[Logger] = None):
|
||||
"""Initialize with verbose flag and initial text."""
|
||||
self._verbose = verbose
|
||||
if self._verbose:
|
||||
print_text(text, None)
|
||||
self._logger = logger
|
||||
if self._logger:
|
||||
self._logger.log(text, context)
|
||||
self._input = text
|
||||
self._context = context
|
||||
|
||||
def add(self, text: str, color: Optional[str] = None) -> None:
|
||||
"""Add text to input, print if in verbose mode."""
|
||||
if self._verbose:
|
||||
print_text(text, color)
|
||||
if self._logger:
|
||||
self._logger.log(text, self._context, color=color)
|
||||
self._input += text
|
||||
|
||||
@property
|
||||
|
||||
70
langchain/logger.py
Normal file
70
langchain/logger.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Any
|
||||
from langchain.printing import print_text
|
||||
from pathlib import Path
|
||||
|
||||
CONTEXT_KEY = "__context__"
|
||||
|
||||
|
||||
class Logger(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def log_start_of_chain(self, inputs):
|
||||
""""""
|
||||
|
||||
@abstractmethod
|
||||
def log_end_of_chain(self, outputs):
|
||||
""""""
|
||||
|
||||
@abstractmethod
|
||||
def log(self, text: str, context: dict, **kwargs):
|
||||
""""""
|
||||
|
||||
|
||||
class PrintLogger(Logger):
|
||||
def log_start_of_chain(self, inputs):
|
||||
""""""
|
||||
print("\n\n\033[1m> Entering new chain...\033[0m")
|
||||
|
||||
def log_end_of_chain(self, outputs):
|
||||
""""""
|
||||
print("\n\033[1m> Finished chain.\033[0m")
|
||||
|
||||
def log(self, text: str, context: dict, title: Optional[str ] =None ,**kwargs:Any):
|
||||
""""""
|
||||
if title is not None:
|
||||
print(title)
|
||||
print_text(text, **kwargs)
|
||||
|
||||
import json
|
||||
class JSONLogger(Logger):
|
||||
|
||||
def __init__(self, log_dir):
|
||||
self.log_dir = Path(log_dir)
|
||||
self.log_dir.mkdir(exist_ok=True)
|
||||
|
||||
def log_start_of_chain(self, inputs):
|
||||
""""""
|
||||
fname = self.log_dir / f"{inputs[CONTEXT_KEY]['id']}.json"
|
||||
if not fname.exists():
|
||||
with open(fname, 'w') as f:
|
||||
json.dump([], f)
|
||||
|
||||
def log_end_of_chain(self, outputs):
|
||||
""""""
|
||||
fname = self.log_dir / f"{outputs[CONTEXT_KEY]['id']}.json"
|
||||
with open(fname) as f:
|
||||
logs = json.load(f)
|
||||
logs.append(outputs)
|
||||
with open(fname, 'w') as f:
|
||||
json.dump(logs, f)
|
||||
|
||||
def log(self, text: str, context: dict, title: Optional[str ] =None ,**kwargs:Any):
|
||||
""""""
|
||||
fname = self.log_dir / f"{context['id']}.json"
|
||||
with open(fname) as f:
|
||||
logs = json.load(f)
|
||||
logs.append({"text": text, "title": title})
|
||||
with open(fname, 'w') as f:
|
||||
json.dump(logs, f)
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import List, Optional, Sequence, Union
|
||||
from langchain.agents.agent import Agent
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.input import get_color_mapping, print_text
|
||||
from langchain.printing import print_text, get_color_mapping
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
|
||||
29
langchain/printing.py
Normal file
29
langchain/printing.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from typing import Optional, List, Dict
|
||||
|
||||
|
||||
def print_text(text: str, color: Optional[str] = None, end: str = "") -> None:
|
||||
"""Print text with highlighting and no end characters."""
|
||||
if color is None:
|
||||
print(text, end=end)
|
||||
else:
|
||||
color_str = _TEXT_COLOR_MAPPING[color]
|
||||
print(f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m", end=end)
|
||||
|
||||
|
||||
def get_color_mapping(
|
||||
items: List[str], excluded_colors: Optional[List] = None
|
||||
) -> Dict[str, str]:
|
||||
"""Get mapping for items to a support color."""
|
||||
colors = list(_TEXT_COLOR_MAPPING.keys())
|
||||
if excluded_colors is not None:
|
||||
colors = [c for c in colors if c not in excluded_colors]
|
||||
color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)}
|
||||
return color_mapping
|
||||
|
||||
|
||||
_TEXT_COLOR_MAPPING = {
|
||||
"blue": "36;1",
|
||||
"yellow": "33;1",
|
||||
"pink": "38;5;200",
|
||||
"green": "32;1",
|
||||
}
|
||||
@@ -67,8 +67,5 @@ class SQLDatabase:
|
||||
|
||||
def run(self, command: str) -> str:
|
||||
"""Execute a SQL command and return a string of the results."""
|
||||
try:
|
||||
result = self._engine.execute(command).fetchall()
|
||||
except Exception as e:
|
||||
result = e
|
||||
result = self._engine.execute(command).fetchall()
|
||||
return str(result)
|
||||
|
||||
@@ -3,7 +3,8 @@
|
||||
import sys
|
||||
from io import StringIO
|
||||
|
||||
from langchain.input import ChainedInput, get_color_mapping
|
||||
from langchain.input import ChainedInput
|
||||
from langchain.printing import get_color_mapping
|
||||
|
||||
|
||||
def test_chained_input_not_verbose() -> None:
|
||||
|
||||
Reference in New Issue
Block a user