diff --git a/langchain/callbacks/__init__.py b/langchain/callbacks/__init__.py index 2bc79d6aacf..d5f6179a2fc 100644 --- a/langchain/callbacks/__init__.py +++ b/langchain/callbacks/__init__.py @@ -36,7 +36,7 @@ def set_default_callback_manager() -> None: ) -def set_tracing_callback_manager(session_name: Optional[str] = None) -> None: +def set_tracing_callback_manager(session_name: Optional[str] = None, example_id: Optional[int] = None) -> None: """Set tracing callback manager.""" handler = SharedLangChainTracer() callback = get_callback_manager() @@ -49,6 +49,9 @@ def set_tracing_callback_manager(session_name: Optional[str] = None) -> None: except Exception: raise ValueError(f"session {session_name} not found") + if example_id is not None: + handler.example_id = example_id + @contextmanager def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]: diff --git a/langchain/callbacks/tracers/langchain.py b/langchain/callbacks/tracers/langchain.py index d25022041aa..aa5085b7baf 100644 --- a/langchain/callbacks/tracers/langchain.py +++ b/langchain/callbacks/tracers/langchain.py @@ -22,11 +22,22 @@ class BaseLangChainTracer(BaseTracer, ABC): """An implementation of the SharedTracer that POSTS to the langchain endpoint.""" always_verbose: bool = True + _example_id: Optional[int] = None _endpoint: str = os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000") _headers: Dict[str, Any] = {"Content-Type": "application/json"} if os.getenv("LANGCHAIN_API_KEY"): _headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY") + @property + def example_id(self) -> Optional[int]: + """Return the example_id.""" + return self._example_id + + @example_id.setter + def example_id(self, value: int) -> None: + """Set the example_id.""" + self._example_id = value + def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None: """Persist a run.""" if isinstance(run, LLMRun): @@ -36,6 +47,9 @@ class BaseLangChainTracer(BaseTracer, ABC): else: endpoint = f"{self._endpoint}/tool-runs" + if self._example_id: + run.example_id = self._example_id + try: requests.post( endpoint, diff --git a/langchain/callbacks/tracers/schemas.py b/langchain/callbacks/tracers/schemas.py index bb77d747e7c..0d376b097bc 100644 --- a/langchain/callbacks/tracers/schemas.py +++ b/langchain/callbacks/tracers/schemas.py @@ -40,6 +40,7 @@ class BaseRun(BaseModel): serialized: Dict[str, Any] session_id: int error: Optional[str] = None + example_id: Optional[int] = None class LLMRun(BaseRun): diff --git a/langchain/evaluation/ExampleRunner.ipynb b/langchain/evaluation/ExampleRunner.ipynb new file mode 100644 index 00000000000..bfb2a0f34a4 --- /dev/null +++ b/langchain/evaluation/ExampleRunner.ipynb @@ -0,0 +1,317 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "bd0b1515-e25f-4707-a240-b5c26d5d33f8", + "metadata": {}, + "source": [ + "# ExamplerRunner Demo\n", + "\n", + "Run a chain on multiple examples for evaluation." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "13fbc962-98e0-470c-9467-c5e28db658a0", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import os\n", + "os.environ[\"LANGCHAIN_HANDLER\"] = \"langchain\"\n", + "os.environ[\"LANGCHAIN_ENDPOINT\"] = \"http://127.0.0.1:8000\" \n", + "\n", + "import langchain\n", + "from langchain.agents import Tool, initialize_agent, load_tools\n", + "from langchain.llms import OpenAI\n", + "from langchain.evaluation.example_runner import ExampleRunner, CsvDataset\n", + "from langchain.llms import OpenAI\n", + "from langchain.prompts import PromptTemplate\n", + "from langchain.chains import LLMChain" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "8afb6529-dfdb-4f0a-b19f-a700d80d4362", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Upload the dataset (only need to do once)\n", + "runner = ExampleRunner(\n", + " csv_dataset=CsvDataset(\n", + " csv_path=\"test_dataset.csv\",\n", + " description=\"Dummy dataset for testing\",\n", + " input_keys=[\"input1\", \"input2\", \"input3\"],\n", + " output_keys=[\"output1\"],\n", + " ),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "603dc948-0a57-4696-8816-008aaf346538", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[Example(created_at=datetime.datetime(2023, 2, 9, 7, 45, 27, 878780), inputs={'input1': 'one', 'input2': ' two', 'input3': ' three'}, outputs={'output1': ' four'}, dataset_id=1, id=1),\n", + " Example(created_at=datetime.datetime(2023, 2, 9, 7, 45, 27, 878790), inputs={'input1': 'five', 'input2': ' six', 'input3': ' seven'}, outputs={'output1': ' eight'}, dataset_id=1, id=2),\n", + " Example(created_at=datetime.datetime(2023, 2, 9, 7, 45, 27, 878792), inputs={'input1': 'nine', 'input2': ' ten', 'input3': ' eleven'}, outputs={'output1': ' twelve'}, dataset_id=1, id=3)]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "runner.dataset.examples[:3]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "11eba157-ec8e-43e9-8326-75099c07e574", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Retrying langchain.llms.openai.BaseOpenAI.completion_with_retry.._completion_with_retry in 4.0 seconds as it raised RateLimitError: The server had an error while processing your request. Sorry about that!.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + ", four\n", + "\n", + "Five\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Retrying langchain.llms.openai.BaseOpenAI.completion_with_retry.._completion_with_retry in 4.0 seconds as it raised RateLimitError: The server had an error while processing your request. Sorry about that!.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "Eight\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Retrying langchain.llms.openai.BaseOpenAI.completion_with_retry.._completion_with_retry in 4.0 seconds as it raised RateLimitError: The server had an error while processing your request. Sorry about that!.\n", + "Retrying langchain.llms.openai.BaseOpenAI.completion_with_retry.._completion_with_retry in 4.0 seconds as it raised RateLimitError: The server had an error while processing your request. Sorry about that!.\n", + "Retrying langchain.llms.openai.BaseOpenAI.completion_with_retry.._completion_with_retry in 4.0 seconds as it raised RateLimitError: The server had an error while processing your request. Sorry about that!.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "Twelve\n", + "\n", + "\n", + "Sixteen, seventeen, eighteen.\n", + ", twenty\n", + "\n", + "Twenty-one\n", + "\n", + "\n", + "twenty-four\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Retrying langchain.llms.openai.BaseOpenAI.completion_with_retry.._completion_with_retry in 4.0 seconds as it raised RateLimitError: The server had an error while processing your request. Sorry about that!.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + ", twenty-eight\n", + "\n", + "Twenty-nine\n", + "\n", + "\n", + "Thirty-two.\n", + "\n", + "\n", + "Thirty-six\n", + ", forty\n", + "\n", + "forty-one\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Retrying langchain.llms.openai.BaseOpenAI.completion_with_retry.._completion_with_retry in 4.0 seconds as it raised RateLimitError: The server had an error while processing your request. Sorry about that!.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "Forty-four\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Retrying langchain.llms.openai.BaseOpenAI.completion_with_retry.._completion_with_retry in 4.0 seconds as it raised RateLimitError: The server had an error while processing your request. Sorry about that!.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "Forty-eight\n", + "four\n", + " eight\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Retrying langchain.llms.openai.BaseOpenAI.completion_with_retry.._completion_with_retry in 4.0 seconds as it raised RateLimitError: The server had an error while processing your request. Sorry about that!.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " twelve\n", + " sixteen\n", + "\n", + "twenty\n", + " twenty-four\n", + " twenty-eight\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Retrying langchain.llms.openai.BaseOpenAI.completion_with_retry.._completion_with_retry in 4.0 seconds as it raised RateLimitError: The server had an error while processing your request. Sorry about that!.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " thirty-two\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Retrying langchain.llms.openai.BaseOpenAI.completion_with_retry.._completion_with_retry in 4.0 seconds as it raised RateLimitError: The server had an error while processing your request. Sorry about that!.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " thirty-six\n", + " forty\n", + " forty-four\n", + " forty-eight\n" + ] + } + ], + "source": [ + "runner = ExampleRunner(\n", + " langchain_dataset_name=\"test_dataset.csv\",\n", + ")\n", + "\n", + "llm = OpenAI(temperature=0)\n", + "prompt1 = PromptTemplate(\n", + " input_variables=[\"input1\", \"input2\", \"input3\"],\n", + " template=\"Complete the sequence: {input1}, {input2}, {input3}\",\n", + ")\n", + "chain1 = LLMChain(llm=llm, prompt=prompt1)\n", + "\n", + "prompt2 = PromptTemplate(\n", + " input_variables=[\"input1\", \"input2\", \"input3\"],\n", + " template=\"\"\"\n", + " You are given the text representation of three numbers. You are to give the next number in the sequence. Only provide one number! \n", + " \n", + " Example:\n", + " Input: one, two three. \n", + " Output: four \n", + " \n", + " Input: {input1}, {input2}, {input3}\n", + " Output:\n", + " \"\"\"\n", + ")\n", + "chain2 = LLMChain(llm=llm, prompt=prompt2)\n", + "runner.run_chain(chain1)\n", + "runner.run_chain(chain2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b510335a-872f-4392-829d-bdcba3a052cb", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/evaluation/example_runner.py b/langchain/evaluation/example_runner.py new file mode 100644 index 00000000000..532153bd3ca --- /dev/null +++ b/langchain/evaluation/example_runner.py @@ -0,0 +1,202 @@ +import asyncio + +from pydantic import BaseModel, validator, root_validator, Field +from pathlib import Path +from typing import Optional, Dict, Any, List +from langchain.utils import get_from_dict_or_env +from pydantic.networks import AnyHttpUrl +import requests +import datetime +import langchain +from langchain.agents import AgentExecutor +from langchain.chains.base import Chain +from langchain.llms.base import BaseLLM +from langchain.callbacks.base import CallbackManager +from langchain.callbacks.tracers import LangChainTracer +from langchain.callbacks.stdout import StdOutCallbackHandler +from urllib.parse import urlparse +import os + + +class ExampleBase(BaseModel): + """Base class for Example.""" + created_at: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) + inputs: Dict[str, Any] + outputs: Dict[str, Any] | None = None + dataset_id: int + + +class ExampleCreate(ExampleBase): + """Create class for Example.""" + + +class Example(ExampleBase): + """Example schema.""" + id: int + + +class DatasetBase(BaseModel): + """Base class for Dataset.""" + name: str + description: str + created_at: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) + examples: List[Example] = Field(default_factory=list) + + +class DatasetCreate(DatasetBase): + """Create class for Dataset.""" + pass + + +class Dataset(DatasetBase): + """Dataset schema.""" + id: int + + +class CsvDataset(BaseModel): + """Class for a csv file that can be uploaded to a LangChain endpoint.""" + csv_path: Path + description: str + input_keys: List[str] + output_keys: List[str] + + @validator("csv_path") + def validate_csv_path(cls, v): + """Validate that the csv path is valid.""" + if not v.exists(): + raise ValueError("CSV file does not exist.") + return v + + +def fetch_dataset_from_endpoint(name: str, headers: Dict[str, str], endpoint: str = "https://localhost:8000") -> Dataset: + """Fetch a dataset from a LangChain endpoint.""" + response = requests.get(f"{endpoint}/datasets?name={name}", headers=headers) + response.raise_for_status() + if len(response.json()) == 0: + raise ValueError(f"Dataset with name {name} does not exist.") + return Dataset(**(response.json()[0])) + + +def upload_csv_dataset_to_endpoint(csv_dataset: CsvDataset, headers: Dict[str, str], endpoint: str = "https://localhost:8000") -> Dataset: + """Upload a csv to a LangChain endpoint.""" + with open(csv_dataset.csv_path, "rb") as f: + response = requests.post( + f"{endpoint}/datasets/upload", + headers=headers, + files={"file": (csv_dataset.csv_path.name, f)}, + data={ + "input_keys": csv_dataset.input_keys, + "output_keys": csv_dataset.output_keys, + "description": csv_dataset.description, + }, + ) + response.raise_for_status() + return Dataset(**response.json()) + + +class ExampleRunner(BaseModel): + """Class that runs an LLM, chain or agent on a set of examples.""" + + langchain_endpoint: AnyHttpUrl + dataset: Dataset + csv_dataset: Optional[CsvDataset] = None + langchain_dataset_name: Optional[str] = None + langchain_api_key: Optional[str] = None + + @root_validator(pre=True) + def validate_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Validate that either csv_path or langchain_dataset is provided but not both.""" + csv_dataset = values.get("csv_dataset") + langchain_dataset_name = values.get("langchain_dataset_name") + values["langchain_endpoint"] = os.environ.get("LANGCHAIN_ENDPOINT", "https://localhost:8000") + langchain_endpoint = values["langchain_endpoint"] + if csv_dataset is None and langchain_dataset_name is None: + raise ValueError("Must provide either csv_path or langchain_dataset.") + if csv_dataset is not None and langchain_dataset_name is not None: + raise ValueError("Cannot provide both csv_path and langchain_dataset.") + if urlparse(langchain_endpoint).hostname not in ["localhost", "127.0.0.1", "0.0.0.0"]: + values["langchain_api_key"] = get_from_dict_or_env( + values, "langchain_api_key", "LANGCHAIN_API_KEY" + ) + # Try fetching the dataset to make sure it exists + if langchain_dataset_name is not None: + headers: Dict[str, str] = {} + if values.get("langchain_api_key"): + headers["x-api-key"] = values["langchain_api_key"] + values["dataset"] = fetch_dataset_from_endpoint(langchain_dataset_name, headers, langchain_endpoint) + if csv_dataset is not None: + # Upload the csv to the endpoint + headers: Dict[str, str] = {} + if values.get("langchain_api_key"): + headers["x-api-key"] = values["langchain_api_key"] + values["dataset"] = upload_csv_dataset_to_endpoint(csv_dataset, headers, langchain_endpoint) + return values + + def examples(self) -> List[Example]: + """Get the examples from the dataset.""" + return self.dataset.examples + + def run_agent(self, agent: AgentExecutor): + """Run an agent on the examples.""" + for example in self.examples(): + agent.run(**example.inputs) + + def run_chain(self, chain: Chain): + """Run a chain on the examples.""" + for example in self.examples(): + langchain.set_tracing_callback_manager(example_id=example.id) + print(chain.run(**example.inputs)) + + def run_llm(self, llm: BaseLLM): + """Run an LLM on the examples.""" + for example in self.examples(): + llm.generate([val for val in example.inputs.values()]) + + + # async def arun_agent(self, agent: AgentExecutor, num_workers: int = 1): + # """Run an agent on the examples.""" + # # Copy the agent num_workers times + # agents = [] + # for _ in range(num_workers): + # tracer = LangChainTracer() + # tracer.load_default_session() + # manager = CallbackManager([StdOutCallbackHandler(), tracer]) + # agent.from_agent_and_tools(agent.agent, agent.tools, manager) + # agents.append(agent) + # + # i = 0 + # while i < len(self.examples()): + # for agent in agents: + # example = self.examples()[i] + # await agent.arun(**example.inputs) + # i += 1 + + +if __name__ == "__main__": + os.environ["LANGCHAIN_ENDPOINT"] = "http://127.0.0.1:8000" + runner = ExampleRunner( + csv_dataset=CsvDataset( + csv_path="test_dataset.csv", + description="Dummy dataset for testing", + input_keys=["input1", "input2", "input3"], + output_keys=["output1"], + ), + ) + + # runner = ExampleRunner( + # langchain_dataset_name="test_dataset.csv", + # ) + + from langchain.llms import OpenAI + from langchain.prompts import PromptTemplate + from langchain.chains import LLMChain + llm = OpenAI(temperature=0.9, model_name="text-ada-001") + prompt = PromptTemplate( + input_variables=["input1", "input2", "input3"], + template="Complete the sequence: {input1}, {input2}, {input3}", + ) + chain = LLMChain(llm=llm, prompt=prompt) + runner.run_chain(chain) + + + diff --git a/langchain/evaluation/test_dataset.csv b/langchain/evaluation/test_dataset.csv new file mode 100644 index 00000000000..49ccea39b13 --- /dev/null +++ b/langchain/evaluation/test_dataset.csv @@ -0,0 +1,13 @@ +input1,input2,input3,output1 +one, two, three, four +five, six, seven, eight +nine, ten, eleven, twelve +thirteen, fourteen, fifteen, sixteen +seventeen, eighteen, nineteen, twenty +twenty-one, twenty-two, twenty-three, twenty-four +twenty-five, twenty-six, twenty-seven, twenty-eight +twenty-nine, thirty, thirty-one, thirty-two +thirty-three, thirty-four, thirty-five, thirty-six +thirty-seven, thirty-eight, thirty-nine, forty +forty-one, forty-two, forty-three, forty-four +forty-five, forty-six, forty-seven, forty-eight \ No newline at end of file