Compare commits

...

35 Commits

Author SHA1 Message Date
vowelparrot
15510bae4a Merge branch 'vwp/lcp_client' into vwp/eval_examples 2023-05-05 14:20:58 -07:00
vowelparrot
15afdd0858 Update V2 Tracer
- Update the RunCreate object to work with recent changes
- Adjust default persist_session behavior to attempt to
load the session if it exists
- Raise more useful HTTP errors for logging
- Add unit testing
2023-05-05 14:18:00 -07:00
vowelparrot
636877b3af Merge branch 'vwp/lcp_client' into vwp/eval_examples 2023-05-05 14:17:36 -07:00
vowelparrot
b7508d9b52 Update V2 Tracer
- Update the RunCreate object to work with recent changes
- Adjust default persist_session behavior to attempt to
load the session if it exists
- Raise more useful HTTP errors for logging
- Add unit testing
2023-05-05 14:16:55 -07:00
vowelparrot
1d58a49b78 Merge branch 'vwp/lcp_client' into vwp/eval_examples 2023-05-05 14:15:54 -07:00
vowelparrot
8c84538bf5 Update V2 Tracer
- Update the RunCreate object to work with recent changes
- Adjust default persist_session behavior to attempt to
load the session if it exists
- Raise more useful HTTP errors for logging
- Add unit testing
2023-05-05 14:11:28 -07:00
vowelparrot
d0aec0fda4 Merge branch 'master' into vwp/eval_examples 2023-05-05 11:40:32 -07:00
vowelparrot
f7f9756576 Updat with LLM and Chat Models 2023-05-05 10:00:09 -07:00
vowelparrot
ed1d3a48e4 Update notebook 2023-05-05 02:25:28 -07:00
vowelparrot
e3808b63d3 add tests 2023-05-04 23:37:14 -07:00
vowelparrot
b709a7b5f0 Merge branch 'master' into vwp/eval_examples 2023-05-04 23:20:23 -07:00
vowelparrot
fa9b296c55 Merge branch 'vwp/add_tenant' into vwp/eval_examples 2023-05-04 21:09:25 -07:00
vowelparrot
ede1bf204c change name 2023-05-04 20:56:31 -07:00
vowelparrot
c2f856fac9 further linting 2023-05-04 20:56:31 -07:00
vowelparrot
059f4e8bb6 Update session 2023-05-04 20:56:31 -07:00
vowelparrot
712d4a228e Update V2 Tracer to be compatible with new data model
Add tenant ID's and update UUID's in runs and sessions
2023-05-04 20:56:31 -07:00
Zander Chase
faa30c8ac5 Visual Studio Code/Github Codespaces Dev Containers (#4035) (#4122)
Having dev containers makes its easier, faster and secure to setup the
dev environment for the repository.

The pull request consists of:

- .devcontainer folder with:
- **devcontainer.json :** (minimal necessary vscode extensions and
settings)
- **docker-compose.yaml :** (could be modified to run necessary services
as per need. Ex vectordbs, databases)
    - **Dockerfile:**(non root with dev tools)
- Changes to README - added the Open in Github Codespaces Badge - added
the Open in dev container Badge

Co-authored-by: Jinto Jose <129657162+jj701@users.noreply.github.com>
2023-05-04 20:56:31 -07:00
vowelparrot
4c3982464d Update uuid typing and the works 2023-05-04 20:56:21 -07:00
vowelparrot
52a7453469 Merge branch 'vwp/add_tenant' into vwp/eval_examples 2023-05-04 16:37:40 -07:00
vowelparrot
f16e413c7c further linting 2023-05-04 16:33:38 -07:00
vowelparrot
9d1889bb59 Update session 2023-05-04 16:25:41 -07:00
vowelparrot
0304a1a563 Migrating to new endpoints 2023-05-04 15:39:30 -07:00
vowelparrot
7058b207fb Let 'Outputs' be None 2023-05-04 15:06:13 -07:00
vowelparrot
40253143d6 Switch to using an asyncio Queue instead of batches 2023-05-04 15:06:13 -07:00
vowelparrot
4f4d1799b0 Add tests 2023-05-04 15:06:13 -07:00
vowelparrot
50615a5282 add html repr 2023-05-04 15:06:13 -07:00
vowelparrot
92a91e54fb Rerun 2023-05-04 15:06:13 -07:00
vowelparrot
054b4ff0d3 Change Name 2023-05-04 15:06:13 -07:00
vowelparrot
759445229b Add del 2023-05-04 15:06:13 -07:00
vowelparrot
93159c6088 add testing 2023-05-04 15:06:13 -07:00
vowelparrot
ee4d92aa00 Update Client 2023-05-04 15:06:13 -07:00
vowelparrot
ddc26e074e [WIP] Example Notebook running a chain on a dataset 2023-05-04 15:06:11 -07:00
vowelparrot
2fd3133239 Update session 2023-05-04 15:04:15 -07:00
vowelparrot
0a87cbd1e6 Update V2 Tracer to be compatible with new data model
Add tenant ID's and update UUID's in runs and sessions
2023-05-04 13:41:17 -07:00
Zander Chase
2df6119194 Visual Studio Code/Github Codespaces Dev Containers (#4035) (#4122)
Having dev containers makes its easier, faster and secure to setup the
dev environment for the repository.

The pull request consists of:

- .devcontainer folder with:
- **devcontainer.json :** (minimal necessary vscode extensions and
settings)
- **docker-compose.yaml :** (could be modified to run necessary services
as per need. Ex vectordbs, databases)
    - **Dockerfile:**(non root with dev tools)
- Changes to README - added the Open in Github Codespaces Badge - added
the Open in dev container Badge

Co-authored-by: Jinto Jose <129657162+jj701@users.noreply.github.com>
2023-05-04 13:41:17 -07:00
14 changed files with 1997 additions and 38 deletions

View File

@@ -0,0 +1,967 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "1a4596ea-a631-416d-a2a4-3577c140493d",
"metadata": {},
"source": [
"# Running Chains on Traced Datasets\n",
"\n",
"Developing applications with language models can be uniquely challenging. To manage this complexity and ensure reliable performance, LangChain provides tracing and evaluation functionality through . This notebook demonstrates how to run Chains, which are language model functions, on previously captured datasets or traces. Some common use cases for this approach include:\n",
"\n",
"- Running an evaluation chain to grade previous runs.\n",
"- Comparing different chains, LLMs, and agents on traced datasets.\n",
"- Executing a stochastic chain multiple times over a dataset to generate metrics before deployment.\n",
"\n",
"Please note that this notebook assumes you have LangChain+ tracing running in the background. It is also configured to work only with the V2 endpoints. To set it up, follow the [tracing directions here](..\\/..\\/tracing\\/local_installation.md).\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "904db9a5-f387-4a57-914c-c8af8d39e249",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain.client import LangChainPlusClient\n",
"\n",
"client = LangChainPlusClient(\n",
" api_url=\"http://localhost:8000\",\n",
" api_key=None,\n",
" # tenant_id=\"your_tenant_uuid\", # This is required when connecting to a hosted LangChain instance\n",
")"
]
},
{
"cell_type": "markdown",
"id": "db79dea2-fbaa-4c12-9083-f6154b51e2d3",
"metadata": {
"tags": []
},
"source": [
"## Seed an example dataset\n",
"\n",
"If you have been using LangChainPlus already, you may have datasets available. To view all saved datasets, run:\n",
"\n",
"```\n",
"datasets = client.list_datasets()\n",
"datasets\n",
"```\n",
"Datasets can be created in a number of ways, most often by collecting `Run`'s captured through the LangChain tracing API.\n",
"\n",
"However, this notebook assumes you're running locally for the first time, so we'll start by uploading an example evaluation dataset."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "1baa677c-5642-4378-8e01-3aa1647f19d6",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# !pip install datasets > /dev/null\n",
"# !pip install pandas > /dev/null"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "60d14593-c61f-449f-a38f-772ca43707c2",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Found cached dataset json (/Users/wfh/.cache/huggingface/datasets/LangChainDatasets___json/LangChainDatasets--agent-search-calculator-8a025c0ce5fb99d2/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "60be643da25c47c6aa729b37046a2e64",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/1 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>input</th>\n",
" <th>output</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>How many people live in canada as of 2023?</td>\n",
" <td>approximately 38,625,801</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>who is dua lipa's boyfriend? what is his age r...</td>\n",
" <td>her boyfriend is Romain Gravas. his age raised...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>what is dua lipa's boyfriend age raised to the...</td>\n",
" <td>her boyfriend is Romain Gravas. his age raised...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>how far is it from paris to boston in miles</td>\n",
" <td>approximately 3,435 mi</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>what was the total number of points scored in ...</td>\n",
" <td>approximately 2.682651500990882</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" input \\\n",
"0 How many people live in canada as of 2023? \n",
"1 who is dua lipa's boyfriend? what is his age r... \n",
"2 what is dua lipa's boyfriend age raised to the... \n",
"3 how far is it from paris to boston in miles \n",
"4 what was the total number of points scored in ... \n",
"\n",
" output \n",
"0 approximately 38,625,801 \n",
"1 her boyfriend is Romain Gravas. his age raised... \n",
"2 her boyfriend is Romain Gravas. his age raised... \n",
"3 approximately 3,435 mi \n",
"4 approximately 2.682651500990882 "
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"from langchain.evaluation.loading import load_dataset\n",
"\n",
"dataset = load_dataset(\"agent-search-calculator\")\n",
"df = pd.DataFrame(dataset, columns=[\"question\", \"answer\"])\n",
"df.columns = [\"input\", \"output\"] # The chain we want to evaluate below expects inputs with the \"input\" key \n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "f27e45f1-e299-4de8-a538-ee1272ac5024",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"dataset_name = f\"calculator-example-dataset\""
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "52a7ea76-79ca-4765-abf7-231e884040d6",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"if dataset_name not in set([dataset.name for dataset in client.list_datasets()]):\n",
" dataset = client.upload_dataframe(df, \n",
" name=dataset_name,\n",
" description=\"Acalculator example dataset\",\n",
" input_keys=[\"input\"],\n",
" output_keys=[\"output\"],\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "07885b10",
"metadata": {
"tags": []
},
"source": [
"## Running a Chain on a Traced Dataset\n",
"\n",
"Once you have a dataset, you can run a chain over it to see its results. The run traces will automatically be associated with the dataset for easy attribution and analysis.\n",
"\n",
"**First, we'll define the chain we wish to run over the dataset.**\n",
"\n",
"In this case, we're using an agent, but it can be any simple chain."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "c2b59104-b90e-466a-b7ea-c5bd0194263b",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.agents import initialize_agent, load_tools\n",
"from langchain.agents import AgentType\n",
"\n",
"llm = ChatOpenAI(temperature=0)\n",
"tools = load_tools(['serpapi', 'llm-math'], llm=llm)\n",
"agent = initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=False)"
]
},
{
"cell_type": "markdown",
"id": "84094a4a-1d76-461c-bc37-8c537939b466",
"metadata": {},
"source": [
"**Now we're ready to run the chain!**\n",
"\n",
"The docstring below hints out ways you can configure the method to run."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "112d7bdf-7e50-4c1a-9285-5bac8473f2ee",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"\u001b[0;31mSignature:\u001b[0m\n",
"\u001b[0mclient\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marun_on_dataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mdataset_name\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'str'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mllm_or_chain\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Union[Chain, BaseLanguageModel]'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mnum_workers\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'int'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mnum_repetitions\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'int'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0msession_name\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Optional[str]'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'bool'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;34m'Dict[str, Any]'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mDocstring:\u001b[0m\n",
"Run the chain on a dataset and store traces to the specified session name.\n",
"\n",
"Args:\n",
" dataset_name: Name of the dataset to run the chain on.\n",
" llm_or_chain: Chain or language model to run over the dataset.\n",
" num_workers: Number of async workers to run in parallel.\n",
" num_repetitions: Number of times to run the model on each example.\n",
" This is useful when testing success rates or generating confidence\n",
" intervals.\n",
" session_name: Name of the session to store the traces in.\n",
" Defaults to {dataset_name}-{chain class name}-{datetime}.\n",
" verbose: Whether to print progress.\n",
"\n",
"Returns:\n",
" A dictionary mapping example ids to the model outputs.\n",
"\u001b[0;31mFile:\u001b[0m ~/code/lc/lckg/langchain/client/langchain.py\n",
"\u001b[0;31mType:\u001b[0m method"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"?client.arun_on_dataset"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "a8088b7d-3ab6-4279-94c8-5116fe7cee33",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/wfh/code/lc/lckg/langchain/callbacks/manager.py:65: UserWarning: The experimental tracing v2 is in development. This is not yet stable and may change in the future.\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processed examples: 1\r"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Chain failed for example 4aca2037-5393-4357-9bf9-ef7451b8f0e3. Error: unknown format from LLM: Assuming we don't have any information about the actual number of points scored in the 2023 super bowl, we cannot provide a mathematical expression to solve this problem.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processed examples: 2\r"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Chain failed for example b0adfea5-1212-4cee-9206-11bd69ccfcc5. Error: 'age'. Please try again with a valid numerical expression\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processed examples: 5\r"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Chain failed for example 7e5873e4-ed38-428b-909f-f8c0fb80bab9. Error: invalid syntax. Perhaps you forgot a comma? (<expr>, line 1). Please try again with a valid numerical expression\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processed examples: 10\r"
]
}
],
"source": [
"chain_results = await client.arun_on_dataset(\n",
" dataset_name=dataset_name,\n",
" llm_or_chain=agent,\n",
" verbose=True\n",
")"
]
},
{
"cell_type": "markdown",
"id": "d2737458-b20c-4288-8790-1f4a8d237b2a",
"metadata": {},
"source": [
"## Reviewing the Chain Results\n",
"\n",
"The method called above returns a dictionary mapping Example IDs to the output of the chain.\n",
"You can directly inspect the results below."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "136db492-d6ca-4215-96f9-439c23538241",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"<a href=\"http://localhost\", target=\"_blank\" rel=\"noopener\">LangChain+ Client</a>"
],
"text/plain": [
"LangChainPlusClient (API URL: http://localhost:8000)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# You can navigate to the UI by clicking on the link below\n",
"client"
]
},
{
"cell_type": "markdown",
"id": "c70cceb5-aa53-4851-bb12-386f092191f9",
"metadata": {},
"source": [
"### Running a Chat Model over a Traced Dataset\n",
"\n",
"We've shown how to run a _chain_ over a dataset, but you can also run an LLM or Chat model over a datasets formed from runs.\n",
"\n",
"First, we'll show an example using a ChatModel. This is useful for things like:\n",
"- Comparing results under different decoding parameters\n",
"- Comparing model providers\n",
"- Testing for regressions in model behavior\n",
"- Running multiple times with a temperature to gauge stability"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "64490d7c-9a18-49ed-a3ac-36049c522cb4",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Found cached dataset parquet (/Users/wfh/.cache/huggingface/datasets/LangChainDatasets___parquet/LangChainDatasets--two-player-dnd-2e84407830cdedfc/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "35a85ea78f934758a5f8b16af0d1944a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/1 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>prompts</th>\n",
" <th>generations</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>[System: Here is the topic for a Dungeons &amp; Dr...</td>\n",
" <td>[[{'generation_info': None, 'message': {'conte...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>[System: Here is the topic for a Dungeons &amp; Dr...</td>\n",
" <td>[[{'generation_info': None, 'message': {'conte...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>[System: Here is the topic for a Dungeons &amp; Dr...</td>\n",
" <td>[[{'generation_info': None, 'message': {'conte...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>[System: Here is the topic for a Dungeons &amp; Dr...</td>\n",
" <td>[[{'generation_info': None, 'message': {'conte...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>[System: Here is the topic for a Dungeons &amp; Dr...</td>\n",
" <td>[[{'generation_info': None, 'message': {'conte...</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" prompts \\\n",
"0 [System: Here is the topic for a Dungeons & Dr... \n",
"1 [System: Here is the topic for a Dungeons & Dr... \n",
"2 [System: Here is the topic for a Dungeons & Dr... \n",
"3 [System: Here is the topic for a Dungeons & Dr... \n",
"4 [System: Here is the topic for a Dungeons & Dr... \n",
"\n",
" generations \n",
"0 [[{'generation_info': None, 'message': {'conte... \n",
"1 [[{'generation_info': None, 'message': {'conte... \n",
"2 [[{'generation_info': None, 'message': {'conte... \n",
"3 [[{'generation_info': None, 'message': {'conte... \n",
"4 [[{'generation_info': None, 'message': {'conte... "
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chat_dataset = load_dataset(\"two-player-dnd\")\n",
"chat_df = pd.DataFrame(chat_dataset)\n",
"chat_df.head()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "348acd86-a927-4d60-8d52-02e64585e4fc",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"chat_dataset_name = \"two-player-dnd\"\n",
"\n",
"if chat_dataset_name not in set([dataset.name for dataset in client.list_datasets()]):\n",
" client.upload_dataframe(chat_df, \n",
" name=chat_dataset_name,\n",
" description=\"An example dataset traced from chat models in a multiagent bidding dialogue\",\n",
" input_keys=[\"prompts\"],\n",
" output_keys=[\"generations\"],\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "a69dd183-ad5e-473d-b631-db90706e837f",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain.chat_models import ChatOpenAI\n",
"\n",
"chat_model = ChatOpenAI(temperature=0)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "063da2a9-3692-4b7b-8edb-e474824fe416",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/wfh/code/lc/lckg/langchain/callbacks/manager.py:65: UserWarning: The experimental tracing v2 is in development. This is not yet stable and may change in the future.\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processed examples: 35\r"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Retrying langchain.chat_models.openai.acompletion_with_retry.<locals>._completion_with_retry in 1.0 seconds as it raised RateLimitError: That model is currently overloaded with other requests. You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 2a718b30703cc8ffef8de0bb48935fd6 in your message.).\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processed examples: 36\r"
]
}
],
"source": [
"chat_model_results = await client.arun_on_dataset(\n",
" dataset_name=chat_dataset_name,\n",
" llm_or_chain=chat_model,\n",
" num_workers=5, # Optional, sets the number of examples to run at a time\n",
" # session_name=\"Calculator Dataset Runs\", # Optional. Will be seed with a default session otherwise\n",
" verbose=True\n",
")"
]
},
{
"cell_type": "markdown",
"id": "de7bfe08-215c-4328-b9b0-631d9a41f0e8",
"metadata": {
"tags": []
},
"source": [
"## Reviewing the Chat Model Results\n",
"\n",
"You can once again review the latest runs by clicking on the link below and navigating to the \"two-player-dnd\" session."
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "5b7a81f2-d19d-438b-a4bb-5678f746b965",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"<a href=\"http://localhost\", target=\"_blank\" rel=\"noopener\">LangChain+ Client</a>"
],
"text/plain": [
"LangChainPlusClient (API URL: http://localhost:8000)"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"client"
]
},
{
"cell_type": "markdown",
"id": "7896cbeb-345f-430b-ab5e-e108973174f8",
"metadata": {},
"source": [
"## Running an LLM over a Traced Dataset\n",
"\n",
"You can run an LLM over a dataset in much the same way as the chain and chat models, provided the dataset you've captured is in the appropriate format. Again, we've cached one for you, but using application-specific traces will be much more useful for your use cases."
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "d6805d0b-4612-4671-bffb-e6978992bd40",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/wfh/code/lc/lckg/langchain/llms/anthropic.py:134: UserWarning: This Anthropic LLM is deprecated. Please use `from langchain.chat_models import ChatAnthropic` instead\n",
" warnings.warn(\n"
]
}
],
"source": [
"from langchain.llms import Anthropic\n",
"\n",
"llm = Anthropic(temperature=0)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "5d7cb243-40c3-44dd-8158-a7b910441e9f",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Found cached dataset parquet (/Users/wfh/.cache/huggingface/datasets/LangChainDatasets___parquet/LangChainDatasets--state-of-the-union-completions-ae7542e7bbd0ae0a/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c8628ee0ce5b4924be636250b18e9803",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/1 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>prompts</th>\n",
" <th>generations</th>\n",
" <th>ground_truth</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>[Putin may circle Kyiv with tanks, but he will...</td>\n",
" <td>[[{'generation_info': {'finish_reason': 'stop'...</td>\n",
" <td>The pandemic has been punishing. \\n\\nAnd so ma...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>[Madam Speaker, Madam Vice President, our Firs...</td>\n",
" <td>[[]]</td>\n",
" <td>With a duty to one another to the American peo...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>[With a duty to one another to the American pe...</td>\n",
" <td>[[{'generation_info': {'finish_reason': 'stop'...</td>\n",
" <td>He thought he could roll into Ukraine and the ...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>[Madam Speaker, Madam Vice President, our Firs...</td>\n",
" <td>[[{'generation_info': {'finish_reason': 'lengt...</td>\n",
" <td>With a duty to one another to the American peo...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>[Please rise if you are able and show that, Ye...</td>\n",
" <td>[[]]</td>\n",
" <td>And the costs and the threats to America and t...</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" prompts \\\n",
"0 [Putin may circle Kyiv with tanks, but he will... \n",
"1 [Madam Speaker, Madam Vice President, our Firs... \n",
"2 [With a duty to one another to the American pe... \n",
"3 [Madam Speaker, Madam Vice President, our Firs... \n",
"4 [Please rise if you are able and show that, Ye... \n",
"\n",
" generations \\\n",
"0 [[{'generation_info': {'finish_reason': 'stop'... \n",
"1 [[]] \n",
"2 [[{'generation_info': {'finish_reason': 'stop'... \n",
"3 [[{'generation_info': {'finish_reason': 'lengt... \n",
"4 [[]] \n",
"\n",
" ground_truth \n",
"0 The pandemic has been punishing. \\n\\nAnd so ma... \n",
"1 With a duty to one another to the American peo... \n",
"2 He thought he could roll into Ukraine and the ... \n",
"3 With a duty to one another to the American peo... \n",
"4 And the costs and the threats to America and t... "
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"completions_dataset = load_dataset(\"state-of-the-union-completions\")\n",
"completions_df = pd.DataFrame(completions_dataset)\n",
"completions_df.head()"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "c7dcc1b2-7aef-44c0-ba0f-c812279099a5",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"completions_dataset_name = \"state-of-the-union-completions\"\n",
"\n",
"if completions_dataset_name not in set([dataset.name for dataset in client.list_datasets()]):\n",
" client.upload_dataframe(completions_df, \n",
" name=completions_dataset_name,\n",
" description=\"An example dataset traced from completion endpoints over the state of the union address\",\n",
" input_keys=[\"prompts\"],\n",
" output_keys=[\"generations\"],\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "02a0b79d-4867-45b2-b43d-662e5dcefb0f",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/wfh/code/lc/lckg/langchain/llms/anthropic.py:134: UserWarning: This Anthropic LLM is deprecated. Please use `from langchain.chat_models import ChatAnthropic` instead\n",
" warnings.warn(\n"
]
}
],
"source": [
"from langchain.llms import Anthropic\n",
"\n",
"\n",
"llm = Anthropic(temperature=0)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "e946138e-bf7c-43d7-861d-9c5740c933fa",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/wfh/code/lc/lckg/langchain/callbacks/manager.py:65: UserWarning: The experimental tracing v2 is in development. This is not yet stable and may change in the future.\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processed examples: 55\r"
]
}
],
"source": [
"completions_model_results = await client.arun_on_dataset(\n",
" dataset_name=completions_dataset_name,\n",
" llm_or_chain=llm,\n",
" num_workers=2, # Optional, sets the number of examples to run at a time\n",
" num_repetitions=1, # Increasing this will run the model multiple times per example\n",
" verbose=True\n",
")"
]
},
{
"cell_type": "markdown",
"id": "cc86e8e6-cee2-429e-942b-289284d14816",
"metadata": {},
"source": [
"## Reviewing the LLM Results\n",
"\n",
"You can once again inspect the latest runs by clicking on the link below and navigating to the \"two-player-dnd\" session."
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "2bf96f17-74c1-4f7d-8458-ae5ab5c6bd36",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"<a href=\"http://localhost\", target=\"_blank\" rel=\"noopener\">LangChain+ Client</a>"
],
"text/plain": [
"LangChainPlusClient (API URL: http://localhost:8000)"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"client"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bbafa5a5-a278-42c0-a24f-191cbcb41156",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -58,6 +58,7 @@ def tracing_enabled(
@contextmanager
def tracing_v2_enabled(
session_name: str = "default",
example_id: Optional[Union[str, UUID]] = None,
) -> Generator[TracerSessionV2, None, None]:
"""Get the experimental tracer handler in a context manager."""
# Issue a warning that this is experimental
@@ -65,8 +66,10 @@ def tracing_v2_enabled(
"The experimental tracing v2 is in development. "
"This is not yet stable and may change in the future."
)
cb = LangChainTracerV2()
session = cb.load_session(session_name)
if isinstance(example_id, str):
example_id = UUID(example_id)
cb = LangChainTracerV2(example_id=example_id)
session = cast(TracerSessionV2, cb.new_session(session_name))
tracing_callback_var.set(cb)
yield session
tracing_callback_var.set(None)

View File

@@ -29,7 +29,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.run_map: Dict[str, Union[LLMRun, ChainRun, ToolRun]] = {}
self.session: Optional[Union[TracerSessionV2, TracerSession]] = None
self.session: Optional[Union[TracerSession, TracerSessionV2]] = None
@staticmethod
def _add_child_run(
@@ -165,7 +165,6 @@ class BaseTracer(BaseCallbackHandler, ABC):
llm_run = self.run_map.get(run_id_)
if llm_run is None or not isinstance(llm_run, LLMRun):
raise TracerException("No LLMRun found to be traced")
llm_run.response = response
llm_run.end_time = datetime.utcnow()
self._end_trace(llm_run)

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
import logging
import os
from typing import Any, Dict, List, Optional, Union
from uuid import UUID, uuid4
import requests
@@ -11,13 +12,14 @@ from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import (
ChainRun,
LLMRun,
Run,
RunCreate,
ToolRun,
TracerSession,
TracerSessionBase,
TracerSessionV2,
TracerSessionV2Create,
)
from langchain.utils import raise_for_status_with_text
def _get_headers() -> Dict[str, Any]:
@@ -51,11 +53,12 @@ class LangChainTracer(BaseTracer):
endpoint = f"{self._endpoint}/tool-runs"
try:
requests.post(
response = requests.post(
endpoint,
data=run.json(),
headers=self._headers,
)
raise_for_status_with_text(response)
except Exception as e:
logging.warning(f"Failed to persist run: {e}")
@@ -111,7 +114,7 @@ def _get_tenant_id() -> Optional[str]:
endpoint = _get_endpoint()
headers = _get_headers()
response = requests.get(endpoint + "/tenants", headers=headers)
response.raise_for_status()
raise_for_status_with_text(response)
tenants: List[Dict[str, Any]] = response.json()
if not tenants:
raise ValueError(f"No tenants found for URL {endpoint}")
@@ -121,12 +124,13 @@ def _get_tenant_id() -> Optional[str]:
class LangChainTracerV2(LangChainTracer):
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
def __init__(self, **kwargs: Any) -> None:
def __init__(self, example_id: Optional[UUID] = None, **kwargs: Any) -> None:
"""Initialize the LangChain tracer."""
super().__init__(**kwargs)
self._endpoint = _get_endpoint()
self._headers = _get_headers()
self.tenant_id = _get_tenant_id()
self.example_id = example_id
def _get_session_create(
self, name: Optional[str] = None, **kwargs: Any
@@ -135,16 +139,30 @@ class LangChainTracerV2(LangChainTracer):
def _persist_session(self, session_create: TracerSessionBase) -> TracerSessionV2:
"""Persist a session."""
session: Optional[TracerSessionV2] = None
try:
r = requests.post(
f"{self._endpoint}/sessions",
data=session_create.json(),
headers=self._headers,
)
session = TracerSessionV2(id=r.json()["id"], **session_create.dict())
raise_for_status_with_text(r)
creation_args = session_create.dict()
if "id" in creation_args:
del creation_args["id"]
return TracerSessionV2(id=r.json()["id"], **creation_args)
except Exception as e:
logging.warning(f"Failed to create session, using default session: {e}")
session = self.load_session("default")
if session_create.name is not None:
try:
return self.load_session(session_create.name)
except Exception:
pass
logging.warning(
f"Failed to create session {session_create.name},"
f" using empty session: {e}"
)
session = TracerSessionV2(id=uuid4(), **session_create.dict())
return session
def _get_default_query_params(self) -> Dict[str, Any]:
@@ -159,13 +177,14 @@ class LangChainTracerV2(LangChainTracer):
if session_name:
params["name"] = session_name
r = requests.get(url, headers=self._headers, params=params)
raise_for_status_with_text(r)
tracer_session = TracerSessionV2(**r.json()[0])
except Exception as e:
session_type = "default" if not session_name else session_name
logging.warning(
f"Failed to load {session_type} session, using empty session: {e}"
)
tracer_session = TracerSessionV2(id=1, tenant_id=self.tenant_id)
tracer_session = TracerSessionV2(id=uuid4(), tenant_id=self.tenant_id)
self.session = tracer_session
return tracer_session
@@ -174,7 +193,7 @@ class LangChainTracerV2(LangChainTracer):
"""Load the default tracing session and set it as the Tracer's session."""
return self.load_session("default")
def _convert_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> Run:
def _convert_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> RunCreate:
"""Convert a run to a Run."""
session = self.session or self.load_default_session()
inputs: Dict[str, Any] = {}
@@ -204,9 +223,9 @@ class LangChainTracerV2(LangChainTracer):
*run.child_tool_runs,
]
return Run(
return RunCreate(
id=run.uuid,
name=run.serialized.get("name", f"{run_type}-{run.uuid}"),
name=run.serialized.get("name"),
start_time=run.start_time,
end_time=run.end_time,
extra=run.extra or {},
@@ -217,7 +236,7 @@ class LangChainTracerV2(LangChainTracer):
outputs=outputs,
session_id=session.id,
run_type=run_type,
parent_run_id=run.parent_uuid,
reference_example_id=self.example_id,
child_runs=[self._convert_run(child) for child in child_runs],
)
@@ -225,11 +244,11 @@ class LangChainTracerV2(LangChainTracer):
"""Persist a run."""
run_create = self._convert_run(run)
try:
result = requests.post(
response = requests.post(
f"{self._endpoint}/runs",
data=run_create.json(),
headers=self._headers,
)
result.raise_for_status()
raise_for_status_with_text(response)
except Exception as e:
logging.warning(f"Failed to persist run: {e}")

View File

@@ -37,9 +37,11 @@ class TracerSessionV2Base(TracerSessionBase):
tenant_id: UUID
class TracerSessionV2Create(TracerSessionBase):
class TracerSessionV2Create(TracerSessionV2Base):
"""A creation class for TracerSessionV2."""
id: Optional[UUID]
pass
@@ -100,9 +102,10 @@ class RunTypeEnum(str, Enum):
llm = "llm"
class Run(BaseModel):
class RunBase(BaseModel):
"""Base Run schema."""
id: Optional[UUID]
name: str
start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
end_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
extra: dict
@@ -112,10 +115,22 @@ class Run(BaseModel):
inputs: dict
outputs: Optional[dict]
session_id: UUID
parent_run_id: Optional[UUID]
reference_example_id: Optional[UUID]
run_type: RunTypeEnum
child_runs: List[Run] = Field(default_factory=list)
class RunCreate(RunBase):
"""Schema to create a run in the DB."""
name: Optional[str]
child_runs: List[RunCreate] = Field(default_factory=list)
class Run(RunBase):
"""Run schema when loading from the DB."""
name: str
parent_run_id: Optional[UUID]
ChainRun.update_forward_refs()

View File

@@ -0,0 +1,6 @@
"""LangChain+ Client."""
from langchain.client.langchain import LangChainPlusClient
__all__ = ["LangChainPlusClient"]

View File

@@ -0,0 +1,519 @@
from __future__ import annotations
import asyncio
import logging
import socket
from datetime import datetime
from io import BytesIO
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
from urllib.parse import urlsplit
from uuid import UUID
import requests
from pydantic import BaseSettings, Field, root_validator
from requests import Response
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import tracing_v2_enabled
from langchain.callbacks.tracers.langchain import LangChainTracerV2
from langchain.chains.base import Chain
from langchain.chat_models.base import BaseChatModel
from langchain.client.models import Dataset, DatasetCreate, Example, ExampleCreate
from langchain.client.utils import parse_chat_messages
from langchain.llms.base import BaseLLM
from langchain.schema import ChatResult, LLMResult
from langchain.utils import raise_for_status_with_text, xor_args
if TYPE_CHECKING:
import pandas as pd
logger = logging.getLogger(__name__)
def _get_link_stem(url: str) -> str:
scheme = urlsplit(url).scheme
netloc_prefix = urlsplit(url).netloc.split(":")[0]
return f"{scheme}://{netloc_prefix}"
def _is_localhost(url: str) -> bool:
"""Check if the URL is localhost."""
try:
netloc = urlsplit(url).netloc.split(":")[0]
ip = socket.gethostbyname(netloc)
return ip == "127.0.0.1" or ip.startswith("0.0.0.0") or ip.startswith("::")
except socket.gaierror:
return False
class LangChainPlusClient(BaseSettings):
"""Client for interacting with the LangChain+ API."""
api_key: Optional[str] = Field(default=None, env="LANGCHAIN_API_KEY")
api_url: str = Field(..., env="LANGCHAIN_ENDPOINT")
tenant_id: str = Field(..., env="LANGCHAIN_TENANT_ID")
@root_validator(pre=True)
def validate_api_key_if_hosted(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Verify API key is provided if url not localhost."""
api_url: str = values.get("api_url", "http://localhost:8000")
api_key: Optional[str] = values.get("api_key")
if not _is_localhost(api_url):
if not api_key:
raise ValueError(
"API key must be provided when using hosted LangChain+ API"
)
else:
tenant_id = values.get("tenant_id")
if not tenant_id:
values["tenant_id"] = LangChainPlusClient._get_seeded_tenant_id(
api_url, api_key
)
return values
@staticmethod
def _get_seeded_tenant_id(api_url: str, api_key: Optional[str]) -> str:
"""Get the tenant ID from the seeded tenant."""
url = f"{api_url}/tenants"
headers = {"authorization": f"Bearer {api_key}"} if api_key else {}
response = requests.get(url, headers=headers)
try:
raise_for_status_with_text(response)
except Exception as e:
raise ValueError(
"Unable to get seeded tenant ID. Please manually provide."
) from e
results: List[dict] = response.json()
if len(results) == 0:
raise ValueError("No seeded tenant found")
return results[0]["id"]
def _repr_html_(self) -> str:
"""Return an HTML representation of the instance with a link to the URL."""
link = _get_link_stem(self.api_url)
return f'<a href="{link}", target="_blank" rel="noopener">LangChain+ Client</a>'
def __repr__(self) -> str:
"""Return a string representation of the instance with a link to the URL."""
return f"LangChainPlusClient (API URL: {self.api_url})"
@property
def _headers(self) -> Dict[str, str]:
"""Get the headers for the API request."""
headers = {}
if self.api_key:
headers["authorization"] = f"Bearer {self.api_key}"
return headers
@property
def query_params(self) -> Dict[str, str]:
"""Get the headers for the API request."""
return {"tenant_id": self.tenant_id}
def _get(self, path: str, params: Optional[Dict[str, Any]] = None) -> Response:
"""Make a GET request."""
query_params = self.query_params
if params:
query_params.update(params)
return requests.get(
f"{self.api_url}{path}", headers=self._headers, params=query_params
)
@xor_args(("dataset_id", "dataset_name"))
def create_example(
self,
inputs: Dict[str, Any],
dataset_id: Optional[UUID] = None,
dataset_name: Optional[str] = None,
created_at: Optional[datetime] = None,
outputs: Dict[str, Any] | None = None,
) -> Example:
"""Create a dataset example in the LangChain+ API."""
if dataset_id is None:
dataset_id = self.read_dataset(dataset_name).id
data = {
"inputs": inputs,
"outputs": outputs,
"dataset_id": dataset_id,
}
if created_at:
data["created_at"] = created_at.isoformat()
example = ExampleCreate(**data)
response = requests.post(
f"{self.api_url}/examples", headers=self._headers, data=example.json()
)
raise_for_status_with_text(response)
result = response.json()
return Example(**result)
def upload_dataframe(
self,
df: pd.DataFrame,
name: str,
description: str,
input_keys: List[str],
output_keys: List[str],
) -> Dataset:
"""Upload a dataframe as individual examples to the LangChain+ API."""
dataset = self.create_dataset(dataset_name=name, description=description)
for row in df.itertuples():
inputs = {key: getattr(row, key) for key in input_keys}
outputs = {key: getattr(row, key) for key in output_keys}
self.create_example(inputs, outputs=outputs, dataset_id=dataset.id)
return dataset
def upload_csv(
self,
csv_file: Union[str, Tuple[str, BytesIO]],
description: str,
input_keys: List[str],
output_keys: List[str],
) -> Dataset:
"""Upload a CSV file to the LangChain+ API."""
files = {"file": csv_file}
data = {
"input_keys": ",".join(input_keys),
"output_keys": ",".join(output_keys),
"description": description,
"tenant_id": self.tenant_id,
}
response = requests.post(
self.api_url + "/datasets/upload",
headers=self._headers,
data=data,
files=files,
)
raise_for_status_with_text(response)
result = response.json()
# TODO: Make this more robust server-side
if "detail" in result and "already exists" in result["detail"]:
file_name = csv_file if isinstance(csv_file, str) else csv_file[0]
file_name = file_name.split("/")[-1]
raise ValueError(f"Dataset {file_name} already exists")
return Dataset(**result)
def create_dataset(self, dataset_name: str, description: str) -> Dataset:
"""Create a dataset in the LangChain+ API."""
dataset = DatasetCreate(
tenant_id=self.tenant_id,
name=dataset_name,
description=description,
)
response = requests.post(
self.api_url + "/datasets",
headers=self._headers,
data=dataset.json(),
)
raise_for_status_with_text(response)
return Dataset(**response.json())
@xor_args(("dataset_name", "dataset_id"))
def read_dataset(
self, *, dataset_name: Optional[str] = None, dataset_id: Optional[str] = None
) -> Dataset:
path = "/datasets"
params: Dict[str, Any] = {"limit": 1, "tenant_id": self.tenant_id}
if dataset_id is not None:
path += f"/{dataset_id}"
elif dataset_name is not None:
params["name"] = dataset_name
else:
raise ValueError("Must provide dataset_name or dataset_id")
response = self._get(
path,
params=params,
)
raise_for_status_with_text(response)
result = response.json()
if isinstance(result, list):
if len(result) == 0:
raise ValueError(f"Dataset {dataset_name} not found")
return Dataset(**result[0])
return Dataset(**result)
def list_datasets(self, limit: int = 100) -> Iterable[Dataset]:
"""List the datasets on the LangChain+ API."""
response = self._get("/datasets", params={"limit": limit})
raise_for_status_with_text(response)
return [Dataset(**dataset) for dataset in response.json()]
@xor_args(("dataset_id", "dataset_name"))
def delete_dataset(
self, *, dataset_id: Optional[str] = None, dataset_name: Optional[str] = None
) -> Dataset:
"""Delete a dataset by ID or name."""
if dataset_name is not None:
dataset_id = self.read_dataset(dataset_name=dataset_name).id
if dataset_id is None:
raise ValueError("Must provide either dataset name or ID")
response = requests.delete(
f"{self.api_url}/datasets/{dataset_id}",
headers=self._headers,
)
raise_for_status_with_text(response)
return response.json()
def read_example(self, example_id: str) -> Example:
"""Read an example from the LangChain+ API."""
response = self._get(f"/examples/{example_id}")
raise_for_status_with_text(response)
return Example(**response.json())
def list_examples(
self, dataset_id: Optional[str] = None, dataset_name: Optional[str] = None
) -> Iterable[Example]:
"""List the datasets on the LangChain+ API."""
params = {}
if dataset_id is not None:
params["dataset"] = dataset_id
elif dataset_name is not None:
dataset_id = self.read_dataset(dataset_name=dataset_name).id
params["dataset"] = dataset_id
else:
pass
response = self._get("/examples", params=params)
raise_for_status_with_text(response)
return [Example(**dataset) for dataset in response.json()]
@staticmethod
async def _arun_llm(
llm: BaseLanguageModel,
inputs: Dict[str, Any],
langchain_tracer: LangChainTracerV2,
) -> Union[LLMResult, ChatResult]:
if isinstance(llm, BaseLLM):
llm_prompts: List[str] = inputs["prompts"]
llm_output = await llm.agenerate(llm_prompts, callbacks=[langchain_tracer])
elif isinstance(llm, BaseChatModel):
chat_prompts: List[str] = inputs["prompts"]
messages = [
parse_chat_messages(chat_prompt) for chat_prompt in chat_prompts
]
llm_output = await llm.agenerate(messages, callbacks=[langchain_tracer])
else:
raise ValueError(f"Unsupported LLM type {type(llm)}")
return llm_output
@staticmethod
async def _arun_llm_or_chain(
example: Example,
langchain_tracer: LangChainTracerV2,
llm_or_chain: Union[Chain, BaseLanguageModel],
n_repetitions: int,
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
"""Run the chain asynchronously."""
previous_example_id = langchain_tracer.example_id
langchain_tracer.example_id = example.id
outputs = []
for _ in range(n_repetitions):
try:
if isinstance(llm_or_chain, BaseLanguageModel):
output: Any = await LangChainPlusClient._arun_llm(
llm_or_chain, example.inputs, langchain_tracer
)
else:
output = await llm_or_chain.arun(
example.inputs, callbacks=[langchain_tracer]
)
outputs.append(output)
except Exception as e:
logger.warning(f"Chain failed for example {example.id}. Error: {e}")
outputs.append({"Error": str(e)})
finally:
langchain_tracer.example_id = previous_example_id
return outputs
@staticmethod
async def _worker(
queue: asyncio.Queue,
tracer: LangChainTracerV2,
llm_or_chain: Union[Chain, BaseLanguageModel],
n_repetitions: int,
results: Dict[str, Any],
job_state: Dict[str, Any],
verbose: bool,
) -> None:
"""Worker for running the chain on examples."""
while True:
example: Optional[Example] = await queue.get()
if example is None:
break
result = await LangChainPlusClient._arun_llm_or_chain(
example,
tracer,
llm_or_chain,
n_repetitions,
)
results[str(example.id)] = result
queue.task_done()
job_state["num_processed"] += 1
if verbose:
print(
f"Processed examples: {job_state['num_processed']}",
end="\r",
flush=True,
)
async def arun_on_dataset(
self,
dataset_name: str,
llm_or_chain: Union[Chain, BaseLanguageModel],
num_workers: int = 5,
num_repetitions: int = 1,
session_name: Optional[str] = None,
verbose: bool = False,
) -> Dict[str, Any]:
"""Run the chain on a dataset and store traces to the specified session name.
Args:
dataset_name: Name of the dataset to run the chain on.
llm_or_chain: Chain or language model to run over the dataset.
num_workers: Number of async workers to run in parallel.
num_repetitions: Number of times to run the model on each example.
This is useful when testing success rates or generating confidence
intervals.
session_name: Name of the session to store the traces in.
Defaults to {dataset_name}-{chain class name}-{datetime}.
verbose: Whether to print progress.
Returns:
A dictionary mapping example ids to the model outputs.
"""
if session_name is None:
current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
session_name = (
f"{dataset_name}-{llm_or_chain.__class__.__name__}-{current_time}"
)
dataset = self.read_dataset(dataset_name=dataset_name)
workers = []
examples = self.list_examples(dataset_id=str(dataset.id))
results: Dict[str, Any] = {}
queue: asyncio.Queue[Optional[Example]] = asyncio.Queue()
job_state = {"num_processed": 0}
with tracing_v2_enabled(session_name=session_name) as session:
for _ in range(num_workers):
tracer = LangChainTracerV2()
tracer.session = session
task = asyncio.create_task(
LangChainPlusClient._worker(
queue,
tracer,
llm_or_chain,
num_repetitions,
results,
job_state,
verbose,
)
)
workers.append(task)
for example in examples:
await queue.put(example)
await queue.join() # Wait for all tasks to complete
for _ in workers:
await queue.put(None) # Signal the workers to exit
await asyncio.gather(*workers)
return results
@staticmethod
def run_llm(
llm: BaseLanguageModel,
inputs: Dict[str, Any],
langchain_tracer: LangChainTracerV2,
) -> Union[LLMResult, ChatResult]:
"""Run the language model on the example."""
if isinstance(llm, BaseLLM):
llm_prompts: List[str] = inputs["prompts"]
llm_output = llm.generate(llm_prompts, callbacks=[langchain_tracer])
elif isinstance(llm, BaseChatModel):
chat_prompts: List[str] = inputs["prompts"]
messages = [
parse_chat_messages(chat_prompt) for chat_prompt in chat_prompts
]
llm_output = llm.generate(messages, callbacks=[langchain_tracer])
else:
raise ValueError(f"Unsupported LLM type {type(llm)}")
return llm_output
@staticmethod
def run_llm_or_chain(
example: Example,
langchain_tracer: LangChainTracerV2,
llm_or_chain: Union[Chain, BaseLanguageModel],
n_repetitions: int,
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
"""Run the chain synchronously."""
previous_example_id = langchain_tracer.example_id
langchain_tracer.example_id = example.id
outputs = []
for _ in range(n_repetitions):
try:
if isinstance(llm_or_chain, BaseLanguageModel):
output: Any = LangChainPlusClient.run_llm(
llm_or_chain, example.inputs, langchain_tracer
)
else:
output = llm_or_chain.run(
example.inputs, callbacks=[langchain_tracer]
)
outputs.append(output)
except Exception as e:
logger.warning(f"Chain failed for example {example.id}. Error: {e}")
outputs.append({"Error": str(e)})
finally:
langchain_tracer.example_id = previous_example_id
return outputs
def run_on_dataset(
self,
dataset_name: str,
llm_or_chain: Union[Chain, BaseLanguageModel],
num_repetitions: int = 1,
session_name: Optional[str] = None,
verbose: bool = False,
) -> Dict[str, Any]:
"""Run the chain on a dataset and store traces to the specified session name.
Args:
dataset_name: Name of the dataset to run the chain on.
llm_or_chain: Chain or language model to run over the dataset.
num_workers: Number of async workers to run in parallel.
num_repetitions: Number of times to run the model on each example.
This is useful when testing success rates or generating confidence
intervals.
session_name: Name of the session to store the traces in.
Defaults to {dataset_name}-{chain class name}-{datetime}.
verbose: Whether to print progress.
Returns:
A dictionary mapping example ids to the model outputs.
"""
if session_name is None:
current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
session_name = (
f"{dataset_name}-{llm_or_chain.__class__.__name__}-{current_time}"
)
dataset = self.read_dataset(dataset_name=dataset_name)
examples = self.list_examples(dataset_id=str(dataset.id))
results: Dict[str, Any] = {}
with tracing_v2_enabled(session_name=session_name) as session:
tracer = LangChainTracerV2()
tracer.session = session
for i, example in enumerate(examples):
result = self.run_llm_or_chain(
example,
tracer,
llm_or_chain,
num_repetitions,
)
if verbose:
print(f"{i+1} processed", flush=True, end="\r")
results[str(example.id)] = result
return results

View File

@@ -0,0 +1,54 @@
from datetime import datetime
from typing import Any, Dict, List, Optional
from uuid import UUID
from pydantic import BaseModel, Field
from langchain.callbacks.tracers.schemas import Run
class ExampleBase(BaseModel):
"""Example base model."""
dataset_id: UUID
inputs: Dict[str, Any]
outputs: Optional[Dict[str, Any]] = Field(default=None)
class ExampleCreate(ExampleBase):
"""Example create model."""
id: Optional[UUID]
created_at: datetime = Field(default_factory=datetime.utcnow)
class Example(ExampleBase):
"""Example model."""
id: UUID
created_at: datetime
modified_at: Optional[datetime] = Field(default=None)
runs: List[Run] = Field(default_factory=list)
class DatasetBase(BaseModel):
"""Dataset base model."""
tenant_id: UUID
name: str
description: str
class DatasetCreate(DatasetBase):
"""Dataset create model."""
id: Optional[UUID]
created_at: datetime = Field(default_factory=datetime.utcnow)
class Dataset(DatasetBase):
"""Dataset ORM model."""
id: UUID
created_at: datetime
modified_at: Optional[datetime] = Field(default=None)

42
langchain/client/utils.py Normal file
View File

@@ -0,0 +1,42 @@
"""Client Utils."""
import re
from typing import Dict, List, Optional, Sequence, Type, Union
from langchain.schema import (
AIMessage,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
_DEFAULT_MESSAGES_T = Union[Type[HumanMessage], Type[SystemMessage], Type[AIMessage]]
_RESOLUTION_MAP: Dict[str, _DEFAULT_MESSAGES_T] = {
"Human": HumanMessage,
"AI": AIMessage,
"System": SystemMessage,
}
def parse_chat_messages(
input_text: str, roles: Optional[Sequence[str]] = None
) -> List[BaseMessage]:
"""Parse chat messages from a string. This is not robust."""
roles = roles or ["Human", "AI", "System"]
roles_pattern = "|".join(roles)
pattern = (
rf"(?P<entity>{roles_pattern}): (?P<message>"
rf"(?:.*\n?)*?)(?=(?:{roles_pattern}): |\Z)"
)
matches = re.finditer(pattern, input_text, re.MULTILINE)
results: List[BaseMessage] = []
for match in matches:
entity = match.group("entity")
message = match.group("message").rstrip("\n")
if entity in _RESOLUTION_MAP:
results.append(_RESOLUTION_MAP[entity](content=message))
else:
results.append(ChatMessage(role=entity, content=message))
return results

View File

@@ -2,6 +2,8 @@
import os
from typing import Any, Callable, Dict, Optional, Tuple
from requests import HTTPError, Response
def get_from_dict_or_env(
data: Dict[str, Any], key: str, env_key: str, default: Optional[str] = None
@@ -52,6 +54,14 @@ def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
return decorator
def raise_for_status_with_text(response: Response) -> None:
"""Raise an error with the response text."""
try:
response.raise_for_status()
except HTTPError as e:
raise ValueError(response.text) from e
def stringify_value(val: Any) -> str:
if isinstance(val, str):
return val

View File

@@ -18,7 +18,12 @@ from langchain.callbacks.tracers.base import (
TracerSession,
)
from langchain.callbacks.tracers.langchain import LangChainTracerV2
from langchain.callbacks.tracers.schemas import Run, TracerSessionBase, TracerSessionV2
from langchain.callbacks.tracers.schemas import (
RunCreate,
TracerSessionBase,
TracerSessionV2,
TracerSessionV2Create,
)
from langchain.schema import LLMResult
TEST_SESSION_ID = 2023
@@ -541,14 +546,12 @@ def sample_runs() -> Tuple[LLMRun, ChainRun, ToolRun]:
return llm_run, chain_run, tool_run
# Test _get_default_query_params method
def test_get_default_query_params(lang_chain_tracer_v2: LangChainTracerV2) -> None:
expected = {"tenant_id": "test-tenant-id"}
result = lang_chain_tracer_v2._get_default_query_params()
assert result == expected
# Test load_session method
@patch("langchain.callbacks.tracers.langchain.requests.get")
def test_load_session(
mock_requests_get: Mock,
@@ -577,23 +580,65 @@ def test_convert_run(
converted_chain_run = lang_chain_tracer_v2._convert_run(chain_run)
converted_tool_run = lang_chain_tracer_v2._convert_run(tool_run)
assert isinstance(converted_llm_run, Run)
assert isinstance(converted_chain_run, Run)
assert isinstance(converted_tool_run, Run)
assert isinstance(converted_llm_run, RunCreate)
assert isinstance(converted_chain_run, RunCreate)
assert isinstance(converted_tool_run, RunCreate)
@patch("langchain.callbacks.tracers.langchain.requests.post")
def test_persist_run(
mock_requests_post: Mock,
lang_chain_tracer_v2: LangChainTracerV2,
sample_tracer_session_v2: TracerSessionV2,
sample_runs: Tuple[LLMRun, ChainRun, ToolRun],
) -> None:
mock_requests_post.return_value.raise_for_status.return_value = None
lang_chain_tracer_v2.session = sample_tracer_session_v2
llm_run, chain_run, tool_run = sample_runs
lang_chain_tracer_v2._persist_run(llm_run)
lang_chain_tracer_v2._persist_run(chain_run)
lang_chain_tracer_v2._persist_run(tool_run)
"""Test that persist_run method calls requests.post once per method call."""
with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch(
"langchain.callbacks.tracers.langchain.requests.get"
) as get:
post.return_value.raise_for_status.return_value = None
lang_chain_tracer_v2.session = sample_tracer_session_v2
llm_run, chain_run, tool_run = sample_runs
lang_chain_tracer_v2._persist_run(llm_run)
lang_chain_tracer_v2._persist_run(chain_run)
lang_chain_tracer_v2._persist_run(tool_run)
assert mock_requests_post.call_count == 3
assert post.call_count == 3
assert get.call_count == 0
def test_get_session_create(lang_chain_tracer_v2: LangChainTracerV2) -> None:
"""Test creating the 'SessionCreate' object."""
lang_chain_tracer_v2.tenant_id = str(_TENANT_ID)
session_create = lang_chain_tracer_v2._get_session_create(name="test")
assert isinstance(session_create, TracerSessionV2Create)
assert session_create.name == "test"
assert session_create.tenant_id == _TENANT_ID
@patch("langchain.callbacks.tracers.langchain.requests.post")
def test_persist_session(
mock_requests_post: Mock,
lang_chain_tracer_v2: LangChainTracerV2,
sample_tracer_session_v2: TracerSessionV2,
) -> None:
"""Test persist_session returns a TracerSessionV2 with the updated ID."""
session_create = TracerSessionV2Create(**sample_tracer_session_v2.dict())
new_id = str(uuid4())
mock_requests_post.return_value.json.return_value = {"id": new_id}
result = lang_chain_tracer_v2._persist_session(session_create)
assert isinstance(result, TracerSessionV2)
res = sample_tracer_session_v2.dict()
res["id"] = UUID(new_id)
assert result.dict() == res
@patch("langchain.callbacks.tracers.langchain.LangChainTracerV2.load_session")
def test_load_default_session(
mock_load_session: Mock,
lang_chain_tracer_v2: LangChainTracerV2,
sample_tracer_session_v2: TracerSessionV2,
) -> None:
"""Test load_default_session attempts to load with the default name."""
mock_load_session.return_value = sample_tracer_session_v2
result = lang_chain_tracer_v2.load_default_session()
assert result == sample_tracer_session_v2
mock_load_session.assert_called_with("default")

View File

View File

@@ -0,0 +1,210 @@
"""Test the LangChain+ client."""
import uuid
from datetime import datetime
from io import BytesIO
from typing import Any, Dict, List, Union
from unittest import mock
import pytest
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.tracers.langchain import LangChainTracerV2
from langchain.callbacks.tracers.schemas import TracerSessionV2
from langchain.chains.base import Chain
from langchain.client.langchain import (
LangChainPlusClient,
_get_link_stem,
_is_localhost,
)
from langchain.client.models import Dataset, Example
_CREATED_AT = datetime(2015, 1, 1, 0, 0, 0)
_TENANT_ID = "7a3d2b56-cd5b-44e5-846f-7eb6e8144ce4"
@pytest.mark.parametrize(
"api_url, expected_url",
[
("http://localhost:8000", "http://localhost"),
("http://www.example.com", "http://www.example.com"),
(
"https://hosted-1234-23qwerty.f.234.foobar.gateway.dev",
"https://hosted-1234-23qwerty.f.234.foobar.gateway.dev",
),
("https://www.langchain.com/path/to/nowhere", "https://www.langchain.com"),
],
)
def test_link_split(api_url: str, expected_url: str) -> None:
"""Test the link splitting handles both localhost and deployed urls."""
assert _get_link_stem(api_url) == expected_url
def test_is_localhost() -> None:
assert _is_localhost("http://localhost:8000")
assert _is_localhost("http://127.0.0.1:8000")
assert _is_localhost("http://0.0.0.0:8000")
assert not _is_localhost("http://example.com:8000")
def test_validate_api_key_if_hosted() -> None:
with pytest.raises(ValueError, match="API key must be provided"):
LangChainPlusClient(api_url="http://www.example.com")
client = LangChainPlusClient(api_url="http://localhost:8000")
assert client.api_url == "http://localhost:8000"
assert client.api_key is None
def test_headers() -> None:
client = LangChainPlusClient(api_url="http://localhost:8000", api_key="123")
assert client._headers == {"authorization": "Bearer 123"}
client_no_key = LangChainPlusClient(api_url="http://localhost:8000")
assert client_no_key._headers == {}
@mock.patch("langchain.client.langchain.requests.post")
def test_upload_csv(mock_post: mock.Mock) -> None:
mock_response = mock.Mock()
dataset_id = str(uuid.uuid4())
example_1 = Example(
id=str(uuid.uuid4()),
created_at=_CREATED_AT,
inputs={"input": "1"},
outputs={"output": "2"},
dataset_id=dataset_id,
)
example_2 = Example(
id=str(uuid.uuid4()),
created_at=_CREATED_AT,
inputs={"input": "3"},
outputs={"output": "4"},
dataset_id=dataset_id,
)
mock_response.json.return_value = {
"id": "1",
"name": "test.csv",
"description": "Test dataset",
"owner_id": "the owner",
"created_at": _CREATED_AT,
"examples": [example_1, example_2],
}
mock_post.return_value = mock_response
client = LangChainPlusClient(api_url="http://localhost:8000", api_key="123")
csv_file = ("test.csv", BytesIO(b"input,output\n1,2\n3,4\n"))
dataset = client.upload_csv(
csv_file, "Test dataset", input_keys=["input"], output_keys=["output"]
)
assert dataset.id == "1"
assert dataset.name == "test.csv"
assert dataset.description == "Test dataset"
@pytest.mark.asyncio
async def test_arun_on_dataset() -> None:
dataset = Dataset(
id="1",
name="test",
description="Test dataset",
owner_id="owner",
created_at=_CREATED_AT,
)
uuids = [
"0c193153-2309-4704-9a47-17aee4fb25c8",
"0d11b5fd-8e66-4485-b696-4b55155c0c05",
"90d696f0-f10d-4fd0-b88b-bfee6df08b84",
"4ce2c6d8-5124-4c0c-8292-db7bdebcf167",
"7b5a524c-80fa-4960-888e-7d380f9a11ee",
]
examples = [
Example(
id=uuids[0],
created_at=_CREATED_AT,
inputs={"input": "1"},
outputs={"output": "2"},
dataset_id=str(uuid.uuid4()),
),
Example(
id=uuids[1],
created_at=_CREATED_AT,
inputs={"input": "3"},
outputs={"output": "4"},
dataset_id=str(uuid.uuid4()),
),
Example(
id=uuids[2],
created_at=_CREATED_AT,
inputs={"input": "5"},
outputs={"output": "6"},
dataset_id=str(uuid.uuid4()),
),
Example(
id=uuids[3],
created_at=_CREATED_AT,
inputs={"input": "7"},
outputs={"output": "8"},
dataset_id=str(uuid.uuid4()),
),
Example(
id=uuids[4],
created_at=_CREATED_AT,
inputs={"input": "9"},
outputs={"output": "10"},
dataset_id=str(uuid.uuid4()),
),
]
async def mock_aread_dataset(*args: Any, **kwargs: Any) -> Dataset:
return dataset
async def mock_alist_examples(*args: Any, **kwargs: Any) -> List[Example]:
return examples
async def mock_arun_chain(
example: Example,
tracer: Any,
llm_or_chain: Union[BaseLanguageModel, Chain],
n_repetitions: int,
) -> List[Dict[str, Any]]:
return [
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
]
def mock_load_session(
self: Any, name: str, *args: Any, **kwargs: Any
) -> TracerSessionV2:
return TracerSessionV2(name=name, tenant_id=_TENANT_ID, id=uuid.uuid4())
with mock.patch.object(
LangChainPlusClient, "aread_dataset", new=mock_aread_dataset
), mock.patch.object(
LangChainPlusClient, "alist_examples", new=mock_alist_examples
), mock.patch.object(
LangChainPlusClient, "_arun_llm_or_chain", new=mock_arun_chain
), mock.patch.object(
LangChainTracerV2, "load_session", new=mock_load_session
):
client = LangChainPlusClient(
api_url="http://localhost:8000", api_key="123", tenant_id=_TENANT_ID
)
chain = mock.MagicMock()
results = await client.arun_on_dataset(
dataset_name="test",
llm_or_chain=chain,
num_workers=2,
session_name="test_session",
num_repetitions=3,
)
expected = {
uuid.UUID(uuid_): [
{"result": f"Result for example {uuid.UUID(uuid_)}"} for _ in range(3)
]
for uuid_ in uuids
}
assert results == expected

View File

@@ -0,0 +1,70 @@
"""Test LangChain+ Client Utils."""
from typing import List
from langchain.client.utils import parse_chat_messages
from langchain.schema import (
AIMessage,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
def test_parse_chat_messages() -> None:
"""Test that chat messages are parsed correctly."""
input_text = (
"Human: I am human roar\nAI: I am AI beep boop\nSystem: I am a system message"
)
expected = [
HumanMessage(content="I am human roar"),
AIMessage(content="I am AI beep boop"),
SystemMessage(content="I am a system message"),
]
assert parse_chat_messages(input_text) == expected
def test_parse_chat_messages_empty_input() -> None:
"""Test that an empty input string returns an empty list."""
input_text = ""
expected: List[BaseMessage] = []
assert parse_chat_messages(input_text) == expected
def test_parse_chat_messages_multiline_messages() -> None:
"""Test that multiline messages are parsed correctly."""
input_text = (
"Human: I am a human\nand I roar\nAI: I am an AI\nand I"
" beep boop\nSystem: I am a system\nand a message"
)
expected = [
HumanMessage(content="I am a human\nand I roar"),
AIMessage(content="I am an AI\nand I beep boop"),
SystemMessage(content="I am a system\nand a message"),
]
assert parse_chat_messages(input_text) == expected
def test_parse_chat_messages_custom_roles() -> None:
"""Test that custom roles are parsed correctly."""
input_text = "Client: I need help\nAgent: I'm here to help\nClient: Thank you"
expected = [
ChatMessage(role="Client", content="I need help"),
ChatMessage(role="Agent", content="I'm here to help"),
ChatMessage(role="Client", content="Thank you"),
]
assert parse_chat_messages(input_text, roles=["Client", "Agent"]) == expected
def test_parse_chat_messages_embedded_roles() -> None:
"""Test that messages with embedded role references are parsed correctly."""
input_text = (
"Human: Oh ai what if you said AI: foo bar?"
"\nAI: Well, that would be interesting!"
)
expected = [
HumanMessage(content="Oh ai what if you said AI: foo bar?"),
AIMessage(content="Well, that would be interesting!"),
]
assert parse_chat_messages(input_text) == expected