mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 00:23:25 +00:00
Add Tags for LLMs (#6229)
- [x] Add tracing tags to LLMs + Chat Models (both inheritable and local) - [x] Add tags for the run_on_dataset helper function(s)
This commit is contained in:
parent
8e1a7a8646
commit
ae76e473e1
@ -39,6 +39,8 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
"""Whether to print out response text."""
|
||||
callbacks: Callbacks = Field(default=None, exclude=True)
|
||||
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
|
||||
tags: Optional[List[str]] = Field(default=None, exclude=True)
|
||||
"""Tags to add to the run trace."""
|
||||
|
||||
@root_validator()
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
@ -65,6 +67,8 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Top Level call"""
|
||||
@ -74,7 +78,11 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
options = {"stop": stop}
|
||||
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
callbacks,
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
tags,
|
||||
self.tags,
|
||||
)
|
||||
run_manager = callback_manager.on_chat_model_start(
|
||||
dumpd(self), messages, invocation_params=params, options=options
|
||||
@ -106,6 +114,8 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Top Level call"""
|
||||
@ -114,7 +124,11 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
options = {"stop": stop}
|
||||
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
callbacks,
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
tags,
|
||||
self.tags,
|
||||
)
|
||||
run_manager = await callback_manager.on_chat_model_start(
|
||||
dumpd(self), messages, invocation_params=params, options=options
|
||||
|
@ -5,7 +5,16 @@ import asyncio
|
||||
import functools
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Coroutine, Dict, Iterator, List, Optional, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchainplus_sdk import LangChainPlusClient
|
||||
from langchainplus_sdk.schemas import Example
|
||||
@ -104,6 +113,8 @@ async def _arun_llm(
|
||||
llm: BaseLanguageModel,
|
||||
inputs: Dict[str, Any],
|
||||
langchain_tracer: Optional[LangChainTracer],
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
) -> Union[LLMResult, ChatResult]:
|
||||
callbacks: Optional[List[BaseCallbackHandler]] = (
|
||||
[langchain_tracer] if langchain_tracer else None
|
||||
@ -111,21 +122,27 @@ async def _arun_llm(
|
||||
if isinstance(llm, BaseLLM):
|
||||
try:
|
||||
llm_prompts = _get_prompts(inputs)
|
||||
llm_output = await llm.agenerate(llm_prompts, callbacks=callbacks)
|
||||
llm_output = await llm.agenerate(
|
||||
llm_prompts, callbacks=callbacks, tags=tags
|
||||
)
|
||||
except InputFormatError:
|
||||
llm_messages = _get_messages(inputs)
|
||||
buffer_strings = [get_buffer_string(messages) for messages in llm_messages]
|
||||
llm_output = await llm.agenerate(buffer_strings, callbacks=callbacks)
|
||||
llm_output = await llm.agenerate(
|
||||
buffer_strings, callbacks=callbacks, tags=tags
|
||||
)
|
||||
elif isinstance(llm, BaseChatModel):
|
||||
try:
|
||||
messages = _get_messages(inputs)
|
||||
llm_output = await llm.agenerate(messages, callbacks=callbacks)
|
||||
llm_output = await llm.agenerate(messages, callbacks=callbacks, tags=tags)
|
||||
except InputFormatError:
|
||||
prompts = _get_prompts(inputs)
|
||||
converted_messages: List[List[BaseMessage]] = [
|
||||
[HumanMessage(content=prompt)] for prompt in prompts
|
||||
]
|
||||
llm_output = await llm.agenerate(converted_messages, callbacks=callbacks)
|
||||
llm_output = await llm.agenerate(
|
||||
converted_messages, callbacks=callbacks, tags=tags
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM type {type(llm)}")
|
||||
return llm_output
|
||||
@ -136,6 +153,8 @@ async def _arun_llm_or_chain(
|
||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||
n_repetitions: int,
|
||||
langchain_tracer: Optional[LangChainTracer],
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
|
||||
"""Run the chain asynchronously."""
|
||||
if langchain_tracer is not None:
|
||||
@ -150,11 +169,16 @@ async def _arun_llm_or_chain(
|
||||
try:
|
||||
if isinstance(llm_or_chain_factory, BaseLanguageModel):
|
||||
output: Any = await _arun_llm(
|
||||
llm_or_chain_factory, example.inputs, langchain_tracer
|
||||
llm_or_chain_factory,
|
||||
example.inputs,
|
||||
langchain_tracer,
|
||||
tags=tags,
|
||||
)
|
||||
else:
|
||||
chain = llm_or_chain_factory()
|
||||
output = await chain.acall(example.inputs, callbacks=callbacks)
|
||||
output = await chain.acall(
|
||||
example.inputs, callbacks=callbacks, tags=tags
|
||||
)
|
||||
outputs.append(output)
|
||||
except Exception as e:
|
||||
logger.warning(f"Chain failed for example {example.id}. Error: {e}")
|
||||
@ -230,6 +254,7 @@ async def arun_on_examples(
|
||||
num_repetitions: int = 1,
|
||||
session_name: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
tags: Optional[List[str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run the chain on examples and store traces to the specified session name.
|
||||
@ -245,6 +270,7 @@ async def arun_on_examples(
|
||||
intervals.
|
||||
session_name: Session name to use when tracing runs.
|
||||
verbose: Whether to print progress.
|
||||
tags: Tags to add to the traces.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping example ids to the model outputs.
|
||||
@ -260,6 +286,7 @@ async def arun_on_examples(
|
||||
llm_or_chain_factory,
|
||||
num_repetitions,
|
||||
tracer,
|
||||
tags=tags,
|
||||
)
|
||||
results[str(example.id)] = result
|
||||
job_state["num_processed"] += 1
|
||||
@ -282,12 +309,14 @@ def run_llm(
|
||||
llm: BaseLanguageModel,
|
||||
inputs: Dict[str, Any],
|
||||
callbacks: Callbacks,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
) -> Union[LLMResult, ChatResult]:
|
||||
"""Run the language model on the example."""
|
||||
if isinstance(llm, BaseLLM):
|
||||
try:
|
||||
llm_prompts = _get_prompts(inputs)
|
||||
llm_output = llm.generate(llm_prompts, callbacks=callbacks)
|
||||
llm_output = llm.generate(llm_prompts, callbacks=callbacks, tags=tags)
|
||||
except InputFormatError:
|
||||
llm_messages = _get_messages(inputs)
|
||||
buffer_strings = [get_buffer_string(messages) for messages in llm_messages]
|
||||
@ -295,13 +324,15 @@ def run_llm(
|
||||
elif isinstance(llm, BaseChatModel):
|
||||
try:
|
||||
messages = _get_messages(inputs)
|
||||
llm_output = llm.generate(messages, callbacks=callbacks)
|
||||
llm_output = llm.generate(messages, callbacks=callbacks, tags=tags)
|
||||
except InputFormatError:
|
||||
prompts = _get_prompts(inputs)
|
||||
converted_messages: List[List[BaseMessage]] = [
|
||||
[HumanMessage(content=prompt)] for prompt in prompts
|
||||
]
|
||||
llm_output = llm.generate(converted_messages, callbacks=callbacks)
|
||||
llm_output = llm.generate(
|
||||
converted_messages, callbacks=callbacks, tags=tags
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM type {type(llm)}")
|
||||
return llm_output
|
||||
@ -312,6 +343,8 @@ def run_llm_or_chain(
|
||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||
n_repetitions: int,
|
||||
langchain_tracer: Optional[LangChainTracer] = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
|
||||
"""Run the chain synchronously."""
|
||||
if langchain_tracer is not None:
|
||||
@ -325,10 +358,12 @@ def run_llm_or_chain(
|
||||
for _ in range(n_repetitions):
|
||||
try:
|
||||
if isinstance(llm_or_chain_factory, BaseLanguageModel):
|
||||
output: Any = run_llm(llm_or_chain_factory, example.inputs, callbacks)
|
||||
output: Any = run_llm(
|
||||
llm_or_chain_factory, example.inputs, callbacks, tags=tags
|
||||
)
|
||||
else:
|
||||
chain = llm_or_chain_factory()
|
||||
output = chain(example.inputs, callbacks=callbacks)
|
||||
output = chain(example.inputs, callbacks=callbacks, tags=tags)
|
||||
outputs.append(output)
|
||||
except Exception as e:
|
||||
logger.warning(f"Chain failed for example {example.id}. Error: {e}")
|
||||
@ -345,6 +380,7 @@ def run_on_examples(
|
||||
num_repetitions: int = 1,
|
||||
session_name: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
tags: Optional[List[str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the chain on examples and store traces to the specified session name.
|
||||
|
||||
@ -359,6 +395,7 @@ def run_on_examples(
|
||||
intervals.
|
||||
session_name: Session name to use when tracing runs.
|
||||
verbose: Whether to print progress.
|
||||
tags: Tags to add to the run traces.
|
||||
Returns:
|
||||
A dictionary mapping example ids to the model outputs.
|
||||
"""
|
||||
@ -370,6 +407,7 @@ def run_on_examples(
|
||||
llm_or_chain_factory,
|
||||
num_repetitions,
|
||||
langchain_tracer=tracer,
|
||||
tags=tags,
|
||||
)
|
||||
if verbose:
|
||||
print(f"{i+1} processed", flush=True, end="\r")
|
||||
@ -401,6 +439,7 @@ async def arun_on_dataset(
|
||||
session_name: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
client: Optional[LangChainPlusClient] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run the chain on a dataset and store traces to the specified session name.
|
||||
@ -420,6 +459,7 @@ async def arun_on_dataset(
|
||||
verbose: Whether to print progress.
|
||||
client: Client to use to read the dataset. If not provided, a new
|
||||
client will be created using the credentials in the environment.
|
||||
tags: Tags to add to each run in the sesssion.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the run's session name and the resulting model outputs.
|
||||
@ -436,6 +476,7 @@ async def arun_on_dataset(
|
||||
num_repetitions=num_repetitions,
|
||||
session_name=session_name,
|
||||
verbose=verbose,
|
||||
tags=tags,
|
||||
)
|
||||
return {
|
||||
"session_name": session_name,
|
||||
@ -451,6 +492,7 @@ def run_on_dataset(
|
||||
session_name: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
client: Optional[LangChainPlusClient] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the chain on a dataset and store traces to the specified session name.
|
||||
|
||||
@ -468,6 +510,7 @@ def run_on_dataset(
|
||||
verbose: Whether to print progress.
|
||||
client: Client to use to access the dataset. If None, a new client
|
||||
will be created using the credentials in the environment.
|
||||
tags: Tags to add to each run in the sesssion.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the run's session name and the resulting model outputs.
|
||||
@ -482,6 +525,7 @@ def run_on_dataset(
|
||||
num_repetitions=num_repetitions,
|
||||
session_name=session_name,
|
||||
verbose=verbose,
|
||||
tags=tags,
|
||||
)
|
||||
return {
|
||||
"session_name": session_name,
|
||||
|
@ -369,6 +369,7 @@
|
||||
"\u001b[0;34m\u001b[0m \u001b[0msession_name\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Optional[str]'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
||||
"\u001b[0;34m\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'bool'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
||||
"\u001b[0;34m\u001b[0m \u001b[0mclient\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Optional[LangChainPlusClient]'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
||||
"\u001b[0;34m\u001b[0m \u001b[0mtags\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Optional[List[str]]'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
|
||||
"\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;34m'Dict[str, Any]'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;31mDocstring:\u001b[0m\n",
|
||||
"Run the chain on a dataset and store traces to the specified session name.\n",
|
||||
@ -388,6 +389,7 @@
|
||||
" verbose: Whether to print progress.\n",
|
||||
" client: Client to use to read the dataset. If not provided, a new\n",
|
||||
" client will be created using the credentials in the environment.\n",
|
||||
" tags: Tags to add to each run in the sesssion.\n",
|
||||
"\n",
|
||||
"Returns:\n",
|
||||
" A dictionary containing the run's session name and the resulting model outputs.\n",
|
||||
@ -430,7 +432,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 13,
|
||||
"id": "a8088b7d-3ab6-4279-94c8-5116fe7cee33",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@ -440,21 +442,21 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Processed examples: 4\r"
|
||||
"Processed examples: 1\r"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Chain failed for example c855f923-4165-4fe0-a909-360749f3f764. Error: Could not parse LLM output: `The final answer is that there were no more points scored in the 2023 Super Bowl than in the 2022 Super Bowl.`\n"
|
||||
"Chain failed for example b36a82d3-4fb6-4bc4-87df-b7c355742b8e. Error: unknown format from LLM: Sorry, I cannot answer this question as it requires information that is not currently available.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Processed examples: 5\r"
|
||||
"Processed examples: 6\r"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -465,6 +467,7 @@
|
||||
" concurrency_level=5, # Optional, sets the number of examples to run at a time\n",
|
||||
" verbose=True,\n",
|
||||
" client=client,\n",
|
||||
" tags=[\"testing-notebook\", \"turbo\"], # Optional, adds a tag to the resulting chain runs\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Sometimes, the agent will error due to parsing issues, incompatible tool inputs, etc.\n",
|
||||
@ -486,7 +489,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 14,
|
||||
"id": "136db492-d6ca-4215-96f9-439c23538241",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@ -501,7 +504,7 @@
|
||||
"LangChainPlusClient (API URL: https://dev.api.langchain.plus)"
|
||||
]
|
||||
},
|
||||
"execution_count": 13,
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -534,7 +537,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 15,
|
||||
"id": "35db4025-9183-4e5f-ba14-0b1b380f49c7",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@ -565,7 +568,7 @@
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "9989f6507cd04ea7a09ea3c5723dc984",
|
||||
"model_id": "5fce1ce42a8c4110b7d12443948ac697",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
@ -592,12 +595,26 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 17,
|
||||
"id": "8696f167-dc75-4ef8-8bb3-ac1ce8324f30",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<a href=\"https://dev.langchain.plus\", target=\"_blank\" rel=\"noopener\">LangChain+ Client</a>"
|
||||
],
|
||||
"text/plain": [
|
||||
"LangChainPlusClient (API URL: https://dev.api.langchain.plus)"
|
||||
]
|
||||
},
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"client"
|
||||
]
|
||||
|
@ -79,6 +79,8 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
"""Whether to print out response text."""
|
||||
callbacks: Callbacks = Field(default=None, exclude=True)
|
||||
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
|
||||
tags: Optional[List[str]] = Field(default=None, exclude=True)
|
||||
"""Tags to add to the run trace."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -155,6 +157,8 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
@ -176,7 +180,7 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
) = get_prompts(params, prompts)
|
||||
disregard_cache = self.cache is not None and not self.cache
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
callbacks, self.callbacks, self.verbose, tags, self.tags
|
||||
)
|
||||
new_arg_supported = inspect.signature(self._generate).parameters.get(
|
||||
"run_manager"
|
||||
@ -241,6 +245,8 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
@ -255,7 +261,7 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
) = get_prompts(params, prompts)
|
||||
disregard_cache = self.cache is not None and not self.cache
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
callbacks, self.callbacks, self.verbose, tags, self.tags
|
||||
)
|
||||
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
|
||||
"run_manager"
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Test the LangChain+ client."""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
@ -170,6 +170,7 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
llm_or_chain: Union[BaseLanguageModel, Chain],
|
||||
n_repetitions: int,
|
||||
tracer: Any,
|
||||
tags: Optional[List[str]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
return [
|
||||
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
|
||||
|
Loading…
Reference in New Issue
Block a user