core, partners: implement standard tracing params for LLMs (#25410)

This commit is contained in:
ccurme 2024-08-16 13:18:09 -04:00 committed by GitHub
parent 9f0c76bf89
commit b83f1eb0d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 298 additions and 36 deletions

View File

@ -10,7 +10,7 @@ def test_standard_params() -> None:
class ExpectedParams(BaseModel): class ExpectedParams(BaseModel):
ls_provider: str ls_provider: str
ls_model_name: str ls_model_name: str
ls_model_type: Literal["chat"] ls_model_type: Literal["chat", "llm"]
ls_temperature: Optional[float] ls_temperature: Optional[float]
ls_max_tokens: Optional[int] ls_max_tokens: Optional[int]
ls_stop: Optional[List[str]] ls_stop: Optional[List[str]]

View File

@ -39,6 +39,7 @@ https://python.langchain.com/v0.2/docs/how_to/custom_llm/
from langchain_core.language_models.base import ( from langchain_core.language_models.base import (
BaseLanguageModel, BaseLanguageModel,
LangSmithParams,
LanguageModelInput, LanguageModelInput,
LanguageModelLike, LanguageModelLike,
LanguageModelOutput, LanguageModelOutput,
@ -62,6 +63,7 @@ __all__ = [
"LLM", "LLM",
"LanguageModelInput", "LanguageModelInput",
"get_tokenizer", "get_tokenizer",
"LangSmithParams",
"LanguageModelOutput", "LanguageModelOutput",
"LanguageModelLike", "LanguageModelLike",
"FakeListLLM", "FakeListLLM",

View File

@ -8,6 +8,7 @@ from typing import (
Callable, Callable,
Dict, Dict,
List, List,
Literal,
Mapping, Mapping,
Optional, Optional,
Sequence, Sequence,
@ -17,7 +18,7 @@ from typing import (
Union, Union,
) )
from typing_extensions import TypeAlias from typing_extensions import TypeAlias, TypedDict
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.messages import ( from langchain_core.messages import (
@ -37,6 +38,23 @@ if TYPE_CHECKING:
from langchain_core.outputs import LLMResult from langchain_core.outputs import LLMResult
class LangSmithParams(TypedDict, total=False):
"""LangSmith parameters for tracing."""
ls_provider: str
"""Provider of the model."""
ls_model_name: str
"""Name of the model."""
ls_model_type: Literal["chat", "llm"]
"""Type of the model. Should be 'chat' or 'llm'."""
ls_temperature: Optional[float]
"""Temperature for generation."""
ls_max_tokens: Optional[int]
"""Max tokens for generation."""
ls_stop: Optional[List[str]]
"""Stop words for generation."""
@lru_cache(maxsize=None) # Cache the tokenizer @lru_cache(maxsize=None) # Cache the tokenizer
def get_tokenizer() -> Any: def get_tokenizer() -> Any:
"""Get a GPT-2 tokenizer instance. """Get a GPT-2 tokenizer instance.

View File

@ -23,8 +23,6 @@ from typing import (
cast, cast,
) )
from typing_extensions import TypedDict
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.caches import BaseCache from langchain_core.caches import BaseCache
from langchain_core.callbacks import ( from langchain_core.callbacks import (
@ -36,7 +34,11 @@ from langchain_core.callbacks import (
Callbacks, Callbacks,
) )
from langchain_core.globals import get_llm_cache from langchain_core.globals import get_llm_cache
from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput from langchain_core.language_models.base import (
BaseLanguageModel,
LangSmithParams,
LanguageModelInput,
)
from langchain_core.load import dumpd, dumps from langchain_core.load import dumpd, dumps
from langchain_core.messages import ( from langchain_core.messages import (
AIMessage, AIMessage,
@ -73,23 +75,6 @@ if TYPE_CHECKING:
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
class LangSmithParams(TypedDict, total=False):
"""LangSmith parameters for tracing."""
ls_provider: str
"""Provider of the model."""
ls_model_name: str
"""Name of the model."""
ls_model_type: Literal["chat"]
"""Type of the model. Should be 'chat'."""
ls_temperature: Optional[float]
"""Temperature for generation."""
ls_max_tokens: Optional[int]
"""Max tokens for generation."""
ls_stop: Optional[List[str]]
"""Stop words for generation."""
def generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult: def generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
"""Generate from a stream. """Generate from a stream.

View File

@ -48,7 +48,11 @@ from langchain_core.callbacks import (
Callbacks, Callbacks,
) )
from langchain_core.globals import get_llm_cache from langchain_core.globals import get_llm_cache
from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput from langchain_core.language_models.base import (
BaseLanguageModel,
LangSmithParams,
LanguageModelInput,
)
from langchain_core.load import dumpd from langchain_core.load import dumpd
from langchain_core.messages import ( from langchain_core.messages import (
AIMessage, AIMessage,
@ -331,6 +335,43 @@ class BaseLLM(BaseLanguageModel[str], ABC):
"Must be a PromptValue, str, or list of BaseMessages." "Must be a PromptValue, str, or list of BaseMessages."
) )
def _get_ls_params(
self,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> LangSmithParams:
"""Get standard params for tracing."""
# get default provider from class name
default_provider = self.__class__.__name__
if default_provider.endswith("LLM"):
default_provider = default_provider[:-3]
default_provider = default_provider.lower()
ls_params = LangSmithParams(ls_provider=default_provider, ls_model_type="llm")
if stop:
ls_params["ls_stop"] = stop
# model
if hasattr(self, "model") and isinstance(self.model, str):
ls_params["ls_model_name"] = self.model
elif hasattr(self, "model_name") and isinstance(self.model_name, str):
ls_params["ls_model_name"] = self.model_name
# temperature
if "temperature" in kwargs and isinstance(kwargs["temperature"], float):
ls_params["ls_temperature"] = kwargs["temperature"]
elif hasattr(self, "temperature") and isinstance(self.temperature, float):
ls_params["ls_temperature"] = self.temperature
# max_tokens
if "max_tokens" in kwargs and isinstance(kwargs["max_tokens"], int):
ls_params["ls_max_tokens"] = kwargs["max_tokens"]
elif hasattr(self, "max_tokens") and isinstance(self.max_tokens, int):
ls_params["ls_max_tokens"] = self.max_tokens
return ls_params
def invoke( def invoke(
self, self,
input: LanguageModelInput, input: LanguageModelInput,
@ -487,13 +528,17 @@ class BaseLLM(BaseLanguageModel[str], ABC):
params["stop"] = stop params["stop"] = stop
params = {**params, **kwargs} params = {**params, **kwargs}
options = {"stop": stop} options = {"stop": stop}
inheritable_metadata = {
**(config.get("metadata") or {}),
**self._get_ls_params(stop=stop, **kwargs),
}
callback_manager = CallbackManager.configure( callback_manager = CallbackManager.configure(
config.get("callbacks"), config.get("callbacks"),
self.callbacks, self.callbacks,
self.verbose, self.verbose,
config.get("tags"), config.get("tags"),
self.tags, self.tags,
config.get("metadata"), inheritable_metadata,
self.metadata, self.metadata,
) )
(run_manager,) = callback_manager.on_llm_start( (run_manager,) = callback_manager.on_llm_start(
@ -548,13 +593,17 @@ class BaseLLM(BaseLanguageModel[str], ABC):
params["stop"] = stop params["stop"] = stop
params = {**params, **kwargs} params = {**params, **kwargs}
options = {"stop": stop} options = {"stop": stop}
inheritable_metadata = {
**(config.get("metadata") or {}),
**self._get_ls_params(stop=stop, **kwargs),
}
callback_manager = AsyncCallbackManager.configure( callback_manager = AsyncCallbackManager.configure(
config.get("callbacks"), config.get("callbacks"),
self.callbacks, self.callbacks,
self.verbose, self.verbose,
config.get("tags"), config.get("tags"),
self.tags, self.tags,
config.get("metadata"), inheritable_metadata,
self.metadata, self.metadata,
) )
(run_manager,) = await callback_manager.on_llm_start( (run_manager,) = await callback_manager.on_llm_start(
@ -796,6 +845,21 @@ class BaseLLM(BaseLanguageModel[str], ABC):
f" argument of type {type(prompts)}." f" argument of type {type(prompts)}."
) )
# Create callback managers # Create callback managers
if isinstance(metadata, list):
metadata = [
{
**(meta or {}),
**self._get_ls_params(stop=stop, **kwargs),
}
for meta in metadata
]
elif isinstance(metadata, dict):
metadata = {
**(metadata or {}),
**self._get_ls_params(stop=stop, **kwargs),
}
else:
pass
if ( if (
isinstance(callbacks, list) isinstance(callbacks, list)
and callbacks and callbacks
@ -1017,6 +1081,21 @@ class BaseLLM(BaseLanguageModel[str], ABC):
An LLMResult, which contains a list of candidate Generations for each input An LLMResult, which contains a list of candidate Generations for each input
prompt and additional model provider-specific output. prompt and additional model provider-specific output.
""" """
if isinstance(metadata, list):
metadata = [
{
**(meta or {}),
**self._get_ls_params(stop=stop, **kwargs),
}
for meta in metadata
]
elif isinstance(metadata, dict):
metadata = {
**(metadata or {}),
**self._get_ls_params(stop=stop, **kwargs),
}
else:
pass
# Create callback managers # Create callback managers
if isinstance(callbacks, list) and ( if isinstance(callbacks, list) and (
isinstance(callbacks[0], (list, BaseCallbackManager)) isinstance(callbacks[0], (list, BaseCallbackManager))

View File

@ -6,6 +6,7 @@ EXPECTED_ALL = [
"SimpleChatModel", "SimpleChatModel",
"BaseLLM", "BaseLLM",
"LLM", "LLM",
"LangSmithParams",
"LanguageModelInput", "LanguageModelInput",
"LanguageModelOutput", "LanguageModelOutput",
"LanguageModelLike", "LanguageModelLike",

File diff suppressed because one or more lines are too long

View File

@ -2180,7 +2180,7 @@ async def test_prompt_with_llm(
"value": { "value": {
"end_time": None, "end_time": None,
"final_output": None, "final_output": None,
"metadata": {}, "metadata": {"ls_model_type": "llm", "ls_provider": "fakelist"},
"name": "FakeListLLM", "name": "FakeListLLM",
"start_time": "2023-01-01T00:00:00.000+00:00", "start_time": "2023-01-01T00:00:00.000+00:00",
"streamed_output": [], "streamed_output": [],
@ -2384,7 +2384,10 @@ async def test_prompt_with_llm_parser(
"value": { "value": {
"end_time": None, "end_time": None,
"final_output": None, "final_output": None,
"metadata": {}, "metadata": {
"ls_model_type": "llm",
"ls_provider": "fakestreaminglist",
},
"name": "FakeStreamingListLLM", "name": "FakeStreamingListLLM",
"start_time": "2023-01-01T00:00:00.000+00:00", "start_time": "2023-01-01T00:00:00.000+00:00",
"streamed_output": [], "streamed_output": [],

View File

@ -17,7 +17,7 @@ from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel, LangSmithParams
from langchain_core.language_models.llms import LLM from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk from langchain_core.outputs import GenerationChunk
from langchain_core.prompt_values import PromptValue from langchain_core.prompt_values import PromptValue
@ -204,6 +204,19 @@ class AnthropicLLM(LLM, _AnthropicCommon):
"max_retries": self.max_retries, "max_retries": self.max_retries,
} }
def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> LangSmithParams:
"""Get standard params for tracing."""
params = super()._get_ls_params(stop=stop, **kwargs)
identifying_params = self._identifying_params
if max_tokens := kwargs.get(
"max_tokens_to_sample",
identifying_params.get("max_tokens"),
):
params["ls_max_tokens"] = max_tokens
return params
def _wrap_prompt(self, prompt: str) -> str: def _wrap_prompt(self, prompt: str) -> str:
if not self.HUMAN_PROMPT or not self.AI_PROMPT: if not self.HUMAN_PROMPT or not self.AI_PROMPT:
raise NameError("Please ensure the anthropic package is loaded") raise NameError("Please ensure the anthropic package is loaded")

View File

@ -0,0 +1,29 @@
import os
from langchain_anthropic import AnthropicLLM
os.environ["ANTHROPIC_API_KEY"] = "foo"
def test_anthropic_model_params() -> None:
# Test standard tracing params
llm = AnthropicLLM(model="foo") # type: ignore[call-arg]
ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "anthropic",
"ls_model_type": "llm",
"ls_model_name": "foo",
"ls_max_tokens": 1024,
}
llm = AnthropicLLM(model="foo", temperature=0.1) # type: ignore[call-arg]
ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "anthropic",
"ls_model_type": "llm",
"ls_model_name": "foo",
"ls_max_tokens": 1024,
"ls_temperature": 0.1,
}

View File

@ -69,3 +69,31 @@ def test_fireworks_uses_actual_secret_value_from_secretstr() -> None:
max_tokens=250, max_tokens=250,
) )
assert cast(SecretStr, llm.fireworks_api_key).get_secret_value() == "secret-api-key" assert cast(SecretStr, llm.fireworks_api_key).get_secret_value() == "secret-api-key"
def test_fireworks_model_params() -> None:
# Test standard tracing params
llm = Fireworks(model="foo", api_key="secret-api-key") # type: ignore[arg-type]
ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "fireworks",
"ls_model_type": "llm",
"ls_model_name": "foo",
}
llm = Fireworks(
model="foo",
api_key="secret-api-key", # type: ignore[arg-type]
max_tokens=10,
temperature=0.1,
)
ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "fireworks",
"ls_model_type": "llm",
"ls_model_name": "foo",
"ls_max_tokens": 10,
"ls_temperature": 0.1,
}

View File

@ -16,7 +16,7 @@ from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain_core.language_models import BaseLLM from langchain_core.language_models import BaseLLM, LangSmithParams
from langchain_core.outputs import GenerationChunk, LLMResult from langchain_core.outputs import GenerationChunk, LLMResult
from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.pydantic_v1 import Field, root_validator
from ollama import AsyncClient, Client, Options from ollama import AsyncClient, Client, Options
@ -155,6 +155,15 @@ class OllamaLLM(BaseLLM):
"""Return type of LLM.""" """Return type of LLM."""
return "ollama-llm" return "ollama-llm"
def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> LangSmithParams:
"""Get standard params for tracing."""
params = super()._get_ls_params(stop=stop, **kwargs)
if max_tokens := kwargs.get("num_predict", self.num_predict):
params["ls_max_tokens"] = max_tokens
return params
@root_validator(pre=False, skip_on_failure=True) @root_validator(pre=False, skip_on_failure=True)
def _set_clients(cls, values: dict) -> dict: def _set_clients(cls, values: dict) -> dict:
"""Set clients to use for ollama.""" """Set clients to use for ollama."""

View File

@ -6,3 +6,23 @@ from langchain_ollama import OllamaLLM
def test_initialization() -> None: def test_initialization() -> None:
"""Test integration initialization.""" """Test integration initialization."""
OllamaLLM(model="llama3") OllamaLLM(model="llama3")
def test_model_params() -> None:
# Test standard tracing params
llm = OllamaLLM(model="llama3")
ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "ollama",
"ls_model_type": "llm",
"ls_model_name": "llama3",
}
llm = OllamaLLM(model="llama3", num_predict=3)
ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "ollama",
"ls_model_type": "llm",
"ls_model_name": "llama3",
"ls_max_tokens": 3,
}

View File

@ -5,6 +5,7 @@ import os
from typing import Any, Callable, Dict, List, Mapping, Optional, Union from typing import Any, Callable, Dict, List, Mapping, Optional, Union
import openai import openai
from langchain_core.language_models import LangSmithParams
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
@ -195,6 +196,17 @@ class AzureOpenAI(BaseOpenAI):
openai_params = {"model": self.deployment_name} openai_params = {"model": self.deployment_name}
return {**openai_params, **super()._invocation_params} return {**openai_params, **super()._invocation_params}
def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> LangSmithParams:
"""Get standard params for tracing."""
params = super()._get_ls_params(stop=stop, **kwargs)
invocation_params = self._invocation_params
params["ls_provider"] = "azure"
if model_name := invocation_params.get("model"):
params["ls_model_name"] = model_name
return params
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
"""Return type of llm.""" """Return type of llm."""

View File

@ -0,0 +1,23 @@
from typing import Any
from langchain_openai import AzureOpenAI
def test_azure_model_param(monkeypatch: Any) -> None:
monkeypatch.delenv("OPENAI_API_BASE", raising=False)
llm = AzureOpenAI(
openai_api_key="secret-api-key", # type: ignore[call-arg]
azure_endpoint="endpoint",
api_version="version",
azure_deployment="gpt-35-turbo-instruct",
)
# Test standard tracing params
ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "azure",
"ls_model_type": "llm",
"ls_model_name": "gpt-35-turbo-instruct",
"ls_temperature": 0.7,
"ls_max_tokens": 256,
}

View File

@ -14,6 +14,16 @@ def test_openai_model_param() -> None:
llm = OpenAI(model_name="foo") # type: ignore[call-arg] llm = OpenAI(model_name="foo") # type: ignore[call-arg]
assert llm.model_name == "foo" assert llm.model_name == "foo"
# Test standard tracing params
ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "openai",
"ls_model_type": "llm",
"ls_model_name": "foo",
"ls_temperature": 0.7,
"ls_max_tokens": 256,
}
def test_openai_model_kwargs() -> None: def test_openai_model_kwargs() -> None:
llm = OpenAI(model_kwargs={"foo": "bar"}) llm = OpenAI(model_kwargs={"foo": "bar"})

View File

@ -69,3 +69,33 @@ def test_together_uses_actual_secret_value_from_secretstr_api_key() -> None:
max_tokens=250, max_tokens=250,
) )
assert cast(SecretStr, llm.together_api_key).get_secret_value() == "secret-api-key" assert cast(SecretStr, llm.together_api_key).get_secret_value() == "secret-api-key"
def test_together_model_params() -> None:
# Test standard tracing params
llm = Together(
api_key="secret-api-key", # type: ignore[arg-type]
model="foo",
)
ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "together",
"ls_model_type": "llm",
"ls_model_name": "foo",
"ls_max_tokens": 200,
}
llm = Together(
api_key="secret-api-key", # type: ignore[arg-type]
model="foo",
temperature=0.2,
max_tokens=250,
)
ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "together",
"ls_model_type": "llm",
"ls_model_name": "foo",
"ls_max_tokens": 250,
"ls_temperature": 0.2,
}