feat(core): allow overriding ls_model_name from kwargs (#32541)

This commit is contained in:
Matthew Lapointe
2025-09-11 16:18:06 -04:00
committed by GitHub
parent 2903e08311
commit b1f08467cd
11 changed files with 129 additions and 7 deletions

View File

@@ -720,7 +720,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
ls_params["ls_stop"] = stop
# model
if hasattr(self, "model") and isinstance(self.model, str):
if "model" in kwargs and isinstance(kwargs["model"], str):
ls_params["ls_model_name"] = kwargs["model"]
elif hasattr(self, "model") and isinstance(self.model, str):
ls_params["ls_model_name"] = self.model
elif hasattr(self, "model_name") and isinstance(self.model_name, str):
ls_params["ls_model_name"] = self.model_name

View File

@@ -357,7 +357,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
ls_params["ls_stop"] = stop
# model
if hasattr(self, "model") and isinstance(self.model, str):
if "model" in kwargs and isinstance(kwargs["model"], str):
ls_params["ls_model_name"] = kwargs["model"]
elif hasattr(self, "model") and isinstance(self.model, str):
ls_params["ls_model_name"] = self.model
elif hasattr(self, "model_name") and isinstance(self.model_name, str):
ls_params["ls_model_name"] = self.model_name

View File

@@ -654,3 +654,57 @@ def test_normalize_messages_edge_cases() -> None:
)
]
assert messages == _normalize_messages(messages)
def test_get_ls_params() -> None:
class LSParamsModel(BaseChatModel):
model: str = "foo"
temperature: float = 0.1
max_tokens: int = 1024
def _generate(
self,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
raise NotImplementedError
@override
def _stream(
self,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
raise NotImplementedError
@property
def _llm_type(self) -> str:
return "fake-chat-model"
llm = LSParamsModel()
# Test standard tracing params
ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "lsparamsmodel",
"ls_model_type": "chat",
"ls_model_name": "foo",
"ls_temperature": 0.1,
"ls_max_tokens": 1024,
}
ls_params = llm._get_ls_params(model="bar")
assert ls_params["ls_model_name"] == "bar"
ls_params = llm._get_ls_params(temperature=0.2)
assert ls_params["ls_temperature"] == 0.2
ls_params = llm._get_ls_params(max_tokens=2048)
assert ls_params["ls_max_tokens"] == 2048
ls_params = llm._get_ls_params(stop=["stop"])
assert ls_params["ls_stop"] == ["stop"]

View File

@@ -232,3 +232,48 @@ async def test_astream_implementation_uses_astream() -> None:
model = ModelWithAsyncStream()
chunks = [chunk async for chunk in model.astream("anything")]
assert chunks == ["a", "b"]
def test_get_ls_params() -> None:
class LSParamsModel(BaseLLM):
model: str = "foo"
temperature: float = 0.1
max_tokens: int = 1024
@override
def _generate(
self,
prompts: list[str],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
raise NotImplementedError
@property
def _llm_type(self) -> str:
return "fake-model"
llm = LSParamsModel()
# Test standard tracing params
ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "lsparamsmodel",
"ls_model_type": "llm",
"ls_model_name": "foo",
"ls_temperature": 0.1,
"ls_max_tokens": 1024,
}
ls_params = llm._get_ls_params(model="bar")
assert ls_params["ls_model_name"] == "bar"
ls_params = llm._get_ls_params(temperature=0.2)
assert ls_params["ls_temperature"] == 0.2
ls_params = llm._get_ls_params(max_tokens=2048)
assert ls_params["ls_max_tokens"] == 2048
ls_params = llm._get_ls_params(stop=["stop"])
assert ls_params["ls_stop"] == ["stop"]

View File

@@ -1376,7 +1376,7 @@ class ChatAnthropic(BaseChatModel):
params = self._get_invocation_params(stop=stop, **kwargs)
ls_params = LangSmithParams(
ls_provider="anthropic",
ls_model_name=self.model,
ls_model_name=params.get("model", self.model),
ls_model_type="chat",
ls_temperature=params.get("temperature", self.temperature),
)

View File

@@ -1215,6 +1215,22 @@ def test_cache_control_kwarg() -> None:
]
def test_anthropic_model_params() -> None:
llm = ChatAnthropic(model="claude-3-5-haiku-latest")
ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "anthropic",
"ls_model_type": "chat",
"ls_model_name": "claude-3-5-haiku-latest",
"ls_max_tokens": 1024,
"ls_temperature": None,
}
ls_params = llm._get_ls_params(model="claude-opus-4-1-20250805")
assert ls_params["ls_model_name"] == "claude-opus-4-1-20250805"
def test_streaming_cache_token_reporting() -> None:
"""Test that cache tokens are properly reported in streaming events."""
from unittest.mock import MagicMock

View File

@@ -413,7 +413,7 @@ class ChatFireworks(BaseChatModel):
params = self._get_invocation_params(stop=stop, **kwargs)
ls_params = LangSmithParams(
ls_provider="fireworks",
ls_model_name=self.model_name,
ls_model_name=params.get("model", self.model_name),
ls_model_type="chat",
ls_temperature=params.get("temperature", self.temperature),
)

View File

@@ -480,7 +480,7 @@ class ChatGroq(BaseChatModel):
params = self._get_invocation_params(stop=stop, **kwargs)
ls_params = LangSmithParams(
ls_provider="groq",
ls_model_name=self.model_name,
ls_model_name=params.get("model", self.model_name),
ls_model_type="chat",
ls_temperature=params.get("temperature", self.temperature),
)

View File

@@ -448,7 +448,7 @@ class ChatMistralAI(BaseChatModel):
params = self._get_invocation_params(stop=stop, **kwargs)
ls_params = LangSmithParams(
ls_provider="mistral",
ls_model_name=self.model,
ls_model_name=params.get("model", self.model),
ls_model_type="chat",
ls_temperature=params.get("temperature", self.temperature),
)

View File

@@ -1471,7 +1471,7 @@ class BaseChatOpenAI(BaseChatModel):
params = self._get_invocation_params(stop=stop, **kwargs)
ls_params = LangSmithParams(
ls_provider="openai",
ls_model_name=self.model_name,
ls_model_name=params.get("model", self.model_name),
ls_model_type="chat",
ls_temperature=params.get("temperature", self.temperature),
)

View File

@@ -25,6 +25,9 @@ def test_openai_model_param() -> None:
"ls_max_tokens": 256,
}
ls_params = llm._get_ls_params(model="bar")
assert ls_params["ls_model_name"] == "bar"
def test_openai_model_kwargs() -> None:
llm = OpenAI(model_kwargs={"foo": "bar"})