This commit is contained in:
Ankush Gola 2023-02-09 10:46:30 -08:00
parent a306baacd1
commit 6b64ff87b7
6 changed files with 551 additions and 1 deletions

View File

@ -36,7 +36,7 @@ def set_default_callback_manager() -> None:
)
def set_tracing_callback_manager(session_name: Optional[str] = None) -> None:
def set_tracing_callback_manager(session_name: Optional[str] = None, example_id: Optional[int] = None) -> None:
"""Set tracing callback manager."""
handler = SharedLangChainTracer()
callback = get_callback_manager()
@ -49,6 +49,9 @@ def set_tracing_callback_manager(session_name: Optional[str] = None) -> None:
except Exception:
raise ValueError(f"session {session_name} not found")
if example_id is not None:
handler.example_id = example_id
@contextmanager
def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]:

View File

@ -22,11 +22,22 @@ class BaseLangChainTracer(BaseTracer, ABC):
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
always_verbose: bool = True
_example_id: Optional[int] = None
_endpoint: str = os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000")
_headers: Dict[str, Any] = {"Content-Type": "application/json"}
if os.getenv("LANGCHAIN_API_KEY"):
_headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY")
@property
def example_id(self) -> Optional[int]:
"""Return the example_id."""
return self._example_id
@example_id.setter
def example_id(self, value: int) -> None:
"""Set the example_id."""
self._example_id = value
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
"""Persist a run."""
if isinstance(run, LLMRun):
@ -36,6 +47,9 @@ class BaseLangChainTracer(BaseTracer, ABC):
else:
endpoint = f"{self._endpoint}/tool-runs"
if self._example_id:
run.example_id = self._example_id
try:
requests.post(
endpoint,

View File

@ -40,6 +40,7 @@ class BaseRun(BaseModel):
serialized: Dict[str, Any]
session_id: int
error: Optional[str] = None
example_id: Optional[int] = None
class LLMRun(BaseRun):

View File

@ -0,0 +1,317 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "bd0b1515-e25f-4707-a240-b5c26d5d33f8",
"metadata": {},
"source": [
"# ExamplerRunner Demo\n",
"\n",
"Run a chain on multiple examples for evaluation."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "13fbc962-98e0-470c-9467-c5e28db658a0",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import os\n",
"os.environ[\"LANGCHAIN_HANDLER\"] = \"langchain\"\n",
"os.environ[\"LANGCHAIN_ENDPOINT\"] = \"http://127.0.0.1:8000\" \n",
"\n",
"import langchain\n",
"from langchain.agents import Tool, initialize_agent, load_tools\n",
"from langchain.llms import OpenAI\n",
"from langchain.evaluation.example_runner import ExampleRunner, CsvDataset\n",
"from langchain.llms import OpenAI\n",
"from langchain.prompts import PromptTemplate\n",
"from langchain.chains import LLMChain"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "8afb6529-dfdb-4f0a-b19f-a700d80d4362",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# Upload the dataset (only need to do once)\n",
"runner = ExampleRunner(\n",
" csv_dataset=CsvDataset(\n",
" csv_path=\"test_dataset.csv\",\n",
" description=\"Dummy dataset for testing\",\n",
" input_keys=[\"input1\", \"input2\", \"input3\"],\n",
" output_keys=[\"output1\"],\n",
" ),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "603dc948-0a57-4696-8816-008aaf346538",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"[Example(created_at=datetime.datetime(2023, 2, 9, 7, 45, 27, 878780), inputs={'input1': 'one', 'input2': ' two', 'input3': ' three'}, outputs={'output1': ' four'}, dataset_id=1, id=1),\n",
" Example(created_at=datetime.datetime(2023, 2, 9, 7, 45, 27, 878790), inputs={'input1': 'five', 'input2': ' six', 'input3': ' seven'}, outputs={'output1': ' eight'}, dataset_id=1, id=2),\n",
" Example(created_at=datetime.datetime(2023, 2, 9, 7, 45, 27, 878792), inputs={'input1': 'nine', 'input2': ' ten', 'input3': ' eleven'}, outputs={'output1': ' twelve'}, dataset_id=1, id=3)]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"runner.dataset.examples[:3]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "11eba157-ec8e-43e9-8326-75099c07e574",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Retrying langchain.llms.openai.BaseOpenAI.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised RateLimitError: The server had an error while processing your request. Sorry about that!.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
", four\n",
"\n",
"Five\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Retrying langchain.llms.openai.BaseOpenAI.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised RateLimitError: The server had an error while processing your request. Sorry about that!.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"Eight\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Retrying langchain.llms.openai.BaseOpenAI.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised RateLimitError: The server had an error while processing your request. Sorry about that!.\n",
"Retrying langchain.llms.openai.BaseOpenAI.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised RateLimitError: The server had an error while processing your request. Sorry about that!.\n",
"Retrying langchain.llms.openai.BaseOpenAI.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised RateLimitError: The server had an error while processing your request. Sorry about that!.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"Twelve\n",
"\n",
"\n",
"Sixteen, seventeen, eighteen.\n",
", twenty\n",
"\n",
"Twenty-one\n",
"\n",
"\n",
"twenty-four\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Retrying langchain.llms.openai.BaseOpenAI.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised RateLimitError: The server had an error while processing your request. Sorry about that!.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
", twenty-eight\n",
"\n",
"Twenty-nine\n",
"\n",
"\n",
"Thirty-two.\n",
"\n",
"\n",
"Thirty-six\n",
", forty\n",
"\n",
"forty-one\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Retrying langchain.llms.openai.BaseOpenAI.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised RateLimitError: The server had an error while processing your request. Sorry about that!.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"Forty-four\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Retrying langchain.llms.openai.BaseOpenAI.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised RateLimitError: The server had an error while processing your request. Sorry about that!.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"Forty-eight\n",
"four\n",
" eight\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Retrying langchain.llms.openai.BaseOpenAI.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised RateLimitError: The server had an error while processing your request. Sorry about that!.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" twelve\n",
" sixteen\n",
"\n",
"twenty\n",
" twenty-four\n",
" twenty-eight\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Retrying langchain.llms.openai.BaseOpenAI.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised RateLimitError: The server had an error while processing your request. Sorry about that!.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" thirty-two\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Retrying langchain.llms.openai.BaseOpenAI.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised RateLimitError: The server had an error while processing your request. Sorry about that!.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" thirty-six\n",
" forty\n",
" forty-four\n",
" forty-eight\n"
]
}
],
"source": [
"runner = ExampleRunner(\n",
" langchain_dataset_name=\"test_dataset.csv\",\n",
")\n",
"\n",
"llm = OpenAI(temperature=0)\n",
"prompt1 = PromptTemplate(\n",
" input_variables=[\"input1\", \"input2\", \"input3\"],\n",
" template=\"Complete the sequence: {input1}, {input2}, {input3}\",\n",
")\n",
"chain1 = LLMChain(llm=llm, prompt=prompt1)\n",
"\n",
"prompt2 = PromptTemplate(\n",
" input_variables=[\"input1\", \"input2\", \"input3\"],\n",
" template=\"\"\"\n",
" You are given the text representation of three numbers. You are to give the next number in the sequence. Only provide one number! \n",
" \n",
" Example:\n",
" Input: one, two three. \n",
" Output: four \n",
" \n",
" Input: {input1}, {input2}, {input3}\n",
" Output:\n",
" \"\"\"\n",
")\n",
"chain2 = LLMChain(llm=llm, prompt=prompt2)\n",
"runner.run_chain(chain1)\n",
"runner.run_chain(chain2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b510335a-872f-4392-829d-bdcba3a052cb",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -0,0 +1,202 @@
import asyncio
from pydantic import BaseModel, validator, root_validator, Field
from pathlib import Path
from typing import Optional, Dict, Any, List
from langchain.utils import get_from_dict_or_env
from pydantic.networks import AnyHttpUrl
import requests
import datetime
import langchain
from langchain.agents import AgentExecutor
from langchain.chains.base import Chain
from langchain.llms.base import BaseLLM
from langchain.callbacks.base import CallbackManager
from langchain.callbacks.tracers import LangChainTracer
from langchain.callbacks.stdout import StdOutCallbackHandler
from urllib.parse import urlparse
import os
class ExampleBase(BaseModel):
"""Base class for Example."""
created_at: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
inputs: Dict[str, Any]
outputs: Dict[str, Any] | None = None
dataset_id: int
class ExampleCreate(ExampleBase):
"""Create class for Example."""
class Example(ExampleBase):
"""Example schema."""
id: int
class DatasetBase(BaseModel):
"""Base class for Dataset."""
name: str
description: str
created_at: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
examples: List[Example] = Field(default_factory=list)
class DatasetCreate(DatasetBase):
"""Create class for Dataset."""
pass
class Dataset(DatasetBase):
"""Dataset schema."""
id: int
class CsvDataset(BaseModel):
"""Class for a csv file that can be uploaded to a LangChain endpoint."""
csv_path: Path
description: str
input_keys: List[str]
output_keys: List[str]
@validator("csv_path")
def validate_csv_path(cls, v):
"""Validate that the csv path is valid."""
if not v.exists():
raise ValueError("CSV file does not exist.")
return v
def fetch_dataset_from_endpoint(name: str, headers: Dict[str, str], endpoint: str = "https://localhost:8000") -> Dataset:
"""Fetch a dataset from a LangChain endpoint."""
response = requests.get(f"{endpoint}/datasets?name={name}", headers=headers)
response.raise_for_status()
if len(response.json()) == 0:
raise ValueError(f"Dataset with name {name} does not exist.")
return Dataset(**(response.json()[0]))
def upload_csv_dataset_to_endpoint(csv_dataset: CsvDataset, headers: Dict[str, str], endpoint: str = "https://localhost:8000") -> Dataset:
"""Upload a csv to a LangChain endpoint."""
with open(csv_dataset.csv_path, "rb") as f:
response = requests.post(
f"{endpoint}/datasets/upload",
headers=headers,
files={"file": (csv_dataset.csv_path.name, f)},
data={
"input_keys": csv_dataset.input_keys,
"output_keys": csv_dataset.output_keys,
"description": csv_dataset.description,
},
)
response.raise_for_status()
return Dataset(**response.json())
class ExampleRunner(BaseModel):
"""Class that runs an LLM, chain or agent on a set of examples."""
langchain_endpoint: AnyHttpUrl
dataset: Dataset
csv_dataset: Optional[CsvDataset] = None
langchain_dataset_name: Optional[str] = None
langchain_api_key: Optional[str] = None
@root_validator(pre=True)
def validate_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate that either csv_path or langchain_dataset is provided but not both."""
csv_dataset = values.get("csv_dataset")
langchain_dataset_name = values.get("langchain_dataset_name")
values["langchain_endpoint"] = os.environ.get("LANGCHAIN_ENDPOINT", "https://localhost:8000")
langchain_endpoint = values["langchain_endpoint"]
if csv_dataset is None and langchain_dataset_name is None:
raise ValueError("Must provide either csv_path or langchain_dataset.")
if csv_dataset is not None and langchain_dataset_name is not None:
raise ValueError("Cannot provide both csv_path and langchain_dataset.")
if urlparse(langchain_endpoint).hostname not in ["localhost", "127.0.0.1", "0.0.0.0"]:
values["langchain_api_key"] = get_from_dict_or_env(
values, "langchain_api_key", "LANGCHAIN_API_KEY"
)
# Try fetching the dataset to make sure it exists
if langchain_dataset_name is not None:
headers: Dict[str, str] = {}
if values.get("langchain_api_key"):
headers["x-api-key"] = values["langchain_api_key"]
values["dataset"] = fetch_dataset_from_endpoint(langchain_dataset_name, headers, langchain_endpoint)
if csv_dataset is not None:
# Upload the csv to the endpoint
headers: Dict[str, str] = {}
if values.get("langchain_api_key"):
headers["x-api-key"] = values["langchain_api_key"]
values["dataset"] = upload_csv_dataset_to_endpoint(csv_dataset, headers, langchain_endpoint)
return values
def examples(self) -> List[Example]:
"""Get the examples from the dataset."""
return self.dataset.examples
def run_agent(self, agent: AgentExecutor):
"""Run an agent on the examples."""
for example in self.examples():
agent.run(**example.inputs)
def run_chain(self, chain: Chain):
"""Run a chain on the examples."""
for example in self.examples():
langchain.set_tracing_callback_manager(example_id=example.id)
print(chain.run(**example.inputs))
def run_llm(self, llm: BaseLLM):
"""Run an LLM on the examples."""
for example in self.examples():
llm.generate([val for val in example.inputs.values()])
# async def arun_agent(self, agent: AgentExecutor, num_workers: int = 1):
# """Run an agent on the examples."""
# # Copy the agent num_workers times
# agents = []
# for _ in range(num_workers):
# tracer = LangChainTracer()
# tracer.load_default_session()
# manager = CallbackManager([StdOutCallbackHandler(), tracer])
# agent.from_agent_and_tools(agent.agent, agent.tools, manager)
# agents.append(agent)
#
# i = 0
# while i < len(self.examples()):
# for agent in agents:
# example = self.examples()[i]
# await agent.arun(**example.inputs)
# i += 1
if __name__ == "__main__":
os.environ["LANGCHAIN_ENDPOINT"] = "http://127.0.0.1:8000"
runner = ExampleRunner(
csv_dataset=CsvDataset(
csv_path="test_dataset.csv",
description="Dummy dataset for testing",
input_keys=["input1", "input2", "input3"],
output_keys=["output1"],
),
)
# runner = ExampleRunner(
# langchain_dataset_name="test_dataset.csv",
# )
from langchain.llms import OpenAI
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
llm = OpenAI(temperature=0.9, model_name="text-ada-001")
prompt = PromptTemplate(
input_variables=["input1", "input2", "input3"],
template="Complete the sequence: {input1}, {input2}, {input3}",
)
chain = LLMChain(llm=llm, prompt=prompt)
runner.run_chain(chain)

View File

@ -0,0 +1,13 @@
input1,input2,input3,output1
one, two, three, four
five, six, seven, eight
nine, ten, eleven, twelve
thirteen, fourteen, fifteen, sixteen
seventeen, eighteen, nineteen, twenty
twenty-one, twenty-two, twenty-three, twenty-four
twenty-five, twenty-six, twenty-seven, twenty-eight
twenty-nine, thirty, thirty-one, thirty-two
thirty-three, thirty-four, thirty-five, thirty-six
thirty-seven, thirty-eight, thirty-nine, forty
forty-one, forty-two, forty-three, forty-four
forty-five, forty-six, forty-seven, forty-eight
1 input1 input2 input3 output1
2 one two three four
3 five six seven eight
4 nine ten eleven twelve
5 thirteen fourteen fifteen sixteen
6 seventeen eighteen nineteen twenty
7 twenty-one twenty-two twenty-three twenty-four
8 twenty-five twenty-six twenty-seven twenty-eight
9 twenty-nine thirty thirty-one thirty-two
10 thirty-three thirty-four thirty-five thirty-six
11 thirty-seven thirty-eight thirty-nine forty
12 forty-one forty-two forty-three forty-four
13 forty-five forty-six forty-seven forty-eight