mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-21 21:56:38 +00:00
Compare commits
35 Commits
erick/rele
...
vwp/eval_e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
15510bae4a | ||
|
|
15afdd0858 | ||
|
|
636877b3af | ||
|
|
b7508d9b52 | ||
|
|
1d58a49b78 | ||
|
|
8c84538bf5 | ||
|
|
d0aec0fda4 | ||
|
|
f7f9756576 | ||
|
|
ed1d3a48e4 | ||
|
|
e3808b63d3 | ||
|
|
b709a7b5f0 | ||
|
|
fa9b296c55 | ||
|
|
ede1bf204c | ||
|
|
c2f856fac9 | ||
|
|
059f4e8bb6 | ||
|
|
712d4a228e | ||
|
|
faa30c8ac5 | ||
|
|
4c3982464d | ||
|
|
52a7453469 | ||
|
|
f16e413c7c | ||
|
|
9d1889bb59 | ||
|
|
0304a1a563 | ||
|
|
7058b207fb | ||
|
|
40253143d6 | ||
|
|
4f4d1799b0 | ||
|
|
50615a5282 | ||
|
|
92a91e54fb | ||
|
|
054b4ff0d3 | ||
|
|
759445229b | ||
|
|
93159c6088 | ||
|
|
ee4d92aa00 | ||
|
|
ddc26e074e | ||
|
|
2fd3133239 | ||
|
|
0a87cbd1e6 | ||
|
|
2df6119194 |
967
docs/use_cases/evaluation/evaluating_traced_examples.ipynb
Normal file
967
docs/use_cases/evaluation/evaluating_traced_examples.ipynb
Normal file
@@ -0,0 +1,967 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1a4596ea-a631-416d-a2a4-3577c140493d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Running Chains on Traced Datasets\n",
|
||||
"\n",
|
||||
"Developing applications with language models can be uniquely challenging. To manage this complexity and ensure reliable performance, LangChain provides tracing and evaluation functionality through . This notebook demonstrates how to run Chains, which are language model functions, on previously captured datasets or traces. Some common use cases for this approach include:\n",
|
||||
"\n",
|
||||
"- Running an evaluation chain to grade previous runs.\n",
|
||||
"- Comparing different chains, LLMs, and agents on traced datasets.\n",
|
||||
"- Executing a stochastic chain multiple times over a dataset to generate metrics before deployment.\n",
|
||||
"\n",
|
||||
"Please note that this notebook assumes you have LangChain+ tracing running in the background. It is also configured to work only with the V2 endpoints. To set it up, follow the [tracing directions here](..\\/..\\/tracing\\/local_installation.md).\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "904db9a5-f387-4a57-914c-c8af8d39e249",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.client import LangChainPlusClient\n",
|
||||
"\n",
|
||||
"client = LangChainPlusClient(\n",
|
||||
" api_url=\"http://localhost:8000\",\n",
|
||||
" api_key=None,\n",
|
||||
" # tenant_id=\"your_tenant_uuid\", # This is required when connecting to a hosted LangChain instance\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "db79dea2-fbaa-4c12-9083-f6154b51e2d3",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"source": [
|
||||
"## Seed an example dataset\n",
|
||||
"\n",
|
||||
"If you have been using LangChainPlus already, you may have datasets available. To view all saved datasets, run:\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"datasets = client.list_datasets()\n",
|
||||
"datasets\n",
|
||||
"```\n",
|
||||
"Datasets can be created in a number of ways, most often by collecting `Run`'s captured through the LangChain tracing API.\n",
|
||||
"\n",
|
||||
"However, this notebook assumes you're running locally for the first time, so we'll start by uploading an example evaluation dataset."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "1baa677c-5642-4378-8e01-3aa1647f19d6",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# !pip install datasets > /dev/null\n",
|
||||
"# !pip install pandas > /dev/null"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "60d14593-c61f-449f-a38f-772ca43707c2",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Found cached dataset json (/Users/wfh/.cache/huggingface/datasets/LangChainDatasets___json/LangChainDatasets--agent-search-calculator-8a025c0ce5fb99d2/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "60be643da25c47c6aa729b37046a2e64",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/1 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>input</th>\n",
|
||||
" <th>output</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>How many people live in canada as of 2023?</td>\n",
|
||||
" <td>approximately 38,625,801</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>who is dua lipa's boyfriend? what is his age r...</td>\n",
|
||||
" <td>her boyfriend is Romain Gravas. his age raised...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>what is dua lipa's boyfriend age raised to the...</td>\n",
|
||||
" <td>her boyfriend is Romain Gravas. his age raised...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>how far is it from paris to boston in miles</td>\n",
|
||||
" <td>approximately 3,435 mi</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>what was the total number of points scored in ...</td>\n",
|
||||
" <td>approximately 2.682651500990882</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" input \\\n",
|
||||
"0 How many people live in canada as of 2023? \n",
|
||||
"1 who is dua lipa's boyfriend? what is his age r... \n",
|
||||
"2 what is dua lipa's boyfriend age raised to the... \n",
|
||||
"3 how far is it from paris to boston in miles \n",
|
||||
"4 what was the total number of points scored in ... \n",
|
||||
"\n",
|
||||
" output \n",
|
||||
"0 approximately 38,625,801 \n",
|
||||
"1 her boyfriend is Romain Gravas. his age raised... \n",
|
||||
"2 her boyfriend is Romain Gravas. his age raised... \n",
|
||||
"3 approximately 3,435 mi \n",
|
||||
"4 approximately 2.682651500990882 "
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"from langchain.evaluation.loading import load_dataset\n",
|
||||
"\n",
|
||||
"dataset = load_dataset(\"agent-search-calculator\")\n",
|
||||
"df = pd.DataFrame(dataset, columns=[\"question\", \"answer\"])\n",
|
||||
"df.columns = [\"input\", \"output\"] # The chain we want to evaluate below expects inputs with the \"input\" key \n",
|
||||
"df.head()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "f27e45f1-e299-4de8-a538-ee1272ac5024",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dataset_name = f\"calculator-example-dataset\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "52a7ea76-79ca-4765-abf7-231e884040d6",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if dataset_name not in set([dataset.name for dataset in client.list_datasets()]):\n",
|
||||
" dataset = client.upload_dataframe(df, \n",
|
||||
" name=dataset_name,\n",
|
||||
" description=\"Acalculator example dataset\",\n",
|
||||
" input_keys=[\"input\"],\n",
|
||||
" output_keys=[\"output\"],\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "07885b10",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"source": [
|
||||
"## Running a Chain on a Traced Dataset\n",
|
||||
"\n",
|
||||
"Once you have a dataset, you can run a chain over it to see its results. The run traces will automatically be associated with the dataset for easy attribution and analysis.\n",
|
||||
"\n",
|
||||
"**First, we'll define the chain we wish to run over the dataset.**\n",
|
||||
"\n",
|
||||
"In this case, we're using an agent, but it can be any simple chain."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "c2b59104-b90e-466a-b7ea-c5bd0194263b",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
"from langchain.agents import initialize_agent, load_tools\n",
|
||||
"from langchain.agents import AgentType\n",
|
||||
"\n",
|
||||
"llm = ChatOpenAI(temperature=0)\n",
|
||||
"tools = load_tools(['serpapi', 'llm-math'], llm=llm)\n",
|
||||
"agent = initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "84094a4a-1d76-461c-bc37-8c537939b466",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**Now we're ready to run the chain!**\n",
|
||||
"\n",
|
||||
"The docstring below hints out ways you can configure the method to run."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "112d7bdf-7e50-4c1a-9285-5bac8473f2ee",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"\u001b[0;31mSignature:\u001b[0m\n",
|
||||
"\u001b[0mclient\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marun_on_dataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n",
|
||||
"\u001b[0;34m\u001b[0m \u001b[0mdataset_name\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'str'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
||||
"\u001b[0;34m\u001b[0m \u001b[0mllm_or_chain\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Union[Chain, BaseLanguageModel]'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
||||
"\u001b[0;34m\u001b[0m \u001b[0mnum_workers\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'int'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
||||
"\u001b[0;34m\u001b[0m \u001b[0mnum_repetitions\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'int'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
||||
"\u001b[0;34m\u001b[0m \u001b[0msession_name\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Optional[str]'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
||||
"\u001b[0;34m\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'bool'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
||||
"\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;34m'Dict[str, Any]'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;31mDocstring:\u001b[0m\n",
|
||||
"Run the chain on a dataset and store traces to the specified session name.\n",
|
||||
"\n",
|
||||
"Args:\n",
|
||||
" dataset_name: Name of the dataset to run the chain on.\n",
|
||||
" llm_or_chain: Chain or language model to run over the dataset.\n",
|
||||
" num_workers: Number of async workers to run in parallel.\n",
|
||||
" num_repetitions: Number of times to run the model on each example.\n",
|
||||
" This is useful when testing success rates or generating confidence\n",
|
||||
" intervals.\n",
|
||||
" session_name: Name of the session to store the traces in.\n",
|
||||
" Defaults to {dataset_name}-{chain class name}-{datetime}.\n",
|
||||
" verbose: Whether to print progress.\n",
|
||||
"\n",
|
||||
"Returns:\n",
|
||||
" A dictionary mapping example ids to the model outputs.\n",
|
||||
"\u001b[0;31mFile:\u001b[0m ~/code/lc/lckg/langchain/client/langchain.py\n",
|
||||
"\u001b[0;31mType:\u001b[0m method"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"?client.arun_on_dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "a8088b7d-3ab6-4279-94c8-5116fe7cee33",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/wfh/code/lc/lckg/langchain/callbacks/manager.py:65: UserWarning: The experimental tracing v2 is in development. This is not yet stable and may change in the future.\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Processed examples: 1\r"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Chain failed for example 4aca2037-5393-4357-9bf9-ef7451b8f0e3. Error: unknown format from LLM: Assuming we don't have any information about the actual number of points scored in the 2023 super bowl, we cannot provide a mathematical expression to solve this problem.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Processed examples: 2\r"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Chain failed for example b0adfea5-1212-4cee-9206-11bd69ccfcc5. Error: 'age'. Please try again with a valid numerical expression\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Processed examples: 5\r"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Chain failed for example 7e5873e4-ed38-428b-909f-f8c0fb80bab9. Error: invalid syntax. Perhaps you forgot a comma? (<expr>, line 1). Please try again with a valid numerical expression\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Processed examples: 10\r"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chain_results = await client.arun_on_dataset(\n",
|
||||
" dataset_name=dataset_name,\n",
|
||||
" llm_or_chain=agent,\n",
|
||||
" verbose=True\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d2737458-b20c-4288-8790-1f4a8d237b2a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Reviewing the Chain Results\n",
|
||||
"\n",
|
||||
"The method called above returns a dictionary mapping Example IDs to the output of the chain.\n",
|
||||
"You can directly inspect the results below."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "136db492-d6ca-4215-96f9-439c23538241",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<a href=\"http://localhost\", target=\"_blank\" rel=\"noopener\">LangChain+ Client</a>"
|
||||
],
|
||||
"text/plain": [
|
||||
"LangChainPlusClient (API URL: http://localhost:8000)"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# You can navigate to the UI by clicking on the link below\n",
|
||||
"client"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c70cceb5-aa53-4851-bb12-386f092191f9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Running a Chat Model over a Traced Dataset\n",
|
||||
"\n",
|
||||
"We've shown how to run a _chain_ over a dataset, but you can also run an LLM or Chat model over a datasets formed from runs.\n",
|
||||
"\n",
|
||||
"First, we'll show an example using a ChatModel. This is useful for things like:\n",
|
||||
"- Comparing results under different decoding parameters\n",
|
||||
"- Comparing model providers\n",
|
||||
"- Testing for regressions in model behavior\n",
|
||||
"- Running multiple times with a temperature to gauge stability"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "64490d7c-9a18-49ed-a3ac-36049c522cb4",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Found cached dataset parquet (/Users/wfh/.cache/huggingface/datasets/LangChainDatasets___parquet/LangChainDatasets--two-player-dnd-2e84407830cdedfc/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "35a85ea78f934758a5f8b16af0d1944a",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/1 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>prompts</th>\n",
|
||||
" <th>generations</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>[System: Here is the topic for a Dungeons & Dr...</td>\n",
|
||||
" <td>[[{'generation_info': None, 'message': {'conte...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>[System: Here is the topic for a Dungeons & Dr...</td>\n",
|
||||
" <td>[[{'generation_info': None, 'message': {'conte...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>[System: Here is the topic for a Dungeons & Dr...</td>\n",
|
||||
" <td>[[{'generation_info': None, 'message': {'conte...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>[System: Here is the topic for a Dungeons & Dr...</td>\n",
|
||||
" <td>[[{'generation_info': None, 'message': {'conte...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>[System: Here is the topic for a Dungeons & Dr...</td>\n",
|
||||
" <td>[[{'generation_info': None, 'message': {'conte...</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" prompts \\\n",
|
||||
"0 [System: Here is the topic for a Dungeons & Dr... \n",
|
||||
"1 [System: Here is the topic for a Dungeons & Dr... \n",
|
||||
"2 [System: Here is the topic for a Dungeons & Dr... \n",
|
||||
"3 [System: Here is the topic for a Dungeons & Dr... \n",
|
||||
"4 [System: Here is the topic for a Dungeons & Dr... \n",
|
||||
"\n",
|
||||
" generations \n",
|
||||
"0 [[{'generation_info': None, 'message': {'conte... \n",
|
||||
"1 [[{'generation_info': None, 'message': {'conte... \n",
|
||||
"2 [[{'generation_info': None, 'message': {'conte... \n",
|
||||
"3 [[{'generation_info': None, 'message': {'conte... \n",
|
||||
"4 [[{'generation_info': None, 'message': {'conte... "
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chat_dataset = load_dataset(\"two-player-dnd\")\n",
|
||||
"chat_df = pd.DataFrame(chat_dataset)\n",
|
||||
"chat_df.head()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "348acd86-a927-4d60-8d52-02e64585e4fc",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chat_dataset_name = \"two-player-dnd\"\n",
|
||||
"\n",
|
||||
"if chat_dataset_name not in set([dataset.name for dataset in client.list_datasets()]):\n",
|
||||
" client.upload_dataframe(chat_df, \n",
|
||||
" name=chat_dataset_name,\n",
|
||||
" description=\"An example dataset traced from chat models in a multiagent bidding dialogue\",\n",
|
||||
" input_keys=[\"prompts\"],\n",
|
||||
" output_keys=[\"generations\"],\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "a69dd183-ad5e-473d-b631-db90706e837f",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
"\n",
|
||||
"chat_model = ChatOpenAI(temperature=0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "063da2a9-3692-4b7b-8edb-e474824fe416",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/wfh/code/lc/lckg/langchain/callbacks/manager.py:65: UserWarning: The experimental tracing v2 is in development. This is not yet stable and may change in the future.\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Processed examples: 35\r"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Retrying langchain.chat_models.openai.acompletion_with_retry.<locals>._completion_with_retry in 1.0 seconds as it raised RateLimitError: That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 2a718b30703cc8ffef8de0bb48935fd6 in your message.).\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Processed examples: 36\r"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chat_model_results = await client.arun_on_dataset(\n",
|
||||
" dataset_name=chat_dataset_name,\n",
|
||||
" llm_or_chain=chat_model,\n",
|
||||
" num_workers=5, # Optional, sets the number of examples to run at a time\n",
|
||||
" # session_name=\"Calculator Dataset Runs\", # Optional. Will be seed with a default session otherwise\n",
|
||||
" verbose=True\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "de7bfe08-215c-4328-b9b0-631d9a41f0e8",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"source": [
|
||||
"## Reviewing the Chat Model Results\n",
|
||||
"\n",
|
||||
"You can once again review the latest runs by clicking on the link below and navigating to the \"two-player-dnd\" session."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"id": "5b7a81f2-d19d-438b-a4bb-5678f746b965",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<a href=\"http://localhost\", target=\"_blank\" rel=\"noopener\">LangChain+ Client</a>"
|
||||
],
|
||||
"text/plain": [
|
||||
"LangChainPlusClient (API URL: http://localhost:8000)"
|
||||
]
|
||||
},
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"client"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7896cbeb-345f-430b-ab5e-e108973174f8",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Running an LLM over a Traced Dataset\n",
|
||||
"\n",
|
||||
"You can run an LLM over a dataset in much the same way as the chain and chat models, provided the dataset you've captured is in the appropriate format. Again, we've cached one for you, but using application-specific traces will be much more useful for your use cases."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "d6805d0b-4612-4671-bffb-e6978992bd40",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/wfh/code/lc/lckg/langchain/llms/anthropic.py:134: UserWarning: This Anthropic LLM is deprecated. Please use `from langchain.chat_models import ChatAnthropic` instead\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.llms import Anthropic\n",
|
||||
"\n",
|
||||
"llm = Anthropic(temperature=0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"id": "5d7cb243-40c3-44dd-8158-a7b910441e9f",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Found cached dataset parquet (/Users/wfh/.cache/huggingface/datasets/LangChainDatasets___parquet/LangChainDatasets--state-of-the-union-completions-ae7542e7bbd0ae0a/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "c8628ee0ce5b4924be636250b18e9803",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/1 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>prompts</th>\n",
|
||||
" <th>generations</th>\n",
|
||||
" <th>ground_truth</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>[Putin may circle Kyiv with tanks, but he will...</td>\n",
|
||||
" <td>[[{'generation_info': {'finish_reason': 'stop'...</td>\n",
|
||||
" <td>The pandemic has been punishing. \\n\\nAnd so ma...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>[Madam Speaker, Madam Vice President, our Firs...</td>\n",
|
||||
" <td>[[]]</td>\n",
|
||||
" <td>With a duty to one another to the American peo...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>[With a duty to one another to the American pe...</td>\n",
|
||||
" <td>[[{'generation_info': {'finish_reason': 'stop'...</td>\n",
|
||||
" <td>He thought he could roll into Ukraine and the ...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>[Madam Speaker, Madam Vice President, our Firs...</td>\n",
|
||||
" <td>[[{'generation_info': {'finish_reason': 'lengt...</td>\n",
|
||||
" <td>With a duty to one another to the American peo...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>[Please rise if you are able and show that, Ye...</td>\n",
|
||||
" <td>[[]]</td>\n",
|
||||
" <td>And the costs and the threats to America and t...</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" prompts \\\n",
|
||||
"0 [Putin may circle Kyiv with tanks, but he will... \n",
|
||||
"1 [Madam Speaker, Madam Vice President, our Firs... \n",
|
||||
"2 [With a duty to one another to the American pe... \n",
|
||||
"3 [Madam Speaker, Madam Vice President, our Firs... \n",
|
||||
"4 [Please rise if you are able and show that, Ye... \n",
|
||||
"\n",
|
||||
" generations \\\n",
|
||||
"0 [[{'generation_info': {'finish_reason': 'stop'... \n",
|
||||
"1 [[]] \n",
|
||||
"2 [[{'generation_info': {'finish_reason': 'stop'... \n",
|
||||
"3 [[{'generation_info': {'finish_reason': 'lengt... \n",
|
||||
"4 [[]] \n",
|
||||
"\n",
|
||||
" ground_truth \n",
|
||||
"0 The pandemic has been punishing. \\n\\nAnd so ma... \n",
|
||||
"1 With a duty to one another to the American peo... \n",
|
||||
"2 He thought he could roll into Ukraine and the ... \n",
|
||||
"3 With a duty to one another to the American peo... \n",
|
||||
"4 And the costs and the threats to America and t... "
|
||||
]
|
||||
},
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"completions_dataset = load_dataset(\"state-of-the-union-completions\")\n",
|
||||
"completions_df = pd.DataFrame(completions_dataset)\n",
|
||||
"completions_df.head()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"id": "c7dcc1b2-7aef-44c0-ba0f-c812279099a5",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"completions_dataset_name = \"state-of-the-union-completions\"\n",
|
||||
"\n",
|
||||
"if completions_dataset_name not in set([dataset.name for dataset in client.list_datasets()]):\n",
|
||||
" client.upload_dataframe(completions_df, \n",
|
||||
" name=completions_dataset_name,\n",
|
||||
" description=\"An example dataset traced from completion endpoints over the state of the union address\",\n",
|
||||
" input_keys=[\"prompts\"],\n",
|
||||
" output_keys=[\"generations\"],\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"id": "02a0b79d-4867-45b2-b43d-662e5dcefb0f",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/wfh/code/lc/lckg/langchain/llms/anthropic.py:134: UserWarning: This Anthropic LLM is deprecated. Please use `from langchain.chat_models import ChatAnthropic` instead\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.llms import Anthropic\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"llm = Anthropic(temperature=0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"id": "e946138e-bf7c-43d7-861d-9c5740c933fa",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/wfh/code/lc/lckg/langchain/callbacks/manager.py:65: UserWarning: The experimental tracing v2 is in development. This is not yet stable and may change in the future.\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Processed examples: 55\r"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"completions_model_results = await client.arun_on_dataset(\n",
|
||||
" dataset_name=completions_dataset_name,\n",
|
||||
" llm_or_chain=llm,\n",
|
||||
" num_workers=2, # Optional, sets the number of examples to run at a time\n",
|
||||
" num_repetitions=1, # Increasing this will run the model multiple times per example\n",
|
||||
" verbose=True\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "cc86e8e6-cee2-429e-942b-289284d14816",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Reviewing the LLM Results\n",
|
||||
"\n",
|
||||
"You can once again inspect the latest runs by clicking on the link below and navigating to the \"two-player-dnd\" session."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"id": "2bf96f17-74c1-4f7d-8458-ae5ab5c6bd36",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<a href=\"http://localhost\", target=\"_blank\" rel=\"noopener\">LangChain+ Client</a>"
|
||||
],
|
||||
"text/plain": [
|
||||
"LangChainPlusClient (API URL: http://localhost:8000)"
|
||||
]
|
||||
},
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"client"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "bbafa5a5-a278-42c0-a24f-191cbcb41156",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -58,6 +58,7 @@ def tracing_enabled(
|
||||
@contextmanager
|
||||
def tracing_v2_enabled(
|
||||
session_name: str = "default",
|
||||
example_id: Optional[Union[str, UUID]] = None,
|
||||
) -> Generator[TracerSessionV2, None, None]:
|
||||
"""Get the experimental tracer handler in a context manager."""
|
||||
# Issue a warning that this is experimental
|
||||
@@ -65,8 +66,10 @@ def tracing_v2_enabled(
|
||||
"The experimental tracing v2 is in development. "
|
||||
"This is not yet stable and may change in the future."
|
||||
)
|
||||
cb = LangChainTracerV2()
|
||||
session = cb.load_session(session_name)
|
||||
if isinstance(example_id, str):
|
||||
example_id = UUID(example_id)
|
||||
cb = LangChainTracerV2(example_id=example_id)
|
||||
session = cast(TracerSessionV2, cb.new_session(session_name))
|
||||
tracing_callback_var.set(cb)
|
||||
yield session
|
||||
tracing_callback_var.set(None)
|
||||
|
||||
@@ -29,7 +29,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.run_map: Dict[str, Union[LLMRun, ChainRun, ToolRun]] = {}
|
||||
self.session: Optional[Union[TracerSessionV2, TracerSession]] = None
|
||||
self.session: Optional[Union[TracerSession, TracerSessionV2]] = None
|
||||
|
||||
@staticmethod
|
||||
def _add_child_run(
|
||||
@@ -165,7 +165,6 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
llm_run = self.run_map.get(run_id_)
|
||||
if llm_run is None or not isinstance(llm_run, LLMRun):
|
||||
raise TracerException("No LLMRun found to be traced")
|
||||
|
||||
llm_run.response = response
|
||||
llm_run.end_time = datetime.utcnow()
|
||||
self._end_trace(llm_run)
|
||||
|
||||
@@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import requests
|
||||
|
||||
@@ -11,13 +12,14 @@ from langchain.callbacks.tracers.base import BaseTracer
|
||||
from langchain.callbacks.tracers.schemas import (
|
||||
ChainRun,
|
||||
LLMRun,
|
||||
Run,
|
||||
RunCreate,
|
||||
ToolRun,
|
||||
TracerSession,
|
||||
TracerSessionBase,
|
||||
TracerSessionV2,
|
||||
TracerSessionV2Create,
|
||||
)
|
||||
from langchain.utils import raise_for_status_with_text
|
||||
|
||||
|
||||
def _get_headers() -> Dict[str, Any]:
|
||||
@@ -51,11 +53,12 @@ class LangChainTracer(BaseTracer):
|
||||
endpoint = f"{self._endpoint}/tool-runs"
|
||||
|
||||
try:
|
||||
requests.post(
|
||||
response = requests.post(
|
||||
endpoint,
|
||||
data=run.json(),
|
||||
headers=self._headers,
|
||||
)
|
||||
raise_for_status_with_text(response)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to persist run: {e}")
|
||||
|
||||
@@ -111,7 +114,7 @@ def _get_tenant_id() -> Optional[str]:
|
||||
endpoint = _get_endpoint()
|
||||
headers = _get_headers()
|
||||
response = requests.get(endpoint + "/tenants", headers=headers)
|
||||
response.raise_for_status()
|
||||
raise_for_status_with_text(response)
|
||||
tenants: List[Dict[str, Any]] = response.json()
|
||||
if not tenants:
|
||||
raise ValueError(f"No tenants found for URL {endpoint}")
|
||||
@@ -121,12 +124,13 @@ def _get_tenant_id() -> Optional[str]:
|
||||
class LangChainTracerV2(LangChainTracer):
|
||||
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
def __init__(self, example_id: Optional[UUID] = None, **kwargs: Any) -> None:
|
||||
"""Initialize the LangChain tracer."""
|
||||
super().__init__(**kwargs)
|
||||
self._endpoint = _get_endpoint()
|
||||
self._headers = _get_headers()
|
||||
self.tenant_id = _get_tenant_id()
|
||||
self.example_id = example_id
|
||||
|
||||
def _get_session_create(
|
||||
self, name: Optional[str] = None, **kwargs: Any
|
||||
@@ -135,16 +139,30 @@ class LangChainTracerV2(LangChainTracer):
|
||||
|
||||
def _persist_session(self, session_create: TracerSessionBase) -> TracerSessionV2:
|
||||
"""Persist a session."""
|
||||
session: Optional[TracerSessionV2] = None
|
||||
try:
|
||||
r = requests.post(
|
||||
f"{self._endpoint}/sessions",
|
||||
data=session_create.json(),
|
||||
headers=self._headers,
|
||||
)
|
||||
session = TracerSessionV2(id=r.json()["id"], **session_create.dict())
|
||||
raise_for_status_with_text(r)
|
||||
creation_args = session_create.dict()
|
||||
if "id" in creation_args:
|
||||
del creation_args["id"]
|
||||
return TracerSessionV2(id=r.json()["id"], **creation_args)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to create session, using default session: {e}")
|
||||
session = self.load_session("default")
|
||||
if session_create.name is not None:
|
||||
try:
|
||||
return self.load_session(session_create.name)
|
||||
except Exception:
|
||||
pass
|
||||
logging.warning(
|
||||
f"Failed to create session {session_create.name},"
|
||||
f" using empty session: {e}"
|
||||
)
|
||||
session = TracerSessionV2(id=uuid4(), **session_create.dict())
|
||||
|
||||
return session
|
||||
|
||||
def _get_default_query_params(self) -> Dict[str, Any]:
|
||||
@@ -159,13 +177,14 @@ class LangChainTracerV2(LangChainTracer):
|
||||
if session_name:
|
||||
params["name"] = session_name
|
||||
r = requests.get(url, headers=self._headers, params=params)
|
||||
raise_for_status_with_text(r)
|
||||
tracer_session = TracerSessionV2(**r.json()[0])
|
||||
except Exception as e:
|
||||
session_type = "default" if not session_name else session_name
|
||||
logging.warning(
|
||||
f"Failed to load {session_type} session, using empty session: {e}"
|
||||
)
|
||||
tracer_session = TracerSessionV2(id=1, tenant_id=self.tenant_id)
|
||||
tracer_session = TracerSessionV2(id=uuid4(), tenant_id=self.tenant_id)
|
||||
|
||||
self.session = tracer_session
|
||||
return tracer_session
|
||||
@@ -174,7 +193,7 @@ class LangChainTracerV2(LangChainTracer):
|
||||
"""Load the default tracing session and set it as the Tracer's session."""
|
||||
return self.load_session("default")
|
||||
|
||||
def _convert_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> Run:
|
||||
def _convert_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> RunCreate:
|
||||
"""Convert a run to a Run."""
|
||||
session = self.session or self.load_default_session()
|
||||
inputs: Dict[str, Any] = {}
|
||||
@@ -204,9 +223,9 @@ class LangChainTracerV2(LangChainTracer):
|
||||
*run.child_tool_runs,
|
||||
]
|
||||
|
||||
return Run(
|
||||
return RunCreate(
|
||||
id=run.uuid,
|
||||
name=run.serialized.get("name", f"{run_type}-{run.uuid}"),
|
||||
name=run.serialized.get("name"),
|
||||
start_time=run.start_time,
|
||||
end_time=run.end_time,
|
||||
extra=run.extra or {},
|
||||
@@ -217,7 +236,7 @@ class LangChainTracerV2(LangChainTracer):
|
||||
outputs=outputs,
|
||||
session_id=session.id,
|
||||
run_type=run_type,
|
||||
parent_run_id=run.parent_uuid,
|
||||
reference_example_id=self.example_id,
|
||||
child_runs=[self._convert_run(child) for child in child_runs],
|
||||
)
|
||||
|
||||
@@ -225,11 +244,11 @@ class LangChainTracerV2(LangChainTracer):
|
||||
"""Persist a run."""
|
||||
run_create = self._convert_run(run)
|
||||
try:
|
||||
result = requests.post(
|
||||
response = requests.post(
|
||||
f"{self._endpoint}/runs",
|
||||
data=run_create.json(),
|
||||
headers=self._headers,
|
||||
)
|
||||
result.raise_for_status()
|
||||
raise_for_status_with_text(response)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to persist run: {e}")
|
||||
|
||||
@@ -37,9 +37,11 @@ class TracerSessionV2Base(TracerSessionBase):
|
||||
tenant_id: UUID
|
||||
|
||||
|
||||
class TracerSessionV2Create(TracerSessionBase):
|
||||
class TracerSessionV2Create(TracerSessionV2Base):
|
||||
"""A creation class for TracerSessionV2."""
|
||||
|
||||
id: Optional[UUID]
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@@ -100,9 +102,10 @@ class RunTypeEnum(str, Enum):
|
||||
llm = "llm"
|
||||
|
||||
|
||||
class Run(BaseModel):
|
||||
class RunBase(BaseModel):
|
||||
"""Base Run schema."""
|
||||
|
||||
id: Optional[UUID]
|
||||
name: str
|
||||
start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
|
||||
end_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
|
||||
extra: dict
|
||||
@@ -112,10 +115,22 @@ class Run(BaseModel):
|
||||
inputs: dict
|
||||
outputs: Optional[dict]
|
||||
session_id: UUID
|
||||
parent_run_id: Optional[UUID]
|
||||
reference_example_id: Optional[UUID]
|
||||
run_type: RunTypeEnum
|
||||
child_runs: List[Run] = Field(default_factory=list)
|
||||
|
||||
|
||||
class RunCreate(RunBase):
|
||||
"""Schema to create a run in the DB."""
|
||||
|
||||
name: Optional[str]
|
||||
child_runs: List[RunCreate] = Field(default_factory=list)
|
||||
|
||||
|
||||
class Run(RunBase):
|
||||
"""Run schema when loading from the DB."""
|
||||
|
||||
name: str
|
||||
parent_run_id: Optional[UUID]
|
||||
|
||||
|
||||
ChainRun.update_forward_refs()
|
||||
|
||||
6
langchain/client/__init__.py
Normal file
6
langchain/client/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""LangChain+ Client."""
|
||||
|
||||
|
||||
from langchain.client.langchain import LangChainPlusClient
|
||||
|
||||
__all__ = ["LangChainPlusClient"]
|
||||
519
langchain/client/langchain.py
Normal file
519
langchain/client/langchain.py
Normal file
@@ -0,0 +1,519 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import socket
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from urllib.parse import urlsplit
|
||||
from uuid import UUID
|
||||
|
||||
import requests
|
||||
from pydantic import BaseSettings, Field, root_validator
|
||||
from requests import Response
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import tracing_v2_enabled
|
||||
from langchain.callbacks.tracers.langchain import LangChainTracerV2
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.client.models import Dataset, DatasetCreate, Example, ExampleCreate
|
||||
from langchain.client.utils import parse_chat_messages
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.schema import ChatResult, LLMResult
|
||||
from langchain.utils import raise_for_status_with_text, xor_args
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pandas as pd
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_link_stem(url: str) -> str:
|
||||
scheme = urlsplit(url).scheme
|
||||
netloc_prefix = urlsplit(url).netloc.split(":")[0]
|
||||
return f"{scheme}://{netloc_prefix}"
|
||||
|
||||
|
||||
def _is_localhost(url: str) -> bool:
|
||||
"""Check if the URL is localhost."""
|
||||
try:
|
||||
netloc = urlsplit(url).netloc.split(":")[0]
|
||||
ip = socket.gethostbyname(netloc)
|
||||
return ip == "127.0.0.1" or ip.startswith("0.0.0.0") or ip.startswith("::")
|
||||
except socket.gaierror:
|
||||
return False
|
||||
|
||||
|
||||
class LangChainPlusClient(BaseSettings):
|
||||
"""Client for interacting with the LangChain+ API."""
|
||||
|
||||
api_key: Optional[str] = Field(default=None, env="LANGCHAIN_API_KEY")
|
||||
api_url: str = Field(..., env="LANGCHAIN_ENDPOINT")
|
||||
tenant_id: str = Field(..., env="LANGCHAIN_TENANT_ID")
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_api_key_if_hosted(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Verify API key is provided if url not localhost."""
|
||||
api_url: str = values.get("api_url", "http://localhost:8000")
|
||||
api_key: Optional[str] = values.get("api_key")
|
||||
if not _is_localhost(api_url):
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"API key must be provided when using hosted LangChain+ API"
|
||||
)
|
||||
else:
|
||||
tenant_id = values.get("tenant_id")
|
||||
if not tenant_id:
|
||||
values["tenant_id"] = LangChainPlusClient._get_seeded_tenant_id(
|
||||
api_url, api_key
|
||||
)
|
||||
return values
|
||||
|
||||
@staticmethod
|
||||
def _get_seeded_tenant_id(api_url: str, api_key: Optional[str]) -> str:
|
||||
"""Get the tenant ID from the seeded tenant."""
|
||||
url = f"{api_url}/tenants"
|
||||
headers = {"authorization": f"Bearer {api_key}"} if api_key else {}
|
||||
response = requests.get(url, headers=headers)
|
||||
try:
|
||||
raise_for_status_with_text(response)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"Unable to get seeded tenant ID. Please manually provide."
|
||||
) from e
|
||||
results: List[dict] = response.json()
|
||||
if len(results) == 0:
|
||||
raise ValueError("No seeded tenant found")
|
||||
return results[0]["id"]
|
||||
|
||||
def _repr_html_(self) -> str:
|
||||
"""Return an HTML representation of the instance with a link to the URL."""
|
||||
link = _get_link_stem(self.api_url)
|
||||
return f'<a href="{link}", target="_blank" rel="noopener">LangChain+ Client</a>'
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return a string representation of the instance with a link to the URL."""
|
||||
return f"LangChainPlusClient (API URL: {self.api_url})"
|
||||
|
||||
@property
|
||||
def _headers(self) -> Dict[str, str]:
|
||||
"""Get the headers for the API request."""
|
||||
headers = {}
|
||||
if self.api_key:
|
||||
headers["authorization"] = f"Bearer {self.api_key}"
|
||||
return headers
|
||||
|
||||
@property
|
||||
def query_params(self) -> Dict[str, str]:
|
||||
"""Get the headers for the API request."""
|
||||
return {"tenant_id": self.tenant_id}
|
||||
|
||||
def _get(self, path: str, params: Optional[Dict[str, Any]] = None) -> Response:
|
||||
"""Make a GET request."""
|
||||
query_params = self.query_params
|
||||
if params:
|
||||
query_params.update(params)
|
||||
return requests.get(
|
||||
f"{self.api_url}{path}", headers=self._headers, params=query_params
|
||||
)
|
||||
|
||||
@xor_args(("dataset_id", "dataset_name"))
|
||||
def create_example(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
dataset_id: Optional[UUID] = None,
|
||||
dataset_name: Optional[str] = None,
|
||||
created_at: Optional[datetime] = None,
|
||||
outputs: Dict[str, Any] | None = None,
|
||||
) -> Example:
|
||||
"""Create a dataset example in the LangChain+ API."""
|
||||
if dataset_id is None:
|
||||
dataset_id = self.read_dataset(dataset_name).id
|
||||
|
||||
data = {
|
||||
"inputs": inputs,
|
||||
"outputs": outputs,
|
||||
"dataset_id": dataset_id,
|
||||
}
|
||||
if created_at:
|
||||
data["created_at"] = created_at.isoformat()
|
||||
example = ExampleCreate(**data)
|
||||
response = requests.post(
|
||||
f"{self.api_url}/examples", headers=self._headers, data=example.json()
|
||||
)
|
||||
raise_for_status_with_text(response)
|
||||
result = response.json()
|
||||
return Example(**result)
|
||||
|
||||
def upload_dataframe(
|
||||
self,
|
||||
df: pd.DataFrame,
|
||||
name: str,
|
||||
description: str,
|
||||
input_keys: List[str],
|
||||
output_keys: List[str],
|
||||
) -> Dataset:
|
||||
"""Upload a dataframe as individual examples to the LangChain+ API."""
|
||||
dataset = self.create_dataset(dataset_name=name, description=description)
|
||||
for row in df.itertuples():
|
||||
inputs = {key: getattr(row, key) for key in input_keys}
|
||||
outputs = {key: getattr(row, key) for key in output_keys}
|
||||
self.create_example(inputs, outputs=outputs, dataset_id=dataset.id)
|
||||
return dataset
|
||||
|
||||
def upload_csv(
|
||||
self,
|
||||
csv_file: Union[str, Tuple[str, BytesIO]],
|
||||
description: str,
|
||||
input_keys: List[str],
|
||||
output_keys: List[str],
|
||||
) -> Dataset:
|
||||
"""Upload a CSV file to the LangChain+ API."""
|
||||
files = {"file": csv_file}
|
||||
data = {
|
||||
"input_keys": ",".join(input_keys),
|
||||
"output_keys": ",".join(output_keys),
|
||||
"description": description,
|
||||
"tenant_id": self.tenant_id,
|
||||
}
|
||||
response = requests.post(
|
||||
self.api_url + "/datasets/upload",
|
||||
headers=self._headers,
|
||||
data=data,
|
||||
files=files,
|
||||
)
|
||||
raise_for_status_with_text(response)
|
||||
result = response.json()
|
||||
# TODO: Make this more robust server-side
|
||||
if "detail" in result and "already exists" in result["detail"]:
|
||||
file_name = csv_file if isinstance(csv_file, str) else csv_file[0]
|
||||
file_name = file_name.split("/")[-1]
|
||||
raise ValueError(f"Dataset {file_name} already exists")
|
||||
return Dataset(**result)
|
||||
|
||||
def create_dataset(self, dataset_name: str, description: str) -> Dataset:
|
||||
"""Create a dataset in the LangChain+ API."""
|
||||
dataset = DatasetCreate(
|
||||
tenant_id=self.tenant_id,
|
||||
name=dataset_name,
|
||||
description=description,
|
||||
)
|
||||
response = requests.post(
|
||||
self.api_url + "/datasets",
|
||||
headers=self._headers,
|
||||
data=dataset.json(),
|
||||
)
|
||||
raise_for_status_with_text(response)
|
||||
return Dataset(**response.json())
|
||||
|
||||
@xor_args(("dataset_name", "dataset_id"))
|
||||
def read_dataset(
|
||||
self, *, dataset_name: Optional[str] = None, dataset_id: Optional[str] = None
|
||||
) -> Dataset:
|
||||
path = "/datasets"
|
||||
params: Dict[str, Any] = {"limit": 1, "tenant_id": self.tenant_id}
|
||||
if dataset_id is not None:
|
||||
path += f"/{dataset_id}"
|
||||
elif dataset_name is not None:
|
||||
params["name"] = dataset_name
|
||||
else:
|
||||
raise ValueError("Must provide dataset_name or dataset_id")
|
||||
response = self._get(
|
||||
path,
|
||||
params=params,
|
||||
)
|
||||
raise_for_status_with_text(response)
|
||||
result = response.json()
|
||||
if isinstance(result, list):
|
||||
if len(result) == 0:
|
||||
raise ValueError(f"Dataset {dataset_name} not found")
|
||||
return Dataset(**result[0])
|
||||
return Dataset(**result)
|
||||
|
||||
def list_datasets(self, limit: int = 100) -> Iterable[Dataset]:
|
||||
"""List the datasets on the LangChain+ API."""
|
||||
response = self._get("/datasets", params={"limit": limit})
|
||||
raise_for_status_with_text(response)
|
||||
return [Dataset(**dataset) for dataset in response.json()]
|
||||
|
||||
@xor_args(("dataset_id", "dataset_name"))
|
||||
def delete_dataset(
|
||||
self, *, dataset_id: Optional[str] = None, dataset_name: Optional[str] = None
|
||||
) -> Dataset:
|
||||
"""Delete a dataset by ID or name."""
|
||||
if dataset_name is not None:
|
||||
dataset_id = self.read_dataset(dataset_name=dataset_name).id
|
||||
if dataset_id is None:
|
||||
raise ValueError("Must provide either dataset name or ID")
|
||||
response = requests.delete(
|
||||
f"{self.api_url}/datasets/{dataset_id}",
|
||||
headers=self._headers,
|
||||
)
|
||||
raise_for_status_with_text(response)
|
||||
return response.json()
|
||||
|
||||
def read_example(self, example_id: str) -> Example:
|
||||
"""Read an example from the LangChain+ API."""
|
||||
response = self._get(f"/examples/{example_id}")
|
||||
raise_for_status_with_text(response)
|
||||
return Example(**response.json())
|
||||
|
||||
def list_examples(
|
||||
self, dataset_id: Optional[str] = None, dataset_name: Optional[str] = None
|
||||
) -> Iterable[Example]:
|
||||
"""List the datasets on the LangChain+ API."""
|
||||
params = {}
|
||||
if dataset_id is not None:
|
||||
params["dataset"] = dataset_id
|
||||
elif dataset_name is not None:
|
||||
dataset_id = self.read_dataset(dataset_name=dataset_name).id
|
||||
params["dataset"] = dataset_id
|
||||
else:
|
||||
pass
|
||||
response = self._get("/examples", params=params)
|
||||
raise_for_status_with_text(response)
|
||||
return [Example(**dataset) for dataset in response.json()]
|
||||
|
||||
@staticmethod
|
||||
async def _arun_llm(
|
||||
llm: BaseLanguageModel,
|
||||
inputs: Dict[str, Any],
|
||||
langchain_tracer: LangChainTracerV2,
|
||||
) -> Union[LLMResult, ChatResult]:
|
||||
if isinstance(llm, BaseLLM):
|
||||
llm_prompts: List[str] = inputs["prompts"]
|
||||
llm_output = await llm.agenerate(llm_prompts, callbacks=[langchain_tracer])
|
||||
elif isinstance(llm, BaseChatModel):
|
||||
chat_prompts: List[str] = inputs["prompts"]
|
||||
messages = [
|
||||
parse_chat_messages(chat_prompt) for chat_prompt in chat_prompts
|
||||
]
|
||||
llm_output = await llm.agenerate(messages, callbacks=[langchain_tracer])
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM type {type(llm)}")
|
||||
return llm_output
|
||||
|
||||
@staticmethod
|
||||
async def _arun_llm_or_chain(
|
||||
example: Example,
|
||||
langchain_tracer: LangChainTracerV2,
|
||||
llm_or_chain: Union[Chain, BaseLanguageModel],
|
||||
n_repetitions: int,
|
||||
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
|
||||
"""Run the chain asynchronously."""
|
||||
previous_example_id = langchain_tracer.example_id
|
||||
langchain_tracer.example_id = example.id
|
||||
outputs = []
|
||||
for _ in range(n_repetitions):
|
||||
try:
|
||||
if isinstance(llm_or_chain, BaseLanguageModel):
|
||||
output: Any = await LangChainPlusClient._arun_llm(
|
||||
llm_or_chain, example.inputs, langchain_tracer
|
||||
)
|
||||
else:
|
||||
output = await llm_or_chain.arun(
|
||||
example.inputs, callbacks=[langchain_tracer]
|
||||
)
|
||||
outputs.append(output)
|
||||
except Exception as e:
|
||||
logger.warning(f"Chain failed for example {example.id}. Error: {e}")
|
||||
outputs.append({"Error": str(e)})
|
||||
finally:
|
||||
langchain_tracer.example_id = previous_example_id
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
async def _worker(
|
||||
queue: asyncio.Queue,
|
||||
tracer: LangChainTracerV2,
|
||||
llm_or_chain: Union[Chain, BaseLanguageModel],
|
||||
n_repetitions: int,
|
||||
results: Dict[str, Any],
|
||||
job_state: Dict[str, Any],
|
||||
verbose: bool,
|
||||
) -> None:
|
||||
"""Worker for running the chain on examples."""
|
||||
while True:
|
||||
example: Optional[Example] = await queue.get()
|
||||
if example is None:
|
||||
break
|
||||
|
||||
result = await LangChainPlusClient._arun_llm_or_chain(
|
||||
example,
|
||||
tracer,
|
||||
llm_or_chain,
|
||||
n_repetitions,
|
||||
)
|
||||
results[str(example.id)] = result
|
||||
queue.task_done()
|
||||
job_state["num_processed"] += 1
|
||||
if verbose:
|
||||
print(
|
||||
f"Processed examples: {job_state['num_processed']}",
|
||||
end="\r",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
async def arun_on_dataset(
|
||||
self,
|
||||
dataset_name: str,
|
||||
llm_or_chain: Union[Chain, BaseLanguageModel],
|
||||
num_workers: int = 5,
|
||||
num_repetitions: int = 1,
|
||||
session_name: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the chain on a dataset and store traces to the specified session name.
|
||||
|
||||
Args:
|
||||
dataset_name: Name of the dataset to run the chain on.
|
||||
llm_or_chain: Chain or language model to run over the dataset.
|
||||
num_workers: Number of async workers to run in parallel.
|
||||
num_repetitions: Number of times to run the model on each example.
|
||||
This is useful when testing success rates or generating confidence
|
||||
intervals.
|
||||
session_name: Name of the session to store the traces in.
|
||||
Defaults to {dataset_name}-{chain class name}-{datetime}.
|
||||
verbose: Whether to print progress.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping example ids to the model outputs.
|
||||
"""
|
||||
if session_name is None:
|
||||
current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
||||
session_name = (
|
||||
f"{dataset_name}-{llm_or_chain.__class__.__name__}-{current_time}"
|
||||
)
|
||||
dataset = self.read_dataset(dataset_name=dataset_name)
|
||||
workers = []
|
||||
examples = self.list_examples(dataset_id=str(dataset.id))
|
||||
results: Dict[str, Any] = {}
|
||||
queue: asyncio.Queue[Optional[Example]] = asyncio.Queue()
|
||||
job_state = {"num_processed": 0}
|
||||
with tracing_v2_enabled(session_name=session_name) as session:
|
||||
for _ in range(num_workers):
|
||||
tracer = LangChainTracerV2()
|
||||
tracer.session = session
|
||||
task = asyncio.create_task(
|
||||
LangChainPlusClient._worker(
|
||||
queue,
|
||||
tracer,
|
||||
llm_or_chain,
|
||||
num_repetitions,
|
||||
results,
|
||||
job_state,
|
||||
verbose,
|
||||
)
|
||||
)
|
||||
workers.append(task)
|
||||
|
||||
for example in examples:
|
||||
await queue.put(example)
|
||||
|
||||
await queue.join() # Wait for all tasks to complete
|
||||
|
||||
for _ in workers:
|
||||
await queue.put(None) # Signal the workers to exit
|
||||
|
||||
await asyncio.gather(*workers)
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def run_llm(
|
||||
llm: BaseLanguageModel,
|
||||
inputs: Dict[str, Any],
|
||||
langchain_tracer: LangChainTracerV2,
|
||||
) -> Union[LLMResult, ChatResult]:
|
||||
"""Run the language model on the example."""
|
||||
if isinstance(llm, BaseLLM):
|
||||
llm_prompts: List[str] = inputs["prompts"]
|
||||
llm_output = llm.generate(llm_prompts, callbacks=[langchain_tracer])
|
||||
elif isinstance(llm, BaseChatModel):
|
||||
chat_prompts: List[str] = inputs["prompts"]
|
||||
messages = [
|
||||
parse_chat_messages(chat_prompt) for chat_prompt in chat_prompts
|
||||
]
|
||||
llm_output = llm.generate(messages, callbacks=[langchain_tracer])
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM type {type(llm)}")
|
||||
return llm_output
|
||||
|
||||
@staticmethod
|
||||
def run_llm_or_chain(
|
||||
example: Example,
|
||||
langchain_tracer: LangChainTracerV2,
|
||||
llm_or_chain: Union[Chain, BaseLanguageModel],
|
||||
n_repetitions: int,
|
||||
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
|
||||
"""Run the chain synchronously."""
|
||||
previous_example_id = langchain_tracer.example_id
|
||||
langchain_tracer.example_id = example.id
|
||||
outputs = []
|
||||
for _ in range(n_repetitions):
|
||||
try:
|
||||
if isinstance(llm_or_chain, BaseLanguageModel):
|
||||
output: Any = LangChainPlusClient.run_llm(
|
||||
llm_or_chain, example.inputs, langchain_tracer
|
||||
)
|
||||
else:
|
||||
output = llm_or_chain.run(
|
||||
example.inputs, callbacks=[langchain_tracer]
|
||||
)
|
||||
outputs.append(output)
|
||||
except Exception as e:
|
||||
logger.warning(f"Chain failed for example {example.id}. Error: {e}")
|
||||
outputs.append({"Error": str(e)})
|
||||
finally:
|
||||
langchain_tracer.example_id = previous_example_id
|
||||
return outputs
|
||||
|
||||
def run_on_dataset(
|
||||
self,
|
||||
dataset_name: str,
|
||||
llm_or_chain: Union[Chain, BaseLanguageModel],
|
||||
num_repetitions: int = 1,
|
||||
session_name: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the chain on a dataset and store traces to the specified session name.
|
||||
|
||||
Args:
|
||||
dataset_name: Name of the dataset to run the chain on.
|
||||
llm_or_chain: Chain or language model to run over the dataset.
|
||||
num_workers: Number of async workers to run in parallel.
|
||||
num_repetitions: Number of times to run the model on each example.
|
||||
This is useful when testing success rates or generating confidence
|
||||
intervals.
|
||||
session_name: Name of the session to store the traces in.
|
||||
Defaults to {dataset_name}-{chain class name}-{datetime}.
|
||||
verbose: Whether to print progress.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping example ids to the model outputs.
|
||||
"""
|
||||
if session_name is None:
|
||||
current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
||||
session_name = (
|
||||
f"{dataset_name}-{llm_or_chain.__class__.__name__}-{current_time}"
|
||||
)
|
||||
dataset = self.read_dataset(dataset_name=dataset_name)
|
||||
examples = self.list_examples(dataset_id=str(dataset.id))
|
||||
results: Dict[str, Any] = {}
|
||||
with tracing_v2_enabled(session_name=session_name) as session:
|
||||
tracer = LangChainTracerV2()
|
||||
tracer.session = session
|
||||
|
||||
for i, example in enumerate(examples):
|
||||
result = self.run_llm_or_chain(
|
||||
example,
|
||||
tracer,
|
||||
llm_or_chain,
|
||||
num_repetitions,
|
||||
)
|
||||
if verbose:
|
||||
print(f"{i+1} processed", flush=True, end="\r")
|
||||
results[str(example.id)] = result
|
||||
return results
|
||||
54
langchain/client/models.py
Normal file
54
langchain/client/models.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.callbacks.tracers.schemas import Run
|
||||
|
||||
|
||||
class ExampleBase(BaseModel):
|
||||
"""Example base model."""
|
||||
|
||||
dataset_id: UUID
|
||||
inputs: Dict[str, Any]
|
||||
outputs: Optional[Dict[str, Any]] = Field(default=None)
|
||||
|
||||
|
||||
class ExampleCreate(ExampleBase):
|
||||
"""Example create model."""
|
||||
|
||||
id: Optional[UUID]
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class Example(ExampleBase):
|
||||
"""Example model."""
|
||||
|
||||
id: UUID
|
||||
created_at: datetime
|
||||
modified_at: Optional[datetime] = Field(default=None)
|
||||
runs: List[Run] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DatasetBase(BaseModel):
|
||||
"""Dataset base model."""
|
||||
|
||||
tenant_id: UUID
|
||||
name: str
|
||||
description: str
|
||||
|
||||
|
||||
class DatasetCreate(DatasetBase):
|
||||
"""Dataset create model."""
|
||||
|
||||
id: Optional[UUID]
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class Dataset(DatasetBase):
|
||||
"""Dataset ORM model."""
|
||||
|
||||
id: UUID
|
||||
created_at: datetime
|
||||
modified_at: Optional[datetime] = Field(default=None)
|
||||
42
langchain/client/utils.py
Normal file
42
langchain/client/utils.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""Client Utils."""
|
||||
import re
|
||||
from typing import Dict, List, Optional, Sequence, Type, Union
|
||||
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
|
||||
_DEFAULT_MESSAGES_T = Union[Type[HumanMessage], Type[SystemMessage], Type[AIMessage]]
|
||||
_RESOLUTION_MAP: Dict[str, _DEFAULT_MESSAGES_T] = {
|
||||
"Human": HumanMessage,
|
||||
"AI": AIMessage,
|
||||
"System": SystemMessage,
|
||||
}
|
||||
|
||||
|
||||
def parse_chat_messages(
|
||||
input_text: str, roles: Optional[Sequence[str]] = None
|
||||
) -> List[BaseMessage]:
|
||||
"""Parse chat messages from a string. This is not robust."""
|
||||
roles = roles or ["Human", "AI", "System"]
|
||||
roles_pattern = "|".join(roles)
|
||||
pattern = (
|
||||
rf"(?P<entity>{roles_pattern}): (?P<message>"
|
||||
rf"(?:.*\n?)*?)(?=(?:{roles_pattern}): |\Z)"
|
||||
)
|
||||
matches = re.finditer(pattern, input_text, re.MULTILINE)
|
||||
|
||||
results: List[BaseMessage] = []
|
||||
for match in matches:
|
||||
entity = match.group("entity")
|
||||
message = match.group("message").rstrip("\n")
|
||||
if entity in _RESOLUTION_MAP:
|
||||
results.append(_RESOLUTION_MAP[entity](content=message))
|
||||
else:
|
||||
results.append(ChatMessage(role=entity, content=message))
|
||||
|
||||
return results
|
||||
@@ -2,6 +2,8 @@
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
|
||||
from requests import HTTPError, Response
|
||||
|
||||
|
||||
def get_from_dict_or_env(
|
||||
data: Dict[str, Any], key: str, env_key: str, default: Optional[str] = None
|
||||
@@ -52,6 +54,14 @@ def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
|
||||
return decorator
|
||||
|
||||
|
||||
def raise_for_status_with_text(response: Response) -> None:
|
||||
"""Raise an error with the response text."""
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except HTTPError as e:
|
||||
raise ValueError(response.text) from e
|
||||
|
||||
|
||||
def stringify_value(val: Any) -> str:
|
||||
if isinstance(val, str):
|
||||
return val
|
||||
|
||||
@@ -18,7 +18,12 @@ from langchain.callbacks.tracers.base import (
|
||||
TracerSession,
|
||||
)
|
||||
from langchain.callbacks.tracers.langchain import LangChainTracerV2
|
||||
from langchain.callbacks.tracers.schemas import Run, TracerSessionBase, TracerSessionV2
|
||||
from langchain.callbacks.tracers.schemas import (
|
||||
RunCreate,
|
||||
TracerSessionBase,
|
||||
TracerSessionV2,
|
||||
TracerSessionV2Create,
|
||||
)
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
TEST_SESSION_ID = 2023
|
||||
@@ -541,14 +546,12 @@ def sample_runs() -> Tuple[LLMRun, ChainRun, ToolRun]:
|
||||
return llm_run, chain_run, tool_run
|
||||
|
||||
|
||||
# Test _get_default_query_params method
|
||||
def test_get_default_query_params(lang_chain_tracer_v2: LangChainTracerV2) -> None:
|
||||
expected = {"tenant_id": "test-tenant-id"}
|
||||
result = lang_chain_tracer_v2._get_default_query_params()
|
||||
assert result == expected
|
||||
|
||||
|
||||
# Test load_session method
|
||||
@patch("langchain.callbacks.tracers.langchain.requests.get")
|
||||
def test_load_session(
|
||||
mock_requests_get: Mock,
|
||||
@@ -577,23 +580,65 @@ def test_convert_run(
|
||||
converted_chain_run = lang_chain_tracer_v2._convert_run(chain_run)
|
||||
converted_tool_run = lang_chain_tracer_v2._convert_run(tool_run)
|
||||
|
||||
assert isinstance(converted_llm_run, Run)
|
||||
assert isinstance(converted_chain_run, Run)
|
||||
assert isinstance(converted_tool_run, Run)
|
||||
assert isinstance(converted_llm_run, RunCreate)
|
||||
assert isinstance(converted_chain_run, RunCreate)
|
||||
assert isinstance(converted_tool_run, RunCreate)
|
||||
|
||||
|
||||
@patch("langchain.callbacks.tracers.langchain.requests.post")
|
||||
def test_persist_run(
|
||||
mock_requests_post: Mock,
|
||||
lang_chain_tracer_v2: LangChainTracerV2,
|
||||
sample_tracer_session_v2: TracerSessionV2,
|
||||
sample_runs: Tuple[LLMRun, ChainRun, ToolRun],
|
||||
) -> None:
|
||||
mock_requests_post.return_value.raise_for_status.return_value = None
|
||||
lang_chain_tracer_v2.session = sample_tracer_session_v2
|
||||
llm_run, chain_run, tool_run = sample_runs
|
||||
lang_chain_tracer_v2._persist_run(llm_run)
|
||||
lang_chain_tracer_v2._persist_run(chain_run)
|
||||
lang_chain_tracer_v2._persist_run(tool_run)
|
||||
"""Test that persist_run method calls requests.post once per method call."""
|
||||
with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch(
|
||||
"langchain.callbacks.tracers.langchain.requests.get"
|
||||
) as get:
|
||||
post.return_value.raise_for_status.return_value = None
|
||||
lang_chain_tracer_v2.session = sample_tracer_session_v2
|
||||
llm_run, chain_run, tool_run = sample_runs
|
||||
lang_chain_tracer_v2._persist_run(llm_run)
|
||||
lang_chain_tracer_v2._persist_run(chain_run)
|
||||
lang_chain_tracer_v2._persist_run(tool_run)
|
||||
|
||||
assert mock_requests_post.call_count == 3
|
||||
assert post.call_count == 3
|
||||
assert get.call_count == 0
|
||||
|
||||
|
||||
def test_get_session_create(lang_chain_tracer_v2: LangChainTracerV2) -> None:
|
||||
"""Test creating the 'SessionCreate' object."""
|
||||
lang_chain_tracer_v2.tenant_id = str(_TENANT_ID)
|
||||
session_create = lang_chain_tracer_v2._get_session_create(name="test")
|
||||
assert isinstance(session_create, TracerSessionV2Create)
|
||||
assert session_create.name == "test"
|
||||
assert session_create.tenant_id == _TENANT_ID
|
||||
|
||||
|
||||
@patch("langchain.callbacks.tracers.langchain.requests.post")
|
||||
def test_persist_session(
|
||||
mock_requests_post: Mock,
|
||||
lang_chain_tracer_v2: LangChainTracerV2,
|
||||
sample_tracer_session_v2: TracerSessionV2,
|
||||
) -> None:
|
||||
"""Test persist_session returns a TracerSessionV2 with the updated ID."""
|
||||
session_create = TracerSessionV2Create(**sample_tracer_session_v2.dict())
|
||||
new_id = str(uuid4())
|
||||
mock_requests_post.return_value.json.return_value = {"id": new_id}
|
||||
result = lang_chain_tracer_v2._persist_session(session_create)
|
||||
assert isinstance(result, TracerSessionV2)
|
||||
res = sample_tracer_session_v2.dict()
|
||||
res["id"] = UUID(new_id)
|
||||
assert result.dict() == res
|
||||
|
||||
|
||||
@patch("langchain.callbacks.tracers.langchain.LangChainTracerV2.load_session")
|
||||
def test_load_default_session(
|
||||
mock_load_session: Mock,
|
||||
lang_chain_tracer_v2: LangChainTracerV2,
|
||||
sample_tracer_session_v2: TracerSessionV2,
|
||||
) -> None:
|
||||
"""Test load_default_session attempts to load with the default name."""
|
||||
mock_load_session.return_value = sample_tracer_session_v2
|
||||
result = lang_chain_tracer_v2.load_default_session()
|
||||
assert result == sample_tracer_session_v2
|
||||
mock_load_session.assert_called_with("default")
|
||||
|
||||
0
tests/unit_tests/client/__init__.py
Normal file
0
tests/unit_tests/client/__init__.py
Normal file
210
tests/unit_tests/client/test_langchain.py
Normal file
210
tests/unit_tests/client/test_langchain.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""Test the LangChain+ client."""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, List, Union
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.tracers.langchain import LangChainTracerV2
|
||||
from langchain.callbacks.tracers.schemas import TracerSessionV2
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.client.langchain import (
|
||||
LangChainPlusClient,
|
||||
_get_link_stem,
|
||||
_is_localhost,
|
||||
)
|
||||
from langchain.client.models import Dataset, Example
|
||||
|
||||
_CREATED_AT = datetime(2015, 1, 1, 0, 0, 0)
|
||||
_TENANT_ID = "7a3d2b56-cd5b-44e5-846f-7eb6e8144ce4"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"api_url, expected_url",
|
||||
[
|
||||
("http://localhost:8000", "http://localhost"),
|
||||
("http://www.example.com", "http://www.example.com"),
|
||||
(
|
||||
"https://hosted-1234-23qwerty.f.234.foobar.gateway.dev",
|
||||
"https://hosted-1234-23qwerty.f.234.foobar.gateway.dev",
|
||||
),
|
||||
("https://www.langchain.com/path/to/nowhere", "https://www.langchain.com"),
|
||||
],
|
||||
)
|
||||
def test_link_split(api_url: str, expected_url: str) -> None:
|
||||
"""Test the link splitting handles both localhost and deployed urls."""
|
||||
assert _get_link_stem(api_url) == expected_url
|
||||
|
||||
|
||||
def test_is_localhost() -> None:
|
||||
assert _is_localhost("http://localhost:8000")
|
||||
assert _is_localhost("http://127.0.0.1:8000")
|
||||
assert _is_localhost("http://0.0.0.0:8000")
|
||||
assert not _is_localhost("http://example.com:8000")
|
||||
|
||||
|
||||
def test_validate_api_key_if_hosted() -> None:
|
||||
with pytest.raises(ValueError, match="API key must be provided"):
|
||||
LangChainPlusClient(api_url="http://www.example.com")
|
||||
|
||||
client = LangChainPlusClient(api_url="http://localhost:8000")
|
||||
assert client.api_url == "http://localhost:8000"
|
||||
assert client.api_key is None
|
||||
|
||||
|
||||
def test_headers() -> None:
|
||||
client = LangChainPlusClient(api_url="http://localhost:8000", api_key="123")
|
||||
assert client._headers == {"authorization": "Bearer 123"}
|
||||
|
||||
client_no_key = LangChainPlusClient(api_url="http://localhost:8000")
|
||||
assert client_no_key._headers == {}
|
||||
|
||||
|
||||
@mock.patch("langchain.client.langchain.requests.post")
|
||||
def test_upload_csv(mock_post: mock.Mock) -> None:
|
||||
mock_response = mock.Mock()
|
||||
dataset_id = str(uuid.uuid4())
|
||||
example_1 = Example(
|
||||
id=str(uuid.uuid4()),
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"input": "1"},
|
||||
outputs={"output": "2"},
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
example_2 = Example(
|
||||
id=str(uuid.uuid4()),
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"input": "3"},
|
||||
outputs={"output": "4"},
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
|
||||
mock_response.json.return_value = {
|
||||
"id": "1",
|
||||
"name": "test.csv",
|
||||
"description": "Test dataset",
|
||||
"owner_id": "the owner",
|
||||
"created_at": _CREATED_AT,
|
||||
"examples": [example_1, example_2],
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
client = LangChainPlusClient(api_url="http://localhost:8000", api_key="123")
|
||||
csv_file = ("test.csv", BytesIO(b"input,output\n1,2\n3,4\n"))
|
||||
|
||||
dataset = client.upload_csv(
|
||||
csv_file, "Test dataset", input_keys=["input"], output_keys=["output"]
|
||||
)
|
||||
|
||||
assert dataset.id == "1"
|
||||
assert dataset.name == "test.csv"
|
||||
assert dataset.description == "Test dataset"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arun_on_dataset() -> None:
|
||||
dataset = Dataset(
|
||||
id="1",
|
||||
name="test",
|
||||
description="Test dataset",
|
||||
owner_id="owner",
|
||||
created_at=_CREATED_AT,
|
||||
)
|
||||
uuids = [
|
||||
"0c193153-2309-4704-9a47-17aee4fb25c8",
|
||||
"0d11b5fd-8e66-4485-b696-4b55155c0c05",
|
||||
"90d696f0-f10d-4fd0-b88b-bfee6df08b84",
|
||||
"4ce2c6d8-5124-4c0c-8292-db7bdebcf167",
|
||||
"7b5a524c-80fa-4960-888e-7d380f9a11ee",
|
||||
]
|
||||
examples = [
|
||||
Example(
|
||||
id=uuids[0],
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"input": "1"},
|
||||
outputs={"output": "2"},
|
||||
dataset_id=str(uuid.uuid4()),
|
||||
),
|
||||
Example(
|
||||
id=uuids[1],
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"input": "3"},
|
||||
outputs={"output": "4"},
|
||||
dataset_id=str(uuid.uuid4()),
|
||||
),
|
||||
Example(
|
||||
id=uuids[2],
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"input": "5"},
|
||||
outputs={"output": "6"},
|
||||
dataset_id=str(uuid.uuid4()),
|
||||
),
|
||||
Example(
|
||||
id=uuids[3],
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"input": "7"},
|
||||
outputs={"output": "8"},
|
||||
dataset_id=str(uuid.uuid4()),
|
||||
),
|
||||
Example(
|
||||
id=uuids[4],
|
||||
created_at=_CREATED_AT,
|
||||
inputs={"input": "9"},
|
||||
outputs={"output": "10"},
|
||||
dataset_id=str(uuid.uuid4()),
|
||||
),
|
||||
]
|
||||
|
||||
async def mock_aread_dataset(*args: Any, **kwargs: Any) -> Dataset:
|
||||
return dataset
|
||||
|
||||
async def mock_alist_examples(*args: Any, **kwargs: Any) -> List[Example]:
|
||||
return examples
|
||||
|
||||
async def mock_arun_chain(
|
||||
example: Example,
|
||||
tracer: Any,
|
||||
llm_or_chain: Union[BaseLanguageModel, Chain],
|
||||
n_repetitions: int,
|
||||
) -> List[Dict[str, Any]]:
|
||||
return [
|
||||
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
|
||||
]
|
||||
|
||||
def mock_load_session(
|
||||
self: Any, name: str, *args: Any, **kwargs: Any
|
||||
) -> TracerSessionV2:
|
||||
return TracerSessionV2(name=name, tenant_id=_TENANT_ID, id=uuid.uuid4())
|
||||
|
||||
with mock.patch.object(
|
||||
LangChainPlusClient, "aread_dataset", new=mock_aread_dataset
|
||||
), mock.patch.object(
|
||||
LangChainPlusClient, "alist_examples", new=mock_alist_examples
|
||||
), mock.patch.object(
|
||||
LangChainPlusClient, "_arun_llm_or_chain", new=mock_arun_chain
|
||||
), mock.patch.object(
|
||||
LangChainTracerV2, "load_session", new=mock_load_session
|
||||
):
|
||||
client = LangChainPlusClient(
|
||||
api_url="http://localhost:8000", api_key="123", tenant_id=_TENANT_ID
|
||||
)
|
||||
chain = mock.MagicMock()
|
||||
|
||||
results = await client.arun_on_dataset(
|
||||
dataset_name="test",
|
||||
llm_or_chain=chain,
|
||||
num_workers=2,
|
||||
session_name="test_session",
|
||||
num_repetitions=3,
|
||||
)
|
||||
|
||||
expected = {
|
||||
uuid.UUID(uuid_): [
|
||||
{"result": f"Result for example {uuid.UUID(uuid_)}"} for _ in range(3)
|
||||
]
|
||||
for uuid_ in uuids
|
||||
}
|
||||
assert results == expected
|
||||
70
tests/unit_tests/client/test_utils.py
Normal file
70
tests/unit_tests/client/test_utils.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Test LangChain+ Client Utils."""
|
||||
|
||||
from typing import List
|
||||
|
||||
from langchain.client.utils import parse_chat_messages
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
|
||||
|
||||
def test_parse_chat_messages() -> None:
|
||||
"""Test that chat messages are parsed correctly."""
|
||||
input_text = (
|
||||
"Human: I am human roar\nAI: I am AI beep boop\nSystem: I am a system message"
|
||||
)
|
||||
expected = [
|
||||
HumanMessage(content="I am human roar"),
|
||||
AIMessage(content="I am AI beep boop"),
|
||||
SystemMessage(content="I am a system message"),
|
||||
]
|
||||
assert parse_chat_messages(input_text) == expected
|
||||
|
||||
|
||||
def test_parse_chat_messages_empty_input() -> None:
|
||||
"""Test that an empty input string returns an empty list."""
|
||||
input_text = ""
|
||||
expected: List[BaseMessage] = []
|
||||
assert parse_chat_messages(input_text) == expected
|
||||
|
||||
|
||||
def test_parse_chat_messages_multiline_messages() -> None:
|
||||
"""Test that multiline messages are parsed correctly."""
|
||||
input_text = (
|
||||
"Human: I am a human\nand I roar\nAI: I am an AI\nand I"
|
||||
" beep boop\nSystem: I am a system\nand a message"
|
||||
)
|
||||
expected = [
|
||||
HumanMessage(content="I am a human\nand I roar"),
|
||||
AIMessage(content="I am an AI\nand I beep boop"),
|
||||
SystemMessage(content="I am a system\nand a message"),
|
||||
]
|
||||
assert parse_chat_messages(input_text) == expected
|
||||
|
||||
|
||||
def test_parse_chat_messages_custom_roles() -> None:
|
||||
"""Test that custom roles are parsed correctly."""
|
||||
input_text = "Client: I need help\nAgent: I'm here to help\nClient: Thank you"
|
||||
expected = [
|
||||
ChatMessage(role="Client", content="I need help"),
|
||||
ChatMessage(role="Agent", content="I'm here to help"),
|
||||
ChatMessage(role="Client", content="Thank you"),
|
||||
]
|
||||
assert parse_chat_messages(input_text, roles=["Client", "Agent"]) == expected
|
||||
|
||||
|
||||
def test_parse_chat_messages_embedded_roles() -> None:
|
||||
"""Test that messages with embedded role references are parsed correctly."""
|
||||
input_text = (
|
||||
"Human: Oh ai what if you said AI: foo bar?"
|
||||
"\nAI: Well, that would be interesting!"
|
||||
)
|
||||
expected = [
|
||||
HumanMessage(content="Oh ai what if you said AI: foo bar?"),
|
||||
AIMessage(content="Well, that would be interesting!"),
|
||||
]
|
||||
assert parse_chat_messages(input_text) == expected
|
||||
Reference in New Issue
Block a user