mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-29 04:16:02 +00:00
core, partners: implement standard tracing params for LLMs (#25410)
This commit is contained in:
parent
9f0c76bf89
commit
b83f1eb0d5
@ -10,7 +10,7 @@ def test_standard_params() -> None:
|
||||
class ExpectedParams(BaseModel):
|
||||
ls_provider: str
|
||||
ls_model_name: str
|
||||
ls_model_type: Literal["chat"]
|
||||
ls_model_type: Literal["chat", "llm"]
|
||||
ls_temperature: Optional[float]
|
||||
ls_max_tokens: Optional[int]
|
||||
ls_stop: Optional[List[str]]
|
||||
|
@ -39,6 +39,7 @@ https://python.langchain.com/v0.2/docs/how_to/custom_llm/
|
||||
|
||||
from langchain_core.language_models.base import (
|
||||
BaseLanguageModel,
|
||||
LangSmithParams,
|
||||
LanguageModelInput,
|
||||
LanguageModelLike,
|
||||
LanguageModelOutput,
|
||||
@ -62,6 +63,7 @@ __all__ = [
|
||||
"LLM",
|
||||
"LanguageModelInput",
|
||||
"get_tokenizer",
|
||||
"LangSmithParams",
|
||||
"LanguageModelOutput",
|
||||
"LanguageModelLike",
|
||||
"FakeListLLM",
|
||||
|
@ -8,6 +8,7 @@ from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
@ -17,7 +18,7 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from typing_extensions import TypeAlias
|
||||
from typing_extensions import TypeAlias, TypedDict
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.messages import (
|
||||
@ -37,6 +38,23 @@ if TYPE_CHECKING:
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
|
||||
class LangSmithParams(TypedDict, total=False):
|
||||
"""LangSmith parameters for tracing."""
|
||||
|
||||
ls_provider: str
|
||||
"""Provider of the model."""
|
||||
ls_model_name: str
|
||||
"""Name of the model."""
|
||||
ls_model_type: Literal["chat", "llm"]
|
||||
"""Type of the model. Should be 'chat' or 'llm'."""
|
||||
ls_temperature: Optional[float]
|
||||
"""Temperature for generation."""
|
||||
ls_max_tokens: Optional[int]
|
||||
"""Max tokens for generation."""
|
||||
ls_stop: Optional[List[str]]
|
||||
"""Stop words for generation."""
|
||||
|
||||
|
||||
@lru_cache(maxsize=None) # Cache the tokenizer
|
||||
def get_tokenizer() -> Any:
|
||||
"""Get a GPT-2 tokenizer instance.
|
||||
|
@ -23,8 +23,6 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.caches import BaseCache
|
||||
from langchain_core.callbacks import (
|
||||
@ -36,7 +34,11 @@ from langchain_core.callbacks import (
|
||||
Callbacks,
|
||||
)
|
||||
from langchain_core.globals import get_llm_cache
|
||||
from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput
|
||||
from langchain_core.language_models.base import (
|
||||
BaseLanguageModel,
|
||||
LangSmithParams,
|
||||
LanguageModelInput,
|
||||
)
|
||||
from langchain_core.load import dumpd, dumps
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
@ -73,23 +75,6 @@ if TYPE_CHECKING:
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
|
||||
class LangSmithParams(TypedDict, total=False):
|
||||
"""LangSmith parameters for tracing."""
|
||||
|
||||
ls_provider: str
|
||||
"""Provider of the model."""
|
||||
ls_model_name: str
|
||||
"""Name of the model."""
|
||||
ls_model_type: Literal["chat"]
|
||||
"""Type of the model. Should be 'chat'."""
|
||||
ls_temperature: Optional[float]
|
||||
"""Temperature for generation."""
|
||||
ls_max_tokens: Optional[int]
|
||||
"""Max tokens for generation."""
|
||||
ls_stop: Optional[List[str]]
|
||||
"""Stop words for generation."""
|
||||
|
||||
|
||||
def generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
|
||||
"""Generate from a stream.
|
||||
|
||||
|
@ -48,7 +48,11 @@ from langchain_core.callbacks import (
|
||||
Callbacks,
|
||||
)
|
||||
from langchain_core.globals import get_llm_cache
|
||||
from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput
|
||||
from langchain_core.language_models.base import (
|
||||
BaseLanguageModel,
|
||||
LangSmithParams,
|
||||
LanguageModelInput,
|
||||
)
|
||||
from langchain_core.load import dumpd
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
@ -331,6 +335,43 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
"Must be a PromptValue, str, or list of BaseMessages."
|
||||
)
|
||||
|
||||
def _get_ls_params(
|
||||
self,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LangSmithParams:
|
||||
"""Get standard params for tracing."""
|
||||
|
||||
# get default provider from class name
|
||||
default_provider = self.__class__.__name__
|
||||
if default_provider.endswith("LLM"):
|
||||
default_provider = default_provider[:-3]
|
||||
default_provider = default_provider.lower()
|
||||
|
||||
ls_params = LangSmithParams(ls_provider=default_provider, ls_model_type="llm")
|
||||
if stop:
|
||||
ls_params["ls_stop"] = stop
|
||||
|
||||
# model
|
||||
if hasattr(self, "model") and isinstance(self.model, str):
|
||||
ls_params["ls_model_name"] = self.model
|
||||
elif hasattr(self, "model_name") and isinstance(self.model_name, str):
|
||||
ls_params["ls_model_name"] = self.model_name
|
||||
|
||||
# temperature
|
||||
if "temperature" in kwargs and isinstance(kwargs["temperature"], float):
|
||||
ls_params["ls_temperature"] = kwargs["temperature"]
|
||||
elif hasattr(self, "temperature") and isinstance(self.temperature, float):
|
||||
ls_params["ls_temperature"] = self.temperature
|
||||
|
||||
# max_tokens
|
||||
if "max_tokens" in kwargs and isinstance(kwargs["max_tokens"], int):
|
||||
ls_params["ls_max_tokens"] = kwargs["max_tokens"]
|
||||
elif hasattr(self, "max_tokens") and isinstance(self.max_tokens, int):
|
||||
ls_params["ls_max_tokens"] = self.max_tokens
|
||||
|
||||
return ls_params
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
@ -487,13 +528,17 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
params["stop"] = stop
|
||||
params = {**params, **kwargs}
|
||||
options = {"stop": stop}
|
||||
inheritable_metadata = {
|
||||
**(config.get("metadata") or {}),
|
||||
**self._get_ls_params(stop=stop, **kwargs),
|
||||
}
|
||||
callback_manager = CallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
config.get("tags"),
|
||||
self.tags,
|
||||
config.get("metadata"),
|
||||
inheritable_metadata,
|
||||
self.metadata,
|
||||
)
|
||||
(run_manager,) = callback_manager.on_llm_start(
|
||||
@ -548,13 +593,17 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
params["stop"] = stop
|
||||
params = {**params, **kwargs}
|
||||
options = {"stop": stop}
|
||||
inheritable_metadata = {
|
||||
**(config.get("metadata") or {}),
|
||||
**self._get_ls_params(stop=stop, **kwargs),
|
||||
}
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
config.get("tags"),
|
||||
self.tags,
|
||||
config.get("metadata"),
|
||||
inheritable_metadata,
|
||||
self.metadata,
|
||||
)
|
||||
(run_manager,) = await callback_manager.on_llm_start(
|
||||
@ -796,6 +845,21 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
f" argument of type {type(prompts)}."
|
||||
)
|
||||
# Create callback managers
|
||||
if isinstance(metadata, list):
|
||||
metadata = [
|
||||
{
|
||||
**(meta or {}),
|
||||
**self._get_ls_params(stop=stop, **kwargs),
|
||||
}
|
||||
for meta in metadata
|
||||
]
|
||||
elif isinstance(metadata, dict):
|
||||
metadata = {
|
||||
**(metadata or {}),
|
||||
**self._get_ls_params(stop=stop, **kwargs),
|
||||
}
|
||||
else:
|
||||
pass
|
||||
if (
|
||||
isinstance(callbacks, list)
|
||||
and callbacks
|
||||
@ -1017,6 +1081,21 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
An LLMResult, which contains a list of candidate Generations for each input
|
||||
prompt and additional model provider-specific output.
|
||||
"""
|
||||
if isinstance(metadata, list):
|
||||
metadata = [
|
||||
{
|
||||
**(meta or {}),
|
||||
**self._get_ls_params(stop=stop, **kwargs),
|
||||
}
|
||||
for meta in metadata
|
||||
]
|
||||
elif isinstance(metadata, dict):
|
||||
metadata = {
|
||||
**(metadata or {}),
|
||||
**self._get_ls_params(stop=stop, **kwargs),
|
||||
}
|
||||
else:
|
||||
pass
|
||||
# Create callback managers
|
||||
if isinstance(callbacks, list) and (
|
||||
isinstance(callbacks[0], (list, BaseCallbackManager))
|
||||
|
@ -6,6 +6,7 @@ EXPECTED_ALL = [
|
||||
"SimpleChatModel",
|
||||
"BaseLLM",
|
||||
"LLM",
|
||||
"LangSmithParams",
|
||||
"LanguageModelInput",
|
||||
"LanguageModelOutput",
|
||||
"LanguageModelLike",
|
||||
|
File diff suppressed because one or more lines are too long
@ -2180,7 +2180,7 @@ async def test_prompt_with_llm(
|
||||
"value": {
|
||||
"end_time": None,
|
||||
"final_output": None,
|
||||
"metadata": {},
|
||||
"metadata": {"ls_model_type": "llm", "ls_provider": "fakelist"},
|
||||
"name": "FakeListLLM",
|
||||
"start_time": "2023-01-01T00:00:00.000+00:00",
|
||||
"streamed_output": [],
|
||||
@ -2384,7 +2384,10 @@ async def test_prompt_with_llm_parser(
|
||||
"value": {
|
||||
"end_time": None,
|
||||
"final_output": None,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"ls_model_type": "llm",
|
||||
"ls_provider": "fakestreaminglist",
|
||||
},
|
||||
"name": "FakeStreamingListLLM",
|
||||
"start_time": "2023-01-01T00:00:00.000+00:00",
|
||||
"streamed_output": [],
|
||||
|
@ -17,7 +17,7 @@ from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.language_models import BaseLanguageModel, LangSmithParams
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.outputs import GenerationChunk
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
@ -204,6 +204,19 @@ class AnthropicLLM(LLM, _AnthropicCommon):
|
||||
"max_retries": self.max_retries,
|
||||
}
|
||||
|
||||
def _get_ls_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> LangSmithParams:
|
||||
"""Get standard params for tracing."""
|
||||
params = super()._get_ls_params(stop=stop, **kwargs)
|
||||
identifying_params = self._identifying_params
|
||||
if max_tokens := kwargs.get(
|
||||
"max_tokens_to_sample",
|
||||
identifying_params.get("max_tokens"),
|
||||
):
|
||||
params["ls_max_tokens"] = max_tokens
|
||||
return params
|
||||
|
||||
def _wrap_prompt(self, prompt: str) -> str:
|
||||
if not self.HUMAN_PROMPT or not self.AI_PROMPT:
|
||||
raise NameError("Please ensure the anthropic package is loaded")
|
||||
|
29
libs/partners/anthropic/tests/unit_tests/test_llms.py
Normal file
29
libs/partners/anthropic/tests/unit_tests/test_llms.py
Normal file
@ -0,0 +1,29 @@
|
||||
import os
|
||||
|
||||
from langchain_anthropic import AnthropicLLM
|
||||
|
||||
os.environ["ANTHROPIC_API_KEY"] = "foo"
|
||||
|
||||
|
||||
def test_anthropic_model_params() -> None:
|
||||
# Test standard tracing params
|
||||
llm = AnthropicLLM(model="foo") # type: ignore[call-arg]
|
||||
|
||||
ls_params = llm._get_ls_params()
|
||||
assert ls_params == {
|
||||
"ls_provider": "anthropic",
|
||||
"ls_model_type": "llm",
|
||||
"ls_model_name": "foo",
|
||||
"ls_max_tokens": 1024,
|
||||
}
|
||||
|
||||
llm = AnthropicLLM(model="foo", temperature=0.1) # type: ignore[call-arg]
|
||||
|
||||
ls_params = llm._get_ls_params()
|
||||
assert ls_params == {
|
||||
"ls_provider": "anthropic",
|
||||
"ls_model_type": "llm",
|
||||
"ls_model_name": "foo",
|
||||
"ls_max_tokens": 1024,
|
||||
"ls_temperature": 0.1,
|
||||
}
|
@ -69,3 +69,31 @@ def test_fireworks_uses_actual_secret_value_from_secretstr() -> None:
|
||||
max_tokens=250,
|
||||
)
|
||||
assert cast(SecretStr, llm.fireworks_api_key).get_secret_value() == "secret-api-key"
|
||||
|
||||
|
||||
def test_fireworks_model_params() -> None:
|
||||
# Test standard tracing params
|
||||
llm = Fireworks(model="foo", api_key="secret-api-key") # type: ignore[arg-type]
|
||||
|
||||
ls_params = llm._get_ls_params()
|
||||
assert ls_params == {
|
||||
"ls_provider": "fireworks",
|
||||
"ls_model_type": "llm",
|
||||
"ls_model_name": "foo",
|
||||
}
|
||||
|
||||
llm = Fireworks(
|
||||
model="foo",
|
||||
api_key="secret-api-key", # type: ignore[arg-type]
|
||||
max_tokens=10,
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
ls_params = llm._get_ls_params()
|
||||
assert ls_params == {
|
||||
"ls_provider": "fireworks",
|
||||
"ls_model_type": "llm",
|
||||
"ls_model_name": "foo",
|
||||
"ls_max_tokens": 10,
|
||||
"ls_temperature": 0.1,
|
||||
}
|
||||
|
@ -16,7 +16,7 @@ from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_core.language_models import BaseLLM, LangSmithParams
|
||||
from langchain_core.outputs import GenerationChunk, LLMResult
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from ollama import AsyncClient, Client, Options
|
||||
@ -155,6 +155,15 @@ class OllamaLLM(BaseLLM):
|
||||
"""Return type of LLM."""
|
||||
return "ollama-llm"
|
||||
|
||||
def _get_ls_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> LangSmithParams:
|
||||
"""Get standard params for tracing."""
|
||||
params = super()._get_ls_params(stop=stop, **kwargs)
|
||||
if max_tokens := kwargs.get("num_predict", self.num_predict):
|
||||
params["ls_max_tokens"] = max_tokens
|
||||
return params
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def _set_clients(cls, values: dict) -> dict:
|
||||
"""Set clients to use for ollama."""
|
||||
|
@ -6,3 +6,23 @@ from langchain_ollama import OllamaLLM
|
||||
def test_initialization() -> None:
|
||||
"""Test integration initialization."""
|
||||
OllamaLLM(model="llama3")
|
||||
|
||||
|
||||
def test_model_params() -> None:
|
||||
# Test standard tracing params
|
||||
llm = OllamaLLM(model="llama3")
|
||||
ls_params = llm._get_ls_params()
|
||||
assert ls_params == {
|
||||
"ls_provider": "ollama",
|
||||
"ls_model_type": "llm",
|
||||
"ls_model_name": "llama3",
|
||||
}
|
||||
|
||||
llm = OllamaLLM(model="llama3", num_predict=3)
|
||||
ls_params = llm._get_ls_params()
|
||||
assert ls_params == {
|
||||
"ls_provider": "ollama",
|
||||
"ls_model_type": "llm",
|
||||
"ls_model_name": "llama3",
|
||||
"ls_max_tokens": 3,
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ import os
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Union
|
||||
|
||||
import openai
|
||||
from langchain_core.language_models import LangSmithParams
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
|
||||
@ -195,6 +196,17 @@ class AzureOpenAI(BaseOpenAI):
|
||||
openai_params = {"model": self.deployment_name}
|
||||
return {**openai_params, **super()._invocation_params}
|
||||
|
||||
def _get_ls_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> LangSmithParams:
|
||||
"""Get standard params for tracing."""
|
||||
params = super()._get_ls_params(stop=stop, **kwargs)
|
||||
invocation_params = self._invocation_params
|
||||
params["ls_provider"] = "azure"
|
||||
if model_name := invocation_params.get("model"):
|
||||
params["ls_model_name"] = model_name
|
||||
return params
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
|
23
libs/partners/openai/tests/unit_tests/llms/test_azure.py
Normal file
23
libs/partners/openai/tests/unit_tests/llms/test_azure.py
Normal file
@ -0,0 +1,23 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain_openai import AzureOpenAI
|
||||
|
||||
|
||||
def test_azure_model_param(monkeypatch: Any) -> None:
|
||||
monkeypatch.delenv("OPENAI_API_BASE", raising=False)
|
||||
llm = AzureOpenAI(
|
||||
openai_api_key="secret-api-key", # type: ignore[call-arg]
|
||||
azure_endpoint="endpoint",
|
||||
api_version="version",
|
||||
azure_deployment="gpt-35-turbo-instruct",
|
||||
)
|
||||
|
||||
# Test standard tracing params
|
||||
ls_params = llm._get_ls_params()
|
||||
assert ls_params == {
|
||||
"ls_provider": "azure",
|
||||
"ls_model_type": "llm",
|
||||
"ls_model_name": "gpt-35-turbo-instruct",
|
||||
"ls_temperature": 0.7,
|
||||
"ls_max_tokens": 256,
|
||||
}
|
@ -14,6 +14,16 @@ def test_openai_model_param() -> None:
|
||||
llm = OpenAI(model_name="foo") # type: ignore[call-arg]
|
||||
assert llm.model_name == "foo"
|
||||
|
||||
# Test standard tracing params
|
||||
ls_params = llm._get_ls_params()
|
||||
assert ls_params == {
|
||||
"ls_provider": "openai",
|
||||
"ls_model_type": "llm",
|
||||
"ls_model_name": "foo",
|
||||
"ls_temperature": 0.7,
|
||||
"ls_max_tokens": 256,
|
||||
}
|
||||
|
||||
|
||||
def test_openai_model_kwargs() -> None:
|
||||
llm = OpenAI(model_kwargs={"foo": "bar"})
|
||||
|
@ -69,3 +69,33 @@ def test_together_uses_actual_secret_value_from_secretstr_api_key() -> None:
|
||||
max_tokens=250,
|
||||
)
|
||||
assert cast(SecretStr, llm.together_api_key).get_secret_value() == "secret-api-key"
|
||||
|
||||
|
||||
def test_together_model_params() -> None:
|
||||
# Test standard tracing params
|
||||
llm = Together(
|
||||
api_key="secret-api-key", # type: ignore[arg-type]
|
||||
model="foo",
|
||||
)
|
||||
ls_params = llm._get_ls_params()
|
||||
assert ls_params == {
|
||||
"ls_provider": "together",
|
||||
"ls_model_type": "llm",
|
||||
"ls_model_name": "foo",
|
||||
"ls_max_tokens": 200,
|
||||
}
|
||||
|
||||
llm = Together(
|
||||
api_key="secret-api-key", # type: ignore[arg-type]
|
||||
model="foo",
|
||||
temperature=0.2,
|
||||
max_tokens=250,
|
||||
)
|
||||
ls_params = llm._get_ls_params()
|
||||
assert ls_params == {
|
||||
"ls_provider": "together",
|
||||
"ls_model_type": "llm",
|
||||
"ls_model_name": "foo",
|
||||
"ls_max_tokens": 250,
|
||||
"ls_temperature": 0.2,
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user