core, partners: implement standard tracing params for LLMs (#25410)

This commit is contained in:
ccurme 2024-08-16 13:18:09 -04:00 committed by GitHub
parent 9f0c76bf89
commit b83f1eb0d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 298 additions and 36 deletions

View File

@ -10,7 +10,7 @@ def test_standard_params() -> None:
class ExpectedParams(BaseModel):
ls_provider: str
ls_model_name: str
ls_model_type: Literal["chat"]
ls_model_type: Literal["chat", "llm"]
ls_temperature: Optional[float]
ls_max_tokens: Optional[int]
ls_stop: Optional[List[str]]

View File

@ -39,6 +39,7 @@ https://python.langchain.com/v0.2/docs/how_to/custom_llm/
from langchain_core.language_models.base import (
BaseLanguageModel,
LangSmithParams,
LanguageModelInput,
LanguageModelLike,
LanguageModelOutput,
@ -62,6 +63,7 @@ __all__ = [
"LLM",
"LanguageModelInput",
"get_tokenizer",
"LangSmithParams",
"LanguageModelOutput",
"LanguageModelLike",
"FakeListLLM",

View File

@ -8,6 +8,7 @@ from typing import (
Callable,
Dict,
List,
Literal,
Mapping,
Optional,
Sequence,
@ -17,7 +18,7 @@ from typing import (
Union,
)
from typing_extensions import TypeAlias
from typing_extensions import TypeAlias, TypedDict
from langchain_core._api import deprecated
from langchain_core.messages import (
@ -37,6 +38,23 @@ if TYPE_CHECKING:
from langchain_core.outputs import LLMResult
class LangSmithParams(TypedDict, total=False):
"""LangSmith parameters for tracing."""
ls_provider: str
"""Provider of the model."""
ls_model_name: str
"""Name of the model."""
ls_model_type: Literal["chat", "llm"]
"""Type of the model. Should be 'chat' or 'llm'."""
ls_temperature: Optional[float]
"""Temperature for generation."""
ls_max_tokens: Optional[int]
"""Max tokens for generation."""
ls_stop: Optional[List[str]]
"""Stop words for generation."""
@lru_cache(maxsize=None) # Cache the tokenizer
def get_tokenizer() -> Any:
"""Get a GPT-2 tokenizer instance.

View File

@ -23,8 +23,6 @@ from typing import (
cast,
)
from typing_extensions import TypedDict
from langchain_core._api import deprecated
from langchain_core.caches import BaseCache
from langchain_core.callbacks import (
@ -36,7 +34,11 @@ from langchain_core.callbacks import (
Callbacks,
)
from langchain_core.globals import get_llm_cache
from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput
from langchain_core.language_models.base import (
BaseLanguageModel,
LangSmithParams,
LanguageModelInput,
)
from langchain_core.load import dumpd, dumps
from langchain_core.messages import (
AIMessage,
@ -73,23 +75,6 @@ if TYPE_CHECKING:
from langchain_core.tools import BaseTool
class LangSmithParams(TypedDict, total=False):
"""LangSmith parameters for tracing."""
ls_provider: str
"""Provider of the model."""
ls_model_name: str
"""Name of the model."""
ls_model_type: Literal["chat"]
"""Type of the model. Should be 'chat'."""
ls_temperature: Optional[float]
"""Temperature for generation."""
ls_max_tokens: Optional[int]
"""Max tokens for generation."""
ls_stop: Optional[List[str]]
"""Stop words for generation."""
def generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
"""Generate from a stream.

View File

@ -48,7 +48,11 @@ from langchain_core.callbacks import (
Callbacks,
)
from langchain_core.globals import get_llm_cache
from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput
from langchain_core.language_models.base import (
BaseLanguageModel,
LangSmithParams,
LanguageModelInput,
)
from langchain_core.load import dumpd
from langchain_core.messages import (
AIMessage,
@ -331,6 +335,43 @@ class BaseLLM(BaseLanguageModel[str], ABC):
"Must be a PromptValue, str, or list of BaseMessages."
)
def _get_ls_params(
self,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> LangSmithParams:
"""Get standard params for tracing."""
# get default provider from class name
default_provider = self.__class__.__name__
if default_provider.endswith("LLM"):
default_provider = default_provider[:-3]
default_provider = default_provider.lower()
ls_params = LangSmithParams(ls_provider=default_provider, ls_model_type="llm")
if stop:
ls_params["ls_stop"] = stop
# model
if hasattr(self, "model") and isinstance(self.model, str):
ls_params["ls_model_name"] = self.model
elif hasattr(self, "model_name") and isinstance(self.model_name, str):
ls_params["ls_model_name"] = self.model_name
# temperature
if "temperature" in kwargs and isinstance(kwargs["temperature"], float):
ls_params["ls_temperature"] = kwargs["temperature"]
elif hasattr(self, "temperature") and isinstance(self.temperature, float):
ls_params["ls_temperature"] = self.temperature
# max_tokens
if "max_tokens" in kwargs and isinstance(kwargs["max_tokens"], int):
ls_params["ls_max_tokens"] = kwargs["max_tokens"]
elif hasattr(self, "max_tokens") and isinstance(self.max_tokens, int):
ls_params["ls_max_tokens"] = self.max_tokens
return ls_params
def invoke(
self,
input: LanguageModelInput,
@ -487,13 +528,17 @@ class BaseLLM(BaseLanguageModel[str], ABC):
params["stop"] = stop
params = {**params, **kwargs}
options = {"stop": stop}
inheritable_metadata = {
**(config.get("metadata") or {}),
**self._get_ls_params(stop=stop, **kwargs),
}
callback_manager = CallbackManager.configure(
config.get("callbacks"),
self.callbacks,
self.verbose,
config.get("tags"),
self.tags,
config.get("metadata"),
inheritable_metadata,
self.metadata,
)
(run_manager,) = callback_manager.on_llm_start(
@ -548,13 +593,17 @@ class BaseLLM(BaseLanguageModel[str], ABC):
params["stop"] = stop
params = {**params, **kwargs}
options = {"stop": stop}
inheritable_metadata = {
**(config.get("metadata") or {}),
**self._get_ls_params(stop=stop, **kwargs),
}
callback_manager = AsyncCallbackManager.configure(
config.get("callbacks"),
self.callbacks,
self.verbose,
config.get("tags"),
self.tags,
config.get("metadata"),
inheritable_metadata,
self.metadata,
)
(run_manager,) = await callback_manager.on_llm_start(
@ -796,6 +845,21 @@ class BaseLLM(BaseLanguageModel[str], ABC):
f" argument of type {type(prompts)}."
)
# Create callback managers
if isinstance(metadata, list):
metadata = [
{
**(meta or {}),
**self._get_ls_params(stop=stop, **kwargs),
}
for meta in metadata
]
elif isinstance(metadata, dict):
metadata = {
**(metadata or {}),
**self._get_ls_params(stop=stop, **kwargs),
}
else:
pass
if (
isinstance(callbacks, list)
and callbacks
@ -1017,6 +1081,21 @@ class BaseLLM(BaseLanguageModel[str], ABC):
An LLMResult, which contains a list of candidate Generations for each input
prompt and additional model provider-specific output.
"""
if isinstance(metadata, list):
metadata = [
{
**(meta or {}),
**self._get_ls_params(stop=stop, **kwargs),
}
for meta in metadata
]
elif isinstance(metadata, dict):
metadata = {
**(metadata or {}),
**self._get_ls_params(stop=stop, **kwargs),
}
else:
pass
# Create callback managers
if isinstance(callbacks, list) and (
isinstance(callbacks[0], (list, BaseCallbackManager))

View File

@ -6,6 +6,7 @@ EXPECTED_ALL = [
"SimpleChatModel",
"BaseLLM",
"LLM",
"LangSmithParams",
"LanguageModelInput",
"LanguageModelOutput",
"LanguageModelLike",

File diff suppressed because one or more lines are too long

View File

@ -2180,7 +2180,7 @@ async def test_prompt_with_llm(
"value": {
"end_time": None,
"final_output": None,
"metadata": {},
"metadata": {"ls_model_type": "llm", "ls_provider": "fakelist"},
"name": "FakeListLLM",
"start_time": "2023-01-01T00:00:00.000+00:00",
"streamed_output": [],
@ -2384,7 +2384,10 @@ async def test_prompt_with_llm_parser(
"value": {
"end_time": None,
"final_output": None,
"metadata": {},
"metadata": {
"ls_model_type": "llm",
"ls_provider": "fakestreaminglist",
},
"name": "FakeStreamingListLLM",
"start_time": "2023-01-01T00:00:00.000+00:00",
"streamed_output": [],

View File

@ -17,7 +17,7 @@ from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseLanguageModel
from langchain_core.language_models import BaseLanguageModel, LangSmithParams
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
from langchain_core.prompt_values import PromptValue
@ -204,6 +204,19 @@ class AnthropicLLM(LLM, _AnthropicCommon):
"max_retries": self.max_retries,
}
def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> LangSmithParams:
"""Get standard params for tracing."""
params = super()._get_ls_params(stop=stop, **kwargs)
identifying_params = self._identifying_params
if max_tokens := kwargs.get(
"max_tokens_to_sample",
identifying_params.get("max_tokens"),
):
params["ls_max_tokens"] = max_tokens
return params
def _wrap_prompt(self, prompt: str) -> str:
if not self.HUMAN_PROMPT or not self.AI_PROMPT:
raise NameError("Please ensure the anthropic package is loaded")

View File

@ -0,0 +1,29 @@
import os
from langchain_anthropic import AnthropicLLM
os.environ["ANTHROPIC_API_KEY"] = "foo"
def test_anthropic_model_params() -> None:
# Test standard tracing params
llm = AnthropicLLM(model="foo") # type: ignore[call-arg]
ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "anthropic",
"ls_model_type": "llm",
"ls_model_name": "foo",
"ls_max_tokens": 1024,
}
llm = AnthropicLLM(model="foo", temperature=0.1) # type: ignore[call-arg]
ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "anthropic",
"ls_model_type": "llm",
"ls_model_name": "foo",
"ls_max_tokens": 1024,
"ls_temperature": 0.1,
}

View File

@ -69,3 +69,31 @@ def test_fireworks_uses_actual_secret_value_from_secretstr() -> None:
max_tokens=250,
)
assert cast(SecretStr, llm.fireworks_api_key).get_secret_value() == "secret-api-key"
def test_fireworks_model_params() -> None:
# Test standard tracing params
llm = Fireworks(model="foo", api_key="secret-api-key") # type: ignore[arg-type]
ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "fireworks",
"ls_model_type": "llm",
"ls_model_name": "foo",
}
llm = Fireworks(
model="foo",
api_key="secret-api-key", # type: ignore[arg-type]
max_tokens=10,
temperature=0.1,
)
ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "fireworks",
"ls_model_type": "llm",
"ls_model_name": "foo",
"ls_max_tokens": 10,
"ls_temperature": 0.1,
}

View File

@ -16,7 +16,7 @@ from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseLLM
from langchain_core.language_models import BaseLLM, LangSmithParams
from langchain_core.outputs import GenerationChunk, LLMResult
from langchain_core.pydantic_v1 import Field, root_validator
from ollama import AsyncClient, Client, Options
@ -155,6 +155,15 @@ class OllamaLLM(BaseLLM):
"""Return type of LLM."""
return "ollama-llm"
def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> LangSmithParams:
"""Get standard params for tracing."""
params = super()._get_ls_params(stop=stop, **kwargs)
if max_tokens := kwargs.get("num_predict", self.num_predict):
params["ls_max_tokens"] = max_tokens
return params
@root_validator(pre=False, skip_on_failure=True)
def _set_clients(cls, values: dict) -> dict:
"""Set clients to use for ollama."""

View File

@ -6,3 +6,23 @@ from langchain_ollama import OllamaLLM
def test_initialization() -> None:
"""Test integration initialization."""
OllamaLLM(model="llama3")
def test_model_params() -> None:
# Test standard tracing params
llm = OllamaLLM(model="llama3")
ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "ollama",
"ls_model_type": "llm",
"ls_model_name": "llama3",
}
llm = OllamaLLM(model="llama3", num_predict=3)
ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "ollama",
"ls_model_type": "llm",
"ls_model_name": "llama3",
"ls_max_tokens": 3,
}

View File

@ -5,6 +5,7 @@ import os
from typing import Any, Callable, Dict, List, Mapping, Optional, Union
import openai
from langchain_core.language_models import LangSmithParams
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
@ -195,6 +196,17 @@ class AzureOpenAI(BaseOpenAI):
openai_params = {"model": self.deployment_name}
return {**openai_params, **super()._invocation_params}
def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> LangSmithParams:
"""Get standard params for tracing."""
params = super()._get_ls_params(stop=stop, **kwargs)
invocation_params = self._invocation_params
params["ls_provider"] = "azure"
if model_name := invocation_params.get("model"):
params["ls_model_name"] = model_name
return params
@property
def _llm_type(self) -> str:
"""Return type of llm."""

View File

@ -0,0 +1,23 @@
from typing import Any
from langchain_openai import AzureOpenAI
def test_azure_model_param(monkeypatch: Any) -> None:
monkeypatch.delenv("OPENAI_API_BASE", raising=False)
llm = AzureOpenAI(
openai_api_key="secret-api-key", # type: ignore[call-arg]
azure_endpoint="endpoint",
api_version="version",
azure_deployment="gpt-35-turbo-instruct",
)
# Test standard tracing params
ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "azure",
"ls_model_type": "llm",
"ls_model_name": "gpt-35-turbo-instruct",
"ls_temperature": 0.7,
"ls_max_tokens": 256,
}

View File

@ -14,6 +14,16 @@ def test_openai_model_param() -> None:
llm = OpenAI(model_name="foo") # type: ignore[call-arg]
assert llm.model_name == "foo"
# Test standard tracing params
ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "openai",
"ls_model_type": "llm",
"ls_model_name": "foo",
"ls_temperature": 0.7,
"ls_max_tokens": 256,
}
def test_openai_model_kwargs() -> None:
llm = OpenAI(model_kwargs={"foo": "bar"})

View File

@ -69,3 +69,33 @@ def test_together_uses_actual_secret_value_from_secretstr_api_key() -> None:
max_tokens=250,
)
assert cast(SecretStr, llm.together_api_key).get_secret_value() == "secret-api-key"
def test_together_model_params() -> None:
# Test standard tracing params
llm = Together(
api_key="secret-api-key", # type: ignore[arg-type]
model="foo",
)
ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "together",
"ls_model_type": "llm",
"ls_model_name": "foo",
"ls_max_tokens": 200,
}
llm = Together(
api_key="secret-api-key", # type: ignore[arg-type]
model="foo",
temperature=0.2,
max_tokens=250,
)
ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "together",
"ls_model_type": "llm",
"ls_model_name": "foo",
"ls_max_tokens": 250,
"ls_temperature": 0.2,
}