core, partners: implement standard tracing params for LLMs (#25410)

This commit is contained in:
ccurme
2024-08-16 13:18:09 -04:00
committed by GitHub
parent 9f0c76bf89
commit b83f1eb0d5
17 changed files with 298 additions and 36 deletions

View File

@@ -17,7 +17,7 @@ from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseLanguageModel
from langchain_core.language_models import BaseLanguageModel, LangSmithParams
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
from langchain_core.prompt_values import PromptValue
@@ -204,6 +204,19 @@ class AnthropicLLM(LLM, _AnthropicCommon):
"max_retries": self.max_retries,
}
def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> LangSmithParams:
"""Get standard params for tracing."""
params = super()._get_ls_params(stop=stop, **kwargs)
identifying_params = self._identifying_params
if max_tokens := kwargs.get(
"max_tokens_to_sample",
identifying_params.get("max_tokens"),
):
params["ls_max_tokens"] = max_tokens
return params
def _wrap_prompt(self, prompt: str) -> str:
if not self.HUMAN_PROMPT or not self.AI_PROMPT:
raise NameError("Please ensure the anthropic package is loaded")

View File

@@ -0,0 +1,29 @@
import os
from langchain_anthropic import AnthropicLLM
os.environ["ANTHROPIC_API_KEY"] = "foo"
def test_anthropic_model_params() -> None:
# Test standard tracing params
llm = AnthropicLLM(model="foo") # type: ignore[call-arg]
ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "anthropic",
"ls_model_type": "llm",
"ls_model_name": "foo",
"ls_max_tokens": 1024,
}
llm = AnthropicLLM(model="foo", temperature=0.1) # type: ignore[call-arg]
ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "anthropic",
"ls_model_type": "llm",
"ls_model_name": "foo",
"ls_max_tokens": 1024,
"ls_temperature": 0.1,
}

View File

@@ -69,3 +69,31 @@ def test_fireworks_uses_actual_secret_value_from_secretstr() -> None:
max_tokens=250,
)
assert cast(SecretStr, llm.fireworks_api_key).get_secret_value() == "secret-api-key"
def test_fireworks_model_params() -> None:
# Test standard tracing params
llm = Fireworks(model="foo", api_key="secret-api-key") # type: ignore[arg-type]
ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "fireworks",
"ls_model_type": "llm",
"ls_model_name": "foo",
}
llm = Fireworks(
model="foo",
api_key="secret-api-key", # type: ignore[arg-type]
max_tokens=10,
temperature=0.1,
)
ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "fireworks",
"ls_model_type": "llm",
"ls_model_name": "foo",
"ls_max_tokens": 10,
"ls_temperature": 0.1,
}

View File

@@ -16,7 +16,7 @@ from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseLLM
from langchain_core.language_models import BaseLLM, LangSmithParams
from langchain_core.outputs import GenerationChunk, LLMResult
from langchain_core.pydantic_v1 import Field, root_validator
from ollama import AsyncClient, Client, Options
@@ -155,6 +155,15 @@ class OllamaLLM(BaseLLM):
"""Return type of LLM."""
return "ollama-llm"
def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> LangSmithParams:
"""Get standard params for tracing."""
params = super()._get_ls_params(stop=stop, **kwargs)
if max_tokens := kwargs.get("num_predict", self.num_predict):
params["ls_max_tokens"] = max_tokens
return params
@root_validator(pre=False, skip_on_failure=True)
def _set_clients(cls, values: dict) -> dict:
"""Set clients to use for ollama."""

View File

@@ -6,3 +6,23 @@ from langchain_ollama import OllamaLLM
def test_initialization() -> None:
"""Test integration initialization."""
OllamaLLM(model="llama3")
def test_model_params() -> None:
# Test standard tracing params
llm = OllamaLLM(model="llama3")
ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "ollama",
"ls_model_type": "llm",
"ls_model_name": "llama3",
}
llm = OllamaLLM(model="llama3", num_predict=3)
ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "ollama",
"ls_model_type": "llm",
"ls_model_name": "llama3",
"ls_max_tokens": 3,
}

View File

@@ -5,6 +5,7 @@ import os
from typing import Any, Callable, Dict, List, Mapping, Optional, Union
import openai
from langchain_core.language_models import LangSmithParams
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
@@ -195,6 +196,17 @@ class AzureOpenAI(BaseOpenAI):
openai_params = {"model": self.deployment_name}
return {**openai_params, **super()._invocation_params}
def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> LangSmithParams:
"""Get standard params for tracing."""
params = super()._get_ls_params(stop=stop, **kwargs)
invocation_params = self._invocation_params
params["ls_provider"] = "azure"
if model_name := invocation_params.get("model"):
params["ls_model_name"] = model_name
return params
@property
def _llm_type(self) -> str:
"""Return type of llm."""

View File

@@ -0,0 +1,23 @@
from typing import Any
from langchain_openai import AzureOpenAI
def test_azure_model_param(monkeypatch: Any) -> None:
monkeypatch.delenv("OPENAI_API_BASE", raising=False)
llm = AzureOpenAI(
openai_api_key="secret-api-key", # type: ignore[call-arg]
azure_endpoint="endpoint",
api_version="version",
azure_deployment="gpt-35-turbo-instruct",
)
# Test standard tracing params
ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "azure",
"ls_model_type": "llm",
"ls_model_name": "gpt-35-turbo-instruct",
"ls_temperature": 0.7,
"ls_max_tokens": 256,
}

View File

@@ -14,6 +14,16 @@ def test_openai_model_param() -> None:
llm = OpenAI(model_name="foo") # type: ignore[call-arg]
assert llm.model_name == "foo"
# Test standard tracing params
ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "openai",
"ls_model_type": "llm",
"ls_model_name": "foo",
"ls_temperature": 0.7,
"ls_max_tokens": 256,
}
def test_openai_model_kwargs() -> None:
llm = OpenAI(model_kwargs={"foo": "bar"})

View File

@@ -69,3 +69,33 @@ def test_together_uses_actual_secret_value_from_secretstr_api_key() -> None:
max_tokens=250,
)
assert cast(SecretStr, llm.together_api_key).get_secret_value() == "secret-api-key"
def test_together_model_params() -> None:
# Test standard tracing params
llm = Together(
api_key="secret-api-key", # type: ignore[arg-type]
model="foo",
)
ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "together",
"ls_model_type": "llm",
"ls_model_name": "foo",
"ls_max_tokens": 200,
}
llm = Together(
api_key="secret-api-key", # type: ignore[arg-type]
model="foo",
temperature=0.2,
max_tokens=250,
)
ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "together",
"ls_model_type": "llm",
"ls_model_name": "foo",
"ls_max_tokens": 250,
"ls_temperature": 0.2,
}