From ddc26e074e67e3db2ce7f22e52fd6e81be2de5aa Mon Sep 17 00:00:00 2001 From: vowelparrot <130414180+vowelparrot@users.noreply.github.com> Date: Wed, 3 May 2023 09:32:24 -0700 Subject: [PATCH] [WIP] Example Notebook running a chain on a dataset --- .../evaluating_traced_examples.ipynb | 374 ++++++++++++++++++ langchain/callbacks/manager.py | 4 +- langchain/callbacks/tracers/base.py | 6 +- langchain/callbacks/tracers/langchain.py | 1 + langchain/callbacks/tracers/schemas.py | 1 + 5 files changed, 384 insertions(+), 2 deletions(-) create mode 100644 docs/use_cases/evaluation/evaluating_traced_examples.ipynb diff --git a/docs/use_cases/evaluation/evaluating_traced_examples.ipynb b/docs/use_cases/evaluation/evaluating_traced_examples.ipynb new file mode 100644 index 00000000000..957fb2f0d22 --- /dev/null +++ b/docs/use_cases/evaluation/evaluating_traced_examples.ipynb @@ -0,0 +1,374 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "1a4596ea-a631-416d-a2a4-3577c140493d", + "metadata": {}, + "source": [ + "## Running chains on Datasets" + ] + }, + { + "cell_type": "markdown", + "id": "185fc992-7472-415b-8eda-cb4e11c9068f", + "metadata": {}, + "source": [ + "Some setup to upload the dataset to LangChainPlus. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "54b4cfb2-2a4b-4017-a474-b088023ea3ec", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "!pip install datasets > /dev/null" + ] + }, + { + "cell_type": "markdown", + "id": "443eacbd", + "metadata": {}, + "source": [ + "You'll already have a dataset in LangChain+, but we'll upload this as an example" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "e360961b-65e9-4fd6-b609-916bae62ff72", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import os\n", + "import pandas as pd\n", + "from io import BytesIO\n", + "import requests\n", + "from typing import Sequence, Tuple, Optional\n", + "\n", + "def upload_dataframe(base_url: str,\n", + " df: pd.DataFrame,\n", + " input_keys: Sequence[str], \n", + " output_keys: Sequence[str], \n", + " name: str, \n", + " description: str, auth: Optional[dict] = None) -> Tuple[int, dict]:\n", + " buffer = BytesIO()\n", + " df.to_csv(buffer, index=False)\n", + " buffer.seek(0)\n", + " files = {\"file\": (f\"{name}.csv\", buffer)}\n", + " data = {\n", + " \"input_keys\": ','.join(input_keys),\n", + " \"output_keys\": ','.join(output_keys),\n", + " \"description\": description,\n", + " }\n", + " response = requests.post(base_url + \"/datasets/upload\", auth=auth, data=data, files=files)\n", + " return response.status_code, response.json()" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "79b4605d-0adc-49a6-b542-5b3fe22fa020", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "endpoint = os.getenv(\"LANGCHAIN_ENDPOINT\", \"http://localhost:8000\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "69af24a2-9892-408d-91ac-6b265b9c903c", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Found cached dataset json (/Users/wfh/.cache/huggingface/datasets/LangChainDatasets___json/LangChainDatasets--agent-search-calculator-8a025c0ce5fb99d2/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e8c0711e6f3e40909da0459419166960", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1 [00:00 dict:\n", + " \"\"\"Run the chain asynchronously\"\"\"\n", + " langchain_tracer.example_id = example[\"id\"]\n", + " inputs = example[\"inputs\"]\n", + " try:\n", + " chain_output = await chain.arun(inputs, callbacks=[langchain_tracer])\n", + " langchain_tracer.example_id = None\n", + " except Exception as e:\n", + " logger.error(e)\n", + " return {\"Error\": str(e)}\n", + " finally:\n", + " langchain_tracer.example_id = None\n", + " return chain_output\n", + "\n", + "\n", + "async def run_dataset(dataset: dict, chain: Chain, batch_size: int = 5):\n", + " \"\"\"Grade the QA examples.\"\"\"\n", + " logger.info(\"`Grading QA performance ...`\")\n", + " tracers = [LangChainTracer() for _ in range(batch_size)]\n", + " for tracer in tracers:\n", + " tracer.load_session(\"default\")\n", + " graded_outputs = []\n", + " total_examples = len(dataset[\"examples\"])\n", + " for i in range(0, total_examples, batch_size):\n", + " batch_results = []\n", + " for j in range(i, min(total_examples, i + batch_size)):\n", + " example = dataset[\"examples\"][j]\n", + " batch_results.append(arun_chain(example, tracers[j % len(tracers)], chain))\n", + " graded_outputs.extend(await asyncio.gather(*batch_results))\n", + " return graded_outputs\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "6c588a1f-f869-4d63-a088-e01d0707116c", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from langchain.chat_models import ChatOpenAI\n", + "from langchain.agents import initialize_agent, load_tools\n", + "from langchain.agents import AgentType\n", + "from langchain import OpenAI\n", + "\n", + "llm = ChatOpenAI(temperature=0)\n", + "tools = load_tools(['serpapi', 'llm-math'], llm=llm)\n", + "agent = initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "d8483f0e-e2ed-4835-96a9-30c2a22aa471", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "'age'. Please try again with a valid numerical expression\n", + "unknown format from LLM: Assuming we don't have any information about the actual number of points scored in the 2023 super bowl, we cannot provide a mathematical expression to solve this problem.\n", + "invalid syntax. Perhaps you forgot a comma? (, line 1). Please try again with a valid numerical expression\n", + "Could not parse LLM output: `The final answer is that there were no more points scored in the 2023 Super Bowl than in the 2022 Super Bowl.`\n" + ] + }, + { + "data": { + "text/plain": [ + "['The current population of Canada as of May 3, 2023 is 38,677,281.',\n", + " \"Anwar Hadid's age raised to the .43 power is approximately 2.68.\",\n", + " {'Error': \"'age'. Please try again with a valid numerical expression\"},\n", + " 'The distance between Paris and Boston is 3448 miles.',\n", + " {'Error': \"unknown format from LLM: Assuming we don't have any information about the actual number of points scored in the 2023 super bowl, we cannot provide a mathematical expression to solve this problem.\"},\n", + " {'Error': 'invalid syntax. Perhaps you forgot a comma? (, line 1). Please try again with a valid numerical expression'},\n", + " {'Error': 'Could not parse LLM output: `The final answer is that there were no more points scored in the 2023 Super Bowl than in the 2022 Super Bowl.`'},\n", + " '1.9347796717823205',\n", + " \"Bad Bunny's height (in inches) raised to the .13 power is approximately 1.740 inches.\",\n", + " '0.2791714614499425']" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "results = await run_dataset(dataset, agent)\n", + "results" + ] + }, + { + "cell_type": "markdown", + "id": "b71e72db-fd81-498f-afc6-4596662aacc4", + "metadata": {}, + "source": [ + "scratch\n", + "----------\n", + "\n", + "ignore below" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "63528052-06e9-4488-ab7b-6489b32531bd", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# from langchain import PromptTemplate\n", + "# template = \"\"\"You are assessing a submitted student answer to a question relative to the true answer based on the provided criteria: \n", + " \n", + "# ***\n", + "# QUESTION: {query}\n", + "# ***\n", + "# STUDENT ANSWER: {result}\n", + "# ***\n", + "# TRUE ANSWER: {answer}\n", + "# ***\n", + "# Criteria: \n", + "# relevance: Is the submission referring to a real quote from the text?\"\n", + "# conciseness: Is the answer concise and to the point?\"\n", + "# correct: Is the answer correct?\"\n", + "# ***\n", + "# Does the submission meet the criterion? First, write out in a step by step manner your reasoning about the criterion to be sure that your conclusion is correct. Avoid simply stating the correct answers at the outset. Then print \"Correct\" or \"Incorrect\" (without quotes or punctuation) on its own line corresponding to the correct answer.\n", + "# Reasoning:\n", + "# \"\"\"\n", + "\n", + "# GRADE_ANSWER_PROMPT_OPENAI = PromptTemplate(\n", + "# input_variables=[\"query\", \"result\", \"answer\"], template=template\n", + "# )\n", + "\n", + "# eval_chain = QAEvalChain.from_llm(\n", + "# llm=ChatOpenAI(model_name=\"gpt-3.5-turbo\", temperature=0),\n", + "# prompt=GRADE_ANSWER_PROMPT_OPENAI,\n", + "# )\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6145782c-5a85-4a7e-b62e-eeeb467c6924", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# # TODO: Update the runs syntax\n", + "# chain_runs = requests.get(endpoint + \"/chain-runs\").json()\n", + "# chain_run_ids = [chain_run['id'] for chain_run in chain_runs]\n", + "# from typing import List, Optional\n", + "# import pandas as pd\n", + "# from io import BytesIO\n", + "# import json\n", + "\n", + "# # Could do runs or whatever\n", + "# def create_dataset(chain_run_ids: List[str], name: str, description: str, auth: Optional[dict] = None) -> str:\n", + "# examples = []\n", + "# for run_id in chain_run_ids:\n", + "# run_url = endpoint + f\"/chain-runs/{run_id}\"\n", + "# run_response = requests.get(run_url).json()\n", + "# examples.append({\"chain_runs\": run_response})\n", + "# values[\"inputs\"].append(run_response[\"inputs\"])\n", + "# values[\"outputs\"].append((run_response[\"outputs\"])\n", + "# dataset_request_body = {\"name\": \"foo\", \n", + "# \"description\": \"bar\", \n", + "# \"examples\": [],\n", + "# }\n", + "# response = requests.post(endpoint + \"/datasets\", json=dataset_request_body)\n", + "# return response.json()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "62b5ae32-2112-441e-acc0-c21428ba8bdc", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/callbacks/manager.py b/langchain/callbacks/manager.py index f67298856db..133534b6bb4 100644 --- a/langchain/callbacks/manager.py +++ b/langchain/callbacks/manager.py @@ -44,11 +44,13 @@ def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]: @contextmanager def tracing_enabled( - session_name: str = "default", + session_name: str = "default", example_id: Optional[str] = None ) -> Generator[TracerSession, None, None]: """Get Tracer in a context manager.""" cb = LangChainTracer() session = cb.load_session(session_name) + if example_id: + cb.example_id tracing_callback_var.set(cb) yield session tracing_callback_var.set(None) diff --git a/langchain/callbacks/tracers/base.py b/langchain/callbacks/tracers/base.py index 6d036d32302..c474b2a13ea 100644 --- a/langchain/callbacks/tracers/base.py +++ b/langchain/callbacks/tracers/base.py @@ -29,7 +29,8 @@ class BaseTracer(BaseCallbackHandler, ABC): def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) self.run_map: Dict[str, Union[LLMRun, ChainRun, ToolRun]] = {} - self.session: Optional[Union[TracerSessionV2, TracerSession]] = None + self.session: Optional[TracerSession] = None + self.example_id: Optional[str] = None @staticmethod def _add_child_run( @@ -153,6 +154,7 @@ class BaseTracer(BaseCallbackHandler, ABC): execution_order=execution_order, child_execution_order=execution_order, session_id=self.session.id, + example_id=self.example_id, ) self._start_trace(llm_run) @@ -218,6 +220,7 @@ class BaseTracer(BaseCallbackHandler, ABC): child_execution_order=execution_order, child_runs=[], session_id=self.session.id, + example_id=self.example_id, ) self._start_trace(chain_run) @@ -283,6 +286,7 @@ class BaseTracer(BaseCallbackHandler, ABC): child_execution_order=execution_order, child_runs=[], session_id=self.session.id, + example_id=self.example_id, ) self._start_trace(tool_run) diff --git a/langchain/callbacks/tracers/langchain.py b/langchain/callbacks/tracers/langchain.py index fb1138a7e2e..71ca1dd4908 100644 --- a/langchain/callbacks/tracers/langchain.py +++ b/langchain/callbacks/tracers/langchain.py @@ -210,6 +210,7 @@ class LangChainTracerV2(BaseTracer): session_id=run.session_id, run_type=run_type, parent_run_id=run.parent_uuid, + example_id=run.example_id, child_runs=[LangChainTracerV2._convert_run(child) for child in child_runs], ) diff --git a/langchain/callbacks/tracers/schemas.py b/langchain/callbacks/tracers/schemas.py index f8190d13b19..506bd76b472 100644 --- a/langchain/callbacks/tracers/schemas.py +++ b/langchain/callbacks/tracers/schemas.py @@ -51,6 +51,7 @@ class BaseRun(BaseModel): serialized: Dict[str, Any] session_id: int error: Optional[str] = None + example_id: Optional[str] = None class LLMRun(BaseRun):