mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
feat(core): allow overriding ls_model_name from kwargs (#32541)
This commit is contained in:
@@ -720,7 +720,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
ls_params["ls_stop"] = stop
|
||||
|
||||
# model
|
||||
if hasattr(self, "model") and isinstance(self.model, str):
|
||||
if "model" in kwargs and isinstance(kwargs["model"], str):
|
||||
ls_params["ls_model_name"] = kwargs["model"]
|
||||
elif hasattr(self, "model") and isinstance(self.model, str):
|
||||
ls_params["ls_model_name"] = self.model
|
||||
elif hasattr(self, "model_name") and isinstance(self.model_name, str):
|
||||
ls_params["ls_model_name"] = self.model_name
|
||||
|
||||
@@ -357,7 +357,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
ls_params["ls_stop"] = stop
|
||||
|
||||
# model
|
||||
if hasattr(self, "model") and isinstance(self.model, str):
|
||||
if "model" in kwargs and isinstance(kwargs["model"], str):
|
||||
ls_params["ls_model_name"] = kwargs["model"]
|
||||
elif hasattr(self, "model") and isinstance(self.model, str):
|
||||
ls_params["ls_model_name"] = self.model
|
||||
elif hasattr(self, "model_name") and isinstance(self.model_name, str):
|
||||
ls_params["ls_model_name"] = self.model_name
|
||||
|
||||
@@ -654,3 +654,57 @@ def test_normalize_messages_edge_cases() -> None:
|
||||
)
|
||||
]
|
||||
assert messages == _normalize_messages(messages)
|
||||
|
||||
|
||||
def test_get_ls_params() -> None:
|
||||
class LSParamsModel(BaseChatModel):
|
||||
model: str = "foo"
|
||||
temperature: float = 0.1
|
||||
max_tokens: int = 1024
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def _stream(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-chat-model"
|
||||
|
||||
llm = LSParamsModel()
|
||||
|
||||
# Test standard tracing params
|
||||
ls_params = llm._get_ls_params()
|
||||
assert ls_params == {
|
||||
"ls_provider": "lsparamsmodel",
|
||||
"ls_model_type": "chat",
|
||||
"ls_model_name": "foo",
|
||||
"ls_temperature": 0.1,
|
||||
"ls_max_tokens": 1024,
|
||||
}
|
||||
|
||||
ls_params = llm._get_ls_params(model="bar")
|
||||
assert ls_params["ls_model_name"] == "bar"
|
||||
|
||||
ls_params = llm._get_ls_params(temperature=0.2)
|
||||
assert ls_params["ls_temperature"] == 0.2
|
||||
|
||||
ls_params = llm._get_ls_params(max_tokens=2048)
|
||||
assert ls_params["ls_max_tokens"] == 2048
|
||||
|
||||
ls_params = llm._get_ls_params(stop=["stop"])
|
||||
assert ls_params["ls_stop"] == ["stop"]
|
||||
|
||||
@@ -232,3 +232,48 @@ async def test_astream_implementation_uses_astream() -> None:
|
||||
model = ModelWithAsyncStream()
|
||||
chunks = [chunk async for chunk in model.astream("anything")]
|
||||
assert chunks == ["a", "b"]
|
||||
|
||||
|
||||
def test_get_ls_params() -> None:
|
||||
class LSParamsModel(BaseLLM):
|
||||
model: str = "foo"
|
||||
temperature: float = 0.1
|
||||
max_tokens: int = 1024
|
||||
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
prompts: list[str],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-model"
|
||||
|
||||
llm = LSParamsModel()
|
||||
|
||||
# Test standard tracing params
|
||||
ls_params = llm._get_ls_params()
|
||||
assert ls_params == {
|
||||
"ls_provider": "lsparamsmodel",
|
||||
"ls_model_type": "llm",
|
||||
"ls_model_name": "foo",
|
||||
"ls_temperature": 0.1,
|
||||
"ls_max_tokens": 1024,
|
||||
}
|
||||
|
||||
ls_params = llm._get_ls_params(model="bar")
|
||||
assert ls_params["ls_model_name"] == "bar"
|
||||
|
||||
ls_params = llm._get_ls_params(temperature=0.2)
|
||||
assert ls_params["ls_temperature"] == 0.2
|
||||
|
||||
ls_params = llm._get_ls_params(max_tokens=2048)
|
||||
assert ls_params["ls_max_tokens"] == 2048
|
||||
|
||||
ls_params = llm._get_ls_params(stop=["stop"])
|
||||
assert ls_params["ls_stop"] == ["stop"]
|
||||
|
||||
@@ -1376,7 +1376,7 @@ class ChatAnthropic(BaseChatModel):
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
ls_params = LangSmithParams(
|
||||
ls_provider="anthropic",
|
||||
ls_model_name=self.model,
|
||||
ls_model_name=params.get("model", self.model),
|
||||
ls_model_type="chat",
|
||||
ls_temperature=params.get("temperature", self.temperature),
|
||||
)
|
||||
|
||||
@@ -1215,6 +1215,22 @@ def test_cache_control_kwarg() -> None:
|
||||
]
|
||||
|
||||
|
||||
def test_anthropic_model_params() -> None:
|
||||
llm = ChatAnthropic(model="claude-3-5-haiku-latest")
|
||||
|
||||
ls_params = llm._get_ls_params()
|
||||
assert ls_params == {
|
||||
"ls_provider": "anthropic",
|
||||
"ls_model_type": "chat",
|
||||
"ls_model_name": "claude-3-5-haiku-latest",
|
||||
"ls_max_tokens": 1024,
|
||||
"ls_temperature": None,
|
||||
}
|
||||
|
||||
ls_params = llm._get_ls_params(model="claude-opus-4-1-20250805")
|
||||
assert ls_params["ls_model_name"] == "claude-opus-4-1-20250805"
|
||||
|
||||
|
||||
def test_streaming_cache_token_reporting() -> None:
|
||||
"""Test that cache tokens are properly reported in streaming events."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@@ -413,7 +413,7 @@ class ChatFireworks(BaseChatModel):
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
ls_params = LangSmithParams(
|
||||
ls_provider="fireworks",
|
||||
ls_model_name=self.model_name,
|
||||
ls_model_name=params.get("model", self.model_name),
|
||||
ls_model_type="chat",
|
||||
ls_temperature=params.get("temperature", self.temperature),
|
||||
)
|
||||
|
||||
@@ -480,7 +480,7 @@ class ChatGroq(BaseChatModel):
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
ls_params = LangSmithParams(
|
||||
ls_provider="groq",
|
||||
ls_model_name=self.model_name,
|
||||
ls_model_name=params.get("model", self.model_name),
|
||||
ls_model_type="chat",
|
||||
ls_temperature=params.get("temperature", self.temperature),
|
||||
)
|
||||
|
||||
@@ -448,7 +448,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
ls_params = LangSmithParams(
|
||||
ls_provider="mistral",
|
||||
ls_model_name=self.model,
|
||||
ls_model_name=params.get("model", self.model),
|
||||
ls_model_type="chat",
|
||||
ls_temperature=params.get("temperature", self.temperature),
|
||||
)
|
||||
|
||||
@@ -1471,7 +1471,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
ls_params = LangSmithParams(
|
||||
ls_provider="openai",
|
||||
ls_model_name=self.model_name,
|
||||
ls_model_name=params.get("model", self.model_name),
|
||||
ls_model_type="chat",
|
||||
ls_temperature=params.get("temperature", self.temperature),
|
||||
)
|
||||
|
||||
@@ -25,6 +25,9 @@ def test_openai_model_param() -> None:
|
||||
"ls_max_tokens": 256,
|
||||
}
|
||||
|
||||
ls_params = llm._get_ls_params(model="bar")
|
||||
assert ls_params["ls_model_name"] == "bar"
|
||||
|
||||
|
||||
def test_openai_model_kwargs() -> None:
|
||||
llm = OpenAI(model_kwargs={"foo": "bar"})
|
||||
|
||||
Reference in New Issue
Block a user