feat(core): allow overriding ls_model_name from kwargs (#32541)

This commit is contained in:
Matthew Lapointe
2025-09-11 16:18:06 -04:00
committed by GitHub
parent 2903e08311
commit b1f08467cd
11 changed files with 129 additions and 7 deletions

View File

@@ -720,7 +720,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
ls_params["ls_stop"] = stop
# model
if hasattr(self, "model") and isinstance(self.model, str):
if "model" in kwargs and isinstance(kwargs["model"], str):
ls_params["ls_model_name"] = kwargs["model"]
elif hasattr(self, "model") and isinstance(self.model, str):
ls_params["ls_model_name"] = self.model
elif hasattr(self, "model_name") and isinstance(self.model_name, str):
ls_params["ls_model_name"] = self.model_name

View File

@@ -357,7 +357,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
ls_params["ls_stop"] = stop
# model
if hasattr(self, "model") and isinstance(self.model, str):
if "model" in kwargs and isinstance(kwargs["model"], str):
ls_params["ls_model_name"] = kwargs["model"]
elif hasattr(self, "model") and isinstance(self.model, str):
ls_params["ls_model_name"] = self.model
elif hasattr(self, "model_name") and isinstance(self.model_name, str):
ls_params["ls_model_name"] = self.model_name

View File

@@ -654,3 +654,57 @@ def test_normalize_messages_edge_cases() -> None:
)
]
assert messages == _normalize_messages(messages)
def test_get_ls_params() -> None:
class LSParamsModel(BaseChatModel):
model: str = "foo"
temperature: float = 0.1
max_tokens: int = 1024
def _generate(
self,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
raise NotImplementedError
@override
def _stream(
self,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
raise NotImplementedError
@property
def _llm_type(self) -> str:
return "fake-chat-model"
llm = LSParamsModel()
# Test standard tracing params
ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "lsparamsmodel",
"ls_model_type": "chat",
"ls_model_name": "foo",
"ls_temperature": 0.1,
"ls_max_tokens": 1024,
}
ls_params = llm._get_ls_params(model="bar")
assert ls_params["ls_model_name"] == "bar"
ls_params = llm._get_ls_params(temperature=0.2)
assert ls_params["ls_temperature"] == 0.2
ls_params = llm._get_ls_params(max_tokens=2048)
assert ls_params["ls_max_tokens"] == 2048
ls_params = llm._get_ls_params(stop=["stop"])
assert ls_params["ls_stop"] == ["stop"]

View File

@@ -232,3 +232,48 @@ async def test_astream_implementation_uses_astream() -> None:
model = ModelWithAsyncStream()
chunks = [chunk async for chunk in model.astream("anything")]
assert chunks == ["a", "b"]
def test_get_ls_params() -> None:
class LSParamsModel(BaseLLM):
model: str = "foo"
temperature: float = 0.1
max_tokens: int = 1024
@override
def _generate(
self,
prompts: list[str],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
raise NotImplementedError
@property
def _llm_type(self) -> str:
return "fake-model"
llm = LSParamsModel()
# Test standard tracing params
ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "lsparamsmodel",
"ls_model_type": "llm",
"ls_model_name": "foo",
"ls_temperature": 0.1,
"ls_max_tokens": 1024,
}
ls_params = llm._get_ls_params(model="bar")
assert ls_params["ls_model_name"] == "bar"
ls_params = llm._get_ls_params(temperature=0.2)
assert ls_params["ls_temperature"] == 0.2
ls_params = llm._get_ls_params(max_tokens=2048)
assert ls_params["ls_max_tokens"] == 2048
ls_params = llm._get_ls_params(stop=["stop"])
assert ls_params["ls_stop"] == ["stop"]