mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-20 01:54:14 +00:00
feat(core): allow overriding ls_model_name
from kwargs (#32541)
This commit is contained in:
@@ -720,7 +720,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
ls_params["ls_stop"] = stop
|
||||
|
||||
# model
|
||||
if hasattr(self, "model") and isinstance(self.model, str):
|
||||
if "model" in kwargs and isinstance(kwargs["model"], str):
|
||||
ls_params["ls_model_name"] = kwargs["model"]
|
||||
elif hasattr(self, "model") and isinstance(self.model, str):
|
||||
ls_params["ls_model_name"] = self.model
|
||||
elif hasattr(self, "model_name") and isinstance(self.model_name, str):
|
||||
ls_params["ls_model_name"] = self.model_name
|
||||
|
@@ -357,7 +357,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
ls_params["ls_stop"] = stop
|
||||
|
||||
# model
|
||||
if hasattr(self, "model") and isinstance(self.model, str):
|
||||
if "model" in kwargs and isinstance(kwargs["model"], str):
|
||||
ls_params["ls_model_name"] = kwargs["model"]
|
||||
elif hasattr(self, "model") and isinstance(self.model, str):
|
||||
ls_params["ls_model_name"] = self.model
|
||||
elif hasattr(self, "model_name") and isinstance(self.model_name, str):
|
||||
ls_params["ls_model_name"] = self.model_name
|
||||
|
@@ -654,3 +654,57 @@ def test_normalize_messages_edge_cases() -> None:
|
||||
)
|
||||
]
|
||||
assert messages == _normalize_messages(messages)
|
||||
|
||||
|
||||
def test_get_ls_params() -> None:
|
||||
class LSParamsModel(BaseChatModel):
|
||||
model: str = "foo"
|
||||
temperature: float = 0.1
|
||||
max_tokens: int = 1024
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def _stream(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-chat-model"
|
||||
|
||||
llm = LSParamsModel()
|
||||
|
||||
# Test standard tracing params
|
||||
ls_params = llm._get_ls_params()
|
||||
assert ls_params == {
|
||||
"ls_provider": "lsparamsmodel",
|
||||
"ls_model_type": "chat",
|
||||
"ls_model_name": "foo",
|
||||
"ls_temperature": 0.1,
|
||||
"ls_max_tokens": 1024,
|
||||
}
|
||||
|
||||
ls_params = llm._get_ls_params(model="bar")
|
||||
assert ls_params["ls_model_name"] == "bar"
|
||||
|
||||
ls_params = llm._get_ls_params(temperature=0.2)
|
||||
assert ls_params["ls_temperature"] == 0.2
|
||||
|
||||
ls_params = llm._get_ls_params(max_tokens=2048)
|
||||
assert ls_params["ls_max_tokens"] == 2048
|
||||
|
||||
ls_params = llm._get_ls_params(stop=["stop"])
|
||||
assert ls_params["ls_stop"] == ["stop"]
|
||||
|
@@ -232,3 +232,48 @@ async def test_astream_implementation_uses_astream() -> None:
|
||||
model = ModelWithAsyncStream()
|
||||
chunks = [chunk async for chunk in model.astream("anything")]
|
||||
assert chunks == ["a", "b"]
|
||||
|
||||
|
||||
def test_get_ls_params() -> None:
|
||||
class LSParamsModel(BaseLLM):
|
||||
model: str = "foo"
|
||||
temperature: float = 0.1
|
||||
max_tokens: int = 1024
|
||||
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
prompts: list[str],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-model"
|
||||
|
||||
llm = LSParamsModel()
|
||||
|
||||
# Test standard tracing params
|
||||
ls_params = llm._get_ls_params()
|
||||
assert ls_params == {
|
||||
"ls_provider": "lsparamsmodel",
|
||||
"ls_model_type": "llm",
|
||||
"ls_model_name": "foo",
|
||||
"ls_temperature": 0.1,
|
||||
"ls_max_tokens": 1024,
|
||||
}
|
||||
|
||||
ls_params = llm._get_ls_params(model="bar")
|
||||
assert ls_params["ls_model_name"] == "bar"
|
||||
|
||||
ls_params = llm._get_ls_params(temperature=0.2)
|
||||
assert ls_params["ls_temperature"] == 0.2
|
||||
|
||||
ls_params = llm._get_ls_params(max_tokens=2048)
|
||||
assert ls_params["ls_max_tokens"] == 2048
|
||||
|
||||
ls_params = llm._get_ls_params(stop=["stop"])
|
||||
assert ls_params["ls_stop"] == ["stop"]
|
||||
|
Reference in New Issue
Block a user