mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-10 05:20:39 +00:00
cr
This commit is contained in:
parent
a306baacd1
commit
6b64ff87b7
@ -36,7 +36,7 @@ def set_default_callback_manager() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def set_tracing_callback_manager(session_name: Optional[str] = None) -> None:
|
def set_tracing_callback_manager(session_name: Optional[str] = None, example_id: Optional[int] = None) -> None:
|
||||||
"""Set tracing callback manager."""
|
"""Set tracing callback manager."""
|
||||||
handler = SharedLangChainTracer()
|
handler = SharedLangChainTracer()
|
||||||
callback = get_callback_manager()
|
callback = get_callback_manager()
|
||||||
@ -49,6 +49,9 @@ def set_tracing_callback_manager(session_name: Optional[str] = None) -> None:
|
|||||||
except Exception:
|
except Exception:
|
||||||
raise ValueError(f"session {session_name} not found")
|
raise ValueError(f"session {session_name} not found")
|
||||||
|
|
||||||
|
if example_id is not None:
|
||||||
|
handler.example_id = example_id
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]:
|
def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]:
|
||||||
|
@ -22,11 +22,22 @@ class BaseLangChainTracer(BaseTracer, ABC):
|
|||||||
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
|
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
|
||||||
|
|
||||||
always_verbose: bool = True
|
always_verbose: bool = True
|
||||||
|
_example_id: Optional[int] = None
|
||||||
_endpoint: str = os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000")
|
_endpoint: str = os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000")
|
||||||
_headers: Dict[str, Any] = {"Content-Type": "application/json"}
|
_headers: Dict[str, Any] = {"Content-Type": "application/json"}
|
||||||
if os.getenv("LANGCHAIN_API_KEY"):
|
if os.getenv("LANGCHAIN_API_KEY"):
|
||||||
_headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY")
|
_headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def example_id(self) -> Optional[int]:
|
||||||
|
"""Return the example_id."""
|
||||||
|
return self._example_id
|
||||||
|
|
||||||
|
@example_id.setter
|
||||||
|
def example_id(self, value: int) -> None:
|
||||||
|
"""Set the example_id."""
|
||||||
|
self._example_id = value
|
||||||
|
|
||||||
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
||||||
"""Persist a run."""
|
"""Persist a run."""
|
||||||
if isinstance(run, LLMRun):
|
if isinstance(run, LLMRun):
|
||||||
@ -36,6 +47,9 @@ class BaseLangChainTracer(BaseTracer, ABC):
|
|||||||
else:
|
else:
|
||||||
endpoint = f"{self._endpoint}/tool-runs"
|
endpoint = f"{self._endpoint}/tool-runs"
|
||||||
|
|
||||||
|
if self._example_id:
|
||||||
|
run.example_id = self._example_id
|
||||||
|
|
||||||
try:
|
try:
|
||||||
requests.post(
|
requests.post(
|
||||||
endpoint,
|
endpoint,
|
||||||
|
@ -40,6 +40,7 @@ class BaseRun(BaseModel):
|
|||||||
serialized: Dict[str, Any]
|
serialized: Dict[str, Any]
|
||||||
session_id: int
|
session_id: int
|
||||||
error: Optional[str] = None
|
error: Optional[str] = None
|
||||||
|
example_id: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
class LLMRun(BaseRun):
|
class LLMRun(BaseRun):
|
||||||
|
317
langchain/evaluation/ExampleRunner.ipynb
Normal file
317
langchain/evaluation/ExampleRunner.ipynb
Normal file
@ -0,0 +1,317 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "bd0b1515-e25f-4707-a240-b5c26d5d33f8",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# ExamplerRunner Demo\n",
|
||||||
|
"\n",
|
||||||
|
"Run a chain on multiple examples for evaluation."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "13fbc962-98e0-470c-9467-c5e28db658a0",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"os.environ[\"LANGCHAIN_HANDLER\"] = \"langchain\"\n",
|
||||||
|
"os.environ[\"LANGCHAIN_ENDPOINT\"] = \"http://127.0.0.1:8000\" \n",
|
||||||
|
"\n",
|
||||||
|
"import langchain\n",
|
||||||
|
"from langchain.agents import Tool, initialize_agent, load_tools\n",
|
||||||
|
"from langchain.llms import OpenAI\n",
|
||||||
|
"from langchain.evaluation.example_runner import ExampleRunner, CsvDataset\n",
|
||||||
|
"from langchain.llms import OpenAI\n",
|
||||||
|
"from langchain.prompts import PromptTemplate\n",
|
||||||
|
"from langchain.chains import LLMChain"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "8afb6529-dfdb-4f0a-b19f-a700d80d4362",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Upload the dataset (only need to do once)\n",
|
||||||
|
"runner = ExampleRunner(\n",
|
||||||
|
" csv_dataset=CsvDataset(\n",
|
||||||
|
" csv_path=\"test_dataset.csv\",\n",
|
||||||
|
" description=\"Dummy dataset for testing\",\n",
|
||||||
|
" input_keys=[\"input1\", \"input2\", \"input3\"],\n",
|
||||||
|
" output_keys=[\"output1\"],\n",
|
||||||
|
" ),\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "603dc948-0a57-4696-8816-008aaf346538",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"[Example(created_at=datetime.datetime(2023, 2, 9, 7, 45, 27, 878780), inputs={'input1': 'one', 'input2': ' two', 'input3': ' three'}, outputs={'output1': ' four'}, dataset_id=1, id=1),\n",
|
||||||
|
" Example(created_at=datetime.datetime(2023, 2, 9, 7, 45, 27, 878790), inputs={'input1': 'five', 'input2': ' six', 'input3': ' seven'}, outputs={'output1': ' eight'}, dataset_id=1, id=2),\n",
|
||||||
|
" Example(created_at=datetime.datetime(2023, 2, 9, 7, 45, 27, 878792), inputs={'input1': 'nine', 'input2': ' ten', 'input3': ' eleven'}, outputs={'output1': ' twelve'}, dataset_id=1, id=3)]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"runner.dataset.examples[:3]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "11eba157-ec8e-43e9-8326-75099c07e574",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Retrying langchain.llms.openai.BaseOpenAI.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised RateLimitError: The server had an error while processing your request. Sorry about that!.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
", four\n",
|
||||||
|
"\n",
|
||||||
|
"Five\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Retrying langchain.llms.openai.BaseOpenAI.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised RateLimitError: The server had an error while processing your request. Sorry about that!.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"Eight\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Retrying langchain.llms.openai.BaseOpenAI.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised RateLimitError: The server had an error while processing your request. Sorry about that!.\n",
|
||||||
|
"Retrying langchain.llms.openai.BaseOpenAI.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised RateLimitError: The server had an error while processing your request. Sorry about that!.\n",
|
||||||
|
"Retrying langchain.llms.openai.BaseOpenAI.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised RateLimitError: The server had an error while processing your request. Sorry about that!.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"Twelve\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"Sixteen, seventeen, eighteen.\n",
|
||||||
|
", twenty\n",
|
||||||
|
"\n",
|
||||||
|
"Twenty-one\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"twenty-four\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Retrying langchain.llms.openai.BaseOpenAI.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised RateLimitError: The server had an error while processing your request. Sorry about that!.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
", twenty-eight\n",
|
||||||
|
"\n",
|
||||||
|
"Twenty-nine\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"Thirty-two.\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"Thirty-six\n",
|
||||||
|
", forty\n",
|
||||||
|
"\n",
|
||||||
|
"forty-one\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Retrying langchain.llms.openai.BaseOpenAI.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised RateLimitError: The server had an error while processing your request. Sorry about that!.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"Forty-four\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Retrying langchain.llms.openai.BaseOpenAI.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised RateLimitError: The server had an error while processing your request. Sorry about that!.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"Forty-eight\n",
|
||||||
|
"four\n",
|
||||||
|
" eight\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Retrying langchain.llms.openai.BaseOpenAI.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised RateLimitError: The server had an error while processing your request. Sorry about that!.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" twelve\n",
|
||||||
|
" sixteen\n",
|
||||||
|
"\n",
|
||||||
|
"twenty\n",
|
||||||
|
" twenty-four\n",
|
||||||
|
" twenty-eight\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Retrying langchain.llms.openai.BaseOpenAI.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised RateLimitError: The server had an error while processing your request. Sorry about that!.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" thirty-two\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Retrying langchain.llms.openai.BaseOpenAI.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised RateLimitError: The server had an error while processing your request. Sorry about that!.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" thirty-six\n",
|
||||||
|
" forty\n",
|
||||||
|
" forty-four\n",
|
||||||
|
" forty-eight\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"runner = ExampleRunner(\n",
|
||||||
|
" langchain_dataset_name=\"test_dataset.csv\",\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"llm = OpenAI(temperature=0)\n",
|
||||||
|
"prompt1 = PromptTemplate(\n",
|
||||||
|
" input_variables=[\"input1\", \"input2\", \"input3\"],\n",
|
||||||
|
" template=\"Complete the sequence: {input1}, {input2}, {input3}\",\n",
|
||||||
|
")\n",
|
||||||
|
"chain1 = LLMChain(llm=llm, prompt=prompt1)\n",
|
||||||
|
"\n",
|
||||||
|
"prompt2 = PromptTemplate(\n",
|
||||||
|
" input_variables=[\"input1\", \"input2\", \"input3\"],\n",
|
||||||
|
" template=\"\"\"\n",
|
||||||
|
" You are given the text representation of three numbers. You are to give the next number in the sequence. Only provide one number! \n",
|
||||||
|
" \n",
|
||||||
|
" Example:\n",
|
||||||
|
" Input: one, two three. \n",
|
||||||
|
" Output: four \n",
|
||||||
|
" \n",
|
||||||
|
" Input: {input1}, {input2}, {input3}\n",
|
||||||
|
" Output:\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
")\n",
|
||||||
|
"chain2 = LLMChain(llm=llm, prompt=prompt2)\n",
|
||||||
|
"runner.run_chain(chain1)\n",
|
||||||
|
"runner.run_chain(chain2)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "b510335a-872f-4392-829d-bdcba3a052cb",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.9"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
202
langchain/evaluation/example_runner.py
Normal file
202
langchain/evaluation/example_runner.py
Normal file
@ -0,0 +1,202 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
from pydantic import BaseModel, validator, root_validator, Field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Dict, Any, List
|
||||||
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
from pydantic.networks import AnyHttpUrl
|
||||||
|
import requests
|
||||||
|
import datetime
|
||||||
|
import langchain
|
||||||
|
from langchain.agents import AgentExecutor
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
from langchain.llms.base import BaseLLM
|
||||||
|
from langchain.callbacks.base import CallbackManager
|
||||||
|
from langchain.callbacks.tracers import LangChainTracer
|
||||||
|
from langchain.callbacks.stdout import StdOutCallbackHandler
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleBase(BaseModel):
|
||||||
|
"""Base class for Example."""
|
||||||
|
created_at: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
|
||||||
|
inputs: Dict[str, Any]
|
||||||
|
outputs: Dict[str, Any] | None = None
|
||||||
|
dataset_id: int
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleCreate(ExampleBase):
|
||||||
|
"""Create class for Example."""
|
||||||
|
|
||||||
|
|
||||||
|
class Example(ExampleBase):
|
||||||
|
"""Example schema."""
|
||||||
|
id: int
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetBase(BaseModel):
|
||||||
|
"""Base class for Dataset."""
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
created_at: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
|
||||||
|
examples: List[Example] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetCreate(DatasetBase):
|
||||||
|
"""Create class for Dataset."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Dataset(DatasetBase):
|
||||||
|
"""Dataset schema."""
|
||||||
|
id: int
|
||||||
|
|
||||||
|
|
||||||
|
class CsvDataset(BaseModel):
|
||||||
|
"""Class for a csv file that can be uploaded to a LangChain endpoint."""
|
||||||
|
csv_path: Path
|
||||||
|
description: str
|
||||||
|
input_keys: List[str]
|
||||||
|
output_keys: List[str]
|
||||||
|
|
||||||
|
@validator("csv_path")
|
||||||
|
def validate_csv_path(cls, v):
|
||||||
|
"""Validate that the csv path is valid."""
|
||||||
|
if not v.exists():
|
||||||
|
raise ValueError("CSV file does not exist.")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_dataset_from_endpoint(name: str, headers: Dict[str, str], endpoint: str = "https://localhost:8000") -> Dataset:
|
||||||
|
"""Fetch a dataset from a LangChain endpoint."""
|
||||||
|
response = requests.get(f"{endpoint}/datasets?name={name}", headers=headers)
|
||||||
|
response.raise_for_status()
|
||||||
|
if len(response.json()) == 0:
|
||||||
|
raise ValueError(f"Dataset with name {name} does not exist.")
|
||||||
|
return Dataset(**(response.json()[0]))
|
||||||
|
|
||||||
|
|
||||||
|
def upload_csv_dataset_to_endpoint(csv_dataset: CsvDataset, headers: Dict[str, str], endpoint: str = "https://localhost:8000") -> Dataset:
|
||||||
|
"""Upload a csv to a LangChain endpoint."""
|
||||||
|
with open(csv_dataset.csv_path, "rb") as f:
|
||||||
|
response = requests.post(
|
||||||
|
f"{endpoint}/datasets/upload",
|
||||||
|
headers=headers,
|
||||||
|
files={"file": (csv_dataset.csv_path.name, f)},
|
||||||
|
data={
|
||||||
|
"input_keys": csv_dataset.input_keys,
|
||||||
|
"output_keys": csv_dataset.output_keys,
|
||||||
|
"description": csv_dataset.description,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return Dataset(**response.json())
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleRunner(BaseModel):
|
||||||
|
"""Class that runs an LLM, chain or agent on a set of examples."""
|
||||||
|
|
||||||
|
langchain_endpoint: AnyHttpUrl
|
||||||
|
dataset: Dataset
|
||||||
|
csv_dataset: Optional[CsvDataset] = None
|
||||||
|
langchain_dataset_name: Optional[str] = None
|
||||||
|
langchain_api_key: Optional[str] = None
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def validate_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Validate that either csv_path or langchain_dataset is provided but not both."""
|
||||||
|
csv_dataset = values.get("csv_dataset")
|
||||||
|
langchain_dataset_name = values.get("langchain_dataset_name")
|
||||||
|
values["langchain_endpoint"] = os.environ.get("LANGCHAIN_ENDPOINT", "https://localhost:8000")
|
||||||
|
langchain_endpoint = values["langchain_endpoint"]
|
||||||
|
if csv_dataset is None and langchain_dataset_name is None:
|
||||||
|
raise ValueError("Must provide either csv_path or langchain_dataset.")
|
||||||
|
if csv_dataset is not None and langchain_dataset_name is not None:
|
||||||
|
raise ValueError("Cannot provide both csv_path and langchain_dataset.")
|
||||||
|
if urlparse(langchain_endpoint).hostname not in ["localhost", "127.0.0.1", "0.0.0.0"]:
|
||||||
|
values["langchain_api_key"] = get_from_dict_or_env(
|
||||||
|
values, "langchain_api_key", "LANGCHAIN_API_KEY"
|
||||||
|
)
|
||||||
|
# Try fetching the dataset to make sure it exists
|
||||||
|
if langchain_dataset_name is not None:
|
||||||
|
headers: Dict[str, str] = {}
|
||||||
|
if values.get("langchain_api_key"):
|
||||||
|
headers["x-api-key"] = values["langchain_api_key"]
|
||||||
|
values["dataset"] = fetch_dataset_from_endpoint(langchain_dataset_name, headers, langchain_endpoint)
|
||||||
|
if csv_dataset is not None:
|
||||||
|
# Upload the csv to the endpoint
|
||||||
|
headers: Dict[str, str] = {}
|
||||||
|
if values.get("langchain_api_key"):
|
||||||
|
headers["x-api-key"] = values["langchain_api_key"]
|
||||||
|
values["dataset"] = upload_csv_dataset_to_endpoint(csv_dataset, headers, langchain_endpoint)
|
||||||
|
return values
|
||||||
|
|
||||||
|
def examples(self) -> List[Example]:
|
||||||
|
"""Get the examples from the dataset."""
|
||||||
|
return self.dataset.examples
|
||||||
|
|
||||||
|
def run_agent(self, agent: AgentExecutor):
|
||||||
|
"""Run an agent on the examples."""
|
||||||
|
for example in self.examples():
|
||||||
|
agent.run(**example.inputs)
|
||||||
|
|
||||||
|
def run_chain(self, chain: Chain):
|
||||||
|
"""Run a chain on the examples."""
|
||||||
|
for example in self.examples():
|
||||||
|
langchain.set_tracing_callback_manager(example_id=example.id)
|
||||||
|
print(chain.run(**example.inputs))
|
||||||
|
|
||||||
|
def run_llm(self, llm: BaseLLM):
|
||||||
|
"""Run an LLM on the examples."""
|
||||||
|
for example in self.examples():
|
||||||
|
llm.generate([val for val in example.inputs.values()])
|
||||||
|
|
||||||
|
|
||||||
|
# async def arun_agent(self, agent: AgentExecutor, num_workers: int = 1):
|
||||||
|
# """Run an agent on the examples."""
|
||||||
|
# # Copy the agent num_workers times
|
||||||
|
# agents = []
|
||||||
|
# for _ in range(num_workers):
|
||||||
|
# tracer = LangChainTracer()
|
||||||
|
# tracer.load_default_session()
|
||||||
|
# manager = CallbackManager([StdOutCallbackHandler(), tracer])
|
||||||
|
# agent.from_agent_and_tools(agent.agent, agent.tools, manager)
|
||||||
|
# agents.append(agent)
|
||||||
|
#
|
||||||
|
# i = 0
|
||||||
|
# while i < len(self.examples()):
|
||||||
|
# for agent in agents:
|
||||||
|
# example = self.examples()[i]
|
||||||
|
# await agent.arun(**example.inputs)
|
||||||
|
# i += 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
os.environ["LANGCHAIN_ENDPOINT"] = "http://127.0.0.1:8000"
|
||||||
|
runner = ExampleRunner(
|
||||||
|
csv_dataset=CsvDataset(
|
||||||
|
csv_path="test_dataset.csv",
|
||||||
|
description="Dummy dataset for testing",
|
||||||
|
input_keys=["input1", "input2", "input3"],
|
||||||
|
output_keys=["output1"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# runner = ExampleRunner(
|
||||||
|
# langchain_dataset_name="test_dataset.csv",
|
||||||
|
# )
|
||||||
|
|
||||||
|
from langchain.llms import OpenAI
|
||||||
|
from langchain.prompts import PromptTemplate
|
||||||
|
from langchain.chains import LLMChain
|
||||||
|
llm = OpenAI(temperature=0.9, model_name="text-ada-001")
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
input_variables=["input1", "input2", "input3"],
|
||||||
|
template="Complete the sequence: {input1}, {input2}, {input3}",
|
||||||
|
)
|
||||||
|
chain = LLMChain(llm=llm, prompt=prompt)
|
||||||
|
runner.run_chain(chain)
|
||||||
|
|
||||||
|
|
||||||
|
|
13
langchain/evaluation/test_dataset.csv
Normal file
13
langchain/evaluation/test_dataset.csv
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
input1,input2,input3,output1
|
||||||
|
one, two, three, four
|
||||||
|
five, six, seven, eight
|
||||||
|
nine, ten, eleven, twelve
|
||||||
|
thirteen, fourteen, fifteen, sixteen
|
||||||
|
seventeen, eighteen, nineteen, twenty
|
||||||
|
twenty-one, twenty-two, twenty-three, twenty-four
|
||||||
|
twenty-five, twenty-six, twenty-seven, twenty-eight
|
||||||
|
twenty-nine, thirty, thirty-one, thirty-two
|
||||||
|
thirty-three, thirty-four, thirty-five, thirty-six
|
||||||
|
thirty-seven, thirty-eight, thirty-nine, forty
|
||||||
|
forty-one, forty-two, forty-three, forty-four
|
||||||
|
forty-five, forty-six, forty-seven, forty-eight
|
|
Loading…
Reference in New Issue
Block a user